diff --git a/internals/proxy/middlewares/endpoints.go b/internals/proxy/middlewares/endpoints.go index 649ebf4b..02df94bb 100644 --- a/internals/proxy/middlewares/endpoints.go +++ b/internals/proxy/middlewares/endpoints.go @@ -53,6 +53,12 @@ func getEndpoints(endpoints []string) ([]string, []string) { } func isBlocked(endpoint string, endpoints []string) bool { + if endpoints == nil { + return false + } else if len(endpoints) <= 0 { + return false + } + allowed, blocked := getEndpoints(endpoints) isExplicitlyBlocked := slices.ContainsFunc(blocked, func(try string) bool { diff --git a/internals/proxy/middlewares/policy.go b/internals/proxy/middlewares/policy.go new file mode 100644 index 00000000..b252c251 --- /dev/null +++ b/internals/proxy/middlewares/policy.go @@ -0,0 +1,146 @@ +package middlewares + +import ( + "errors" + "net/http" + "strings" + + "github.com/codeshelldev/secured-signal-api/utils/config/structure" + "github.com/codeshelldev/secured-signal-api/utils/jsonutils" + log "github.com/codeshelldev/secured-signal-api/utils/logger" + request "github.com/codeshelldev/secured-signal-api/utils/request" +) + +var Policy Middleware = Middleware{ + Name: "Policy", + Use: policyHandler, +} + +func policyHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + settings := getSettingsByReq(req) + + policies := settings.ACCESS.FIELD_POLOCIES + + if policies == nil { + policies = getSettings("*").ACCESS.FIELD_POLOCIES + } + + body, err := request.GetReqBody(w, req) + + if err != nil { + log.Error("Could not get Request Body: ", err.Error()) + } + + if body.Empty { + body.Data = map[string]any{} + } + + headerData := request.GetReqHeaders(req) + + shouldBlock, field := doBlock(body.Data, headerData, policies) + + if shouldBlock { + log.Warn("User tried to use blocked field: ", field) + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + next.ServeHTTP(w, req) + }) +} + +func getPolicies(policies map[string]structure.FieldPolicy) (map[string]structure.FieldPolicy, map[string]structure.FieldPolicy) { + blockedFields := map[string]structure.FieldPolicy{} + allowedFields := map[string]structure.FieldPolicy{} + + for field, policy := range policies { + switch policy.Action { + case "block": + blockedFields[field] = policy + case "allow": + allowedFields[field] = policy + } + } + + return allowedFields, blockedFields +} + +func getField(field string, body map[string]any, headers map[string]any) (any, error) { + isHeader := strings.HasPrefix(field, "#") + isBody := strings.HasPrefix(field, "@") + + fieldWithoutPrefix := field[1:] + + var value any + + if body[fieldWithoutPrefix] != nil && isBody { + value = body[fieldWithoutPrefix] + } else if headers[fieldWithoutPrefix] != nil && isHeader { + value = headers[fieldWithoutPrefix] + } + + if value != nil { + return value, nil + } + + return value, errors.New("field not found") +} + +func doBlock(body map[string]any, headers map[string]any, policies map[string]structure.FieldPolicy) (bool, string) { + if policies == nil { + return false, "" + } else if len(policies) <= 0 { + return false, "" + } + + allowed, blocked := getPolicies(policies) + + var cause string + + var isExplictlyAllowed, isExplicitlyBlocked bool + + for field, policy := range allowed { + value, err := getField(field, body, headers) + + log.Dev("Checking ", field, "...") + log.Dev("Got Value of ", jsonutils.ToJson(value)) + + if value == policy.Value && err == nil { + isExplictlyAllowed = true + cause = field + break + } + } + + for field, policy := range blocked { + value, err := getField(field, body, headers) + + log.Dev("Checking ", field, "...") + log.Dev("Got Value of ", jsonutils.ToJson(value)) + + if value == policy.Value && err == nil { + isExplicitlyBlocked = true + cause = field + break + } + } + + // Block all except explicitly Allowed + if len(blocked) == 0 && len(allowed) != 0 { + return !isExplictlyAllowed, cause + } + + // Allow all except explicitly Blocked + if len(allowed) == 0 && len(blocked) != 0 { + return isExplicitlyBlocked, cause + } + + // Excplicitly Blocked except excplictly Allowed + if len(blocked) != 0 && len(allowed) != 0 { + return isExplicitlyBlocked && !isExplictlyAllowed, cause + } + + // Block all + return true, "" +} diff --git a/internals/proxy/proxy.go b/internals/proxy/proxy.go index acf9c3f7..c74bf028 100644 --- a/internals/proxy/proxy.go +++ b/internals/proxy/proxy.go @@ -37,6 +37,7 @@ func (proxy Proxy) Init() http.Handler { Use(m.Endpoints). Use(m.Template). Use(m.Mapping). + Use(m.Policy). Use(m.Message). Then(proxy.Use()) diff --git a/utils/config/config.go b/utils/config/config.go index e0779624..5422647e 100644 --- a/utils/config/config.go +++ b/utils/config/config.go @@ -130,6 +130,8 @@ func normalizeKeys(config *koanf.Koanf) { for _, key := range config.Keys() { lower := strings.ToLower(key) + log.Dev("Lowering key: ", key) + data[lower] = config.Get(key) } diff --git a/utils/config/structure/structure.go b/utils/config/structure/structure.go index 2a768199..0fd2e428 100644 --- a/utils/config/structure/structure.go +++ b/utils/config/structure/structure.go @@ -20,7 +20,7 @@ type SETTINGS struct { type MESSAGE_SETTINGS struct { VARIABLES map[string]any `koanf:"variables"` - FIELD_MAPPINGS map[string][]FieldMapping `koanf:"fieldMappings"` + FIELD_MAPPINGS map[string][]FieldMapping `koanf:"fieldmappings"` TEMPLATE string `koanf:"template"` } @@ -31,4 +31,10 @@ type FieldMapping struct { type ACCESS_SETTINGS struct { ENDPOINTS []string `koanf:"endpoints"` + FIELD_POLOCIES map[string]FieldPolicy `koanf:"fieldpolicies"` +} + +type FieldPolicy struct { + Value any `koanf:"value"` + Action string `koanf:"action"` } \ No newline at end of file diff --git a/utils/request/request.go b/utils/request/request.go index 1c8fbd47..ae5e8157 100644 --- a/utils/request/request.go +++ b/utils/request/request.go @@ -1,6 +1,7 @@ package req import ( + "bytes" "encoding/json" "errors" "io" @@ -85,12 +86,13 @@ func GetFormData(body []byte) (map[string]any, error) { func GetBody(req *http.Request) ([]byte, error) { bodyBytes, err := io.ReadAll(req.Body) - if err != nil { - req.Body.Close() + req.Body.Close() + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + if err != nil { return nil, err } - defer req.Body.Close() return bodyBytes, nil }