From 3c12b4f5adfaaf34bee815f827dada7526b35151 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 3 Oct 2023 16:32:55 +0200 Subject: [PATCH] Improve decoding SCEP requests --- scep/api/api.go | 22 ++++++++------- scep/api/api_test.go | 64 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/scep/api/api.go b/scep/api/api.go index 614b5184..c3159a71 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -151,11 +151,14 @@ func decodeRequest(r *http.Request) (request, error) { defer r.Body.Close() method := r.Method - query := r.URL.Query() + query, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + return request{}, fmt.Errorf("failed parsing URL query: %w", err) + } - var operation string - if _, ok := query["operation"]; ok { - operation = query.Get("operation") + operation := query.Get("operation") + if operation == "" { + return request{}, errors.New("no operation provided") } switch method { @@ -167,14 +170,13 @@ func decodeRequest(r *http.Request) (request, error) { Message: []byte{}, }, nil case opnPKIOperation: - var message string - if _, ok := query["message"]; ok { - message = query.Get("message") + message := query.Get("message") + if message == "" { + return request{}, errors.New("message must not be empty") } - // TODO: verify this; right type of encoding? Needs additional transformations? decodedMessage, err := base64.StdEncoding.DecodeString(message) if err != nil { - return request{}, err + return request{}, fmt.Errorf("failed decoding message: %w", err) } return request{ Operation: operation, @@ -186,7 +188,7 @@ func decodeRequest(r *http.Request) (request, error) { case http.MethodPost: body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) if err != nil { - return request{}, err + return request{}, fmt.Errorf("failed reading request body: %w", err) } return request{ Operation: operation, diff --git a/scep/api/api_test.go b/scep/api/api_test.go index ef3e57ab..2a26f534 100644 --- a/scep/api/api_test.go +++ b/scep/api/api_test.go @@ -3,15 +3,23 @@ package api import ( "bytes" + "encoding/base64" "errors" + "fmt" "net/http" "net/http/httptest" - "reflect" + "net/url" "testing" "testing/iotest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_decodeRequest(t *testing.T) { + randomB64 := "wx/1mQ49TpdLRfvVjQhXNSe8RB3hjZEarqYp5XVIxpSbvOhQSs8hP2TgucID1IputbA8JC6CbsUpcVae3+8hRNqs5pTsSHP2aNxsw8AHGSX9dZVymSclkUV8irk+ztfEfs7aLA==" + expectedRandom, err := base64.StdEncoding.DecodeString(randomB64) + require.NoError(t, err) type args struct { r *http.Request } @@ -21,6 +29,22 @@ func Test_decodeRequest(t *testing.T) { want request wantErr bool }{ + { + name: "fail/invalid-query", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=bla;message=invalid-separator", http.NoBody), + }, + want: request{}, + wantErr: true, + }, + { + name: "fail/empty-operation", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=", http.NoBody), + }, + want: request{}, + wantErr: true, + }, { name: "fail/unsupported-method", args: args{ @@ -37,6 +61,14 @@ func Test_decodeRequest(t *testing.T) { want: request{}, wantErr: true, }, + { + name: "fail/get-PKIOperation-empty-message", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=PKIOperation&message=", http.NoBody), + }, + want: request{}, + wantErr: true, + }, { name: "fail/get-PKIOperation", args: args{ @@ -45,6 +77,14 @@ func Test_decodeRequest(t *testing.T) { want: request{}, wantErr: true, }, + { + name: "fail/get-PKIOperation-not-escaped", + args: args{ + r: httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://scep:8080/?operation=PKIOperation&message=%s", randomB64), http.NoBody), + }, + want: request{}, + wantErr: true, + }, { name: "fail/post-PKIOperation", args: args{ @@ -86,6 +126,17 @@ func Test_decodeRequest(t *testing.T) { }, wantErr: false, }, + { + name: "ok/get-PKIOperation-escaped", + args: args{ + r: httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://scep:8080/?operation=PKIOperation&message=%s", url.QueryEscape(randomB64)), http.NoBody), + }, + want: request{ + Operation: "PKIOperation", + Message: expectedRandom, + }, + wantErr: false, + }, { name: "ok/post-PKIOperation", args: args{ @@ -101,13 +152,14 @@ func Test_decodeRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := decodeRequest(tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeRequest() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Equal(t, tt.want, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeRequest() = %v, want %v", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } }