pull/299/head
max furman 4 years ago
parent d25e7f64c2
commit 71d87b4e61

@ -449,7 +449,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
}
} else {
if assert.Nil(t, tc.err) {
assert.Len(t, 8, got)
assert.Len(t, 6, got)
}
}
})

@ -10,6 +10,7 @@ import (
"encoding/hex"
"encoding/pem"
"fmt"
"net"
"net/http"
"net/url"
"strings"
@ -529,7 +530,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err)
type args struct {
token string
token, cn string
}
tests := []struct {
name string
@ -539,24 +540,24 @@ func TestAWS_AuthorizeSign(t *testing.T) {
code int
wantErr bool
}{
{"ok", p1, args{t1}, 5, http.StatusOK, false},
{"ok", p2, args{t2}, 7, http.StatusOK, false},
{"ok", p2, args{t2Hostname}, 7, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP}, 7, http.StatusOK, false},
{"ok", p1, args{t4}, 5, http.StatusOK, false},
{"fail account", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{failSubject}, 0, http.StatusUnauthorized, true},
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
{"fail account", p1, args{failAccount}, 0, http.StatusUnauthorized, true},
{"fail instanceID", p1, args{failInstanceID}, 0, http.StatusUnauthorized, true},
{"fail privateIP", p1, args{failPrivateIP}, 0, http.StatusUnauthorized, true},
{"fail region", p1, args{failRegion}, 0, http.StatusUnauthorized, true},
{"fail exp", p1, args{failExp}, 0, http.StatusUnauthorized, true},
{"fail nbf", p1, args{failNbf}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail instance age", p2, args{failInstanceAge}, 0, http.StatusUnauthorized, true},
{"ok", p1, args{t1, "foo.local"}, 5, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 9, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 9, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 9, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 5, http.StatusOK, false},
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
{"fail issuer", p1, args{token: failIssuer}, 0, http.StatusUnauthorized, true},
{"fail audience", p1, args{token: failAudience}, 0, http.StatusUnauthorized, true},
{"fail account", p1, args{token: failAccount}, 0, http.StatusUnauthorized, true},
{"fail instanceID", p1, args{token: failInstanceID}, 0, http.StatusUnauthorized, true},
{"fail privateIP", p1, args{token: failPrivateIP}, 0, http.StatusUnauthorized, true},
{"fail region", p1, args{token: failRegion}, 0, http.StatusUnauthorized, true},
{"fail exp", p1, args{token: failExp}, 0, http.StatusUnauthorized, true},
{"fail nbf", p1, args{token: failNbf}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{token: failKey}, 0, http.StatusUnauthorized, true},
{"fail instance age", p2, args{token: failInstanceAge}, 0, http.StatusUnauthorized, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -571,6 +572,33 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAWS))
assert.Equals(t, v.Name, tt.aws.GetName())
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
assert.Len(t, 2, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration())
case commonNameValidator:
assert.Equals(t, string(v), tt.args.cn)
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration())
case ipAddressesValidator:
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
case emailAddressesValidator:
assert.Equals(t, v, nil)
case urisValidator:
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
}
}
})
}

@ -432,7 +432,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
wantErr bool
}{
{"ok", p1, args{t1}, 4, http.StatusOK, false},
{"ok", p2, args{t2}, 6, http.StatusOK, false},
{"ok", p2, args{t2}, 9, http.StatusOK, false},
{"ok", p1, args{t11}, 4, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
@ -456,6 +456,33 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAzure))
assert.Equals(t, v.Name, tt.azure.GetName())
assert.Equals(t, v.CredentialID, tt.azure.TenantID)
assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration())
case commonNameValidator:
assert.Equals(t, string(v), "virtualMachine")
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration())
case ipAddressesValidator:
assert.Equals(t, v, nil)
case emailAddressesValidator:
assert.Equals(t, v, nil)
case urisValidator:
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"virtualMachine"})
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
}
}
})
}

@ -516,7 +516,7 @@ func TestGCP_AuthorizeSign(t *testing.T) {
wantErr bool
}{
{"ok", p1, args{t1}, 4, http.StatusOK, false},
{"ok", p2, args{t2}, 6, http.StatusOK, false},
{"ok", p2, args{t2}, 9, http.StatusOK, false},
{"ok", p3, args{t3}, 4, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
@ -545,6 +545,33 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Equals(t, sc.StatusCode(), tt.code)
} else {
assert.Len(t, tt.wantLen, got)
for _, o := range got {
switch v := o.(type) {
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeGCP))
assert.Equals(t, v.Name, tt.gcp.GetName())
assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0])
assert.Len(t, 4, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration())
case commonNameSliceValidator:
assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration())
case ipAddressesValidator:
assert.Equals(t, v, nil)
case emailAddressesValidator:
assert.Equals(t, v, nil)
case urisValidator:
assert.Equals(t, v, nil)
case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
}
}
})
}

@ -157,8 +157,8 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
// validators
commonNameValidator(claims.Subject),
defaultSANsValidator(claims.SANs),
defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
}, nil
}

@ -6,7 +6,6 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"net"
"net/http"
"strings"
"testing"
@ -253,18 +252,36 @@ func TestJWK_AuthorizeSign(t *testing.T) {
token string
}
tests := []struct {
name string
prov *JWK
args args
code int
err error
dns []string
emails []string
ips []net.IP
name string
prov *JWK
args args
code int
err error
sans []string
}{
{name: "fail-signature", prov: p1, args: args{failSig}, code: http.StatusUnauthorized, err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive")},
{"ok-sans", p1, args{t1}, http.StatusOK, nil, []string{"foo"}, []string{"max@smallstep.com"}, []net.IP{net.ParseIP("127.0.0.1")}},
{"ok-no-sans", p1, args{t2}, http.StatusOK, nil, []string{"subject"}, []string{}, []net.IP{}},
{
name: "fail-signature",
prov: p1,
args: args{failSig},
code: http.StatusUnauthorized,
err: errors.New("jwk.AuthorizeSign: jwk.authorizeToken; error parsing jwk claims: square/go-jose: error in cryptographic primitive"),
},
{
name: "ok-sans",
prov: p1,
args: args{t1},
code: http.StatusOK,
err: nil,
sans: []string{"127.0.0.1", "max@smallstep.com", "foo"},
},
{
name: "ok-no-sans",
prov: p1,
args: args{t2},
code: http.StatusOK,
err: nil,
sans: []string{"subject"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -278,7 +295,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
}
} else {
if assert.NotNil(t, got) {
assert.Len(t, 8, got)
assert.Len(t, 6, got)
for _, o := range got {
switch v := o.(type) {
case *provisionerExtensionOption:
@ -291,15 +308,11 @@ func TestJWK_AuthorizeSign(t *testing.T) {
case commonNameValidator:
assert.Equals(t, string(v), "subject")
case defaultPublicKeyValidator:
case dnsNamesValidator:
assert.Equals(t, []string(v), tt.dns)
case emailAddressesValidator:
assert.Equals(t, []string(v), tt.emails)
case ipAddressesValidator:
assert.Equals(t, []net.IP(v), tt.ips)
case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans)
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}

@ -426,7 +426,15 @@ func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
if err != nil {
return err
}
crt.ExtraExtensions = append(crt.ExtraExtensions, ext)
// NOTE: HACK.
// Prepend the provisioner extension. In the auth.Sign code we will
// force the resulting certificate to only have one extension, the
// first stepOIDProvisioner that is found in the ExtraExtensions.
// A client could pass a csr containing a malicious stepOIDProvisioner
// ExtraExtension. If we were to append (rather than prepend) the correct
// stepOIDProvisioner extension, then the resulting certificate would
// contain the malicious extension, rather than the one applied by step-ca.
crt.ExtraExtensions = append([]pkix.Extension{ext}, crt.ExtraExtensions...)
return nil
}
}

@ -356,6 +356,7 @@ func Test_ExtraExtsEnforcer_Enforce(t *testing.T) {
e1 := pkix.Extension{Id: []int{1, 2, 3, 4, 5}, Critical: false, Value: []byte("foo")}
e2 := pkix.Extension{Id: []int{2, 2, 2}, Critical: false, Value: []byte("bar")}
stepExt := pkix.Extension{Id: stepOIDProvisioner, Critical: false, Value: []byte("baz")}
fakeStepExt := pkix.Extension{Id: stepOIDProvisioner, Critical: false, Value: []byte("zap")}
type test struct {
cert *x509.Certificate
check func(*x509.Certificate)
@ -379,7 +380,7 @@ func Test_ExtraExtsEnforcer_Enforce(t *testing.T) {
},
"ok/step-provisioner-ext": func() test {
return test{
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{e1, stepExt, e2}},
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{e1, stepExt, fakeStepExt, e2}},
check: func(cert *x509.Certificate) {
assert.Equals(t, len(cert.ExtraExtensions), 1)
assert.Equals(t, cert.ExtraExtensions[0], stepExt)
@ -668,6 +669,47 @@ func Test_profileDefaultDuration_Option(t *testing.T) {
}
}
func Test_newProvisionerExtension_Option(t *testing.T) {
type test struct {
cert *x509.Certificate
valid func(*x509.Certificate)
}
tests := map[string]func() test{
"ok/one-element": func() test {
return test{
cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) {
if assert.Len(t, 1, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner)
}
},
}
},
"ok/prepend": func() test {
return test{
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
valid: func(cert *x509.Certificate) {
if assert.Len(t, 3, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner)
assert.False(t, ext.Critical)
}
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
prof := &x509util.Leaf{}
prof.SetSubject(tt.cert)
assert.FatalError(t, newProvisionerExtensionOption(TypeJWK, "foo", "bar", "baz", "zap").Option(Options{})(prof))
tt.valid(prof.Subject())
})
}
}
func Test_profileLimitDuration_Option(t *testing.T) {
n, fn := mockNow()
defer fn()

@ -2,7 +2,6 @@ package provisioner
import (
"context"
"net"
"net/http"
"testing"
"time"
@ -407,13 +406,11 @@ func TestX5C_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err)
type test struct {
p *X5C
token string
code int
err error
dns []string
emails []string
ips []net.IP
p *X5C
token string
code int
err error
sans []string
}
tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test {
@ -434,11 +431,9 @@ func TestX5C_AuthorizeSign(t *testing.T) {
withX5CHdr(certs))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
dns: []string{"foo"},
emails: []string{},
ips: []net.IP{},
p: p,
token: tok,
sans: []string{"foo"},
}
},
"ok/multi-sans": func(t *testing.T) test {
@ -449,11 +444,9 @@ func TestX5C_AuthorizeSign(t *testing.T) {
withX5CHdr(certs))
assert.FatalError(t, err)
return test{
p: p,
token: tok,
dns: []string{"foo"},
emails: []string{"max@smallstep.com"},
ips: []net.IP{net.ParseIP("127.0.0.1")},
p: p,
token: tok,
sans: []string{"127.0.0.1", "foo", "max@smallstep.com"},
}
},
}
@ -470,7 +463,7 @@ func TestX5C_AuthorizeSign(t *testing.T) {
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
assert.Equals(t, len(opts), 6)
for _, o := range opts {
switch v := o.(type) {
case *provisionerExtensionOption:
@ -487,21 +480,15 @@ func TestX5C_AuthorizeSign(t *testing.T) {
case commonNameValidator:
assert.Equals(t, string(v), "foo")
case defaultPublicKeyValidator:
case dnsNamesValidator:
assert.Equals(t, []string(v), tc.dns)
case emailAddressesValidator:
assert.Equals(t, []string(v), tc.emails)
case ipAddressesValidator:
assert.Equals(t, []net.IP(v), tc.ips)
case defaultSANsValidator:
assert.Equals(t, []string(v), tc.sans)
case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
tot++
}
assert.Equals(t, tot, 8)
}
}
}

@ -89,6 +89,17 @@ func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateReques
return csr
}
func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) {
return func(csr *x509.CertificateRequest) {
csr.ExtraExtensions = exts
}
}
type basicConstraints struct {
IsCA bool `asn1:"optional"`
MaxPathLen int `asn1:"optional,default:-1"`
}
func TestAuthority_Sign(t *testing.T) {
pub, priv, err := keys.GenerateDefaultKeyPair()
assert.FatalError(t, err)
@ -271,7 +282,16 @@ ZYtQ9Ot36qc=
}
},
"ok with enforced modifier": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
bcExt := pkix.Extension{}
bcExt.Id = asn1.ObjectIdentifier{2, 5, 29, 19}
bcExt.Critical = false
bcExt.Value, err = asn1.Marshal(basicConstraints{IsCA: true, MaxPathLen: 4})
assert.FatalError(t, err)
csr := getCSR(t, priv, setExtraExtsCSR([]pkix.Extension{
bcExt,
{Id: stepOIDProvisioner, Value: []byte("foo")},
{Id: []int{1, 1, 1}, Value: []byte("bar")}}))
now := time.Now().UTC()
enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{
NotBefore: now,
@ -347,19 +367,26 @@ ZYtQ9Ot36qc=
// Verify Provisioner OID
found := 0
for _, ext := range leaf.Extensions {
id := ext.Id.String()
if id != stepOIDProvisioner.String() {
continue
switch {
case ext.Id.Equal(stepOIDProvisioner):
found++
val := stepProvisionerASN1{}
_, err := asn1.Unmarshal(ext.Value, &val)
assert.FatalError(t, err)
assert.Equals(t, val.Type, provisionerTypeJWK)
assert.Equals(t, val.Name, []byte(p.Name))
assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID))
// Basic Constraints
case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})):
val := basicConstraints{}
_, err := asn1.Unmarshal(ext.Value, &val)
assert.FatalError(t, err)
assert.False(t, val.IsCA, false)
assert.Equals(t, val.MaxPathLen, 0)
}
found++
val := stepProvisionerASN1{}
_, err := asn1.Unmarshal(ext.Value, &val)
assert.FatalError(t, err)
assert.Equals(t, val.Type, provisionerTypeJWK)
assert.Equals(t, val.Name, []byte(p.Name))
assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID))
}
assert.Equals(t, found, 1)
assert.Len(t, 6, leaf.Extensions)
realIntermediate, err := x509.ParseCertificate(a.x509Issuer.Raw)
assert.FatalError(t, err)

Loading…
Cancel
Save