Skip to content
Merged
125 changes: 51 additions & 74 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,50 +462,53 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
),
*(
[
ast.Assign(
targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.IfExp(
test=ast.Name(id="test_class_name", ctx=ast.Load()),
body=ast.BinOp(
left=ast.Name(id="test_class_name", ctx=ast.Load()),
op=ast.Add(),
right=ast.Constant(value="."),
),
orelse=ast.Constant(value=""),
),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
]
),
lineno=lineno + 9,
),
ast.Expr(
value=ast.Call(
func=ast.Name(id="print", ctx=ast.Load()),
args=[
ast.JoinedStr(
values=[
ast.Constant(value="!######"),
ast.Constant(value="!$######"),
ast.FormattedValue(
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.IfExp(
test=ast.Name(id="test_class_name", ctx=ast.Load()),
body=ast.BinOp(
left=ast.Name(id="test_class_name", ctx=ast.Load()),
op=ast.Add(),
right=ast.Constant(value="."),
),
orelse=ast.Constant(value=""),
),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1
),
ast.Constant(value="######!"),
ast.Constant(value="######$!"),
]
)
],
keywords=[],
)
)
),
]
if mode == TestingMode.BEHAVIOR
else []
),
ast.Assign(
targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10
Expand Down Expand Up @@ -598,56 +601,30 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
keywords=[],
)
),
*(
[
ast.Expr(
value=ast.Call(
func=ast.Name(id="print", ctx=ast.Load()),
args=[
ast.JoinedStr(
values=[
ast.Constant(value="!######"),
ast.FormattedValue(
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.IfExp(
test=ast.Name(id="test_class_name", ctx=ast.Load()),
body=ast.BinOp(
left=ast.Name(id="test_class_name", ctx=ast.Load()),
op=ast.Add(),
right=ast.Constant(value="."),
),
orelse=ast.Constant(value=""),
),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1
),
ast.Expr(
value=ast.Call(
func=ast.Name(id="print", ctx=ast.Load()),
args=[
ast.JoinedStr(
values=[
ast.Constant(value="!######"),
ast.FormattedValue(value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1),
*(
[
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="codeflash_duration", ctx=ast.Load()), conversion=-1
),
ast.Constant(value="######!"),
]
)
],
keywords=[],
if mode == TestingMode.PERFORMANCE
else []
),
ast.Constant(value="######!"),
]
)
)
]
if mode == TestingMode.PERFORMANCE
else []
],
keywords=[],
)
),
*(
[
Expand Down
6 changes: 3 additions & 3 deletions codeflash/verification/codeflash_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,8 @@ def wrapper(*args, **kwargs) -> None: # noqa: ANN002, ANN003

# Generate invocation id
invocation_id = f"{line_id}_{codeflash_test_index}"
print(
f"!######{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}######!"
)
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
# Connect to sqlite
codeflash_con = sqlite3.connect(f"{tmp_dir_path}_{codeflash_iteration}.sqlite")
codeflash_cur = codeflash_con.cursor()
Expand All @@ -131,6 +130,7 @@ def wrapper(*args, **kwargs) -> None: # noqa: ANN002, ANN003
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")

# Capture instance state after initialization
if hasattr(args[0], "__dict__"):
Expand Down
56 changes: 35 additions & 21 deletions codeflash/verification/parse_test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def parse_func(file_path: Path) -> XMLParser:
return parse(file_path, xml_parser)


matches_re = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
cleaner_re = re.compile(r"!######.*?######!|-+\s*Captured\s+(Log|Out)\s*-+\n?")
matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######\$!\n")
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")


def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
Expand Down Expand Up @@ -265,12 +265,16 @@ def parse_test_xml(
timed_out = True

sys_stdout = testcase.system_out or ""
matches = matches_re.findall(sys_stdout)

if sys_stdout:
sys_stdout = cleaner_re.sub("", sys_stdout).strip()

if not matches or not len(matches):
begin_matches = list(matches_re_start.finditer(sys_stdout))
end_matches = {}
for match in matches_re_end.finditer(sys_stdout):
groups = match.groups()
if len(groups[5].split(":")) > 1:
iteration_id = groups[5].split(":")[0]
groups = groups[:5] + (iteration_id,)
end_matches[groups] = match

if not begin_matches or not begin_matches:
test_results.add(
FunctionTestInvocation(
loop_index=loop_index,
Expand All @@ -288,26 +292,36 @@ def parse_test_xml(
test_type=test_type,
return_value=None,
timed_out=timed_out,
stdout=sys_stdout,
stdout="",
)
)

else:
for match in matches:
split_val = match[5].split(":")
if len(split_val) > 1:
iteration_id = split_val[0]
runtime = int(split_val[1])
for match_index, match in enumerate(begin_matches):
groups = match.groups()
end_match = end_matches.get(groups)
iteration_id, runtime = groups[5], None
if end_match:
stdout = sys_stdout[match.end() : end_match.start()]
split_val = end_match.groups()[5].split(":")
if len(split_val) > 1:
iteration_id = split_val[0]
runtime = int(split_val[1])
else:
iteration_id, runtime = split_val[0], None
elif match_index == len(begin_matches) - 1:
stdout = sys_stdout[match.end() :]
else:
iteration_id, runtime = split_val[0], None
stdout = sys_stdout[match.end() : begin_matches[match_index + 1].start()]

test_results.add(
FunctionTestInvocation(
loop_index=int(match[4]),
loop_index=int(groups[4]),
id=InvocationId(
test_module_path=match[0],
test_class_name=None if match[1] == "" else match[1][:-1],
test_function_name=match[2],
function_getting_tested=match[3],
test_module_path=groups[0],
test_class_name=None if groups[1] == "" else groups[1][:-1],
test_function_name=groups[2],
function_getting_tested=groups[3],
iteration_id=iteration_id,
),
file_name=test_file_path,
Expand All @@ -317,7 +331,7 @@ def parse_test_xml(
test_type=test_type,
return_value=None,
timed_out=timed_out,
stdout=sys_stdout,
stdout=stdout,
)
)

Expand Down
Loading
Loading