diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index ebd4e90a59..a08627aad0 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -791,7 +791,8 @@ def cpu( # noqa C901 feature_requires_grad = None if mixed: Ds = [ - div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) + # int4 table batched emb op can only handle mixed D where D is multiple of 8 + div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8) for _ in range(T) ] D = np.average(Ds) @@ -905,8 +906,9 @@ def int4_device( # noqa C901 else: feature_requires_grad = None if mixed: + # int4 table batched emb op can only handle mixed D where D is multiple of 8 Ds = [ - div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) + div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8) for _ in range(T) ] D = np.average(Ds)