diff --git a/.github/workflows/serial-tests.yml b/.github/workflows/serial-tests.yml new file mode 100644 index 0000000..bdfa1f9 --- /dev/null +++ b/.github/workflows/serial-tests.yml @@ -0,0 +1,32 @@ +name: Serial CPU Tests + +on: + push: + branches: + - develop + pull_request: + branches: + - develop + +jobs: + serial-tests: + runs-on: ubuntu-latest + strategy: + matrix: + problem: ["00_dense_la_lu_decomp", "01_dense_la_solve", "02_dense_la_gemm", "03_dense_la_axpy", "04_dense_la_gemv", "05_fft_inverse_fft", "06_fft_dft", "07_fft_fft_conjugate", "08_fft_split_fft", "09_fft_fft_out_of_place", "10_geometry_convex_hull", "11_geometry_convex_hull_perimeter", "12_geometry_smallest_triangle", "13_geometry_closest_pair_2d", "14_geometry_closest_pair_1d", "15_graph_edge_count", "16_graph_largest_component", "17_graph_highest_degree", "18_graph_count_components", "19_graph_shortest_path", "20_histogram_pixel_histogram", "21_histogram_bin_0-100", "22_histogram_count_quadrants", "23_histogram_first_letter_counts", "24_histogram_count_quartile", "25_reduce_xor", "26_reduce_product_of_inverses", "27_reduce_average", "28_reduce_smallest_odd_number", "29_reduce_sum_of_min_of_pairs", "30_scan_prefix_sum", "31_scan_scan_with_min_function", "32_scan_sum_of_prefix_sum_array", "33_scan_reverse_prefix_sum", "34_scan_largest_contiguous_subarray_sum", "35_search_search_for_last_struct_by_key", "36_search_check_if_array_contains_value", "37_search_find_the_closest_number_to_pi", "38_search_find_the_first_even_number", "39_search_xor_contains", "40_sort_sort_an_array_of_complex_numbers_by_magnitude", "41_sort_k-th_smallest_element", "42_sort_sorted_ranks", "43_sort_sort_an_array_of_structs_by_key", "44_sort_sort_non-zero_elements", "45_sparse_la_sparse_solve", "46_sparse_la_spmm", "47_sparse_la_spmv", "48_sparse_la_sparse_axpy", "49_sparse_la_sparse_lu_decomp", "50_stencil_xor_kernel", "51_stencil_edge_kernel", "52_stencil_1d_jacobi_3-point_stencil", "53_stencil_2d_jacobi_5-point_stencil", "54_stencil_game_of_life", "55_transform_relu", "56_transform_negate_odds", "57_transform_inverse_offset", "58_transform_squaring", "59_transform_map_function"] + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tqdm + + - name: Run CPU test for ${{ matrix.problem }} + run: bash test/test-serial.bash "${{ matrix.problem }}" \ No newline at end of file diff --git a/prompts/create-serial-tests.py b/prompts/create-serial-tests.py index 29e0adc..fdcd6c0 100644 --- a/prompts/create-serial-tests.py +++ b/prompts/create-serial-tests.py @@ -24,7 +24,7 @@ def get_return_type(code: str) -> str: # then return the type lines = code.split('\n') for line in lines: - if line.strip().endswith(') {'): + if "NO_INLINE correct" in line and line.strip().endswith(') {'): return line.split()[0] def main(): @@ -45,7 +45,8 @@ def main(): continue baseline = get_file_contents(baseline_fpath) - impl = get_substr_after_first_of(baseline, ') {') + func_start = get_substr_after_first_of(baseline, 'NO_INLINE correct') + impl = get_substr_after_first_of(func_start, ') {') return_type = get_return_type(baseline) prompt['outputs'] = [ impl, diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000..6dba1b9 --- /dev/null +++ b/test/README.md @@ -0,0 +1,3 @@ +# Tests + +Testing of the benchmark. Currently only tests the sequential CPU capabilities. \ No newline at end of file diff --git a/test/test-serial.bash b/test/test-serial.bash new file mode 100644 index 0000000..776cc22 --- /dev/null +++ b/test/test-serial.bash @@ -0,0 +1,44 @@ +#!/bin/bash +# Uses the baseline implementations to test the CPU capabilities of the system. + +# usage: bash test/test-cpu.bash +if [ $# -eq 0 ]; then + echo "No problem specified. Using default: 'all'." + PROBLEM_ARG="" +else + PROBLEM_ARG="--problem $1" +fi + +# First, use the baseline implementations to mimic LLM outputs. +python prompts/create-serial-tests.py drivers/cpp/benchmarks prompts/generation-prompts.json serial-generations.json + +# make sure the model drivers are built +cd drivers +cd cpp +make +cd .. + +# Run the drivers using these generations +python run-all.py \ + ../serial-generations.json \ + --output results.json \ + --launch-configs launch-configs.json \ + --problem-sizes problem-sizes.json \ + --yes-to-all \ + --include-models serial \ + ${PROBLEM_ARG} \ + --build-timeout 60 \ + --run-timeout 120 \ + --log info + + +# check results +cd .. +python test/validate-test-results.py \ + --results drivers/results.json \ + --problem $1 \ + --expected-write 3 \ + --expected-source-valid 3 \ + --expected-build 2 \ + --expected-run 2 \ + --expected-correct 1 \ No newline at end of file diff --git a/test/validate-test-results.py b/test/validate-test-results.py new file mode 100644 index 0000000..cc191a5 --- /dev/null +++ b/test/validate-test-results.py @@ -0,0 +1,112 @@ +""" Checks if the expected test results are present in the output JSON file. + usage: python test/validate-test-results.py \ + --results \ + --problem \ + --expected-write \ + --expected-source-valid \ + --expected-build \ + --expected-run \ + --expected-correct +""" +from argparse import ArgumentParser +import json +from collections import Counter + + +def parse_args(): + parser = ArgumentParser(description="Validate test results.") + parser.add_argument( + "--results", + type=str, + required=True, + help="Path to the results JSON file.", + ) + parser.add_argument( + "--problem", + type=str, + required=True, + help="Name of the problem to validate.", + ) + parser.add_argument( + "--expected-write", + type=int, + required=True, + help="Expected number of write operations.", + ) + parser.add_argument( + "--expected-source-valid", + type=int, + required=True, + help="Expected number of source valid operations.", + ) + parser.add_argument( + "--expected-build", + type=int, + required=True, + help="Expected number of build operations.", + ) + parser.add_argument( + "--expected-run", + type=int, + required=True, + help="Expected number of run operations.", + ) + parser.add_argument( + "--expected-correct", + type=int, + required=True, + help="Expected number of correct operations.", + ) + + return parser.parse_args() + + +def validate_outputs(outputs, expected_counts): + actual_counts = Counter() + + for output in outputs: + if output.get("source_write_success", False): + actual_counts["write"] += 1 + if output.get("is_source_valid", False): + actual_counts["source_valid"] += 1 + if output.get("did_build", False): + actual_counts["build"] += 1 + if output.get("did_all_run", False): + actual_counts["run"] += 1 + if output.get("are_all_valid", False): + actual_counts["correct"] += 1 + + for key, expected in expected_counts.items(): + actual = actual_counts[key] + if actual != expected: + print(f"Expected {expected} for {key}, but got {actual}.") + return False + return True + + +def main(): + args = parse_args() + + # Load the results JSON file + with open(args.results, "r") as f: + results = json.load(f) + + # Validate the results + expected_counts = { + "write": args.expected_write, + "source_valid": args.expected_source_valid, + "build": args.expected_build, + "run": args.expected_run, + "correct": args.expected_correct, + } + + results = [r for r in results if r["name"] == args.problem][0] + + if not validate_outputs(results["outputs"], expected_counts): + print(f"Validation failed for problem {args.problem}.") + return 1 + + +if __name__ == "__main__": + main() + \ No newline at end of file