diff --git a/api/api.go b/api/api.go index 30ba03f9..e057caaa 100644 --- a/api/api.go +++ b/api/api.go @@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) { if v := q.Get("limit"); len(v) > 0 { limit, err = strconv.Atoi(v) if err != nil { - return "", 0, errors.Wrapf(err, "error converting %s to integer", v) + return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v) } } return diff --git a/api/api_test.go b/api/api_test.go index 05d592f0..5cbce8b3 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) { t.Fatal(err) } - expectedError400 := errs.BadRequest("force") + expectedError400 := errs.BadRequest("limit 'abc' is not an integer") expectedError400Bytes, err := json.Marshal(expectedError400) assert.FatalError(t, err) expectedError500 := errs.InternalServer("force") diff --git a/api/rekey.go b/api/rekey.go index c0d88e55..b7958844 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -18,7 +18,7 @@ func (s *RekeyRequest) Validate() error { return errs.BadRequest("missing csr") } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { - return errs.Wrap(http.StatusBadRequest, err, "invalid csr") + return errs.BadRequestErr(err, "invalid csr") } return nil @@ -26,15 +26,14 @@ func (s *RekeyRequest) Validate() error { // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing peer certificate")) + WriteError(w, errs.BadRequest("missing client certificate")) return } var body RekeyRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/renew.go b/api/renew.go index 74ef2034..725322ee 100644 --- a/api/renew.go +++ b/api/renew.go @@ -10,7 +10,7 @@ import ( // new one. func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing peer certificate")) + WriteError(w, errs.BadRequest("missing client certificate")) return } diff --git a/api/revoke.go b/api/revoke.go index 21c3154c..44d52cb9 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -49,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } @@ -80,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing ott or peer certificate")) + WriteError(w, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body")) + WriteError(w, errs.BadRequest("serial number in client certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. diff --git a/api/sign.go b/api/sign.go index d6fd2bc6..a1e5b998 100644 --- a/api/sign.go +++ b/api/sign.go @@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error { return errs.BadRequest("missing csr") } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { - return errs.Wrap(http.StatusBadRequest, err, "invalid csr") + return errs.BadRequestErr(err, "invalid csr") } if s.OTT == "" { return errs.BadRequest("missing ott") @@ -50,7 +50,7 @@ type SignResponse struct { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/ssh.go b/api/ssh.go index 7c7a5acd..43ee6b98 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -49,16 +49,16 @@ type SSHSignRequest struct { func (s *SSHSignRequest) Validate() error { switch { case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: - return errors.Errorf("unknown certType %s", s.CertType) + return errs.BadRequest("invalid certType '%s'", s.CertType) case len(s.PublicKey) == 0: - return errors.New("missing or empty publicKey") + return errs.BadRequest("missing or empty publicKey") case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") default: // Validate identity signature if provided if s.IdentityCSR.CertificateRequest != nil { if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil { - return errors.Wrap(err, "invalid identityCSR") + return errs.BadRequestErr(err, "invalid identityCSR") } } return nil @@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error { case provisioner.SSHUserCert, provisioner.SSHHostCert: return nil default: - return errors.Errorf("unsupported type %s", r.Type) + return errs.BadRequest("invalid type '%s'", r.Type) } } @@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct { func (r *SSHCheckPrincipalRequest) Validate() error { switch { case r.Type != provisioner.SSHHostCert: - return errors.Errorf("unsupported type %s", r.Type) + return errs.BadRequest("unsupported type '%s'", r.Type) case r.Principal == "": - return errors.New("missing or empty principal") + return errs.BadRequest("missing or empty principal") default: return nil } @@ -232,7 +232,7 @@ type SSHBastionRequest struct { // Validate checks the values of the SSHBastionRequest. func (r *SSHBastionRequest) Validate() error { if r.Hostname == "" { - return errors.New("missing or empty hostname") + return errs.BadRequest("missing or empty hostname") } return nil } @@ -250,19 +250,19 @@ type SSHBastionResponse struct { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) + WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) return } @@ -270,7 +270,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey")) + WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) return } } @@ -394,11 +394,11 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -426,11 +426,11 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -465,11 +465,11 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } diff --git a/api/sshRekey.go b/api/sshRekey.go index 9d9e17cf..8d2ba5ee 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -4,7 +4,6 @@ import ( "net/http" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" @@ -20,9 +19,9 @@ type SSHRekeyRequest struct { func (s *SSHRekeyRequest) Validate() error { switch { case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") case len(s.PublicKey) == 0: - return errors.New("missing or empty public key") + return errs.BadRequest("missing or empty public key") default: return nil } @@ -40,19 +39,19 @@ type SSHRekeyResponse struct { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) + WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index d0633ecf..5dfd5983 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -19,7 +19,7 @@ type SSHRenewRequest struct { func (s *SSHRenewRequest) Validate() error { switch { case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") default: return nil } @@ -37,13 +37,13 @@ type SSHRenewResponse struct { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index c6ebe99d..cfc25f04 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/utils.go b/api/utils.go index fa56ed6b..a7f4bf58 100644 --- a/api/utils.go +++ b/api/utils.go @@ -93,7 +93,7 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { // pointed by v. func ReadJSON(r io.Reader, v interface{}) error { if err := json.NewDecoder(r).Decode(v); err != nil { - return errs.Wrap(http.StatusBadRequest, err, "error decoding json") + return errs.BadRequestErr(err, "error decoding json") } return nil } @@ -103,7 +103,7 @@ func ReadJSON(r io.Reader, v interface{}) error { func ReadProtoJSON(r io.Reader, m proto.Message) error { data, err := io.ReadAll(r) if err != nil { - return errs.Wrap(http.StatusBadRequest, err, "error reading request body") + return errs.BadRequestErr(err, "error reading request body") } return protojson.Unmarshal(data, m) } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 56768fb7..137915c8 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -228,7 +228,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Use options in the token. if opts.CertType != "" { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "jwk.AuthorizeSSHSign") + return nil, errs.BadRequestErr(err, err.Error()) } } if opts.KeyID != "" { diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 95f7fc39..c4779ea3 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -8,9 +8,7 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/json" - "fmt" "net" - "net/http" "net/url" "reflect" "time" @@ -372,17 +370,6 @@ func newValidityValidator(min, max time.Duration) *validityValidator { return &validityValidator{min: min, max: max} } -// TODO(mariano): refactor errs package to allow sending real errors to the -// user. -func badRequest(format string, args ...interface{}) error { - msg := fmt.Sprintf(format, args...) - return &errs.Error{ - Status: http.StatusBadRequest, - Msg: msg, - Err: errors.New(msg), - } -} - // Valid validates the certificate validity settings (notBefore/notAfter) and // total duration. func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { @@ -395,20 +382,20 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { d := na.Sub(nb) if na.Before(now) { - return badRequest("notAfter cannot be in the past; na=%v", na) + return errs.BadRequest("notAfter cannot be in the past; na=%v", na) } if na.Before(nb) { - return badRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) + return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) } if d < v.min { - return badRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) + return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) } // NOTE: this check is not "technically correct". We're allowing the max // duration of a cert to be "max + backdate" and not all certificates will // be backdated (e.g. if a user passes the NotBefore value then we do not // apply a backdate). This is good enough. if d > v.max+o.Backdate { - return badRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) + return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) } return nil } diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 878d3d02..6cd38c59 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -9,6 +9,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/ssh" ) @@ -55,7 +56,7 @@ type SignSSHOptions struct { // Validate validates the given SignSSHOptions. func (o SignSSHOptions) Validate() error { if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert { - return errors.Errorf("unknown certType %s", o.CertType) + return errs.BadRequest("unknown certificate type '%s'", o.CertType) } return nil } @@ -335,11 +336,11 @@ type sshCertValidityValidator struct { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error { switch { case cert.ValidAfter == 0: - return badRequest("ssh certificate validAfter cannot be 0") + return errs.BadRequest("ssh certificate validAfter cannot be 0") case cert.ValidBefore < uint64(now().Unix()): - return badRequest("ssh certificate validBefore cannot be in the past") + return errs.BadRequest("ssh certificate validBefore cannot be in the past") case cert.ValidBefore < cert.ValidAfter: - return badRequest("ssh certificate validBefore cannot be before validAfter") + return errs.BadRequest("ssh certificate validBefore cannot be before validAfter") } var min, max time.Duration @@ -351,9 +352,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti min = v.MinHostSSHCertDuration() max = v.MaxHostSSHCertDuration() case 0: - return badRequest("ssh certificate type has not been set") + return errs.BadRequest("ssh certificate type has not been set") default: - return badRequest("unknown ssh certificate type %d", cert.CertType) + return errs.BadRequest("unknown ssh certificate type %d", cert.CertType) } // To not take into account the backdate, time.Now() will be used to @@ -362,9 +363,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti switch { case dur < min: - return badRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) + return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) case dur > max+opts.Backdate: - return badRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) + return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) default: return nil } diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 99974ff1..3039d2a3 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -191,8 +191,7 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { - return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " + - "must be equivalent to sshpop certificate serial number") + return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number") } return nil } @@ -205,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { - return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate") + return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, nil @@ -220,7 +219,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } if claims.sshCert.CertType != ssh.HostCert { - return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate") + return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, []SignOption{ // Validate public key diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 3d343967..da036864 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"), + err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"), } }, "ok": func(t *testing.T) test { @@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"), + err: errors.New("sshpop certificate must be a host ssh certificate"), } }, "ok": func(t *testing.T) test { @@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"), + err: errors.New("sshpop certificate must be a host ssh certificate"), } }, "ok": func(t *testing.T) test { diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index a05f39c7..8710acb5 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -271,7 +271,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Use options in the token. if opts.CertType != "" { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "x5c.AuthorizeSSHSign") + return nil, errs.BadRequestErr(err, err.Error()) } } if opts.KeyID != "" { diff --git a/authority/ssh.go b/authority/ssh.go index bef673bf..9c5405c4 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -69,7 +69,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin ts = a.templates.SSH.Host } default: - return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ) + return nil, errs.BadRequest("invalid certificate type '%s'", typ) } // Merge user and default data @@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin // Check for required variables. if err := t.ValidateRequiredData(data); err != nil { - return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set ` flag", err)) + return nil, errs.BadRequestErr(err, "%v, please use `--set ` flag", err) } o, err := t.Output(mergedData) @@ -151,7 +151,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // Validate given options. if err := opts.Validate(); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") + return nil, err } // Set backdate with the configured value @@ -194,8 +194,8 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi certificate, err := sshutil.NewCertificate(cr, certOptions...) if err != nil { if _, ok := err.(*sshutil.TemplateError); ok { - return nil, errs.NewErr(http.StatusBadRequest, err, - errs.WithMessage(err.Error()), + return nil, errs.ApplyOptions( + errs.BadRequestErr(err, err.Error()), errs.WithKeyVal("signOptions", signOpts), ) } @@ -208,7 +208,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // Use SignSSHOptions to modify the certificate validity. It will be later // checked or set if not defined. if err := opts.ModifyValidity(certTpl); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") + return nil, errs.BadRequestErr(err, err.Error()) } // Use provisioner modifiers. @@ -258,7 +258,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period") + return nil, errs.BadRequest("cannot renew a certificate without validity period") } if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { @@ -329,7 +329,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { - return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period") + return nil, errs.BadRequest("cannot rekey a certificate without validity period") } if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { @@ -369,7 +369,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } signer = a.sshCAHostCertSignKey default: - return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType) + return nil, errs.BadRequest("unexpected certificate type '%d'", cert.CertType) } var err error diff --git a/authority/ssh_test.go b/authority/ssh_test.go index 994d015f..b0907a79 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; cannot rekey certificate without validity period"), + err: errors.New("cannot rekey a certificate without validity period"), code: http.StatusBadRequest, } }, @@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; cannot rekey certificate without validity period"), + err: errors.New("cannot rekey a certificate without validity period"), code: http.StatusBadRequest, } }, @@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("rekeySSH; unexpected ssh certificate type: 0"), + err: errors.New("unexpected certificate type '0'"), code: http.StatusBadRequest, } }, diff --git a/authority/tls.go b/authority/tls.go index 839866a2..716d8956 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -76,7 +76,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign opts := []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} if err := csr.CheckSignature(); err != nil { - return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...) + return nil, errs.ApplyOptions( + errs.BadRequestErr(err, "invalid certificate request"), + opts..., + ) } // Set backdate with the configured value @@ -114,8 +117,8 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign cert, err := x509util.NewCertificate(csr, certOptions...) if err != nil { if _, ok := err.(*x509util.TemplateError); ok { - return nil, errs.NewErr(http.StatusBadRequest, err, - errs.WithMessage(err.Error()), + return nil, errs.ApplyOptions( + errs.BadRequestErr(err, err.Error()), errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts), ) @@ -433,8 +436,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error case db.ErrNotImplemented: return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...) case db.ErrAlreadyExists: - return errs.BadRequest("authority.Revoke; certificate with serial "+ - "number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...) + return errs.ApplyOptions( + errs.BadRequest("certificate with serial number '%s' is already revoked", rci.Serial), + opts..., + ) default: return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) } diff --git a/authority/tls_test.go b/authority/tls_test.go index ba05b9fc..03beb5c1 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -256,7 +256,7 @@ func TestAuthority_Sign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: errors.New("authority.Sign; invalid certificate request"), + err: errors.New("invalid certificate request"), code: http.StatusBadRequest, } }, @@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) { Reason: reason, OTT: raw, }, - err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"), + err: errors.New("certificate with serial number 'sn' is already revoked"), code: http.StatusBadRequest, checkErrDetails: func(err *errs.Error) { assert.Equals(t, err.Details["token"], raw) diff --git a/ca/ca_test.go b/ca/ca_test.go index 0f7cb02e..1271659a 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -115,7 +115,7 @@ func TestCASign(t *testing.T) { ca: ca, body: "invalid json", status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "fail invalid-csr-sig": func(t *testing.T) *signTest { @@ -153,7 +153,7 @@ ZEp7knvU2psWRw== ca: ca, body: string(body), status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "fail unauthorized-ott": func(t *testing.T) *signTest { @@ -588,7 +588,7 @@ func TestCARenew(t *testing.T) { ca: ca, tlsConnState: nil, status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "request-missing-peer-certificate": func(t *testing.T) *renewTest { @@ -596,7 +596,7 @@ func TestCARenew(t *testing.T) { ca: ca, tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}}, status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "success": func(t *testing.T) *renewTest { diff --git a/ca/client.go b/ca/client.go index b10c0f86..6bc48a42 100644 --- a/ca/client.go +++ b/ca/client.go @@ -662,7 +662,7 @@ retry: // verify the sha256 sum := sha256.Sum256(root.RootPEM.Raw) if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) { - return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") + return nil, errs.BadRequest("root certificate fingerprint does not match") } return &root, nil } @@ -1108,8 +1108,7 @@ retry: retried = true goto retry } - - return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body)) + return nil, readError(resp.Body) } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { diff --git a/ca/client_test.go b/ca/client_test.go index 187066f0..29a4848d 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -337,8 +337,8 @@ func TestClient_Sign(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")}, } srv := httptest.NewServer(nil) @@ -410,7 +410,7 @@ func TestClient_Revoke(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -455,7 +455,7 @@ func TestClient_Revoke(t *testing.T) { if got != nil { t.Errorf("Client.Revoke() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) @@ -484,8 +484,8 @@ func TestClient_Renew(t *testing.T) { }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -519,7 +519,7 @@ func TestClient_Renew(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -553,8 +553,8 @@ func TestClient_Rekey(t *testing.T) { }{ {"ok", request, ok, 200, false, nil}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, - {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -588,7 +588,7 @@ func TestClient_Rekey(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -735,7 +735,7 @@ func TestClient_Roots(t *testing.T) { }{ {"ok", ok, 200, false, nil}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, - {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -768,7 +768,7 @@ func TestClient_Roots(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Roots() = %v, want %v", got, tt.response) @@ -1016,7 +1016,7 @@ func TestClient_SSHBastion(t *testing.T) { }{ {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, - {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, } srv := httptest.NewServer(nil) @@ -1050,7 +1050,7 @@ func TestClient_SSHBastion(t *testing.T) { sc, ok := err.(errs.StatusCoder) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) - assert.HasPrefix(t, tt.err.Error(), err.Error()) + assert.HasPrefix(t, err.Error(), tt.err.Error()) } default: if !reflect.DeepEqual(got, tt.response) { diff --git a/errs/error.go b/errs/error.go index ebcf0894..60312313 100644 --- a/errs/error.go +++ b/errs/error.go @@ -25,7 +25,7 @@ type Option func(e *Error) error // message only if it is empty. func withDefaultMessage(format string, args ...interface{}) Option { return func(e *Error) error { - if len(e.Msg) > 0 { + if e.Msg != "" { return e } e.Msg = fmt.Sprintf(format, args...) @@ -164,7 +164,8 @@ type Messenger interface { func StatusCodeError(code int, e error, opts ...Option) error { switch code { case http.StatusBadRequest: - return BadRequestErr(e, opts...) + opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) + return NewErr(http.StatusBadRequest, e, opts...) case http.StatusUnauthorized: return UnauthorizedErr(e, opts...) case http.StatusForbidden: @@ -194,6 +195,21 @@ var ( NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs ) +var ( + // BadRequestPrefix is the prefix added to the bad request messages that are + // directly sent to the cli. + BadRequestPrefix = "The request could not be completed: " +) + +func formatMessage(status int, msg string) string { + switch status { + case http.StatusBadRequest: + return BadRequestPrefix + msg + "." + default: + return msg + } +} + // splitOptionArgs splits the variadic length args into string formatting args // and Option(s) to apply to an Error. func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { @@ -218,6 +234,32 @@ func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { return args[:indexOptionStart], opts } +// New creates a new http error with the given status and message. +func New(status int, format string, args ...interface{}) error { + msg := fmt.Sprintf(format, args...) + return &Error{ + Status: status, + Msg: formatMessage(status, msg), + Err: errors.New(msg), + } +} + +// NewError creates a new http error with the given error and message. +func NewError(status int, err error, format string, args ...interface{}) error { + if _, ok := err.(*Error); ok { + return err + } + msg := fmt.Sprintf(format, args...) + if _, ok := err.(StackTracer); !ok { + err = errors.Wrap(err, msg) + } + return &Error{ + Status: status, + Msg: formatMessage(status, msg), + Err: err, + } +} + // NewErr returns a new Error. If the given error implements the StatusCoder // interface we will ignore the given status. func NewErr(status int, err error, opts ...Option) error { @@ -254,6 +296,18 @@ func Errorf(code int, format string, args ...interface{}) error { return e } +// ApplyOptions applies the given options to the error if is the type *Error. +// TODO(mariano): try to get rid of this. +func ApplyOptions(err error, opts ...interface{}) error { + if e, ok := err.(*Error); ok { + _, o := splitOptionArgs(opts) + for _, fn := range o { + fn(e) + } + } + return err +} + // InternalServer creates a 500 error with the given format and arguments. func InternalServer(format string, args ...interface{}) error { args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg)) @@ -280,14 +334,12 @@ func NotImplementedErr(err error, opts ...Option) error { // BadRequest creates a 400 error with the given format and arguments. func BadRequest(format string, args ...interface{}) error { - args = append(args, withDefaultMessage(BadRequestDefaultMsg)) - return Errorf(http.StatusBadRequest, format, args...) + return New(http.StatusBadRequest, format, args...) } // BadRequestErr returns an 400 error with the given error. -func BadRequestErr(err error, opts ...Option) error { - opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) - return NewErr(http.StatusBadRequest, err, opts...) +func BadRequestErr(err error, format string, args ...interface{}) error { + return NewError(http.StatusBadRequest, err, format, args...) } // Unauthorized creates a 401 error with the given format and arguments.