package vaultcas import ( "bytes" "context" "crypto/x509" "encoding/json" "fmt" "net/http" "net/http/httptest" "net/url" "reflect" "testing" "time" vault "github.com/hashicorp/vault/api" "github.com/smallstep/certificates/cas/apiv1" "go.step.sm/crypto/pemutil" ) var ( testCertificateSigned = `-----BEGIN CERTIFICATE----- MIIB/DCCAaKgAwIBAgIQHHFuGMz0cClfde5kqP5prTAKBggqhkjOPQQDAjAqMSgw JgYDVQQDEx9Hb29nbGUgQ0FTIFRlc3QgSW50ZXJtZWRpYXRlIENBMB4XDTIwMDkx NTAwMDQ0M1oXDTMwMDkxMzAwMDQ0MFowHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0 ZXAuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEMqNCiXMvbn74LsHzRv+8 17m9vEzH6RHrg3m82e0uEc36+fZWV/zJ9SKuONmnl5VP79LsjL5SVH0RDj73U2XO DKOBtjCBszAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG AQUFBwMCMB0GA1UdDgQWBBRTA2cTs7PCNjnps/+T0dS8diqv0DAfBgNVHSMEGDAW gBRIOVqyLDSlErJLuWWEvRm5UU1r1TBCBgwrBgEEAYKkZMYoQAIEMjAwEwhjbG91 ZGNhcxMkZDhkMThhNjgtNTI5Ni00YWYzLWFlNGItMmY4NzdkYTNmYmQ5MAoGCCqG SM49BAMCA0gAMEUCIGxl+pqJ50WYWUqK2l4V1FHoXSi0Nht5kwTxFxnWZu1xAiEA zemu3bhWLFaGg3s8i+HTEhw4RqkHP74vF7AVYp88bAw= -----END CERTIFICATE-----` testCertificateCsrEc = `-----BEGIN CERTIFICATE REQUEST----- MIHoMIGPAgEAMA0xCzAJBgNVBAMTAkVDMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD QgAEUVVVZGD6eUrB20T/qrjKZoYzseQ18AIm9jtUNpQn5hIClpdk2zKy5bja3iUa nmqRKCIz/B/MU55zuNDeckqqX6AgMB4GCSqGSIb3DQEJDjERMA8wDQYDVR0RBAYw BIICRUMwCgYIKoZIzj0EAwIDSAAwRQIhAJxpWyH7cctbzcnK1JBWDAmc/G61bq9y otHrQDfYvS8bAiBVGQz2cfO2SqhvkkQbOqWUFjk1wHzISvlTjyc3IJ7FLw== -----END CERTIFICATE REQUEST-----` testCertificateCsrRsa = `-----BEGIN CERTIFICATE REQUEST----- MIICdDCCAVwCAQAwDjEMMAoGA1UEAxMDUlNBMIIBIjANBgkqhkiG9w0BAQEFAAOC AQ8AMIIBCgKCAQEAxe5XLSZrTCzzH0FJCXvZwghAY5XztzjseSRcm0jL8Q7nvNWi Vpu1n7EmfVU9b8sbvtVYqMQV+hMdj2C/NIw4Yal4Wg+BgunYOrRqfY7oDm4csG0R g5v0h2yQw14kqVrftNyojX0Nv/CPboCGl64PA9zsEXQTB3Y1AUWrUGPiBWNACYIH mjv70Ay9JKBBAqov38I7nka/RgYAl5DCHzU2vvODriBYFWagnzycA4Ni5EKTz93W SPdDEhkWi3ugUqal3SvgHl8re+8d7ghLn85Y3TFuyU2nSMDPHaymsiNFw1mRwOw3 lAseidHJkPQs7q6FiYXaeqetf1j/gw0n23ZogwIDAQABoCEwHwYJKoZIhvcNAQkO MRIwEDAOBgNVHREEBzAFggNSU0EwDQYJKoZIhvcNAQELBQADggEBALnO5vcDkgGO GQoSINa2NmNFxAtYQGYHok5KXYX+S+etmOmDrmrhsl/pSjN3GPCPlThFlbLStB70 oJw67nEjGf0hPEBVlm+qFUsYQ1KGRZFAWDSMQ//pU225XFDCmlzHfV7gZjSkP9GN Gc5VECOzx6hAFR+IEL/l/1GG5HHkPPrr/8OvuIfm2V5ofYmhsXMVVYH52qPofMAV B8UdNnZK3nyLdUqVd+PYUUJmN4bJ8YfxofKKgbLkhvkKp4OZ9vkwUi2+61NdHTf2 wIauOyxEoTlJpU6oA/sxu/2Ht2DP+8y6mognLBuKklE/VH3/2iqQWyg1NV5hyg3b loVSdLsIh5Y= -----END CERTIFICATE REQUEST-----` testCertificateCsrEd25519 = `-----BEGIN CERTIFICATE REQUEST----- MIGuMGICAQAwDjEMMAoGA1UEAxMDT0tQMCowBQYDK2VwAyEAopc6daK4zYR6BDAM pV/v53oR/ewbtrkHZQkN/amFMLagITAfBgkqhkiG9w0BCQ4xEjAQMA4GA1UdEQQH MAWCA09LUDAFBgMrZXADQQDJi47MAgl/WKAz+V/kDu1k/zbKk1nrHHAUonbofHUW M6ihSD43+awq3BPeyPbToeH5orSH9l3MuTfbxPb5BVEH -----END CERTIFICATE REQUEST-----` testRootCertificate = `-----BEGIN CERTIFICATE----- MIIBeDCCAR+gAwIBAgIQcXWWjtSZ/PAyH8D1Ou4L9jAKBggqhkjOPQQDAjAbMRkw FwYDVQQDExBDbG91ZENBUyBSb290IENBMB4XDTIwMTAyNzIyNTM1NFoXDTMwMTAy NzIyNTM1NFowGzEZMBcGA1UEAxMQQ2xvdWRDQVMgUm9vdCBDQTBZMBMGByqGSM49 AgEGCCqGSM49AwEHA0IABIySHA4b78Yu4LuGhZIlv/PhNwXz4ZoV1OUZQ0LrK3vj B13O12DLZC5uj1z3kxdQzXUttSbtRv49clMpBiTpsZKjRTBDMA4GA1UdDwEB/wQE AwIBBjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSZ+t9RMHbFTl5BatM3 5bJlHPOu3DAKBggqhkjOPQQDAgNHADBEAiASah6gg0tVM3WI0meCQ4SEKk7Mjhbv +SmhuZHWV1QlXQIgRXNyWcpVUrAoG6Uy1KQg07LDpF5dFeK9InrDxSJAkVo= -----END CERTIFICATE-----` testRootFingerprint = `62e816cbac5c501b7705e18415503852798dfbcd67062f06bcb4af67c290e3c8` ) func mustParseCertificate(t *testing.T, pemCert string) *x509.Certificate { t.Helper() crt := parseCertificates(pemCert)[0] return crt } func mustParseCertificateRequest(t *testing.T, pemData string) *x509.CertificateRequest { t.Helper() csr, err := pemutil.ParseCertificateRequest([]byte(pemData)) if err != nil { t.Fatal(err) } return csr } func testCAHelper(t *testing.T) (*url.URL, *vault.Client) { t.Helper() writeJSON := func(w http.ResponseWriter, v interface{}) { _ = json.NewEncoder(w).Encode(v) } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case r.RequestURI == "/v1/auth/approle/login": w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{ "auth": { "client_token": "98a4c7ab-b1fe-361b-ba0b-e307aacfd587" } }`) case r.RequestURI == "/v1/pki/sign/ec": w.WriteHeader(http.StatusOK) cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case r.RequestURI == "/v1/pki/sign/rsa": w.WriteHeader(http.StatusOK) cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case r.RequestURI == "/v1/pki/sign/ed25519": w.WriteHeader(http.StatusOK) cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case r.RequestURI == "/v1/pki/cert/ca_chain": w.WriteHeader(http.StatusOK) cert := map[string]interface{}{"data": map[string]interface{}{"certificate": testCertificateSigned + "\n" + testRootCertificate}} writeJSON(w, cert) return case r.RequestURI == "/v1/pki/revoke": buf := new(bytes.Buffer) buf.ReadFrom(r.Body) m := make(map[string]string) json.Unmarshal(buf.Bytes(), &m) switch { case m["serial_number"] == "1c-71-6e-18-cc-f4-70-29-5f-75-ee-64-a8-fe-69-ad": w.WriteHeader(http.StatusOK) return case m["serial_number"] == "01-e2-40": w.WriteHeader(http.StatusOK) return // both case m["serial_number"] == "01-34-3e": w.WriteHeader(http.StatusOK) return default: w.WriteHeader(http.StatusNotFound) } default: w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, `{"error":"not found"}`) } })) t.Cleanup(func() { srv.Close() }) u, err := url.Parse(srv.URL) if err != nil { srv.Close() t.Fatal(err) } config := vault.DefaultConfig() config.Address = srv.URL client, err := vault.NewClient(config) if err != nil { srv.Close() t.Fatal(err) } return u, client } func TestNew_register(t *testing.T) { caURL, _ := testCAHelper(t) fn, ok := apiv1.LoadCertificateAuthorityServiceNewFunc(apiv1.VaultCAS) if !ok { t.Errorf("apiv1.Register() ok = %v, want true", ok) return } _, err := fn(context.Background(), apiv1.Options{ CertificateAuthority: caURL.String(), CertificateAuthorityFingerprint: testRootFingerprint, Config: json.RawMessage(`{ "AuthType": "approle", "AuthOptions": {"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false} }`), }) if err != nil { t.Errorf("New() error = %v", err) return } } func TestVaultCAS_Type(t *testing.T) { tests := []struct { name string want apiv1.Type }{ {"ok", apiv1.VaultCAS}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &VaultCAS{} if got := c.Type(); got != tt.want { t.Errorf("VaultCAS.Type() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_CreateCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", } type fields struct { client *vault.Client options VaultOptions } type args struct { req *apiv1.CreateCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.CreateCertificateResponse wantErr bool }{ {"ok ec", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEc), Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), CertificateChain: nil, }, false}, {"ok rsa", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrRsa), Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), CertificateChain: nil, }, false}, {"ok ed25519", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEd25519), Lifetime: time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: mustParseCertificate(t, testCertificateSigned), CertificateChain: nil, }, false}, {"fail CSR", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: nil, Lifetime: time.Hour, }}, nil, true}, {"fail lifetime", fields{client, options}, args{&apiv1.CreateCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEc), Lifetime: 0, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &VaultCAS{ client: tt.fields.client, config: tt.fields.options, } got, err := c.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.CreateCertificate() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_GetCertificateAuthority(t *testing.T) { caURL, client := testCAHelper(t) type fields struct { client *vault.Client options VaultOptions fingerprint string } type args struct { req *apiv1.GetCertificateAuthorityRequest } options := VaultOptions{ PKIMountPath: "pki", } rootCert := parseCertificates(testRootCertificate)[0] tests := []struct { name string fields fields args args want *apiv1.GetCertificateAuthorityResponse wantErr bool }{ {"ok", fields{client, options, testRootFingerprint}, args{&apiv1.GetCertificateAuthorityRequest{ Name: caURL.String(), }}, &apiv1.GetCertificateAuthorityResponse{ RootCertificate: rootCert, }, false}, {"fail fingerprint", fields{client, options, "fail"}, args{&apiv1.GetCertificateAuthorityRequest{ Name: caURL.String(), }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &VaultCAS{ client: tt.fields.client, fingerprint: tt.fields.fingerprint, config: tt.fields.options, } got, err := s.GetCertificateAuthority(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.GetCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.GetCertificateAuthority() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_RevokeCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", } type fields struct { client *vault.Client options VaultOptions } type args struct { req *apiv1.RevokeCertificateRequest } testCrt := parseCertificates(testCertificateSigned)[0] tests := []struct { name string fields fields args args want *apiv1.RevokeCertificateResponse wantErr bool }{ {"ok serial number", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "123456", Certificate: nil, }}, &apiv1.RevokeCertificateResponse{}, false}, {"ok certificate", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "", Certificate: testCrt, }}, &apiv1.RevokeCertificateResponse{ Certificate: testCrt, }, false}, {"ok both", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "78910", Certificate: testCrt, }}, &apiv1.RevokeCertificateResponse{ Certificate: testCrt, }, false}, {"fail serial string", fields{client, options}, args{&apiv1.RevokeCertificateRequest{ SerialNumber: "fail", Certificate: nil, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &VaultCAS{ client: tt.fields.client, config: tt.fields.options, } got, err := s.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.RevokeCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.RevokeCertificate() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_RenewCertificate(t *testing.T) { _, client := testCAHelper(t) options := VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", } type fields struct { client *vault.Client options VaultOptions } type args struct { req *apiv1.RenewCertificateRequest } tests := []struct { name string fields fields args args want *apiv1.RenewCertificateResponse wantErr bool }{ {"not implemented", fields{client, options}, args{&apiv1.RenewCertificateRequest{ CSR: mustParseCertificateRequest(t, testCertificateCsrEc), Lifetime: time.Hour, }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &VaultCAS{ client: tt.fields.client, config: tt.fields.options, } got, err := s.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.RenewCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.RenewCertificate() = %v, want %v", got, tt.want) } }) } } func TestVaultCAS_loadOptions(t *testing.T) { tests := []struct { name string raw string want *VaultOptions wantErr bool }{ { "ok mandatory PKIRole PKIRoleEd25519", `{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "role", PKIRoleEC: "role", PKIRoleEd25519: "ed25519", }, false, }, { "ok mandatory PKIRole PKIRoleEC", `{"PKIRoleDefault": "role", "PKIRoleEC": "ec"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "role", PKIRoleEC: "ec", PKIRoleEd25519: "role", }, false, }, { "ok mandatory PKIRole PKIRoleRSA", `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "role", PKIRoleEd25519: "role", }, false, }, { "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519", `{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "default", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", }, false, }, { "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519 with useless PKIRoleDefault", `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`, &VaultOptions{ PKIMountPath: "pki", PKIRoleDefault: "role", PKIRoleRSA: "rsa", PKIRoleEC: "ec", PKIRoleEd25519: "ed25519", }, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := loadOptions(json.RawMessage(tt.raw)) if (err != nil) != tt.wantErr { t.Errorf("VaultCAS.loadOptions() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("VaultCAS.loadOptions() = %v, want %v", got, tt.want) } }) } }