diff --git a/internals/proxy/middlewares/mapping.go b/internals/proxy/middlewares/mapping.go index a5f51790..f2d7fdf6 100644 --- a/internals/proxy/middlewares/mapping.go +++ b/internals/proxy/middlewares/mapping.go @@ -31,10 +31,11 @@ func mappingHandler(next http.Handler) http.Handler { settings.MESSAGE.VARIABLES = getSettings("*").MESSAGE.VARIABLES } - body, err := request.GetReqBody(w, req) + body, err := request.GetReqBody(req) if err != nil { log.Error("Could not get Request Body: ", err.Error()) + http.Error(w, "Bad Request: invalid body", http.StatusBadRequest) } var modifiedBody bool diff --git a/internals/proxy/middlewares/message.go b/internals/proxy/middlewares/message.go index 390f1fc0..47662085 100644 --- a/internals/proxy/middlewares/message.go +++ b/internals/proxy/middlewares/message.go @@ -30,10 +30,11 @@ func messageHandler(next http.Handler) http.Handler { messageTemplate = getSettings("*").MESSAGE.TEMPLATE } - body, err := request.GetReqBody(w, req) + body, err := request.GetReqBody(req) if err != nil { log.Error("Could not get Request Body: ", err.Error()) + http.Error(w, "Bad Request: invalid body", http.StatusBadRequest) } bodyData := map[string]any{} @@ -83,7 +84,7 @@ func messageHandler(next http.Handler) http.Handler { }) } -func TemplateMessage(template string, bodyData map[string]any, headerData map[string]any, variables map[string]any) (map[string]any, error) { +func TemplateMessage(template string, bodyData map[string]any, headerData map[string][]string, variables map[string]any) (map[string]any, error) { bodyData["message_template"] = template data, _, err := TemplateBody(bodyData, headerData, variables) diff --git a/internals/proxy/middlewares/policy.go b/internals/proxy/middlewares/policy.go index 3e3a4558..e7cfc6d9 100644 --- a/internals/proxy/middlewares/policy.go +++ b/internals/proxy/middlewares/policy.go @@ -3,11 +3,11 @@ package middlewares import ( "errors" "net/http" - "strings" "github.com/codeshelldev/secured-signal-api/internals/config/structure" log "github.com/codeshelldev/secured-signal-api/utils/logger" request "github.com/codeshelldev/secured-signal-api/utils/request" + "github.com/codeshelldev/secured-signal-api/utils/request/requestkeys" ) var Policy Middleware = Middleware{ @@ -25,10 +25,11 @@ func policyHandler(next http.Handler) http.Handler { policies = getSettings("*").ACCESS.FIELD_POLOCIES } - body, err := request.GetReqBody(w, req) + body, err := request.GetReqBody(req) if err != nil { log.Error("Could not get Request Body: ", err.Error()) + http.Error(w, "Bad Request: invalid body", http.StatusBadRequest) } if body.Empty { @@ -65,19 +66,10 @@ func getPolicies(policies map[string]structure.FieldPolicy) (map[string]structur return allowedFields, blockedFields } -func getField(field string, body map[string]any, headers map[string]any) (any, error) { - isHeader := strings.HasPrefix(field, "#") - isBody := strings.HasPrefix(field, "@") +func getField(key string, body map[string]any, headers map[string][]string) (any, error) { + field := requestkeys.Parse(key) - fieldWithoutPrefix := field[1:] - - var value any - - if body[fieldWithoutPrefix] != nil && isBody { - value = body[fieldWithoutPrefix] - } else if headers[fieldWithoutPrefix] != nil && isHeader { - value = headers[fieldWithoutPrefix] - } + value := requestkeys.GetFromBodyAndHeaders(field, body, headers) if value != nil { return value, nil @@ -86,7 +78,7 @@ func getField(field string, body map[string]any, headers map[string]any) (any, e return value, errors.New("field not found") } -func doBlock(body map[string]any, headers map[string]any, policies map[string]structure.FieldPolicy) (bool, string) { +func doBlock(body map[string]any, headers map[string][]string, policies map[string]structure.FieldPolicy) (bool, string) { if policies == nil { return false, "" } else if len(policies) <= 0 { diff --git a/internals/proxy/middlewares/template.go b/internals/proxy/middlewares/template.go index a7ea7166..1577d363 100644 --- a/internals/proxy/middlewares/template.go +++ b/internals/proxy/middlewares/template.go @@ -14,6 +14,7 @@ import ( log "github.com/codeshelldev/secured-signal-api/utils/logger" query "github.com/codeshelldev/secured-signal-api/utils/query" request "github.com/codeshelldev/secured-signal-api/utils/request" + "github.com/codeshelldev/secured-signal-api/utils/request/requestkeys" templating "github.com/codeshelldev/secured-signal-api/utils/templating" ) @@ -30,10 +31,11 @@ func templateHandler(next http.Handler) http.Handler { variables = getSettings("*").MESSAGE.VARIABLES } - body, err := request.GetReqBody(w, req) + body, err := request.GetReqBody(req) if err != nil { log.Error("Could not get Request Body: ", err.Error()) + http.Error(w, "Bad Request: invalid body", http.StatusBadRequest) } bodyData := map[string]any{} @@ -146,8 +148,8 @@ func prefixData(prefix string, data map[string]any) map[string]any { return res } -func cleanHeaders(headers map[string]any) map[string]any { - cleanedHeaders := map[string]any{} +func cleanHeaders(headers map[string][]string) map[string][]string { + cleanedHeaders := map[string][]string{} for key, value := range headers { cleanedKey := strings.ReplaceAll(key, "-", "_") @@ -155,30 +157,30 @@ func cleanHeaders(headers map[string]any) map[string]any { cleanedHeaders[cleanedKey] = value } - authHeader, ok := cleanedHeaders["Authorization"].([]string) + authHeader, ok := cleanedHeaders["Authorization"] if !ok { authHeader = []string{"UNKNOWN REDACTED"} } - cleanedHeaders["Authorization"] = strings.Split(authHeader[0], ` `)[0] + " REDACTED" + cleanedHeaders["Authorization"] = []string{strings.Split(authHeader[0], ` `)[0] + " REDACTED"} return cleanedHeaders } -func TemplateBody(body map[string]any, headers map[string]any, VARIABLES map[string]any) (map[string]any, bool, error) { +func TemplateBody(body map[string]any, headers map[string][]string, VARIABLES map[string]any) (map[string]any, bool, error) { var modified bool headers = cleanHeaders(headers) - // Normalize #Var and @Var to .header_key_Var and .body_key_Var - normalizedBody, err := normalizeData("@", "body_key_", body) + // Normalize `keys.BodyPrefix` + "Var" and `keys.HeaderPrefix` + "Var" to "".header_key_Var" and ".body_key_Var" + normalizedBody, err := normalizeData(requestkeys.BodyPrefix, "body_key_", body) if err != nil { return body, false, err } - normalizedBody, err = normalizeData("#", "header_key_", normalizedBody) + normalizedBody, err = normalizeData(requestkeys.HeaderPrefix, "header_key_", normalizedBody) if err != nil { return body, false, err @@ -188,7 +190,7 @@ func TemplateBody(body map[string]any, headers map[string]any, VARIABLES map[str prefixedBody := prefixData("body_key_", normalizedBody) // Prefix Header Data with header_key_ - prefixedHeaders := prefixData("header_key_", headers) + prefixedHeaders := prefixData("header_key_", request.ParseHeaders(headers)) variables := VARIABLES diff --git a/tests/string_test.go b/tests/string_test.go index 41fe0156..075d31a7 100644 --- a/tests/string_test.go +++ b/tests/string_test.go @@ -8,25 +8,25 @@ import ( ) func TestStringEscaping(t *testing.T) { - str1 := `\#` + str1 := `\-` - res1 := stringutils.IsEscaped(str1, "#") + res1 := stringutils.IsEscaped(str1, "-") if !res1 { t.Error("Expected: ", str1, " == true", "; Got: ", str1, " == ", res1) } - str2 := "#" + str2 := "-" - res2 := stringutils.IsEscaped(str2, "#") + res2 := stringutils.IsEscaped(str2, "-") if res2 { t.Error("Expected: ", str2, " == false", "; Got: ", str2, " == ", res2) } - str3 := `#\#` + str3 := `-\-` - res3 := stringutils.Contains(str3, "#") + res3 := stringutils.Contains(str3, "-") if !res3 { t.Error("Expected: ", str3, " == true", "; Got: ", str3, " == ", res3) diff --git a/utils/request/request.go b/utils/request/request.go index ae5e8157..d3a5241c 100644 --- a/utils/request/request.go +++ b/utils/request/request.go @@ -1,4 +1,4 @@ -package req +package request import ( "bytes" @@ -97,8 +97,8 @@ func GetBody(req *http.Request) ([]byte, error) { return bodyBytes, nil } -func GetReqHeaders(req *http.Request) map[string]any { - data := map[string]any{} +func GetReqHeaders(req *http.Request) map[string][]string { + data := map[string][]string{} for key, value := range req.Header { data[key] = value @@ -107,14 +107,26 @@ func GetReqHeaders(req *http.Request) map[string]any { return data } -func GetReqBody(w http.ResponseWriter, req *http.Request) (Body, error) { +func ParseHeaders(headers map[string][]string) map[string]any { + generic := make(map[string]any, len(headers)) + + for i, header := range headers { + if len(header) == 1 { + generic[i] = header[0] + } else { + generic[i] = header + } + } + + return generic +} + +func GetReqBody(req *http.Request) (Body, error) { bytes, err := GetBody(req) var isEmpty bool if err != nil { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return Body{Empty: true}, err } @@ -129,16 +141,12 @@ func GetReqBody(w http.ResponseWriter, req *http.Request) (Body, error) { data, err = GetJsonData(bytes) if err != nil { - http.Error(w, "Bad Request: invalid JSON", http.StatusBadRequest) - return Body{Empty: true}, err } case Form: data, err = GetFormData(bytes) if err != nil { - http.Error(w, "Bad Request: invalid Form", http.StatusBadRequest) - return Body{Empty: true}, err } } diff --git a/utils/request/requestkeys/requestkeys.go b/utils/request/requestkeys/requestkeys.go new file mode 100644 index 00000000..39f88463 --- /dev/null +++ b/utils/request/requestkeys/requestkeys.go @@ -0,0 +1,61 @@ +package requestkeys + +import "github.com/codeshelldev/secured-signal-api/utils/request" + +type Field struct { + Prefix string + Key string +} + +var BodyPrefix = "@" +var HeaderPrefix = "#" + +func Parse(str string) Field { + prefix := str[1:] + key := str[:1] + + return Field{ + Prefix: prefix, + Key: key, + } +} + +func GetByField(field Field, data map[string]any) any { + key := field.Prefix + field.Key + + return data[key] +} + +func PrefixBody(body map[string]any) map[string]any { + res := map[string]any{} + + for key, value := range body { + res[BodyPrefix + key] = value + } + + return res +} + +func PrefixHeaders(headers map[string][]string) map[string][]string { + res := map[string][]string{} + + for key, value := range headers { + res[HeaderPrefix + key] = value + } + + return res +} + +func GetFromBodyAndHeaders(field Field, body map[string]any, headers map[string][]string) any { + body = PrefixBody(body) + headers = PrefixHeaders(headers) + + switch(field.Prefix) { + case BodyPrefix: + return GetByField(field, body) + case HeaderPrefix: + return GetByField(field, request.ParseHeaders(headers)) + } + + return nil +} \ No newline at end of file