Skip to content
6 changes: 6 additions & 0 deletions internals/proxy/middlewares/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ func getEndpoints(endpoints []string) ([]string, []string) {
}

func isBlocked(endpoint string, endpoints []string) bool {
if endpoints == nil {
return false
} else if len(endpoints) <= 0 {
return false
}

allowed, blocked := getEndpoints(endpoints)

isExplicitlyBlocked := slices.ContainsFunc(blocked, func(try string) bool {
Expand Down
146 changes: 146 additions & 0 deletions internals/proxy/middlewares/policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package middlewares

import (
"errors"
"net/http"
"strings"

"github.com/codeshelldev/secured-signal-api/utils/config/structure"
"github.com/codeshelldev/secured-signal-api/utils/jsonutils"
log "github.com/codeshelldev/secured-signal-api/utils/logger"
request "github.com/codeshelldev/secured-signal-api/utils/request"
)

var Policy Middleware = Middleware{
Name: "Policy",
Use: policyHandler,
}

func policyHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
settings := getSettingsByReq(req)

policies := settings.ACCESS.FIELD_POLOCIES

if policies == nil {
policies = getSettings("*").ACCESS.FIELD_POLOCIES
}

body, err := request.GetReqBody(w, req)

if err != nil {
log.Error("Could not get Request Body: ", err.Error())
}

if body.Empty {
body.Data = map[string]any{}
}

headerData := request.GetReqHeaders(req)

shouldBlock, field := doBlock(body.Data, headerData, policies)

if shouldBlock {
log.Warn("User tried to use blocked field: ", field)
http.Error(w, "Forbidden", http.StatusForbidden)
return
}

next.ServeHTTP(w, req)
})
}

func getPolicies(policies map[string]structure.FieldPolicy) (map[string]structure.FieldPolicy, map[string]structure.FieldPolicy) {
blockedFields := map[string]structure.FieldPolicy{}
allowedFields := map[string]structure.FieldPolicy{}

for field, policy := range policies {
switch policy.Action {
case "block":
blockedFields[field] = policy
case "allow":
allowedFields[field] = policy
}
}

return allowedFields, blockedFields
}

func getField(field string, body map[string]any, headers map[string]any) (any, error) {
isHeader := strings.HasPrefix(field, "#")
isBody := strings.HasPrefix(field, "@")

fieldWithoutPrefix := field[1:]

var value any

if body[fieldWithoutPrefix] != nil && isBody {
value = body[fieldWithoutPrefix]
} else if headers[fieldWithoutPrefix] != nil && isHeader {
value = headers[fieldWithoutPrefix]
}

if value != nil {
return value, nil
}

return value, errors.New("field not found")
}

func doBlock(body map[string]any, headers map[string]any, policies map[string]structure.FieldPolicy) (bool, string) {
if policies == nil {
return false, ""
} else if len(policies) <= 0 {
return false, ""
}

allowed, blocked := getPolicies(policies)

var cause string

var isExplictlyAllowed, isExplicitlyBlocked bool

for field, policy := range allowed {
value, err := getField(field, body, headers)

log.Dev("Checking ", field, "...")
log.Dev("Got Value of ", jsonutils.ToJson(value))

if value == policy.Value && err == nil {
isExplictlyAllowed = true
cause = field
break
}
}

for field, policy := range blocked {
value, err := getField(field, body, headers)

log.Dev("Checking ", field, "...")
log.Dev("Got Value of ", jsonutils.ToJson(value))

if value == policy.Value && err == nil {
isExplicitlyBlocked = true
cause = field
break
}
}

// Block all except explicitly Allowed
if len(blocked) == 0 && len(allowed) != 0 {
return !isExplictlyAllowed, cause
}

// Allow all except explicitly Blocked
if len(allowed) == 0 && len(blocked) != 0 {
return isExplicitlyBlocked, cause
}

// Excplicitly Blocked except excplictly Allowed
if len(blocked) != 0 && len(allowed) != 0 {
return isExplicitlyBlocked && !isExplictlyAllowed, cause
}

// Block all
return true, ""
}
1 change: 1 addition & 0 deletions internals/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (proxy Proxy) Init() http.Handler {
Use(m.Endpoints).
Use(m.Template).
Use(m.Mapping).
Use(m.Policy).
Use(m.Message).
Then(proxy.Use())

Expand Down
2 changes: 2 additions & 0 deletions utils/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ func normalizeKeys(config *koanf.Koanf) {
for _, key := range config.Keys() {
lower := strings.ToLower(key)

log.Dev("Lowering key: ", key)

data[lower] = config.Get(key)
}

Expand Down
8 changes: 7 additions & 1 deletion utils/config/structure/structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type SETTINGS struct {

type MESSAGE_SETTINGS struct {
VARIABLES map[string]any `koanf:"variables"`
FIELD_MAPPINGS map[string][]FieldMapping `koanf:"fieldMappings"`
FIELD_MAPPINGS map[string][]FieldMapping `koanf:"fieldmappings"`
TEMPLATE string `koanf:"template"`
}

Expand All @@ -31,4 +31,10 @@ type FieldMapping struct {

type ACCESS_SETTINGS struct {
ENDPOINTS []string `koanf:"endpoints"`
FIELD_POLOCIES map[string]FieldPolicy `koanf:"fieldpolicies"`
}

type FieldPolicy struct {
Value any `koanf:"value"`
Action string `koanf:"action"`
}
8 changes: 5 additions & 3 deletions utils/request/request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package req

import (
"bytes"
"encoding/json"
"errors"
"io"
Expand Down Expand Up @@ -85,12 +86,13 @@ func GetFormData(body []byte) (map[string]any, error) {
func GetBody(req *http.Request) ([]byte, error) {
bodyBytes, err := io.ReadAll(req.Body)

if err != nil {
req.Body.Close()
req.Body.Close()

req.Body = io.NopCloser(bytes.NewReader(bodyBytes))

if err != nil {
return nil, err
}
defer req.Body.Close()

return bodyBytes, nil
}
Expand Down