diff --git a/acme/api/handler.go b/acme/api/handler.go index 11cd74f2..8c3a7cd5 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -123,7 +123,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { api.WriteError(w, err) return } - dir := h.Auth.GetDirectory(prov) + dir := h.Auth.GetDirectory(prov, baseURLFromRequest(r)) api.JSON(w, dir) } diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index ebafbbb8..8d42916f 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -30,7 +30,7 @@ type mockAcmeAuthority struct { getAuthz func(p provisioner.Interface, accID string, id string) (*acme.Authz, error) getCertificate func(accID string, id string) ([]byte, error) getChallenge func(p provisioner.Interface, accID string, id string) (*acme.Challenge, error) - getDirectory func(provisioner.Interface) *acme.Directory + getDirectory func(provisioner.Interface, string) *acme.Directory getLink func(acme.Link, string, bool, ...string) string getOrder func(p provisioner.Interface, accID string, id string) (*acme.Order, error) getOrdersByAccount func(p provisioner.Interface, id string) ([]string, error) @@ -108,9 +108,9 @@ func (m *mockAcmeAuthority) GetChallenge(p provisioner.Interface, accID, id stri return m.ret1.(*acme.Challenge), m.err } -func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface) *acme.Directory { +func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface, baseURLFromRequest string) *acme.Directory { if m.getDirectory != nil { - return m.getDirectory(p) + return m.getDirectory(p, baseURLFromRequest) } return m.ret1.(*acme.Directory) } @@ -276,6 +276,7 @@ func TestHandlerGetDirectory(t *testing.T) { t.Run(name, func(t *testing.T) { h := New(auth).(*Handler) req := httptest.NewRequest("GET", url, nil) + req.Header.Add("X-Forwarded-Proto", "https") req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetDirectory(w, req) diff --git a/acme/api/hostutil.go b/acme/api/hostutil.go new file mode 100644 index 00000000..0d1df9fd --- /dev/null +++ b/acme/api/hostutil.go @@ -0,0 +1,28 @@ +package api + +import ( + "net/http" +) + +// baseURLFromRequest determines the base URL which should be used for constructing link URLs in e.g. the ACME directory +// result by taking the request Host, TLS and Header[X-Forwarded-Proto] values into consideration. +// If the Request.Host is an empty string, we return an empty string, to indicate that the configured +// URL values should be used instead. +// If this function returns a non-empty result, then this should be used in constructing ACME link URLs. +func baseURLFromRequest(r *http.Request) string { + // TODO: I semantically copied the functionality of determining the protol from boulder web/relative.go + // which allows HTTP. Previously this was always forced to be HTTPS for absolute URLs. Should this be + // changed to also always force HTTPS protocol? + proto := "http" + if specifiedProto := r.Header.Get("X-Forwarded-Proto"); specifiedProto != "" { + proto = specifiedProto + } else if r.TLS != nil { + proto += "s" + } + + host := r.Host + if host == "" { + return "" + } + return proto + "://" + host +} diff --git a/acme/api/hostutil_test.go b/acme/api/hostutil_test.go new file mode 100644 index 00000000..6555e3c7 --- /dev/null +++ b/acme/api/hostutil_test.go @@ -0,0 +1,70 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetBaseUrl(t *testing.T) { + tests := []struct { + testFailedDescription string + targetURL string + expectedResult string + requestPreparer func(*http.Request) + }{ + { + "HTTP host pass-through failed.", + "http://my.dummy.host", + "http://my.dummy.host", + nil, + }, + { + "HTTPS host pass-through failed.", + "https://my.dummy.host", + "https://my.dummy.host", + nil, + }, + { + "Port pass-through failed", + "http://host.with.port:8080", + "http://host.with.port:8080", + nil, + }, + { + "Explicit host from Request.Host was not used.", + "http://some.target.host:8080", + "http://proxied.host", + func(r *http.Request) { + r.Host = "proxied.host" + }, + }, + { + "Explicit forwarded protocol from request header X-Forwarded-Proto was not used.", + "http://some.host", + "ssl://some.host", + func(r *http.Request) { + r.Header.Add("X-Forwarded-Proto", "ssl") + }, + }, + { + "Missing Request.Host value did not result in empty string result.", + "http://some.host", + "", + func(r *http.Request) { + r.Host = "" + }, + }, + } + + for _, test := range tests { + request := httptest.NewRequest("GET", test.targetURL, nil) + if test.requestPreparer != nil { + test.requestPreparer(request) + } + result := baseURLFromRequest(request) + if result != test.expectedResult { + t.Errorf("Expected %q, but got %q", test.expectedResult, result) + } + } +} diff --git a/acme/authority.go b/acme/authority.go index fe51ea9b..5ec0daa3 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -25,7 +25,7 @@ type Interface interface { GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error) GetAuthz(provisioner.Interface, string, string) (*Authz, error) GetCertificate(string, string) ([]byte, error) - GetDirectory(provisioner.Interface) *Directory + GetDirectory(provisioner.Interface, string) *Directory GetLink(Link, string, bool, ...string) string GetOrder(provisioner.Interface, string, string) (*Order, error) GetOrdersByAccount(provisioner.Interface, string) ([]string, error) @@ -82,14 +82,14 @@ func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) } // GetDirectory returns the ACME directory object. -func (a *Authority) GetDirectory(p provisioner.Interface) *Directory { +func (a *Authority) GetDirectory(p provisioner.Interface, baseURLFromRequest string) *Directory { name := url.PathEscape(p.GetName()) return &Directory{ - NewNonce: a.dir.getLink(NewNonceLink, name, true), - NewAccount: a.dir.getLink(NewAccountLink, name, true), - NewOrder: a.dir.getLink(NewOrderLink, name, true), - RevokeCert: a.dir.getLink(RevokeCertLink, name, true), - KeyChange: a.dir.getLink(KeyChangeLink, name, true), + NewNonce: a.dir.getLinkFromBaseURL(NewNonceLink, name, true, baseURLFromRequest), + NewAccount: a.dir.getLinkFromBaseURL(NewAccountLink, name, true, baseURLFromRequest), + NewOrder: a.dir.getLinkFromBaseURL(NewOrderLink, name, true, baseURLFromRequest), + RevokeCert: a.dir.getLinkFromBaseURL(RevokeCertLink, name, true, baseURLFromRequest), + KeyChange: a.dir.getLinkFromBaseURL(KeyChangeLink, name, true, baseURLFromRequest), } } diff --git a/acme/authority_test.go b/acme/authority_test.go index 525a61b9..6e679192 100644 --- a/acme/authority_test.go +++ b/acme/authority_test.go @@ -73,7 +73,7 @@ func TestAuthorityGetDirectory(t *testing.T) { auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) assert.FatalError(t, err) prov := newProv() - acmeDir := auth.GetDirectory(prov) + acmeDir := auth.GetDirectory(prov, "") assert.Equals(t, acmeDir.NewNonce, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", URLSafeProvisionerName(prov))) assert.Equals(t, acmeDir.NewAccount, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", URLSafeProvisionerName(prov))) assert.Equals(t, acmeDir.NewOrder, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", URLSafeProvisionerName(prov))) @@ -82,6 +82,20 @@ func TestAuthorityGetDirectory(t *testing.T) { assert.Equals(t, acmeDir.KeyChange, fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", URLSafeProvisionerName(prov))) } +func TestAuthorityGetDirectoryWithBaseURL(t *testing.T) { + baseURL := "http://my.proxied.host" + auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) + assert.FatalError(t, err) + prov := newProv() + acmeDir := auth.GetDirectory(prov, baseURL) + assert.Equals(t, acmeDir.NewNonce, fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, URLSafeProvisionerName(prov))) + assert.Equals(t, acmeDir.NewAccount, fmt.Sprintf("%s/acme/%s/new-account", baseURL, URLSafeProvisionerName(prov))) + assert.Equals(t, acmeDir.NewOrder, fmt.Sprintf("%s/acme/%s/new-order", baseURL, URLSafeProvisionerName(prov))) + //assert.Equals(t, acmeDir.NewOrder, "%s/acme/new-authz") + assert.Equals(t, acmeDir.RevokeCert, fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, URLSafeProvisionerName(prov))) + assert.Equals(t, acmeDir.KeyChange, fmt.Sprintf("%s/acme/%s/key-change", baseURL, URLSafeProvisionerName(prov))) +} + func TestAuthorityNewNonce(t *testing.T) { type test struct { auth *Authority diff --git a/acme/directory.go b/acme/directory.go index 85819f10..76e9d541 100644 --- a/acme/directory.go +++ b/acme/directory.go @@ -102,6 +102,12 @@ func (l Link) String() string { // getLink returns an absolute or partial path to the given resource. func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string { + return d.getLinkFromBaseURL(typ, provisionerName, abs, "", inputs...) +} + +// getLinkFromBaseURL returns an absolute or partial path to the given resource and a base URL dynamically obtained from the request for which +// the link is being calculated. +func (d *directory) getLinkFromBaseURL(typ Link, provisionerName string, abs bool, baseURLFromRequest string, inputs ...string) string { var link string switch typ { case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink: @@ -114,7 +120,11 @@ func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs . link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0]) } if abs { - return fmt.Sprintf("https://%s/%s%s", d.dns, d.prefix, link) + baseURL := baseURLFromRequest + if baseURL == "" { + baseURL = "https://" + d.dns + } + return fmt.Sprintf("%s/%s%s", baseURL, d.prefix, link) } return link }