|
|
|
@ -11,10 +11,10 @@ import (
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/smallstep/assert"
|
|
|
|
|
"github.com/smallstep/certificates/api/render"
|
|
|
|
|
"github.com/smallstep/certificates/authority/provisioner/wire"
|
|
|
|
|
sassert "github.com/stretchr/testify/assert"
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func TestACMEChallenge_Validate(t *testing.T) {
|
|
|
|
@ -27,14 +27,20 @@ func TestACMEChallenge_Validate(t *testing.T) {
|
|
|
|
|
{"dns-01", DNS_01, false},
|
|
|
|
|
{"tls-alpn-01", TLS_ALPN_01, false},
|
|
|
|
|
{"device-attest-01", DEVICE_ATTEST_01, false},
|
|
|
|
|
{"wire-oidc-01", DEVICE_ATTEST_01, false},
|
|
|
|
|
{"wire-dpop-01", DEVICE_ATTEST_01, false},
|
|
|
|
|
{"uppercase", "HTTP-01", false},
|
|
|
|
|
{"fail", "http-02", true},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
if err := tt.c.Validate(); (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("ACMEChallenge.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
err := tt.c.Validate()
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
assert.Error(t, err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -53,26 +59,24 @@ func TestACMEAttestationFormat_Validate(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
if err := tt.f.Validate(); (err != nil) != tt.wantErr {
|
|
|
|
|
t.Errorf("ACMEAttestationFormat.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
|
|
|
|
err := tt.f.Validate()
|
|
|
|
|
if tt.wantErr {
|
|
|
|
|
assert.Error(t, err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestACME_Getters(t *testing.T) {
|
|
|
|
|
p, err := generateACME()
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
id := "acme/" + p.Name
|
|
|
|
|
if got := p.GetID(); got != id {
|
|
|
|
|
t.Errorf("ACME.GetID() = %v, want %v", got, id)
|
|
|
|
|
}
|
|
|
|
|
if got := p.GetName(); got != p.Name {
|
|
|
|
|
t.Errorf("ACME.GetName() = %v, want %v", got, p.Name)
|
|
|
|
|
}
|
|
|
|
|
if got := p.GetType(); got != TypeACME {
|
|
|
|
|
t.Errorf("ACME.GetType() = %v, want %v", got, TypeACME)
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
id := "acme/test@acme-provisioner.com"
|
|
|
|
|
assert.Equal(t, id, p.GetID())
|
|
|
|
|
assert.Equal(t, "test@acme-provisioner.com", p.GetName())
|
|
|
|
|
assert.Equal(t, TypeACME, p.GetType())
|
|
|
|
|
kid, key, ok := p.GetEncryptedKey()
|
|
|
|
|
if kid != "" || key != "" || ok == true {
|
|
|
|
|
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
|
|
|
@ -82,13 +86,9 @@ func TestACME_Getters(t *testing.T) {
|
|
|
|
|
|
|
|
|
|
func TestACME_Init(t *testing.T) {
|
|
|
|
|
appleCA, err := os.ReadFile("testdata/certs/apple-att-ca.crt")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
yubicoCA, err := os.ReadFile("testdata/certs/yubico-piv-ca.crt")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
fakeWireDPoPKey := []byte(`-----BEGIN PUBLIC KEY-----
|
|
|
|
|
MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
|
|
|
|
|
-----END PUBLIC KEY-----`)
|
|
|
|
@ -224,11 +224,11 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
|
|
|
|
|
t.Log(string(tc.p.AttestationRoots))
|
|
|
|
|
err := tc.p.Init(config)
|
|
|
|
|
if tc.err != nil {
|
|
|
|
|
sassert.EqualError(t, err, tc.err.Error())
|
|
|
|
|
assert.EqualError(t, err, tc.err.Error())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sassert.NoError(t, err)
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -244,12 +244,12 @@ func TestACME_AuthorizeRenew(t *testing.T) {
|
|
|
|
|
tests := map[string]func(*testing.T) test{
|
|
|
|
|
"fail/renew-disabled": func(t *testing.T) test {
|
|
|
|
|
p, err := generateACME()
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
// disable renewal
|
|
|
|
|
disable := true
|
|
|
|
|
p.Claims = &Claims{DisableRenewal: &disable}
|
|
|
|
|
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
return test{
|
|
|
|
|
p: p,
|
|
|
|
|
cert: &x509.Certificate{
|
|
|
|
@ -262,7 +262,7 @@ func TestACME_AuthorizeRenew(t *testing.T) {
|
|
|
|
|
},
|
|
|
|
|
"ok": func(t *testing.T) test {
|
|
|
|
|
p, err := generateACME()
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
return test{
|
|
|
|
|
p: p,
|
|
|
|
|
cert: &x509.Certificate{
|
|
|
|
@ -275,16 +275,19 @@ func TestACME_AuthorizeRenew(t *testing.T) {
|
|
|
|
|
for name, tt := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := tt(t)
|
|
|
|
|
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
|
|
|
|
|
sc, ok := err.(render.StatusCodedError)
|
|
|
|
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
|
|
|
|
if assert.NotNil(t, tc.err) {
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
|
|
|
err := tc.p.AuthorizeRenew(context.Background(), tc.cert)
|
|
|
|
|
if tc.err != nil {
|
|
|
|
|
if assert.Implements(t, (*render.StatusCodedError)(nil), err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if errors.As(err, &sc) {
|
|
|
|
|
assert.Equal(t, tc.code, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
assert.Nil(t, tc.err)
|
|
|
|
|
assert.EqualError(t, err, tc.err.Error())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -299,7 +302,7 @@ func TestACME_AuthorizeSign(t *testing.T) {
|
|
|
|
|
tests := map[string]func(*testing.T) test{
|
|
|
|
|
"ok": func(t *testing.T) test {
|
|
|
|
|
p, err := generateACME()
|
|
|
|
|
assert.FatalError(t, err)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
return test{
|
|
|
|
|
p: p,
|
|
|
|
|
token: "foo",
|
|
|
|
@ -309,39 +312,43 @@ func TestACME_AuthorizeSign(t *testing.T) {
|
|
|
|
|
for name, tt := range tests {
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
tc := tt(t)
|
|
|
|
|
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
|
|
|
|
if assert.NotNil(t, tc.err) {
|
|
|
|
|
sc, ok := err.(render.StatusCodedError)
|
|
|
|
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
|
|
|
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
|
|
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
|
|
|
opts, err := tc.p.AuthorizeSign(context.Background(), tc.token)
|
|
|
|
|
if tc.err != nil {
|
|
|
|
|
if assert.Implements(t, (*render.StatusCodedError)(nil), err) {
|
|
|
|
|
var sc render.StatusCodedError
|
|
|
|
|
if errors.As(err, &sc) {
|
|
|
|
|
assert.Equal(t, tc.code, sc.StatusCode())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
|
|
|
|
|
assert.Equals(t, 8, len(opts)) // number of SignOptions returned
|
|
|
|
|
for _, o := range opts {
|
|
|
|
|
switch v := o.(type) {
|
|
|
|
|
case *ACME:
|
|
|
|
|
case *provisionerExtensionOption:
|
|
|
|
|
assert.Equals(t, v.Type, TypeACME)
|
|
|
|
|
assert.Equals(t, v.Name, tc.p.GetName())
|
|
|
|
|
assert.Equals(t, v.CredentialID, "")
|
|
|
|
|
assert.Len(t, 0, v.KeyValuePairs)
|
|
|
|
|
case *forceCNOption:
|
|
|
|
|
assert.Equals(t, v.ForceCN, tc.p.ForceCN)
|
|
|
|
|
case profileDefaultDuration:
|
|
|
|
|
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
|
|
|
|
|
case defaultPublicKeyValidator:
|
|
|
|
|
case *validityValidator:
|
|
|
|
|
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
|
|
|
|
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
|
|
|
|
case *x509NamePolicyValidator:
|
|
|
|
|
assert.Equals(t, nil, v.policyEngine)
|
|
|
|
|
case *WebhookController:
|
|
|
|
|
assert.Len(t, 0, v.webhooks)
|
|
|
|
|
default:
|
|
|
|
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
|
|
|
|
}
|
|
|
|
|
assert.EqualError(t, err, tc.err.Error())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
if assert.NotNil(t, opts) {
|
|
|
|
|
assert.Len(t, opts, 8) // number of SignOptions returned
|
|
|
|
|
for _, o := range opts {
|
|
|
|
|
switch v := o.(type) {
|
|
|
|
|
case *ACME:
|
|
|
|
|
case *provisionerExtensionOption:
|
|
|
|
|
assert.Equal(t, v.Type, TypeACME)
|
|
|
|
|
assert.Equal(t, v.Name, tc.p.GetName())
|
|
|
|
|
assert.Equal(t, v.CredentialID, "")
|
|
|
|
|
assert.Len(t, v.KeyValuePairs, 0)
|
|
|
|
|
case *forceCNOption:
|
|
|
|
|
assert.Equal(t, v.ForceCN, tc.p.ForceCN)
|
|
|
|
|
case profileDefaultDuration:
|
|
|
|
|
assert.Equal(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
|
|
|
|
|
case defaultPublicKeyValidator:
|
|
|
|
|
case *validityValidator:
|
|
|
|
|
assert.Equal(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
|
|
|
|
assert.Equal(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
|
|
|
|
case *x509NamePolicyValidator:
|
|
|
|
|
assert.Equal(t, nil, v.policyEngine)
|
|
|
|
|
case *WebhookController:
|
|
|
|
|
assert.Len(t, v.webhooks, 0)
|
|
|
|
|
default:
|
|
|
|
|
require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -372,10 +379,14 @@ func TestACME_IsChallengeEnabled(t *testing.T) {
|
|
|
|
|
{"ok dns-01 enabled", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, DNS_01}, true},
|
|
|
|
|
{"ok tls-alpn-01 enabled", fields{[]ACMEChallenge{"http-01", "dns-01", "tls-alpn-01"}}, args{ctx, TLS_ALPN_01}, true},
|
|
|
|
|
{"ok device-attest-01 enabled", fields{[]ACMEChallenge{"device-attest-01", "dns-01"}}, args{ctx, DEVICE_ATTEST_01}, true},
|
|
|
|
|
{"ok wire-oidc-01 enabled", fields{[]ACMEChallenge{"wire-oidc-01"}}, args{ctx, WIREOIDC_01}, true},
|
|
|
|
|
{"ok wire-dpop-01 enabled", fields{[]ACMEChallenge{"wire-dpop-01"}}, args{ctx, WIREDPOP_01}, true},
|
|
|
|
|
{"fail http-01", fields{[]ACMEChallenge{"dns-01"}}, args{ctx, "http-01"}, false},
|
|
|
|
|
{"fail dns-01", fields{[]ACMEChallenge{"http-01", "tls-alpn-01"}}, args{ctx, "dns-01"}, false},
|
|
|
|
|
{"fail tls-alpn-01", fields{[]ACMEChallenge{"http-01", "dns-01", "device-attest-01"}}, args{ctx, "tls-alpn-01"}, false},
|
|
|
|
|
{"fail device-attest-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "device-attest-01"}, false},
|
|
|
|
|
{"fail wire-oidc-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "wire-oidc-01"}, false},
|
|
|
|
|
{"fail wire-dpop-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "wire-dpop-01"}, false},
|
|
|
|
|
{"fail unknown", fields{[]ACMEChallenge{"http-01", "dns-01", "tls-alpn-01", "device-attest-01"}}, args{ctx, "unknown"}, false},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
@ -383,9 +394,8 @@ func TestACME_IsChallengeEnabled(t *testing.T) {
|
|
|
|
|
p := &ACME{
|
|
|
|
|
Challenges: tt.fields.Challenges,
|
|
|
|
|
}
|
|
|
|
|
if got := p.IsChallengeEnabled(tt.args.ctx, tt.args.challenge); got != tt.want {
|
|
|
|
|
t.Errorf("ACME.AuthorizeChallenge() = %v, want %v", got, tt.want)
|
|
|
|
|
}
|
|
|
|
|
got := p.IsChallengeEnabled(tt.args.ctx, tt.args.challenge)
|
|
|
|
|
assert.Equal(t, tt.want, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -419,9 +429,8 @@ func TestACME_IsAttestationFormatEnabled(t *testing.T) {
|
|
|
|
|
p := &ACME{
|
|
|
|
|
AttestationFormats: tt.fields.AttestationFormats,
|
|
|
|
|
}
|
|
|
|
|
if got := p.IsAttestationFormatEnabled(tt.args.ctx, tt.args.format); got != tt.want {
|
|
|
|
|
t.Errorf("ACME.IsAttestationFormatEnabled() = %v, want %v", got, tt.want)
|
|
|
|
|
}
|
|
|
|
|
got := p.IsAttestationFormatEnabled(tt.args.ctx, tt.args.format)
|
|
|
|
|
assert.Equal(t, tt.want, got)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|