Skip to content

Commit f6912d2

Browse files
author
Kareem Zidane
committed
use statement factory
1 parent 9302a1e commit f6912d2

File tree

4 files changed

+105
-101
lines changed

4 files changed

+105
-101
lines changed

src/cs50/_statement.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,23 @@
1616
)
1717

1818

19+
def statement_factory(dialect):
20+
sql_sanitizer = SQLSanitizer(dialect)
21+
22+
def statement(sql, *args, **kwargs):
23+
return Statement(sql_sanitizer, sql, *args, **kwargs)
24+
25+
return statement
26+
27+
1928
class Statement:
2029
"""Parses a SQL statement and replaces the placeholders with the corresponding parameters"""
2130

22-
def __init__(self, dialect, sql, *args, **kwargs):
31+
def __init__(self, sql_sanitizer, sql, *args, **kwargs):
2332
if len(args) > 0 and len(kwargs) > 0:
2433
raise RuntimeError("cannot pass both positional and named parameters")
2534

26-
self._sql_sanitizer = SQLSanitizer(dialect)
35+
self._sql_sanitizer = sql_sanitizer
2736

2837
self._args = self._get_escaped_args(args)
2938
self._kwargs = self._get_escaped_kwargs(kwargs)

src/cs50/sql.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import termcolor
88

99
from ._session import Session
10-
from ._statement import Statement
10+
from ._statement import statement_factory
1111
from ._sql_util import fetch_select_result
1212

1313
_logger = logging.getLogger("cs50")
@@ -18,13 +18,14 @@ class SQL:
1818

1919
def __init__(self, url, **engine_kwargs):
2020
self._session = Session(url, **engine_kwargs)
21-
self._dialect = self._session.get_bind().dialect
22-
self._is_postgres = self._dialect.name in {"postgres", "postgresql"}
21+
dialect = self._session.get_bind().dialect
22+
self._is_postgres = dialect.name in {"postgres", "postgresql"}
23+
self._sanitized_statement = statement_factory(dialect)
2324
self._autocommit = False
2425

2526
def execute(self, sql, *args, **kwargs):
2627
"""Execute a SQL statement."""
27-
statement = Statement(self._dialect, sql, *args, **kwargs)
28+
statement = self._sanitized_statement(sql, *args, **kwargs)
2829
if statement.is_transaction_start():
2930
self._autocommit = False
3031

@@ -53,7 +54,6 @@ def execute(self, sql, *args, **kwargs):
5354

5455
return ret
5556

56-
5757
def _execute(self, statement):
5858
# Catch SQLAlchemy warnings
5959
with warnings.catch_warnings():

tests/test_cs50.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,29 @@ def test_get_string_empty_input(self, mock_get_input):
1414
self.assertEqual(get_string("Answer: "), "")
1515
mock_get_input.assert_called_with("Answer: ")
1616

17-
1817
@patch("cs50.cs50._get_input", return_value="test")
1918
def test_get_string_nonempty_input(self, mock_get_input):
2019
"""Returns the provided non-empty input"""
2120
self.assertEqual(get_string("Answer: "), "test")
2221
mock_get_input.assert_called_with("Answer: ")
2322

24-
2523
@patch("cs50.cs50._get_input", side_effect=EOFError)
2624
def test_get_string_eof(self, mock_get_input):
2725
"""Returns None on EOF"""
2826
self.assertIs(get_string("Answer: "), None)
2927
mock_get_input.assert_called_with("Answer: ")
3028

31-
3229
def test_get_string_invalid_prompt(self):
3330
"""Raises TypeError when prompt is not str"""
3431
with self.assertRaises(TypeError):
3532
get_string(1)
3633

37-
3834
@patch("cs50.cs50.get_string", return_value=None)
3935
def test_get_int_eof(self, mock_get_string):
4036
"""Returns None on EOF"""
4137
self.assertIs(_get_int("Answer: "), None)
4238
mock_get_string.assert_called_with("Answer: ")
4339

44-
4540
def test_get_int_valid_input(self):
4641
"""Returns the provided integer input"""
4742

@@ -62,7 +57,6 @@ def assert_equal(return_value, expected_value):
6257
for return_value, expected_value in values:
6358
assert_equal(return_value, expected_value)
6459

65-
6660
def test_get_int_invalid_input(self):
6761
"""Raises ValueError when input is invalid base-10 int"""
6862

@@ -90,14 +84,12 @@ def assert_raises_valueerror(return_value):
9084
for return_value in return_values:
9185
assert_raises_valueerror(return_value)
9286

93-
9487
@patch("cs50.cs50.get_string", return_value=None)
9588
def test_get_float_eof(self, mock_get_string):
9689
"""Returns None on EOF"""
9790
self.assertIs(_get_float("Answer: "), None)
9891
mock_get_string.assert_called_with("Answer: ")
9992

100-
10193
def test_get_float_valid_input(self):
10294
"""Returns the provided integer input"""
10395
def assert_equal(return_value, expected_value):
@@ -121,7 +113,6 @@ def assert_equal(return_value, expected_value):
121113
for return_value, expected_value in values:
122114
assert_equal(return_value, expected_value)
123115

124-
125116
def test_get_float_invalid_input(self):
126117
"""Raises ValueError when input is invalid float"""
127118

0 commit comments

Comments
 (0)