|
| 1 | +package middlewares |
| 2 | + |
| 3 | +import ( |
| 4 | + "errors" |
| 5 | + "net/http" |
| 6 | + "strings" |
| 7 | + |
| 8 | + "github.com/codeshelldev/secured-signal-api/utils/config/structure" |
| 9 | + "github.com/codeshelldev/secured-signal-api/utils/jsonutils" |
| 10 | + log "github.com/codeshelldev/secured-signal-api/utils/logger" |
| 11 | + request "github.com/codeshelldev/secured-signal-api/utils/request" |
| 12 | +) |
| 13 | + |
| 14 | +var Policy Middleware = Middleware{ |
| 15 | + Name: "Policy", |
| 16 | + Use: policyHandler, |
| 17 | +} |
| 18 | + |
| 19 | +func policyHandler(next http.Handler) http.Handler { |
| 20 | + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { |
| 21 | + settings := getSettingsByReq(req) |
| 22 | + |
| 23 | + policies := settings.ACCESS.FIELD_POLOCIES |
| 24 | + |
| 25 | + if policies == nil { |
| 26 | + policies = getSettings("*").ACCESS.FIELD_POLOCIES |
| 27 | + } |
| 28 | + |
| 29 | + body, err := request.GetReqBody(w, req) |
| 30 | + |
| 31 | + if err != nil { |
| 32 | + log.Error("Could not get Request Body: ", err.Error()) |
| 33 | + } |
| 34 | + |
| 35 | + if body.Empty { |
| 36 | + body.Data = map[string]any{} |
| 37 | + } |
| 38 | + |
| 39 | + headerData := request.GetReqHeaders(req) |
| 40 | + |
| 41 | + shouldBlock, field := doBlock(body.Data, headerData, policies) |
| 42 | + |
| 43 | + if shouldBlock { |
| 44 | + log.Warn("User tried to use blocked field: ", field) |
| 45 | + http.Error(w, "Forbidden", http.StatusForbidden) |
| 46 | + return |
| 47 | + } |
| 48 | + |
| 49 | + next.ServeHTTP(w, req) |
| 50 | + }) |
| 51 | +} |
| 52 | + |
| 53 | +func getPolicies(policies map[string]structure.FieldPolicy) (map[string]structure.FieldPolicy, map[string]structure.FieldPolicy) { |
| 54 | + blockedFields := map[string]structure.FieldPolicy{} |
| 55 | + allowedFields := map[string]structure.FieldPolicy{} |
| 56 | + |
| 57 | + for field, policy := range policies { |
| 58 | + switch policy.Action { |
| 59 | + case "block": |
| 60 | + blockedFields[field] = policy |
| 61 | + case "allow": |
| 62 | + allowedFields[field] = policy |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + return allowedFields, blockedFields |
| 67 | +} |
| 68 | + |
| 69 | +func getField(field string, body map[string]any, headers map[string]any) (any, error) { |
| 70 | + isHeader := strings.HasPrefix(field, "#") |
| 71 | + isBody := strings.HasPrefix(field, "@") |
| 72 | + |
| 73 | + fieldWithoutPrefix := field[1:] |
| 74 | + |
| 75 | + var value any |
| 76 | + |
| 77 | + if body[fieldWithoutPrefix] != nil && isBody { |
| 78 | + value = body[fieldWithoutPrefix] |
| 79 | + } else if headers[fieldWithoutPrefix] != nil && isHeader { |
| 80 | + value = headers[fieldWithoutPrefix] |
| 81 | + } |
| 82 | + |
| 83 | + if value != nil { |
| 84 | + return value, nil |
| 85 | + } |
| 86 | + |
| 87 | + return value, errors.New("field not found") |
| 88 | +} |
| 89 | + |
| 90 | +func doBlock(body map[string]any, headers map[string]any, policies map[string]structure.FieldPolicy) (bool, string) { |
| 91 | + if policies == nil { |
| 92 | + return false, "" |
| 93 | + } else if len(policies) <= 0 { |
| 94 | + return false, "" |
| 95 | + } |
| 96 | + |
| 97 | + allowed, blocked := getPolicies(policies) |
| 98 | + |
| 99 | + var cause string |
| 100 | + |
| 101 | + var isExplictlyAllowed, isExplicitlyBlocked bool |
| 102 | + |
| 103 | + for field, policy := range allowed { |
| 104 | + value, err := getField(field, body, headers) |
| 105 | + |
| 106 | + log.Dev("Checking ", field, "...") |
| 107 | + log.Dev("Got Value of ", jsonutils.ToJson(value)) |
| 108 | + |
| 109 | + if value == policy.Value && err == nil { |
| 110 | + isExplictlyAllowed = true |
| 111 | + cause = field |
| 112 | + break |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + for field, policy := range blocked { |
| 117 | + value, err := getField(field, body, headers) |
| 118 | + |
| 119 | + log.Dev("Checking ", field, "...") |
| 120 | + log.Dev("Got Value of ", jsonutils.ToJson(value)) |
| 121 | + |
| 122 | + if value == policy.Value && err == nil { |
| 123 | + isExplicitlyBlocked = true |
| 124 | + cause = field |
| 125 | + break |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + // Block all except explicitly Allowed |
| 130 | + if len(blocked) == 0 && len(allowed) != 0 { |
| 131 | + return !isExplictlyAllowed, cause |
| 132 | + } |
| 133 | + |
| 134 | + // Allow all except explicitly Blocked |
| 135 | + if len(allowed) == 0 && len(blocked) != 0 { |
| 136 | + return isExplicitlyBlocked, cause |
| 137 | + } |
| 138 | + |
| 139 | + // Excplicitly Blocked except excplictly Allowed |
| 140 | + if len(blocked) != 0 && len(allowed) != 0 { |
| 141 | + return isExplicitlyBlocked && !isExplictlyAllowed, cause |
| 142 | + } |
| 143 | + |
| 144 | + // Block all |
| 145 | + return true, "" |
| 146 | +} |
0 commit comments