diff --git a/api/ssh.go b/api/ssh.go index 4e6e9c0b..f423583b 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -296,7 +296,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { } var addUserCertificate *SSHCertificate - if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { + if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { WriteError(w, errs.ForbiddenErr(err)) diff --git a/authority/ssh.go b/authority/ssh.go index d28205e2..b38cfca9 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -442,16 +442,37 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub return cert, nil } +// IsValidForAddUser checks if a user provisioner certificate can be issued to +// the given certificate. +func IsValidForAddUser(cert *ssh.Certificate) error { + if cert.CertType != ssh.UserCert { + return errors.New("certificate is not a user certificate") + } + + switch len(cert.ValidPrincipals) { + case 0: + return errors.New("certificate does not have any principals") + case 1: + return nil + case 2: + // OIDC provisioners adds a second principal with the email address. + // @ cannot be the first character. + if strings.Index(cert.ValidPrincipals[1], "@") > 0 { + return nil + } + return errors.New("certificate does not have only one principal") + default: + return errors.New("certificate does not have only one principal") + } +} + // SignSSHAddUser signs a certificate that provisions a new user in a server. func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { if a.sshCAUserCertSignKey == nil { return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled") } - if subject.CertType != ssh.UserCert { - return nil, errs.Forbidden("signSSHAddUser: certificate is not a user certificate") - } - if len(subject.ValidPrincipals) != 1 { - return nil, errs.Forbidden("signSSHAddUser: certificate does not have only one principal") + if err := IsValidForAddUser(subject); err != nil { + return nil, errs.Wrap(http.StatusForbidden, err, "signSSHAddUser") } nonce, err := randutil.ASCII(32) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index e32d9d87..e30ff60b 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -917,3 +917,28 @@ func TestAuthority_RekeySSH(t *testing.T) { }) } } + +func TestIsValidForAddUser(t *testing.T) { + type args struct { + cert *ssh.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"john"}}}, false}, + {"ok oidc", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"jane", "jane@smallstep.com"}}}, false}, + {"fail host", args{&ssh.Certificate{CertType: ssh.HostCert, ValidPrincipals: []string{"john"}}}, true}, + {"fail principals", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"john", "jane"}}}, true}, + {"fail no principals", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{}}}, true}, + {"fail extra principals", args{&ssh.Certificate{CertType: ssh.UserCert, ValidPrincipals: []string{"john", "jane", "doe"}}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := IsValidForAddUser(tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("IsValidForAddUser() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}