Merge pull request #1708 from smallstep/herman/csr-expires-header

Add `Expires` header to CRL endpoint
pull/1725/head
Herman Slatman 3 months ago committed by GitHub
commit bb296c9d19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -54,7 +54,7 @@ type Authority interface {
GetRoots() ([]*x509.Certificate, error) GetRoots() ([]*x509.Certificate, error)
GetFederation() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error)
Version() authority.Version Version() authority.Version
GetCertificateRevocationList() ([]byte, error) GetCertificateRevocationList() (*authority.CertificateRevocationListInfo, error)
} }
// mustAuthority will be replaced on unit tests. // mustAuthority will be replaced on unit tests.

@ -200,7 +200,7 @@ type mockAuthority struct {
getEncryptedKey func(kid string) (string, error) getEncryptedKey func(kid string) (string, error)
getRoots func() ([]*x509.Certificate, error) getRoots func() ([]*x509.Certificate, error)
getFederation func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error)
getCRL func() ([]byte, error) getCRL func() (*authority.CertificateRevocationListInfo, error)
signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
@ -214,12 +214,12 @@ type mockAuthority struct {
version func() authority.Version version func() authority.Version
} }
func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) { func (m *mockAuthority) GetCertificateRevocationList() (*authority.CertificateRevocationListInfo, error) {
if m.getCRL != nil { if m.getCRL != nil {
return m.getCRL() return m.getCRL()
} }
return m.ret1.([]byte), m.err return m.ret1.(*authority.CertificateRevocationListInfo), m.err
} }
// TODO: remove once Authorize is deprecated. // TODO: remove once Authorize is deprecated.
@ -789,45 +789,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (
return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err
} }
func Test_CRLGeneration(t *testing.T) {
tests := []struct {
name string
err error
statusCode int
expected []byte
}{
{"empty", nil, http.StatusOK, nil},
}
chiCtx := chi.NewRouteContext()
req := httptest.NewRequest("GET", "http://example.com/crl", http.NoBody)
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockMustAuthority(t, &mockAuthority{ret1: tt.expected, err: tt.err})
w := httptest.NewRecorder()
CRL(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.CRL StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err)
}
if tt.statusCode == 200 {
if !bytes.Equal(bytes.TrimSpace(body), tt.expected) {
t.Errorf("caHandler.Root CRL = %s, wants %s", body, tt.expected)
}
}
})
}
}
func Test_caHandler_Route(t *testing.T) { func Test_caHandler_Route(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority

@ -3,18 +3,32 @@ package api
import ( import (
"encoding/pem" "encoding/pem"
"net/http" "net/http"
"time"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs"
) )
// CRL is an HTTP handler that returns the current CRL in DER or PEM format // CRL is an HTTP handler that returns the current CRL in DER or PEM format
func CRL(w http.ResponseWriter, r *http.Request) { func CRL(w http.ResponseWriter, r *http.Request) {
crlBytes, err := mustAuthority(r.Context()).GetCertificateRevocationList() crlInfo, err := mustAuthority(r.Context()).GetCertificateRevocationList()
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
if crlInfo == nil {
render.Error(w, errs.New(http.StatusNotFound, "no CRL available"))
return
}
expires := crlInfo.ExpiresAt
if expires.IsZero() {
expires = time.Now()
}
w.Header().Add("Expires", expires.Format(time.RFC1123))
_, formatAsPEM := r.URL.Query()["pem"] _, formatAsPEM := r.URL.Query()["pem"]
if formatAsPEM { if formatAsPEM {
w.Header().Add("Content-Type", "application/x-pem-file") w.Header().Add("Content-Type", "application/x-pem-file")
@ -22,11 +36,11 @@ func CRL(w http.ResponseWriter, r *http.Request) {
_ = pem.Encode(w, &pem.Block{ _ = pem.Encode(w, &pem.Block{
Type: "X509 CRL", Type: "X509 CRL",
Bytes: crlBytes, Bytes: crlInfo.Data,
}) })
} else { } else {
w.Header().Add("Content-Type", "application/pkix-crl") w.Header().Add("Content-Type", "application/pkix-crl")
w.Header().Add("Content-Disposition", "attachment; filename=\"crl.der\"") w.Header().Add("Content-Disposition", "attachment; filename=\"crl.der\"")
w.Write(crlBytes) w.Write(crlInfo.Data)
} }
} }

@ -0,0 +1,93 @@
package api
import (
"bytes"
"context"
"encoding/pem"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_CRL(t *testing.T) {
data := []byte{1, 2, 3, 4}
pemData := pem.EncodeToMemory(&pem.Block{
Type: "X509 CRL",
Bytes: data,
})
pemData = bytes.TrimSpace(pemData)
emptyPEMData := pem.EncodeToMemory(&pem.Block{
Type: "X509 CRL",
Bytes: nil,
})
emptyPEMData = bytes.TrimSpace(emptyPEMData)
tests := []struct {
name string
url string
err error
statusCode int
crlInfo *authority.CertificateRevocationListInfo
expectedBody []byte
expectedHeaders http.Header
expectedErrorJSON string
}{
{"ok", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, data, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""},
{"ok/pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, pemData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""},
{"ok/empty", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, nil, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""},
{"ok/empty-pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, emptyPEMData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""},
{"fail/internal", "http://example.com/crl", errs.Wrap(http.StatusInternalServerError, errors.New("failure"), "authority.GetCertificateRevocationList"), http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info."}`},
{"fail/nil", "http://example.com/crl", nil, http.StatusNotFound, nil, nil, http.Header{}, `{"status":404,"message":"no CRL available"}`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockMustAuthority(t, &mockAuthority{ret1: tt.crlInfo, err: tt.err})
chiCtx := chi.NewRouteContext()
req := httptest.NewRequest("GET", tt.url, http.NoBody)
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
w := httptest.NewRecorder()
CRL(w, req)
res := w.Result()
assert.Equal(t, tt.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
require.NoError(t, err)
if tt.statusCode >= 300 {
assert.JSONEq(t, tt.expectedErrorJSON, string(bytes.TrimSpace(body)))
return
}
// check expected header values
for _, h := range []string{"content-type", "content-disposition"} {
v := tt.expectedHeaders.Get(h)
require.NotEmpty(t, v)
actual := res.Header.Get(h)
assert.Equal(t, v, actual)
}
// check expires header value
assert.NotEmpty(t, res.Header.Get("expires"))
t1, err := time.Parse(time.RFC1123, res.Header.Get("expires"))
if assert.NoError(t, err) {
assert.False(t, t1.IsZero())
}
// check body contents
assert.Equal(t, tt.expectedBody, bytes.TrimSpace(body))
})
}
}

@ -696,9 +696,17 @@ func (a *Authority) revokeSSH(crt *ssh.Certificate, rci *db.RevokedCertificateIn
return a.db.RevokeSSH(rci) return a.db.RevokeSSH(rci)
} }
// CertificateRevocationListInfo contains a CRL in DER format and associated metadata.
type CertificateRevocationListInfo struct {
Number int64
ExpiresAt time.Time
Duration time.Duration
Data []byte
}
// GetCertificateRevocationList will return the currently generated CRL from the DB, or a not implemented // GetCertificateRevocationList will return the currently generated CRL from the DB, or a not implemented
// error if the underlying AuthDB does not support CRLs // error if the underlying AuthDB does not support CRLs
func (a *Authority) GetCertificateRevocationList() ([]byte, error) { func (a *Authority) GetCertificateRevocationList() (*CertificateRevocationListInfo, error) {
if !a.config.CRL.IsEnabled() { if !a.config.CRL.IsEnabled() {
return nil, errs.Wrap(http.StatusNotFound, errors.Errorf("Certificate Revocation Lists are not enabled"), "authority.GetCertificateRevocationList") return nil, errs.Wrap(http.StatusNotFound, errors.Errorf("Certificate Revocation Lists are not enabled"), "authority.GetCertificateRevocationList")
} }
@ -713,7 +721,12 @@ func (a *Authority) GetCertificateRevocationList() ([]byte, error) {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetCertificateRevocationList") return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.GetCertificateRevocationList")
} }
return crlInfo.DER, nil return &CertificateRevocationListInfo{
Number: crlInfo.Number,
ExpiresAt: crlInfo.ExpiresAt,
Duration: crlInfo.Duration,
Data: crlInfo.DER,
}, nil
} }
// GenerateCertificateRevocationList generates a DER representation of a signed CRL and stores it in the // GenerateCertificateRevocationList generates a DER representation of a signed CRL and stores it in the

@ -24,7 +24,7 @@ import (
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/assert" sassert "github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/policy"
@ -33,6 +33,8 @@ import (
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
var ( var (
@ -80,25 +82,25 @@ func generateCertificate(t *testing.T, commonName string, sans []string, opts ..
t.Helper() t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err) require.NoError(t, err)
cr, err := x509util.CreateCertificateRequest(commonName, sans, priv) cr, err := x509util.CreateCertificateRequest(commonName, sans, priv)
assert.FatalError(t, err) require.NoError(t, err)
template, err := x509util.NewCertificate(cr) template, err := x509util.NewCertificate(cr)
assert.FatalError(t, err) require.NoError(t, err)
cert := template.GetCertificate() cert := template.GetCertificate()
for _, m := range opts { for _, m := range opts {
switch m := m.(type) { switch m := m.(type) {
case provisioner.CertificateModifierFunc: case provisioner.CertificateModifierFunc:
err = m.Modify(cert, provisioner.SignOptions{}) err = m.Modify(cert, provisioner.SignOptions{})
assert.FatalError(t, err) require.NoError(t, err)
case signerFunc: case signerFunc:
cert, err = m(cert, priv.Public()) cert, err = m(cert, priv.Public())
assert.FatalError(t, err) require.NoError(t, err)
default: default:
t.Fatalf("unknown type %T", m) require.Fail(t, "", "unknown type %T", m)
} }
} }
@ -108,36 +110,36 @@ func generateCertificate(t *testing.T, commonName string, sans []string, opts ..
func generateRootCertificate(t *testing.T) (*x509.Certificate, crypto.Signer) { func generateRootCertificate(t *testing.T) (*x509.Certificate, crypto.Signer) {
t.Helper() t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err) require.NoError(t, err)
cr, err := x509util.CreateCertificateRequest("TestRootCA", nil, priv) cr, err := x509util.CreateCertificateRequest("TestRootCA", nil, priv)
assert.FatalError(t, err) require.NoError(t, err)
data := x509util.CreateTemplateData("TestRootCA", nil) data := x509util.CreateTemplateData("TestRootCA", nil)
template, err := x509util.NewCertificate(cr, x509util.WithTemplate(x509util.DefaultRootTemplate, data)) template, err := x509util.NewCertificate(cr, x509util.WithTemplate(x509util.DefaultRootTemplate, data))
assert.FatalError(t, err) require.NoError(t, err)
cert := template.GetCertificate() cert := template.GetCertificate()
cert, err = x509util.CreateCertificate(cert, cert, priv.Public(), priv) cert, err = x509util.CreateCertificate(cert, cert, priv.Public(), priv)
assert.FatalError(t, err) require.NoError(t, err)
return cert, priv return cert, priv
} }
func generateIntermidiateCertificate(t *testing.T, issuer *x509.Certificate, signer crypto.Signer) (*x509.Certificate, crypto.Signer) { func generateIntermidiateCertificate(t *testing.T, issuer *x509.Certificate, signer crypto.Signer) (*x509.Certificate, crypto.Signer) {
t.Helper() t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err) require.NoError(t, err)
cr, err := x509util.CreateCertificateRequest("TestIntermediateCA", nil, priv) cr, err := x509util.CreateCertificateRequest("TestIntermediateCA", nil, priv)
assert.FatalError(t, err) require.NoError(t, err)
data := x509util.CreateTemplateData("TestIntermediateCA", nil) data := x509util.CreateTemplateData("TestIntermediateCA", nil)
template, err := x509util.NewCertificate(cr, x509util.WithTemplate(x509util.DefaultRootTemplate, data)) template, err := x509util.NewCertificate(cr, x509util.WithTemplate(x509util.DefaultRootTemplate, data))
assert.FatalError(t, err) require.NoError(t, err)
cert := template.GetCertificate() cert := template.GetCertificate()
cert, err = x509util.CreateCertificate(cert, issuer, priv.Public(), signer) cert, err = x509util.CreateCertificate(cert, issuer, priv.Public(), signer)
assert.FatalError(t, err) require.NoError(t, err)
return cert, priv return cert, priv
} }
@ -192,9 +194,9 @@ func getCSR(t *testing.T, priv interface{}, opts ...func(*x509.CertificateReques
opt(_csr) opt(_csr)
} }
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, _csr, priv) csrBytes, err := x509.CreateCertificateRequest(rand.Reader, _csr, priv)
assert.FatalError(t, err) require.NoError(t, err)
csr, err := x509.ParseCertificateRequest(csrBytes) csr, err := x509.ParseCertificateRequest(csrBytes)
assert.FatalError(t, err) require.NoError(t, err)
return csr return csr
} }
@ -239,10 +241,10 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error {
func TestAuthority_Sign(t *testing.T) { func TestAuthority_Sign(t *testing.T) {
pub, priv, err := keyutil.GenerateDefaultKeyPair() pub, priv, err := keyutil.GenerateDefaultKeyPair()
assert.FatalError(t, err) require.NoError(t, err)
a := testAuthority(t) a := testAuthority(t)
assert.FatalError(t, err) require.NoError(t, err)
a.config.AuthorityConfig.Template = &ASN1DN{ a.config.AuthorityConfig.Template = &ASN1DN{
Country: "Tazmania", Country: "Tazmania",
Organization: "Acme Co", Organization: "Acme Co",
@ -262,12 +264,12 @@ func TestAuthority_Sign(t *testing.T) {
// Create a token to get test extra opts. // Create a token to get test extra opts.
p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK) p := a.config.AuthorityConfig.Provisioners[1].(*provisioner.JWK)
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err) require.NoError(t, err)
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key) token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err) require.NoError(t, err)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
extraOpts, err := a.Authorize(ctx, token) extraOpts, err := a.Authorize(ctx, token)
assert.FatalError(t, err) require.NoError(t, err)
type signTest struct { type signTest struct {
auth *Authority auth *Authority
@ -372,9 +374,9 @@ W5kR63lNVHBHgQmv5mA8YFsfrJHstaz5k727v2LMHEYIf5/3i16d5zhuxUoaPTYr
ZYtQ9Ot36qc= ZYtQ9Ot36qc=
-----END CERTIFICATE REQUEST-----` -----END CERTIFICATE REQUEST-----`
block, _ := pem.Decode([]byte(shortRSAKeyPEM)) block, _ := pem.Decode([]byte(shortRSAKeyPEM))
assert.FatalError(t, err) require.NoError(t, err)
csr, err := x509.ParseCertificateRequest(block.Bytes) csr, err := x509.ParseCertificateRequest(block.Bytes)
assert.FatalError(t, err) require.NoError(t, err)
return &signTest{ return &signTest{
auth: a, auth: a,
@ -413,10 +415,10 @@ ZYtQ9Ot36qc=
X509: &provisioner.X509Options{Template: `{{ fail "fail message" }}`}, X509: &provisioner.X509Options{Template: `{{ fail "fail message" }}`},
} }
testExtraOpts, err := testAuthority.Authorize(ctx, token) testExtraOpts, err := testAuthority.Authorize(ctx, token)
assert.FatalError(t, err) require.NoError(t, err)
testAuthority.db = &db.MockAuthDB{ testAuthority.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -442,10 +444,10 @@ ZYtQ9Ot36qc=
}, },
} }
testExtraOpts, err := testAuthority.Authorize(ctx, token) testExtraOpts, err := testAuthority.Authorize(ctx, token)
assert.FatalError(t, err) require.NoError(t, err)
testAuthority.db = &db.MockAuthDB{ testAuthority.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -471,10 +473,10 @@ ZYtQ9Ot36qc=
}, },
} }
testExtraOpts, err := testAuthority.Authorize(ctx, token) testExtraOpts, err := testAuthority.Authorize(ctx, token)
assert.FatalError(t, err) require.NoError(t, err)
testAuthority.db = &db.MockAuthDB{ testAuthority.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -492,7 +494,7 @@ ZYtQ9Ot36qc=
aa := testAuthority(t) aa := testAuthority(t)
aa.db = &db.MockAuthDB{ aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -517,7 +519,7 @@ ZYtQ9Ot36qc=
})) }))
aa.db = &db.MockAuthDB{ aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -537,7 +539,7 @@ ZYtQ9Ot36qc=
aa.db = &db.MockAuthDB{ aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
fmt.Println(crt.Subject) fmt.Println(crt.Subject)
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -549,7 +551,7 @@ ZYtQ9Ot36qc=
}, },
} }
engine, err := policy.New(options) engine, err := policy.New(options)
assert.FatalError(t, err) require.NoError(t, err)
aa.policyEngine = engine aa.policyEngine = engine
return &signTest{ return &signTest{
auth: aa, auth: aa,
@ -598,7 +600,7 @@ ZYtQ9Ot36qc=
_a := testAuthority(t) _a := testAuthority(t)
_a.db = &db.MockAuthDB{ _a.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -617,7 +619,7 @@ ZYtQ9Ot36qc=
bcExt.Id = asn1.ObjectIdentifier{2, 5, 29, 19} bcExt.Id = asn1.ObjectIdentifier{2, 5, 29, 19}
bcExt.Critical = false bcExt.Critical = false
bcExt.Value, err = asn1.Marshal(basicConstraints{IsCA: true, MaxPathLen: 4}) bcExt.Value, err = asn1.Marshal(basicConstraints{IsCA: true, MaxPathLen: 4})
assert.FatalError(t, err) require.NoError(t, err)
csr := getCSR(t, priv, setExtraExtsCSR([]pkix.Extension{ csr := getCSR(t, priv, setExtraExtsCSR([]pkix.Extension{
bcExt, bcExt,
@ -632,7 +634,7 @@ ZYtQ9Ot36qc=
_a := testAuthority(t) _a := testAuthority(t)
_a.db = &db.MockAuthDB{ _a.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -663,10 +665,10 @@ ZYtQ9Ot36qc=
}`}, }`},
} }
testExtraOpts, err := testAuthority.Authorize(ctx, token) testExtraOpts, err := testAuthority.Authorize(ctx, token)
assert.FatalError(t, err) require.NoError(t, err)
testAuthority.db = &db.MockAuthDB{ testAuthority.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -697,10 +699,10 @@ ZYtQ9Ot36qc=
}`}, }`},
} }
testExtraOpts, err := testAuthority.Authorize(ctx, token) testExtraOpts, err := testAuthority.Authorize(ctx, token)
assert.FatalError(t, err) require.NoError(t, err)
testAuthority.db = &db.MockAuthDB{ testAuthority.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -737,7 +739,7 @@ ZYtQ9Ot36qc=
_a.config.AuthorityConfig.Template = &ASN1DN{} _a.config.AuthorityConfig.Template = &ASN1DN{}
_a.db = &db.MockAuthDB{ _a.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject, pkix.Name{}) sassert.Equals(t, crt.Subject, pkix.Name{})
return nil return nil
}, },
} }
@ -762,8 +764,8 @@ ZYtQ9Ot36qc=
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
aa.db = &db.MockAuthDB{ aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
assert.Equals(t, crt.CRLDistributionPoints, []string{"http://ca.example.org/leaf.crl"}) sassert.Equals(t, crt.CRLDistributionPoints, []string{"http://ca.example.org/leaf.crl"})
return nil return nil
}, },
} }
@ -783,7 +785,7 @@ ZYtQ9Ot36qc=
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
aa.db = &db.MockAuthDB{ aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error { MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test") sassert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil return nil
}, },
} }
@ -796,7 +798,7 @@ ZYtQ9Ot36qc=
}, },
} }
engine, err := policy.New(options) engine, err := policy.New(options)
assert.FatalError(t, err) require.NoError(t, err)
aa.policyEngine = engine aa.policyEngine = engine
return &signTest{ return &signTest{
auth: aa, auth: aa,
@ -816,13 +818,13 @@ ZYtQ9Ot36qc=
MStoreCertificateChain: func(prov provisioner.Interface, certs ...*x509.Certificate) error { MStoreCertificateChain: func(prov provisioner.Interface, certs ...*x509.Certificate) error {
p, ok := prov.(attProvisioner) p, ok := prov.(attProvisioner)
if assert.True(t, ok) { if assert.True(t, ok) {
assert.Equals(t, &provisioner.AttestationData{ sassert.Equals(t, &provisioner.AttestationData{
PermanentIdentifier: "1234567890", PermanentIdentifier: "1234567890",
}, p.AttestationData()) }, p.AttestationData())
} }
if assert.Len(t, 2, certs) { if assert.Len(t, certs, 2) {
assert.Equals(t, certs[0].Subject.CommonName, "smallstep test") sassert.Equals(t, certs[0].Subject.CommonName, "smallstep test")
assert.Equals(t, certs[1].Subject.CommonName, "smallstep Intermediate CA") sassert.Equals(t, certs[1].Subject.CommonName, "smallstep Intermediate CA")
} }
return nil return nil
}, },
@ -851,26 +853,26 @@ ZYtQ9Ot36qc=
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
var sc render.StatusCodedError var sc render.StatusCodedError
assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") sassert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) sassert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) sassert.HasPrefix(t, err.Error(), tc.err.Error())
var ctxErr *errs.Error var ctxErr *errs.Error
assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") sassert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["csr"], tc.csr) sassert.Equals(t, ctxErr.Details["csr"], tc.csr)
assert.Equals(t, ctxErr.Details["signOptions"], tc.signOpts) sassert.Equals(t, ctxErr.Details["signOptions"], tc.signOpts)
} }
} else { } else {
leaf := certChain[0] leaf := certChain[0]
intermediate := certChain[1] intermediate := certChain[1]
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, leaf.NotBefore, tc.notBefore) sassert.Equals(t, leaf.NotBefore, tc.notBefore)
assert.Equals(t, leaf.NotAfter, tc.notAfter) sassert.Equals(t, leaf.NotAfter, tc.notAfter)
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
if tc.csr.Subject.CommonName == "" { if tc.csr.Subject.CommonName == "" {
assert.Equals(t, leaf.Subject, pkix.Name{}) sassert.Equals(t, leaf.Subject, pkix.Name{})
} else { } else {
assert.Equals(t, leaf.Subject.String(), sassert.Equals(t, leaf.Subject.String(),
pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
@ -879,18 +881,18 @@ ZYtQ9Ot36qc=
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: "smallstep test", CommonName: "smallstep test",
}.String()) }.String())
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) sassert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"})
} }
assert.Equals(t, leaf.Issuer, intermediate.Subject) sassert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) sassert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) sassert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA)
assert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) sassert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
issuer := getDefaultIssuer(a) issuer := getDefaultIssuer(a)
subjectKeyID, err := generateSubjectKeyID(pub) subjectKeyID, err := generateSubjectKeyID(pub)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, leaf.SubjectKeyId, subjectKeyID) sassert.Equals(t, leaf.SubjectKeyId, subjectKeyID)
assert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) sassert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId)
// Verify Provisioner OID // Verify Provisioner OID
found := 0 found := 0
@ -900,18 +902,18 @@ ZYtQ9Ot36qc=
found++ found++
val := stepProvisionerASN1{} val := stepProvisionerASN1{}
_, err := asn1.Unmarshal(ext.Value, &val) _, err := asn1.Unmarshal(ext.Value, &val)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, val.Type, provisionerTypeJWK) sassert.Equals(t, val.Type, provisionerTypeJWK)
assert.Equals(t, val.Name, []byte(p.Name)) sassert.Equals(t, val.Name, []byte(p.Name))
assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID)) sassert.Equals(t, val.CredentialID, []byte(p.Key.KeyID))
// Basic Constraints // Basic Constraints
case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})): case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})):
val := basicConstraints{} val := basicConstraints{}
_, err := asn1.Unmarshal(ext.Value, &val) _, err := asn1.Unmarshal(ext.Value, &val)
assert.FatalError(t, err) require.NoError(t, err)
assert.False(t, val.IsCA, false) assert.False(t, val.IsCA, false)
assert.Equals(t, val.MaxPathLen, 0) sassert.Equals(t, val.MaxPathLen, 0)
// SAN extension // SAN extension
case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 17})): case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 17})):
@ -922,11 +924,11 @@ ZYtQ9Ot36qc=
} }
} }
} }
assert.Equals(t, found, 1) sassert.Equals(t, found, 1)
realIntermediate, err := x509.ParseCertificate(issuer.Raw) realIntermediate, err := x509.ParseCertificate(issuer.Raw)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, intermediate, realIntermediate) sassert.Equals(t, intermediate, realIntermediate)
assert.Len(t, tc.extensionsCount, leaf.Extensions) assert.Len(t, leaf.Extensions, tc.extensionsCount)
} }
} }
}) })
@ -1056,7 +1058,7 @@ func TestAuthority_Renew(t *testing.T) {
for name, genTestCase := range tests { for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc, err := genTestCase() tc, err := genTestCase()
assert.FatalError(t, err) require.NoError(t, err)
var certChain []*x509.Certificate var certChain []*x509.Certificate
if tc.auth != nil { if tc.auth != nil {
@ -1068,19 +1070,19 @@ func TestAuthority_Renew(t *testing.T) {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
var sc render.StatusCodedError var sc render.StatusCodedError
assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") sassert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) sassert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) sassert.HasPrefix(t, err.Error(), tc.err.Error())
var ctxErr *errs.Error var ctxErr *errs.Error
assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") sassert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) sassert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String())
} }
} else { } else {
leaf := certChain[0] leaf := certChain[0]
intermediate := certChain[1] intermediate := certChain[1]
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore)) sassert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore))
assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute))) assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute)))
assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute))) assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute)))
@ -1090,30 +1092,30 @@ func TestAuthority_Renew(t *testing.T) {
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.RawSubject, tc.cert.RawSubject) sassert.Equals(t, leaf.RawSubject, tc.cert.RawSubject)
assert.Equals(t, leaf.Subject.Country, []string{tmplt.Country}) sassert.Equals(t, leaf.Subject.Country, []string{tmplt.Country})
assert.Equals(t, leaf.Subject.Organization, []string{tmplt.Organization}) sassert.Equals(t, leaf.Subject.Organization, []string{tmplt.Organization})
assert.Equals(t, leaf.Subject.Locality, []string{tmplt.Locality}) sassert.Equals(t, leaf.Subject.Locality, []string{tmplt.Locality})
assert.Equals(t, leaf.Subject.StreetAddress, []string{tmplt.StreetAddress}) sassert.Equals(t, leaf.Subject.StreetAddress, []string{tmplt.StreetAddress})
assert.Equals(t, leaf.Subject.Province, []string{tmplt.Province}) sassert.Equals(t, leaf.Subject.Province, []string{tmplt.Province})
assert.Equals(t, leaf.Subject.CommonName, tmplt.CommonName) sassert.Equals(t, leaf.Subject.CommonName, tmplt.CommonName)
assert.Equals(t, leaf.Issuer, intermediate.Subject) sassert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) sassert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) sassert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA)
assert.Equals(t, leaf.ExtKeyUsage, sassert.Equals(t, leaf.ExtKeyUsage,
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"}) sassert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"})
subjectKeyID, err := generateSubjectKeyID(leaf.PublicKey) subjectKeyID, err := generateSubjectKeyID(leaf.PublicKey)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, leaf.SubjectKeyId, subjectKeyID) sassert.Equals(t, leaf.SubjectKeyId, subjectKeyID)
// We did not change the intermediate before renewing. // We did not change the intermediate before renewing.
authIssuer := getDefaultIssuer(tc.auth) authIssuer := getDefaultIssuer(tc.auth)
if issuer.SerialNumber == authIssuer.SerialNumber { if issuer.SerialNumber == authIssuer.SerialNumber {
assert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) sassert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId)
// Compare extensions: they can be in a different order // Compare extensions: they can be in a different order
for _, ext1 := range tc.cert.Extensions { for _, ext1 := range tc.cert.Extensions {
//skip SubjectKeyIdentifier //skip SubjectKeyIdentifier
@ -1133,7 +1135,7 @@ func TestAuthority_Renew(t *testing.T) {
} }
} else { } else {
// We did change the intermediate before renewing. // We did change the intermediate before renewing.
assert.Equals(t, leaf.AuthorityKeyId, authIssuer.SubjectKeyId) sassert.Equals(t, leaf.AuthorityKeyId, authIssuer.SubjectKeyId)
// Compare extensions: they can be in a different order // Compare extensions: they can be in a different order
for _, ext1 := range tc.cert.Extensions { for _, ext1 := range tc.cert.Extensions {
//skip SubjectKeyIdentifier //skip SubjectKeyIdentifier
@ -1161,8 +1163,8 @@ func TestAuthority_Renew(t *testing.T) {
} }
realIntermediate, err := x509.ParseCertificate(authIssuer.Raw) realIntermediate, err := x509.ParseCertificate(authIssuer.Raw)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, intermediate, realIntermediate) sassert.Equals(t, intermediate, realIntermediate)
} }
} }
}) })
@ -1171,7 +1173,7 @@ func TestAuthority_Renew(t *testing.T) {
func TestAuthority_Rekey(t *testing.T) { func TestAuthority_Rekey(t *testing.T) {
pub, _, err := keyutil.GenerateDefaultKeyPair() pub, _, err := keyutil.GenerateDefaultKeyPair()
assert.FatalError(t, err) require.NoError(t, err)
a := testAuthority(t) a := testAuthority(t)
a.config.AuthorityConfig.Template = &ASN1DN{ a.config.AuthorityConfig.Template = &ASN1DN{
@ -1261,7 +1263,7 @@ func TestAuthority_Rekey(t *testing.T) {
for name, genTestCase := range tests { for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc, err := genTestCase() tc, err := genTestCase()
assert.FatalError(t, err) require.NoError(t, err)
var certChain []*x509.Certificate var certChain []*x509.Certificate
if tc.auth != nil { if tc.auth != nil {
@ -1273,19 +1275,19 @@ func TestAuthority_Rekey(t *testing.T) {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
var sc render.StatusCodedError var sc render.StatusCodedError
assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") sassert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) sassert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) sassert.HasPrefix(t, err.Error(), tc.err.Error())
var ctxErr *errs.Error var ctxErr *errs.Error
assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") sassert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String()) sassert.Equals(t, ctxErr.Details["serialNumber"], tc.cert.SerialNumber.String())
} }
} else { } else {
leaf := certChain[0] leaf := certChain[0]
intermediate := certChain[1] intermediate := certChain[1]
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore)) sassert.Equals(t, leaf.NotAfter.Sub(leaf.NotBefore), tc.cert.NotAfter.Sub(cert.NotBefore))
assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute))) assert.True(t, leaf.NotBefore.After(now.Add(-2*time.Minute)))
assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute))) assert.True(t, leaf.NotBefore.Before(now.Add(time.Minute)))
@ -1295,7 +1297,7 @@ func TestAuthority_Rekey(t *testing.T) {
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(), sassert.Equals(t, leaf.Subject.String(),
pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
@ -1304,32 +1306,32 @@ func TestAuthority_Rekey(t *testing.T) {
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: tmplt.CommonName, CommonName: tmplt.CommonName,
}.String()) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) sassert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) sassert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) sassert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA)
assert.Equals(t, leaf.ExtKeyUsage, sassert.Equals(t, leaf.ExtKeyUsage,
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"}) sassert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com", "test"})
// Test Public Key and SubjectKeyId // Test Public Key and SubjectKeyId
expectedPK := tc.pk expectedPK := tc.pk
if tc.pk == nil { if tc.pk == nil {
expectedPK = cert.PublicKey expectedPK = cert.PublicKey
} }
assert.Equals(t, leaf.PublicKey, expectedPK) sassert.Equals(t, leaf.PublicKey, expectedPK)
subjectKeyID, err := generateSubjectKeyID(expectedPK) subjectKeyID, err := generateSubjectKeyID(expectedPK)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, leaf.SubjectKeyId, subjectKeyID) sassert.Equals(t, leaf.SubjectKeyId, subjectKeyID)
if tc.pk == nil { if tc.pk == nil {
assert.Equals(t, leaf.SubjectKeyId, cert.SubjectKeyId) sassert.Equals(t, leaf.SubjectKeyId, cert.SubjectKeyId)
} }
// We did not change the intermediate before renewing. // We did not change the intermediate before renewing.
authIssuer := getDefaultIssuer(tc.auth) authIssuer := getDefaultIssuer(tc.auth)
if issuer.SerialNumber == authIssuer.SerialNumber { if issuer.SerialNumber == authIssuer.SerialNumber {
assert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId) sassert.Equals(t, leaf.AuthorityKeyId, issuer.SubjectKeyId)
// Compare extensions: they can be in a different order // Compare extensions: they can be in a different order
for _, ext1 := range tc.cert.Extensions { for _, ext1 := range tc.cert.Extensions {
//skip SubjectKeyIdentifier //skip SubjectKeyIdentifier
@ -1349,7 +1351,7 @@ func TestAuthority_Rekey(t *testing.T) {
} }
} else { } else {
// We did change the intermediate before renewing. // We did change the intermediate before renewing.
assert.Equals(t, leaf.AuthorityKeyId, authIssuer.SubjectKeyId) sassert.Equals(t, leaf.AuthorityKeyId, authIssuer.SubjectKeyId)
// Compare extensions: they can be in a different order // Compare extensions: they can be in a different order
for _, ext1 := range tc.cert.Extensions { for _, ext1 := range tc.cert.Extensions {
//skip SubjectKeyIdentifier //skip SubjectKeyIdentifier
@ -1377,8 +1379,8 @@ func TestAuthority_Rekey(t *testing.T) {
} }
realIntermediate, err := x509.ParseCertificate(authIssuer.Raw) realIntermediate, err := x509.ParseCertificate(authIssuer.Raw)
assert.FatalError(t, err) require.NoError(t, err)
assert.Equals(t, intermediate, realIntermediate) sassert.Equals(t, intermediate, realIntermediate)
} }
} }
}) })
@ -1413,10 +1415,10 @@ func TestAuthority_GetTLSOptions(t *testing.T) {
for name, genTestCase := range tests { for name, genTestCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc, err := genTestCase() tc, err := genTestCase()
assert.FatalError(t, err) require.NoError(t, err)
opts := tc.auth.GetTLSOptions() opts := tc.auth.GetTLSOptions()
assert.Equals(t, opts, tc.opts) sassert.Equals(t, opts, tc.opts)
}) })
} }
} }
@ -1429,11 +1431,11 @@ func TestAuthority_Revoke(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err) require.NoError(t, err)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err) require.NoError(t, err)
a := testAuthority(t) a := testAuthority(t)
@ -1472,7 +1474,7 @@ func TestAuthority_Revoke(t *testing.T) {
ID: "44", ID: "44",
} }
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
auth: a, auth: a,
@ -1486,9 +1488,9 @@ func TestAuthority_Revoke(t *testing.T) {
err: errors.New("authority.Revoke; no persistence layer configured"), err: errors.New("authority.Revoke; no persistence layer configured"),
code: http.StatusNotImplemented, code: http.StatusNotImplemented,
checkErrDetails: func(err *errs.Error) { checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw) sassert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44") sassert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") sassert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
}, },
} }
}, },
@ -1512,7 +1514,7 @@ func TestAuthority_Revoke(t *testing.T) {
ID: "44", ID: "44",
} }
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
auth: _a, auth: _a,
@ -1526,9 +1528,9 @@ func TestAuthority_Revoke(t *testing.T) {
err: errors.New("authority.Revoke: force"), err: errors.New("authority.Revoke: force"),
code: http.StatusInternalServerError, code: http.StatusInternalServerError,
checkErrDetails: func(err *errs.Error) { checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw) sassert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44") sassert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") sassert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
}, },
} }
}, },
@ -1552,7 +1554,7 @@ func TestAuthority_Revoke(t *testing.T) {
ID: "44", ID: "44",
} }
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
auth: _a, auth: _a,
@ -1566,9 +1568,9 @@ func TestAuthority_Revoke(t *testing.T) {
err: errors.New("certificate with serial number 'sn' is already revoked"), err: errors.New("certificate with serial number 'sn' is already revoked"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
checkErrDetails: func(err *errs.Error) { checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw) sassert.Equals(t, err.Details["token"], raw)
assert.Equals(t, err.Details["tokenID"], "44") sassert.Equals(t, err.Details["tokenID"], "44")
assert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc") sassert.Equals(t, err.Details["provisionerID"], "step-cli:4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc")
}, },
} }
}, },
@ -1591,7 +1593,7 @@ func TestAuthority_Revoke(t *testing.T) {
ID: "44", ID: "44",
} }
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
auth: _a, auth: _a,
ctx: tlsRevokeCtx, ctx: tlsRevokeCtx,
@ -1607,7 +1609,7 @@ func TestAuthority_Revoke(t *testing.T) {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) _a := testAuthority(t, WithDatabase(&db.MockAuthDB{}))
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
auth: _a, auth: _a,
@ -1625,7 +1627,7 @@ func TestAuthority_Revoke(t *testing.T) {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) _a := testAuthority(t, WithDatabase(&db.MockAuthDB{}))
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err) require.NoError(t, err)
// Filter out provisioner extension. // Filter out provisioner extension.
for i, ext := range crt.Extensions { for i, ext := range crt.Extensions {
if ext.Id.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}) { if ext.Id.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}) {
@ -1650,7 +1652,7 @@ func TestAuthority_Revoke(t *testing.T) {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) _a := testAuthority(t, WithDatabase(&db.MockAuthDB{}))
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
auth: _a, auth: _a,
@ -1683,7 +1685,7 @@ func TestAuthority_Revoke(t *testing.T) {
ID: "44", ID: "44",
} }
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) require.NoError(t, err)
return test{ return test{
auth: a, auth: a,
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod),
@ -1702,17 +1704,17 @@ func TestAuthority_Revoke(t *testing.T) {
if err := tc.auth.Revoke(tc.ctx, tc.opts); err != nil { if err := tc.auth.Revoke(tc.ctx, tc.opts); err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
var sc render.StatusCodedError var sc render.StatusCodedError
assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") sassert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) sassert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) sassert.HasPrefix(t, err.Error(), tc.err.Error())
var ctxErr *errs.Error var ctxErr *errs.Error
assert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error") sassert.Fatal(t, errors.As(err, &ctxErr), "error is not of type *errs.Error")
assert.Equals(t, ctxErr.Details["serialNumber"], tc.opts.Serial) sassert.Equals(t, ctxErr.Details["serialNumber"], tc.opts.Serial)
assert.Equals(t, ctxErr.Details["reasonCode"], tc.opts.ReasonCode) sassert.Equals(t, ctxErr.Details["reasonCode"], tc.opts.ReasonCode)
assert.Equals(t, ctxErr.Details["reason"], tc.opts.Reason) sassert.Equals(t, ctxErr.Details["reason"], tc.opts.Reason)
assert.Equals(t, ctxErr.Details["MTLS"], tc.opts.MTLS) sassert.Equals(t, ctxErr.Details["MTLS"], tc.opts.MTLS)
assert.Equals(t, ctxErr.Details["context"], provisioner.RevokeMethod.String()) sassert.Equals(t, ctxErr.Details["context"], provisioner.RevokeMethod.String())
if tc.checkErrDetails != nil { if tc.checkErrDetails != nil {
tc.checkErrDetails(ctxErr) tc.checkErrDetails(ctxErr)
@ -1814,13 +1816,11 @@ func TestAuthority_CRL(t *testing.T) {
validIssuer := "step-cli" validIssuer := "step-cli"
validAudience := testAudiences.Revoke validAudience := testAudiences.Revoke
now := time.Now().UTC() now := time.Now().UTC()
//
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err) require.NoError(t, err)
//
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err) require.NoError(t, err)
crlCtx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) crlCtx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
@ -1865,7 +1865,7 @@ func TestAuthority_CRL(t *testing.T) {
auth: a, auth: a,
ctx: crlCtx, ctx: crlCtx,
expected: nil, expected: nil,
err: database.ErrNotFound, err: errors.New("authority.GetCertificateRevocationList: not found"),
} }
}, },
"ok/crl-full": func() test { "ok/crl-full": func() test {
@ -1910,7 +1910,7 @@ func TestAuthority_CRL(t *testing.T) {
ID: sn, ID: sn,
} }
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err) require.NoError(t, err)
err = a.Revoke(crlCtx, &RevokeOptions{ err = a.Revoke(crlCtx, &RevokeOptions{
Serial: sn, Serial: sn,
ReasonCode: reasonCode, ReasonCode: reasonCode,
@ -1918,7 +1918,7 @@ func TestAuthority_CRL(t *testing.T) {
OTT: raw, OTT: raw,
}) })
assert.FatalError(t, err) require.NoError(t, err)
ex = append(ex, sn) ex = append(ex, sn)
} }
@ -1933,22 +1933,22 @@ func TestAuthority_CRL(t *testing.T) {
for name, f := range tests { for name, f := range tests {
tc := f() tc := f()
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
if crlBytes, err := tc.auth.GetCertificateRevocationList(); err == nil { crlInfo, err := tc.auth.GetCertificateRevocationList()
crl, parseErr := x509.ParseRevocationList(crlBytes) if tc.err != nil {
if parseErr != nil { assert.EqualError(t, err, tc.err.Error())
t.Errorf("x509.ParseCertificateRequest() error = %v, wantErr %v", parseErr, nil) assert.Nil(t, crlInfo)
return return
} }
var cmpList []string crl, parseErr := x509.ParseRevocationList(crlInfo.Data)
for _, c := range crl.RevokedCertificates { require.NoError(t, parseErr)
cmpList = append(cmpList, c.SerialNumber.String())
}
assert.Equals(t, cmpList, tc.expected) var cmpList []string
} else { for _, c := range crl.RevokedCertificateEntries {
assert.NotNil(t, tc.err, err.Error()) cmpList = append(cmpList, c.SerialNumber.String())
} }
assert.Equal(t, tc.expected, cmpList)
}) })
} }
} }

Loading…
Cancel
Save