From 169017792f8bced53d2970416e6a1d1b07b05818 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 13 Jun 2025 16:56:25 +0800 Subject: [PATCH 1/7] compare llama-bench: add option to plot --- scripts/compare-llama-bench.py | 163 +++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index a1013c3b7a66d..b741efd35f81c 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -122,11 +122,23 @@ parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed") parser.add_argument("-s", "--show", help=help_s) parser.add_argument("--verbose", action="store_true", help="increase output verbosity") +parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)") +parser.add_argument("--plot_x", help="parameter to use as x-axis for plotting (default: n_depth)", default="n_depth") known_args, unknown_args = parser.parse_known_args() logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) +# Check for matplotlib if plotting is requested +if known_args.plot: + try: + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') + except ImportError as e: + print("matplotlib is required for --plot.") + raise e + if known_args.check: # Check if all required Python libraries are installed. Would have failed earlier if not. sys.exit(0) @@ -600,6 +612,157 @@ def valid_format(data_files: list[str]) -> bool: headers = [PRETTY_NAMES[p] for p in show] headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] +if known_args.plot: + def create_performance_plot(table_data, headers, baseline_name, compare_name, output_file, plot_x_param): + + data_headers = headers[:-4] #Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup) + plot_x_index = None + plot_x_label = plot_x_param + + if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]: + pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param) + if pretty_name in data_headers: + plot_x_index = data_headers.index(pretty_name) + plot_x_label = pretty_name + elif plot_x_param in data_headers: + plot_x_index = data_headers.index(plot_x_param) + plot_x_label = plot_x_param + else: + logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}") + logger.error(f"To plot by '{plot_x_param}', include it in --show parameter or ensure it varies in your data.") + return + + grouped_data = {} + + for i, row in enumerate(table_data): + group_key_parts = [] + test_name = row[-4] + + if plot_x_param in ["n_prompt", "n_gen", "n_depth"]: + for j, val in enumerate(row[:-4]): + header_name = data_headers[j] + if val is not None and str(val).strip(): + group_key_parts.append(f"{header_name}={val}") + + if plot_x_param == "n_prompt": + assert "pp" in test_name, f"n_prompt test name {test_name} does not contain 'pp'" + base_test = test_name.split("@")[0] + x_value = base_test + elif plot_x_param == "n_gen" and "tg" in test_name: + assert "tg" in test_name, f"n_gen test name {test_name} does not contain 'tg'" + x_value = test_name.split("@")[0] + elif plot_x_param == "n_depth" and "@d" in test_name: + assert "@d" in test_name, f"n_depth test name {test_name} does not contain '@d'" + base_test = test_name.split("@d")[0] + x_value = int(test_name.split("@d")[1]) + else: + base_test = test_name + + if base_test.strip(): + group_key_parts.append(f"Test={base_test}") + else: + for j, val in enumerate(row[:-4]): + if j != plot_x_index: + header_name = data_headers[j] + if val is not None and str(val).strip(): + group_key_parts.append(f"{header_name}={val}") + else: + x_value = val + + group_key_parts.append(f"Test={test_name}") + + group_key = tuple(sorted(group_key_parts)) + + if group_key not in grouped_data: + grouped_data[group_key] = [] + + grouped_data[group_key].append({ + 'x_value': x_value, + 'baseline': float(row[-3]), + 'compare': float(row[-2]), + 'speedup': float(row[-1]) + }) + + if not grouped_data: + logger.error("No data available for plotting") + return + + + def make_axes(num_groups, max_cols=2, base_size=(8, 4)): + from math import ceil + cols = 1 if num_groups == 1 else min(max_cols, num_groups) + rows = ceil(num_groups / cols) + + # scale figure size by grid dimensions + w, h = base_size + fig, ax_arr = plt.subplots(rows, cols, + figsize=(w * cols, h * rows), + squeeze=False) + + axes = ax_arr.flatten()[:num_groups] + return fig, axes + + num_groups = len(grouped_data) + fig, axes = make_axes(num_groups) + + plot_idx = 0 + + for group_key, points in grouped_data.items(): + if plot_idx >= len(axes): + break + ax = axes[plot_idx] + + try: + points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0) + x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted] + except ValueError: + points_sorted = sorted(points, key=lambda p: group_key) + x_values = [p['x_value'] for p in points_sorted] + + baseline_vals = [p['baseline'] for p in points_sorted] + compare_vals = [p['compare'] for p in points_sorted] + + ax.plot(x_values, baseline_vals, 'o-', color='skyblue', + label=f'{baseline_name}', linewidth=2, markersize=6) + ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, + label=f'{compare_name}', linewidth=2, markersize=6) + + if plot_x_param == "n_depth" and max(x_values) > 0 and max(x_values) > min(x_values) * 4: + ax.set_xscale('log', base=2) + unique_x = sorted(set(x_values)) + ax.set_xticks(unique_x) + ax.set_xticklabels([str(int(x)) for x in unique_x]) + + title_parts = [] + for part in group_key: + if '=' in part: + key, value = part.split('=', 1) + title_parts.append(f"{key}: {value}") + + title = ', '.join(title_parts) if title_parts else "Performance Comparison" + + ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold') + ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold') + ax.set_title(title, fontsize=12, fontweight='bold') + ax.legend(loc='best', fontsize=10) + ax.grid(True, alpha=0.3) + + plot_idx += 1 + + for i in range(plot_idx, len(axes)): + axes[i].set_visible(False) + + fig.suptitle(f'Performance Comparison: {compare_name} vs {baseline_name}', + fontsize=14, fontweight='bold') + fig.subplots_adjust(top=1) + + + plt.tight_layout() + plt.savefig(output_file, dpi=300, bbox_inches='tight') + plt.close() + + create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x) + print(tabulate( # noqa: NP100 table, headers=headers, From 5426c875f821ba70a800a4453067cb7034cd2744 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 13 Jun 2025 20:29:07 +0800 Subject: [PATCH 2/7] Address review comments: convert case + add type hints --- scripts/compare-llama-bench.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index b741efd35f81c..91033986b1337 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -123,7 +123,7 @@ parser.add_argument("-s", "--show", help=help_s) parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)") -parser.add_argument("--plot_x", help="parameter to use as x-axis for plotting (default: n_depth)", default="n_depth") +parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth") known_args, unknown_args = parser.parse_known_args() @@ -136,7 +136,7 @@ import matplotlib matplotlib.use('Agg') except ImportError as e: - print("matplotlib is required for --plot.") + logger.error("matplotlib is required for --plot.") raise e if known_args.check: @@ -613,9 +613,9 @@ def valid_format(data_files: list[str]) -> bool: headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] if known_args.plot: - def create_performance_plot(table_data, headers, baseline_name, compare_name, output_file, plot_x_param): + def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str): - data_headers = headers[:-4] #Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup) + data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup) plot_x_index = None plot_x_label = plot_x_param @@ -687,7 +687,6 @@ def create_performance_plot(table_data, headers, baseline_name, compare_name, ou logger.error("No data available for plotting") return - def make_axes(num_groups, max_cols=2, base_size=(8, 4)): from math import ceil cols = 1 if num_groups == 1 else min(max_cols, num_groups) @@ -696,8 +695,8 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): # scale figure size by grid dimensions w, h = base_size fig, ax_arr = plt.subplots(rows, cols, - figsize=(w * cols, h * rows), - squeeze=False) + figsize=(w * cols, h * rows), + squeeze=False) axes = ax_arr.flatten()[:num_groups] return fig, axes @@ -739,7 +738,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): key, value = part.split('=', 1) title_parts.append(f"{key}: {value}") - title = ', '.join(title_parts) if title_parts else "Performance Comparison" + title = ', '.join(title_parts) if title_parts else "Performance comparison" ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold') ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold') @@ -752,11 +751,10 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): for i in range(plot_idx, len(axes)): axes[i].set_visible(False) - fig.suptitle(f'Performance Comparison: {compare_name} vs {baseline_name}', - fontsize=14, fontweight='bold') + fig.suptitle(f'Performance comparison: {compare_name} vs {baseline_name}', + fontsize=14, fontweight='bold') fig.subplots_adjust(top=1) - plt.tight_layout() plt.savefig(output_file, dpi=300, bbox_inches='tight') plt.close() From deeaecf5d8a4b1979c0c5fd4005e943289ef05d9 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 13 Jun 2025 21:25:28 +0800 Subject: [PATCH 3/7] Add matplotlib to requirements --- .../requirements-compare-llama-bench.txt | 1 + scripts/compare-llama-bench.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/requirements/requirements-compare-llama-bench.txt b/requirements/requirements-compare-llama-bench.txt index e0aaa32043ce2..d87e897e17199 100644 --- a/requirements/requirements-compare-llama-bench.txt +++ b/requirements/requirements-compare-llama-bench.txt @@ -1,2 +1,3 @@ tabulate~=0.9.0 GitPython~=3.1.43 +matplotlib~=3.10.0 diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 91033986b1337..bdbe060efc305 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -129,7 +129,6 @@ logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) -# Check for matplotlib if plotting is requested if known_args.plot: try: import matplotlib.pyplot as plt @@ -511,7 +510,6 @@ def valid_format(data_files: list[str]) -> bool: name_compare = bench_data.get_commit_name(hexsha8_compare) - # If the user provided columns to group the results by, use them: if known_args.show is not None: show = known_args.show.split(",") @@ -556,6 +554,14 @@ def valid_format(data_files: list[str]) -> bool: show.remove(prop) except ValueError: pass + + # add plot_x parameter to if it's not already there + if known_args.plot: + for k, v in PRETTY_NAMES.items(): + if v == known_args.plot_x and k not in show: + show.append(k) + break + rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) if not rows_show: @@ -629,7 +635,6 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas plot_x_label = plot_x_param else: logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}") - logger.error(f"To plot by '{plot_x_param}', include it in --show parameter or ensure it varies in your data.") return grouped_data = {} @@ -671,7 +676,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas group_key_parts.append(f"Test={test_name}") - group_key = tuple(sorted(group_key_parts)) + group_key = tuple(group_key_parts) if group_key not in grouped_data: grouped_data[group_key] = [] @@ -692,7 +697,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): cols = 1 if num_groups == 1 else min(max_cols, num_groups) rows = ceil(num_groups / cols) - # scale figure size by grid dimensions + # Scale figure size by grid dimensions w, h = base_size fig, ax_arr = plt.subplots(rows, cols, figsize=(w * cols, h * rows), @@ -726,7 +731,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, label=f'{compare_name}', linewidth=2, markersize=6) - if plot_x_param == "n_depth" and max(x_values) > 0 and max(x_values) > min(x_values) * 4: + if plot_x_param == "n_depth" and min(x_values) > 0 and max(x_values) > min(x_values) * 4: ax.set_xscale('log', base=2) unique_x = sorted(set(x_values)) ax.set_xticks(unique_x) @@ -741,7 +746,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): title = ', '.join(title_parts) if title_parts else "Performance comparison" ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold') - ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold') + ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold') ax.set_title(title, fontsize=12, fontweight='bold') ax.legend(loc='best', fontsize=10) ax.grid(True, alpha=0.3) @@ -751,7 +756,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): for i in range(plot_idx, len(axes)): axes[i].set_visible(False) - fig.suptitle(f'Performance comparison: {compare_name} vs {baseline_name}', + fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}', fontsize=14, fontweight='bold') fig.subplots_adjust(top=1) From 3915a8de79776cce62db582313b0f9f469238623 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 13 Jun 2025 21:55:58 +0800 Subject: [PATCH 4/7] fix tests --- scripts/compare-llama-bench.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index bdbe060efc305..3d42c9e57d01f 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -19,6 +19,7 @@ print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100 raise e + logger = logging.getLogger("compare-llama-bench") # All llama-bench SQL fields @@ -129,14 +130,6 @@ logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) -if known_args.plot: - try: - import matplotlib.pyplot as plt - import matplotlib - matplotlib.use('Agg') - except ImportError as e: - logger.error("matplotlib is required for --plot.") - raise e if known_args.check: # Check if all required Python libraries are installed. Would have failed earlier if not. @@ -620,6 +613,13 @@ def valid_format(data_files: list[str]) -> bool: if known_args.plot: def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str): + try: + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') + except ImportError as e: + logger.error("matplotlib is required for --plot.") + raise e data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup) plot_x_index = None @@ -643,6 +643,9 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas group_key_parts = [] test_name = row[-4] + base_test = "" + x_value = None + if plot_x_param in ["n_prompt", "n_gen", "n_depth"]: for j, val in enumerate(row[:-4]): header_name = data_headers[j] From 8228393e9511b8c0842cc170e7e5be2b528a9d17 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 14 Jun 2025 01:14:09 +0800 Subject: [PATCH 5/7] Improve comment and fix assert condition for test --- scripts/compare-llama-bench.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 3d42c9e57d01f..22140733d93f8 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -548,7 +548,7 @@ def valid_format(data_files: list[str]) -> bool: except ValueError: pass - # add plot_x parameter to if it's not already there + # Add plot_x parameter to parameters to show if it's not already present: if known_args.plot: for k, v in PRETTY_NAMES.items(): if v == known_args.plot_x and k not in show: @@ -652,19 +652,16 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas if val is not None and str(val).strip(): group_key_parts.append(f"{header_name}={val}") - if plot_x_param == "n_prompt": - assert "pp" in test_name, f"n_prompt test name {test_name} does not contain 'pp'" + if plot_x_param == "n_prompt" and "pp" in test_name: base_test = test_name.split("@")[0] x_value = base_test elif plot_x_param == "n_gen" and "tg" in test_name: - assert "tg" in test_name, f"n_gen test name {test_name} does not contain 'tg'" x_value = test_name.split("@")[0] elif plot_x_param == "n_depth" and "@d" in test_name: - assert "@d" in test_name, f"n_depth test name {test_name} does not contain '@d'" base_test = test_name.split("@d")[0] x_value = int(test_name.split("@d")[1]) else: - base_test = test_name + assert False if base_test.strip(): group_key_parts.append(f"Test={base_test}") From e79049704063f0a5b92642476f04963bc7882635 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 14 Jun 2025 02:11:50 +0800 Subject: [PATCH 6/7] Add back default test_name, add --plot_log_scale --- scripts/compare-llama-bench.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 22140733d93f8..73ce1498395aa 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -125,6 +125,7 @@ parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)") parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth") +parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)") known_args, unknown_args = parser.parse_known_args() @@ -612,7 +613,7 @@ def valid_format(data_files: list[str]) -> bool: headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] if known_args.plot: - def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str): + def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False): try: import matplotlib.pyplot as plt import matplotlib @@ -661,7 +662,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas base_test = test_name.split("@d")[0] x_value = int(test_name.split("@d")[1]) else: - assert False + base_test = test_name if base_test.strip(): group_key_parts.append(f"Test={base_test}") @@ -731,7 +732,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, label=f'{compare_name}', linewidth=2, markersize=6) - if plot_x_param == "n_depth" and min(x_values) > 0 and max(x_values) > min(x_values) * 4: + if log_scale and min(x_values) > 0: ax.set_xscale('log', base=2) unique_x = sorted(set(x_values)) ax.set_xticks(unique_x) @@ -764,7 +765,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): plt.savefig(output_file, dpi=300, bbox_inches='tight') plt.close() - create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x) + create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale) print(tabulate( # noqa: NP100 table, From 530de45086da4daa234e56bd4d3679ed0d092639 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 14 Jun 2025 11:52:12 +0800 Subject: [PATCH 7/7] use log_scale regardless of x_values --- scripts/compare-llama-bench.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 73ce1498395aa..30e3cf8649e8a 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -732,7 +732,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)): ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, label=f'{compare_name}', linewidth=2, markersize=6) - if log_scale and min(x_values) > 0: + if log_scale: ax.set_xscale('log', base=2) unique_x = sorted(set(x_values)) ax.set_xticks(unique_x)