Skip to content

Commit c48ba9b

Browse files
jianyuhfacebook-github-bot
authored andcommitted
Fix the int4 table batched embedding benchmark with mixed dim (#609)
Summary: Pull Request resolved: #609 Fix the bench with "--mixed" dimension as reported by xcliang in https://www.internalfb.com/diff/D28248236 (0fe80ee014b936733278a77d0a24c9fe9a431c31)?dst_version_fbid=834500627143449&transaction_fbid=331571331658562 Differential Revision: D28466825 fbshipit-source-id: ac4725f37d89a3ecd2bcccd564b0424173f81230
1 parent 0fe80ee commit c48ba9b

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,8 @@ def cpu( # noqa C901
791791
feature_requires_grad = None
792792
if mixed:
793793
Ds = [
794-
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
794+
# int4 table batched emb op can only handle mixed D where D is multiple of 8
795+
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8)
795796
for _ in range(T)
796797
]
797798
D = np.average(Ds)
@@ -905,8 +906,9 @@ def int4_device( # noqa C901
905906
else:
906907
feature_requires_grad = None
907908
if mixed:
909+
# int4 table batched emb op can only handle mixed D where D is multiple of 8
908910
Ds = [
909-
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
911+
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8)
910912
for _ in range(T)
911913
]
912914
D = np.average(Ds)

0 commit comments

Comments
 (0)