diff --git a/docs/basic_template.md b/docs/basic_template.md
index 451711b..ef4494b 100644
--- a/docs/basic_template.md
+++ b/docs/basic_template.md
@@ -11,7 +11,7 @@
|Assumptions|
|-----------|
-{tm.assumptions:repeat:|{{item}}|
+{tm.assumptions:repeat:|{{item}}|
}
@@ -62,6 +62,15 @@ Name|Description|Classification
{{item.mitigations}}
References
{{item.references}}
+ Comment
+ {{item.specific_comment}}
+ CWEs
+ {{item.cwes}}
+ TTPs
+ {{item.ttps}}
+ CAPECs
+ {{item.capecs}}
+
diff --git a/pytm/plugins/base/base_plugin.py b/pytm/plugins/base/base_plugin.py
new file mode 100644
index 0000000..88495a1
--- /dev/null
+++ b/pytm/plugins/base/base_plugin.py
@@ -0,0 +1,11 @@
+""" Base class for all plugin types """
+
+
+class BasePlugin():
+ """ Base class for plugins """
+
+ def __init__(self):
+ raise NotImplementedError("Plugin needs an __init__ function")
+
+ def get_name(self):
+ return self.name
\ No newline at end of file
diff --git a/pytm/plugins/base/rule_plugin.py b/pytm/plugins/base/rule_plugin.py
new file mode 100644
index 0000000..aed2d00
--- /dev/null
+++ b/pytm/plugins/base/rule_plugin.py
@@ -0,0 +1,133 @@
+from pytm.plugins.base.base_plugin import BasePlugin
+
+
+class PluginThreat():
+ """ A threat description """
+ # TODO: Untangle the main code and use the Threat class from there here.
+
+ cwes = []
+ capecs = []
+ ttps = []
+
+ def __init__(self, element, comment, **kwargs):
+
+ self.data = {"SID": kwargs.get("SID"),
+ "description": kwargs.get("description", ""),
+ "condition": kwargs.get("condition", ""),
+ "target": kwargs.get("target", []),
+ "details": kwargs.get("details", ""),
+ "severity": kwargs.get("severity", ""),
+ "mitigations": kwargs.get("mitigations", ""),
+ "example": kwargs.get("example", ""),
+ "references": " ".join(kwargs.get("reference_list", [])),
+ "specific_comment": kwargs.get("specific_comment", ""),
+
+ "cwes": kwargs.get("cwes", []),
+ "ttps": kwargs.get("ttps", []),
+ "capecs": kwargs.get("capecs", [])
+ }
+ self.element = element
+ self.comment = comment
+
+ def to_threatfile_format(self):
+ """ Returns data in threatfile format """
+ return self.data
+
+
+class RuleResult():
+ """ Can collect a large variety of detection results. Can be extended beyond threats. This is the reason there is a whole class here to collect that """
+
+ def __init__(self) -> None:
+ self._threats = []
+
+ def add_threat(self, element, comment, **kwargs):
+ self._threats.append(PluginThreat(element, comment, **kwargs))
+
+ def get_threats(self):
+ return self._threats
+
+
+
+class RulePlugin(BasePlugin):
+ """ A rule matching plugin base
+
+ """
+
+ sid = None
+
+ cwes = []
+ capecs = []
+ ttps = []
+
+ def __init__(self):
+ self.result = RuleResult()
+ self.elements = []
+
+ ### Entry points
+
+ def threat_check(self, elements):
+ """ Calls the plugin function to check threats after abstracting internals away """
+
+ self.elements = elements
+
+ self.threat_match()
+
+ ### Generic functions
+
+ def get_type(self, element):
+ """ returns a type string for an element """
+
+ # TODO: Move that to the classes
+ if str((type(element))) == "":
+ return "Boundary"
+ if str((type(element))) == "":
+ return "Datastore"
+ if str((type(element))) == "":
+ return "Dataflow"
+ if str((type(element))) == "":
+ return "Server"
+ if str((type(element))) == "":
+ return "Actor"
+
+ def get_elements(self):
+ return self.elements
+
+ ### Threat things
+
+ def add_threat(self, element, comment):
+ """ Add a threat to the results
+
+ @param element: the threat is tied to
+ @param comment: used comment for this threat
+
+ """
+ data = {"SID": self.SID,
+ "description": self.description,
+ "condition": self.condition,
+ "target": self.target,
+ "details": self.details,
+ "severity": self.severity,
+ "mitigations": self.mitigations,
+ "example": self.example,
+ "reference_list": self.reference_list,
+ "specific_comment": comment,
+
+ "ttps": self.ttps,
+ "capecs": self.capecs,
+ "cwes": self.cwes,
+ }
+
+ self.result.add_threat(element, comment, **data)
+
+ def get_threats(self):
+ """ Read threats from the collection """
+ return self.result.get_threats()
+
+ def get_id(self):
+ return self.SID
+
+ def get_description(self):
+ return self.description
+
+
+
diff --git a/pytm/plugins/rules/specific_sql_injection.py b/pytm/plugins/rules/specific_sql_injection.py
new file mode 100644
index 0000000..7e013e5
--- /dev/null
+++ b/pytm/plugins/rules/specific_sql_injection.py
@@ -0,0 +1,61 @@
+from pytm.plugins.base.rule_plugin import RulePlugin, RuleResult
+from pytm import DatastoreType
+
+
+class StrictSQLInjectionRulePlugin(RulePlugin):
+ # Boilerplate
+ name = "strict_sql_injection"
+ description = "A strict SQL injection rule"
+
+ SID = "EXP01"
+ details = "A SQL datastore is connected to a web server which does not sanitize inputs. This web server can be accessed by an actor"
+ LikelihoodOfAttack = "High"
+ severity = "High"
+ condition = "A SQL datastore is connected to a web server which does not sanitize inputs. This web server can be accessed by an actor"
+ prerequisites = ""
+ mitigations = "Sanitize input to protect the SQL server. Use PreparedStatements"
+ example = ""
+ reference_list = []
+ target = []
+
+ cwes = ["89", "1286"]
+ capecs = ["66"]
+ ttps = ["T1190"]
+
+ def __init__(self):
+ super().__init__()
+ self.plugin_path = __file__
+
+ def connected_elements(self, element):
+ """ Lists all elements connected by Dataflows to this element """
+ res = []
+
+ for a_dataflow in self.get_elements():
+ if self.get_type(a_dataflow) == "Dataflow":
+ if a_dataflow.source == element:
+ res.append(a_dataflow.sink)
+ if a_dataflow.sink == element:
+ res.append(a_dataflow.source)
+ return res
+
+ def threat_match(self):
+ """ Specific SQL injection test. Extra specific to test the power of plugin rules.
+
+ A SQL datastore is connected to a web server which does not sanitize inputs. This web server can be accessed by an actor .
+ """
+ for a_database in self.get_elements():
+ if self.get_type(a_database) == "Datastore" and a_database.type == DatastoreType.SQL:
+ servers_connected_to_database = self.connected_elements(a_database)
+ for a_webserver in servers_connected_to_database:
+ # Is connected to a web server which does not sanitize input
+ if self.get_type(a_webserver) == "Server" and a_webserver.controls.sanitizesInput == False:
+ users_connected_to_server = self.connected_elements(a_webserver)
+ # Check all connections of this web server, is a user connected (="Actor")
+ for a_user in users_connected_to_server:
+ if self.get_type(a_user) == "Actor":
+ self.add_threat(a_database, comment = f"The user '{a_user.name}' could run SQL injection attacks on '{a_database.name}' via '{a_webserver.name}'")
+
+
+
+
+
diff --git a/pytm/pytm.py b/pytm/pytm.py
index e6e005f..13c82f2 100644
--- a/pytm/pytm.py
+++ b/pytm/pytm.py
@@ -9,6 +9,7 @@
import uuid
import html
import copy
+from glob import glob
from collections import Counter, defaultdict
from collections.abc import Iterable
@@ -20,6 +21,9 @@
from textwrap import indent, wrap
from weakref import WeakKeyDictionary
from datetime import datetime
+import straight.plugin # type: ignore
+from straight.plugin.manager import PluginManager as StraightPluginManager # type: ignore
+from pytm.plugins.base.rule_plugin import RulePlugin
from pydal import DAL, Field
@@ -594,6 +598,7 @@ class Threat:
example = varString("")
references = varString("")
target = ()
+ specific_comment = varString("")
def __init__(self, **kwargs):
self.id = kwargs["SID"]
@@ -610,6 +615,10 @@ def __init__(self, **kwargs):
self.mitigations = kwargs.get("mitigations", "")
self.example = kwargs.get("example", "")
self.references = kwargs.get("references", "")
+ self.specific_comment = kwargs.get("specific_comment", "") # A detailed comment for this specific occurence
+ self.cwes = kwargs.get("cwes", [])
+ self.capecs = kwargs.get("capecs", [])
+ self.ttps = kwargs.get("ttps", [])
def _safeset(self, attr, value):
try:
@@ -658,6 +667,7 @@ class Finding:
""",
)
cvss = varString("", required=False, doc="The CVSS score and/or vector")
+ specific_comment = ""
def __init__(
self,
@@ -679,6 +689,10 @@ def __init__(
"example",
"references",
"condition",
+ "specific_comment",
+ "cwes",
+ "capecs",
+ "ttps"
]
threat = kwargs.pop("threat", None)
if threat:
@@ -731,6 +745,7 @@ class TM:
_actors = []
_assets = []
_threats = []
+ _rule_plugins = []
_boundaries = []
_data = []
_threatsExcluded = []
@@ -757,6 +772,7 @@ class TM:
required=False,
doc="A list of assumptions about the design/model.",
)
+ finding_count = 0
def __init__(self, name, **kwargs):
for key, value in kwargs.items():
@@ -764,6 +780,7 @@ def __init__(self, name, **kwargs):
self.name = name
self._sf = SuperFormatter()
self._add_threats()
+ self._rule_plugins = self._add_rule_plugins()
# make sure generated diagrams do not change, makes sense if they're commited
random.seed(0)
@@ -780,8 +797,36 @@ def reset(cls):
def _init_threats(self):
TM._threats = []
+ TM._rule_plugins = []
self._add_threats()
+
+ def _add_rule_plugins(self):
+ """ Returns a list plugins
+
+ :return: A list of instantiated plugins
+ """
+
+
+ res = []
+
+ def get_handlers(a_plugin: StraightPluginManager):
+ return a_plugin.produce()
+
+ plugin_dirs = set()
+ for a_glob in glob("pytm/plugins/rules/*.py", recursive=True):
+ plugin_dirs.add(os.path.dirname(a_glob))
+
+ for a_dir in plugin_dirs:
+ plugins = straight.plugin.load(a_dir, subclasses=RulePlugin)
+
+ handlers = get_handlers(plugins)
+
+ for plugin in handlers:
+ res.append(plugin)
+
+ return res
+
def _add_threats(self):
try:
with open(self.threatsFile, "r", encoding="utf8") as threat_file:
@@ -793,7 +838,7 @@ def _add_threats(self):
TM._threats.append(Threat(**i))
def resolve(self):
- finding_count = 0
+ self.finding_count = 0
findings = []
elements = defaultdict(list)
for e in TM._elements:
@@ -817,8 +862,8 @@ def resolve(self):
if t.id in TM._threatsExcluded:
continue
- finding_count += 1
- f = Finding(e, id=str(finding_count), threat=t)
+ self.finding_count += 1
+ f = Finding(e, id=str(self.finding_count), threat=t)
logger.debug(f"new finding: {f}")
findings.append(f)
elements[e].append(f)
@@ -826,6 +871,16 @@ def resolve(self):
for e, findings in elements.items():
e.findings = findings
+ def resolve_plugins(self):
+ for plugin in self._rule_plugins:
+ plugin.threat_check(self._elements)
+
+ for t in plugin.get_threats():
+ self.finding_count += 1
+ f = Finding(t.element, id=str(self.finding_count), threat=Threat(**t.to_threatfile_format()))
+ self.findings.append(f)
+ # TODO: Allow exclusion of threats
+
def check(self):
if self.description is None:
raise ValueError(
@@ -1074,6 +1129,7 @@ def _process(self):
or result.stale_days is not None
):
self.resolve()
+ self.resolve_plugins()
if result.sqldump is not None:
self.sqlDump(result.sqldump)
@@ -1096,6 +1152,8 @@ def _process(self):
if result.list is True:
[print("{} - {}".format(t.id, t.description)) for t in TM._threats]
+ print("Plugins:")
+ [print("{} - {}".format(p.get_id(), p.get_description())) for p in self._rule_plugins]
if result.stale_days is not None:
print(self._stale(result.stale_days))
@@ -1917,8 +1975,8 @@ def encode_element_threat_data(obj):
v = getattr(o, a)
if (type(v) is not list or (type(v) is list and len(v) != 0)):
c._safeset(a, v)
-
- encoded_elements.append(c)
+
+ encoded_elements.append(c)
return encoded_elements
diff --git a/requirements.txt b/requirements.txt
index d522381..e0e4315 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,2 @@
pydal>=20200714.1
+straight.plugin==1.5.0
\ No newline at end of file
diff --git a/tests/output.json b/tests/output.json
index bfd2627..38e31b2 100644
--- a/tests/output.json
+++ b/tests/output.json
@@ -912,6 +912,7 @@
"usesEnvironmentVariables": false
}
],
+ "finding_count": 0,
"findings": [],
"flows": [
{
diff --git a/tests/test_pytmfunc.py b/tests/test_pytmfunc.py
index ab1a8d4..2c92899 100644
--- a/tests/test_pytmfunc.py
+++ b/tests/test_pytmfunc.py
@@ -299,6 +299,7 @@ def test_json_dumps(self):
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(dir_path, "output.json")) as x:
expected = x.read().strip()
+ expected_data = json.loads(expected)
TM.reset()
tm = TM(
"my test tm", description="aaa", threatsFile="pytm/threatlib/threats.json"
@@ -331,7 +332,13 @@ def test_json_dumps(self):
x.write(output)
self.maxDiff = None
- self.assertEqual(output, expected)
+
+ # Plugins are flexible this must be removed from dict to be able to do unit tests
+ with open(os.path.join(output_path, "output_current.json")) as fh:
+ output_data = json.load(fh)
+ output_data.pop("rule_plugins")
+
+ self.assertEqual(output_data, expected_data)
def test_json_loads(self):
random.seed(0)