From 17274d8eec3c650a102761674cca25662e6a42f1 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 16 May 2021 11:13:09 -0700 Subject: [PATCH] Fix the int4 table batched embedding benchmark with mixed dim Summary: Fix the bench with "--mixed" dimension as reported by xcliang in https://www.internalfb.com/diff/D28248236 (https://github.com/pytorch/FBGEMM/commit/0fe80ee014b936733278a77d0a24c9fe9a431c31)?dst_version_fbid=834500627143449&transaction_fbid=331571331658562 Differential Revision: D28466825 fbshipit-source-id: a343a3d629c8fc751fc228321f5f43ec28065f99 --- .../bench/split_table_batched_embeddings_benchmark.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index ebd4e90a59..a08627aad0 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -791,7 +791,8 @@ def cpu( # noqa C901 feature_requires_grad = None if mixed: Ds = [ - div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) + # int4 table batched emb op can only handle mixed D where D is multiple of 8 + div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8) for _ in range(T) ] D = np.average(Ds) @@ -905,8 +906,9 @@ def int4_device( # noqa C901 else: feature_requires_grad = None if mixed: + # int4 table batched emb op can only handle mixed D where D is multiple of 8 Ds = [ - div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) + div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8) for _ in range(T) ] D = np.average(Ds)