Skip to content

Commit b01c349

Browse files
gsethi523facebook-github-bot
authored andcommitted
fix gpu device bug in split_table_batched_embeddings_benchmark (#590)
Summary: Pull Request resolved: #590 Bug in get_device() function, where on a gpu machine (torch.cuda.is_available() == True) it returns torch.cuda.current_device instead of torch.cuda.current_device(). Therefore the return value for get_device() ends up being a function rather than a torch.device, causing the benchmark to crash at the first instance of device placement based on get_device(). Reviewed By: bilgeacun, jianyuh Differential Revision: D27606621 fbshipit-source-id: 0fcfae496f44092bdd5c29f8966e26ef7852110f
1 parent a98ad84 commit b01c349

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def div_round_up(a: int, b: int) -> int:
3737

3838
def get_device() -> torch.device:
3939
return (
40-
torch.cuda.current_device if torch.cuda.is_available() else torch.device("cpu")
40+
torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu")
4141
)
4242

4343

0 commit comments

Comments
 (0)