Allow for identity certificate signing (in sshSign) by skipping validators (#1572)

- skip urisValidator for identity certificate signing. Implemented
  by building the validator with the context in a hacky way.
pull/1575/head
Max 7 months ago committed by GitHub
parent 06750b03fe
commit 9f84f7ce35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -317,7 +317,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
var identityCertificate []Certificate
if cr := body.IdentityCSR.CertificateRequest; cr != nil {
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil {
render.Error(w, errs.UnauthorizedErr(err))

@ -214,7 +214,7 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner.
var opts = []interface{}{errs.WithKeyVal("token", token)}
switch m := provisioner.MethodFromContext(ctx); m {
case provisioner.SignMethod:
case provisioner.SignMethod, provisioner.SignIdentityMethod:
signOpts, err := a.authorizeSign(ctx, token)
return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...)
case provisioner.RevokeMethod:

@ -336,7 +336,7 @@ func (p *AWS) Init(config Config) (err error) {
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *AWS) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) {
func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
payload, err := p.authorizeToken(token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign")
@ -363,7 +363,7 @@ func (p *AWS) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
net.ParseIP(doc.PrivateIP),
}),
emailAddressesValidator(nil),
urisValidator(nil),
newURIsValidator(ctx, nil),
)
// Template options

@ -695,8 +695,9 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
case emailAddressesValidator:
assert.Equals(t, v, nil)
case urisValidator:
assert.Equals(t, v, nil)
case *urisValidator:
assert.Equals(t, v.uris, nil)
assert.Equals(t, MethodFromContext(v.ctx), SignMethod)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
case *x509NamePolicyValidator:

@ -316,7 +316,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *Azure) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) {
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
_, name, group, subscription, identityObjectID, err := p.authorizeToken(token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
@ -382,7 +382,7 @@ func (p *Azure) AuthorizeSign(_ context.Context, token string) ([]SignOption, er
dnsNamesValidator([]string{name}),
ipAddressesValidator(nil),
emailAddressesValidator(nil),
urisValidator(nil),
newURIsValidator(ctx, nil),
)
// Enforce SANs in the template.

@ -560,8 +560,9 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil)
case emailAddressesValidator:
assert.Equals(t, v, nil)
case urisValidator:
assert.Equals(t, v, nil)
case *urisValidator:
assert.Equals(t, v.uris, nil)
assert.Equals(t, MethodFromContext(v.ctx), SignMethod)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"virtualMachine"})
case *x509NamePolicyValidator:

@ -223,7 +223,7 @@ func (p *GCP) Init(config Config) (err error) {
// AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation.
func (p *GCP) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) {
func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign")
@ -254,7 +254,7 @@ func (p *GCP) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
}),
ipAddressesValidator(nil),
emailAddressesValidator(nil),
urisValidator(nil),
newURIsValidator(ctx, nil),
)
// Template SANs

@ -567,8 +567,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil)
case emailAddressesValidator:
assert.Equals(t, v, nil)
case urisValidator:
assert.Equals(t, v, nil)
case *urisValidator:
assert.Equals(t, v.uris, nil)
assert.Equals(t, MethodFromContext(v.ctx), SignMethod)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case *x509NamePolicyValidator:

@ -150,7 +150,7 @@ func (p *JWK) AuthorizeRevoke(_ context.Context, token string) error {
}
// AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) {
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
@ -192,7 +192,7 @@ func (p *JWK) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
// validators
commonNameValidator(claims.Subject),
defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs),
newDefaultSANsValidator(ctx, claims.SANs),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
p.ctl.newWebhookController(data, linkedca.Webhook_X509),

@ -315,8 +315,9 @@ func TestJWK_AuthorizeSign(t *testing.T) {
case *validityValidator:
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans)
case *defaultSANsValidator:
assert.Equals(t, v.sans, tt.sans)
assert.Equals(t, MethodFromContext(v.ctx), SignMethod)
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
case *WebhookController:

@ -14,6 +14,8 @@ type methodKey struct{}
const (
// SignMethod is the method used to sign X.509 certificates.
SignMethod Method = iota
// SignIdentityMethod is the method used to sign X.509 identity certificates.
SignIdentityMethod
// RevokeMethod is the method used to revoke X.509 certificates.
RevokeMethod
// RenewMethod is the method used to renew X.509 certificates.
@ -33,6 +35,8 @@ func (m Method) String() string {
switch m {
case SignMethod:
return "sign-method"
case SignIdentityMethod:
return "sign-identity-method"
case RevokeMethod:
return "revoke-method"
case RenewMethod:

@ -389,7 +389,7 @@ func (v nebulaSANsValidator) Valid(req *x509.CertificateRequest) error {
}
}
if len(req.URIs) > 0 {
if err := urisValidator(uris).Valid(req); err != nil {
if err := newURIsValidator(context.Background(), uris).Valid(req); err != nil {
return err
}
}

@ -1,6 +1,7 @@
package provisioner
import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
@ -233,16 +234,28 @@ func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error {
}
// urisValidator validates the URI SANs of a certificate request.
type urisValidator []*url.URL
type urisValidator struct {
ctx context.Context
uris []*url.URL
}
func newURIsValidator(ctx context.Context, uris []*url.URL) *urisValidator {
return &urisValidator{ctx, uris}
}
// Valid checks that certificate request IP Addresses match those configured in
// the bootstrap (token) flow.
func (v urisValidator) Valid(req *x509.CertificateRequest) error {
// SignIdentityMethod does not need to validate URIs.
if MethodFromContext(v.ctx) == SignIdentityMethod {
return nil
}
if len(req.URIs) == 0 {
return nil
}
want := make(map[string]bool)
for _, u := range v {
for _, u := range v.uris {
want[u.String()] = true
}
got := make(map[string]bool)
@ -250,26 +263,33 @@ func (v urisValidator) Valid(req *x509.CertificateRequest) error {
got[u.String()] = true
}
if !reflect.DeepEqual(want, got) {
return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v)
return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v.uris)
}
return nil
}
// defaultsSANsValidator stores a set of SANs to eventually validate 1:1 against
// the SANs in an x509 certificate request.
type defaultSANsValidator []string
type defaultSANsValidator struct {
ctx context.Context
sans []string
}
func newDefaultSANsValidator(ctx context.Context, sans []string) *defaultSANsValidator {
return &defaultSANsValidator{ctx, sans}
}
// Valid verifies that the SANs stored in the validator match 1:1 with those
// requested in the x509 certificate request.
func (v defaultSANsValidator) Valid(req *x509.CertificateRequest) (err error) {
dnsNames, ips, emails, uris := x509util.SplitSANs(v)
dnsNames, ips, emails, uris := x509util.SplitSANs(v.sans)
if err = dnsNamesValidator(dnsNames).Valid(req); err != nil {
return
} else if err = emailAddressesValidator(emails).Valid(req); err != nil {
return
} else if err = ipAddressesValidator(ips).Valid(req); err != nil {
return
} else if err = urisValidator(uris).Valid(req); err != nil {
} else if err = newURIsValidator(v.ctx, uris).Valid(req); err != nil {
return
}
return

@ -1,6 +1,7 @@
package provisioner
import (
"context"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
@ -227,23 +228,26 @@ func Test_urisValidator_Valid(t *testing.T) {
fu, err := url.Parse("https://unexpected.com")
assert.FatalError(t, err)
signContext := NewContextWithMethod(context.Background(), SignMethod)
signIdentityContext := NewContextWithMethod(context.Background(), SignIdentityMethod)
type args struct {
req *x509.CertificateRequest
}
tests := []struct {
name string
v urisValidator
v *urisValidator
args args
wantErr bool
}{
{"ok0", []*url.URL{}, args{&x509.CertificateRequest{URIs: []*url.URL{}}}, false},
{"ok1", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u1}}}, false},
{"ok2", []*url.URL{u1, u2}, args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, false},
{"ok3", []*url.URL{u2, u1, u3}, args{&x509.CertificateRequest{URIs: []*url.URL{u3, u2, u1}}}, false},
{"ok3", []*url.URL{u2, u1, u3}, args{&x509.CertificateRequest{}}, false},
{"fail1", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u2}}}, true},
{"fail2", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, true},
{"fail3", []*url.URL{u1, u2}, args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, true},
{"ok0", newURIsValidator(signContext, []*url.URL{}), args{&x509.CertificateRequest{URIs: []*url.URL{}}}, false},
{"ok1", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u1}}}, false},
{"ok2", newURIsValidator(signContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, false},
{"ok3", newURIsValidator(signContext, []*url.URL{u2, u1, u3}), args{&x509.CertificateRequest{URIs: []*url.URL{u3, u2, u1}}}, false},
{"ok4", newURIsValidator(signIdentityContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, false},
{"fail1", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u2}}}, true},
{"fail2", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, true},
{"fail3", newURIsValidator(signContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -257,13 +261,19 @@ func Test_urisValidator_Valid(t *testing.T) {
func Test_defaultSANsValidator_Valid(t *testing.T) {
type test struct {
csr *x509.CertificateRequest
ctx context.Context
expectedSANs []string
err error
}
signContext := NewContextWithMethod(context.Background(), SignMethod)
signIdentityContext := NewContextWithMethod(context.Background(), SignIdentityMethod)
tests := map[string]func() test{
"fail/dnsNamesValidator": func() test {
return test{
csr: &x509.CertificateRequest{DNSNames: []string{"foo", "bar"}},
ctx: signContext,
expectedSANs: []string{"foo"},
err: errors.New("certificate request does not contain the valid DNS names"),
}
@ -271,6 +281,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
"fail/emailAddressesValidator": func() test {
return test{
csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}},
ctx: signContext,
expectedSANs: []string{"dcow@fx.com"},
err: errors.New("certificate request does not contain the valid email addresses"),
}
@ -278,6 +289,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
"fail/ipAddressesValidator": func() test {
return test{
csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}},
ctx: signContext,
expectedSANs: []string{"127.0.0.1"},
err: errors.New("certificate request does not contain the valid IP addresses"),
}
@ -289,16 +301,29 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
assert.FatalError(t, err)
return test{
csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}},
ctx: signContext,
expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"},
err: errors.New("certificate request does not contain the valid URIs"),
}
},
"ok/urisBadValidator-SignIdentity": func() test {
u1, err := url.Parse("https://google.com")
assert.FatalError(t, err)
u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959")
assert.FatalError(t, err)
return test{
csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}},
ctx: signIdentityContext,
expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"},
}
},
"ok": func() test {
u1, err := url.Parse("https://google.com")
assert.FatalError(t, err)
u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959")
assert.FatalError(t, err)
return test{
ctx: signContext,
csr: &x509.CertificateRequest{
DNSNames: []string{"foo", "bar"},
EmailAddresses: []string{"max@fx.com", "mariano@fx.com"},
@ -312,7 +337,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
if err := defaultSANsValidator(tt.expectedSANs).Valid(tt.csr); err != nil {
if err := newDefaultSANsValidator(tt.ctx, tt.expectedSANs).Valid(tt.csr); err != nil {
if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) {
assert.True(t, strings.Contains(err.Error(), tt.err.Error()),
fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error()))

@ -194,7 +194,7 @@ func (p *X5C) AuthorizeRevoke(_ context.Context, token string) error {
}
// AuthorizeSign validates the given token.
func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) {
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
@ -244,7 +244,7 @@ func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
},
// validators
commonNameValidator(claims.Subject),
defaultSANsValidator(claims.SANs),
newDefaultSANsValidator(ctx, claims.SANs),
defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),

@ -460,7 +460,8 @@ func TestX5C_AuthorizeSign(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
ctx := NewContextWithMethod(context.Background(), SignIdentityMethod)
if opts, err := tc.p.AuthorizeSign(ctx, tc.token); err != nil {
if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
@ -489,8 +490,9 @@ func TestX5C_AuthorizeSign(t *testing.T) {
case commonNameValidator:
assert.Equals(t, string(v), "foo")
case defaultPublicKeyValidator:
case defaultSANsValidator:
assert.Equals(t, []string(v), tc.sans)
case *defaultSANsValidator:
assert.Equals(t, v.sans, tc.sans)
assert.Equals(t, MethodFromContext(v.ctx), SignIdentityMethod)
case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())

Loading…
Cancel
Save