diff --git a/api/ssh.go b/api/ssh.go index fbaa8c5a..9d0bbc14 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -317,7 +317,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { var identityCertificate []Certificate if cr := body.IdentityCSR.CertificateRequest; cr != nil { ctx := authority.NewContextWithSkipTokenReuse(r.Context()) - ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) diff --git a/authority/authorize.go b/authority/authorize.go index 1e35afe0..f14574a8 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -214,7 +214,7 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner. var opts = []interface{}{errs.WithKeyVal("token", token)} switch m := provisioner.MethodFromContext(ctx); m { - case provisioner.SignMethod: + case provisioner.SignMethod, provisioner.SignIdentityMethod: signOpts, err := a.authorizeSign(ctx, token) return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) case provisioner.RevokeMethod: diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index be641973..e95feedd 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -336,7 +336,7 @@ func (p *AWS) Init(config Config) (err error) { // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *AWS) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { payload, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign") @@ -363,7 +363,7 @@ func (p *AWS) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro net.ParseIP(doc.PrivateIP), }), emailAddressesValidator(nil), - urisValidator(nil), + newURIsValidator(ctx, nil), ) // Template options diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 05f51456..02be1ba9 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -695,8 +695,9 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) case emailAddressesValidator: assert.Equals(t, v, nil) - case urisValidator: - assert.Equals(t, v, nil) + case *urisValidator: + assert.Equals(t, v.uris, nil) + assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) case *x509NamePolicyValidator: diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 76bcebb6..a9d5d1fa 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -316,7 +316,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *Azure) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { _, name, group, subscription, identityObjectID, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign") @@ -382,7 +382,7 @@ func (p *Azure) AuthorizeSign(_ context.Context, token string) ([]SignOption, er dnsNamesValidator([]string{name}), ipAddressesValidator(nil), emailAddressesValidator(nil), - urisValidator(nil), + newURIsValidator(ctx, nil), ) // Enforce SANs in the template. diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 51d46c5a..f262ffbc 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -560,8 +560,9 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case emailAddressesValidator: assert.Equals(t, v, nil) - case urisValidator: - assert.Equals(t, v, nil) + case *urisValidator: + assert.Equals(t, v.uris, nil) + assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) case *x509NamePolicyValidator: diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index b6274f8f..2296b1b0 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -223,7 +223,7 @@ func (p *GCP) Init(config Config) (err error) { // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *GCP) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign") @@ -254,7 +254,7 @@ func (p *GCP) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro }), ipAddressesValidator(nil), emailAddressesValidator(nil), - urisValidator(nil), + newURIsValidator(ctx, nil), ) // Template SANs diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 7705b44a..ef791614 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -567,8 +567,9 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case emailAddressesValidator: assert.Equals(t, v, nil) - case urisValidator: - assert.Equals(t, v, nil) + case *urisValidator: + assert.Equals(t, v.uris, nil) + assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) case *x509NamePolicyValidator: diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 6c5ee657..c98d78f2 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -150,7 +150,7 @@ func (p *JWK) AuthorizeRevoke(_ context.Context, token string) error { } // AuthorizeSign validates the given token. -func (p *JWK) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") @@ -192,7 +192,7 @@ func (p *JWK) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro // validators commonNameValidator(claims.Subject), defaultPublicKeyValidator{}, - defaultSANsValidator(claims.SANs), + newDefaultSANsValidator(ctx, claims.SANs), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController(data, linkedca.Webhook_X509), diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 19cee4fb..bffe1141 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -315,8 +315,9 @@ func TestJWK_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) - case defaultSANsValidator: - assert.Equals(t, []string(v), tt.sans) + case *defaultSANsValidator: + assert.Equals(t, v.sans, tt.sans) + assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go index 01dda2ed..19aa6224 100644 --- a/authority/provisioner/method.go +++ b/authority/provisioner/method.go @@ -14,6 +14,8 @@ type methodKey struct{} const ( // SignMethod is the method used to sign X.509 certificates. SignMethod Method = iota + // SignIdentityMethod is the method used to sign X.509 identity certificates. + SignIdentityMethod // RevokeMethod is the method used to revoke X.509 certificates. RevokeMethod // RenewMethod is the method used to renew X.509 certificates. @@ -33,6 +35,8 @@ func (m Method) String() string { switch m { case SignMethod: return "sign-method" + case SignIdentityMethod: + return "sign-identity-method" case RevokeMethod: return "revoke-method" case RenewMethod: diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index 6c24bd00..66c523dc 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -389,7 +389,7 @@ func (v nebulaSANsValidator) Valid(req *x509.CertificateRequest) error { } } if len(req.URIs) > 0 { - if err := urisValidator(uris).Valid(req); err != nil { + if err := newURIsValidator(context.Background(), uris).Valid(req); err != nil { return err } } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 782a3598..fec9b9f6 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" @@ -233,16 +234,28 @@ func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error { } // urisValidator validates the URI SANs of a certificate request. -type urisValidator []*url.URL +type urisValidator struct { + ctx context.Context + uris []*url.URL +} + +func newURIsValidator(ctx context.Context, uris []*url.URL) *urisValidator { + return &urisValidator{ctx, uris} +} // Valid checks that certificate request IP Addresses match those configured in // the bootstrap (token) flow. func (v urisValidator) Valid(req *x509.CertificateRequest) error { + // SignIdentityMethod does not need to validate URIs. + if MethodFromContext(v.ctx) == SignIdentityMethod { + return nil + } + if len(req.URIs) == 0 { return nil } want := make(map[string]bool) - for _, u := range v { + for _, u := range v.uris { want[u.String()] = true } got := make(map[string]bool) @@ -250,26 +263,33 @@ func (v urisValidator) Valid(req *x509.CertificateRequest) error { got[u.String()] = true } if !reflect.DeepEqual(want, got) { - return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v) + return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v.uris) } return nil } // defaultsSANsValidator stores a set of SANs to eventually validate 1:1 against // the SANs in an x509 certificate request. -type defaultSANsValidator []string +type defaultSANsValidator struct { + ctx context.Context + sans []string +} + +func newDefaultSANsValidator(ctx context.Context, sans []string) *defaultSANsValidator { + return &defaultSANsValidator{ctx, sans} +} // Valid verifies that the SANs stored in the validator match 1:1 with those // requested in the x509 certificate request. func (v defaultSANsValidator) Valid(req *x509.CertificateRequest) (err error) { - dnsNames, ips, emails, uris := x509util.SplitSANs(v) + dnsNames, ips, emails, uris := x509util.SplitSANs(v.sans) if err = dnsNamesValidator(dnsNames).Valid(req); err != nil { return } else if err = emailAddressesValidator(emails).Valid(req); err != nil { return } else if err = ipAddressesValidator(ips).Valid(req); err != nil { return - } else if err = urisValidator(uris).Valid(req); err != nil { + } else if err = newURIsValidator(v.ctx, uris).Valid(req); err != nil { return } return diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index e36d051f..5a55aa86 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -227,23 +228,26 @@ func Test_urisValidator_Valid(t *testing.T) { fu, err := url.Parse("https://unexpected.com") assert.FatalError(t, err) + signContext := NewContextWithMethod(context.Background(), SignMethod) + signIdentityContext := NewContextWithMethod(context.Background(), SignIdentityMethod) + type args struct { req *x509.CertificateRequest } tests := []struct { name string - v urisValidator + v *urisValidator args args wantErr bool }{ - {"ok0", []*url.URL{}, args{&x509.CertificateRequest{URIs: []*url.URL{}}}, false}, - {"ok1", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u1}}}, false}, - {"ok2", []*url.URL{u1, u2}, args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, false}, - {"ok3", []*url.URL{u2, u1, u3}, args{&x509.CertificateRequest{URIs: []*url.URL{u3, u2, u1}}}, false}, - {"ok3", []*url.URL{u2, u1, u3}, args{&x509.CertificateRequest{}}, false}, - {"fail1", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u2}}}, true}, - {"fail2", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, true}, - {"fail3", []*url.URL{u1, u2}, args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, true}, + {"ok0", newURIsValidator(signContext, []*url.URL{}), args{&x509.CertificateRequest{URIs: []*url.URL{}}}, false}, + {"ok1", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u1}}}, false}, + {"ok2", newURIsValidator(signContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, false}, + {"ok3", newURIsValidator(signContext, []*url.URL{u2, u1, u3}), args{&x509.CertificateRequest{URIs: []*url.URL{u3, u2, u1}}}, false}, + {"ok4", newURIsValidator(signIdentityContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, false}, + {"fail1", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u2}}}, true}, + {"fail2", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, true}, + {"fail3", newURIsValidator(signContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -257,13 +261,19 @@ func Test_urisValidator_Valid(t *testing.T) { func Test_defaultSANsValidator_Valid(t *testing.T) { type test struct { csr *x509.CertificateRequest + ctx context.Context expectedSANs []string err error } + + signContext := NewContextWithMethod(context.Background(), SignMethod) + signIdentityContext := NewContextWithMethod(context.Background(), SignIdentityMethod) + tests := map[string]func() test{ "fail/dnsNamesValidator": func() test { return test{ csr: &x509.CertificateRequest{DNSNames: []string{"foo", "bar"}}, + ctx: signContext, expectedSANs: []string{"foo"}, err: errors.New("certificate request does not contain the valid DNS names"), } @@ -271,6 +281,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { "fail/emailAddressesValidator": func() test { return test{ csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}}, + ctx: signContext, expectedSANs: []string{"dcow@fx.com"}, err: errors.New("certificate request does not contain the valid email addresses"), } @@ -278,6 +289,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { "fail/ipAddressesValidator": func() test { return test{ csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}}, + ctx: signContext, expectedSANs: []string{"127.0.0.1"}, err: errors.New("certificate request does not contain the valid IP addresses"), } @@ -289,16 +301,29 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { assert.FatalError(t, err) return test{ csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}}, + ctx: signContext, expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, err: errors.New("certificate request does not contain the valid URIs"), } }, + "ok/urisBadValidator-SignIdentity": func() test { + u1, err := url.Parse("https://google.com") + assert.FatalError(t, err) + u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959") + assert.FatalError(t, err) + return test{ + csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}}, + ctx: signIdentityContext, + expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, + } + }, "ok": func() test { u1, err := url.Parse("https://google.com") assert.FatalError(t, err) u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959") assert.FatalError(t, err) return test{ + ctx: signContext, csr: &x509.CertificateRequest{ DNSNames: []string{"foo", "bar"}, EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}, @@ -312,7 +337,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() - if err := defaultSANsValidator(tt.expectedSANs).Valid(tt.csr); err != nil { + if err := newDefaultSANsValidator(tt.ctx, tt.expectedSANs).Valid(tt.csr); err != nil { if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) { assert.True(t, strings.Contains(err.Error(), tt.err.Error()), fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error())) diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index b6e78697..9b1f2b08 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -194,7 +194,7 @@ func (p *X5C) AuthorizeRevoke(_ context.Context, token string) error { } // AuthorizeSign validates the given token. -func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") @@ -244,7 +244,7 @@ func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro }, // validators commonNameValidator(claims.Subject), - defaultSANsValidator(claims.SANs), + newDefaultSANsValidator(ctx, claims.SANs), defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index f9a2604b..22545446 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -460,7 +460,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { + ctx := NewContextWithMethod(context.Background(), SignIdentityMethod) + if opts, err := tc.p.AuthorizeSign(ctx, tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { @@ -489,8 +490,9 @@ func TestX5C_AuthorizeSign(t *testing.T) { case commonNameValidator: assert.Equals(t, string(v), "foo") case defaultPublicKeyValidator: - case defaultSANsValidator: - assert.Equals(t, []string(v), tc.sans) + case *defaultSANsValidator: + assert.Equals(t, v.sans, tc.sans) + assert.Equals(t, MethodFromContext(v.ctx), SignIdentityMethod) case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())