diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 1097c003..14d357f1 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -56,7 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - resp, err := wh.Do(wc.client, req, wc.TemplateData) + // TODO(hs): propagate context from above + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -88,7 +92,12 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - resp, err := wh.Do(wc.client, req, wc.TemplateData) + + // TODO(hs): propagate context from above + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -124,13 +133,6 @@ type Webhook struct { } `json:"-"` } -func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - return w.DoWithContext(ctx, client, reqBody, data) -} - func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) if err != nil { diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 656d75d8..a61da39c 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/hmac" "crypto/sha256" "crypto/tls" @@ -13,6 +14,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/pkg/errors" "github.com/smallstep/assert" @@ -522,7 +524,11 @@ func TestWebhook_Do(t *testing.T) { reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) assert.FatalError(t, err) - got, err := tc.webhook.Do(http.DefaultClient, reqBody, tc.dataArg) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { assert.Equals(t, tc.expectErr.Error(), err.Error()) return @@ -553,11 +559,18 @@ func TestWebhook_Do(t *testing.T) { } reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) assert.FatalError(t, err) - _, err = wh.Do(client, reqBody, nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + _, err = wh.DoWithContext(ctx, client, reqBody, nil) assert.FatalError(t, err) + ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + wh.DisableTLSClientAuth = true - _, err = wh.Do(client, reqBody, nil) + _, err = wh.DoWithContext(ctx, client, reqBody, nil) assert.Error(t, err) }) }