|
|
|
@ -70,7 +70,7 @@ func Test_baseURLFromRequest(t *testing.T) {
|
|
|
|
|
if tc.requestPreparer != nil {
|
|
|
|
|
tc.requestPreparer(request)
|
|
|
|
|
}
|
|
|
|
|
result := baseURLFromRequest(request)
|
|
|
|
|
result := getBaseURLFromRequest(request)
|
|
|
|
|
if result == nil || tc.expectedResult == nil {
|
|
|
|
|
assert.Equals(t, result, tc.expectedResult)
|
|
|
|
|
} else if result.String() != tc.expectedResult.String() {
|
|
|
|
@ -81,7 +81,7 @@ func Test_baseURLFromRequest(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestHandler_baseURLFromRequest(t *testing.T) {
|
|
|
|
|
h := &Handler{}
|
|
|
|
|
// h := &Handler{}
|
|
|
|
|
req := httptest.NewRequest("GET", "/foo", nil)
|
|
|
|
|
req.Host = "test.ca.smallstep.com:8080"
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
@ -94,7 +94,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
h.baseURLFromRequest(next)(w, req)
|
|
|
|
|
baseURLFromRequest(next)(w, req)
|
|
|
|
|
|
|
|
|
|
req = httptest.NewRequest("GET", "/foo", nil)
|
|
|
|
|
req.Host = ""
|
|
|
|
@ -103,7 +103,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
|
|
|
|
|
assert.Equals(t, baseURLFromContext(r.Context()), nil)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
h.baseURLFromRequest(next)(w, req)
|
|
|
|
|
baseURLFromRequest(next)(w, req)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestHandler_addNonce(t *testing.T) {
|
|
|
|
@ -139,10 +139,10 @@ func TestHandler_addNonce(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{db: tc.db}
|
|
|
|
|
// h := &Handler{db: tc.db}
|
|
|
|
|
req := httptest.NewRequest("GET", u, nil)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.addNonce(testNext)(w, req)
|
|
|
|
|
addNonce(testNext)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -195,11 +195,11 @@ func TestHandler_addDirLink(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{linker: tc.linker}
|
|
|
|
|
// h := &Handler{linker: tc.linker}
|
|
|
|
|
req := httptest.NewRequest("GET", "/foo", nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.addDirLink(testNext)(w, req)
|
|
|
|
|
addDirLink(testNext)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -242,7 +242,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|
|
|
|
"fail/provisioner-not-set": func(t *testing.T) test {
|
|
|
|
|
return test{
|
|
|
|
|
h: Handler{
|
|
|
|
|
linker: NewLinker("dns", "acme"),
|
|
|
|
|
// linker: NewLinker("dns", "acme"),
|
|
|
|
|
},
|
|
|
|
|
url: u,
|
|
|
|
|
ctx: context.Background(),
|
|
|
|
@ -254,7 +254,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|
|
|
|
"fail/general-bad-content-type": func(t *testing.T) test {
|
|
|
|
|
return test{
|
|
|
|
|
h: Handler{
|
|
|
|
|
linker: NewLinker("dns", "acme"),
|
|
|
|
|
// linker: NewLinker("dns", "acme"),
|
|
|
|
|
},
|
|
|
|
|
url: u,
|
|
|
|
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
|
|
@ -266,7 +266,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|
|
|
|
"fail/certificate-bad-content-type": func(t *testing.T) test {
|
|
|
|
|
return test{
|
|
|
|
|
h: Handler{
|
|
|
|
|
linker: NewLinker("dns", "acme"),
|
|
|
|
|
// linker: NewLinker("dns", "acme"),
|
|
|
|
|
},
|
|
|
|
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
|
|
|
contentType: "foo",
|
|
|
|
@ -277,7 +277,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|
|
|
|
"ok": func(t *testing.T) test {
|
|
|
|
|
return test{
|
|
|
|
|
h: Handler{
|
|
|
|
|
linker: NewLinker("dns", "acme"),
|
|
|
|
|
// linker: NewLinker("dns", "acme"),
|
|
|
|
|
},
|
|
|
|
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
|
|
|
contentType: "application/jose+json",
|
|
|
|
@ -287,7 +287,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|
|
|
|
"ok/certificate/pkix-cert": func(t *testing.T) test {
|
|
|
|
|
return test{
|
|
|
|
|
h: Handler{
|
|
|
|
|
linker: NewLinker("dns", "acme"),
|
|
|
|
|
// linker: NewLinker("dns", "acme"),
|
|
|
|
|
},
|
|
|
|
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
|
|
|
contentType: "application/pkix-cert",
|
|
|
|
@ -297,7 +297,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|
|
|
|
"ok/certificate/jose+json": func(t *testing.T) test {
|
|
|
|
|
return test{
|
|
|
|
|
h: Handler{
|
|
|
|
|
linker: NewLinker("dns", "acme"),
|
|
|
|
|
// linker: NewLinker("dns", "acme"),
|
|
|
|
|
},
|
|
|
|
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
|
|
|
contentType: "application/jose+json",
|
|
|
|
@ -307,7 +307,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|
|
|
|
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
|
|
|
|
|
return test{
|
|
|
|
|
h: Handler{
|
|
|
|
|
linker: NewLinker("dns", "acme"),
|
|
|
|
|
// linker: NewLinker("dns", "acme"),
|
|
|
|
|
},
|
|
|
|
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
|
|
|
|
contentType: "application/pkcs7-mime",
|
|
|
|
@ -326,7 +326,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
req.Header.Add("Content-Type", tc.contentType)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
tc.h.verifyContentType(testNext)(w, req)
|
|
|
|
|
verifyContentType(testNext)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -390,11 +390,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{}
|
|
|
|
|
// h := &Handler{}
|
|
|
|
|
req := httptest.NewRequest("GET", u, nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.isPostAsGet(testNext)(w, req)
|
|
|
|
|
isPostAsGet(testNext)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -481,10 +481,10 @@ func TestHandler_parseJWS(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{}
|
|
|
|
|
// h := &Handler{}
|
|
|
|
|
req := httptest.NewRequest("GET", u, tc.body)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.parseJWS(tc.next)(w, req)
|
|
|
|
|
parseJWS(tc.next)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -679,11 +679,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{}
|
|
|
|
|
// h := &Handler{}
|
|
|
|
|
req := httptest.NewRequest("GET", u, nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
|
|
|
|
verifyAndExtractJWSPayload(tc.next)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -881,11 +881,11 @@ func TestHandler_lookupJWK(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{db: tc.db, linker: tc.linker}
|
|
|
|
|
// h := &Handler{db: tc.db, linker: tc.linker}
|
|
|
|
|
req := httptest.NewRequest("GET", u, nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.lookupJWK(tc.next)(w, req)
|
|
|
|
|
lookupJWK(tc.next)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -1077,11 +1077,11 @@ func TestHandler_extractJWK(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{db: tc.db}
|
|
|
|
|
// h := &Handler{db: tc.db}
|
|
|
|
|
req := httptest.NewRequest("GET", u, nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.extractJWK(tc.next)(w, req)
|
|
|
|
|
extractJWK(tc.next)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -1444,11 +1444,11 @@ func TestHandler_validateJWS(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{db: tc.db}
|
|
|
|
|
// h := &Handler{db: tc.db}
|
|
|
|
|
req := httptest.NewRequest("GET", u, nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.validateJWS(tc.next)(w, req)
|
|
|
|
|
validateJWS(tc.next)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -1628,11 +1628,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|
|
|
|
for name, prep := range tests {
|
|
|
|
|
tc := prep(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{db: tc.db, linker: tc.linker}
|
|
|
|
|
// h := &Handler{db: tc.db, linker: tc.linker}
|
|
|
|
|
req := httptest.NewRequest("GET", u, nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.extractOrLookupJWK(tc.next)(w, req)
|
|
|
|
|
extractOrLookupJWK(tc.next)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
@ -1717,11 +1717,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
|
|
|
|
|
for name, run := range tests {
|
|
|
|
|
tc := run(t)
|
|
|
|
|
t.Run(name, func(t *testing.T) {
|
|
|
|
|
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
|
|
|
|
// h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
|
|
|
|
req := httptest.NewRequest("GET", u, nil)
|
|
|
|
|
req = req.WithContext(tc.ctx)
|
|
|
|
|
w := httptest.NewRecorder()
|
|
|
|
|
h.checkPrerequisites(tc.next)(w, req)
|
|
|
|
|
checkPrerequisites(tc.next)(w, req)
|
|
|
|
|
res := w.Result()
|
|
|
|
|
|
|
|
|
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
|
|
|
|