|
|
|
@ -15,23 +15,55 @@ import (
|
|
|
|
|
|
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
|
"github.com/smallstep/certificates/api"
|
|
|
|
|
"github.com/smallstep/certificates/ca/identity"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// mTLSDialContext will hold the dial context function to use in
|
|
|
|
|
// getDefaultTransport.
|
|
|
|
|
var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
|
var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
|
// STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS
|
|
|
|
|
// tunnel to step-ca using identity credentials. The value must have the
|
|
|
|
|
// form "host:port", if the form is not correct, the default dialer will be
|
|
|
|
|
// used. This feature is EXPERIMENTAL and might change at any time.
|
|
|
|
|
if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" {
|
|
|
|
|
if host, port, err := net.SplitHostPort(hostport); err == nil {
|
|
|
|
|
mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
|
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
|
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
|
|
|
|
|
}
|
|
|
|
|
// STEP_TLS_TUNNEL is an environment that can be set to do an TLS over
|
|
|
|
|
// (m)TLS tunnel to step-ca using identity-like credentials. The value is a
|
|
|
|
|
// path to a json file with the tunnel host, certificate, key and root used
|
|
|
|
|
// to create the (m)TLS tunnel.
|
|
|
|
|
//
|
|
|
|
|
// The configuration should look like:
|
|
|
|
|
// {
|
|
|
|
|
// "type": "tTLS",
|
|
|
|
|
// "host": "tunnel.example.com:443"
|
|
|
|
|
// "crt": "/path/to/tunnel.crt",
|
|
|
|
|
// "key": "/path/to/tunnel.key",
|
|
|
|
|
// "root": "/path/to/tunnel-root.crt"
|
|
|
|
|
// }
|
|
|
|
|
//
|
|
|
|
|
// This feature is EXPERIMENTAL and might change at any time.
|
|
|
|
|
if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" {
|
|
|
|
|
id, err := identity.LoadIdentity(path)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
if err := id.Validate(); err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
host, port, err := net.SplitHostPort(id.Host)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
pool, err := id.GetCertPool()
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
|
d := &tls.Dialer{
|
|
|
|
|
NetDialer: getDefaultDialer(),
|
|
|
|
|
Config: &tls.Config{
|
|
|
|
|
RootCAs: pool,
|
|
|
|
|
GetClientCertificate: id.GetClientCertificateFunc(),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
|
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -71,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Update renew function with transport
|
|
|
|
|
tr, err := getDefaultTransport(tlsConfig)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, nil, err
|
|
|
|
|
}
|
|
|
|
|
tr := getDefaultTransport(tlsConfig)
|
|
|
|
|
// Use mutable tls.Config on renew
|
|
|
|
|
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
|
|
|
|
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
|
|
|
@ -123,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|
|
|
|
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
|
|
|
|
|
|
|
|
|
|
// Update renew function with transport
|
|
|
|
|
tr, err := getDefaultTransport(tlsConfig)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
tr := getDefaultTransport(tlsConfig)
|
|
|
|
|
// Use mutable tls.Config on renew
|
|
|
|
|
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
|
|
|
|
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
|
|
|
@ -164,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell
|
|
|
|
|
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
|
|
|
|
|
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
|
|
|
|
|
return func(network, addr string) (net.Conn, error) {
|
|
|
|
|
return tls.DialWithDialer(&net.Dialer{
|
|
|
|
|
Timeout: 30 * time.Second,
|
|
|
|
|
KeepAlive: 30 * time.Second,
|
|
|
|
|
DualStack: true,
|
|
|
|
|
}, network, addr, ctx.mutableConfig.TLSConfig())
|
|
|
|
|
return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -176,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
|
|
|
|
|
// nolint:unused
|
|
|
|
|
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
|
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
|
|
|
d := getDefaultDialer()
|
|
|
|
|
// TLS dialers do not support context, but we can use the context
|
|
|
|
|
// deadline if it is set.
|
|
|
|
|
var deadline time.Time
|
|
|
|
|
if t, ok := ctx.Deadline(); ok {
|
|
|
|
|
deadline = t
|
|
|
|
|
d.Deadline = t
|
|
|
|
|
}
|
|
|
|
|
return tls.DialWithDialer(&net.Dialer{
|
|
|
|
|
Timeout: 30 * time.Second,
|
|
|
|
|
KeepAlive: 30 * time.Second,
|
|
|
|
|
Deadline: deadline,
|
|
|
|
|
DualStack: true,
|
|
|
|
|
}, network, addr, tlsCtx.mutableConfig.TLSConfig())
|
|
|
|
|
return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -258,25 +275,24 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// getDefaultDialer returns a new dialer with the default configuration.
|
|
|
|
|
func getDefaultDialer() *net.Dialer {
|
|
|
|
|
return &net.Dialer{
|
|
|
|
|
Timeout: 30 * time.Second,
|
|
|
|
|
KeepAlive: 30 * time.Second,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// getDefaultTransport returns an http.Transport with the same parameters than
|
|
|
|
|
// http.DefaultTransport, but adds the given tls.Config and configures the
|
|
|
|
|
// transport for HTTP/2.
|
|
|
|
|
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
|
|
|
|
|
func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
|
|
|
|
|
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
|
|
|
|
|
if mTLSDialContext == nil {
|
|
|
|
|
d := &net.Dialer{
|
|
|
|
|
Timeout: 30 * time.Second,
|
|
|
|
|
KeepAlive: 30 * time.Second,
|
|
|
|
|
}
|
|
|
|
|
d := getDefaultDialer()
|
|
|
|
|
dialContext = d.DialContext
|
|
|
|
|
} else {
|
|
|
|
|
dialContext = mTLSDialContext(&tls.Dialer{
|
|
|
|
|
NetDialer: &net.Dialer{
|
|
|
|
|
Timeout: 30 * time.Second,
|
|
|
|
|
KeepAlive: 30 * time.Second,
|
|
|
|
|
},
|
|
|
|
|
Config: tlsConfig,
|
|
|
|
|
})
|
|
|
|
|
dialContext = mTLSDialContext()
|
|
|
|
|
}
|
|
|
|
|
return &http.Transport{
|
|
|
|
|
Proxy: http.ProxyFromEnvironment,
|
|
|
|
@ -287,7 +303,7 @@ func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
|
|
|
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
|
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
|
|
|
TLSClientConfig: tlsConfig,
|
|
|
|
|
}, nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func getPEM(i interface{}) ([]byte, error) {
|
|
|
|
|