diff --git a/check50/__main__.py b/check50/__main__.py index 7b0e2a6..a675d6e 100644 --- a/check50/__main__.py +++ b/check50/__main__.py @@ -7,6 +7,7 @@ import logging import os import platform +import shutil import site from pathlib import Path import subprocess @@ -20,7 +21,7 @@ import requests import termcolor -from . import _exceptions, internal, renderer, __version__ +from . import _exceptions, internal, renderer, assertions, __version__ from .contextmanagers import nullcontext from .runner import CheckRunner @@ -273,10 +274,10 @@ def flush(self): def check_version(package_name=__package__, timeout=1): - """Check for newer version of the package on PyPI""" + """Check for newer version of the package on PyPI""" if not __version__: return - + try: current = packaging.version.parse(__version__) latest = max(requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=timeout).json()["releases"], key=packaging.version.parse) @@ -387,7 +388,21 @@ def main(): if not args.no_install_dependencies: install_dependencies(config["dependencies"]) - checks_file = (internal.check_dir / config["checks"]).resolve() + # Store the original checks file and leave as is + original_checks_file = (internal.check_dir / config["checks"]).resolve() + + # If the user has enabled the rewrite feature + if assertions.rewrite_enabled(str(original_checks_file)): + # Create a temporary copy of the checks file + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tmp: + checks_file = Path(tmp.name) + shutil.copyfile(original_checks_file, checks_file) + + # Rewrite all assert statements in the copied checks file to check50_assert + assertions.rewrite(str(checks_file)) + else: + # Don't rewrite any assert statements and continue + checks_file = original_checks_file # Have lib50 decide which files to include included_files = lib50.files(config.get("files"))[0] diff --git a/check50/assertions/__init__.py b/check50/assertions/__init__.py new file mode 100644 index 0000000..3b77f65 --- /dev/null +++ b/check50/assertions/__init__.py @@ -0,0 +1 @@ +from .rewrite import rewrite, rewrite_enabled diff --git a/check50/assertions/rewrite.py b/check50/assertions/rewrite.py new file mode 100644 index 0000000..a4f92da --- /dev/null +++ b/check50/assertions/rewrite.py @@ -0,0 +1,262 @@ +import ast +import re + +def rewrite(path: str): + """ + A function that rewrites all instances of `assert` in a file to our own + `check50_assert` function that raises our own exceptions. + + :param path: The path to the file you wish to rewrite. + :type path: str + """ + with open(path) as f: + source = f.read() + + # Parse the tree and replace all instance of `assert`. + tree = ast.parse(source, filename=path) + transformer = _AssertionRewriter() + new_tree = transformer.visit(tree) + ast.fix_missing_locations(new_tree) + + # Insert `from check50.assertions.runtime import check50_assert` only if not already present + if not any( + isinstance(stmt, ast.ImportFrom) and stmt.module == "check50.assertions.runtime" + for stmt in new_tree.body + ): + # Create an import statement for check50_assert + import_stmt = ast.ImportFrom( + module="check50.assertions.runtime", + names=[ast.alias(name="check50_assert", asname=None)], + level=0 + ) + + # Prepend to the beginning of the file + new_tree.body.insert(0, import_stmt) + + modified_source = ast.unparse(new_tree) + + # Write to the file + with open(path, 'w') as f: + f.write(modified_source) + +def rewrite_enabled(path: str): + """ + Checks if the first line of the file contains a comment of the form: + + ``` + # ENABLE_CHECK50_ASSERT = 1 + ``` + + Ignores whitespace. + + :param path: The path to the file you wish to check. + :type path: str + """ + pattern = re.compile( + r"^#\s*ENABLE_CHECK50_ASSERT\s*=\s*(1|True)$", + re.IGNORECASE + ) + + with open(path, 'r') as f: + first_line = f.readline().strip() + return bool(pattern.match(first_line)) + + +class _AssertionRewriter(ast.NodeTransformer): + """ + Helper class to to wrap the conditions being tested by `assert` with a + function called `check50_assert`. + """ + def visit_Assert(self, node): + """ + An overwrite of the AST module's visit_Assert to inject our code in + place of the default assertion logic. + + :param node: The `assert` statement node being visited and transformed. + :type node: ast.Assert + """ + self.generic_visit(node) + cond_type = self._identify_comparison_type(node.test) + + # Begin adding a named parameter that determines the type of condition + keywords = [ast.keyword(arg="cond_type", value=ast.Constant(value=cond_type))] + + # Extract variable names and build context={"var": var, ...} + var_names = self._extract_names(node.test) + context_dict = self._make_context_dict(var_names) + + if var_names and context_dict.keys: + keywords.append(ast.keyword( + arg="context", + value=context_dict + )) + + # Set the left and right side of the conditional as strings for later + # evaluation (used when raising check50.Missing and check50.Mismatch) + if isinstance(node.test, ast.Compare) and node.test.comparators: + left_node = node.test.left + right_node = node.test.comparators[0] + + left_str = ast.unparse(left_node) + right_str = ast.unparse(right_node) + + # Only add to context if not literal constants + if not isinstance(left_node, ast.Constant): + context_dict.keys.append(ast.Constant(value=left_str)) + context_dict.values.append(ast.Constant(value=None)) + if not isinstance(right_node, ast.Constant): + context_dict.keys.append(ast.Constant(value=right_str)) + context_dict.values.append(ast.Constant(value=None)) + + + keywords.extend([ + ast.keyword(arg="left", value=ast.Constant(value=left_str)), + ast.keyword(arg="right", value=ast.Constant(value=right_str)) + ]) + + return ast.Expr( + value=ast.Call( + # Create a function called check50_assert + func=ast.Name(id="check50_assert", ctx=ast.Load()), + # Give it these postional arguments: + args=[ + # The string form of the condition + ast.Constant(value=ast.unparse(node.test)), + # The additional msg or exception that the user provided + node.msg or ast.Constant(value=None) + ], + # And these named parameters: + keywords=keywords + ) + ) + + + def _identify_comparison_type(self, test_node): + """ + Checks if a conditional is a comparison between two expressions. If so, + attempts to identify the comparison operator (e.g., `==`, `in`). Falls + back to "unknown" if the conditional is not a comparison or if the + operator is not recognized. + + :param test_node: The AST conditional node that is being identified. + :type test_node: ast.expr + """ + if isinstance(test_node, ast.Compare) and test_node.ops: + op = test_node.ops[0] # the operator in between the comparators + if isinstance(op, ast.Eq): + return "eq" + elif isinstance(op, ast.In): + return "in" + + return "unknown" + + def _extract_names(self, expr): + """ + Returns a set of the names of every variable, function + (including the modules or classes they're located under), and function + argument in a given AST expression. + + :param expr: An AST expression. + :type expr: ast.AST + """ + class NameExtractor(ast.NodeVisitor): + def __init__(self): + self.names = set() + self._in_func_chain = False # flag to track nested Calls and Names + + def visit_Call(self, node): + # Temporarily store whether we're already in a chain + already_in_chain = self._in_func_chain + + # If already_in_chain is False, we're at the top-most level of + # the Call node. Without this guard, callable classes/modules + # will also be included in the output. For instance, + # check50.run('./test') AND check50.run('./test').stdout() will + # be included. + if not already_in_chain: + # Grab the entire dotted function name + full_name = self._get_full_func_name(node) + self.names.add(full_name) + + # As we travel down the function's subtree, denote this flag as True + self._in_func_chain = True + self.visit(node.func) + self._in_func_chain = already_in_chain # Restore previous state + + # Now visit the arguments of this function + for arg in node.args: + self.visit(arg) + for kw in node.keywords: + self.visit(kw) + + def visit_Name(self, node): + if not self._in_func_chain: # ignore Names of modules/libraries + self.names.add(node.id) + # self.names.add(node.id) + + def _get_full_func_name(self, node): + """ + Grab the entire function name, including the module or class + in which the function was located, as well as the function + arguments. + + For instance, this function would return + ``` + "check50.run('./test').stdout()" + ``` + as opposed to + ``` + "stdout" + ``` + """ + def format_args(call_node): + # Positional arguments + args = [ast.unparse(arg) for arg in call_node.args] + # Keyword arguments + kwargs = [f"{kw.arg}={ast.unparse(kw.value)}" for kw in call_node.keywords] + all_args = args + kwargs + return f"({', '.join(all_args)})" + + parts = [] + # Apply the same operations for even nested function calls. + while isinstance(node, ast.Call): + func = node.func + arg_string = format_args(node) + + # Attributes inside of Calls signify a `.` attribute was used + if isinstance(func, ast.Attribute): + parts.append(func.attr + arg_string) + node = func.value # step into next node in chain + elif isinstance(func, ast.Name): + parts.append(func.id + arg_string) + return ".".join(reversed(parts)) + else: + return f"[DEBUG] failed to grab func name: {ast.unparse(func)}" + + if isinstance(node, ast.Name): + parts.append(node.id) + + return ".".join(reversed(parts)) + + extractor = NameExtractor() + extractor.visit(expr) + return extractor.names + + def _make_context_dict(self, name_set): + """ + Returns an AST dictionary in which the keys are the names of variables + and the values are the value from each respective variable. + + :param name_set: A set of known names of variables. + :type name_set: set[str] + """ + keys, values = [], [] + for name in name_set: + keys.append(ast.Constant(value=name)) + # Defer evaluation of the values until later, since we don't have + # access to function results at this point + values.append(ast.Constant(value=None)) + + return ast.Dict(keys=keys, values=values) + + diff --git a/check50/assertions/runtime.py b/check50/assertions/runtime.py new file mode 100644 index 0000000..fdf1a56 --- /dev/null +++ b/check50/assertions/runtime.py @@ -0,0 +1,131 @@ +from check50 import Failure, Missing, Mismatch + +def check50_assert(src, msg_or_exc=None, cond_type="unknown", left=None, right=None, context=None): + """ + Asserts a conditional statement. If the condition evaluates to True, + nothing happens. Otherwise, it will look for a message or exception that + follows the condition (seperated by a comma). If the msg_or_exc is not + a string, an exception, or it was not provided, it is silently ignored. + + In such cases, we attempt to determine which exception should be raised + based on the type of the conditional. If recognized, it raises either + check50.Mismatch or check50.Missing. If the conditional type is unknown or + unhandled, check50.Failure is raised with a default message. + + Used for rewriting assertion statements in check files. + + Note: + Exceptions from the check50 library are preferred, since they will be + handled gracefully and integrated into the check output. Native Python + exceptions are technically supported, but check50 will immediately + terminate on the user's end if the assertion fails. + + Example usage: + ``` + assert x in y + ``` + will be converted to + ``` + check50_assert(x in y, "x in y", None, "in", x, y) + ``` + + :param src: The source code string of the conditional expression \ + (e.g., 'x in y'), extracted from the AST. + :type src: str + :param msg_or_exc: The message or exception following the conditional in \ + the assertion statement. + :type msg_or_exc: str | BaseException | None + :param cond_type: The type of conditional, one of {"eq", "in", "unknown"} + :type cond_type: str + :param left: The left side of the conditional, if applicable + :type left: str | None + :param right: The right side of the conditional, if applicable + :type right: str | None + :param context: A collection of the conditional's variable names as keys. + :type context: dict + + :raises msg_or_exc: If msg_or_exc is an exception. + :raises check50.Mismatch: If no exception is provided and cond_type is "eq". + :raises check50.Missing: If no exception is provided and cond_type is "in". + :raises check50.Failure: If msg_or_exc is a string, or if cond_type is \ + unrecognized. + """ + # Evaluate all variables and functions within the context dict and generate + # a string of these values + context_str = None + if context or (left and right): + import inspect + for expr_str in context: + try: + # Grab the global and local variables as of now + caller_frame = inspect.currentframe().f_back + context[expr_str] = eval(expr_str, caller_frame.f_globals, caller_frame.f_locals) + except Exception as e: + context[expr_str] = f"[error evaluating: {e}]" + + # produces a string like "var1 = ..., var2 = ..., foo() = ..." + context_str = ", ".join(f"{k} = {repr(v)}" for k, v in (context or {}).items()) + + # Since we've memoized the functions and variables once, now try and + # evaluate the conditional by substituting the function calls/vars with + # their results + eval_src, eval_context = substitute_expressions(src, context) + cond = eval(eval_src, {}, eval_context) + + # Finally, quit if the condition evaluated to True. + if cond: + return + + # If `right` or `left` were evaluatable objects, their actual value will be stored in `context`. + # Otherwise, they're still just literals. + right = context.get(right) or right + left = context.get(left) or left + + # Since the condition didn't evaluate to True, now, we can raise special + # exceptions. + if isinstance(msg_or_exc, str): + raise Failure(msg_or_exc) + elif isinstance(msg_or_exc, BaseException): + raise msg_or_exc + elif cond_type == 'eq' and left and right: + help_msg = f"checked: {src}" + help_msg += f"\n where {context_str}" if context_str else "" + raise Mismatch(right, left, help=help_msg) + elif cond_type == 'in' and left and right: + help_msg = f"checked: {src}" + help_msg += f"\n where {context_str}" if context_str else "" + raise Missing(left, right, help=help_msg) + else: + help_msg = f"\n where {context_str}" if context_str else "" + raise Failure(f"check did not pass: {src}" + help_msg) + +def substitute_expressions(src: str, context: dict) -> tuple[str, dict]: + """ + Rewrites `src` by replacing each key in `context` with a placeholder variable name, + and builds a new context dict where those names map to pre-evaluated values. + + For instance, given a `src`: + ``` + check50.run('pwd').stdout() == actual + ``` + it will create a new `eval_src` as + ``` + __expr0 == __expr1 + ``` + and use the given context to define these variables: + ``` + eval_context = { + '__expr0': context['check50.run('pwd').stdout()'], + '__expr1': context['actual'] + } + ``` + """ + new_src = src + new_context = {} + + for i, expr in enumerate(sorted(context.keys(), key=len, reverse=True)): + placeholder = f"__expr{i}" + new_src = new_src.replace(expr, placeholder) + new_context[placeholder] = context[expr] + + return new_src, new_context diff --git a/setup.py b/setup.py index 098e3ba..7c3d64c 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ }, keywords=["check", "check50"], name="check50", - packages=["check50", "check50.renderer"], + packages=["check50", "check50.renderer", "check50.assertions"], python_requires=">= 3.8", entry_points={ "console_scripts": ["check50=check50.__main__:main"] diff --git a/tests/check50_tests.py b/tests/check50_tests.py index 12b2584..4c3dc1e 100644 --- a/tests/check50_tests.py +++ b/tests/check50_tests.py @@ -493,5 +493,15 @@ def test_successful_exit(self): self.assertEqual(process.returncode, 0) +class TestAssertionsRewrite(Base): + def test_assertions_rewrite_enabled(self): + process = pexpect.spawn(f"check50 --dev {CHECKS_DIRECTORY}/assertions_rewrite_enabled") + process.expect_exact(":)") + + def test_assertions_rewrite_disabled(self): + process = pexpect.spawn(f"check50 --dev {CHECKS_DIRECTORY}/assertions_rewrite_disabled") + process.expect_exact(":)") + + if __name__ == "__main__": unittest.main() diff --git a/tests/checks/assertions_rewrite_disabled/.cs50.yaml b/tests/checks/assertions_rewrite_disabled/.cs50.yaml new file mode 100644 index 0000000..be5ecce --- /dev/null +++ b/tests/checks/assertions_rewrite_disabled/.cs50.yaml @@ -0,0 +1,3 @@ +check50: + files: + - !exclude "*" diff --git a/tests/checks/assertions_rewrite_disabled/__init__.py b/tests/checks/assertions_rewrite_disabled/__init__.py new file mode 100644 index 0000000..9c44a49 --- /dev/null +++ b/tests/checks/assertions_rewrite_disabled/__init__.py @@ -0,0 +1,36 @@ +# ENABLE_CHECK50_ASSERT = 0 +import check50 + +@check50.check() +def foo(): + stdout = "Hello, world!" + try: + assert stdout is "Special cases aren't special enough to break the rules." + except AssertionError: + pass + + try: + assert stdout is "Although practicality beats purity.", "help msg goes here" + except AssertionError: + pass + + try: + assert stdout == "Errors should never pass silently." + except AssertionError: + pass + + try: + assert stdout in "Unless explicitly silenced." + except AssertionError: + pass + + try: + assert bar(qux()) in "In the face of ambiguity, refuse the temptation to guess." + except AssertionError: + pass + +def bar(baz): + return "Hello, world!" + +def qux(): + return diff --git a/tests/checks/assertions_rewrite_enabled/.cs50.yaml b/tests/checks/assertions_rewrite_enabled/.cs50.yaml new file mode 100644 index 0000000..be5ecce --- /dev/null +++ b/tests/checks/assertions_rewrite_enabled/.cs50.yaml @@ -0,0 +1,3 @@ +check50: + files: + - !exclude "*" diff --git a/tests/checks/assertions_rewrite_enabled/__init__.py b/tests/checks/assertions_rewrite_enabled/__init__.py new file mode 100644 index 0000000..bc10555 --- /dev/null +++ b/tests/checks/assertions_rewrite_enabled/__init__.py @@ -0,0 +1,42 @@ +# ENABLE_CHECK50_ASSERT = 1 +import check50 + +@check50.check() +def foo(): + stdout = "Hello, world!" + try: + assert stdout is "Beautiful is better than ugly." + except check50.Failure: + pass + + try: + assert stdout is "Explicit is better than implicit.", "help msg goes here" + except check50.Failure: + pass + + try: + assert stdout == "Simple is better than complex." + except check50.Mismatch: + pass + + try: + assert stdout in "Complex is better than complicated." + except check50.Missing: + pass + + try: + assert stdout in "Flat is better than nested.", check50.Mismatch("Flat is better than nested.", stdout) + except check50.Mismatch: + pass + + try: + assert bar(qux()) in "Readability counts." + except check50.Missing: + pass + + +def bar(baz): + return "Hello, world!" + +def qux(): + return