diff --git a/api/api_test.go b/api/api_test.go index cbaf806f..edefbd47 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -550,8 +550,6 @@ type mockAuthority struct { getTLSOptions func() *tlsutil.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByID func(provID string) (provisioner.Interface, error) @@ -560,14 +558,16 @@ type mockAuthority struct { getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) - renewSSH func(cert *ssh.Certificate) (*ssh.Certificate, error) - rekeySSH func(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - getSSHHosts func(*x509.Certificate) ([]sshutil.Host, error) - getSSHRoots func() (*authority.SSHKeys, error) - getSSHFederation func() (*authority.SSHKeys, error) - getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error) + signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) + renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) + rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) + getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) + getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) + getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) checkSSHHost func(ctx context.Context, principal, token string) (bool, error) - getSSHBastion func(user string, hostname string) (*authority.Bastion, error) + getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) version func() authority.Version } @@ -604,20 +604,6 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } -func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.signSSH != nil { - return m.signSSH(key, opts, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.signSSHAddUser != nil { - return m.signSSHAddUser(key, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.renew != nil { return m.renew(cert) @@ -674,44 +660,58 @@ func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { return m.ret1.([]*x509.Certificate), m.err } -func (m *mockAuthority) RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error) { +func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.signSSH != nil { + return m.signSSH(ctx, key, opts, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.signSSHAddUser != nil { + return m.signSSHAddUser(ctx, key, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { if m.renewSSH != nil { - return m.renewSSH(cert) + return m.renewSSH(ctx, cert) } return m.ret1.(*ssh.Certificate), m.err } -func (m *mockAuthority) RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { +func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { if m.rekeySSH != nil { - return m.rekeySSH(cert, key, signOpts...) + return m.rekeySSH(ctx, cert, key, signOpts...) } return m.ret1.(*ssh.Certificate), m.err } -func (m *mockAuthority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) { +func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { if m.getSSHHosts != nil { - return m.getSSHHosts(cert) + return m.getSSHHosts(ctx, cert) } return m.ret1.([]sshutil.Host), m.err } -func (m *mockAuthority) GetSSHRoots() (*authority.SSHKeys, error) { +func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { if m.getSSHRoots != nil { - return m.getSSHRoots() + return m.getSSHRoots(ctx) } return m.ret1.(*authority.SSHKeys), m.err } -func (m *mockAuthority) GetSSHFederation() (*authority.SSHKeys, error) { +func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { if m.getSSHFederation != nil { - return m.getSSHFederation() + return m.getSSHFederation(ctx) } return m.ret1.(*authority.SSHKeys), m.err } -func (m *mockAuthority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) { +func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { if m.getSSHConfig != nil { - return m.getSSHConfig(typ, data) + return m.getSSHConfig(ctx, typ, data) } return m.ret1.([]templates.Output), m.err } @@ -723,9 +723,9 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin return m.ret1.(bool), m.err } -func (m *mockAuthority) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) { +func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) { if m.getSSHBastion != nil { - return m.getSSHBastion(user, hostname) + return m.getSSHBastion(ctx, user, hostname) } return m.ret1.(*authority.Bastion), m.err } diff --git a/api/ssh_test.go b/api/ssh_test.go index cb5c7904..874c00b7 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -319,10 +319,10 @@ func Test_caHandler_SSHSign(t *testing.T) { authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, - signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { return tt.signCert, tt.signErr }, - signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { return tt.addUserCert, tt.addUserErr }, sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { @@ -379,7 +379,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHRoots: func() (*authority.SSHKeys, error) { + getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, }).(*caHandler) @@ -433,7 +433,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHFederation: func() (*authority.SSHKeys, error) { + getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, }).(*caHandler) @@ -493,7 +493,7 @@ func Test_caHandler_SSHConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHConfig: func(typ string, data map[string]string) ([]templates.Output, error) { + getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, }).(*caHandler) @@ -591,7 +591,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHHosts: func(*x509.Certificate) ([]sshutil.Host, error) { + getSSHHosts: func(context.Context, *x509.Certificate) ([]sshutil.Host, error) { return tt.hosts, tt.err }, }).(*caHandler) @@ -646,7 +646,7 @@ func Test_caHandler_SSHBastion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHBastion: func(user, hostname string) (*authority.Bastion, error) { + getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, }).(*caHandler) diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index d0782c1e..fbf71f4b 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -485,10 +485,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p5.Init(config)) - p4.getIdentityFunc = func(p Interface, email string) (*Identity, error) { + p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return &Identity{Usernames: []string{"max", "mariano"}}, nil } - p5.getIdentityFunc = func(p Interface, email string) (*Identity, error) { + p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return nil, errors.New("force") } diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index 2577c62f..238e21a3 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -92,7 +92,7 @@ func TestDefaultIdentityFunc(t *testing.T) { for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) - identity, err := DefaultIdentityFunc(tc.p, tc.email) + identity, err := DefaultIdentityFunc(context.Background(), tc.p, tc.email) if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index b581740f..6d05e1a9 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -153,7 +153,7 @@ func TestAuthority_SignSSH(t *testing.T) { a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey - got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...) + got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr) return @@ -242,7 +242,7 @@ func TestAuthority_SignSSHAddUser(t *testing.T) { AddUserPrincipal: tt.fields.addUserPrincipal, AddUserCommand: tt.fields.addUserCommand, } - got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject) + got, err := a.SignSSHAddUser(context.Background(), tt.args.key, tt.args.subject) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr) return @@ -295,7 +295,7 @@ func TestAuthority_GetSSHRoots(t *testing.T) { a.sshCAUserCerts = tt.fields.sshCAUserCerts a.sshCAHostCerts = tt.fields.sshCAHostCerts - got, err := a.GetSSHRoots() + got, err := a.GetSSHRoots(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr) return @@ -337,7 +337,7 @@ func TestAuthority_GetSSHFederation(t *testing.T) { a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts - got, err := a.GetSSHFederation() + got, err := a.GetSSHFederation(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr) return @@ -463,7 +463,7 @@ func TestAuthority_GetSSHConfig(t *testing.T) { a.sshCAUserCertSignKey = tt.fields.userSigner a.sshCAHostCertSignKey = tt.fields.hostSigner - got, err := a.GetSSHConfig(tt.args.typ, tt.args.data) + got, err := a.GetSSHConfig(context.Background(), tt.args.typ, tt.args.data) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr) return @@ -614,7 +614,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) { } type fields struct { config *Config - sshBastionFunc func(user, hostname string) (*Bastion, error) + sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error) } type args struct { user string @@ -630,8 +630,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) { {"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false}, {"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false}, {"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false}, - {"func", fields{&Config{}, func(_, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false}, - {"func err", fields{&Config{}, func(_, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true}, + {"func", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false}, + {"func err", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true}, {"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true}, } for _, tt := range tests { @@ -640,7 +640,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) { config: tt.fields.config, sshBastionFunc: tt.fields.sshBastionFunc, } - got, err := a.GetSSHBastion(tt.args.user, tt.args.hostname) + got, err := a.GetSSHBastion(context.Background(), tt.args.user, tt.args.hostname) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) return @@ -659,7 +659,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) { a := testAuthority(t) type test struct { - getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error) + getHostsFunc func(context.Context, *x509.Certificate) ([]sshutil.Host, error) auth *Authority cert *x509.Certificate cmp func(got []sshutil.Host) @@ -669,7 +669,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) { tests := map[string]func(t *testing.T) *test{ "fail/getHostsFunc-fail": func(t *testing.T) *test { return &test{ - getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) { + getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { return nil, errors.New("force") }, cert: &x509.Certificate{}, @@ -684,7 +684,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) { } return &test{ - getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) { + getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { return hosts, nil }, cert: &x509.Certificate{}, @@ -732,7 +732,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) { } auth.sshGetHostsFunc = tc.getHostsFunc - hosts, err := auth.GetSSHHosts(tc.cert) + hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) @@ -901,7 +901,7 @@ func TestAuthority_RekeySSH(t *testing.T) { a.sshCAUserCertSignKey = tc.userSigner a.sshCAHostCertSignKey = tc.hostSigner - cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...) + cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) if err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder)