Skip to content

Commit 40cb9fe

Browse files
Sasha Fominafacebook-github-bot
authored andcommitted
Moving all tensor allocation from cpu to meta device in SplitTableBatchedEmbeddingBagsCodegen
Summary: Used profiler logs sorted by `cpu_memory_usage` in `embedding_bag_wprofiler_gpu_test.py` to add device kwarg to all tensor allocation sites in `split_table_batched_embeddings_ops.py`, so that they can be materialized on the meta device. Some `torch.tensor` calls switched out for `torch.zeros` calls (where appropriate) to avoid temporary allocation on CPU memory. Still, some `torch.tensor` calls were kept with temp. CPU memory allocation but with final materialization on the meta device. Reviewed By: xush6528 Differential Revision: D29566376 fbshipit-source-id: c01575127cb2392f95ec1d3712ad43803a373db5
1 parent 57f3478 commit 40cb9fe

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __init__( # noqa C901
196196
self.weights_precision = weights_precision
197197
self.record_cache_metrics = record_cache_metrics
198198
# NOTE: a placeholder to avoid multi-construction and make TorchScript work!
199-
self.dummy_tensor: Tensor = torch.tensor(0)
199+
self.dummy_tensor: Tensor = torch.zeros(0, device=device)
200200

201201
self.embedding_specs = embedding_specs
202202
(rows, dims, locations, compute_devices) = zip(*embedding_specs)
@@ -373,30 +373,30 @@ def __init__( # noqa C901
373373
prefix="momentum2",
374374
dtype=torch.float32,
375375
)
376-
self.register_buffer("iter", torch.tensor([0], dtype=torch.int64))
376+
self.register_buffer("iter", torch.zeros(1, dtype=torch.int64, device=self.current_device))
377377
else:
378378
# NOTE: make TorchScript work!
379379
self.register_buffer(
380-
"momentum2_dev", torch.tensor([0], dtype=torch.int64), persistent=False
380+
"momentum2_dev", torch.zeros(1, dtype=torch.int64, device=self.current_device), persistent=False
381381
)
382382
self.register_buffer(
383-
"momentum2_host", torch.tensor([0], dtype=torch.int64), persistent=False
383+
"momentum2_host", torch.zeros(1, dtype=torch.int64, device=self.current_device), persistent=False
384384
)
385385
self.register_buffer(
386-
"momentum2_uvm", torch.tensor([0], dtype=torch.int64), persistent=False
386+
"momentum2_uvm", torch.zeros(1, dtype=torch.int64, device=self.current_device), persistent=False
387387
)
388388
self.register_buffer(
389389
"momentum2_placements",
390-
torch.tensor([0], dtype=torch.int64),
390+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
391391
persistent=False,
392392
)
393393
self.register_buffer(
394394
"momentum2_offsets",
395-
torch.tensor([0], dtype=torch.int64),
395+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
396396
persistent=False,
397397
)
398398
self.register_buffer(
399-
"iter", torch.tensor([0], dtype=torch.int64), persistent=False
399+
"iter", torch.zeros(1, dtype=torch.int64, device=self.current_device), persistent=False
400400
)
401401

402402
cache_state = construct_cache_state(embedding_specs, self.feature_table_map)
@@ -983,27 +983,27 @@ def _apply_cache_state(
983983
# NOTE: make TorchScript work!
984984
self.register_buffer(
985985
"cache_hash_size_cumsum",
986-
torch.tensor([0], dtype=torch.int64),
986+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
987987
persistent=False,
988988
)
989989
self.register_buffer(
990990
"total_cache_hash_size",
991-
torch.tensor([0], dtype=torch.int64),
991+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
992992
persistent=False,
993993
)
994994
self.register_buffer(
995995
"cache_index_table_map",
996-
torch.tensor([0], dtype=torch.int64),
996+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
997997
persistent=False,
998998
)
999999
self.register_buffer(
10001000
"lxu_cache_state",
1001-
torch.tensor([0], dtype=torch.int64),
1001+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
10021002
persistent=False,
10031003
)
10041004
self.register_buffer(
10051005
"lxu_state",
1006-
torch.tensor([0], dtype=torch.int64),
1006+
torch.zeros(1, dtype=torch.int64, device=self.current_device),
10071007
persistent=False,
10081008
)
10091009
self.register_buffer(

0 commit comments

Comments
 (0)