diff --git a/ca/identity/client.go b/ca/identity/client.go index 7daafacd..4377638f 100644 --- a/ca/identity/client.go +++ b/ca/identity/client.go @@ -60,7 +60,9 @@ func LoadClient() (*Client, error) { // Prepare transport with information in defaults.json and identity.json tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = &tls.Config{} + tr.TLSClientConfig = &tls.Config{ + GetClientCertificate: identity.GetClientCertificateFunc(), + } // RootCAs b, err = ioutil.ReadFile(defaults.Root) @@ -72,13 +74,6 @@ func LoadClient() (*Client, error) { tr.TLSClientConfig.RootCAs = pool } - // Certificate - crt, err := tls.LoadX509KeyPair(identity.Certificate, identity.Key) - if err != nil { - return nil, fmt.Errorf("error loading certificate: %v", err) - } - tr.TLSClientConfig.Certificates = []tls.Certificate{crt} - return &Client{ CaURL: caURL, Client: &http.Client{ diff --git a/ca/identity/client_test.go b/ca/identity/client_test.go index 8ff27bba..136e839a 100644 --- a/ca/identity/client_test.go +++ b/ca/identity/client_test.go @@ -185,11 +185,21 @@ func TestLoadClient(t *testing.T) { t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) } } else { - if !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || - !reflect.DeepEqual(got.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs, tt.want.Client.Transport.(*http.Transport).TLSClientConfig.RootCAs) || - !reflect.DeepEqual(got.Client.Transport.(*http.Transport).TLSClientConfig.Certificates, tt.want.Client.Transport.(*http.Transport).TLSClientConfig.Certificates) { + gotTransport := got.Client.Transport.(*http.Transport) + wantTransport := tt.want.Client.Transport.(*http.Transport) + if gotTransport.TLSClientConfig.GetClientCertificate == nil { + t.Error("LoadClient() transport does not define GetClientCertificate") + } else if !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs) { t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) + } else { + crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) + if err != nil { + t.Errorf("LoadClient() GetClientCertificate error = %v", err) + } else if !reflect.DeepEqual(*crt, wantTransport.TLSClientConfig.Certificates[0]) { + t.Errorf("LoadClient() GetClientCertificate crt = %#v, want %#v", *crt, wantTransport.TLSClientConfig.Certificates[0]) + } } + } }) } diff --git a/ca/identity/identity.go b/ca/identity/identity.go index d6aee85b..d37628f1 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -203,6 +203,18 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) { } } +// GetClientCertificateFunc returns a method that can be used as the +// GetClientCertificate property in a tls.Config. +func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) + if err != nil { + return nil, errors.Wrap(err, "error loading identity certificate") + } + return &crt, nil + } +} + // Renewer is that interface that a renew client must implement. type Renewer interface { GetRootCAs() *x509.CertPool