From 1b671e2b6b205385a0a10dc4d0b0240653e8c44e Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Fri, 9 Apr 2021 19:37:55 -0400 Subject: [PATCH 01/47] refactor, fix scoped session --- setup.py | 2 +- src/cs50/__init__.py | 20 +- src/cs50/_logger.py | 48 ++++ src/cs50/_session.py | 80 ++++++ src/cs50/_statement.py | 269 +++++++++++++++++++ src/cs50/cs50.py | 170 +++++------- src/cs50/sql.py | 582 +++++------------------------------------ tests/test_cs50.py | 151 +++++++++++ 8 files changed, 684 insertions(+), 638 deletions(-) create mode 100644 src/cs50/_logger.py create mode 100644 src/cs50/_session.py create mode 100644 src/cs50/_statement.py create mode 100644 tests/test_cs50.py diff --git a/setup.py b/setup.py index 550e65d..a5b8fb7 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.4" + version="7.0.0" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index aaec161..f04da00 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,20 +1,6 @@ -import logging -import os -import sys +from ._logger import _setup_logger +_setup_logger() - -# Disable cs50 logger by default -logging.getLogger("cs50").disabled = True - -# Import cs50_* -from .cs50 import get_char, get_float, get_int, get_string -try: - from .cs50 import get_long -except ImportError: - pass - -# Hook into flask importing +from .cs50 import get_float, get_int, get_string from . import flask - -# Wrap SQLAlchemy from .sql import SQL diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py new file mode 100644 index 0000000..46f0821 --- /dev/null +++ b/src/cs50/_logger.py @@ -0,0 +1,48 @@ +import logging +import os.path +import re +import sys +import traceback + +import termcolor + + +def _setup_logger(): + _logger = logging.getLogger("cs50") + _logger.disabled = True + _logger.setLevel(logging.DEBUG) + + # Log messages once + _logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(levelname)s: %(message)s") + formatter.formatException = lambda exc_info: _formatException(*exc_info) + handler.setFormatter(formatter) + _logger.addHandler(handler) + + +def _formatException(type, value, tb): + """ + Format traceback, darkening entries from global site-packages directories + and user-specific site-packages directory. + https://stackoverflow.com/a/46071447/5156190 + """ + + # Absolute paths to site-packages + packages = tuple(os.path.join(os.path.abspath(p), "") for p in sys.path[1:]) + + # Highlight lines not referring to files in site-packages + lines = [] + for line in traceback.format_exception(type, value, tb): + matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) + if matches and matches.group(1).startswith(packages): + lines += line + else: + matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) + lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) + return "".join(lines).rstrip() + + diff --git a/src/cs50/_session.py b/src/cs50/_session.py new file mode 100644 index 0000000..4d1a2a9 --- /dev/null +++ b/src/cs50/_session.py @@ -0,0 +1,80 @@ +import os + +import sqlalchemy +import sqlalchemy.orm +import sqlite3 + +class Session: + def __init__(self, url, **engine_kwargs): + self._url = url + if _is_sqlite_url(self._url): + _assert_sqlite_file_exists(self._url) + + self._engine = _create_engine(self._url, **engine_kwargs) + self._is_postgres = self._engine.url.get_backend_name() in {"postgres", "postgresql"} + _setup_on_connect(self._engine) + self._session = _create_scoped_session(self._engine) + + + def is_postgres(self): + return self._is_postgres + + + def execute(self, statement): + return self._session.execute(sqlalchemy.text(str(statement))) + + + def __getattr__(self, attr): + return getattr(self._session, attr) + + +def _is_sqlite_url(url): + return url.startswith("sqlite:///") + + +def _assert_sqlite_file_exists(url): + path = url[len("sqlite:///"):] + if not os.path.exists(path): + raise RuntimeError(f"does not exist: {path}") + if not os.path.isfile(path): + raise RuntimeError(f"not a file: {path}") + + +def _create_engine(url, **kwargs): + try: + engine = sqlalchemy.create_engine(url, **kwargs) + except sqlalchemy.exc.ArgumentError: + raise RuntimeError(f"invalid URL: {url}") from None + + engine.execution_options(autocommit=False) + return engine + + +def _setup_on_connect(engine): + def connect(dbapi_connection, _): + _disable_auto_begin_commit(dbapi_connection) + if _is_sqlite_connection(dbapi_connection): + _enable_sqlite_foreign_key_constraints(dbapi_connection) + + sqlalchemy.event.listen(engine, "connect", connect) + + +def _create_scoped_session(engine): + session_factory = sqlalchemy.orm.sessionmaker(bind=engine) + return sqlalchemy.orm.scoping.scoped_session(session_factory) + + +def _disable_auto_begin_commit(dbapi_connection): + # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves + # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + dbapi_connection.isolation_level = None + + +def _is_sqlite_connection(dbapi_connection): + return isinstance(dbapi_connection, sqlite3.Connection) + + +def _enable_sqlite_foreign_key_constraints(dbapi_connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py new file mode 100644 index 0000000..7519b1e --- /dev/null +++ b/src/cs50/_statement.py @@ -0,0 +1,269 @@ +import collections +import datetime +import enum +import re + +import sqlalchemy +import sqlparse + + +class Statement: + def __init__(self, dialect, sql, *args, **kwargs): + if len(args) > 0 and len(kwargs) > 0: + raise RuntimeError("cannot pass both positional and named parameters") + + self._dialect = dialect + self._sql = sql + self._args = args + self._kwargs = kwargs + + self._statement = self._parse() + self._command = self._parse_command() + self._tokens = self._bind_params() + + def _parse(self): + formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip() + parsed_statements = sqlparse.parse(formatted_statements) + num_of_statements = len(parsed_statements) + if num_of_statements == 0: + raise RuntimeError("missing statement") + elif num_of_statements > 1: + raise RuntimeError("too many statements at once") + + return parsed_statements[0] + + + def _parse_command(self): + for token in self._statement: + if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: + token_value = token.value.upper() + if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: + command = token_value + break + else: + command = None + + return command + + + def _bind_params(self): + tokens = self._tokenize() + paramstyle, placeholders = self._parse_placeholders(tokens) + if paramstyle in [Paramstyle.FORMAT, Paramstyle.QMARK]: + tokens = self._bind_format_or_qmark(placeholders, tokens) + elif paramstyle == Paramstyle.NUMERIC: + tokens = self._bind_numeric(placeholders, tokens) + if paramstyle in [Paramstyle.NAMED, Paramstyle.PYFORMAT]: + tokens = self._bind_named_or_pyformat(placeholders, tokens) + + tokens = _escape_verbatim_colons(tokens) + return tokens + + + def _tokenize(self): + return list(self._statement.flatten()) + + + def _parse_placeholders(self, tokens): + paramstyle = None + placeholders = collections.OrderedDict() + for index, token in enumerate(tokens): + if _is_placeholder(token): + _paramstyle, name = _parse_placeholder(token) + if paramstyle is None: + paramstyle = _paramstyle + elif _paramstyle != paramstyle: + raise RuntimeError("inconsistent paramstyle") + + placeholders[index] = name + + if paramstyle is None: + paramstyle = self._default_paramstyle() + + return paramstyle, placeholders + + + def _default_paramstyle(self): + paramstyle = None + if self._args: + paramstyle = Paramstyle.QMARK + elif self._kwargs: + paramstyle = Paramstyle.NAMED + + return paramstyle + + + def _bind_format_or_qmark(self, placeholders, tokens): + if len(placeholders) != len(self._args): + _placeholders = ", ".join([str(token) for token in placeholders.values()]) + _args = ", ".join([str(self._escape(arg)) for arg in self._args]) + if len(placeholders) < len(self._args): + raise RuntimeError(f"fewer placeholders ({_placeholders}) than values ({_args})") + + raise RuntimeError(f"more placeholders ({_placeholders}) than values ({_args})") + + for arg_index, token_index in enumerate(placeholders.keys()): + tokens[token_index] = self._escape(self._args[arg_index]) + + return tokens + + + def _bind_numeric(self, placeholders, tokens): + unused_arg_indices = set(range(len(self._args))) + for token_index, num in placeholders.items(): + if num >= len(self._args): + raise RuntimeError(f"missing value for placeholder ({num + 1})") + + tokens[token_index] = self._escape(self._args[num]) + unused_arg_indices.remove(num) + + if len(unused_arg_indices) > 0: + unused_args = ", ".join([str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)]) + raise RuntimeError(f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})") + + return tokens + + + def _bind_named_or_pyformat(self, placeholders, tokens): + unused_params = set(self._kwargs.keys()) + for token_index, param_name in placeholders.items(): + if param_name not in self._kwargs: + raise RuntimeError(f"missing value for placeholder ({param_name})") + + tokens[token_index] = self._escape(self._kwargs[param_name]) + unused_params.remove(param_name) + + if len(unused_params) > 0: + raise RuntimeError("unused value{'' if len(unused_params) == 1 else 's'} ({', '.join(sorted(unused_params))})") + + return tokens + + + def _escape(self, value): + """ + Escapes value using engine's conversion function. + https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor + """ + + if isinstance(value, (list, tuple)): + return self._escape_iterable(value) + + if isinstance(value, bool): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) + + if isinstance(value, bytes): + if self._dialect.name in ["mysql", "sqlite"]: + # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") + if self._dialect.name in ["postgres", "postgresql"]: + # https://dba.stackexchange.com/a/203359 + return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") + + raise RuntimeError(f"unsupported value: {value}") + + if isinstance(value, datetime.date): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d"))) + + if isinstance(value, datetime.datetime): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + + if isinstance(value, datetime.time): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%H:%M:%S"))) + + if isinstance(value, float): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Float().literal_processor(self._dialect)(value)) + + if isinstance(value, int): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) + + if isinstance(value, str): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._dialect)(value)) + + if value is None: + return sqlparse.sql.Token( + sqlparse.tokens.Keyword, + sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) + + raise RuntimeError(f"unsupported value: {value}") + + + def _escape_iterable(self, iterable): + return sqlparse.sql.TokenList( + sqlparse.parse(", ".join([str(self._escape(v)) for v in iterable]))) + + + def get_command(self): + return self._command + + + def __str__(self): + return "".join([str(token) for token in self._tokens]) + + +def _is_placeholder(token): + return token.ttype == sqlparse.tokens.Name.Placeholder + + +def _parse_placeholder(token): + if token.value == "?": + return Paramstyle.QMARK, None + + # E.g., :1 + matches = re.search(r"^:([1-9]\d*)$", token.value) + if matches: + return Paramstyle.NUMERIC, int(matches.group(1)) - 1 + + # E.g., :foo + matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) + if matches: + return Paramstyle.NAMED, matches.group(1) + + if token.value == "%s": + return Paramstyle.FORMAT, None + + # E.g., %(foo) + matches = re.search(r"%\((\w+)\)s$", token.value) + if matches: + return Paramstyle.PYFORMAT, matches.group(1) + + raise RuntimeError(f"{token.value}: invalid placeholder") + + +def _escape_verbatim_colons(tokens): + for token in tokens: + if _is_string_literal(token): + token.value = re.sub("(^'|\s+):", r"\1\:", token.value) + elif _is_identifier(token): + token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) + + return tokens + + +def _is_string_literal(token): + return token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] + + +def _is_identifier(token): + return token.ttype == sqlparse.tokens.Literal.String.Symbol + + +class Paramstyle(enum.Enum): + FORMAT = enum.auto() + NAMED = enum.auto() + NUMERIC = enum.auto() + PYFORMAT = enum.auto() + QMARK = enum.auto() diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 1d7b6ea..573d862 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -1,98 +1,6 @@ -from __future__ import print_function - -import inspect -import logging -import os import re import sys -from distutils.sysconfig import get_python_lib -from os.path import abspath, join -from termcolor import colored -from traceback import format_exception - - -# Configure default logging handler and formatter -# Prevent flask, werkzeug, etc from adding default handler -logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) - -try: - # Patch formatException - logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) -except IndexError: - pass - -# Configure cs50 logger -_logger = logging.getLogger("cs50") -_logger.setLevel(logging.DEBUG) - -# Log messages once -_logger.propagate = False - -handler = logging.StreamHandler() -handler.setLevel(logging.DEBUG) - -formatter = logging.Formatter("%(levelname)s: %(message)s") -formatter.formatException = lambda exc_info: _formatException(*exc_info) -handler.setFormatter(formatter) -_logger.addHandler(handler) - - -class _flushfile(): - """ - Disable buffering for standard output and standard error. - - http://stackoverflow.com/a/231216 - """ - - def __init__(self, f): - self.f = f - - def __getattr__(self, name): - return object.__getattribute__(self.f, name) - - def write(self, x): - self.f.write(x) - self.f.flush() - - -sys.stderr = _flushfile(sys.stderr) -sys.stdout = _flushfile(sys.stdout) - - -def _formatException(type, value, tb): - """ - Format traceback, darkening entries from global site-packages directories - and user-specific site-packages directory. - - https://stackoverflow.com/a/46071447/5156190 - """ - - # Absolute paths to site-packages - packages = tuple(join(abspath(p), "") for p in sys.path[1:]) - - # Highlight lines not referring to files in site-packages - lines = [] - for line in format_exception(type, value, tb): - matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) - if matches and matches.group(1).startswith(packages): - lines += line - else: - matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3)) - return "".join(lines).rstrip() - - -sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) - - -def eprint(*args, **kwargs): - raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!") - - -def get_char(prompt): - raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!") - def get_float(prompt): """ @@ -101,14 +9,21 @@ def get_float(prompt): prompted to retry. If line can't be read, return None. """ while True: - s = get_string(prompt) - if s is None: - return None - if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s): - try: - return float(s) - except (OverflowError, ValueError): - pass + try: + return _get_float(prompt) + except (OverflowError, ValueError): + pass + + +def _get_float(prompt): + s = get_string(prompt) + if s is None: + return + + if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s): + return float(s) + + raise ValueError(f"invalid float literal: {s}") def get_int(prompt): @@ -118,14 +33,21 @@ def get_int(prompt): can't be read, return None. """ while True: - s = get_string(prompt) - if s is None: - return None - if re.search(r"^[+-]?\d+$", s): - try: - return int(s, 10) - except ValueError: - pass + try: + return _get_int(prompt) + except (MemoryError, ValueError): + pass + + +def _get_int(prompt): + s = get_string(prompt) + if s is None: + return + + if re.search(r"^[+-]?\d+$", s): + return int(s, 10) + + raise ValueError(f"invalid int literal for base 10: {s}") def get_string(prompt): @@ -137,7 +59,35 @@ def get_string(prompt): """ if type(prompt) is not str: raise TypeError("prompt must be of type str") + try: - return input(prompt) + return _get_input(prompt) except EOFError: - return None + return + + +def _get_input(prompt): + return input(prompt) + + +class _flushfile(): + """ + Disable buffering for standard output and standard error. + http://stackoverflow.com/a/231216 + """ + + def __init__(self, f): + self.f = f + + def __getattr__(self, name): + return object.__getattribute__(self.f, name) + + def write(self, x): + self.f.write(x) + self.f.flush() + +def disable_buffering(): + sys.stderr = _flushfile(sys.stderr) + sys.stdout = _flushfile(sys.stdout) + +disable_buffering() diff --git a/src/cs50/sql.py b/src/cs50/sql.py index f95b347..b778601 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,545 +1,107 @@ -def _enable_logging(f): - """Enable logging of SQL statements when Flask is in use.""" +import decimal +import logging +import warnings - import logging - import functools +import sqlalchemy +import termcolor - @functools.wraps(f) - def decorator(*args, **kwargs): +from ._session import Session +from ._statement import Statement - # Infer whether Flask is installed - try: - import flask - except ModuleNotFoundError: - return f(*args, **kwargs) - - # Enable logging - disabled = logging.getLogger("cs50").disabled - if flask.current_app: - logging.getLogger("cs50").disabled = False - try: - return f(*args, **kwargs) - finally: - logging.getLogger("cs50").disabled = disabled - - return decorator - - -class SQL(object): - """Wrap SQLAlchemy to provide a simple SQL API.""" - - def __init__(self, url, **kwargs): - """ - Create instance of sqlalchemy.engine.Engine. +_logger = logging.getLogger("cs50") - URL should be a string that indicates database dialect and connection arguments. - http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine - http://docs.sqlalchemy.org/en/latest/dialects/index.html - """ +class SQL: + def __init__(self, url, **engine_kwargs): + self._session = Session(url, **engine_kwargs) + self._autocommit = False + self._test_database() - # Lazily import - import logging - import os - import re - import sqlalchemy - import sqlalchemy.orm - import sqlite3 - # Require that file already exist for SQLite - matches = re.search(r"^sqlite:///(.+)$", url) - if matches: - if not os.path.exists(matches.group(1)): - raise RuntimeError("does not exist: {}".format(matches.group(1))) - if not os.path.isfile(matches.group(1)): - raise RuntimeError("not a file: {}".format(matches.group(1))) + def _test_database(self): + self.execute("SELECT 1") - # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) - # Get logger - self._logger = logging.getLogger("cs50") - - # Listener for connections - def connect(dbapi_connection, connection_record): - - # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves - # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl - dbapi_connection.isolation_level = None - - # Enable foreign key constraints - if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() - - # Register listener - sqlalchemy.event.listen(self._engine, "connect", connect) - - # Autocommit by default - self._autocommit = True - - # Test database - disabled = self._logger.disabled - self._logger.disabled = True - try: - self.execute("SELECT 1") - except sqlalchemy.exc.OperationalError as e: - e = RuntimeError(_parse_exception(e)) - e.__cause__ = None - raise e - finally: - self._logger.disabled = disabled - - def __del__(self): - """Disconnect from database.""" - self._disconnect() - - def _disconnect(self): - """Close database connection.""" - if hasattr(self, "_session"): - self._session.remove() - delattr(self, "_session") - - @_enable_logging def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" + statement = Statement(self._session.get_bind().dialect, sql, *args, **kwargs) + command = statement.get_command() + if command in ["BEGIN", "START"]: + self._autocommit = False - # Lazily import - import decimal - import re - import sqlalchemy - import sqlparse - import termcolor - import warnings + if self._autocommit: + self._session.execute("BEGIN") - # Parse statement, stripping comments and then leading/trailing whitespace - statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip()) + result = self._execute(statement) - # Allow only one statement at a time, since SQLite doesn't support multiple - # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute - if len(statements) > 1: - raise RuntimeError("too many statements at once") - elif len(statements) == 0: - raise RuntimeError("missing statement") + if self._autocommit: + self._session.execute("COMMIT") + self._session.remove() - # Ensure named and positional parameters are mutually exclusive - if len(args) > 0 and len(kwargs) > 0: - raise RuntimeError("cannot pass both positional and named parameters") + if command in ["COMMIT", "ROLLBACK"]: + self._autocommit = True + self._session.remove() - # Infer command from (unflattened) statement - for token in statements[0]: - if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - token_value = token.value.upper() - if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: - command = token_value - break + if command == "SELECT": + ret = _fetch_select_result(result) + elif command == "INSERT": + if self._session.is_postgres(): + ret = self._get_last_val() + else: + ret = result.lastrowid if result.rowcount == 1 else None + elif command in ["DELETE", "UPDATE"]: + ret = result.rowcount else: - command = None - - # Flatten statement - tokens = list(statements[0].flatten()) - - # Validate paramstyle - placeholders = {} - paramstyle = None - for index, token in enumerate(tokens): - - # If token is a placeholder - if token.ttype == sqlparse.tokens.Name.Placeholder: - - # Determine paramstyle, name - _paramstyle, name = _parse_placeholder(token) - - # Remember paramstyle - if not paramstyle: - paramstyle = _paramstyle - - # Ensure paramstyle is consistent - elif _paramstyle != paramstyle: - raise RuntimeError("inconsistent paramstyle") - - # Remember placeholder's index, name - placeholders[index] = name - - # If no placeholders - if not paramstyle: - - # Error-check like qmark if args - if args: - paramstyle = "qmark" - - # Error-check like named if kwargs - elif kwargs: - paramstyle = "named" - - # In case of errors - _placeholders = ", ".join([str(tokens[index]) for index in placeholders]) - _args = ", ".join([str(self._escape(arg)) for arg in args]) - - # qmark - if paramstyle == "qmark": - - # Validate number of placeholders - if len(placeholders) != len(args): - if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) - else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) - - # Escape values - for i, index in enumerate(placeholders.keys()): - tokens[index] = self._escape(args[i]) - - # numeric - elif paramstyle == "numeric": - - # Escape values - for index, i in placeholders.items(): - if i >= len(args): - raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args))) - tokens[index] = self._escape(args[i]) - - # Check if any values unused - indices = set(range(len(args))) - set(placeholders.values()) - if indices: - raise RuntimeError("unused {} ({})".format( - "value" if len(indices) == 1 else "values", - ", ".join([str(self._escape(args[index])) for index in indices]))) - - # named - elif paramstyle == "named": - - # Escape values - for index, name in placeholders.items(): - if name not in kwargs: - raise RuntimeError("missing value for placeholder (:{})".format(name)) - tokens[index] = self._escape(kwargs[name]) - - # Check if any keys unused - keys = kwargs.keys() - placeholders.values() - if keys: - raise RuntimeError("unused values ({})".format(", ".join(keys))) - - # format - elif paramstyle == "format": - - # Validate number of placeholders - if len(placeholders) != len(args): - if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) - else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) - - # Escape values - for i, index in enumerate(placeholders.keys()): - tokens[index] = self._escape(args[i]) - - # pyformat - elif paramstyle == "pyformat": - - # Escape values - for index, name in placeholders.items(): - if name not in kwargs: - raise RuntimeError("missing value for placeholder (%{}s)".format(name)) - tokens[index] = self._escape(kwargs[name]) - - # Check if any keys unused - keys = kwargs.keys() - placeholders.values() - if keys: - raise RuntimeError("unused {} ({})".format( - "value" if len(keys) == 1 else "values", - ", ".join(keys))) - - # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape - # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text - for index, token in enumerate(tokens): - - # In string literal - # https://www.sqlite.org/lang_keywords.html - if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]: - token.value = re.sub("(^'|\s+):", r"\1\:", token.value) - - # In identifier - # https://www.sqlite.org/lang_keywords.html - elif token.ttype == sqlparse.tokens.Literal.String.Symbol: - token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) - - # Join tokens into statement - statement = "".join([str(token) for token in tokens]) - - # Connect to database - try: - - # Infer whether Flask is installed - import flask + ret = True - # Infer whether app is defined - assert flask.current_app + return ret - # If no sessions for any databases yet - if not hasattr(flask.g, "_sessions"): - setattr(flask.g, "_sessions", {}) - sessions = getattr(flask.g, "_sessions") - - # If no session yet for this database - # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data - # https://stackoverflow.com/a/34010159 - if self not in sessions: - - # Connect to database - sessions[self] = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - - # Remove session later - if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: - flask.current_app.teardown_appcontext(_teardown_appcontext) - - # Use this session - session = sessions[self] - - except (ModuleNotFoundError, AssertionError): - - # If no connection yet - if not hasattr(self, "_session"): - self._session = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - - # Use this session - session = self._session + def _execute(self, statement): # Catch SQLAlchemy warnings with warnings.catch_warnings(): - # Raise exceptions for warnings warnings.simplefilter("error") - - # Prepare, execute statement try: + return self._session.execute(statement) + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(termcolor.colored(str(statement), "yellow")) + raise ValueError(exc.orig) from None + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + self._session.remove() + _logger.debug(termcolor.colored(statement, "red")) + raise RuntimeError(exc.orig) from None - # Join tokens into statement, abbreviating binary data as - _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) - - # Check for start of transaction - if command in ["BEGIN", "START"]: - self._autocommit = False - - # Execute statement - if self._autocommit: - session.execute(sqlalchemy.text("BEGIN")) - result = session.execute(sqlalchemy.text(statement)) - if self._autocommit: - session.execute(sqlalchemy.text("COMMIT")) - - # Check for end of transaction - if command in ["COMMIT", "ROLLBACK"]: - self._autocommit = True - - # Return value - ret = True - - # If SELECT, return result set as list of dict objects - if command == "SELECT": - - # Coerce types - rows = [dict(row) for row in result.fetchall()] - for row in rows: - for column in row: - - # Coerce decimal.Decimal objects to float objects - # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - if type(row[column]) is decimal.Decimal: - row[column] = float(row[column]) - - # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes - elif type(row[column]) is memoryview: - row[column] = bytes(row[column]) - - # Rows to be returned - ret = rows - - # If INSERT, return primary key value for a newly inserted row (or None if none) - elif command == "INSERT": - if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: - try: - result = session.execute("SELECT LASTVAL()") - ret = result.first()[0] - except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session - ret = None - else: - ret = result.lastrowid if result.rowcount == 1 else None - - # If DELETE or UPDATE, return number of rows matched - elif command in ["DELETE", "UPDATE"]: - ret = result.rowcount - - # If constraint violated, return None - except sqlalchemy.exc.IntegrityError as e: - self._logger.debug(termcolor.colored(statement, "yellow")) - e = ValueError(e.orig) - e.__cause__ = None - raise e - - # If user error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: - self._disconnect() - self._logger.debug(termcolor.colored(statement, "red")) - e = RuntimeError(e.orig) - e.__cause__ = None - raise e - - # Return value - else: - self._logger.debug(termcolor.colored(_statement, "green")) - return ret - - def _escape(self, value): - """ - Escapes value using engine's conversion function. - - https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor - """ - - # Lazily import - import sqlparse + _logger.debug(termcolor.colored(str(statement), "green")) - def __escape(value): - # Lazily import - import datetime - import sqlalchemy - - # bool - if type(value) is bool: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) - - # bytes - elif type(value) is bytes: - if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html - elif self._engine.url.get_backend_name() == "postgresql": - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359 - else: - raise RuntimeError("unsupported value: {}".format(value)) - - # datetime.date - elif type(value) is datetime.date: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) - - # datetime.datetime - elif type(value) is datetime.datetime: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) - - # datetime.time - elif type(value) is datetime.time: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) - - # float - elif type(value) is float: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) - - # int - elif type(value) is int: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) - - # str - elif type(value) is str: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) - - # None - elif value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self._engine.dialect)(value)) - - # Unsupported value - else: - raise RuntimeError("unsupported value: {}".format(value)) - - # Escape value(s), separating with commas as needed - if type(value) in [list, tuple]: - return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) - else: - return __escape(value) - - -def _parse_exception(e): - """Parses an exception, returns its message.""" - - # Lazily import - import re - - # MySQL - matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) - if matches: - return matches.group(1) - - # PostgreSQL - matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e)) - if matches: - return matches.group(1) - - # SQLite - matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e)) - if matches: - return matches.group(1) - - # Default - return str(e) - - -def _parse_placeholder(token): - """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder.""" - - # Lazily load - import re - import sqlparse - - # Validate token - if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder: - raise TypeError() - - # qmark - if token.value == "?": - return "qmark", None + def _get_last_val(self): + try: + return self._session.execute("SELECT LASTVAL()").first()[0] + except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session + return None - # numeric - matches = re.search(r"^:([1-9]\d*)$", token.value) - if matches: - return "numeric", int(matches.group(1)) - 1 - # named - matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) - if matches: - return "named", matches.group(1) + def init_app(self, app): + @app.teardown_appcontext + def shutdown_session(res_or_exc): + self._session.remove() + return res_or_exc - # format - if token.value == "%s": - return "format", None + logging.getLogger("cs50").disabled = False - # pyformat - matches = re.search(r"%\((\w+)\)s$", token.value) - if matches: - return "pyformat", matches.group(1) - # Invalid - raise RuntimeError("{}: invalid placeholder".format(token.value)) +def _fetch_select_result(result): + rows = [dict(row) for row in result.fetchall()] + for row in rows: + for column in row: + # Coerce decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + if isinstance(row[column], decimal.Decimal): + row[column] = float(row[column]) + # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes + elif isinstance(row[column], memoryview): + row[column] = bytes(row[column]) -def _teardown_appcontext(exception=None): - """Closes context's database connection, if any.""" - import flask - for session in flask.g.pop("_sessions", {}).values(): - session.remove() + return rows diff --git a/tests/test_cs50.py b/tests/test_cs50.py new file mode 100644 index 0000000..a58424d --- /dev/null +++ b/tests/test_cs50.py @@ -0,0 +1,151 @@ +import math +import sys +import unittest + +from unittest.mock import patch + +from cs50.cs50 import get_string, _get_int, _get_float + + +class TestCS50(unittest.TestCase): + @patch("cs50.cs50._get_input", return_value="") + def test_get_string_empty_input(self, mock_get_input): + """Returns empty string when input is empty""" + self.assertEqual(get_string("Answer: "), "") + mock_get_input.assert_called_with("Answer: ") + + + @patch("cs50.cs50._get_input", return_value="test") + def test_get_string_nonempty_input(self, mock_get_input): + """Returns the provided non-empty input""" + self.assertEqual(get_string("Answer: "), "test") + mock_get_input.assert_called_with("Answer: ") + + + @patch("cs50.cs50._get_input", side_effect=EOFError) + def test_get_string_eof(self, mock_get_input): + """Returns None on EOF""" + self.assertIs(get_string("Answer: "), None) + mock_get_input.assert_called_with("Answer: ") + + + def test_get_string_invalid_prompt(self): + """Raises TypeError when prompt is not str""" + with self.assertRaises(TypeError): + get_string(1) + + + @patch("cs50.cs50.get_string", return_value=None) + def test_get_int_eof(self, mock_get_string): + """Returns None on EOF""" + self.assertIs(_get_int("Answer: "), None) + mock_get_string.assert_called_with("Answer: ") + + + def test_get_int_valid_input(self): + """Returns the provided integer input""" + + def assert_equal(return_value, expected_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + self.assertEqual(_get_int("Answer: "), expected_value) + mock_get_string.assert_called_with("Answer: ") + + values = [ + ("0", 0), + ("50", 50), + ("+50", 50), + ("+42", 42), + ("-42", -42), + ("42", 42), + ] + + for return_value, expected_value in values: + assert_equal(return_value, expected_value) + + + def test_get_int_invalid_input(self): + """Raises ValueError when input is invalid base-10 int""" + + def assert_raises_valueerror(return_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + with self.assertRaises(ValueError): + _get_int("Answer: ") + + mock_get_string.assert_called_with("Answer: ") + + return_values = [ + "++50", + "--50", + "50+", + "50-", + " 50", + " +50", + " -50", + "50 ", + "ab50", + "50ab", + "ab50ab", + ] + + for return_value in return_values: + assert_raises_valueerror(return_value) + + + @patch("cs50.cs50.get_string", return_value=None) + def test_get_float_eof(self, mock_get_string): + """Returns None on EOF""" + self.assertIs(_get_float("Answer: "), None) + mock_get_string.assert_called_with("Answer: ") + + + def test_get_float_valid_input(self): + """Returns the provided integer input""" + def assert_equal(return_value, expected_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + f = _get_float("Answer: ") + self.assertTrue(math.isclose(f, expected_value)) + mock_get_string.assert_called_with("Answer: ") + + values = [ + (".0", 0.0), + ("0.", 0.0), + (".42", 0.42), + ("42.", 42.0), + ("50", 50.0), + ("+50", 50.0), + ("-50", -50.0), + ("+3.14", 3.14), + ("-3.14", -3.14), + ] + + for return_value, expected_value in values: + assert_equal(return_value, expected_value) + + + def test_get_float_invalid_input(self): + """Raises ValueError when input is invalid float""" + + def assert_raises_valueerror(return_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + with self.assertRaises(ValueError): + _get_float("Answer: ") + + mock_get_string.assert_called_with("Answer: ") + + return_values = [ + ".", + "..5", + "a.5", + ".5a" + "0.5a", + "a0.42", + " .42", + "3.14 ", + "++3.14", + "3.14+", + "--3.14", + "3.14--", + ] + + for return_value in return_values: + assert_raises_valueerror(return_value) From d23ed8a9bdd2bbf529021904aa6c98b640781033 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Fri, 9 Apr 2021 19:44:21 -0400 Subject: [PATCH 02/47] remove unused import --- src/cs50/flask.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 324ec30..a0e077a 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -2,18 +2,17 @@ import pkgutil import sys +from distutils.version import StrictVersion +from werkzeug.middleware.proxy_fix import ProxyFix + def _wrap_flask(f): if f is None: return - from distutils.version import StrictVersion - from .cs50 import _formatException - if f.__version__ < StrictVersion("1.0"): return if os.getenv("CS50_IDE_TYPE") == "online": - from werkzeug.middleware.proxy_fix import ProxyFix _flask_init_before = f.Flask.__init__ def _flask_init_after(self, *args, **kwargs): _flask_init_before(self, *args, **kwargs) From b5f030083e30b31acd6056aee2097462c430a731 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Fri, 9 Apr 2021 19:49:44 -0400 Subject: [PATCH 03/47] fix logger --- src/cs50/sql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index b778601..d5c8d49 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -63,7 +63,7 @@ def _execute(self, statement): # Raise exceptions for warnings warnings.simplefilter("error") try: - return self._session.execute(statement) + result = self._session.execute(statement) except sqlalchemy.exc.IntegrityError as exc: _logger.debug(termcolor.colored(str(statement), "yellow")) raise ValueError(exc.orig) from None @@ -73,6 +73,7 @@ def _execute(self, statement): raise RuntimeError(exc.orig) from None _logger.debug(termcolor.colored(str(statement), "green")) + return result def _get_last_val(self): From 022a3da3151c7c82315e12bcb2d87fabe61e4600 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Fri, 9 Apr 2021 20:01:11 -0400 Subject: [PATCH 04/47] remove test_database, rename param --- src/cs50/sql.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d5c8d49..a1f7dbd 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -15,11 +15,6 @@ class SQL: def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) self._autocommit = False - self._test_database() - - - def _test_database(self): - self.execute("SELECT 1") def execute(self, sql, *args, **kwargs): @@ -85,9 +80,8 @@ def _get_last_val(self): def init_app(self, app): @app.teardown_appcontext - def shutdown_session(res_or_exc): + def shutdown_session(_): self._session.remove() - return res_or_exc logging.getLogger("cs50").disabled = False From a3b32c45bb64308eed38b53680e5997e00cbeac8 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Fri, 9 Apr 2021 20:12:13 -0400 Subject: [PATCH 05/47] fix exception formatting --- src/cs50/_logger.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py index 46f0821..c489111 100644 --- a/src/cs50/_logger.py +++ b/src/cs50/_logger.py @@ -8,6 +8,16 @@ def _setup_logger(): + # Configure default logging handler and formatter + # Prevent flask, werkzeug, etc from adding default handler + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) + + try: + # Patch formatException + logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) + except IndexError: + pass + _logger = logging.getLogger("cs50") _logger.disabled = True _logger.setLevel(logging.DEBUG) @@ -23,6 +33,8 @@ def _setup_logger(): handler.setFormatter(formatter) _logger.addHandler(handler) + sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) + def _formatException(type, value, tb): """ @@ -44,5 +56,3 @@ def _formatException(type, value, tb): matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) return "".join(lines).rstrip() - - From 663e6bdf919853d8518d140e3b9c5da51edca0ed Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Sat, 10 Apr 2021 00:06:25 -0400 Subject: [PATCH 06/47] simplify _session, execute --- src/cs50/_session.py | 16 +++++----------- src/cs50/_statement.py | 24 ++++++++++++++---------- src/cs50/sql.py | 25 +++++++++++++++++-------- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 4d1a2a9..441371a 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -6,18 +6,12 @@ class Session: def __init__(self, url, **engine_kwargs): - self._url = url - if _is_sqlite_url(self._url): - _assert_sqlite_file_exists(self._url) + if _is_sqlite_url(url): + _assert_sqlite_file_exists(url) - self._engine = _create_engine(self._url, **engine_kwargs) - self._is_postgres = self._engine.url.get_backend_name() in {"postgres", "postgresql"} - _setup_on_connect(self._engine) - self._session = _create_scoped_session(self._engine) - - - def is_postgres(self): - return self._is_postgres + engine = _create_engine(url, **engine_kwargs) + _setup_on_connect(engine) + self._session = _create_scoped_session(engine) def execute(self, statement): diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 7519b1e..d6ba10d 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -24,10 +24,10 @@ def __init__(self, dialect, sql, *args, **kwargs): def _parse(self): formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) - num_of_statements = len(parsed_statements) - if num_of_statements == 0: + statement_count = len(parsed_statements) + if statement_count == 0: raise RuntimeError("missing statement") - elif num_of_statements > 1: + elif statement_count > 1: raise RuntimeError("too many statements at once") return parsed_statements[0] @@ -35,9 +35,9 @@ def _parse(self): def _parse_command(self): for token in self._statement: - if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: + if _is_command_token(token): token_value = token.value.upper() - if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: + if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: command = token_value break else: @@ -49,11 +49,11 @@ def _parse_command(self): def _bind_params(self): tokens = self._tokenize() paramstyle, placeholders = self._parse_placeholders(tokens) - if paramstyle in [Paramstyle.FORMAT, Paramstyle.QMARK]: + if paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: tokens = self._bind_format_or_qmark(placeholders, tokens) elif paramstyle == Paramstyle.NUMERIC: tokens = self._bind_numeric(placeholders, tokens) - if paramstyle in [Paramstyle.NAMED, Paramstyle.PYFORMAT]: + if paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: tokens = self._bind_named_or_pyformat(placeholders, tokens) tokens = _escape_verbatim_colons(tokens) @@ -154,10 +154,10 @@ def _escape(self, value): sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) if isinstance(value, bytes): - if self._dialect.name in ["mysql", "sqlite"]: + if self._dialect.name in {"mysql", "sqlite"}: # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") - if self._dialect.name in ["postgres", "postgresql"]: + if self._dialect.name in {"postgres", "postgresql"}: # https://dba.stackexchange.com/a/203359 return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") @@ -235,7 +235,7 @@ def _parse_placeholder(token): if token.value == "%s": return Paramstyle.FORMAT, None - # E.g., %(foo) + # E.g., %(foo)s matches = re.search(r"%\((\w+)\)s$", token.value) if matches: return Paramstyle.PYFORMAT, matches.group(1) @@ -253,6 +253,10 @@ def _escape_verbatim_colons(tokens): return tokens +def _is_command_token(token): + return token.ttype in {sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} + + def _is_string_literal(token): return token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] diff --git a/src/cs50/sql.py b/src/cs50/sql.py index a1f7dbd..64aa83d 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -14,14 +14,16 @@ class SQL: def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) + self._dialect = self._session.get_bind().dialect + self._is_postgres = self._dialect in {"postgres", "postgresql"} self._autocommit = False def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" - statement = Statement(self._session.get_bind().dialect, sql, *args, **kwargs) + statement = Statement(self._dialect, sql, *args, **kwargs) command = statement.get_command() - if command in ["BEGIN", "START"]: + if command in {"BEGIN", "START"}: self._autocommit = False if self._autocommit: @@ -33,18 +35,15 @@ def execute(self, sql, *args, **kwargs): self._session.execute("COMMIT") self._session.remove() - if command in ["COMMIT", "ROLLBACK"]: + if command in {"COMMIT", "ROLLBACK"}: self._autocommit = True self._session.remove() if command == "SELECT": ret = _fetch_select_result(result) elif command == "INSERT": - if self._session.is_postgres(): - ret = self._get_last_val() - else: - ret = result.lastrowid if result.rowcount == 1 else None - elif command in ["DELETE", "UPDATE"]: + ret = self._last_row_id_or_none(result) + elif command in {"DELETE", "UPDATE"}: ret = result.rowcount else: ret = True @@ -71,6 +70,16 @@ def _execute(self, statement): return result + def _last_row_id_or_none(self, result): + if self.is_postgres(): + return self._get_last_val() + return result.lastrowid if result.rowcount == 1 else None + + + def is_postgres(self): + return self._is_postgres + + def _get_last_val(self): try: return self._session.execute("SELECT LASTVAL()").first()[0] From f8afbccdb15f860dbcefb69febd4458ad3dc8673 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Sat, 10 Apr 2021 07:39:19 -0400 Subject: [PATCH 07/47] fix is_postgres --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 64aa83d..74ec9b2 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -15,7 +15,7 @@ class SQL: def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) self._dialect = self._session.get_bind().dialect - self._is_postgres = self._dialect in {"postgres", "postgresql"} + self._is_postgres = self._dialect.name in {"postgres", "postgresql"} self._autocommit = False From 62b998265a5a8928b4d03962e7e9df65d4fc2ea1 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Sat, 10 Apr 2021 07:47:29 -0400 Subject: [PATCH 08/47] abstract away engine creation --- src/cs50/_session.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 441371a..3aff4f7 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -5,13 +5,12 @@ import sqlite3 class Session: + """Wraps a SQLAlchemy scoped session""" def __init__(self, url, **engine_kwargs): if _is_sqlite_url(url): _assert_sqlite_file_exists(url) - engine = _create_engine(url, **engine_kwargs) - _setup_on_connect(engine) - self._session = _create_scoped_session(engine) + self._session = _create_session(url, **engine_kwargs) def execute(self, statement): @@ -34,6 +33,12 @@ def _assert_sqlite_file_exists(url): raise RuntimeError(f"not a file: {path}") +def _create_session(url, **engine_kwargs): + engine = _create_engine(url, **engine_kwargs) + _setup_on_connect(engine) + return _create_scoped_session(engine) + + def _create_engine(url, **kwargs): try: engine = sqlalchemy.create_engine(url, **kwargs) From da613bec033538cf4df818479a92b12ef47867f6 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Sat, 10 Apr 2021 09:01:58 -0400 Subject: [PATCH 09/47] fix pylint errors --- src/cs50/__init__.py | 9 ++++-- src/cs50/_flask.py | 38 ++++++++++++++++++++++++++ src/cs50/_logger.py | 17 ++++++++---- src/cs50/_session.py | 6 +++- src/cs50/_statement.py | 62 +++++++++++++++++++++++------------------- src/cs50/cs50.py | 48 +++++++++++++++++--------------- src/cs50/flask.py | 37 ------------------------- src/cs50/sql.py | 12 ++++---- 8 files changed, 126 insertions(+), 103 deletions(-) create mode 100644 src/cs50/_flask.py delete mode 100644 src/cs50/flask.py diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index f04da00..b75f415 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,6 +1,9 @@ -from ._logger import _setup_logger -_setup_logger() +"""Exposes API, wraps flask, and sets up logging""" from .cs50 import get_float, get_int, get_string -from . import flask from .sql import SQL +from ._logger import _setup_logger +from ._flask import _wrap_flask + +_setup_logger() +_wrap_flask() diff --git a/src/cs50/_flask.py b/src/cs50/_flask.py new file mode 100644 index 0000000..d65a8a5 --- /dev/null +++ b/src/cs50/_flask.py @@ -0,0 +1,38 @@ +"""Hooks into flask importing to support X-Forwarded-Proto header in online IDE""" + +import os +import pkgutil +import sys + +from distutils.version import StrictVersion +from werkzeug.middleware.proxy_fix import ProxyFix + + +def _wrap_flask(): + if "flask" in sys.modules: + _support_x_forwarded_proto(sys.modules["flask"]) + else: + flask_loader = pkgutil.get_loader('flask') + if flask_loader: + _exec_module_before = flask_loader.exec_module + + def _exec_module_after(*args, **kwargs): + _exec_module_before(*args, **kwargs) + _support_x_forwarded_proto(sys.modules["flask"]) + + flask_loader.exec_module = _exec_module_after + + +def _support_x_forwarded_proto(flask_module): + if flask_module is None: + return + + if flask_module.__version__ < StrictVersion("1.0"): + return + + if os.getenv("CS50_IDE_TYPE") == "online": + _flask_init_before = flask_module.Flask.__init__ + def _flask_init_after(self, *args, **kwargs): + _flask_init_before(self, *args, **kwargs) + self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy + flask_module.Flask.__init__ = _flask_init_after diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py index c489111..df021a3 100644 --- a/src/cs50/_logger.py +++ b/src/cs50/_logger.py @@ -1,3 +1,5 @@ +"""Sets up logging for cs50 library""" + import logging import os.path import re @@ -14,7 +16,8 @@ def _setup_logger(): try: # Patch formatException - logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) + formatter = logging.root.handlers[0].formatter + formatter.formatException = lambda exc_info: _format_exception(*exc_info) except IndexError: pass @@ -29,14 +32,15 @@ def _setup_logger(): handler.setLevel(logging.DEBUG) formatter = logging.Formatter("%(levelname)s: %(message)s") - formatter.formatException = lambda exc_info: _formatException(*exc_info) + formatter.formatException = lambda exc_info: _format_exception(*exc_info) handler.setFormatter(formatter) _logger.addHandler(handler) - sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) + sys.excepthook = lambda type_, value, exc_tb: print( + _format_exception(type_, value, exc_tb), file=sys.stderr) -def _formatException(type, value, tb): +def _format_exception(type_, value, exc_tb): """ Format traceback, darkening entries from global site-packages directories and user-specific site-packages directory. @@ -48,11 +52,12 @@ def _formatException(type, value, tb): # Highlight lines not referring to files in site-packages lines = [] - for line in traceback.format_exception(type, value, tb): + for line in traceback.format_exception(type_, value, exc_tb): matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) if matches and matches.group(1).startswith(packages): lines += line else: matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) + lines.append( + matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) return "".join(lines).rstrip() diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 3aff4f7..cd23453 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -1,8 +1,10 @@ +"""Wraps a SQLAlchemy scoped session""" + import os +import sqlite3 import sqlalchemy import sqlalchemy.orm -import sqlite3 class Session: """Wraps a SQLAlchemy scoped session""" @@ -14,6 +16,8 @@ def __init__(self, url, **engine_kwargs): def execute(self, statement): + """Converts statement to str and executes it""" + # pylint: disable=no-member return self._session.execute(sqlalchemy.text(str(statement))) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index d6ba10d..7a38c90 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,3 +1,5 @@ +"""Parses a SQL statement and binds its parameters""" + import collections import datetime import enum @@ -8,6 +10,7 @@ class Statement: + """Parses and binds a SQL statement""" def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -21,13 +24,14 @@ def __init__(self, dialect, sql, *args, **kwargs): self._command = self._parse_command() self._tokens = self._bind_params() + def _parse(self): formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) statement_count = len(parsed_statements) if statement_count == 0: raise RuntimeError("missing statement") - elif statement_count > 1: + if statement_count > 1: raise RuntimeError("too many statements at once") return parsed_statements[0] @@ -49,11 +53,11 @@ def _parse_command(self): def _bind_params(self): tokens = self._tokenize() paramstyle, placeholders = self._parse_placeholders(tokens) - if paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: + if paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: tokens = self._bind_format_or_qmark(placeholders, tokens) - elif paramstyle == Paramstyle.NUMERIC: + elif paramstyle == _Paramstyle.NUMERIC: tokens = self._bind_numeric(placeholders, tokens) - if paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: + if paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: tokens = self._bind_named_or_pyformat(placeholders, tokens) tokens = _escape_verbatim_colons(tokens) @@ -86,9 +90,9 @@ def _parse_placeholders(self, tokens): def _default_paramstyle(self): paramstyle = None if self._args: - paramstyle = Paramstyle.QMARK + paramstyle = _Paramstyle.QMARK elif self._kwargs: - paramstyle = Paramstyle.NAMED + paramstyle = _Paramstyle.NAMED return paramstyle @@ -118,8 +122,10 @@ def _bind_numeric(self, placeholders, tokens): unused_arg_indices.remove(num) if len(unused_arg_indices) > 0: - unused_args = ", ".join([str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)]) - raise RuntimeError(f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})") + unused_args = ", ".join( + [str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)]) + raise RuntimeError( + f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})") return tokens @@ -134,7 +140,9 @@ def _bind_named_or_pyformat(self, placeholders, tokens): unused_params.remove(param_name) if len(unused_params) > 0: - raise RuntimeError("unused value{'' if len(unused_params) == 1 else 's'} ({', '.join(sorted(unused_params))})") + joined_unused_params = ", ".join(sorted(unused_params)) + raise RuntimeError( + f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") return tokens @@ -144,7 +152,7 @@ def _escape(self, value): Escapes value using engine's conversion function. https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor """ - + # pylint: disable=too-many-return-statements if isinstance(value, (list, tuple)): return self._escape_iterable(value) @@ -163,20 +171,18 @@ def _escape(self, value): raise RuntimeError(f"unsupported value: {value}") + string_processor = sqlalchemy.types.String().literal_processor(self._dialect) if isinstance(value, datetime.date): return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d"))) + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d"))) if isinstance(value, datetime.datetime): return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S"))) if isinstance(value, datetime.time): return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%H:%M:%S"))) + sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S"))) if isinstance(value, float): return sqlparse.sql.Token( @@ -189,9 +195,7 @@ def _escape(self, value): sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) if isinstance(value, str): - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._dialect)(value)) + return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) if value is None: return sqlparse.sql.Token( @@ -207,6 +211,7 @@ def _escape_iterable(self, iterable): def get_command(self): + """Returns statement command (e.g., SELECT) or None""" return self._command @@ -220,25 +225,25 @@ def _is_placeholder(token): def _parse_placeholder(token): if token.value == "?": - return Paramstyle.QMARK, None + return _Paramstyle.QMARK, None # E.g., :1 matches = re.search(r"^:([1-9]\d*)$", token.value) if matches: - return Paramstyle.NUMERIC, int(matches.group(1)) - 1 + return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 # E.g., :foo matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) if matches: - return Paramstyle.NAMED, matches.group(1) + return _Paramstyle.NAMED, matches.group(1) if token.value == "%s": - return Paramstyle.FORMAT, None + return _Paramstyle.FORMAT, None # E.g., %(foo)s matches = re.search(r"%\((\w+)\)s$", token.value) if matches: - return Paramstyle.PYFORMAT, matches.group(1) + return _Paramstyle.PYFORMAT, matches.group(1) raise RuntimeError(f"{token.value}: invalid placeholder") @@ -246,15 +251,16 @@ def _parse_placeholder(token): def _escape_verbatim_colons(tokens): for token in tokens: if _is_string_literal(token): - token.value = re.sub("(^'|\s+):", r"\1\:", token.value) + token.value = re.sub(r"(^'|\s+):", r"\1\:", token.value) elif _is_identifier(token): - token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) + token.value = re.sub(r"(^\"|\s+):", r"\1\:", token.value) return tokens def _is_command_token(token): - return token.ttype in {sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} + return token.ttype in { + sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} def _is_string_literal(token): @@ -265,7 +271,7 @@ def _is_identifier(token): return token.ttype == sqlparse.tokens.Literal.String.Symbol -class Paramstyle(enum.Enum): +class _Paramstyle(enum.Enum): FORMAT = enum.auto() NAMED = enum.auto() NUMERIC = enum.auto() diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 573d862..24c748b 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -1,3 +1,5 @@ +"""Exposes simple API for getting and validating user input""" + import re import sys @@ -16,14 +18,14 @@ def get_float(prompt): def _get_float(prompt): - s = get_string(prompt) - if s is None: - return + user_input = get_string(prompt) + if user_input is None: + return None - if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s): - return float(s) + if len(user_input) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", user_input): + return float(user_input) - raise ValueError(f"invalid float literal: {s}") + raise ValueError(f"invalid float literal: {user_input}") def get_int(prompt): @@ -40,14 +42,14 @@ def get_int(prompt): def _get_int(prompt): - s = get_string(prompt) - if s is None: - return + user_input = get_string(prompt) + if user_input is None: + return None - if re.search(r"^[+-]?\d+$", s): - return int(s, 10) + if re.search(r"^[+-]?\d+$", user_input): + return int(user_input, 10) - raise ValueError(f"invalid int literal for base 10: {s}") + raise ValueError(f"invalid int literal for base 10: {user_input}") def get_string(prompt): @@ -57,13 +59,13 @@ def get_string(prompt): as line endings. If user inputs only a line ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF). """ - if type(prompt) is not str: + if not isinstance(prompt, str): raise TypeError("prompt must be of type str") try: return _get_input(prompt) except EOFError: - return + return None def _get_input(prompt): @@ -76,18 +78,20 @@ class _flushfile(): http://stackoverflow.com/a/231216 """ - def __init__(self, f): - self.f = f + def __init__(self, stream): + self.stream = stream def __getattr__(self, name): - return object.__getattribute__(self.f, name) + return object.__getattribute__(self.stream, name) - def write(self, x): - self.f.write(x) - self.f.flush() + def write(self, data): + """Writes data to stream""" + self.stream.write(data) + self.stream.flush() -def disable_buffering(): +def disable_output_buffering(): + """Disables output buffering to prevent prompts from being buffered""" sys.stderr = _flushfile(sys.stderr) sys.stdout = _flushfile(sys.stdout) -disable_buffering() +disable_output_buffering() diff --git a/src/cs50/flask.py b/src/cs50/flask.py deleted file mode 100644 index a0e077a..0000000 --- a/src/cs50/flask.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import pkgutil -import sys - -from distutils.version import StrictVersion -from werkzeug.middleware.proxy_fix import ProxyFix - -def _wrap_flask(f): - if f is None: - return - - if f.__version__ < StrictVersion("1.0"): - return - - if os.getenv("CS50_IDE_TYPE") == "online": - _flask_init_before = f.Flask.__init__ - def _flask_init_after(self, *args, **kwargs): - _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy - f.Flask.__init__ = _flask_init_after - - -# If Flask was imported before cs50 -if "flask" in sys.modules: - _wrap_flask(sys.modules["flask"]) - -# If Flask wasn't imported -else: - flask_loader = pkgutil.get_loader('flask') - if flask_loader: - _exec_module_before = flask_loader.exec_module - - def _exec_module_after(*args, **kwargs): - _exec_module_before(*args, **kwargs) - _wrap_flask(sys.modules["flask"]) - - flask_loader.exec_module = _exec_module_after diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 74ec9b2..0510f17 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,3 +1,5 @@ +"""Wraps SQLAlchemy""" + import decimal import logging import warnings @@ -12,6 +14,7 @@ class SQL: + """Wraps SQLAlchemy""" def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) self._dialect = self._session.get_bind().dialect @@ -71,15 +74,11 @@ def _execute(self, statement): def _last_row_id_or_none(self, result): - if self.is_postgres(): + if self._is_postgres: return self._get_last_val() return result.lastrowid if result.rowcount == 1 else None - def is_postgres(self): - return self._is_postgres - - def _get_last_val(self): try: return self._session.execute("SELECT LASTVAL()").first()[0] @@ -88,8 +87,9 @@ def _get_last_val(self): def init_app(self, app): + """Registers a teardown_appcontext listener to remove session and enables logging""" @app.teardown_appcontext - def shutdown_session(_): + def _(_): self._session.remove() logging.getLogger("cs50").disabled = False From 4a593dd4ab27ade5978906b5d9be40ac3af78ed2 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 15:58:51 -0400 Subject: [PATCH 10/47] rename _parse_command --- src/cs50/_statement.py | 20 ++++++++++---------- src/cs50/sql.py | 12 ++++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 7a38c90..598b131 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -21,7 +21,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._kwargs = kwargs self._statement = self._parse() - self._command = self._parse_command() + self._operation_keyword = self._get_operation_keyword() self._tokens = self._bind_params() @@ -37,17 +37,17 @@ def _parse(self): return parsed_statements[0] - def _parse_command(self): + def _get_operation_keyword(self): for token in self._statement: - if _is_command_token(token): + if _is_operation_token(token): token_value = token.value.upper() if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: - command = token_value + operation_keyword = token_value break else: - command = None + operation_keyword = None - return command + return operation_keyword def _bind_params(self): @@ -210,9 +210,9 @@ def _escape_iterable(self, iterable): sqlparse.parse(", ".join([str(self._escape(v)) for v in iterable]))) - def get_command(self): - """Returns statement command (e.g., SELECT) or None""" - return self._command + def get_operation_keyword(self): + """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" + return self._operation_keyword def __str__(self): @@ -258,7 +258,7 @@ def _escape_verbatim_colons(tokens): return tokens -def _is_command_token(token): +def _is_operation_token(token): return token.ttype in { sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 0510f17..fca57d2 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -25,8 +25,8 @@ def __init__(self, url, **engine_kwargs): def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = Statement(self._dialect, sql, *args, **kwargs) - command = statement.get_command() - if command in {"BEGIN", "START"}: + operation_keyword = statement.get_operation_keyword() + if operation_keyword in {"BEGIN", "START"}: self._autocommit = False if self._autocommit: @@ -38,15 +38,15 @@ def execute(self, sql, *args, **kwargs): self._session.execute("COMMIT") self._session.remove() - if command in {"COMMIT", "ROLLBACK"}: + if operation_keyword in {"COMMIT", "ROLLBACK"}: self._autocommit = True self._session.remove() - if command == "SELECT": + if operation_keyword == "SELECT": ret = _fetch_select_result(result) - elif command == "INSERT": + elif operation_keyword == "INSERT": ret = self._last_row_id_or_none(result) - elif command in {"DELETE", "UPDATE"}: + elif operation_keyword in {"DELETE", "UPDATE"}: ret = result.rowcount else: ret = True From 6fcf7ed469ad8cdff1628e591eff3d7eb6767f34 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 16:01:36 -0400 Subject: [PATCH 11/47] rename _bind_params --- src/cs50/_statement.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 598b131..789acdf 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,4 +1,4 @@ -"""Parses a SQL statement and binds its parameters""" +"""Parses a SQL statement and replaces placeholders with parameters""" import collections import datetime @@ -10,7 +10,7 @@ class Statement: - """Parses and binds a SQL statement""" + """Parses a SQL statement and replaces placeholders with parameters""" def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -22,7 +22,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._statement = self._parse() self._operation_keyword = self._get_operation_keyword() - self._tokens = self._bind_params() + self._tokens = self._replace_placeholders_with_params() def _parse(self): @@ -50,15 +50,15 @@ def _get_operation_keyword(self): return operation_keyword - def _bind_params(self): + def _replace_placeholders_with_params(self): tokens = self._tokenize() paramstyle, placeholders = self._parse_placeholders(tokens) if paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - tokens = self._bind_format_or_qmark(placeholders, tokens) + tokens = self._replace_format_or_qmark_placeholders(placeholders, tokens) elif paramstyle == _Paramstyle.NUMERIC: - tokens = self._bind_numeric(placeholders, tokens) + tokens = self._replace_numeric_placeholders(placeholders, tokens) if paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - tokens = self._bind_named_or_pyformat(placeholders, tokens) + tokens = self._replace_named_or_pyformat_placeholders(placeholders, tokens) tokens = _escape_verbatim_colons(tokens) return tokens @@ -97,7 +97,7 @@ def _default_paramstyle(self): return paramstyle - def _bind_format_or_qmark(self, placeholders, tokens): + def _replace_format_or_qmark_placeholders(self, placeholders, tokens): if len(placeholders) != len(self._args): _placeholders = ", ".join([str(token) for token in placeholders.values()]) _args = ", ".join([str(self._escape(arg)) for arg in self._args]) @@ -112,7 +112,7 @@ def _bind_format_or_qmark(self, placeholders, tokens): return tokens - def _bind_numeric(self, placeholders, tokens): + def _replace_numeric_placeholders(self, placeholders, tokens): unused_arg_indices = set(range(len(self._args))) for token_index, num in placeholders.items(): if num >= len(self._args): @@ -130,7 +130,7 @@ def _bind_numeric(self, placeholders, tokens): return tokens - def _bind_named_or_pyformat(self, placeholders, tokens): + def _replace_named_or_pyformat_placeholders(self, placeholders, tokens): unused_params = set(self._kwargs.keys()) for token_index, param_name in placeholders.items(): if param_name not in self._kwargs: From a4989eb05e078f78ac5dc488300b804ae3ecb996 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 19:37:49 -0400 Subject: [PATCH 12/47] wrap flask wrapper the IDE proxy now handles forcing https --- src/cs50/__init__.py | 4 +--- src/cs50/_flask.py | 38 ------------------------------- tests/redirect/application.py | 12 ---------- tests/redirect/templates/foo.html | 1 - 4 files changed, 1 insertion(+), 54 deletions(-) delete mode 100644 src/cs50/_flask.py delete mode 100644 tests/redirect/application.py delete mode 100644 tests/redirect/templates/foo.html diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index b75f415..fa07171 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,9 +1,7 @@ -"""Exposes API, wraps flask, and sets up logging""" +"""Exposes API and sets up logging""" from .cs50 import get_float, get_int, get_string from .sql import SQL from ._logger import _setup_logger -from ._flask import _wrap_flask _setup_logger() -_wrap_flask() diff --git a/src/cs50/_flask.py b/src/cs50/_flask.py deleted file mode 100644 index d65a8a5..0000000 --- a/src/cs50/_flask.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Hooks into flask importing to support X-Forwarded-Proto header in online IDE""" - -import os -import pkgutil -import sys - -from distutils.version import StrictVersion -from werkzeug.middleware.proxy_fix import ProxyFix - - -def _wrap_flask(): - if "flask" in sys.modules: - _support_x_forwarded_proto(sys.modules["flask"]) - else: - flask_loader = pkgutil.get_loader('flask') - if flask_loader: - _exec_module_before = flask_loader.exec_module - - def _exec_module_after(*args, **kwargs): - _exec_module_before(*args, **kwargs) - _support_x_forwarded_proto(sys.modules["flask"]) - - flask_loader.exec_module = _exec_module_after - - -def _support_x_forwarded_proto(flask_module): - if flask_module is None: - return - - if flask_module.__version__ < StrictVersion("1.0"): - return - - if os.getenv("CS50_IDE_TYPE") == "online": - _flask_init_before = flask_module.Flask.__init__ - def _flask_init_after(self, *args, **kwargs): - _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy - flask_module.Flask.__init__ = _flask_init_after diff --git a/tests/redirect/application.py b/tests/redirect/application.py deleted file mode 100644 index 6aff187..0000000 --- a/tests/redirect/application.py +++ /dev/null @@ -1,12 +0,0 @@ -import cs50 -from flask import Flask, redirect, render_template - -app = Flask(__name__) - -@app.route("/") -def index(): - return redirect("/foo") - -@app.route("/foo") -def foo(): - return render_template("foo.html") diff --git a/tests/redirect/templates/foo.html b/tests/redirect/templates/foo.html deleted file mode 100644 index 257cc56..0000000 --- a/tests/redirect/templates/foo.html +++ /dev/null @@ -1 +0,0 @@ -foo From e8827bfa0b68c06b48822d03d30cf84c882b45bf Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 20:38:01 -0400 Subject: [PATCH 13/47] factor out sanitizer --- src/cs50/_sql_sanitizer.py | 86 +++++++++++++++++++ src/cs50/_statement.py | 167 ++++++++++--------------------------- 2 files changed, 132 insertions(+), 121 deletions(-) create mode 100644 src/cs50/_sql_sanitizer.py diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py new file mode 100644 index 0000000..c2f35c4 --- /dev/null +++ b/src/cs50/_sql_sanitizer.py @@ -0,0 +1,86 @@ +"""Escapes SQL values""" + +import datetime +import re + +import sqlalchemy +import sqlparse + + +class SQLSanitizer: + """Escapes SQL values""" + + def __init__(self, dialect): + self._dialect = dialect + + + def escape(self, value): + """ + Escapes value using engine's conversion function. + https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor + """ + # pylint: disable=too-many-return-statements + if isinstance(value, (list, tuple)): + return self.escape_iterable(value) + + if isinstance(value, bool): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) + + if isinstance(value, bytes): + if self._dialect.name in {"mysql", "sqlite"}: + # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") + if self._dialect.name in {"postgres", "postgresql"}: + # https://dba.stackexchange.com/a/203359 + return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") + + raise RuntimeError(f"unsupported value: {value}") + + string_processor = sqlalchemy.types.String().literal_processor(self._dialect) + if isinstance(value, datetime.date): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d"))) + + if isinstance(value, datetime.datetime): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S"))) + + if isinstance(value, datetime.time): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S"))) + + if isinstance(value, float): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Float().literal_processor(self._dialect)(value)) + + if isinstance(value, int): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) + + if isinstance(value, str): + return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) + + if value is None: + return sqlparse.sql.Token( + sqlparse.tokens.Keyword, + sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) + + raise RuntimeError(f"unsupported value: {value}") + + + def escape_iterable(self, iterable): + """Escapes a collection of values (e.g., list, tuple)""" + return sqlparse.sql.TokenList( + sqlparse.parse(", ".join([str(self.escape(v)) for v in iterable]))) + + +def escape_verbatim_colon(value): + """Escapes verbatim colon from a value so as it is not confused with a placeholder""" + + # E.g., ':foo, ":foo, :foo will be replaced with + # '\:foo, "\:foo, \:foo respectively + return re.sub(r"(^(?:'|\")|\s+):", r"\1\:", value) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 789acdf..7222f0e 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,40 +1,28 @@ """Parses a SQL statement and replaces placeholders with parameters""" import collections -import datetime import enum import re -import sqlalchemy import sqlparse +from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon + + class Statement: """Parses a SQL statement and replaces placeholders with parameters""" def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") - self._dialect = dialect - self._sql = sql + self._sql_sanitizer = SQLSanitizer(dialect) self._args = args self._kwargs = kwargs - - self._statement = self._parse() + self._statement = _parse(sql) self._operation_keyword = self._get_operation_keyword() - self._tokens = self._replace_placeholders_with_params() - - - def _parse(self): - formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip() - parsed_statements = sqlparse.parse(formatted_statements) - statement_count = len(parsed_statements) - if statement_count == 0: - raise RuntimeError("missing statement") - if statement_count > 1: - raise RuntimeError("too many statements at once") - - return parsed_statements[0] + self._tokens = self._tokenize() + self._replace_placeholders_with_params() def _get_operation_keyword(self): @@ -50,28 +38,26 @@ def _get_operation_keyword(self): return operation_keyword + def _tokenize(self): + return list(self._statement.flatten()) + + def _replace_placeholders_with_params(self): - tokens = self._tokenize() - paramstyle, placeholders = self._parse_placeholders(tokens) + paramstyle, placeholders = self._parse_placeholders() if paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - tokens = self._replace_format_or_qmark_placeholders(placeholders, tokens) + self._replace_format_or_qmark_placeholders(placeholders) elif paramstyle == _Paramstyle.NUMERIC: - tokens = self._replace_numeric_placeholders(placeholders, tokens) + self._replace_numeric_placeholders(placeholders) if paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - tokens = self._replace_named_or_pyformat_placeholders(placeholders, tokens) - - tokens = _escape_verbatim_colons(tokens) - return tokens - + self._replace_named_or_pyformat_placeholders(placeholders) - def _tokenize(self): - return list(self._statement.flatten()) + self._escape_verbatim_colons() - def _parse_placeholders(self, tokens): + def _parse_placeholders(self): paramstyle = None placeholders = collections.OrderedDict() - for index, token in enumerate(tokens): + for index, token in enumerate(self._tokens): if _is_placeholder(token): _paramstyle, name = _parse_placeholder(token) if paramstyle is None: @@ -97,46 +83,42 @@ def _default_paramstyle(self): return paramstyle - def _replace_format_or_qmark_placeholders(self, placeholders, tokens): + def _replace_format_or_qmark_placeholders(self, placeholders): if len(placeholders) != len(self._args): _placeholders = ", ".join([str(token) for token in placeholders.values()]) - _args = ", ".join([str(self._escape(arg)) for arg in self._args]) + _args = ", ".join([str(self._sql_sanitizer.escape(arg)) for arg in self._args]) if len(placeholders) < len(self._args): raise RuntimeError(f"fewer placeholders ({_placeholders}) than values ({_args})") raise RuntimeError(f"more placeholders ({_placeholders}) than values ({_args})") for arg_index, token_index in enumerate(placeholders.keys()): - tokens[token_index] = self._escape(self._args[arg_index]) + self._tokens[token_index] = self._sql_sanitizer.escape(self._args[arg_index]) - return tokens - - def _replace_numeric_placeholders(self, placeholders, tokens): - unused_arg_indices = set(range(len(self._args))) + def _replace_numeric_placeholders(self, placeholders): + unused_arg_idxs = set(range(len(self._args))) for token_index, num in placeholders.items(): if num >= len(self._args): raise RuntimeError(f"missing value for placeholder ({num + 1})") - tokens[token_index] = self._escape(self._args[num]) - unused_arg_indices.remove(num) + self._tokens[token_index] = self._sql_sanitizer.escape(self._args[num]) + unused_arg_idxs.remove(num) - if len(unused_arg_indices) > 0: + if len(unused_arg_idxs) > 0: unused_args = ", ".join( - [str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)]) + [str(self._sql_sanitizer.escape(self._args[i])) for i in sorted(unused_arg_idxs)]) raise RuntimeError( - f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})") - - return tokens + f"unused value{'' if len(unused_arg_idxs) == 1 else 's'} ({unused_args})") - def _replace_named_or_pyformat_placeholders(self, placeholders, tokens): + def _replace_named_or_pyformat_placeholders(self, placeholders): unused_params = set(self._kwargs.keys()) for token_index, param_name in placeholders.items(): if param_name not in self._kwargs: raise RuntimeError(f"missing value for placeholder ({param_name})") - tokens[token_index] = self._escape(self._kwargs[param_name]) + self._tokens[token_index] = self._sql_sanitizer.escape(self._kwargs[param_name]) unused_params.remove(param_name) if len(unused_params) > 0: @@ -144,70 +126,11 @@ def _replace_named_or_pyformat_placeholders(self, placeholders, tokens): raise RuntimeError( f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") - return tokens - - - def _escape(self, value): - """ - Escapes value using engine's conversion function. - https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor - """ - # pylint: disable=too-many-return-statements - if isinstance(value, (list, tuple)): - return self._escape_iterable(value) - if isinstance(value, bool): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) - - if isinstance(value, bytes): - if self._dialect.name in {"mysql", "sqlite"}: - # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") - if self._dialect.name in {"postgres", "postgresql"}: - # https://dba.stackexchange.com/a/203359 - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") - - raise RuntimeError(f"unsupported value: {value}") - - string_processor = sqlalchemy.types.String().literal_processor(self._dialect) - if isinstance(value, datetime.date): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d"))) - - if isinstance(value, datetime.datetime): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S"))) - - if isinstance(value, datetime.time): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S"))) - - if isinstance(value, float): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._dialect)(value)) - - if isinstance(value, int): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) - - if isinstance(value, str): - return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) - - if value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) - - raise RuntimeError(f"unsupported value: {value}") - - - def _escape_iterable(self, iterable): - return sqlparse.sql.TokenList( - sqlparse.parse(", ".join([str(self._escape(v)) for v in iterable]))) + def _escape_verbatim_colons(self): + for token in self._tokens: + if _is_string_literal(token) or _is_identifier(token): + token.value = escape_verbatim_colon(token.value) def get_operation_keyword(self): @@ -219,6 +142,18 @@ def __str__(self): return "".join([str(token) for token in self._tokens]) +def _parse(sql): + formatted_statements = sqlparse.format(sql, strip_comments=True).strip() + parsed_statements = sqlparse.parse(formatted_statements) + statement_count = len(parsed_statements) + if statement_count == 0: + raise RuntimeError("missing statement") + if statement_count > 1: + raise RuntimeError("too many statements at once") + + return parsed_statements[0] + + def _is_placeholder(token): return token.ttype == sqlparse.tokens.Name.Placeholder @@ -248,16 +183,6 @@ def _parse_placeholder(token): raise RuntimeError(f"{token.value}: invalid placeholder") -def _escape_verbatim_colons(tokens): - for token in tokens: - if _is_string_literal(token): - token.value = re.sub(r"(^'|\s+):", r"\1\:", token.value) - elif _is_identifier(token): - token.value = re.sub(r"(^\"|\s+):", r"\1\:", token.value) - - return tokens - - def _is_operation_token(token): return token.ttype in { sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} From bd330a703b6a29bf23d6f58749d5b0c857bbd32d Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 20:58:54 -0400 Subject: [PATCH 14/47] promote paramstyle and placeholders to instance variables --- src/cs50/_statement.py | 97 ++++++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 7222f0e..34f9247 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -6,12 +6,11 @@ import sqlparse - from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon class Statement: - """Parses a SQL statement and replaces placeholders with parameters""" + """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -19,23 +18,12 @@ def __init__(self, dialect, sql, *args, **kwargs): self._sql_sanitizer = SQLSanitizer(dialect) self._args = args self._kwargs = kwargs - self._statement = _parse(sql) - self._operation_keyword = self._get_operation_keyword() + self._statement = _format_and_parse(sql) self._tokens = self._tokenize() + self._paramstyle = self._get_paramstyle() + self._placeholders = self._get_placeholders() self._replace_placeholders_with_params() - - - def _get_operation_keyword(self): - for token in self._statement: - if _is_operation_token(token): - token_value = token.value.upper() - if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: - operation_keyword = token_value - break - else: - operation_keyword = None - - return operation_keyword + self._operation_keyword = self._get_operation_keyword() def _tokenize(self): @@ -43,34 +31,40 @@ def _tokenize(self): def _replace_placeholders_with_params(self): - paramstyle, placeholders = self._parse_placeholders() - if paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - self._replace_format_or_qmark_placeholders(placeholders) - elif paramstyle == _Paramstyle.NUMERIC: - self._replace_numeric_placeholders(placeholders) - if paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - self._replace_named_or_pyformat_placeholders(placeholders) + if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: + self._replace_format_or_qmark_placeholders() + elif self._paramstyle == _Paramstyle.NUMERIC: + self._replace_numeric_placeholders() + if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: + self._replace_named_or_pyformat_placeholders() self._escape_verbatim_colons() - def _parse_placeholders(self): + def _get_paramstyle(self): paramstyle = None + for token in self._tokens: + if _is_placeholder(token): + paramstyle, _ = _parse_placeholder(token) + break + + if paramstyle is None: + paramstyle = self._default_paramstyle() + + return paramstyle + + + def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): if _is_placeholder(token): - _paramstyle, name = _parse_placeholder(token) - if paramstyle is None: - paramstyle = _paramstyle - elif _paramstyle != paramstyle: + paramstyle, name = _parse_placeholder(token) + if paramstyle != self._paramstyle: raise RuntimeError("inconsistent paramstyle") placeholders[index] = name - if paramstyle is None: - paramstyle = self._default_paramstyle() - - return paramstyle, placeholders + return placeholders def _default_paramstyle(self): @@ -83,22 +77,22 @@ def _default_paramstyle(self): return paramstyle - def _replace_format_or_qmark_placeholders(self, placeholders): - if len(placeholders) != len(self._args): - _placeholders = ", ".join([str(token) for token in placeholders.values()]) + def _replace_format_or_qmark_placeholders(self): + if len(self._placeholders) != len(self._args): + placeholders = ", ".join([str(token) for token in self._placeholders.values()]) _args = ", ".join([str(self._sql_sanitizer.escape(arg)) for arg in self._args]) - if len(placeholders) < len(self._args): - raise RuntimeError(f"fewer placeholders ({_placeholders}) than values ({_args})") + if len(self._placeholders) < len(self._args): + raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({_args})") - raise RuntimeError(f"more placeholders ({_placeholders}) than values ({_args})") + raise RuntimeError(f"more placeholders ({placeholders}) than values ({_args})") - for arg_index, token_index in enumerate(placeholders.keys()): + for arg_index, token_index in enumerate(self._placeholders.keys()): self._tokens[token_index] = self._sql_sanitizer.escape(self._args[arg_index]) - def _replace_numeric_placeholders(self, placeholders): + def _replace_numeric_placeholders(self): unused_arg_idxs = set(range(len(self._args))) - for token_index, num in placeholders.items(): + for token_index, num in self._placeholders.items(): if num >= len(self._args): raise RuntimeError(f"missing value for placeholder ({num + 1})") @@ -112,9 +106,9 @@ def _replace_numeric_placeholders(self, placeholders): f"unused value{'' if len(unused_arg_idxs) == 1 else 's'} ({unused_args})") - def _replace_named_or_pyformat_placeholders(self, placeholders): + def _replace_named_or_pyformat_placeholders(self): unused_params = set(self._kwargs.keys()) - for token_index, param_name in placeholders.items(): + for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: raise RuntimeError(f"missing value for placeholder ({param_name})") @@ -133,6 +127,19 @@ def _escape_verbatim_colons(self): token.value = escape_verbatim_colon(token.value) + def _get_operation_keyword(self): + for token in self._statement: + if _is_operation_token(token): + token_value = token.value.upper() + if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: + operation_keyword = token_value + break + else: + operation_keyword = None + + return operation_keyword + + def get_operation_keyword(self): """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" return self._operation_keyword @@ -142,7 +149,7 @@ def __str__(self): return "".join([str(token) for token in self._tokens]) -def _parse(sql): +def _format_and_parse(sql): formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) statement_count = len(parsed_statements) From db4cce1c25f9b9a3349c696b13078073e1ed8cd6 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 21:09:22 -0400 Subject: [PATCH 15/47] pass token type/value around --- src/cs50/_statement.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 34f9247..0d62266 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -44,8 +44,8 @@ def _replace_placeholders_with_params(self): def _get_paramstyle(self): paramstyle = None for token in self._tokens: - if _is_placeholder(token): - paramstyle, _ = _parse_placeholder(token) + if _is_placeholder(token.ttype): + paramstyle, _ = _parse_placeholder(token.value) break if paramstyle is None: @@ -57,8 +57,8 @@ def _get_paramstyle(self): def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): - if _is_placeholder(token): - paramstyle, name = _parse_placeholder(token) + if _is_placeholder(token.ttype): + paramstyle, name = _parse_placeholder(token.value) if paramstyle != self._paramstyle: raise RuntimeError("inconsistent paramstyle") @@ -123,13 +123,13 @@ def _replace_named_or_pyformat_placeholders(self): def _escape_verbatim_colons(self): for token in self._tokens: - if _is_string_literal(token) or _is_identifier(token): + if _is_string_literal(token.ttype) or _is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) def _get_operation_keyword(self): for token in self._statement: - if _is_operation_token(token): + if _is_operation_token(token.ttype): token_value = token.value.upper() if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: operation_keyword = token_value @@ -161,46 +161,46 @@ def _format_and_parse(sql): return parsed_statements[0] -def _is_placeholder(token): - return token.ttype == sqlparse.tokens.Name.Placeholder +def _is_placeholder(ttype): + return ttype == sqlparse.tokens.Name.Placeholder -def _parse_placeholder(token): - if token.value == "?": +def _parse_placeholder(value): + if value == "?": return _Paramstyle.QMARK, None # E.g., :1 - matches = re.search(r"^:([1-9]\d*)$", token.value) + matches = re.search(r"^:([1-9]\d*)$", value) if matches: return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 # E.g., :foo - matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) + matches = re.search(r"^:([a-zA-Z]\w*)$", value) if matches: return _Paramstyle.NAMED, matches.group(1) - if token.value == "%s": + if value == "%s": return _Paramstyle.FORMAT, None # E.g., %(foo)s - matches = re.search(r"%\((\w+)\)s$", token.value) + matches = re.search(r"%\((\w+)\)s$", value) if matches: return _Paramstyle.PYFORMAT, matches.group(1) - raise RuntimeError(f"{token.value}: invalid placeholder") + raise RuntimeError(f"{value}: invalid placeholder") -def _is_operation_token(token): - return token.ttype in { +def _is_operation_token(ttype): + return ttype in { sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} -def _is_string_literal(token): - return token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] +def _is_string_literal(ttype): + return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] -def _is_identifier(token): - return token.ttype == sqlparse.tokens.Literal.String.Symbol +def _is_identifier(ttype): + return ttype == sqlparse.tokens.Literal.String.Symbol class _Paramstyle(enum.Enum): From 88adfb95eb760e1d7be01b3c55566560d81125a8 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 21:16:04 -0400 Subject: [PATCH 16/47] rename methods --- src/cs50/_statement.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 0d62266..931504d 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,4 +1,4 @@ -"""Parses a SQL statement and replaces placeholders with parameters""" +"""Parses a SQL statement and replaces the placeholders with the corresponding parameters""" import collections import enum @@ -22,7 +22,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._tokens = self._tokenize() self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() - self._replace_placeholders_with_params() + self._plugin_escaped_params() self._operation_keyword = self._get_operation_keyword() @@ -30,13 +30,13 @@ def _tokenize(self): return list(self._statement.flatten()) - def _replace_placeholders_with_params(self): + def _plugin_escaped_params(self): if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - self._replace_format_or_qmark_placeholders() + self._plugin_format_or_qmark_params() elif self._paramstyle == _Paramstyle.NUMERIC: - self._replace_numeric_placeholders() + self._plugin_numeric_params() if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - self._replace_named_or_pyformat_placeholders() + self._plugin_named_or_pyformat_params() self._escape_verbatim_colons() @@ -77,7 +77,7 @@ def _default_paramstyle(self): return paramstyle - def _replace_format_or_qmark_placeholders(self): + def _plugin_format_or_qmark_params(self): if len(self._placeholders) != len(self._args): placeholders = ", ".join([str(token) for token in self._placeholders.values()]) _args = ", ".join([str(self._sql_sanitizer.escape(arg)) for arg in self._args]) @@ -90,7 +90,7 @@ def _replace_format_or_qmark_placeholders(self): self._tokens[token_index] = self._sql_sanitizer.escape(self._args[arg_index]) - def _replace_numeric_placeholders(self): + def _plugin_numeric_params(self): unused_arg_idxs = set(range(len(self._args))) for token_index, num in self._placeholders.items(): if num >= len(self._args): @@ -106,7 +106,7 @@ def _replace_numeric_placeholders(self): f"unused value{'' if len(unused_arg_idxs) == 1 else 's'} ({unused_args})") - def _replace_named_or_pyformat_placeholders(self): + def _plugin_named_or_pyformat_params(self): unused_params = set(self._kwargs.keys()) for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: From a4a88108f652ca7fb253d056ac449adacc02b5a4 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 21:17:33 -0400 Subject: [PATCH 17/47] move escape_verbatim_colons to constructor --- src/cs50/_statement.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 931504d..5fc41e7 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -23,6 +23,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() self._plugin_escaped_params() + self._escape_verbatim_colons() self._operation_keyword = self._get_operation_keyword() @@ -38,8 +39,6 @@ def _plugin_escaped_params(self): if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: self._plugin_named_or_pyformat_params() - self._escape_verbatim_colons() - def _get_paramstyle(self): paramstyle = None From f618840dbf0918a060aa8bcbc712a2b4a1c665d2 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 21:24:39 -0400 Subject: [PATCH 18/47] reorder methods --- src/cs50/_statement.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 5fc41e7..dc4013b 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -31,15 +31,6 @@ def _tokenize(self): return list(self._statement.flatten()) - def _plugin_escaped_params(self): - if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - self._plugin_format_or_qmark_params() - elif self._paramstyle == _Paramstyle.NUMERIC: - self._plugin_numeric_params() - if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - self._plugin_named_or_pyformat_params() - - def _get_paramstyle(self): paramstyle = None for token in self._tokens: @@ -53,6 +44,16 @@ def _get_paramstyle(self): return paramstyle + def _default_paramstyle(self): + paramstyle = None + if self._args: + paramstyle = _Paramstyle.QMARK + elif self._kwargs: + paramstyle = _Paramstyle.NAMED + + return paramstyle + + def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): @@ -66,14 +67,13 @@ def _get_placeholders(self): return placeholders - def _default_paramstyle(self): - paramstyle = None - if self._args: - paramstyle = _Paramstyle.QMARK - elif self._kwargs: - paramstyle = _Paramstyle.NAMED - - return paramstyle + def _plugin_escaped_params(self): + if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: + self._plugin_format_or_qmark_params() + elif self._paramstyle == _Paramstyle.NUMERIC: + self._plugin_numeric_params() + if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: + self._plugin_named_or_pyformat_params() def _plugin_format_or_qmark_params(self): From 9a153fe9dbca0157958bb577458ec53505c03421 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 22:17:36 -0400 Subject: [PATCH 19/47] use else --- src/cs50/_statement.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index dc4013b..f0ed325 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -37,8 +37,7 @@ def _get_paramstyle(self): if _is_placeholder(token.ttype): paramstyle, _ = _parse_placeholder(token.value) break - - if paramstyle is None: + else: paramstyle = self._default_paramstyle() return paramstyle From 456cea56083d586edae6ccedfe525bf0f64c4f77 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 22:56:48 -0400 Subject: [PATCH 20/47] escape args and kwargs in constructor --- src/cs50/_statement.py | 54 +++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index f0ed325..d02f844 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -16,8 +16,8 @@ def __init__(self, dialect, sql, *args, **kwargs): raise RuntimeError("cannot pass both positional and named parameters") self._sql_sanitizer = SQLSanitizer(dialect) - self._args = args - self._kwargs = kwargs + self._args = self._get_escaped_args(args) + self._kwargs = self._get_escaped_kwargs(kwargs) self._statement = _format_and_parse(sql) self._tokens = self._tokenize() self._paramstyle = self._get_paramstyle() @@ -27,6 +27,14 @@ def __init__(self, dialect, sql, *args, **kwargs): self._operation_keyword = self._get_operation_keyword() + def _get_escaped_args(self, args): + return [self._sql_sanitizer.escape(arg) for arg in args] + + + def _get_escaped_kwargs(self, kwargs): + return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()} + + def _tokenize(self): return list(self._statement.flatten()) @@ -76,32 +84,33 @@ def _plugin_escaped_params(self): def _plugin_format_or_qmark_params(self): + self._assert_valid_arg_count() + for arg_index, token_index in enumerate(self._placeholders.keys()): + self._tokens[token_index] = self._args[arg_index] + + + def _assert_valid_arg_count(self): if len(self._placeholders) != len(self._args): - placeholders = ", ".join([str(token) for token in self._placeholders.values()]) - _args = ", ".join([str(self._sql_sanitizer.escape(arg)) for arg in self._args]) + placeholders = _get_human_readable_list(self._placeholders.values()) + args = _get_human_readable_list(self._args) if len(self._placeholders) < len(self._args): - raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({_args})") + raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({args})") - raise RuntimeError(f"more placeholders ({placeholders}) than values ({_args})") - - for arg_index, token_index in enumerate(self._placeholders.keys()): - self._tokens[token_index] = self._sql_sanitizer.escape(self._args[arg_index]) + raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") def _plugin_numeric_params(self): - unused_arg_idxs = set(range(len(self._args))) + unused_arg_indices = set(range(len(self._args))) for token_index, num in self._placeholders.items(): if num >= len(self._args): raise RuntimeError(f"missing value for placeholder ({num + 1})") - self._tokens[token_index] = self._sql_sanitizer.escape(self._args[num]) - unused_arg_idxs.remove(num) + self._tokens[token_index] = self._args[num] + unused_arg_indices.remove(num) - if len(unused_arg_idxs) > 0: - unused_args = ", ".join( - [str(self._sql_sanitizer.escape(self._args[i])) for i in sorted(unused_arg_idxs)]) - raise RuntimeError( - f"unused value{'' if len(unused_arg_idxs) == 1 else 's'} ({unused_args})") + if len(unused_arg_indices) > 0: + unused_args = _get_human_readable_list([self._args[i] for i in sorted(unused_arg_indices)]) + raise RuntimeError(f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") def _plugin_named_or_pyformat_params(self): @@ -110,11 +119,11 @@ def _plugin_named_or_pyformat_params(self): if param_name not in self._kwargs: raise RuntimeError(f"missing value for placeholder ({param_name})") - self._tokens[token_index] = self._sql_sanitizer.escape(self._kwargs[param_name]) + self._tokens[token_index] = self._kwargs[param_name] unused_params.remove(param_name) if len(unused_params) > 0: - joined_unused_params = ", ".join(sorted(unused_params)) + joined_unused_params = _get_human_readable_list(sorted(unused_params)) raise RuntimeError( f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") @@ -147,6 +156,9 @@ def __str__(self): return "".join([str(token) for token in self._tokens]) + + + def _format_and_parse(sql): formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) @@ -201,6 +213,10 @@ def _is_identifier(ttype): return ttype == sqlparse.tokens.Literal.String.Symbol +def _get_human_readable_list(iterable): + return ", ".join(str(v) for v in iterable) + + class _Paramstyle(enum.Enum): FORMAT = enum.auto() NAMED = enum.auto() From 713368f0d5b44c9ce22739410a96b77ae08c237b Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 22:59:37 -0400 Subject: [PATCH 21/47] reorder methods --- src/cs50/_statement.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index d02f844..50e6673 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -16,15 +16,19 @@ def __init__(self, dialect, sql, *args, **kwargs): raise RuntimeError("cannot pass both positional and named parameters") self._sql_sanitizer = SQLSanitizer(dialect) + self._args = self._get_escaped_args(args) self._kwargs = self._get_escaped_kwargs(kwargs) + self._statement = _format_and_parse(sql) self._tokens = self._tokenize() + + self._operation_keyword = self._get_operation_keyword() + self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() self._plugin_escaped_params() self._escape_verbatim_colons() - self._operation_keyword = self._get_operation_keyword() def _get_escaped_args(self, args): @@ -39,6 +43,19 @@ def _tokenize(self): return list(self._statement.flatten()) + def _get_operation_keyword(self): + for token in self._statement: + if _is_operation_token(token.ttype): + token_value = token.value.upper() + if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: + operation_keyword = token_value + break + else: + operation_keyword = None + + return operation_keyword + + def _get_paramstyle(self): paramstyle = None for token in self._tokens: @@ -134,19 +151,6 @@ def _escape_verbatim_colons(self): token.value = escape_verbatim_colon(token.value) - def _get_operation_keyword(self): - for token in self._statement: - if _is_operation_token(token.ttype): - token_value = token.value.upper() - if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: - operation_keyword = token_value - break - else: - operation_keyword = None - - return operation_keyword - - def get_operation_keyword(self): """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" return self._operation_keyword From 2187d95b51f9020c98cb6feb5102db375633196d Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 23:15:01 -0400 Subject: [PATCH 22/47] refactor logger --- src/cs50/_logger.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py index df021a3..1307e19 100644 --- a/src/cs50/_logger.py +++ b/src/cs50/_logger.py @@ -10,17 +10,26 @@ def _setup_logger(): - # Configure default logging handler and formatter - # Prevent flask, werkzeug, etc from adding default handler + _configure_default_logger() + _patch_root_handler_format_exception() + _configure_cs50_logger() + _patch_excepthook() + + +def _configure_default_logger(): + """Configure default handler and formatter to prevent flask and werkzeug from adding theirs""" logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) + +def _patch_root_handler_format_exception(): try: - # Patch formatException formatter = logging.root.handlers[0].formatter formatter.formatException = lambda exc_info: _format_exception(*exc_info) except IndexError: pass + +def _configure_cs50_logger(): _logger = logging.getLogger("cs50") _logger.disabled = True _logger.setLevel(logging.DEBUG) @@ -36,6 +45,8 @@ def _setup_logger(): handler.setFormatter(formatter) _logger.addHandler(handler) + +def _patch_excepthook(): sys.excepthook = lambda type_, value, exc_tb: print( _format_exception(type_, value, exc_tb), file=sys.stderr) From 758910e7374ae6d8b7a35636d6c122318648d9c7 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 12 Apr 2021 23:30:30 -0400 Subject: [PATCH 23/47] fix style --- src/cs50/_session.py | 4 ++-- src/cs50/_sql_sanitizer.py | 2 -- src/cs50/_statement.py | 25 +++++-------------------- src/cs50/cs50.py | 4 +++- src/cs50/sql.py | 6 +----- 5 files changed, 11 insertions(+), 30 deletions(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index cd23453..4c63b39 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -6,21 +6,21 @@ import sqlalchemy import sqlalchemy.orm + class Session: """Wraps a SQLAlchemy scoped session""" + def __init__(self, url, **engine_kwargs): if _is_sqlite_url(url): _assert_sqlite_file_exists(url) self._session = _create_session(url, **engine_kwargs) - def execute(self, statement): """Converts statement to str and executes it""" # pylint: disable=no-member return self._session.execute(sqlalchemy.text(str(statement))) - def __getattr__(self, attr): return getattr(self._session, attr) diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py index c2f35c4..f4ff3e0 100644 --- a/src/cs50/_sql_sanitizer.py +++ b/src/cs50/_sql_sanitizer.py @@ -13,7 +13,6 @@ class SQLSanitizer: def __init__(self, dialect): self._dialect = dialect - def escape(self, value): """ Escapes value using engine's conversion function. @@ -71,7 +70,6 @@ def escape(self, value): raise RuntimeError(f"unsupported value: {value}") - def escape_iterable(self, iterable): """Escapes a collection of values (e.g., list, tuple)""" return sqlparse.sql.TokenList( diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 50e6673..9f9fae8 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -11,6 +11,7 @@ class Statement: """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" + def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -30,19 +31,15 @@ def __init__(self, dialect, sql, *args, **kwargs): self._plugin_escaped_params() self._escape_verbatim_colons() - def _get_escaped_args(self, args): return [self._sql_sanitizer.escape(arg) for arg in args] - def _get_escaped_kwargs(self, kwargs): return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()} - def _tokenize(self): return list(self._statement.flatten()) - def _get_operation_keyword(self): for token in self._statement: if _is_operation_token(token.ttype): @@ -55,7 +52,6 @@ def _get_operation_keyword(self): return operation_keyword - def _get_paramstyle(self): paramstyle = None for token in self._tokens: @@ -67,7 +63,6 @@ def _get_paramstyle(self): return paramstyle - def _default_paramstyle(self): paramstyle = None if self._args: @@ -77,7 +72,6 @@ def _default_paramstyle(self): return paramstyle - def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): @@ -90,7 +84,6 @@ def _get_placeholders(self): return placeholders - def _plugin_escaped_params(self): if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: self._plugin_format_or_qmark_params() @@ -99,13 +92,11 @@ def _plugin_escaped_params(self): if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: self._plugin_named_or_pyformat_params() - def _plugin_format_or_qmark_params(self): self._assert_valid_arg_count() for arg_index, token_index in enumerate(self._placeholders.keys()): self._tokens[token_index] = self._args[arg_index] - def _assert_valid_arg_count(self): if len(self._placeholders) != len(self._args): placeholders = _get_human_readable_list(self._placeholders.values()) @@ -115,7 +106,6 @@ def _assert_valid_arg_count(self): raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") - def _plugin_numeric_params(self): unused_arg_indices = set(range(len(self._args))) for token_index, num in self._placeholders.items(): @@ -126,9 +116,10 @@ def _plugin_numeric_params(self): unused_arg_indices.remove(num) if len(unused_arg_indices) > 0: - unused_args = _get_human_readable_list([self._args[i] for i in sorted(unused_arg_indices)]) - raise RuntimeError(f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") - + unused_args = _get_human_readable_list( + [self._args[i] for i in sorted(unused_arg_indices)]) + raise RuntimeError( + f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") def _plugin_named_or_pyformat_params(self): unused_params = set(self._kwargs.keys()) @@ -144,25 +135,19 @@ def _plugin_named_or_pyformat_params(self): raise RuntimeError( f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") - def _escape_verbatim_colons(self): for token in self._tokens: if _is_string_literal(token.ttype) or _is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) - def get_operation_keyword(self): """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" return self._operation_keyword - def __str__(self): return "".join([str(token) for token in self._tokens]) - - - def _format_and_parse(sql): formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 24c748b..30d3515 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -36,7 +36,7 @@ def get_int(prompt): """ while True: try: - return _get_int(prompt) + return _get_int(prompt) except (MemoryError, ValueError): pass @@ -89,9 +89,11 @@ def write(self, data): self.stream.write(data) self.stream.flush() + def disable_output_buffering(): """Disables output buffering to prevent prompts from being buffered""" sys.stderr = _flushfile(sys.stderr) sys.stdout = _flushfile(sys.stdout) + disable_output_buffering() diff --git a/src/cs50/sql.py b/src/cs50/sql.py index fca57d2..8547aca 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -15,13 +15,13 @@ class SQL: """Wraps SQLAlchemy""" + def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) self._dialect = self._session.get_bind().dialect self._is_postgres = self._dialect.name in {"postgres", "postgresql"} self._autocommit = False - def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = Statement(self._dialect, sql, *args, **kwargs) @@ -53,7 +53,6 @@ def execute(self, sql, *args, **kwargs): return ret - def _execute(self, statement): # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -72,20 +71,17 @@ def _execute(self, statement): _logger.debug(termcolor.colored(str(statement), "green")) return result - def _last_row_id_or_none(self, result): if self._is_postgres: return self._get_last_val() return result.lastrowid if result.rowcount == 1 else None - def _get_last_val(self): try: return self._session.execute("SELECT LASTVAL()").first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session return None - def init_app(self, app): """Registers a teardown_appcontext listener to remove session and enables logging""" @app.teardown_appcontext From 32db777581af3929a69ab3aaded835cf7823e9a9 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 07:04:58 -0400 Subject: [PATCH 24/47] factor out utility functions --- src/cs50/_statement.py | 80 +++++-------------------------------- src/cs50/_statement_util.py | 72 +++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 70 deletions(-) create mode 100644 src/cs50/_statement_util.py diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 9f9fae8..c673719 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,12 +1,18 @@ """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" import collections -import enum -import re - -import sqlparse from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon +from ._statement_util import ( + _format_and_parse, + _get_human_readable_list, + _is_identifier, + _is_operation_token, + _is_placeholder, + _is_string_literal, + _Paramstyle, + _parse_placeholder, +) class Statement: @@ -146,69 +152,3 @@ def get_operation_keyword(self): def __str__(self): return "".join([str(token) for token in self._tokens]) - - -def _format_and_parse(sql): - formatted_statements = sqlparse.format(sql, strip_comments=True).strip() - parsed_statements = sqlparse.parse(formatted_statements) - statement_count = len(parsed_statements) - if statement_count == 0: - raise RuntimeError("missing statement") - if statement_count > 1: - raise RuntimeError("too many statements at once") - - return parsed_statements[0] - - -def _is_placeholder(ttype): - return ttype == sqlparse.tokens.Name.Placeholder - - -def _parse_placeholder(value): - if value == "?": - return _Paramstyle.QMARK, None - - # E.g., :1 - matches = re.search(r"^:([1-9]\d*)$", value) - if matches: - return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 - - # E.g., :foo - matches = re.search(r"^:([a-zA-Z]\w*)$", value) - if matches: - return _Paramstyle.NAMED, matches.group(1) - - if value == "%s": - return _Paramstyle.FORMAT, None - - # E.g., %(foo)s - matches = re.search(r"%\((\w+)\)s$", value) - if matches: - return _Paramstyle.PYFORMAT, matches.group(1) - - raise RuntimeError(f"{value}: invalid placeholder") - - -def _is_operation_token(ttype): - return ttype in { - sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} - - -def _is_string_literal(ttype): - return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] - - -def _is_identifier(ttype): - return ttype == sqlparse.tokens.Literal.String.Symbol - - -def _get_human_readable_list(iterable): - return ", ".join(str(v) for v in iterable) - - -class _Paramstyle(enum.Enum): - FORMAT = enum.auto() - NAMED = enum.auto() - NUMERIC = enum.auto() - PYFORMAT = enum.auto() - QMARK = enum.auto() diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py new file mode 100644 index 0000000..f299767 --- /dev/null +++ b/src/cs50/_statement_util.py @@ -0,0 +1,72 @@ +"""Utility functions used by _statement.py""" + +import enum +import re + +import sqlparse + + +class _Paramstyle(enum.Enum): + FORMAT = enum.auto() + NAMED = enum.auto() + NUMERIC = enum.auto() + PYFORMAT = enum.auto() + QMARK = enum.auto() + + +def _format_and_parse(sql): + formatted_statements = sqlparse.format(sql, strip_comments=True).strip() + parsed_statements = sqlparse.parse(formatted_statements) + statement_count = len(parsed_statements) + if statement_count == 0: + raise RuntimeError("missing statement") + if statement_count > 1: + raise RuntimeError("too many statements at once") + + return parsed_statements[0] + + +def _is_placeholder(ttype): + return ttype == sqlparse.tokens.Name.Placeholder + + +def _parse_placeholder(value): + if value == "?": + return _Paramstyle.QMARK, None + + # E.g., :1 + matches = re.search(r"^:([1-9]\d*)$", value) + if matches: + return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 + + # E.g., :foo + matches = re.search(r"^:([a-zA-Z]\w*)$", value) + if matches: + return _Paramstyle.NAMED, matches.group(1) + + if value == "%s": + return _Paramstyle.FORMAT, None + + # E.g., %(foo)s + matches = re.search(r"%\((\w+)\)s$", value) + if matches: + return _Paramstyle.PYFORMAT, matches.group(1) + + raise RuntimeError(f"{value}: invalid placeholder") + + +def _is_operation_token(ttype): + return ttype in { + sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} + + +def _is_string_literal(ttype): + return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] + + +def _is_identifier(ttype): + return ttype == sqlparse.tokens.Literal.String.Symbol + + +def _get_human_readable_list(iterable): + return ", ".join(str(v) for v in iterable) From 1f622834ddecd571a0c49190a190035f4bbb95e3 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 07:08:13 -0400 Subject: [PATCH 25/47] factor out session utility functions --- src/cs50/_session.py | 67 ++++----------------------------------- src/cs50/_session_util.py | 63 ++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 61 deletions(-) create mode 100644 src/cs50/_session_util.py diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 4c63b39..a67f02d 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -1,11 +1,14 @@ """Wraps a SQLAlchemy scoped session""" -import os -import sqlite3 - import sqlalchemy import sqlalchemy.orm +from ._session_util import ( + _is_sqlite_url, + _assert_sqlite_file_exists, + _create_session, +) + class Session: """Wraps a SQLAlchemy scoped session""" @@ -23,61 +26,3 @@ def execute(self, statement): def __getattr__(self, attr): return getattr(self._session, attr) - - -def _is_sqlite_url(url): - return url.startswith("sqlite:///") - - -def _assert_sqlite_file_exists(url): - path = url[len("sqlite:///"):] - if not os.path.exists(path): - raise RuntimeError(f"does not exist: {path}") - if not os.path.isfile(path): - raise RuntimeError(f"not a file: {path}") - - -def _create_session(url, **engine_kwargs): - engine = _create_engine(url, **engine_kwargs) - _setup_on_connect(engine) - return _create_scoped_session(engine) - - -def _create_engine(url, **kwargs): - try: - engine = sqlalchemy.create_engine(url, **kwargs) - except sqlalchemy.exc.ArgumentError: - raise RuntimeError(f"invalid URL: {url}") from None - - engine.execution_options(autocommit=False) - return engine - - -def _setup_on_connect(engine): - def connect(dbapi_connection, _): - _disable_auto_begin_commit(dbapi_connection) - if _is_sqlite_connection(dbapi_connection): - _enable_sqlite_foreign_key_constraints(dbapi_connection) - - sqlalchemy.event.listen(engine, "connect", connect) - - -def _create_scoped_session(engine): - session_factory = sqlalchemy.orm.sessionmaker(bind=engine) - return sqlalchemy.orm.scoping.scoped_session(session_factory) - - -def _disable_auto_begin_commit(dbapi_connection): - # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves - # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl - dbapi_connection.isolation_level = None - - -def _is_sqlite_connection(dbapi_connection): - return isinstance(dbapi_connection, sqlite3.Connection) - - -def _enable_sqlite_foreign_key_constraints(dbapi_connection): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py new file mode 100644 index 0000000..c0cf33a --- /dev/null +++ b/src/cs50/_session_util.py @@ -0,0 +1,63 @@ +"""Utility functions used by _session.py""" + +import os +import sqlite3 + +import sqlalchemy + +def _is_sqlite_url(url): + return url.startswith("sqlite:///") + + +def _assert_sqlite_file_exists(url): + path = url[len("sqlite:///"):] + if not os.path.exists(path): + raise RuntimeError(f"does not exist: {path}") + if not os.path.isfile(path): + raise RuntimeError(f"not a file: {path}") + + +def _create_session(url, **engine_kwargs): + engine = _create_engine(url, **engine_kwargs) + _setup_on_connect(engine) + return _create_scoped_session(engine) + + +def _create_engine(url, **kwargs): + try: + engine = sqlalchemy.create_engine(url, **kwargs) + except sqlalchemy.exc.ArgumentError: + raise RuntimeError(f"invalid URL: {url}") from None + + engine.execution_options(autocommit=False) + return engine + + +def _setup_on_connect(engine): + def connect(dbapi_connection, _): + _disable_auto_begin_commit(dbapi_connection) + if _is_sqlite_connection(dbapi_connection): + _enable_sqlite_foreign_key_constraints(dbapi_connection) + + sqlalchemy.event.listen(engine, "connect", connect) + + +def _create_scoped_session(engine): + session_factory = sqlalchemy.orm.sessionmaker(bind=engine) + return sqlalchemy.orm.scoping.scoped_session(session_factory) + + +def _disable_auto_begin_commit(dbapi_connection): + # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves + # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + dbapi_connection.isolation_level = None + + +def _is_sqlite_connection(dbapi_connection): + return isinstance(dbapi_connection, sqlite3.Connection) + + +def _enable_sqlite_foreign_key_constraints(dbapi_connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() From c0534e27dc005dab42062c39f03751e075aa9b6a Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 07:11:26 -0400 Subject: [PATCH 26/47] factor out sql utility functions --- src/cs50/_sql_util.py | 18 ++++++++++++++++++ src/cs50/sql.py | 20 ++------------------ 2 files changed, 20 insertions(+), 18 deletions(-) create mode 100644 src/cs50/_sql_util.py diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py new file mode 100644 index 0000000..ea3edad --- /dev/null +++ b/src/cs50/_sql_util.py @@ -0,0 +1,18 @@ +"""Utility functions used by sql.py""" + +import decimal + +def fetch_select_result(result): + rows = [dict(row) for row in result.fetchall()] + for row in rows: + for column in row: + # Coerce decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + if isinstance(row[column], decimal.Decimal): + row[column] = float(row[column]) + + # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes + elif isinstance(row[column], memoryview): + row[column] = bytes(row[column]) + + return rows diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 8547aca..d823c8b 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,6 +1,5 @@ """Wraps SQLAlchemy""" -import decimal import logging import warnings @@ -9,6 +8,7 @@ from ._session import Session from ._statement import Statement +from ._sql_util import fetch_select_result _logger = logging.getLogger("cs50") @@ -43,7 +43,7 @@ def execute(self, sql, *args, **kwargs): self._session.remove() if operation_keyword == "SELECT": - ret = _fetch_select_result(result) + ret = fetch_select_result(result) elif operation_keyword == "INSERT": ret = self._last_row_id_or_none(result) elif operation_keyword in {"DELETE", "UPDATE"}: @@ -89,19 +89,3 @@ def _(_): self._session.remove() logging.getLogger("cs50").disabled = False - - -def _fetch_select_result(result): - rows = [dict(row) for row in result.fetchall()] - for row in rows: - for column in row: - # Coerce decimal.Decimal objects to float objects - # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - if isinstance(row[column], decimal.Decimal): - row[column] = float(row[column]) - - # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes - elif isinstance(row[column], memoryview): - row[column] = bytes(row[column]) - - return rows From 37fba4f5707aa5e57b41f893b2a6d5f209238a4d Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 07:16:38 -0400 Subject: [PATCH 27/47] remove underscore from util functions --- src/cs50/_session.py | 12 +++++----- src/cs50/_session_util.py | 6 ++--- src/cs50/_statement.py | 48 ++++++++++++++++++------------------- src/cs50/_statement_util.py | 26 ++++++++++---------- 4 files changed, 46 insertions(+), 46 deletions(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index a67f02d..0a30c36 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -4,9 +4,9 @@ import sqlalchemy.orm from ._session_util import ( - _is_sqlite_url, - _assert_sqlite_file_exists, - _create_session, + is_sqlite_url, + assert_sqlite_file_exists, + create_session, ) @@ -14,10 +14,10 @@ class Session: """Wraps a SQLAlchemy scoped session""" def __init__(self, url, **engine_kwargs): - if _is_sqlite_url(url): - _assert_sqlite_file_exists(url) + if is_sqlite_url(url): + assert_sqlite_file_exists(url) - self._session = _create_session(url, **engine_kwargs) + self._session = create_session(url, **engine_kwargs) def execute(self, statement): """Converts statement to str and executes it""" diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py index c0cf33a..3433fa9 100644 --- a/src/cs50/_session_util.py +++ b/src/cs50/_session_util.py @@ -5,11 +5,11 @@ import sqlalchemy -def _is_sqlite_url(url): +def is_sqlite_url(url): return url.startswith("sqlite:///") -def _assert_sqlite_file_exists(url): +def assert_sqlite_file_exists(url): path = url[len("sqlite:///"):] if not os.path.exists(path): raise RuntimeError(f"does not exist: {path}") @@ -17,7 +17,7 @@ def _assert_sqlite_file_exists(url): raise RuntimeError(f"not a file: {path}") -def _create_session(url, **engine_kwargs): +def create_session(url, **engine_kwargs): engine = _create_engine(url, **engine_kwargs) _setup_on_connect(engine) return _create_scoped_session(engine) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index c673719..ac83758 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -4,14 +4,14 @@ from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon from ._statement_util import ( - _format_and_parse, - _get_human_readable_list, - _is_identifier, - _is_operation_token, - _is_placeholder, - _is_string_literal, - _Paramstyle, - _parse_placeholder, + format_and_parse, + get_human_readable_list, + is_identifier, + is_operation_token, + is_placeholder, + is_string_literal, + Paramstyle, + parse_placeholder, ) @@ -27,7 +27,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._args = self._get_escaped_args(args) self._kwargs = self._get_escaped_kwargs(kwargs) - self._statement = _format_and_parse(sql) + self._statement = format_and_parse(sql) self._tokens = self._tokenize() self._operation_keyword = self._get_operation_keyword() @@ -48,7 +48,7 @@ def _tokenize(self): def _get_operation_keyword(self): for token in self._statement: - if _is_operation_token(token.ttype): + if is_operation_token(token.ttype): token_value = token.value.upper() if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: operation_keyword = token_value @@ -61,8 +61,8 @@ def _get_operation_keyword(self): def _get_paramstyle(self): paramstyle = None for token in self._tokens: - if _is_placeholder(token.ttype): - paramstyle, _ = _parse_placeholder(token.value) + if is_placeholder(token.ttype): + paramstyle, _ = parse_placeholder(token.value) break else: paramstyle = self._default_paramstyle() @@ -72,17 +72,17 @@ def _get_paramstyle(self): def _default_paramstyle(self): paramstyle = None if self._args: - paramstyle = _Paramstyle.QMARK + paramstyle = Paramstyle.QMARK elif self._kwargs: - paramstyle = _Paramstyle.NAMED + paramstyle = Paramstyle.NAMED return paramstyle def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): - if _is_placeholder(token.ttype): - paramstyle, name = _parse_placeholder(token.value) + if is_placeholder(token.ttype): + paramstyle, name = parse_placeholder(token.value) if paramstyle != self._paramstyle: raise RuntimeError("inconsistent paramstyle") @@ -91,11 +91,11 @@ def _get_placeholders(self): return placeholders def _plugin_escaped_params(self): - if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: + if self._paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: self._plugin_format_or_qmark_params() - elif self._paramstyle == _Paramstyle.NUMERIC: + elif self._paramstyle == Paramstyle.NUMERIC: self._plugin_numeric_params() - if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: + if self._paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: self._plugin_named_or_pyformat_params() def _plugin_format_or_qmark_params(self): @@ -105,8 +105,8 @@ def _plugin_format_or_qmark_params(self): def _assert_valid_arg_count(self): if len(self._placeholders) != len(self._args): - placeholders = _get_human_readable_list(self._placeholders.values()) - args = _get_human_readable_list(self._args) + placeholders = get_human_readable_list(self._placeholders.values()) + args = get_human_readable_list(self._args) if len(self._placeholders) < len(self._args): raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({args})") @@ -122,7 +122,7 @@ def _plugin_numeric_params(self): unused_arg_indices.remove(num) if len(unused_arg_indices) > 0: - unused_args = _get_human_readable_list( + unused_args = get_human_readable_list( [self._args[i] for i in sorted(unused_arg_indices)]) raise RuntimeError( f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") @@ -137,13 +137,13 @@ def _plugin_named_or_pyformat_params(self): unused_params.remove(param_name) if len(unused_params) > 0: - joined_unused_params = _get_human_readable_list(sorted(unused_params)) + joined_unused_params = get_human_readable_list(sorted(unused_params)) raise RuntimeError( f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") def _escape_verbatim_colons(self): for token in self._tokens: - if _is_string_literal(token.ttype) or _is_identifier(token.ttype): + if is_string_literal(token.ttype) or is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) def get_operation_keyword(self): diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py index f299767..81b79e1 100644 --- a/src/cs50/_statement_util.py +++ b/src/cs50/_statement_util.py @@ -6,7 +6,7 @@ import sqlparse -class _Paramstyle(enum.Enum): +class Paramstyle(enum.Enum): FORMAT = enum.auto() NAMED = enum.auto() NUMERIC = enum.auto() @@ -14,7 +14,7 @@ class _Paramstyle(enum.Enum): QMARK = enum.auto() -def _format_and_parse(sql): +def format_and_parse(sql): formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) statement_count = len(parsed_statements) @@ -26,47 +26,47 @@ def _format_and_parse(sql): return parsed_statements[0] -def _is_placeholder(ttype): +def is_placeholder(ttype): return ttype == sqlparse.tokens.Name.Placeholder -def _parse_placeholder(value): +def parse_placeholder(value): if value == "?": - return _Paramstyle.QMARK, None + return Paramstyle.QMARK, None # E.g., :1 matches = re.search(r"^:([1-9]\d*)$", value) if matches: - return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 + return Paramstyle.NUMERIC, int(matches.group(1)) - 1 # E.g., :foo matches = re.search(r"^:([a-zA-Z]\w*)$", value) if matches: - return _Paramstyle.NAMED, matches.group(1) + return Paramstyle.NAMED, matches.group(1) if value == "%s": - return _Paramstyle.FORMAT, None + return Paramstyle.FORMAT, None # E.g., %(foo)s matches = re.search(r"%\((\w+)\)s$", value) if matches: - return _Paramstyle.PYFORMAT, matches.group(1) + return Paramstyle.PYFORMAT, matches.group(1) raise RuntimeError(f"{value}: invalid placeholder") -def _is_operation_token(ttype): +def is_operation_token(ttype): return ttype in { sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} -def _is_string_literal(ttype): +def is_string_literal(ttype): return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] -def _is_identifier(ttype): +def is_identifier(ttype): return ttype == sqlparse.tokens.Literal.String.Symbol -def _get_human_readable_list(iterable): +def get_human_readable_list(iterable): return ", ".join(str(v) for v in iterable) From e06131c42a3e300cd1a8496dc95a9464381b8961 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 07:20:01 -0400 Subject: [PATCH 28/47] fix style --- src/cs50/_session_util.py | 1 + src/cs50/_sql_util.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py index 3433fa9..ed44eaa 100644 --- a/src/cs50/_session_util.py +++ b/src/cs50/_session_util.py @@ -5,6 +5,7 @@ import sqlalchemy + def is_sqlite_url(url): return url.startswith("sqlite:///") diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index ea3edad..238d979 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -2,6 +2,7 @@ import decimal + def fetch_select_result(result): rows = [dict(row) for row in result.fetchall()] for row in rows: From 86f981f04377c2f81be1fc012c9d59d91c85d5ca Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 07:21:03 -0400 Subject: [PATCH 29/47] reorder imports --- src/cs50/_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 0a30c36..c1ea426 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -4,9 +4,9 @@ import sqlalchemy.orm from ._session_util import ( - is_sqlite_url, assert_sqlite_file_exists, create_session, + is_sqlite_url, ) From 67a7f0c8559ecd719cab86012f443e4d36cd3f8a Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 11:57:03 -0400 Subject: [PATCH 30/47] remove manual tests --- tests/flask/application.py | 22 --------------- tests/flask/requirements.txt | 2 -- tests/flask/templates/error.html | 10 ------- tests/flask/templates/index.html | 10 ------- tests/foo.py | 48 -------------------------------- tests/mysql.py | 8 ------ tests/python.py | 8 ------ tests/sqlite.py | 44 ----------------------------- tests/tb.py | 10 ------- 9 files changed, 162 deletions(-) delete mode 100644 tests/flask/application.py delete mode 100644 tests/flask/requirements.txt delete mode 100644 tests/flask/templates/error.html delete mode 100644 tests/flask/templates/index.html delete mode 100644 tests/foo.py delete mode 100644 tests/mysql.py delete mode 100644 tests/python.py delete mode 100644 tests/sqlite.py delete mode 100644 tests/tb.py diff --git a/tests/flask/application.py b/tests/flask/application.py deleted file mode 100644 index 939a8f9..0000000 --- a/tests/flask/application.py +++ /dev/null @@ -1,22 +0,0 @@ -import requests -import sys -from flask import Flask, render_template - -sys.path.insert(0, "../../src") - -import cs50 -import cs50.flask - -app = Flask(__name__) - -db = cs50.SQL("sqlite:///../sqlite.db") - -@app.route("/") -def index(): - db.execute("SELECT 1") - """ - def f(): - res = requests.get("cs50.harvard.edu") - f() - """ - return render_template("index.html") diff --git a/tests/flask/requirements.txt b/tests/flask/requirements.txt deleted file mode 100644 index 7d0c101..0000000 --- a/tests/flask/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -cs50 -Flask diff --git a/tests/flask/templates/error.html b/tests/flask/templates/error.html deleted file mode 100644 index 3302040..0000000 --- a/tests/flask/templates/error.html +++ /dev/null @@ -1,10 +0,0 @@ - - - - - error - - - error - - diff --git a/tests/flask/templates/index.html b/tests/flask/templates/index.html deleted file mode 100644 index 2f6a145..0000000 --- a/tests/flask/templates/index.html +++ /dev/null @@ -1,10 +0,0 @@ - - - - - flask - - - flask - - diff --git a/tests/foo.py b/tests/foo.py deleted file mode 100644 index 7f32a00..0000000 --- a/tests/foo.py +++ /dev/null @@ -1,48 +0,0 @@ -import logging -import sys - -sys.path.insert(0, "../src") - -import cs50 - -""" -db = cs50.SQL("sqlite:///foo.db") - -logging.getLogger("cs50").disabled = False - -#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c") -db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)") - -db.execute("INSERT INTO bar VALUES (?)", "baz") -db.execute("INSERT INTO bar VALUES (?)", "qux") -db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")) -db.execute("DELETE FROM bar") -""" - -db = cs50.SQL("postgresql://postgres@localhost/test") - -""" -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("SELECT * FROM cs50")) - -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("SELECT * FROM cs50")) -""" - -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("INSERT INTO cs50 (val) VALUES('bar')")) -print(db.execute("INSERT INTO cs50 (val) VALUES('baz')")) -print(db.execute("SELECT * FROM cs50")) -try: - print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')")) -except Exception as e: - print(e) - pass -print(db.execute("INSERT INTO cs50 (val) VALUES('qux')")) -#print(db.execute("DELETE FROM cs50")) diff --git a/tests/mysql.py b/tests/mysql.py deleted file mode 100644 index 2a431c3..0000000 --- a/tests/mysql.py +++ /dev/null @@ -1,8 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -from cs50 import SQL - -db = SQL("mysql://root@localhost/test") -db.execute("SELECT 1") diff --git a/tests/python.py b/tests/python.py deleted file mode 100644 index 6a265cb..0000000 --- a/tests/python.py +++ /dev/null @@ -1,8 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -import cs50 - -i = cs50.get_int("Input: ") -print(f"Output: {i}") diff --git a/tests/sqlite.py b/tests/sqlite.py deleted file mode 100644 index 05c2cea..0000000 --- a/tests/sqlite.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging -import sys - -sys.path.insert(0, "../src") - -from cs50 import SQL - -logging.getLogger("cs50").disabled = False - -db = SQL("sqlite:///sqlite.db") -db.execute("SELECT 1") - -# TODO -#db.execute("SELECT * FROM Employee WHERE FirstName = ?", b'\x00') - -db.execute("SELECT * FROM Employee WHERE FirstName = ?", "' OR 1 = 1") - -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", "Andrew") -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew"]) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew",)) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew", "Nancy"]) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew", "Nancy")) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", []) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ()) - -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = ':Andrew :Adams'") - -db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", first="Andrew", last="Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", {"first": "Andrew", "last": "Adams"}) - -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", first="Andrew", last="Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", {"first": "Andrew", "last": "Adams"}) diff --git a/tests/tb.py b/tests/tb.py deleted file mode 100644 index 3ad8175..0000000 --- a/tests/tb.py +++ /dev/null @@ -1,10 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -import cs50 -import requests - -def f(): - res = requests.get("cs50.harvard.edu") -f() From 839b1f1baca132bc4edf83489f3856f7438bf6de Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 12:40:34 -0400 Subject: [PATCH 31/47] add statement tests, rollback on error in autocommit --- src/cs50/_sql_util.py | 8 ++ src/cs50/_statement.py | 3 +- src/cs50/_statement_util.py | 12 ++ src/cs50/sql.py | 14 ++- tests/test_statement.py | 213 ++++++++++++++++++++++++++++++++++++ 5 files changed, 244 insertions(+), 6 deletions(-) create mode 100644 tests/test_statement.py diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index 238d979..dbaff2e 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -3,6 +3,14 @@ import decimal +def is_transaction_start(keyword): + return keyword in {"BEGIN", "START"} + + +def is_transaction_end(keyword): + return keyword in {"COMMIT", "ROLLBACK"} + + def fetch_select_result(result): rows = [dict(row) for row in result.fetchall()] for row in rows: diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index ac83758..3347f61 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -10,6 +10,7 @@ is_operation_token, is_placeholder, is_string_literal, + operation_keywords, Paramstyle, parse_placeholder, ) @@ -50,7 +51,7 @@ def _get_operation_keyword(self): for token in self._statement: if is_operation_token(token.ttype): token_value = token.value.upper() - if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: + if token_value in operation_keywords: operation_keyword = token_value break else: diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py index 81b79e1..4ef092a 100644 --- a/src/cs50/_statement_util.py +++ b/src/cs50/_statement_util.py @@ -6,6 +6,18 @@ import sqlparse +operation_keywords = { + "BEGIN", + "COMMIT", + "DELETE", + "INSERT", + "ROLLBACK", + "SELECT", + "START", + "UPDATE" +} + + class Paramstyle(enum.Enum): FORMAT = enum.auto() NAMED = enum.auto() diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d823c8b..ae9f97e 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -8,7 +8,7 @@ from ._session import Session from ._statement import Statement -from ._sql_util import fetch_select_result +from ._sql_util import fetch_select_result, is_transaction_start, is_transaction_end _logger = logging.getLogger("cs50") @@ -26,7 +26,7 @@ def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = Statement(self._dialect, sql, *args, **kwargs) operation_keyword = statement.get_operation_keyword() - if operation_keyword in {"BEGIN", "START"}: + if is_transaction_start(operation_keyword): self._autocommit = False if self._autocommit: @@ -36,11 +36,9 @@ def execute(self, sql, *args, **kwargs): if self._autocommit: self._session.execute("COMMIT") - self._session.remove() - if operation_keyword in {"COMMIT", "ROLLBACK"}: + if is_transaction_end(operation_keyword): self._autocommit = True - self._session.remove() if operation_keyword == "SELECT": ret = fetch_select_result(result) @@ -51,8 +49,12 @@ def execute(self, sql, *args, **kwargs): else: ret = True + if self._autocommit: + self._session.remove() + return ret + def _execute(self, statement): # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -62,6 +64,8 @@ def _execute(self, statement): result = self._session.execute(statement) except sqlalchemy.exc.IntegrityError as exc: _logger.debug(termcolor.colored(str(statement), "yellow")) + if self._autocommit: + self._session.execute("ROLLBACK") raise ValueError(exc.orig) from None except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: self._session.remove() diff --git a/tests/test_statement.py b/tests/test_statement.py new file mode 100644 index 0000000..cbbafe8 --- /dev/null +++ b/tests/test_statement.py @@ -0,0 +1,213 @@ +import unittest + +from unittest.mock import patch + +from cs50._statement import Statement +from cs50._sql_sanitizer import SQLSanitizer + +class TestStatement(unittest.TestCase): + # TODO assert correct exception messages + def test_mutex_args_and_kwargs(self): + with self.assertRaises(RuntimeError): + Statement("", "", "test", foo="foo") + + with self.assertRaises(RuntimeError): + Statement("", "", "test", 1, 2, foo="foo", bar="bar") + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_valid_qmark_count(self, *_): + Statement("", "SELECT * FROM test WHERE id = ?", 1) + Statement("", "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') + Statement("", "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_invalid_qmark_count(self, *_): + def assert_invalid_count(sql, *args): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): + Statement("", sql, *args) + + statements = [ + ("SELECT * FROM test WHERE id = ?", ()), + ("SELECT * FROM test WHERE id = ?", (1, "test")), + ("SELECT * FROM test WHERE id = ? AND val = ?", (1,)), + ("SELECT * FROM test WHERE id = ? AND val = ?", ()), + ("SELECT * FROM test WHERE id = ? AND val = ?", (1, "test", True)), + ] + + for sql, args in statements: + assert_invalid_count(sql, *args) + + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_valid_format_count(self, *_): + Statement("", "SELECT * FROM test WHERE id = %s", 1) + Statement("", "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') + Statement("", "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_invalid_format_count(self, *_): + def assert_invalid_count(sql, *args): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): + Statement("", sql, *args) + + statements = [ + ("SELECT * FROM test WHERE id = %s", ()), + ("SELECT * FROM test WHERE id = %s", (1, "test")), + ("SELECT * FROM test WHERE id = %s AND val = ?", (1,)), + ("SELECT * FROM test WHERE id = %s AND val = ?", ()), + ("SELECT * FROM test WHERE id = %s AND val = ?", (1, "test", True)), + ] + + for sql, args in statements: + assert_invalid_count(sql, *args) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_missing_numeric(self, *_): + def assert_missing_numeric(sql, *args): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): + Statement("", sql, *args) + + statements = [ + ("SELECT * FROM test WHERE id = :1", ()), + ("SELECT * FROM test WHERE id = :1 AND val = :2", ()), + ("SELECT * FROM test WHERE id = :1 AND val = :2", (1,)), + ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", ()), + ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1,)), + ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1, "test")), + ] + + for sql, args in statements: + assert_missing_numeric(sql, *args) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_unused_numeric(self, *_): + def assert_unused_numeric(sql, *args): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): + Statement("", sql, *args) + + statements = [ + ("SELECT * FROM test WHERE id = :1", (1, "test")), + ("SELECT * FROM test WHERE id = :1", (1, "test", True)), + ("SELECT * FROM test WHERE id = :1 AND val = :2", (1, "test", True)), + ] + + for sql, args in statements: + assert_unused_numeric(sql, *args) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_missing_named(self, *_): + def assert_missing_named(sql, **kwargs): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): + Statement("", sql, **kwargs) + + statements = [ + ("SELECT * FROM test WHERE id = :id", {}), + ("SELECT * FROM test WHERE id = :id AND val = :val", {}), + ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1, "val": "test"}), + ] + + for sql, kwargs in statements: + assert_missing_named(sql, **kwargs) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_unused_named(self, *_): + def assert_unused_named(sql, **kwargs): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): + Statement("", sql, **kwargs) + + statements = [ + ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}), + ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test", "is_valid": True}), + ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1, "val": "test", "is_valid": True}), + ] + + for sql, kwargs in statements: + assert_unused_named(sql, **kwargs) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_missing_pyformat(self, *_): + def assert_missing_pyformat(sql, **kwargs): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): + Statement("", sql, **kwargs) + + statements = [ + ("SELECT * FROM test WHERE id = %(id)s", {}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1, "val": "test"}), + ] + + for sql, kwargs in statements: + assert_missing_pyformat(sql, **kwargs) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_unused_pyformat(self, *_): + def assert_unused_pyformat(sql, **kwargs): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): + Statement("", sql, **kwargs) + + statements = [ + ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}), + ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test", "is_valid": True}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1, "val": "test", "is_valid": True}), + ] + + for sql, kwargs in statements: + assert_unused_pyformat(sql, **kwargs) + + def test_multiple_statements(self): + def assert_raises_runtimeerror(sql): + with self.assertRaises(RuntimeError): + Statement("", sql) + + statements = [ + "SELECT 1; SELECT 2;", + "SELECT 1; SELECT 2", + "SELECT 1; SELECT 2; SELECT 3", + "SELECT 1; SELECT 2; SELECT 3;", + "SELECT 1;SELECT 2", + "select 1; select 2", + "select 1;select 2", + "DELETE FROM test; SELECT * FROM test", + ] + + for sql in statements: + assert_raises_runtimeerror(sql) + + def test_get_operation_keyword(self): + def test_raw_and_lowercase(sql, keyword): + statement = Statement("", sql) + self.assertEqual(statement.get_operation_keyword(), keyword) + + statement = Statement("", sql.lower()) + self.assertEqual(statement.get_operation_keyword(), keyword) + + + statements = [ + ("SELECT * FROM test", "SELECT"), + ("INSERT INTO test (id, val) VALUES (1, 'test')", "INSERT"), + ("DELETE FROM test", "DELETE"), + ("UPDATE test SET id = 2", "UPDATE"), + ("START TRANSACTION", "START"), + ("BEGIN", "BEGIN"), + ("COMMIT", "COMMIT"), + ("ROLLBACK", "ROLLBACK"), + ] + + for sql, keyword in statements: + test_raw_and_lowercase(sql, keyword) From 9302a1e8b8a82ac2a206cee0fb340f0dd82cf153 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 21:19:13 -0400 Subject: [PATCH 32/47] move operation check to Statement --- src/cs50/_sql_util.py | 8 ---- src/cs50/_statement.py | 20 +++++++-- src/cs50/sql.py | 13 +++--- tests/test_statement.py | 95 ++++++++++++++++++++++++----------------- 4 files changed, 79 insertions(+), 57 deletions(-) diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index dbaff2e..238d979 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -3,14 +3,6 @@ import decimal -def is_transaction_start(keyword): - return keyword in {"BEGIN", "START"} - - -def is_transaction_end(keyword): - return keyword in {"COMMIT", "ROLLBACK"} - - def fetch_select_result(result): rows = [dict(row) for row in result.fetchall()] for row in rows: diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 3347f61..2502284 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -147,9 +147,23 @@ def _escape_verbatim_colons(self): if is_string_literal(token.ttype) or is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) - def get_operation_keyword(self): - """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" - return self._operation_keyword + def is_transaction_start(self): + return self._operation_keyword in {"BEGIN", "START"} + + def is_transaction_end(self): + return self._operation_keyword in {"COMMIT", "ROLLBACK"} + + def is_delete(self): + return self._operation_keyword == "DELETE" + + def is_insert(self): + return self._operation_keyword == "INSERT" + + def is_select(self): + return self._operation_keyword == "SELECT" + + def is_update(self): + return self._operation_keyword == "UPDATE" def __str__(self): return "".join([str(token) for token in self._tokens]) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index ae9f97e..c0e41fd 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -8,7 +8,7 @@ from ._session import Session from ._statement import Statement -from ._sql_util import fetch_select_result, is_transaction_start, is_transaction_end +from ._sql_util import fetch_select_result _logger = logging.getLogger("cs50") @@ -25,8 +25,7 @@ def __init__(self, url, **engine_kwargs): def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = Statement(self._dialect, sql, *args, **kwargs) - operation_keyword = statement.get_operation_keyword() - if is_transaction_start(operation_keyword): + if statement.is_transaction_start(): self._autocommit = False if self._autocommit: @@ -37,14 +36,14 @@ def execute(self, sql, *args, **kwargs): if self._autocommit: self._session.execute("COMMIT") - if is_transaction_end(operation_keyword): + if statement.is_transaction_end(): self._autocommit = True - if operation_keyword == "SELECT": + if statement.is_select(): ret = fetch_select_result(result) - elif operation_keyword == "INSERT": + elif statement.is_insert(): ret = self._last_row_id_or_none(result) - elif operation_keyword in {"DELETE", "UPDATE"}: + elif statement.is_delete() or statement.is_update(): ret = result.rowcount else: ret = True diff --git a/tests/test_statement.py b/tests/test_statement.py index cbbafe8..fcee3b9 100644 --- a/tests/test_statement.py +++ b/tests/test_statement.py @@ -9,24 +9,24 @@ class TestStatement(unittest.TestCase): # TODO assert correct exception messages def test_mutex_args_and_kwargs(self): with self.assertRaises(RuntimeError): - Statement("", "", "test", foo="foo") + Statement(None, None, "test", foo="foo") with self.assertRaises(RuntimeError): - Statement("", "", "test", 1, 2, foo="foo", bar="bar") + Statement(None, None, "test", 1, 2, foo="foo", bar="bar") @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") def test_valid_qmark_count(self, *_): - Statement("", "SELECT * FROM test WHERE id = ?", 1) - Statement("", "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') - Statement("", "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) + Statement(None, "SELECT * FROM test WHERE id = ?", 1) + Statement(None, "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') + Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") def test_invalid_qmark_count(self, *_): def assert_invalid_count(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement("", sql, *args) + Statement(None, sql, *args) statements = [ ("SELECT * FROM test WHERE id = ?", ()), @@ -43,16 +43,16 @@ def assert_invalid_count(sql, *args): @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") def test_valid_format_count(self, *_): - Statement("", "SELECT * FROM test WHERE id = %s", 1) - Statement("", "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') - Statement("", "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) + Statement(None, "SELECT * FROM test WHERE id = %s", 1) + Statement(None, "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') + Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") def test_invalid_format_count(self, *_): def assert_invalid_count(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement("", sql, *args) + Statement(None, sql, *args) statements = [ ("SELECT * FROM test WHERE id = %s", ()), @@ -70,7 +70,7 @@ def assert_invalid_count(sql, *args): def test_missing_numeric(self, *_): def assert_missing_numeric(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement("", sql, *args) + Statement(None, sql, *args) statements = [ ("SELECT * FROM test WHERE id = :1", ()), @@ -89,7 +89,7 @@ def assert_missing_numeric(sql, *args): def test_unused_numeric(self, *_): def assert_unused_numeric(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement("", sql, *args) + Statement(None, sql, *args) statements = [ ("SELECT * FROM test WHERE id = :1", (1, "test")), @@ -105,7 +105,7 @@ def assert_unused_numeric(sql, *args): def test_missing_named(self, *_): def assert_missing_named(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement("", sql, **kwargs) + Statement(None, sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = :id", {}), @@ -124,7 +124,7 @@ def assert_missing_named(sql, **kwargs): def test_unused_named(self, *_): def assert_unused_named(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement("", sql, **kwargs) + Statement(None, sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}), @@ -140,7 +140,7 @@ def assert_unused_named(sql, **kwargs): def test_missing_pyformat(self, *_): def assert_missing_pyformat(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement("", sql, **kwargs) + Statement(None, sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = %(id)s", {}), @@ -159,7 +159,7 @@ def assert_missing_pyformat(sql, **kwargs): def test_unused_pyformat(self, *_): def assert_unused_pyformat(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement("", sql, **kwargs) + Statement(None, sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}), @@ -173,7 +173,7 @@ def assert_unused_pyformat(sql, **kwargs): def test_multiple_statements(self): def assert_raises_runtimeerror(sql): with self.assertRaises(RuntimeError): - Statement("", sql) + Statement(None, sql) statements = [ "SELECT 1; SELECT 2;", @@ -189,25 +189,42 @@ def assert_raises_runtimeerror(sql): for sql in statements: assert_raises_runtimeerror(sql) - def test_get_operation_keyword(self): - def test_raw_and_lowercase(sql, keyword): - statement = Statement("", sql) - self.assertEqual(statement.get_operation_keyword(), keyword) - - statement = Statement("", sql.lower()) - self.assertEqual(statement.get_operation_keyword(), keyword) - - - statements = [ - ("SELECT * FROM test", "SELECT"), - ("INSERT INTO test (id, val) VALUES (1, 'test')", "INSERT"), - ("DELETE FROM test", "DELETE"), - ("UPDATE test SET id = 2", "UPDATE"), - ("START TRANSACTION", "START"), - ("BEGIN", "BEGIN"), - ("COMMIT", "COMMIT"), - ("ROLLBACK", "ROLLBACK"), - ] - - for sql, keyword in statements: - test_raw_and_lowercase(sql, keyword) + def test_is_delete(self): + self.assertTrue(Statement(None, "DELETE FROM test").is_delete()) + self.assertTrue(Statement(None, "delete FROM test").is_delete()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_delete()) + self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete()) + + def test_is_insert(self): + self.assertTrue(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert()) + self.assertTrue(Statement(None, "insert INTO test (id, val) VALUES (1, 'test')").is_insert()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_insert()) + self.assertFalse(Statement(None, "DELETE FROM test").is_insert()) + + def test_is_select(self): + self.assertTrue(Statement(None, "SELECT * FROM test").is_select()) + self.assertTrue(Statement(None, "select * FROM test").is_select()) + self.assertFalse(Statement(None, "DELETE FROM test").is_select()) + self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_select()) + + def test_is_update(self): + self.assertTrue(Statement(None, "UPDATE test SET id = 2").is_update()) + self.assertTrue(Statement(None, "update test SET id = 2").is_update()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_update()) + self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_update()) + + def test_is_transaction_start(self): + self.assertTrue(Statement(None, "START TRANSACTION").is_transaction_start()) + self.assertTrue(Statement(None, "start TRANSACTION").is_transaction_start()) + self.assertTrue(Statement(None, "BEGIN").is_transaction_start()) + self.assertTrue(Statement(None, "begin").is_transaction_start()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_start()) + self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_start()) + + def test_is_transaction_end(self): + self.assertTrue(Statement(None, "COMMIT").is_transaction_end()) + self.assertTrue(Statement(None, "commit").is_transaction_end()) + self.assertTrue(Statement(None, "ROLLBACK").is_transaction_end()) + self.assertTrue(Statement(None, "rollback").is_transaction_end()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_end()) + self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_end()) From f6912d27ba1ed250519eff6c434720a092679783 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Tue, 13 Apr 2021 22:31:51 -0400 Subject: [PATCH 33/47] use statement factory --- src/cs50/_statement.py | 13 ++- src/cs50/sql.py | 10 +-- tests/test_cs50.py | 9 --- tests/test_statement.py | 174 ++++++++++++++++++++-------------------- 4 files changed, 105 insertions(+), 101 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 2502284..cc4cdb8 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -16,14 +16,23 @@ ) +def statement_factory(dialect): + sql_sanitizer = SQLSanitizer(dialect) + + def statement(sql, *args, **kwargs): + return Statement(sql_sanitizer, sql, *args, **kwargs) + + return statement + + class Statement: """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" - def __init__(self, dialect, sql, *args, **kwargs): + def __init__(self, sql_sanitizer, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") - self._sql_sanitizer = SQLSanitizer(dialect) + self._sql_sanitizer = sql_sanitizer self._args = self._get_escaped_args(args) self._kwargs = self._get_escaped_kwargs(kwargs) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index c0e41fd..10bffd6 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -7,7 +7,7 @@ import termcolor from ._session import Session -from ._statement import Statement +from ._statement import statement_factory from ._sql_util import fetch_select_result _logger = logging.getLogger("cs50") @@ -18,13 +18,14 @@ class SQL: def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) - self._dialect = self._session.get_bind().dialect - self._is_postgres = self._dialect.name in {"postgres", "postgresql"} + dialect = self._session.get_bind().dialect + self._is_postgres = dialect.name in {"postgres", "postgresql"} + self._sanitized_statement = statement_factory(dialect) self._autocommit = False def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" - statement = Statement(self._dialect, sql, *args, **kwargs) + statement = self._sanitized_statement(sql, *args, **kwargs) if statement.is_transaction_start(): self._autocommit = False @@ -53,7 +54,6 @@ def execute(self, sql, *args, **kwargs): return ret - def _execute(self, statement): # Catch SQLAlchemy warnings with warnings.catch_warnings(): diff --git a/tests/test_cs50.py b/tests/test_cs50.py index a58424d..dd0f14b 100644 --- a/tests/test_cs50.py +++ b/tests/test_cs50.py @@ -14,34 +14,29 @@ def test_get_string_empty_input(self, mock_get_input): self.assertEqual(get_string("Answer: "), "") mock_get_input.assert_called_with("Answer: ") - @patch("cs50.cs50._get_input", return_value="test") def test_get_string_nonempty_input(self, mock_get_input): """Returns the provided non-empty input""" self.assertEqual(get_string("Answer: "), "test") mock_get_input.assert_called_with("Answer: ") - @patch("cs50.cs50._get_input", side_effect=EOFError) def test_get_string_eof(self, mock_get_input): """Returns None on EOF""" self.assertIs(get_string("Answer: "), None) mock_get_input.assert_called_with("Answer: ") - def test_get_string_invalid_prompt(self): """Raises TypeError when prompt is not str""" with self.assertRaises(TypeError): get_string(1) - @patch("cs50.cs50.get_string", return_value=None) def test_get_int_eof(self, mock_get_string): """Returns None on EOF""" self.assertIs(_get_int("Answer: "), None) mock_get_string.assert_called_with("Answer: ") - def test_get_int_valid_input(self): """Returns the provided integer input""" @@ -62,7 +57,6 @@ def assert_equal(return_value, expected_value): for return_value, expected_value in values: assert_equal(return_value, expected_value) - def test_get_int_invalid_input(self): """Raises ValueError when input is invalid base-10 int""" @@ -90,14 +84,12 @@ def assert_raises_valueerror(return_value): for return_value in return_values: assert_raises_valueerror(return_value) - @patch("cs50.cs50.get_string", return_value=None) def test_get_float_eof(self, mock_get_string): """Returns None on EOF""" self.assertIs(_get_float("Answer: "), None) mock_get_string.assert_called_with("Answer: ") - def test_get_float_valid_input(self): """Returns the provided integer input""" def assert_equal(return_value, expected_value): @@ -121,7 +113,6 @@ def assert_equal(return_value, expected_value): for return_value, expected_value in values: assert_equal(return_value, expected_value) - def test_get_float_invalid_input(self): """Raises ValueError when input is invalid float""" diff --git a/tests/test_statement.py b/tests/test_statement.py index fcee3b9..91261cd 100644 --- a/tests/test_statement.py +++ b/tests/test_statement.py @@ -5,28 +5,29 @@ from cs50._statement import Statement from cs50._sql_sanitizer import SQLSanitizer + +@patch.object(SQLSanitizer, "escape", return_value="test") class TestStatement(unittest.TestCase): # TODO assert correct exception messages - def test_mutex_args_and_kwargs(self): + def test_mutex_args_and_kwargs(self, MockSQLSanitizer): with self.assertRaises(RuntimeError): - Statement(None, None, "test", foo="foo") + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ? AND val = :val", 1, val="test") with self.assertRaises(RuntimeError): - Statement(None, None, "test", 1, 2, foo="foo", bar="bar") + Statement(MockSQLSanitizer(), "SELECT * FROM test", "test", 1, 2, foo="foo", bar="bar") - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_valid_qmark_count(self, *_): - Statement(None, "SELECT * FROM test WHERE id = ?", 1) - Statement(None, "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') - Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) + def test_valid_qmark_count(self, MockSQLSanitizer, *_): + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ?", 1) + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') + Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_invalid_qmark_count(self, *_): + def test_invalid_qmark_count(self, MockSQLSanitizer, *_): def assert_invalid_count(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(None, sql, *args) + Statement(MockSQLSanitizer(), sql, *args) statements = [ ("SELECT * FROM test WHERE id = ?", ()), @@ -39,20 +40,18 @@ def assert_invalid_count(sql, *args): for sql, args in statements: assert_invalid_count(sql, *args) - - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_valid_format_count(self, *_): - Statement(None, "SELECT * FROM test WHERE id = %s", 1) - Statement(None, "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') - Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) + def test_valid_format_count(self, MockSQLSanitizer, *_): + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = %s", 1) + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') + Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_invalid_format_count(self, *_): + def test_invalid_format_count(self, MockSQLSanitizer, *_): def assert_invalid_count(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(None, sql, *args) + Statement(MockSQLSanitizer(), sql, *args) statements = [ ("SELECT * FROM test WHERE id = %s", ()), @@ -65,12 +64,11 @@ def assert_invalid_count(sql, *args): for sql, args in statements: assert_invalid_count(sql, *args) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_numeric(self, *_): + def test_missing_numeric(self, MockSQLSanitizer, *_): def assert_missing_numeric(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(None, sql, *args) + Statement(MockSQLSanitizer(), sql, *args) statements = [ ("SELECT * FROM test WHERE id = :1", ()), @@ -84,12 +82,11 @@ def assert_missing_numeric(sql, *args): for sql, args in statements: assert_missing_numeric(sql, *args) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_numeric(self, *_): + def test_unused_numeric(self, MockSQLSanitizer, *_): def assert_unused_numeric(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(None, sql, *args) + Statement(MockSQLSanitizer(), sql, *args) statements = [ ("SELECT * FROM test WHERE id = :1", (1, "test")), @@ -100,80 +97,82 @@ def assert_unused_numeric(sql, *args): for sql, args in statements: assert_unused_numeric(sql, *args) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_named(self, *_): + def test_missing_named(self, MockSQLSanitizer, *_): def assert_missing_named(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(None, sql, **kwargs) + Statement(MockSQLSanitizer(), sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = :id", {}), ("SELECT * FROM test WHERE id = :id AND val = :val", {}), ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1}), ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1, "val": "test"}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", + {"id": 1}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", + {"id": 1, "val": "test"}), ] for sql, kwargs in statements: assert_missing_named(sql, **kwargs) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_named(self, *_): + def test_unused_named(self, MockSQLSanitizer, *_): def assert_unused_named(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(None, sql, **kwargs) + Statement(MockSQLSanitizer(), sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}), ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test", "is_valid": True}), - ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1, "val": "test", "is_valid": True}), + ("SELECT * FROM test WHERE id = :id AND val = :val", + {"id": 1, "val": "test", "is_valid": True}), ] for sql, kwargs in statements: assert_unused_named(sql, **kwargs) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_pyformat(self, *_): + def test_missing_pyformat(self, MockSQLSanitizer, *_): def assert_missing_pyformat(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(None, sql, **kwargs) + Statement(MockSQLSanitizer(), sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = %(id)s", {}), ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {}), ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1}), ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1, "val": "test"}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", + {"id": 1}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", + {"id": 1, "val": "test"}), ] for sql, kwargs in statements: assert_missing_pyformat(sql, **kwargs) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_pyformat(self, *_): + def test_unused_pyformat(self, MockSQLSanitizer, *_): def assert_unused_pyformat(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(None, sql, **kwargs) + Statement(MockSQLSanitizer(), sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}), ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test", "is_valid": True}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1, "val": "test", "is_valid": True}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", + {"id": 1, "val": "test", "is_valid": True}), ] for sql, kwargs in statements: assert_unused_pyformat(sql, **kwargs) - def test_multiple_statements(self): + def test_multiple_statements(self, MockSQLSanitizer): def assert_raises_runtimeerror(sql): with self.assertRaises(RuntimeError): - Statement(None, sql) + Statement(MockSQLSanitizer(), sql) statements = [ "SELECT 1; SELECT 2;", @@ -189,42 +188,47 @@ def assert_raises_runtimeerror(sql): for sql in statements: assert_raises_runtimeerror(sql) - def test_is_delete(self): - self.assertTrue(Statement(None, "DELETE FROM test").is_delete()) - self.assertTrue(Statement(None, "delete FROM test").is_delete()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_delete()) - self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete()) - - def test_is_insert(self): - self.assertTrue(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert()) - self.assertTrue(Statement(None, "insert INTO test (id, val) VALUES (1, 'test')").is_insert()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_insert()) - self.assertFalse(Statement(None, "DELETE FROM test").is_insert()) - - def test_is_select(self): - self.assertTrue(Statement(None, "SELECT * FROM test").is_select()) - self.assertTrue(Statement(None, "select * FROM test").is_select()) - self.assertFalse(Statement(None, "DELETE FROM test").is_select()) - self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_select()) - - def test_is_update(self): - self.assertTrue(Statement(None, "UPDATE test SET id = 2").is_update()) - self.assertTrue(Statement(None, "update test SET id = 2").is_update()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_update()) - self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_update()) - - def test_is_transaction_start(self): - self.assertTrue(Statement(None, "START TRANSACTION").is_transaction_start()) - self.assertTrue(Statement(None, "start TRANSACTION").is_transaction_start()) - self.assertTrue(Statement(None, "BEGIN").is_transaction_start()) - self.assertTrue(Statement(None, "begin").is_transaction_start()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_start()) - self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_start()) - - def test_is_transaction_end(self): - self.assertTrue(Statement(None, "COMMIT").is_transaction_end()) - self.assertTrue(Statement(None, "commit").is_transaction_end()) - self.assertTrue(Statement(None, "ROLLBACK").is_transaction_end()) - self.assertTrue(Statement(None, "rollback").is_transaction_end()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_end()) - self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_end()) + def test_is_delete(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "DELETE FROM test").is_delete()) + self.assertTrue(Statement(MockSQLSanitizer(), "delete FROM test").is_delete()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_delete()) + self.assertFalse(Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete()) + + def test_is_insert(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert()) + self.assertTrue(Statement(MockSQLSanitizer(), + "insert INTO test (id, val) VALUES (1, 'test')").is_insert()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_insert()) + self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_insert()) + + def test_is_select(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_select()) + self.assertTrue(Statement(MockSQLSanitizer(), "select * FROM test").is_select()) + self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_select()) + self.assertFalse(Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val) VALUES (1, 'test')").is_select()) + + def test_is_update(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "UPDATE test SET id = 2").is_update()) + self.assertTrue(Statement(MockSQLSanitizer(), "update test SET id = 2").is_update()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_update()) + self.assertFalse(Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val) VALUES (1, 'test')").is_update()) + + def test_is_transaction_start(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "START TRANSACTION").is_transaction_start()) + self.assertTrue(Statement(MockSQLSanitizer(), "start TRANSACTION").is_transaction_start()) + self.assertTrue(Statement(MockSQLSanitizer(), "BEGIN").is_transaction_start()) + self.assertTrue(Statement(MockSQLSanitizer(), "begin").is_transaction_start()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_transaction_start()) + self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_transaction_start()) + + def test_is_transaction_end(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "COMMIT").is_transaction_end()) + self.assertTrue(Statement(MockSQLSanitizer(), "commit").is_transaction_end()) + self.assertTrue(Statement(MockSQLSanitizer(), "ROLLBACK").is_transaction_end()) + self.assertTrue(Statement(MockSQLSanitizer(), "rollback").is_transaction_end()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_transaction_end()) + self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_transaction_end()) From 08a66362335e7e832bb1d366711556a2aed9fe37 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Wed, 14 Apr 2021 08:05:45 -0400 Subject: [PATCH 34/47] use remove instead of rollback --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 10bffd6..0e7ee8a 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -64,7 +64,7 @@ def _execute(self, statement): except sqlalchemy.exc.IntegrityError as exc: _logger.debug(termcolor.colored(str(statement), "yellow")) if self._autocommit: - self._session.execute("ROLLBACK") + self._session.remove() raise ValueError(exc.orig) from None except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: self._session.remove() From 7f9b77c364c5893427db8e27b77427e7d5872f90 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Wed, 14 Apr 2021 08:06:24 -0400 Subject: [PATCH 35/47] rename _sanitized_statement --- src/cs50/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 0e7ee8a..4bf82e5 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -20,12 +20,12 @@ def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) dialect = self._session.get_bind().dialect self._is_postgres = dialect.name in {"postgres", "postgresql"} - self._sanitized_statement = statement_factory(dialect) + self._sanitize_statement = statement_factory(dialect) self._autocommit = False def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" - statement = self._sanitized_statement(sql, *args, **kwargs) + statement = self._sanitize_statement(sql, *args, **kwargs) if statement.is_transaction_start(): self._autocommit = False From 944934fa1048081293c385b7cd988d71cdc1496d Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Wed, 14 Apr 2021 08:21:01 -0400 Subject: [PATCH 36/47] abstract away catch_warnings --- src/cs50/_sql_util.py | 9 +++++++++ src/cs50/sql.py | 8 ++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index 238d979..52538ad 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -1,6 +1,8 @@ """Utility functions used by sql.py""" +import contextlib import decimal +import warnings def fetch_select_result(result): @@ -17,3 +19,10 @@ def fetch_select_result(result): row[column] = bytes(row[column]) return rows + + +@contextlib.contextmanager +def raise_errors_for_warnings(): + with warnings.catch_warnings(): + warnings.simplefilter("error") + yield diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 4bf82e5..0486214 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,14 +1,13 @@ """Wraps SQLAlchemy""" import logging -import warnings import sqlalchemy import termcolor from ._session import Session from ._statement import statement_factory -from ._sql_util import fetch_select_result +from ._sql_util import fetch_select_result, raise_errors_for_warnings _logger = logging.getLogger("cs50") @@ -55,10 +54,7 @@ def execute(self, sql, *args, **kwargs): return ret def _execute(self, statement): - # Catch SQLAlchemy warnings - with warnings.catch_warnings(): - # Raise exceptions for warnings - warnings.simplefilter("error") + with raise_errors_for_warnings(): try: result = self._session.execute(statement) except sqlalchemy.exc.IntegrityError as exc: From 006657e0a50aeba1f80a76a7016dc1b187f7a4d2 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Wed, 14 Apr 2021 09:48:20 -0400 Subject: [PATCH 37/47] remove BEGIN and COMMIT --- src/cs50/sql.py | 46 ++++++++++++++++++---------------------------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 0486214..9aab897 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -20,24 +20,31 @@ def __init__(self, url, **engine_kwargs): dialect = self._session.get_bind().dialect self._is_postgres = dialect.name in {"postgres", "postgresql"} self._sanitize_statement = statement_factory(dialect) - self._autocommit = False + self._outside_transaction = True def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = self._sanitize_statement(sql, *args, **kwargs) - if statement.is_transaction_start(): - self._autocommit = False - - if self._autocommit: - self._session.execute("BEGIN") + try: + with raise_errors_for_warnings(): + result = self._session.execute(statement) + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(termcolor.colored(str(statement), "yellow")) + if self._outside_transaction: + self._session.remove() + raise ValueError(exc.orig) from None + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + self._session.remove() + _logger.debug(termcolor.colored(statement, "red")) + raise RuntimeError(exc.orig) from None - result = self._execute(statement) + if statement.is_transaction_start(): + self._outside_transaction = False - if self._autocommit: - self._session.execute("COMMIT") + _logger.debug(termcolor.colored(str(statement), "green")) if statement.is_transaction_end(): - self._autocommit = True + self._outside_transaction = True if statement.is_select(): ret = fetch_select_result(result) @@ -48,28 +55,11 @@ def execute(self, sql, *args, **kwargs): else: ret = True - if self._autocommit: + if self._outside_transaction: self._session.remove() return ret - def _execute(self, statement): - with raise_errors_for_warnings(): - try: - result = self._session.execute(statement) - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(termcolor.colored(str(statement), "yellow")) - if self._autocommit: - self._session.remove() - raise ValueError(exc.orig) from None - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - self._session.remove() - _logger.debug(termcolor.colored(statement, "red")) - raise RuntimeError(exc.orig) from None - - _logger.debug(termcolor.colored(str(statement), "green")) - return result - def _last_row_id_or_none(self, result): if self._is_postgres: return self._get_last_val() From 36fb280771e1985d95bdd75e0410afcee1352035 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Wed, 14 Apr 2021 10:28:46 -0400 Subject: [PATCH 38/47] Revert "remove BEGIN and COMMIT" This reverts commit 006657e0a50aeba1f80a76a7016dc1b187f7a4d2. --- src/cs50/sql.py | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 9aab897..0486214 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -20,31 +20,24 @@ def __init__(self, url, **engine_kwargs): dialect = self._session.get_bind().dialect self._is_postgres = dialect.name in {"postgres", "postgresql"} self._sanitize_statement = statement_factory(dialect) - self._outside_transaction = True + self._autocommit = False def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = self._sanitize_statement(sql, *args, **kwargs) - try: - with raise_errors_for_warnings(): - result = self._session.execute(statement) - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(termcolor.colored(str(statement), "yellow")) - if self._outside_transaction: - self._session.remove() - raise ValueError(exc.orig) from None - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - self._session.remove() - _logger.debug(termcolor.colored(statement, "red")) - raise RuntimeError(exc.orig) from None - if statement.is_transaction_start(): - self._outside_transaction = False + self._autocommit = False + + if self._autocommit: + self._session.execute("BEGIN") - _logger.debug(termcolor.colored(str(statement), "green")) + result = self._execute(statement) + + if self._autocommit: + self._session.execute("COMMIT") if statement.is_transaction_end(): - self._outside_transaction = True + self._autocommit = True if statement.is_select(): ret = fetch_select_result(result) @@ -55,11 +48,28 @@ def execute(self, sql, *args, **kwargs): else: ret = True - if self._outside_transaction: + if self._autocommit: self._session.remove() return ret + def _execute(self, statement): + with raise_errors_for_warnings(): + try: + result = self._session.execute(statement) + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(termcolor.colored(str(statement), "yellow")) + if self._autocommit: + self._session.remove() + raise ValueError(exc.orig) from None + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + self._session.remove() + _logger.debug(termcolor.colored(statement, "red")) + raise RuntimeError(exc.orig) from None + + _logger.debug(termcolor.colored(str(statement), "green")) + return result + def _last_row_id_or_none(self, result): if self._is_postgres: return self._get_last_val() From a6668c093cbe005337aae64147582c1bb26e30ef Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Wed, 14 Apr 2021 12:33:50 -0400 Subject: [PATCH 39/47] rename methods --- src/cs50/_statement.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index cc4cdb8..2e286e9 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -44,7 +44,7 @@ def __init__(self, sql_sanitizer, sql, *args, **kwargs): self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() - self._plugin_escaped_params() + self._substitute_markers_with_escaped_params() self._escape_verbatim_colons() def _get_escaped_args(self, args): @@ -100,15 +100,15 @@ def _get_placeholders(self): return placeholders - def _plugin_escaped_params(self): + def _substitute_markers_with_escaped_params(self): if self._paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: - self._plugin_format_or_qmark_params() + self._substitute_format_or_qmark_markers() elif self._paramstyle == Paramstyle.NUMERIC: - self._plugin_numeric_params() + self._substitue_numeric_markers() if self._paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: - self._plugin_named_or_pyformat_params() + self._substitute_named_or_pyformat_markers() - def _plugin_format_or_qmark_params(self): + def _substitute_format_or_qmark_markers(self): self._assert_valid_arg_count() for arg_index, token_index in enumerate(self._placeholders.keys()): self._tokens[token_index] = self._args[arg_index] @@ -122,7 +122,7 @@ def _assert_valid_arg_count(self): raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") - def _plugin_numeric_params(self): + def _substitue_numeric_markers(self): unused_arg_indices = set(range(len(self._args))) for token_index, num in self._placeholders.items(): if num >= len(self._args): @@ -137,7 +137,7 @@ def _plugin_numeric_params(self): raise RuntimeError( f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") - def _plugin_named_or_pyformat_params(self): + def _substitute_named_or_pyformat_markers(self): unused_params = set(self._kwargs.keys()) for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: From 1440ae57a6437c2b1c0389958122c5524d43f5b9 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Wed, 14 Apr 2021 14:57:35 -0400 Subject: [PATCH 40/47] add docstrings --- src/cs50/__init__.py | 2 - src/cs50/_logger.py | 38 ++++++++-- src/cs50/_session.py | 14 ++-- src/cs50/_session_util.py | 10 ++- src/cs50/_sql_sanitizer.py | 25 +++++-- src/cs50/_sql_util.py | 14 +++- src/cs50/_statement.py | 72 ++++++++++++++++++- src/cs50/_statement_util.py | 19 ++++- src/cs50/sql.py | 137 +++++++++++++++++++++++++----------- 9 files changed, 261 insertions(+), 70 deletions(-) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index fa07171..e5ec787 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,5 +1,3 @@ -"""Exposes API and sets up logging""" - from .cs50 import get_float, get_int, get_string from .sql import SQL from ._logger import _setup_logger diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py index 1307e19..e7b03ca 100644 --- a/src/cs50/_logger.py +++ b/src/cs50/_logger.py @@ -1,4 +1,5 @@ -"""Sets up logging for cs50 library""" +"""Sets up logging for the library. +""" import logging import os.path @@ -9,6 +10,22 @@ import termcolor +def green(msg): + return _colored(msg, "green") + + +def red(msg): + return _colored(msg, "red") + + +def yellow(msg): + return _colored(msg, "yellow") + + +def _colored(msg, color): + return termcolor.colored(str(msg), color) + + def _setup_logger(): _configure_default_logger() _patch_root_handler_format_exception() @@ -17,11 +34,16 @@ def _setup_logger(): def _configure_default_logger(): - """Configure default handler and formatter to prevent flask and werkzeug from adding theirs""" + """Configures a default handler and formatter to prevent flask and werkzeug from adding theirs. + """ + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) def _patch_root_handler_format_exception(): + """Patches formatException for the root handler to use ``_format_exception``. + """ + try: formatter = logging.root.handlers[0].formatter formatter.formatException = lambda exc_info: _format_exception(*exc_info) @@ -30,6 +52,10 @@ def _patch_root_handler_format_exception(): def _configure_cs50_logger(): + """Disables the cs50 logger by default. Disables logging propagation to prevent messages from + being logged more than once. Sets the logging handler and formatter. + """ + _logger = logging.getLogger("cs50") _logger.disabled = True _logger.setLevel(logging.DEBUG) @@ -52,9 +78,8 @@ def _patch_excepthook(): def _format_exception(type_, value, exc_tb): - """ - Format traceback, darkening entries from global site-packages directories - and user-specific site-packages directory. + """Formats traceback, darkening entries from global site-packages directories and user-specific + site-packages directory. https://stackoverflow.com/a/46071447/5156190 """ @@ -69,6 +94,5 @@ def _format_exception(type_, value, exc_tb): lines += line else: matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append( - matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) + lines.append(matches.group(1) + yellow(matches.group(2)) + matches.group(3)) return "".join(lines).rstrip() diff --git a/src/cs50/_session.py b/src/cs50/_session.py index c1ea426..f28c30a 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -1,5 +1,3 @@ -"""Wraps a SQLAlchemy scoped session""" - import sqlalchemy import sqlalchemy.orm @@ -11,7 +9,8 @@ class Session: - """Wraps a SQLAlchemy scoped session""" + """Wraps a SQLAlchemy scoped session. + """ def __init__(self, url, **engine_kwargs): if is_sqlite_url(url): @@ -20,9 +19,16 @@ def __init__(self, url, **engine_kwargs): self._session = create_session(url, **engine_kwargs) def execute(self, statement): - """Converts statement to str and executes it""" + """Converts statement to str and executes it. + + :param statement: The SQL statement to be executed + """ + # pylint: disable=no-member return self._session.execute(sqlalchemy.text(str(statement))) def __getattr__(self, attr): + """Proxies any attributes to the underlying SQLAlchemy scoped session. + """ + return getattr(self._session, attr) diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py index ed44eaa..01983b5 100644 --- a/src/cs50/_session_util.py +++ b/src/cs50/_session_util.py @@ -1,4 +1,5 @@ -"""Utility functions used by _session.py""" +"""Utility functions used by _session.py. +""" import os import sqlite3 @@ -49,8 +50,11 @@ def _create_scoped_session(engine): def _disable_auto_begin_commit(dbapi_connection): - # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves - # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + """Disables the underlying API's own emitting of BEGIN and COMMIT so we can support manual + transactions. + https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + """ + dbapi_connection.isolation_level = None diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py index f4ff3e0..17fc5fa 100644 --- a/src/cs50/_sql_sanitizer.py +++ b/src/cs50/_sql_sanitizer.py @@ -1,5 +1,3 @@ -"""Escapes SQL values""" - import datetime import re @@ -8,15 +6,19 @@ class SQLSanitizer: - """Escapes SQL values""" + """Sanitizes SQL values. + """ def __init__(self, dialect): self._dialect = dialect def escape(self, value): - """ - Escapes value using engine's conversion function. + """Escapes value using engine's conversion function. https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor + + :param value: The value to be sanitized + + :returns: The sanitized value """ # pylint: disable=too-many-return-statements if isinstance(value, (list, tuple)): @@ -71,13 +73,22 @@ def escape(self, value): raise RuntimeError(f"unsupported value: {value}") def escape_iterable(self, iterable): - """Escapes a collection of values (e.g., list, tuple)""" + """Escapes each value in iterable and joins all the escaped values with ", ", formatted for + SQL's ``IN`` operator. + + :param: An iterable of values to be escaped + + :returns: A comma-separated list of escaped values from ``iterable`` + :rtype: :class:`sqlparse.sql.TokenList` + """ + return sqlparse.sql.TokenList( sqlparse.parse(", ".join([str(self.escape(v)) for v in iterable]))) def escape_verbatim_colon(value): - """Escapes verbatim colon from a value so as it is not confused with a placeholder""" + """Escapes verbatim colon from a value so as it is not confused with a parameter marker. + """ # E.g., ':foo, ":foo, :foo will be replaced with # '\:foo, "\:foo, \:foo respectively diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index 52538ad..0b0c27b 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -1,11 +1,18 @@ -"""Utility functions used by sql.py""" +"""Utility functions used by sql.py. +""" import contextlib import decimal import warnings -def fetch_select_result(result): +def process_select_result(result): + """Converts a SQLAlchemy result to a ``list`` of ``dict`` objects, each of which represents a + row in the result set. + + :param result: A SQLAlchemy result + :type result: :class:`sqlalchemy.engine.Result` + """ rows = [dict(row) for row in result.fetchall()] for row in rows: for column in row: @@ -23,6 +30,9 @@ def fetch_select_result(result): @contextlib.contextmanager def raise_errors_for_warnings(): + """Catches warnings and raises errors instead. + """ + with warnings.catch_warnings(): warnings.simplefilter("error") yield diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 2e286e9..79e77d8 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,5 +1,3 @@ -"""Parses a SQL statement and replaces the placeholders with the corresponding parameters""" - import collections from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon @@ -17,6 +15,13 @@ def statement_factory(dialect): + """Creates a sanitizer for ``dialect`` and injects it into ``Statement``, exposing a simpler + interface for ``Statement``. + + :param dialect: a SQLAlchemy dialect + :type dialect: :class:`sqlalchemy.engine.Dialect` + """ + sql_sanitizer = SQLSanitizer(dialect) def statement(sql, *args, **kwargs): @@ -26,9 +31,23 @@ def statement(sql, *args, **kwargs): class Statement: - """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" + """Parses a SQL statement and substitutes any parameter markers with their corresponding + placeholders. + """ def __init__(self, sql_sanitizer, sql, *args, **kwargs): + """ + :param sql_sanitizer: The SQL sanitizer used to sanitize the parameters + :type sql_sanitizer: :class:`_sql_sanitizer.SQLSanitizer` + + :param sql: The SQL statement + :type sql: str + + :param *args: Zero or more positional parameters to be substituted for the parameter markers + + :param *kwargs: Zero or more keyword arguments to be substituted for the parameter markers + """ + if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -54,9 +73,18 @@ def _get_escaped_kwargs(self, kwargs): return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()} def _tokenize(self): + """ + :returns: A flattened list of SQLParse tokens that represent the SQL statement + """ + return list(self._statement.flatten()) def _get_operation_keyword(self): + """ + :returns: The operation keyword of the SQL statement (e.g., ``SELECT``, ``DELETE``, etc) + :rtype: str + """ + for token in self._statement: if is_operation_token(token.ttype): token_value = token.value.upper() @@ -69,6 +97,11 @@ def _get_operation_keyword(self): return operation_keyword def _get_paramstyle(self): + """ + :returns: The paramstyle used in the SQL statement (if any) + :rtype: :class:_statement_util.Paramstyle`` + """ + paramstyle = None for token in self._tokens: if is_placeholder(token.ttype): @@ -80,6 +113,11 @@ def _get_paramstyle(self): return paramstyle def _default_paramstyle(self): + """ + :returns: If positional args were passed, returns ``Paramstyle.QMARK``; if keyword arguments + were passed, returns ``Paramstyle.NAMED``; otherwise, returns ``None`` + """ + paramstyle = None if self._args: paramstyle = Paramstyle.QMARK @@ -89,6 +127,12 @@ def _default_paramstyle(self): return paramstyle def _get_placeholders(self): + """ + :returns: A dict that maps the index of each parameter marker in the tokens list to the name + of that parameter marker (if applicable) or ``None`` + :rtype: dict + """ + placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): if is_placeholder(token.ttype): @@ -109,11 +153,18 @@ def _substitute_markers_with_escaped_params(self): self._substitute_named_or_pyformat_markers() def _substitute_format_or_qmark_markers(self): + """Substitutes format or qmark parameter markers with their corresponding parameters. + """ + self._assert_valid_arg_count() for arg_index, token_index in enumerate(self._placeholders.keys()): self._tokens[token_index] = self._args[arg_index] def _assert_valid_arg_count(self): + """Raises a ``RuntimeError`` if the number of arguments does not match the number of + placeholders. + """ + if len(self._placeholders) != len(self._args): placeholders = get_human_readable_list(self._placeholders.values()) args = get_human_readable_list(self._args) @@ -123,6 +174,10 @@ def _assert_valid_arg_count(self): raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") def _substitue_numeric_markers(self): + """Substitutes numeric parameter markers with their corresponding parameters. Raises a + ``RuntimeError`` if any parameters are missing or unused. + """ + unused_arg_indices = set(range(len(self._args))) for token_index, num in self._placeholders.items(): if num >= len(self._args): @@ -138,6 +193,10 @@ def _substitue_numeric_markers(self): f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") def _substitute_named_or_pyformat_markers(self): + """Substitutes named or pyformat parameter markers with their corresponding parameters. + Raises a ``RuntimeError`` if any parameters are missing or unused. + """ + unused_params = set(self._kwargs.keys()) for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: @@ -152,6 +211,10 @@ def _substitute_named_or_pyformat_markers(self): f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") def _escape_verbatim_colons(self): + """Escapes verbatim colons from string literal and identifier tokens so they aren't treated + as parameter markers. + """ + for token in self._tokens: if is_string_literal(token.ttype) or is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) @@ -175,4 +238,7 @@ def is_update(self): return self._operation_keyword == "UPDATE" def __str__(self): + """Joins the statement tokens into a string. + """ + return "".join([str(token) for token in self._tokens]) diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py index 4ef092a..34ca6ff 100644 --- a/src/cs50/_statement_util.py +++ b/src/cs50/_statement_util.py @@ -1,4 +1,5 @@ -"""Utility functions used by _statement.py""" +"""Utility functions used by _statement.py. +""" import enum import re @@ -19,6 +20,9 @@ class Paramstyle(enum.Enum): + """Represents the supported parameter marker styles. + """ + FORMAT = enum.auto() NAMED = enum.auto() NUMERIC = enum.auto() @@ -27,6 +31,15 @@ class Paramstyle(enum.Enum): def format_and_parse(sql): + """Formats and parses a SQL statement. Raises ``RuntimeError`` if ``sql`` represents more than + one statement. + + :param sql: The SQL statement to be formatted and parsed + :type sql: str + + :returns: A list of unflattened SQLParse tokens that represent the parsed statement + """ + formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) statement_count = len(parsed_statements) @@ -43,6 +56,10 @@ def is_placeholder(ttype): def parse_placeholder(value): + """ + :returns: A tuple of the paramstyle and the name of the parameter marker (if any) or ``None`` + :rtype: tuple + """ if value == "?": return Paramstyle.QMARK, None diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 0486214..974137c 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,46 +1,60 @@ -"""Wraps SQLAlchemy""" - import logging import sqlalchemy -import termcolor +from ._logger import green, red, yellow from ._session import Session from ._statement import statement_factory -from ._sql_util import fetch_select_result, raise_errors_for_warnings +from ._sql_util import process_select_result, raise_errors_for_warnings + _logger = logging.getLogger("cs50") class SQL: - """Wraps SQLAlchemy""" + """An API for executing SQL Statements. + """ + + def __init__(self, url): + """ + :param url: The database URL + """ - def __init__(self, url, **engine_kwargs): - self._session = Session(url, **engine_kwargs) - dialect = self._session.get_bind().dialect + self._session = Session(url) + dialect = self._get_dialect() self._is_postgres = dialect.name in {"postgres", "postgresql"} - self._sanitize_statement = statement_factory(dialect) + self._substitute_markers_with_params = statement_factory(dialect) self._autocommit = False + def _get_dialect(self): + return self._session.get_bind().dialect + def execute(self, sql, *args, **kwargs): - """Execute a SQL statement.""" - statement = self._sanitize_statement(sql, *args, **kwargs) - if statement.is_transaction_start(): - self._autocommit = False + """Executes a SQL statement. - if self._autocommit: - self._session.execute("BEGIN") + :param sql: a SQL statement, possibly with parameters markers + :type sql: str + :param *args: zero or more positional arguments to substitute the parameter markers with + :param **kwargs: zero or more keyword arguments to substitute the parameter markers with - result = self._execute(statement) + :returns: For ``SELECT``, a :py:class:`list` of :py:class:`dict` objects, each of which + represents a row in the result set; for ``INSERT``, the primary key of a newly inserted row + (or ``None`` if none); for ``UPDATE``, the number of rows updated; for ``DELETE``, the + number of rows deleted; for other statements, ``True``; on integrity errors, a + :py:class:`ValueError` is raised, on other errors, a :py:class:`RuntimeError` is raised - if self._autocommit: - self._session.execute("COMMIT") + """ - if statement.is_transaction_end(): - self._autocommit = True + statement = self._substitute_markers_with_params(sql, *args, **kwargs) + if statement.is_transaction_start(): + self._disable_autocommit() + + self._begin_transaction_in_autocommit_mode() + result = self._execute(statement) + self._commit_transaction_in_autocommit_mode() if statement.is_select(): - ret = fetch_select_result(result) + ret = process_select_result(result) elif statement.is_insert(): ret = self._last_row_id_or_none(result) elif statement.is_delete() or statement.is_update(): @@ -48,43 +62,84 @@ def execute(self, sql, *args, **kwargs): else: ret = True - if self._autocommit: - self._session.remove() + if statement.is_transaction_end(): + self._enable_autocommit() + self._shutdown_session_in_autocommit_mode() return ret + def _disable_autocommit(self): + self._autocommit = False + + def _begin_transaction_in_autocommit_mode(self): + if self._autocommit: + self._session.execute("BEGIN") + def _execute(self, statement): - with raise_errors_for_warnings(): - try: + """ + :param statement: a SQL statement represented as a ``str`` or a + :class:`_statement.Statement` + + :rtype: :class:`sqlalchemy.engine.Result` + """ + try: + with raise_errors_for_warnings(): result = self._session.execute(statement) - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(termcolor.colored(str(statement), "yellow")) - if self._autocommit: - self._session.remove() - raise ValueError(exc.orig) from None - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - self._session.remove() - _logger.debug(termcolor.colored(statement, "red")) - raise RuntimeError(exc.orig) from None - - _logger.debug(termcolor.colored(str(statement), "green")) - return result + # E.g., failed constraint + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(yellow(statement)) + self._shutdown_session_in_autocommit_mode() + raise ValueError(exc.orig) from None + # E.g., connection error or syntax error + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + self._shutdown_session() + _logger.debug(red(statement)) + raise RuntimeError(exc.orig) from None + + _logger.debug(green(statement)) + return result + + def _shutdown_session_in_autocommit_mode(self): + if self._autocommit: + self._shutdown_session() + + def _shutdown_session(self): + self._session.remove() + + def _commit_transaction_in_autocommit_mode(self): + if self._autocommit: + self._session.execute("COMMIT") + + def _enable_autocommit(self): + self._autocommit = True def _last_row_id_or_none(self, result): + """ + :param result: A SQLAlchemy result object + :type result: :class:`sqlalchemy.engine.Result` + + :returns: The ID of the last inserted row or ``None`` + """ + if self._is_postgres: - return self._get_last_val() + return self._postgres_lastval() return result.lastrowid if result.rowcount == 1 else None - def _get_last_val(self): + def _postgres_lastval(self): try: return self._session.execute("SELECT LASTVAL()").first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session return None def init_app(self, app): - """Registers a teardown_appcontext listener to remove session and enables logging""" + """Enables logging and registers a ``teardown_appcontext`` listener to remove the session. + + :param app: a Flask application instance + :type app: :class:`flask.Flask` + """ + @app.teardown_appcontext def _(_): - self._session.remove() + self._shutdown_session() logging.getLogger("cs50").disabled = False From 9e9cb0b34ebd44394f7406a97a2113adae015c4c Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Thu, 15 Apr 2021 19:06:14 -0400 Subject: [PATCH 41/47] update cs50 docstrings --- src/cs50/cs50.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 30d3515..11fa20a 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -5,11 +5,14 @@ def get_float(prompt): + """Reads a line of text from standard input and returns the equivalent float as precisely as + possible; if text does not represent a float, user is prompted to retry. If line can't be read, + returns None. + + :type prompt: str + """ - Read a line of text from standard input and return the equivalent float - as precisely as possible; if text does not represent a double, user is - prompted to retry. If line can't be read, return None. - """ + while True: try: return _get_float(prompt) @@ -29,11 +32,12 @@ def _get_float(prompt): def get_int(prompt): + """Reads a line of text from standard input and return the equivalent int; if text does not + represent an int, user is prompted to retry. If line can't be read, returns None. + + :type prompt: str """ - Read a line of text from standard input and return the equivalent int; - if text does not represent an int, user is prompted to retry. If line - can't be read, return None. - """ + while True: try: return _get_int(prompt) @@ -53,12 +57,13 @@ def _get_int(prompt): def get_string(prompt): + """Reads a line of text from standard input and returns it as a string, sans trailing line + ending. Supports CR (\r), LF (\n), and CRLF (\r\n) as line endings. If user inputs only a line + ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF). + + :type prompt: str """ - Read a line of text from standard input and return it as a string, - sans trailing line ending. Supports CR (\r), LF (\n), and CRLF (\r\n) - as line endings. If user inputs only a line ending, returns "", not None. - Returns None upon error or no input whatsoever (i.e., just EOF). - """ + if not isinstance(prompt, str): raise TypeError("prompt must be of type str") @@ -73,8 +78,7 @@ def _get_input(prompt): class _flushfile(): - """ - Disable buffering for standard output and standard error. + """ Disable buffering for standard output and standard error. http://stackoverflow.com/a/231216 """ @@ -91,7 +95,8 @@ def write(self, data): def disable_output_buffering(): - """Disables output buffering to prevent prompts from being buffered""" + """Disables output buffering to prevent prompts from being buffered. + """ sys.stderr = _flushfile(sys.stderr) sys.stdout = _flushfile(sys.stdout) From 789bb4015ae38c1e58942f92dd54a6ba5ce65200 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Thu, 15 Apr 2021 21:15:41 -0400 Subject: [PATCH 42/47] use assertAlmostEqual --- tests/test_cs50.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_cs50.py b/tests/test_cs50.py index dd0f14b..9a0faca 100644 --- a/tests/test_cs50.py +++ b/tests/test_cs50.py @@ -1,4 +1,3 @@ -import math import sys import unittest @@ -95,7 +94,7 @@ def test_get_float_valid_input(self): def assert_equal(return_value, expected_value): with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: f = _get_float("Answer: ") - self.assertTrue(math.isclose(f, expected_value)) + self.assertAlmostEqual(f, expected_value) mock_get_string.assert_called_with("Answer: ") values = [ From 05a4d9df5f59415592dcad413ed2c91a63cf7bfb Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Fri, 16 Apr 2021 02:15:05 -0400 Subject: [PATCH 43/47] enable autocommit by default --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 974137c..c38ce25 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -24,7 +24,7 @@ def __init__(self, url): dialect = self._get_dialect() self._is_postgres = dialect.name in {"postgres", "postgresql"} self._substitute_markers_with_params = statement_factory(dialect) - self._autocommit = False + self._autocommit = True def _get_dialect(self): return self._session.get_bind().dialect From 7bf7e1688c5162e7e1e1004841639d62fa38c4c7 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Fri, 16 Apr 2021 09:49:31 -0400 Subject: [PATCH 44/47] avoid using sessions --- setup.py | 2 +- src/cs50/_engine.py | 66 +++++++++++ src/cs50/_engine_util.py | 43 +++++++ src/cs50/_session.py | 34 ------ src/cs50/_session_util.py | 68 ----------- src/cs50/_sql_util.py | 13 +++ src/cs50/_statement.py | 2 +- src/cs50/sql.py | 132 ++++++++------------- tests/test_statement.py | 234 -------------------------------------- 9 files changed, 174 insertions(+), 420 deletions(-) create mode 100644 src/cs50/_engine.py create mode 100644 src/cs50/_engine_util.py delete mode 100644 src/cs50/_session.py delete mode 100644 src/cs50/_session_util.py delete mode 100644 tests/test_statement.py diff --git a/setup.py b/setup.py index a5b8fb7..de271f8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor"], + install_requires=["Flask>=1.0", "SQLAlchemy<2", "sqlparse", "termcolor"], keywords="cs50", name="cs50", package_dir={"": "src"}, diff --git a/src/cs50/_engine.py b/src/cs50/_engine.py new file mode 100644 index 0000000..d74992c --- /dev/null +++ b/src/cs50/_engine.py @@ -0,0 +1,66 @@ +import threading + +from ._engine_util import create_engine + + +thread_local_data = threading.local() + + +class Engine: + """Wraps a SQLAlchemy engine. + """ + + def __init__(self, url): + self._engine = create_engine(url) + + def get_transaction_connection(self): + """ + :returns: A new connection with autocommit disabled (to be used for transactions). + """ + + _thread_local_connections()[self] = self._engine.connect().execution_options( + autocommit=False) + return self.get_existing_transaction_connection() + + def get_connection(self): + """ + :returns: A new connection with autocommit enabled + """ + + return self._engine.connect().execution_options(autocommit=True) + + def get_existing_transaction_connection(self): + """ + :returns: The transaction connection bound to this Engine instance, if one exists, or None. + """ + + return _thread_local_connections().get(self) + + def close_transaction_connection(self): + """Closes the transaction connection bound to this Engine instance, if one exists and + removes it. + """ + + connection = self.get_existing_transaction_connection() + if connection: + connection.close() + del _thread_local_connections()[self] + + def is_postgres(self): + return self._engine.dialect.name in {"postgres", "postgresql"} + + def __getattr__(self, attr): + return getattr(self._engine, attr) + +def _thread_local_connections(): + """ + :returns: A thread local dict to keep track of transaction connection. If one does not exist, + creates one. + """ + + try: + connections = thread_local_data.connections + except AttributeError: + connections = thread_local_data.connections = {} + + return connections diff --git a/src/cs50/_engine_util.py b/src/cs50/_engine_util.py new file mode 100644 index 0000000..c55b8f2 --- /dev/null +++ b/src/cs50/_engine_util.py @@ -0,0 +1,43 @@ +"""Utility functions used by _session.py. +""" + +import os +import sqlite3 + +import sqlalchemy + +sqlite_url_prefix = "sqlite:///" + + +def create_engine(url, **kwargs): + """Creates a new SQLAlchemy engine. If ``url`` is a URL for a SQLite database, makes sure that + the SQLite file exits and enables foreign key constraints. + """ + + try: + engine = sqlalchemy.create_engine(url, **kwargs) + except sqlalchemy.exc.ArgumentError: + raise RuntimeError(f"invalid URL: {url}") from None + + if _is_sqlite_url(url): + _assert_sqlite_file_exists(url) + sqlalchemy.event.listen(engine, "connect", _enable_sqlite_foreign_key_constraints) + + return engine + +def _is_sqlite_url(url): + return url.startswith(sqlite_url_prefix) + + +def _assert_sqlite_file_exists(url): + path = url[len(sqlite_url_prefix):] + if not os.path.exists(path): + raise RuntimeError(f"does not exist: {path}") + if not os.path.isfile(path): + raise RuntimeError(f"not a file: {path}") + + +def _enable_sqlite_foreign_key_constraints(dbapi_connection, _): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() diff --git a/src/cs50/_session.py b/src/cs50/_session.py deleted file mode 100644 index f28c30a..0000000 --- a/src/cs50/_session.py +++ /dev/null @@ -1,34 +0,0 @@ -import sqlalchemy -import sqlalchemy.orm - -from ._session_util import ( - assert_sqlite_file_exists, - create_session, - is_sqlite_url, -) - - -class Session: - """Wraps a SQLAlchemy scoped session. - """ - - def __init__(self, url, **engine_kwargs): - if is_sqlite_url(url): - assert_sqlite_file_exists(url) - - self._session = create_session(url, **engine_kwargs) - - def execute(self, statement): - """Converts statement to str and executes it. - - :param statement: The SQL statement to be executed - """ - - # pylint: disable=no-member - return self._session.execute(sqlalchemy.text(str(statement))) - - def __getattr__(self, attr): - """Proxies any attributes to the underlying SQLAlchemy scoped session. - """ - - return getattr(self._session, attr) diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py deleted file mode 100644 index 01983b5..0000000 --- a/src/cs50/_session_util.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Utility functions used by _session.py. -""" - -import os -import sqlite3 - -import sqlalchemy - - -def is_sqlite_url(url): - return url.startswith("sqlite:///") - - -def assert_sqlite_file_exists(url): - path = url[len("sqlite:///"):] - if not os.path.exists(path): - raise RuntimeError(f"does not exist: {path}") - if not os.path.isfile(path): - raise RuntimeError(f"not a file: {path}") - - -def create_session(url, **engine_kwargs): - engine = _create_engine(url, **engine_kwargs) - _setup_on_connect(engine) - return _create_scoped_session(engine) - - -def _create_engine(url, **kwargs): - try: - engine = sqlalchemy.create_engine(url, **kwargs) - except sqlalchemy.exc.ArgumentError: - raise RuntimeError(f"invalid URL: {url}") from None - - engine.execution_options(autocommit=False) - return engine - - -def _setup_on_connect(engine): - def connect(dbapi_connection, _): - _disable_auto_begin_commit(dbapi_connection) - if _is_sqlite_connection(dbapi_connection): - _enable_sqlite_foreign_key_constraints(dbapi_connection) - - sqlalchemy.event.listen(engine, "connect", connect) - - -def _create_scoped_session(engine): - session_factory = sqlalchemy.orm.sessionmaker(bind=engine) - return sqlalchemy.orm.scoping.scoped_session(session_factory) - - -def _disable_auto_begin_commit(dbapi_connection): - """Disables the underlying API's own emitting of BEGIN and COMMIT so we can support manual - transactions. - https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl - """ - - dbapi_connection.isolation_level = None - - -def _is_sqlite_connection(dbapi_connection): - return isinstance(dbapi_connection, sqlite3.Connection) - - -def _enable_sqlite_foreign_key_constraints(dbapi_connection): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index 0b0c27b..2dbfecf 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -5,6 +5,8 @@ import decimal import warnings +import sqlalchemy + def process_select_result(result): """Converts a SQLAlchemy result to a ``list`` of ``dict`` objects, each of which represents a @@ -36,3 +38,14 @@ def raise_errors_for_warnings(): with warnings.catch_warnings(): warnings.simplefilter("error") yield + + +def postgres_lastval(connection): + """ + :returns: The ID of the last inserted row, if defined in this session, or None + """ + + try: + return connection.execute("SELECT LASTVAL()").first()[0] + except sqlalchemy.exc.OperationalError: + return None diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 79e77d8..2de956a 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -64,7 +64,7 @@ def __init__(self, sql_sanitizer, sql, *args, **kwargs): self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() self._substitute_markers_with_escaped_params() - self._escape_verbatim_colons() + # self._escape_verbatim_colons() def _get_escaped_args(self, args): return [self._sql_sanitizer.escape(arg) for arg in args] diff --git a/src/cs50/sql.py b/src/cs50/sql.py index c38ce25..d32c319 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -3,9 +3,9 @@ import sqlalchemy from ._logger import green, red, yellow -from ._session import Session +from ._engine import Engine from ._statement import statement_factory -from ._sql_util import process_select_result, raise_errors_for_warnings +from ._sql_util import postgres_lastval, process_select_result, raise_errors_for_warnings _logger = logging.getLogger("cs50") @@ -20,14 +20,8 @@ def __init__(self, url): :param url: The database URL """ - self._session = Session(url) - dialect = self._get_dialect() - self._is_postgres = dialect.name in {"postgres", "postgresql"} - self._substitute_markers_with_params = statement_factory(dialect) - self._autocommit = True - - def _get_dialect(self): - return self._session.get_bind().dialect + self._engine = Engine(url) + self._substitute_markers_with_params = statement_factory(self._engine.dialect) def execute(self, sql, *args, **kwargs): """Executes a SQL statement. @@ -46,73 +40,52 @@ def execute(self, sql, *args, **kwargs): """ statement = self._substitute_markers_with_params(sql, *args, **kwargs) - if statement.is_transaction_start(): - self._disable_autocommit() - - self._begin_transaction_in_autocommit_mode() - result = self._execute(statement) - self._commit_transaction_in_autocommit_mode() - - if statement.is_select(): - ret = process_select_result(result) - elif statement.is_insert(): - ret = self._last_row_id_or_none(result) - elif statement.is_delete() or statement.is_update(): - ret = result.rowcount + connection = self._engine.get_existing_transaction_connection() + if connection is None: + if statement.is_transaction_start(): + connection = self._engine.get_transaction_connection() + else: + connection = self._engine.get_connection() + elif statement.is_transaction_start(): + raise RuntimeError("nested transactions are not supported") + + return self._execute(statement, connection) + + def _execute(self, statement, connection): + with raise_errors_for_warnings(): + try: + result = connection.execute(str(statement)) + # E.g., failed constraint + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(yellow(statement)) + if self._engine.get_existing_transaction_connection() is None: + connection.close() + raise ValueError(exc.orig) from None + # E.g., connection error or syntax error + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + connection.close() + _logger.debug(red(statement)) + raise RuntimeError(exc.orig) from None + + _logger.debug(green(statement)) + + if statement.is_select(): + ret = process_select_result(result) + elif statement.is_insert(): + ret = self._last_row_id_or_none(result) + elif statement.is_delete() or statement.is_update(): + ret = result.rowcount + else: + ret = True + + if self._engine.get_existing_transaction_connection(): + if statement.is_transaction_end(): + self._engine.close_transaction_connection() else: - ret = True - - if statement.is_transaction_end(): - self._enable_autocommit() + connection.close() - self._shutdown_session_in_autocommit_mode() return ret - def _disable_autocommit(self): - self._autocommit = False - - def _begin_transaction_in_autocommit_mode(self): - if self._autocommit: - self._session.execute("BEGIN") - - def _execute(self, statement): - """ - :param statement: a SQL statement represented as a ``str`` or a - :class:`_statement.Statement` - - :rtype: :class:`sqlalchemy.engine.Result` - """ - try: - with raise_errors_for_warnings(): - result = self._session.execute(statement) - # E.g., failed constraint - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(yellow(statement)) - self._shutdown_session_in_autocommit_mode() - raise ValueError(exc.orig) from None - # E.g., connection error or syntax error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - self._shutdown_session() - _logger.debug(red(statement)) - raise RuntimeError(exc.orig) from None - - _logger.debug(green(statement)) - return result - - def _shutdown_session_in_autocommit_mode(self): - if self._autocommit: - self._shutdown_session() - - def _shutdown_session(self): - self._session.remove() - - def _commit_transaction_in_autocommit_mode(self): - if self._autocommit: - self._session.execute("COMMIT") - - def _enable_autocommit(self): - self._autocommit = True - def _last_row_id_or_none(self, result): """ :param result: A SQLAlchemy result object @@ -121,16 +94,10 @@ def _last_row_id_or_none(self, result): :returns: The ID of the last inserted row or ``None`` """ - if self._is_postgres: - return self._postgres_lastval() + if self._engine.is_postgres(): + return postgres_lastval(result.connection) return result.lastrowid if result.rowcount == 1 else None - def _postgres_lastval(self): - try: - return self._session.execute("SELECT LASTVAL()").first()[0] - except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session - return None - def init_app(self, app): """Enables logging and registers a ``teardown_appcontext`` listener to remove the session. @@ -140,6 +107,7 @@ def init_app(self, app): @app.teardown_appcontext def _(_): - self._shutdown_session() + self._engine.close_transaction_connection() + logging.getLogger("cs50").disabled = False diff --git a/tests/test_statement.py b/tests/test_statement.py deleted file mode 100644 index 91261cd..0000000 --- a/tests/test_statement.py +++ /dev/null @@ -1,234 +0,0 @@ -import unittest - -from unittest.mock import patch - -from cs50._statement import Statement -from cs50._sql_sanitizer import SQLSanitizer - - -@patch.object(SQLSanitizer, "escape", return_value="test") -class TestStatement(unittest.TestCase): - # TODO assert correct exception messages - def test_mutex_args_and_kwargs(self, MockSQLSanitizer): - with self.assertRaises(RuntimeError): - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ? AND val = :val", 1, val="test") - - with self.assertRaises(RuntimeError): - Statement(MockSQLSanitizer(), "SELECT * FROM test", "test", 1, 2, foo="foo", bar="bar") - - @patch.object(Statement, "_escape_verbatim_colons") - def test_valid_qmark_count(self, MockSQLSanitizer, *_): - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ?", 1) - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') - Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_invalid_qmark_count(self, MockSQLSanitizer, *_): - def assert_invalid_count(sql, *args): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(MockSQLSanitizer(), sql, *args) - - statements = [ - ("SELECT * FROM test WHERE id = ?", ()), - ("SELECT * FROM test WHERE id = ?", (1, "test")), - ("SELECT * FROM test WHERE id = ? AND val = ?", (1,)), - ("SELECT * FROM test WHERE id = ? AND val = ?", ()), - ("SELECT * FROM test WHERE id = ? AND val = ?", (1, "test", True)), - ] - - for sql, args in statements: - assert_invalid_count(sql, *args) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_valid_format_count(self, MockSQLSanitizer, *_): - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = %s", 1) - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') - Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_invalid_format_count(self, MockSQLSanitizer, *_): - def assert_invalid_count(sql, *args): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(MockSQLSanitizer(), sql, *args) - - statements = [ - ("SELECT * FROM test WHERE id = %s", ()), - ("SELECT * FROM test WHERE id = %s", (1, "test")), - ("SELECT * FROM test WHERE id = %s AND val = ?", (1,)), - ("SELECT * FROM test WHERE id = %s AND val = ?", ()), - ("SELECT * FROM test WHERE id = %s AND val = ?", (1, "test", True)), - ] - - for sql, args in statements: - assert_invalid_count(sql, *args) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_numeric(self, MockSQLSanitizer, *_): - def assert_missing_numeric(sql, *args): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(MockSQLSanitizer(), sql, *args) - - statements = [ - ("SELECT * FROM test WHERE id = :1", ()), - ("SELECT * FROM test WHERE id = :1 AND val = :2", ()), - ("SELECT * FROM test WHERE id = :1 AND val = :2", (1,)), - ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", ()), - ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1,)), - ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1, "test")), - ] - - for sql, args in statements: - assert_missing_numeric(sql, *args) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_numeric(self, MockSQLSanitizer, *_): - def assert_unused_numeric(sql, *args): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(MockSQLSanitizer(), sql, *args) - - statements = [ - ("SELECT * FROM test WHERE id = :1", (1, "test")), - ("SELECT * FROM test WHERE id = :1", (1, "test", True)), - ("SELECT * FROM test WHERE id = :1 AND val = :2", (1, "test", True)), - ] - - for sql, args in statements: - assert_unused_numeric(sql, *args) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_named(self, MockSQLSanitizer, *_): - def assert_missing_named(sql, **kwargs): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(MockSQLSanitizer(), sql, **kwargs) - - statements = [ - ("SELECT * FROM test WHERE id = :id", {}), - ("SELECT * FROM test WHERE id = :id AND val = :val", {}), - ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", - {"id": 1}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", - {"id": 1, "val": "test"}), - ] - - for sql, kwargs in statements: - assert_missing_named(sql, **kwargs) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_named(self, MockSQLSanitizer, *_): - def assert_unused_named(sql, **kwargs): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(MockSQLSanitizer(), sql, **kwargs) - - statements = [ - ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}), - ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test", "is_valid": True}), - ("SELECT * FROM test WHERE id = :id AND val = :val", - {"id": 1, "val": "test", "is_valid": True}), - ] - - for sql, kwargs in statements: - assert_unused_named(sql, **kwargs) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_pyformat(self, MockSQLSanitizer, *_): - def assert_missing_pyformat(sql, **kwargs): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(MockSQLSanitizer(), sql, **kwargs) - - statements = [ - ("SELECT * FROM test WHERE id = %(id)s", {}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", - {"id": 1}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", - {"id": 1, "val": "test"}), - ] - - for sql, kwargs in statements: - assert_missing_pyformat(sql, **kwargs) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_pyformat(self, MockSQLSanitizer, *_): - def assert_unused_pyformat(sql, **kwargs): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(MockSQLSanitizer(), sql, **kwargs) - - statements = [ - ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}), - ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test", "is_valid": True}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", - {"id": 1, "val": "test", "is_valid": True}), - ] - - for sql, kwargs in statements: - assert_unused_pyformat(sql, **kwargs) - - def test_multiple_statements(self, MockSQLSanitizer): - def assert_raises_runtimeerror(sql): - with self.assertRaises(RuntimeError): - Statement(MockSQLSanitizer(), sql) - - statements = [ - "SELECT 1; SELECT 2;", - "SELECT 1; SELECT 2", - "SELECT 1; SELECT 2; SELECT 3", - "SELECT 1; SELECT 2; SELECT 3;", - "SELECT 1;SELECT 2", - "select 1; select 2", - "select 1;select 2", - "DELETE FROM test; SELECT * FROM test", - ] - - for sql in statements: - assert_raises_runtimeerror(sql) - - def test_is_delete(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "DELETE FROM test").is_delete()) - self.assertTrue(Statement(MockSQLSanitizer(), "delete FROM test").is_delete()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_delete()) - self.assertFalse(Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete()) - - def test_is_insert(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert()) - self.assertTrue(Statement(MockSQLSanitizer(), - "insert INTO test (id, val) VALUES (1, 'test')").is_insert()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_insert()) - self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_insert()) - - def test_is_select(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_select()) - self.assertTrue(Statement(MockSQLSanitizer(), "select * FROM test").is_select()) - self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_select()) - self.assertFalse(Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val) VALUES (1, 'test')").is_select()) - - def test_is_update(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "UPDATE test SET id = 2").is_update()) - self.assertTrue(Statement(MockSQLSanitizer(), "update test SET id = 2").is_update()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_update()) - self.assertFalse(Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val) VALUES (1, 'test')").is_update()) - - def test_is_transaction_start(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "START TRANSACTION").is_transaction_start()) - self.assertTrue(Statement(MockSQLSanitizer(), "start TRANSACTION").is_transaction_start()) - self.assertTrue(Statement(MockSQLSanitizer(), "BEGIN").is_transaction_start()) - self.assertTrue(Statement(MockSQLSanitizer(), "begin").is_transaction_start()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_transaction_start()) - self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_transaction_start()) - - def test_is_transaction_end(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "COMMIT").is_transaction_end()) - self.assertTrue(Statement(MockSQLSanitizer(), "commit").is_transaction_end()) - self.assertTrue(Statement(MockSQLSanitizer(), "ROLLBACK").is_transaction_end()) - self.assertTrue(Statement(MockSQLSanitizer(), "rollback").is_transaction_end()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_transaction_end()) - self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_transaction_end()) From 0674b7c086946d0c87a912482934b4ffabaa1c04 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Fri, 16 Apr 2021 10:39:53 -0400 Subject: [PATCH 45/47] delete transaction connection on failure --- src/cs50/sql.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d32c319..64d30e3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -63,7 +63,10 @@ def _execute(self, statement, connection): raise ValueError(exc.orig) from None # E.g., connection error or syntax error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - connection.close() + if self._engine.get_existing_transaction_connection(): + self._engine.close_transaction_connection() + else: + connection.close() _logger.debug(red(statement)) raise RuntimeError(exc.orig) from None From ff9e69f5ac7519498b6552a4e3180b5fadee4b78 Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 26 Jul 2021 15:58:21 -0400 Subject: [PATCH 46/47] Deploy with GitHub Actions --- .github/workflows/main.yml | 46 ++++++++++++++++++++++++++++++++++++++ .travis.yml | 30 ------------------------- 2 files changed, 46 insertions(+), 30 deletions(-) create mode 100644 .github/workflows/main.yml delete mode 100644 .travis.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..0eb0e2c --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,46 @@ +on: push +jobs: + deploy: + runs-on: ubuntu-latest + services: + mysql: + image: mysql + env: + MYSQL_DATABASE: test + MYSQL_ALLOW_EMPTY_PASSWORD: yes + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + ports: + - 3306:3306 + postgres: + image: postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: test + ports: + - 5432:5432 + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: '3.6' + - name: Setup databases + run: | + python setup.py install + pip install mysqlclient + pip install psycopg2-binary + touch test.db test1.db + - name: Run tests + run: python tests/sql.py + - name: Install pypa/build + run: | + python -m pip install build --user + - name: Build a binary wheel and a source tarball + run: | + python -m build --sdist --wheel --outdir dist/ . + - name: Deploy to PyPI + if: ${{ github.ref == 'refs/heads/main' }} + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 0433f6a..0000000 --- a/.travis.yml +++ /dev/null @@ -1,30 +0,0 @@ -language: python -python: '3.6' -branches: - except: "/^v\\d/" -services: - - mysql - - postgresql -install: - - python setup.py install - - pip install mysqlclient - - pip install psycopg2-binary -before_script: - - mysql -e 'CREATE DATABASE IF NOT EXISTS test;' - - psql -c 'create database test;' -U postgres - - touch test.db test1.db -script: python tests/sql.py -deploy: - - provider: script - script: 'curl --fail --data "{ \"tag_name\": \"v$(python setup.py --version)\", - \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" - }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' - on: - branch: main - - provider: pypi - user: "$PYPI_USERNAME" - password: "$PYPI_PASSWORD" - on: main -notifications: - slack: - secure: lJklhcBVjDT6KzUNa3RFHXdXSeH7ytuuGrkZ5ZcR72CXMoTf2pMJTzPwRLWOp6lCSdDC9Y8MWLrcg/e33dJga4Jlp9alOmWqeqesaFjfee4st8vAsgNbv8/RajPH1gD2bnkt8oIwUzdHItdb5AucKFYjbH2g0d8ndoqYqUeBLrnsT1AP5G/Vi9OHC9OWNpR0FKaZIJE0Wt52vkPMH3sV2mFeIskByPB+56U5y547mualKxn61IVR/dhYBEtZQJuSvnwKHPOn9Pkk7cCa+SSSeTJ4w5LboY8T17otaYNauXo46i1bKIoGiBcCcrJyQHHiPQmcq/YU540MC5Wzt9YXUycmJzRi347oyQeDee27wV3XJlWMXuuhbtJiKCFny7BTQ160VATlj/dbwIzN99Ra6/BtTumv/6LyTdKIuVjdAkcN8dtdDW1nlrQ29zuPNCcXXzJ7zX7kQaOCUV1c2OrsbiH/0fE9nknUORn97txqhlYVi0QMS7764wFo6kg0vpmFQRkkQySsJl+TmgcZ01AlsJc2EMMWVuaj9Af9JU4/4yalqDiXIh1fOYYUZnLfOfWS+MsnI+/oLfqJFyMbrsQQTIjs+kTzbiEdhd2R4EZgusU/xRFWokS2NAvahexrRhRQ6tpAI+LezPrkNOR3aHiykBf+P9BkUa0wPp6V2Ayc6q0= From 6aabf0c6ebbbb7ccfe59416a500223b02b710f3d Mon Sep 17 00:00:00 2001 From: Kareem Zidane Date: Mon, 26 Jul 2021 16:01:03 -0400 Subject: [PATCH 47/47] Update sql.py --- tests/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sql.py b/tests/sql.py index e4757c7..89853a7 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -150,7 +150,7 @@ def tearDownClass(self): class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("mysql://root@localhost/test") + self.db = SQL("mysql://root@127.0.0.1/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") @@ -160,7 +160,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://postgres@localhost/test") + self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")