diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 37f2add4abbd..d33497d4987b 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -8,7 +8,8 @@ import sys from collections.abc import Sequence -from typing import Callable, Final, Optional +from typing import Callable, Final, Optional, cast +from typing_extensions import TypeGuard from mypy.argmap import map_actuals_to_formals from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind @@ -185,6 +186,7 @@ from mypyc.primitives.str_ops import ( str_check_if_true, str_eq, + str_eq_literal, str_ssize_t_size_op, unicode_compare, ) @@ -1551,9 +1553,33 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) - def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: """Compare two strings""" if op == "==": + # We can specialize this case if one or both values are string literals + literal_fastpath = False + + def is_string_literal(value: Value) -> TypeGuard[LoadLiteral]: + return isinstance(value, LoadLiteral) and is_str_rprimitive(value.type) + + if is_string_literal(lhs): + if is_string_literal(rhs): + # we can optimize out the check entirely in some constant-folded cases + return self.true() if lhs.value == rhs.value else self.false() + + # if lhs argument is string literal, switch sides to match specializer C api + lhs, rhs = rhs, lhs + literal_fastpath = True + elif is_string_literal(rhs): + literal_fastpath = True + + if literal_fastpath: + literal_string = cast(str, cast(LoadLiteral, rhs).value) + literal_length = Integer(len(literal_string), c_pyssize_t_rprimitive, line) + return self.primitive_op(str_eq_literal, [lhs, rhs, literal_length], line) + return self.primitive_op(str_eq, [lhs, rhs], line) + elif op == "!=": - eq = self.primitive_op(str_eq, [lhs, rhs], line) + # perform a standard equality check, then negate + eq = self.compare_strings(lhs, rhs, "==", line) return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line)) # TODO: modify 'str' to use same interface as 'compare_bytes' as it would avoid diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index e9dfd8de3683..ca9f7ba30277 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -734,6 +734,7 @@ static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) { #define BOTHSTRIP 2 char CPyStr_Equal(PyObject *str1, PyObject *str2); +char CPyStr_EqualLiteral(PyObject *str, PyObject *literal_str, Py_ssize_t literal_length); PyObject *CPyStr_Build(Py_ssize_t len, ...); PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index); PyObject *CPyStr_GetItemUnsafe(PyObject *str, Py_ssize_t index); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 337ef14fc955..abc0de5db2d8 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -64,20 +64,33 @@ make_bloom_mask(int kind, const void* ptr, Py_ssize_t len) #undef BLOOM_UPDATE } -// Adapted from CPython 3.13.1 (_PyUnicode_Equal) -char CPyStr_Equal(PyObject *str1, PyObject *str2) { - if (str1 == str2) { - return 1; - } - Py_ssize_t len = PyUnicode_GET_LENGTH(str1); - if (PyUnicode_GET_LENGTH(str2) != len) +static char _CPyStr_Equal_NoIdentCheck(PyObject *str1, PyObject *str2, Py_ssize_t str2_length) { + // This helper function only exists to deduplicate code in CPyStr_Equal and CPyStr_EqualLiteral + Py_ssize_t str1_length = PyUnicode_GET_LENGTH(str1); + if (str1_length != str2_length) return 0; int kind = PyUnicode_KIND(str1); if (PyUnicode_KIND(str2) != kind) return 0; const void *data1 = PyUnicode_DATA(str1); const void *data2 = PyUnicode_DATA(str2); - return memcmp(data1, data2, len * kind) == 0; + return memcmp(data1, data2, str1_length * kind) == 0; +} + +// Adapted from CPython 3.13.1 (_PyUnicode_Equal) +char CPyStr_Equal(PyObject *str1, PyObject *str2) { + if (str1 == str2) { + return 1; + } + Py_ssize_t str2_length = PyUnicode_GET_LENGTH(str2); + return _CPyStr_Equal_NoIdentCheck(str1, str2, str2_length); +} + +char CPyStr_EqualLiteral(PyObject *str, PyObject *literal_str, Py_ssize_t literal_length) { + if (str == literal_str) { + return 1; + } + return _CPyStr_Equal_NoIdentCheck(str, literal_str, literal_length); } PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) { diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index a8f4e4df74c2..d39f1f872763 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -88,6 +88,14 @@ error_kind=ERR_NEVER, ) +str_eq_literal = custom_primitive_op( + name="str_eq_literal", + c_function_name="CPyStr_EqualLiteral", + arg_types=[str_rprimitive, str_rprimitive, c_pyssize_t_rprimitive], + return_type=bool_rprimitive, + error_kind=ERR_NEVER, +) + unicode_compare = custom_op( arg_types=[str_rprimitive, str_rprimitive], return_type=c_int_rprimitive, diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index a98b3a7d3dcf..b2313ccba911 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -2302,7 +2302,7 @@ def SetAttr.__setattr__(self, key, val): r12 :: bit L0: r0 = 'regular_attr' - r1 = CPyStr_Equal(key, r0) + r1 = CPyStr_EqualLiteral(key, r0, 12) if r1 goto L1 else goto L2 :: bool L1: r2 = unbox(int, val) @@ -2310,7 +2310,7 @@ L1: goto L6 L2: r4 = 'class_var' - r5 = CPyStr_Equal(key, r4) + r5 = CPyStr_EqualLiteral(key, r4, 9) if r5 goto L3 else goto L4 :: bool L3: r6 = builtins :: module diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index e0c014f07813..e7a330951ab0 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -410,7 +410,7 @@ L2: k = r8 v = r7 r9 = 'name' - r10 = CPyStr_Equal(k, r9) + r10 = CPyStr_EqualLiteral(k, r9, 4) if r10 goto L3 else goto L4 :: bool L3: name = v diff --git a/mypyc/test-data/irbuild-unreachable.test b/mypyc/test-data/irbuild-unreachable.test index a4f1ef8c7dba..8eafede66b56 100644 --- a/mypyc/test-data/irbuild-unreachable.test +++ b/mypyc/test-data/irbuild-unreachable.test @@ -20,7 +20,7 @@ L0: r2 = CPyObject_GetAttr(r0, r1) r3 = cast(str, r2) r4 = 'x' - r5 = CPyStr_Equal(r3, r4) + r5 = CPyStr_EqualLiteral(r3, r4, 1) if r5 goto L2 else goto L1 :: bool L1: r6 = r5 @@ -54,7 +54,7 @@ L0: r2 = CPyObject_GetAttr(r0, r1) r3 = cast(str, r2) r4 = 'x' - r5 = CPyStr_Equal(r3, r4) + r5 = CPyStr_EqualLiteral(r3, r4, 1) if r5 goto L2 else goto L1 :: bool L1: r6 = r5