From 8ce807a6cb409c0615cf9b9dfc2ba11455ec927d Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 18 Nov 2021 15:12:44 -0800 Subject: [PATCH 1/5] Modify errs.BadRequest() calls to always send an error to the client. --- api/api_test.go | 6 +++--- api/rekey.go | 3 +-- api/renew.go | 2 +- api/revoke.go | 4 ++-- api/revoke_test.go | 4 ++-- authority/provisioner/sshpop.go | 7 +++--- authority/provisioner/sshpop_test.go | 6 +++--- authority/ssh.go | 8 +++---- authority/ssh_test.go | 6 +++--- authority/tls.go | 6 ++++-- authority/tls_test.go | 2 +- ca/ca_test.go | 4 ++-- ca/client.go | 2 +- ca/client_test.go | 28 ++++++++++++------------ errs/error.go | 32 ++++++++++++++++++++++++++-- 15 files changed, 74 insertions(+), 46 deletions(-) diff --git a/api/api_test.go b/api/api_test.go index 05d592f0..0fab1a5b 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -403,9 +403,9 @@ func TestSignRequest_Validate(t *testing.T) { fields fields err error }{ - {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")}, + {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing csr.")}, {"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")}, - {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")}, + {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing ott.")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) { t.Fatal(err) } - expectedError400 := errs.BadRequest("force") + expectedError400 := errs.BadRequestErr(errors.New("force")) expectedError400Bytes, err := json.Marshal(expectedError400) assert.FatalError(t, err) expectedError500 := errs.InternalServer("force") diff --git a/api/rekey.go b/api/rekey.go index c0d88e55..2b60eabc 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -26,9 +26,8 @@ func (s *RekeyRequest) Validate() error { // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing peer certificate")) + WriteError(w, errs.BadRequest("missing client certificate")) return } diff --git a/api/renew.go b/api/renew.go index 74ef2034..725322ee 100644 --- a/api/renew.go +++ b/api/renew.go @@ -10,7 +10,7 @@ import ( // new one. func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing peer certificate")) + WriteError(w, errs.BadRequest("missing client certificate")) return } diff --git a/api/revoke.go b/api/revoke.go index 21c3154c..f3f47ebb 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -80,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing ott or peer certificate")) + WriteError(w, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body")) + WriteError(w, errs.BadRequest("serial number in client certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. diff --git a/api/revoke_test.go b/api/revoke_test.go index 4ed4e3fe..b6ba30fb 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -28,7 +28,7 @@ func TestRevokeRequestValidate(t *testing.T) { tests := map[string]test{ "error/missing serial": { rr: &RevokeRequest{}, - err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest}, + err: &errs.Error{Err: errors.New("The request could not be completed: missing serial."), Status: http.StatusBadRequest}, }, "error/bad reasonCode": { rr: &RevokeRequest{ @@ -36,7 +36,7 @@ func TestRevokeRequestValidate(t *testing.T) { ReasonCode: 15, Passive: true, }, - err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest}, + err: &errs.Error{Err: errors.New("The request could not be completed: reasonCode out of bounds."), Status: http.StatusBadRequest}, }, "error/non-passive not implemented": { rr: &RevokeRequest{ diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 99974ff1..3039d2a3 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -191,8 +191,7 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { - return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " + - "must be equivalent to sshpop certificate serial number") + return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number") } return nil } @@ -205,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { - return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate") + return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, nil @@ -220,7 +219,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } if claims.sshCert.CertType != ssh.HostCert { - return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate") + return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, []SignOption{ // Validate public key diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 3d343967..850a698d 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"), + err: errors.New("The request could not be completed: sshpop token subject must be equivalent to sshpop certificate serial number."), } }, "ok": func(t *testing.T) test { @@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"), + err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), } }, "ok": func(t *testing.T) test { @@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"), + err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), } }, "ok": func(t *testing.T) test { diff --git a/authority/ssh.go b/authority/ssh.go index bef673bf..eba48297 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -69,7 +69,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin ts = a.templates.SSH.Host } default: - return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ) + return nil, errs.BadRequest("invalid certificate type '%s'", typ) } // Merge user and default data @@ -258,7 +258,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period") + return nil, errs.BadRequest("cannot renew a certificate without validity period") } if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { @@ -329,7 +329,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period") + return nil, errs.BadRequest("cannot rekey a certificate without validity period") } if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { @@ -369,7 +369,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } signer = a.sshCAHostCertSignKey default: - return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType) + return nil, errs.BadRequest("unexpected certificate type '%d'", cert.CertType) } var err error diff --git a/authority/ssh_test.go b/authority/ssh_test.go index 994d015f..a62c9e54 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; cannot rekey certificate without validity period"), + err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), code: http.StatusBadRequest, } }, @@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; cannot rekey certificate without validity period"), + err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), code: http.StatusBadRequest, } }, @@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; unexpected ssh certificate type: 0"), + err: errors.New("The request could not be completed: unexpected certificate type '0'."), code: http.StatusBadRequest, } }, diff --git a/authority/tls.go b/authority/tls.go index 839866a2..4a5f2fdf 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -433,8 +433,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error case db.ErrNotImplemented: return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...) case db.ErrAlreadyExists: - return errs.BadRequest("authority.Revoke; certificate with serial "+ - "number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...) + return errs.ApplyOptions( + errs.BadRequest("certificate with serial number '%s' is already revoked", rci.Serial), + opts..., + ) default: return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) } diff --git a/authority/tls_test.go b/authority/tls_test.go index ba05b9fc..1796c4a3 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) { Reason: reason, OTT: raw, }, - err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"), + err: errors.New("The request could not be completed: certificate with serial number 'sn' is already revoked"), code: http.StatusBadRequest, checkErrDetails: func(err *errs.Error) { assert.Equals(t, err.Details["token"], raw) diff --git a/ca/ca_test.go b/ca/ca_test.go index 0f7cb02e..64371ac3 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -588,7 +588,7 @@ func TestCARenew(t *testing.T) { ca: ca, tlsConnState: nil, status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "request-missing-peer-certificate": func(t *testing.T) *renewTest { @@ -596,7 +596,7 @@ func TestCARenew(t *testing.T) { ca: ca, tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}}, status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "success": func(t *testing.T) *renewTest { diff --git a/ca/client.go b/ca/client.go index b10c0f86..74a3b7df 100644 --- a/ca/client.go +++ b/ca/client.go @@ -662,7 +662,7 @@ retry: // verify the sha256 sum := sha256.Sum256(root.RootPEM.Raw) if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) { - return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") + return nil, errs.BadRequest("root certificate fingerprint does not match") } return &root, nil } diff --git a/ca/client_test.go b/ca/client_test.go index 187066f0..29a4848d 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -337,8 +337,8 @@ func TestClient_Sign(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")}, } srv := httptest.NewServer(nil) @@ -410,7 +410,7 @@ func TestClient_Revoke(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -455,7 +455,7 @@ func TestClient_Revoke(t *testing.T) { if got != nil { t.Errorf("Client.Revoke() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) @@ -484,8 +484,8 @@ func TestClient_Renew(t *testing.T) { }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -519,7 +519,7 @@ func TestClient_Renew(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -553,8 +553,8 @@ func TestClient_Rekey(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -588,7 +588,7 @@ func TestClient_Rekey(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -735,7 +735,7 @@ func TestClient_Roots(t *testing.T) { }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -768,7 +768,7 @@ func TestClient_Roots(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Roots() = %v, want %v", got, tt.response) @@ -1016,7 +1016,7 @@ func TestClient_SSHBastion(t *testing.T) { }{ {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, - {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -1050,7 +1050,7 @@ func TestClient_SSHBastion(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) } default: if !reflect.DeepEqual(got, tt.response) { diff --git a/errs/error.go b/errs/error.go index ebcf0894..ab488af1 100644 --- a/errs/error.go +++ b/errs/error.go @@ -194,6 +194,12 @@ var ( NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs ) +var ( + // BadRequestPrefix is the prefix added to the bad request messages that are + // directly sent to the cli. + BadRequestPrefix = "The request could not be completed: " +) + // splitOptionArgs splits the variadic length args into string formatting args // and Option(s) to apply to an Error. func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { @@ -218,6 +224,16 @@ func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { return args[:indexOptionStart], opts } +// New creates a new http error with the given status and message. +func New(status int, format string, args ...interface{}) error { + msg := fmt.Sprintf(format, args...) + return &Error{ + Status: status, + Msg: msg, + Err: errors.New(msg), + } +} + // NewErr returns a new Error. If the given error implements the StatusCoder // interface we will ignore the given status. func NewErr(status int, err error, opts ...Option) error { @@ -254,6 +270,18 @@ func Errorf(code int, format string, args ...interface{}) error { return e } +// ApplyOptions applies the given options to the error if is the type *Error. +// TODO(mariano): try to get rid of this. +func ApplyOptions(err error, opts ...interface{}) error { + if e, ok := err.(*Error); ok { + _, o := splitOptionArgs(opts) + for _, fn := range o { + fn(e) + } + } + return err +} + // InternalServer creates a 500 error with the given format and arguments. func InternalServer(format string, args ...interface{}) error { args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg)) @@ -280,8 +308,8 @@ func NotImplementedErr(err error, opts ...Option) error { // BadRequest creates a 400 error with the given format and arguments. func BadRequest(format string, args ...interface{}) error { - args = append(args, withDefaultMessage(BadRequestDefaultMsg)) - return Errorf(http.StatusBadRequest, format, args...) + format = BadRequestPrefix + format + "." + return New(http.StatusBadRequest, format, args...) } // BadRequestErr returns an 400 error with the given error. From 8c8db0d4b7ab17e764375aa5f2419592628415db Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 18 Nov 2021 18:17:36 -0800 Subject: [PATCH 2/5] Modify errs.BadRequestErr() to always return an error to the client. --- api/api.go | 4 ++-- api/api_test.go | 6 ++--- api/revoke_test.go | 4 ++-- api/sign.go | 4 ++-- api/ssh.go | 24 +++++++++---------- api/sshRekey.go | 7 +++--- api/sshRenew.go | 4 ++-- authority/provisioner/sshpop_test.go | 6 ++--- authority/ssh.go | 2 +- authority/ssh_test.go | 6 ++--- authority/tls_test.go | 2 +- ca/ca_test.go | 4 ++-- ca/client.go | 3 +-- errs/error.go | 35 ++++++++++++++++++++++------ 14 files changed, 65 insertions(+), 46 deletions(-) diff --git a/api/api.go b/api/api.go index 30ba03f9..e057caaa 100644 --- a/api/api.go +++ b/api/api.go @@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) { if v := q.Get("limit"); len(v) > 0 { limit, err = strconv.Atoi(v) if err != nil { - return "", 0, errors.Wrapf(err, "error converting %s to integer", v) + return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v) } } return diff --git a/api/api_test.go b/api/api_test.go index 0fab1a5b..5cbce8b3 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -403,9 +403,9 @@ func TestSignRequest_Validate(t *testing.T) { fields fields err error }{ - {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing csr.")}, + {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")}, {"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")}, - {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing ott.")}, + {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) { t.Fatal(err) } - expectedError400 := errs.BadRequestErr(errors.New("force")) + expectedError400 := errs.BadRequest("limit 'abc' is not an integer") expectedError400Bytes, err := json.Marshal(expectedError400) assert.FatalError(t, err) expectedError500 := errs.InternalServer("force") diff --git a/api/revoke_test.go b/api/revoke_test.go index b6ba30fb..4ed4e3fe 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -28,7 +28,7 @@ func TestRevokeRequestValidate(t *testing.T) { tests := map[string]test{ "error/missing serial": { rr: &RevokeRequest{}, - err: &errs.Error{Err: errors.New("The request could not be completed: missing serial."), Status: http.StatusBadRequest}, + err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest}, }, "error/bad reasonCode": { rr: &RevokeRequest{ @@ -36,7 +36,7 @@ func TestRevokeRequestValidate(t *testing.T) { ReasonCode: 15, Passive: true, }, - err: &errs.Error{Err: errors.New("The request could not be completed: reasonCode out of bounds."), Status: http.StatusBadRequest}, + err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest}, }, "error/non-passive not implemented": { rr: &RevokeRequest{ diff --git a/api/sign.go b/api/sign.go index d6fd2bc6..a1e5b998 100644 --- a/api/sign.go +++ b/api/sign.go @@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error { return errs.BadRequest("missing csr") } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { - return errs.Wrap(http.StatusBadRequest, err, "invalid csr") + return errs.BadRequestErr(err, "invalid csr") } if s.OTT == "" { return errs.BadRequest("missing ott") @@ -50,7 +50,7 @@ type SignResponse struct { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/ssh.go b/api/ssh.go index 7c7a5acd..315b3e83 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -49,16 +49,16 @@ type SSHSignRequest struct { func (s *SSHSignRequest) Validate() error { switch { case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: - return errors.Errorf("unknown certType %s", s.CertType) + return errs.BadRequest("invalid certType '%s'", s.CertType) case len(s.PublicKey) == 0: - return errors.New("missing or empty publicKey") + return errs.BadRequest("missing or empty publicKey") case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") default: // Validate identity signature if provided if s.IdentityCSR.CertificateRequest != nil { if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil { - return errors.Wrap(err, "invalid identityCSR") + return errs.BadRequestErr(err, "invalid identityCSR") } } return nil @@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error { case provisioner.SSHUserCert, provisioner.SSHHostCert: return nil default: - return errors.Errorf("unsupported type %s", r.Type) + return errs.BadRequest("invalid type '%s'", r.Type) } } @@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct { func (r *SSHCheckPrincipalRequest) Validate() error { switch { case r.Type != provisioner.SSHHostCert: - return errors.Errorf("unsupported type %s", r.Type) + return errs.BadRequest("unsupported type '%s'", r.Type) case r.Principal == "": - return errors.New("missing or empty principal") + return errs.BadRequest("missing or empty principal") default: return nil } @@ -232,7 +232,7 @@ type SSHBastionRequest struct { // Validate checks the values of the SSHBastionRequest. func (r *SSHBastionRequest) Validate() error { if r.Hostname == "" { - return errors.New("missing or empty hostname") + return errs.BadRequest("missing or empty hostname") } return nil } @@ -256,7 +256,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -398,7 +398,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -430,7 +430,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -469,7 +469,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } diff --git a/api/sshRekey.go b/api/sshRekey.go index 9d9e17cf..4e29b043 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -4,7 +4,6 @@ import ( "net/http" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" @@ -20,9 +19,9 @@ type SSHRekeyRequest struct { func (s *SSHRekeyRequest) Validate() error { switch { case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") case len(s.PublicKey) == 0: - return errors.New("missing or empty public key") + return errs.BadRequest("missing or empty public key") default: return nil } @@ -46,7 +45,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index d0633ecf..d28b57b5 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -19,7 +19,7 @@ type SSHRenewRequest struct { func (s *SSHRenewRequest) Validate() error { switch { case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") default: return nil } @@ -43,7 +43,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 850a698d..da036864 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("The request could not be completed: sshpop token subject must be equivalent to sshpop certificate serial number."), + err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"), } }, "ok": func(t *testing.T) test { @@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), + err: errors.New("sshpop certificate must be a host ssh certificate"), } }, "ok": func(t *testing.T) test { @@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), + err: errors.New("sshpop certificate must be a host ssh certificate"), } }, "ok": func(t *testing.T) test { diff --git a/authority/ssh.go b/authority/ssh.go index eba48297..5e03ee9e 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin // Check for required variables. if err := t.ValidateRequiredData(data); err != nil { - return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set ` flag", err)) + return nil, errs.BadRequestErr(err, "%v, please use `--set ` flag", err) } o, err := t.Output(mergedData) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index a62c9e54..b0907a79 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), + err: errors.New("cannot rekey a certificate without validity period"), code: http.StatusBadRequest, } }, @@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), + err: errors.New("cannot rekey a certificate without validity period"), code: http.StatusBadRequest, } }, @@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("The request could not be completed: unexpected certificate type '0'."), + err: errors.New("unexpected certificate type '0'"), code: http.StatusBadRequest, } }, diff --git a/authority/tls_test.go b/authority/tls_test.go index 1796c4a3..409c0582 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) { Reason: reason, OTT: raw, }, - err: errors.New("The request could not be completed: certificate with serial number 'sn' is already revoked"), + err: errors.New("certificate with serial number 'sn' is already revoked"), code: http.StatusBadRequest, checkErrDetails: func(err *errs.Error) { assert.Equals(t, err.Details["token"], raw) diff --git a/ca/ca_test.go b/ca/ca_test.go index 64371ac3..1271659a 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -115,7 +115,7 @@ func TestCASign(t *testing.T) { ca: ca, body: "invalid json", status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "fail invalid-csr-sig": func(t *testing.T) *signTest { @@ -153,7 +153,7 @@ ZEp7knvU2psWRw== ca: ca, body: string(body), status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "fail unauthorized-ott": func(t *testing.T) *signTest { diff --git a/ca/client.go b/ca/client.go index 74a3b7df..6bc48a42 100644 --- a/ca/client.go +++ b/ca/client.go @@ -1108,8 +1108,7 @@ retry: retried = true goto retry } - - return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body)) + return nil, readError(resp.Body) } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { diff --git a/errs/error.go b/errs/error.go index ab488af1..3e40b3f3 100644 --- a/errs/error.go +++ b/errs/error.go @@ -25,7 +25,7 @@ type Option func(e *Error) error // message only if it is empty. func withDefaultMessage(format string, args ...interface{}) Option { return func(e *Error) error { - if len(e.Msg) > 0 { + if e.Msg != "" { return e } e.Msg = fmt.Sprintf(format, args...) @@ -164,7 +164,8 @@ type Messenger interface { func StatusCodeError(code int, e error, opts ...Option) error { switch code { case http.StatusBadRequest: - return BadRequestErr(e, opts...) + opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) + return NewErr(http.StatusBadRequest, e, opts...) case http.StatusUnauthorized: return UnauthorizedErr(e, opts...) case http.StatusForbidden: @@ -200,6 +201,15 @@ var ( BadRequestPrefix = "The request could not be completed: " ) +func formatMessage(status int, msg string) string { + switch status { + case http.StatusBadRequest: + return BadRequestPrefix + msg + "." + default: + return msg + } +} + // splitOptionArgs splits the variadic length args into string formatting args // and Option(s) to apply to an Error. func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { @@ -229,11 +239,24 @@ func New(status int, format string, args ...interface{}) error { msg := fmt.Sprintf(format, args...) return &Error{ Status: status, - Msg: msg, + Msg: formatMessage(status, msg), Err: errors.New(msg), } } +// NewError creates a new http error with the given error and message. +func NewError(status int, err error, format string, args ...interface{}) error { + msg := fmt.Sprintf(format, args...) + if _, ok := err.(StackTracer); !ok { + err = errors.Wrap(err, msg) + } + return &Error{ + Status: status, + Msg: formatMessage(status, msg), + Err: err, + } +} + // NewErr returns a new Error. If the given error implements the StatusCoder // interface we will ignore the given status. func NewErr(status int, err error, opts ...Option) error { @@ -308,14 +331,12 @@ func NotImplementedErr(err error, opts ...Option) error { // BadRequest creates a 400 error with the given format and arguments. func BadRequest(format string, args ...interface{}) error { - format = BadRequestPrefix + format + "." return New(http.StatusBadRequest, format, args...) } // BadRequestErr returns an 400 error with the given error. -func BadRequestErr(err error, opts ...Option) error { - opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) - return NewErr(http.StatusBadRequest, err, opts...) +func BadRequestErr(err error, format string, args ...interface{}) error { + return NewError(http.StatusBadRequest, err, format, args...) } // Unauthorized creates a 401 error with the given format and arguments. From 668d3ea6c72578cdfd162e0ac2eaab0850b39373 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 18 Nov 2021 18:44:58 -0800 Subject: [PATCH 3/5] Modify errs.Wrap() with bad request to send messages to users. --- api/rekey.go | 4 ++-- api/revoke.go | 2 +- api/ssh.go | 12 ++++++------ api/sshRekey.go | 4 ++-- api/sshRenew.go | 2 +- api/sshRevoke.go | 2 +- api/utils.go | 4 ++-- authority/provisioner/jwk.go | 2 +- authority/provisioner/sign_ssh_options.go | 3 ++- authority/provisioner/x5c.go | 2 +- authority/ssh.go | 8 ++++---- authority/tls.go | 9 ++++++--- authority/tls_test.go | 2 +- 13 files changed, 30 insertions(+), 26 deletions(-) diff --git a/api/rekey.go b/api/rekey.go index 2b60eabc..b7958844 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -18,7 +18,7 @@ func (s *RekeyRequest) Validate() error { return errs.BadRequest("missing csr") } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { - return errs.Wrap(http.StatusBadRequest, err, "invalid csr") + return errs.BadRequestErr(err, "invalid csr") } return nil @@ -33,7 +33,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { var body RekeyRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/revoke.go b/api/revoke.go index f3f47ebb..44d52cb9 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -49,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/ssh.go b/api/ssh.go index 315b3e83..43ee6b98 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -250,7 +250,7 @@ type SSHBastionResponse struct { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } @@ -262,7 +262,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) + WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) return } @@ -270,7 +270,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey")) + WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) return } } @@ -394,7 +394,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { @@ -426,7 +426,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { @@ -465,7 +465,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { diff --git a/api/sshRekey.go b/api/sshRekey.go index 4e29b043..8d2ba5ee 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -39,7 +39,7 @@ type SSHRekeyResponse struct { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } @@ -51,7 +51,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) + WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index d28b57b5..5dfd5983 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -37,7 +37,7 @@ type SSHRenewResponse struct { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index c6ebe99d..cfc25f04 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/utils.go b/api/utils.go index fa56ed6b..a7f4bf58 100644 --- a/api/utils.go +++ b/api/utils.go @@ -93,7 +93,7 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { // pointed by v. func ReadJSON(r io.Reader, v interface{}) error { if err := json.NewDecoder(r).Decode(v); err != nil { - return errs.Wrap(http.StatusBadRequest, err, "error decoding json") + return errs.BadRequestErr(err, "error decoding json") } return nil } @@ -103,7 +103,7 @@ func ReadJSON(r io.Reader, v interface{}) error { func ReadProtoJSON(r io.Reader, m proto.Message) error { data, err := io.ReadAll(r) if err != nil { - return errs.Wrap(http.StatusBadRequest, err, "error reading request body") + return errs.BadRequestErr(err, "error reading request body") } return protojson.Unmarshal(data, m) } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 56768fb7..137915c8 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -228,7 +228,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Use options in the token. if opts.CertType != "" { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "jwk.AuthorizeSSHSign") + return nil, errs.BadRequestErr(err, err.Error()) } } if opts.KeyID != "" { diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 878d3d02..78d5dd31 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -9,6 +9,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" ) @@ -55,7 +56,7 @@ type SignSSHOptions struct { // Validate validates the given SignSSHOptions. func (o SignSSHOptions) Validate() error { if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert { - return errors.Errorf("unknown certType %s", o.CertType) + return errs.BadRequest("unknown certificate type '%s'", o.CertType) } return nil } diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index a05f39c7..8710acb5 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -271,7 +271,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Use options in the token. if opts.CertType != "" { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "x5c.AuthorizeSSHSign") + return nil, errs.BadRequestErr(err, err.Error()) } } if opts.KeyID != "" { diff --git a/authority/ssh.go b/authority/ssh.go index 5e03ee9e..9c5405c4 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -151,7 +151,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // Validate given options. if err := opts.Validate(); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") + return nil, err } // Set backdate with the configured value @@ -194,8 +194,8 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi certificate, err := sshutil.NewCertificate(cr, certOptions...) if err != nil { if _, ok := err.(*sshutil.TemplateError); ok { - return nil, errs.NewErr(http.StatusBadRequest, err, - errs.WithMessage(err.Error()), + return nil, errs.ApplyOptions( + errs.BadRequestErr(err, err.Error()), errs.WithKeyVal("signOptions", signOpts), ) } @@ -208,7 +208,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // Use SignSSHOptions to modify the certificate validity. It will be later // checked or set if not defined. if err := opts.ModifyValidity(certTpl); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") + return nil, errs.BadRequestErr(err, err.Error()) } // Use provisioner modifiers. diff --git a/authority/tls.go b/authority/tls.go index 4a5f2fdf..716d8956 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -76,7 +76,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign opts := []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} if err := csr.CheckSignature(); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...) + return nil, errs.ApplyOptions( + errs.BadRequestErr(err, "invalid certificate request"), + opts..., + ) } // Set backdate with the configured value @@ -114,8 +117,8 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign cert, err := x509util.NewCertificate(csr, certOptions...) if err != nil { if _, ok := err.(*x509util.TemplateError); ok { - return nil, errs.NewErr(http.StatusBadRequest, err, - errs.WithMessage(err.Error()), + return nil, errs.ApplyOptions( + errs.BadRequestErr(err, err.Error()), errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts), ) diff --git a/authority/tls_test.go b/authority/tls_test.go index 409c0582..03beb5c1 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -256,7 +256,7 @@ func TestAuthority_Sign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: errors.New("authority.Sign; invalid certificate request"), + err: errors.New("invalid certificate request"), code: http.StatusBadRequest, } }, From b6ebd118fc2d1840b2cd071ec63355d3599a1be0 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 18 Nov 2021 18:47:55 -0800 Subject: [PATCH 4/5] Update temporal solution for sending message to users --- authority/provisioner/sign_options.go | 21 ++++----------------- authority/provisioner/sign_ssh_options.go | 14 +++++++------- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 95f7fc39..c4779ea3 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -8,9 +8,7 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/json" - "fmt" "net" - "net/http" "net/url" "reflect" "time" @@ -372,17 +370,6 @@ func newValidityValidator(min, max time.Duration) *validityValidator { return &validityValidator{min: min, max: max} } -// TODO(mariano): refactor errs package to allow sending real errors to the -// user. -func badRequest(format string, args ...interface{}) error { - msg := fmt.Sprintf(format, args...) - return &errs.Error{ - Status: http.StatusBadRequest, - Msg: msg, - Err: errors.New(msg), - } -} - // Valid validates the certificate validity settings (notBefore/notAfter) and // total duration. func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { @@ -395,20 +382,20 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { d := na.Sub(nb) if na.Before(now) { - return badRequest("notAfter cannot be in the past; na=%v", na) + return errs.BadRequest("notAfter cannot be in the past; na=%v", na) } if na.Before(nb) { - return badRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) + return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) } if d < v.min { - return badRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) + return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) } // NOTE: this check is not "technically correct". We're allowing the max // duration of a cert to be "max + backdate" and not all certificates will // be backdated (e.g. if a user passes the NotBefore value then we do not // apply a backdate). This is good enough. if d > v.max+o.Backdate { - return badRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) + return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) } return nil } diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 78d5dd31..6cd38c59 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -336,11 +336,11 @@ type sshCertValidityValidator struct { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error { switch { case cert.ValidAfter == 0: - return badRequest("ssh certificate validAfter cannot be 0") + return errs.BadRequest("ssh certificate validAfter cannot be 0") case cert.ValidBefore < uint64(now().Unix()): - return badRequest("ssh certificate validBefore cannot be in the past") + return errs.BadRequest("ssh certificate validBefore cannot be in the past") case cert.ValidBefore < cert.ValidAfter: - return badRequest("ssh certificate validBefore cannot be before validAfter") + return errs.BadRequest("ssh certificate validBefore cannot be before validAfter") } var min, max time.Duration @@ -352,9 +352,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti min = v.MinHostSSHCertDuration() max = v.MaxHostSSHCertDuration() case 0: - return badRequest("ssh certificate type has not been set") + return errs.BadRequest("ssh certificate type has not been set") default: - return badRequest("unknown ssh certificate type %d", cert.CertType) + return errs.BadRequest("unknown ssh certificate type %d", cert.CertType) } // To not take into account the backdate, time.Now() will be used to @@ -363,9 +363,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti switch { case dur < min: - return badRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) + return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) case dur > max+opts.Backdate: - return badRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) + return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) default: return nil } From aa3fdf8fb97e6142e831581d8398eda4525f17b4 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 18 Nov 2021 19:03:43 -0800 Subject: [PATCH 5/5] Do not overwrite errors. --- errs/error.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/errs/error.go b/errs/error.go index 3e40b3f3..60312313 100644 --- a/errs/error.go +++ b/errs/error.go @@ -246,6 +246,9 @@ func New(status int, format string, args ...interface{}) error { // NewError creates a new http error with the given error and message. func NewError(status int, err error, format string, args ...interface{}) error { + if _, ok := err.(*Error); ok { + return err + } msg := fmt.Sprintf(format, args...) if _, ok := err.(StackTracer); !ok { err = errors.Wrap(err, msg)