diff --git a/docs/basic_template.md b/docs/basic_template.md index 451711b..ef4494b 100644 --- a/docs/basic_template.md +++ b/docs/basic_template.md @@ -11,7 +11,7 @@ |Assumptions| |-----------| -{tm.assumptions:repeat:|{{item}}| +{tm.assumptions:repeat:|{{item}}| }   @@ -62,6 +62,15 @@ Name|Description|Classification

{{item.mitigations}}

References

{{item.references}}

+
Comment
+

{{item.specific_comment}}

+
CWEs
+

{{item.cwes}}

+
TTPs
+

{{item.ttps}}

+
CAPECs
+

{{item.capecs}}

+       diff --git a/pytm/plugins/base/base_plugin.py b/pytm/plugins/base/base_plugin.py new file mode 100644 index 0000000..88495a1 --- /dev/null +++ b/pytm/plugins/base/base_plugin.py @@ -0,0 +1,11 @@ +""" Base class for all plugin types """ + + +class BasePlugin(): + """ Base class for plugins """ + + def __init__(self): + raise NotImplementedError("Plugin needs an __init__ function") + + def get_name(self): + return self.name \ No newline at end of file diff --git a/pytm/plugins/base/rule_plugin.py b/pytm/plugins/base/rule_plugin.py new file mode 100644 index 0000000..aed2d00 --- /dev/null +++ b/pytm/plugins/base/rule_plugin.py @@ -0,0 +1,133 @@ +from pytm.plugins.base.base_plugin import BasePlugin + + +class PluginThreat(): + """ A threat description """ + # TODO: Untangle the main code and use the Threat class from there here. + + cwes = [] + capecs = [] + ttps = [] + + def __init__(self, element, comment, **kwargs): + + self.data = {"SID": kwargs.get("SID"), + "description": kwargs.get("description", ""), + "condition": kwargs.get("condition", ""), + "target": kwargs.get("target", []), + "details": kwargs.get("details", ""), + "severity": kwargs.get("severity", ""), + "mitigations": kwargs.get("mitigations", ""), + "example": kwargs.get("example", ""), + "references": " ".join(kwargs.get("reference_list", [])), + "specific_comment": kwargs.get("specific_comment", ""), + + "cwes": kwargs.get("cwes", []), + "ttps": kwargs.get("ttps", []), + "capecs": kwargs.get("capecs", []) + } + self.element = element + self.comment = comment + + def to_threatfile_format(self): + """ Returns data in threatfile format """ + return self.data + + +class RuleResult(): + """ Can collect a large variety of detection results. Can be extended beyond threats. This is the reason there is a whole class here to collect that """ + + def __init__(self) -> None: + self._threats = [] + + def add_threat(self, element, comment, **kwargs): + self._threats.append(PluginThreat(element, comment, **kwargs)) + + def get_threats(self): + return self._threats + + + +class RulePlugin(BasePlugin): + """ A rule matching plugin base + + """ + + sid = None + + cwes = [] + capecs = [] + ttps = [] + + def __init__(self): + self.result = RuleResult() + self.elements = [] + + ### Entry points + + def threat_check(self, elements): + """ Calls the plugin function to check threats after abstracting internals away """ + + self.elements = elements + + self.threat_match() + + ### Generic functions + + def get_type(self, element): + """ returns a type string for an element """ + + # TODO: Move that to the classes + if str((type(element))) == "": + return "Boundary" + if str((type(element))) == "": + return "Datastore" + if str((type(element))) == "": + return "Dataflow" + if str((type(element))) == "": + return "Server" + if str((type(element))) == "": + return "Actor" + + def get_elements(self): + return self.elements + + ### Threat things + + def add_threat(self, element, comment): + """ Add a threat to the results + + @param element: the threat is tied to + @param comment: used comment for this threat + + """ + data = {"SID": self.SID, + "description": self.description, + "condition": self.condition, + "target": self.target, + "details": self.details, + "severity": self.severity, + "mitigations": self.mitigations, + "example": self.example, + "reference_list": self.reference_list, + "specific_comment": comment, + + "ttps": self.ttps, + "capecs": self.capecs, + "cwes": self.cwes, + } + + self.result.add_threat(element, comment, **data) + + def get_threats(self): + """ Read threats from the collection """ + return self.result.get_threats() + + def get_id(self): + return self.SID + + def get_description(self): + return self.description + + + diff --git a/pytm/plugins/rules/specific_sql_injection.py b/pytm/plugins/rules/specific_sql_injection.py new file mode 100644 index 0000000..7e013e5 --- /dev/null +++ b/pytm/plugins/rules/specific_sql_injection.py @@ -0,0 +1,61 @@ +from pytm.plugins.base.rule_plugin import RulePlugin, RuleResult +from pytm import DatastoreType + + +class StrictSQLInjectionRulePlugin(RulePlugin): + # Boilerplate + name = "strict_sql_injection" + description = "A strict SQL injection rule" + + SID = "EXP01" + details = "A SQL datastore is connected to a web server which does not sanitize inputs. This web server can be accessed by an actor" + LikelihoodOfAttack = "High" + severity = "High" + condition = "A SQL datastore is connected to a web server which does not sanitize inputs. This web server can be accessed by an actor" + prerequisites = "" + mitigations = "Sanitize input to protect the SQL server. Use PreparedStatements" + example = "" + reference_list = [] + target = [] + + cwes = ["89", "1286"] + capecs = ["66"] + ttps = ["T1190"] + + def __init__(self): + super().__init__() + self.plugin_path = __file__ + + def connected_elements(self, element): + """ Lists all elements connected by Dataflows to this element """ + res = [] + + for a_dataflow in self.get_elements(): + if self.get_type(a_dataflow) == "Dataflow": + if a_dataflow.source == element: + res.append(a_dataflow.sink) + if a_dataflow.sink == element: + res.append(a_dataflow.source) + return res + + def threat_match(self): + """ Specific SQL injection test. Extra specific to test the power of plugin rules. + + A SQL datastore is connected to a web server which does not sanitize inputs. This web server can be accessed by an actor . + """ + for a_database in self.get_elements(): + if self.get_type(a_database) == "Datastore" and a_database.type == DatastoreType.SQL: + servers_connected_to_database = self.connected_elements(a_database) + for a_webserver in servers_connected_to_database: + # Is connected to a web server which does not sanitize input + if self.get_type(a_webserver) == "Server" and a_webserver.controls.sanitizesInput == False: + users_connected_to_server = self.connected_elements(a_webserver) + # Check all connections of this web server, is a user connected (="Actor") + for a_user in users_connected_to_server: + if self.get_type(a_user) == "Actor": + self.add_threat(a_database, comment = f"The user '{a_user.name}' could run SQL injection attacks on '{a_database.name}' via '{a_webserver.name}'") + + + + + diff --git a/pytm/pytm.py b/pytm/pytm.py index e6e005f..13c82f2 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -9,6 +9,7 @@ import uuid import html import copy +from glob import glob from collections import Counter, defaultdict from collections.abc import Iterable @@ -20,6 +21,9 @@ from textwrap import indent, wrap from weakref import WeakKeyDictionary from datetime import datetime +import straight.plugin # type: ignore +from straight.plugin.manager import PluginManager as StraightPluginManager # type: ignore +from pytm.plugins.base.rule_plugin import RulePlugin from pydal import DAL, Field @@ -594,6 +598,7 @@ class Threat: example = varString("") references = varString("") target = () + specific_comment = varString("") def __init__(self, **kwargs): self.id = kwargs["SID"] @@ -610,6 +615,10 @@ def __init__(self, **kwargs): self.mitigations = kwargs.get("mitigations", "") self.example = kwargs.get("example", "") self.references = kwargs.get("references", "") + self.specific_comment = kwargs.get("specific_comment", "") # A detailed comment for this specific occurence + self.cwes = kwargs.get("cwes", []) + self.capecs = kwargs.get("capecs", []) + self.ttps = kwargs.get("ttps", []) def _safeset(self, attr, value): try: @@ -658,6 +667,7 @@ class Finding: """, ) cvss = varString("", required=False, doc="The CVSS score and/or vector") + specific_comment = "" def __init__( self, @@ -679,6 +689,10 @@ def __init__( "example", "references", "condition", + "specific_comment", + "cwes", + "capecs", + "ttps" ] threat = kwargs.pop("threat", None) if threat: @@ -731,6 +745,7 @@ class TM: _actors = [] _assets = [] _threats = [] + _rule_plugins = [] _boundaries = [] _data = [] _threatsExcluded = [] @@ -757,6 +772,7 @@ class TM: required=False, doc="A list of assumptions about the design/model.", ) + finding_count = 0 def __init__(self, name, **kwargs): for key, value in kwargs.items(): @@ -764,6 +780,7 @@ def __init__(self, name, **kwargs): self.name = name self._sf = SuperFormatter() self._add_threats() + self._rule_plugins = self._add_rule_plugins() # make sure generated diagrams do not change, makes sense if they're commited random.seed(0) @@ -780,8 +797,36 @@ def reset(cls): def _init_threats(self): TM._threats = [] + TM._rule_plugins = [] self._add_threats() + + def _add_rule_plugins(self): + """ Returns a list plugins + + :return: A list of instantiated plugins + """ + + + res = [] + + def get_handlers(a_plugin: StraightPluginManager): + return a_plugin.produce() + + plugin_dirs = set() + for a_glob in glob("pytm/plugins/rules/*.py", recursive=True): + plugin_dirs.add(os.path.dirname(a_glob)) + + for a_dir in plugin_dirs: + plugins = straight.plugin.load(a_dir, subclasses=RulePlugin) + + handlers = get_handlers(plugins) + + for plugin in handlers: + res.append(plugin) + + return res + def _add_threats(self): try: with open(self.threatsFile, "r", encoding="utf8") as threat_file: @@ -793,7 +838,7 @@ def _add_threats(self): TM._threats.append(Threat(**i)) def resolve(self): - finding_count = 0 + self.finding_count = 0 findings = [] elements = defaultdict(list) for e in TM._elements: @@ -817,8 +862,8 @@ def resolve(self): if t.id in TM._threatsExcluded: continue - finding_count += 1 - f = Finding(e, id=str(finding_count), threat=t) + self.finding_count += 1 + f = Finding(e, id=str(self.finding_count), threat=t) logger.debug(f"new finding: {f}") findings.append(f) elements[e].append(f) @@ -826,6 +871,16 @@ def resolve(self): for e, findings in elements.items(): e.findings = findings + def resolve_plugins(self): + for plugin in self._rule_plugins: + plugin.threat_check(self._elements) + + for t in plugin.get_threats(): + self.finding_count += 1 + f = Finding(t.element, id=str(self.finding_count), threat=Threat(**t.to_threatfile_format())) + self.findings.append(f) + # TODO: Allow exclusion of threats + def check(self): if self.description is None: raise ValueError( @@ -1074,6 +1129,7 @@ def _process(self): or result.stale_days is not None ): self.resolve() + self.resolve_plugins() if result.sqldump is not None: self.sqlDump(result.sqldump) @@ -1096,6 +1152,8 @@ def _process(self): if result.list is True: [print("{} - {}".format(t.id, t.description)) for t in TM._threats] + print("Plugins:") + [print("{} - {}".format(p.get_id(), p.get_description())) for p in self._rule_plugins] if result.stale_days is not None: print(self._stale(result.stale_days)) @@ -1917,8 +1975,8 @@ def encode_element_threat_data(obj): v = getattr(o, a) if (type(v) is not list or (type(v) is list and len(v) != 0)): c._safeset(a, v) - - encoded_elements.append(c) + + encoded_elements.append(c) return encoded_elements diff --git a/requirements.txt b/requirements.txt index d522381..e0e4315 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ pydal>=20200714.1 +straight.plugin==1.5.0 \ No newline at end of file diff --git a/tests/output.json b/tests/output.json index bfd2627..38e31b2 100644 --- a/tests/output.json +++ b/tests/output.json @@ -912,6 +912,7 @@ "usesEnvironmentVariables": false } ], + "finding_count": 0, "findings": [], "flows": [ { diff --git a/tests/test_pytmfunc.py b/tests/test_pytmfunc.py index ab1a8d4..2c92899 100644 --- a/tests/test_pytmfunc.py +++ b/tests/test_pytmfunc.py @@ -299,6 +299,7 @@ def test_json_dumps(self): dir_path = os.path.dirname(os.path.realpath(__file__)) with open(os.path.join(dir_path, "output.json")) as x: expected = x.read().strip() + expected_data = json.loads(expected) TM.reset() tm = TM( "my test tm", description="aaa", threatsFile="pytm/threatlib/threats.json" @@ -331,7 +332,13 @@ def test_json_dumps(self): x.write(output) self.maxDiff = None - self.assertEqual(output, expected) + + # Plugins are flexible this must be removed from dict to be able to do unit tests + with open(os.path.join(output_path, "output_current.json")) as fh: + output_data = json.load(fh) + output_data.pop("rule_plugins") + + self.assertEqual(output_data, expected_data) def test_json_loads(self): random.seed(0)