Skip to content

Commit c565348

Browse files
r-barnesfacebook-github-bot
authored andcommitted
More types for fbgemm_gpu (#585)
Summary: Pull Request resolved: #585 Reviewed By: xush6528 Differential Revision: D27550481 fbshipit-source-id: 3eccfa359cd8f983b7c74a7c2f22c75024779f3b
1 parent 4c43051 commit c565348

File tree

3 files changed

+50
-58
lines changed

3 files changed

+50
-58
lines changed

fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
SparseType,
2121
SplitTableBatchedEmbeddingBagsCodegen,
2222
)
23+
from torch import Tensor
2324

2425
logging.basicConfig(level=logging.DEBUG)
2526

@@ -43,8 +44,8 @@ def get_device() -> torch.device:
4344
# Merged indices with shape (T, B, L) -> (flattened indices with shape
4445
# (T * B * L), offsets with shape (T * B + 1))
4546
def get_table_batched_offsets_from_dense(
46-
merged_indices: torch.Tensor,
47-
) -> Tuple[torch.Tensor, torch.Tensor]:
47+
merged_indices: Tensor,
48+
) -> Tuple[Tensor, Tensor]:
4849
(T, B, L) = merged_indices.size()
4950
lengths = np.ones((T, B)) * L
5051
flat_lengths = lengths.flatten()
@@ -67,7 +68,7 @@ def generate_requests(
6768
alpha: float = 1.0,
6869
weights_precision: SparseType = SparseType.FP32,
6970
weighted: bool = False,
70-
) -> List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]:
71+
) -> List[Tuple[Tensor, Tensor, Optional[Tensor]]]:
7172
if alpha <= 1.0:
7273
all_indices = torch.randint(
7374
low=0,
@@ -111,9 +112,8 @@ def generate_requests(
111112

112113

113114
def benchmark_requests(
114-
requests: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
115-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
116-
f: Callable,
115+
requests: List[Tuple[Tensor, Tensor, Optional[Tensor]]],
116+
func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor],
117117
) -> float:
118118
if torch.cuda.is_available():
119119
torch.cuda.synchronize()
@@ -123,7 +123,7 @@ def benchmark_requests(
123123
else:
124124
start_time = time.time()
125125
for (indices, offsets, weights) in requests:
126-
f(indices, offsets, weights)
126+
func(indices, offsets, weights)
127127
if torch.cuda.is_available():
128128
end_event.record()
129129
torch.cuda.synchronize()
@@ -133,11 +133,9 @@ def benchmark_requests(
133133

134134

135135
def benchmark_pipelined_requests(
136-
requests: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
137-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
138-
f: Callable,
139-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
140-
g: Callable,
136+
requests: List[Tuple[Tensor, Tensor, Optional[Tensor]]],
137+
func1: Callable[[Tensor, Tensor, Optional[Tensor]], None],
138+
func2: Callable[[Tensor, Tensor, Optional[Tensor]], None],
141139
) -> Tuple[float, float]:
142140
torch.cuda.synchronize()
143141
start_events = [
@@ -152,10 +150,10 @@ def benchmark_pipelined_requests(
152150
requests, start_events, end_events
153151
):
154152
start_event[0].record()
155-
f(indices, offsets, indices_weights)
153+
func1(indices, offsets, indices_weights)
156154
end_event[0].record()
157155
start_event[1].record()
158-
g(indices, offsets, indices_weights)
156+
func2(indices, offsets, indices_weights)
159157
end_event[1].record()
160158
torch.cuda.synchronize()
161159
return (

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
INT8_EMB_ROW_DIM_OFFSET = 8
2828

2929

30+
class DoesNotHavePrefix(Exception):
31+
pass
32+
33+
3034
class EmbeddingLocation(enum.IntEnum):
3135
DEVICE = 0
3236
MANAGED = 1
@@ -420,10 +424,9 @@ def __init__( # noqa C901
420424

421425
self.step = 0
422426

423-
# pyre-fixme[3]: Return type must be annotated.
424-
def get_states(self, prefix: str):
427+
def get_states(self, prefix: str) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
425428
if not hasattr(self, f"{prefix}_physical_placements"):
426-
return None
429+
raise DoesNotHavePrefix()
427430
dev_param = getattr(self, f"{prefix}_dev")
428431
host_param = getattr(self, f"{prefix}_host")
429432
uvm_param = getattr(self, f"{prefix}_uvm")
@@ -437,14 +440,13 @@ def get_states(self, prefix: str):
437440
torch.tensor(offsets, dtype=torch.int64),
438441
)
439442

440-
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
441-
# `typing.List` to avoid runtime subscripting errors.
442-
def get_all_states(self) -> List:
443+
def get_all_states(self) -> List[Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]]:
443444
all_states = []
444445
for prefix in ["weights", "momentum1", "momentum2"]:
445-
states = self.get_states(prefix)
446-
if states:
447-
all_states.append(states)
446+
try:
447+
all_states.append(self.get_states(prefix))
448+
except DoesNotHavePrefix:
449+
pass
448450
return all_states
449451

450452
def forward(
@@ -741,16 +743,11 @@ def split_optimizer_states(self) -> List[Tuple[torch.Tensor]]:
741743
"""
742744

743745
def get_optimizer_states(
744-
# pyre-fixme[2]: Parameter must be annotated.
745-
state_dev,
746-
# pyre-fixme[2]: Parameter must be annotated.
747-
state_host,
748-
# pyre-fixme[2]: Parameter must be annotated.
749-
state_uvm,
750-
# pyre-fixme[2]: Parameter must be annotated.
751-
state_offsets,
752-
# pyre-fixme[2]: Parameter must be annotated.
753-
state_placements,
746+
state_dev: Tensor,
747+
state_host: Tensor,
748+
state_uvm: Tensor,
749+
state_offsets: Tensor,
750+
state_placements: Tensor,
754751
rowwise: bool,
755752
) -> List[torch.Tensor]:
756753
splits = []
@@ -872,9 +869,7 @@ def flush(self) -> None:
872869
self.stochastic_rounding,
873870
)
874871

875-
# pyre-fixme[2]: Parameter must be annotated.
876-
# pyre-fixme[2]: Parameter must be annotated.
877-
def _apply_split(self, split, prefix, dtype: torch.dtype, enforce_hbm: bool = False) -> None:
872+
def _apply_split(self, split: SplitState, prefix: str, dtype: torch.dtype, enforce_hbm: bool = False) -> None:
878873
setattr(self, f"{prefix}_physical_placements", split.placements)
879874
setattr(self, f"{prefix}_physical_offsets", split.offsets)
880875

@@ -1184,8 +1179,7 @@ def __init__(
11841179
row for (row, _) in embedding_specs[:t]
11851180
)
11861181

1187-
# pyre-fixme[4]: Attribute must be annotated.
1188-
self.weights_physical_offsets = weights_offsets
1182+
self.weights_physical_offsets: List[int] = weights_offsets
11891183
weights_offsets = [weights_offsets[t] for t in feature_table_map]
11901184
self.register_buffer(
11911185
"weights_offsets",

fbgemm_gpu/test/split_table_batched_embeddings_test.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,33 @@
99

1010
import copy
1111
import unittest
12-
from typing import Any, Callable, List, Optional, Tuple
12+
from typing import Any, Callable, List, Optional, Tuple, TypeVar
1313

1414
import fbgemm_gpu.split_table_batched_embeddings_ops as split_table_batched_embeddings_ops
1515
import hypothesis.strategies as st
1616
import numpy as np
1717
import torch
1818
from fbgemm_gpu.split_table_batched_embeddings_ops import OptimType, SparseType
19+
from torch import Tensor
1920
from hypothesis import HealthCheck, Verbosity, assume, given, settings
2021

2122

2223
MAX_EXAMPLES = 40
24+
Deviceable = TypeVar("Deviceable", torch.nn.EmbeddingBag, Tensor)
25+
26+
27+
def b_indices(
28+
b: torch.nn.EmbeddingBag,
29+
x: torch.Tensor,
30+
per_sample_weights: Optional[torch.Tensor] = None,
31+
use_cpu: bool = False
32+
) -> torch.Tensor:
33+
(indices, offsets) = get_offsets_from_dense(x)
34+
return b(
35+
to_device(indices, use_cpu),
36+
to_device(offsets, use_cpu),
37+
per_sample_weights=per_sample_weights,
38+
)
2339

2440

2541
def div_round_up(a: int, b: int) -> int:
@@ -35,8 +51,9 @@ def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.T
3551
),
3652
)
3753

38-
39-
def to_device(t: torch.Tensor, use_cpu: bool) -> torch.Tensor:
54+
def to_device(t: Deviceable, use_cpu: bool) -> Deviceable:
55+
# pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor,
56+
# torch.nn.EmbeddingBag]`.
4057
return t.cpu() if use_cpu else t.cuda()
4158

4259

@@ -239,11 +256,9 @@ def test_forward(
239256
xws = [xw.half() for xw in xws]
240257

241258
fs = (
242-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
243259
[b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
244260
if not weighted
245261
else [
246-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
247262
b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
248263
for (b, x, xw) in zip(bs, xs, xws)
249264
]
@@ -270,7 +285,6 @@ def test_forward(
270285
cc = torch.jit.script(cc)
271286

272287
for t in range(T):
273-
# pyre-fixme[16]: `Tensor` has no attribute `weight`.
274288
cc.split_embedding_weights()[t].data.copy_(bs[t].weight)
275289

276290
x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
@@ -385,19 +399,16 @@ def test_backward_dense(
385399
xws = [xw.half() for xw in xws]
386400

387401
fs = (
388-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
389402
[b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
390403
if not weighted
391404
else [
392-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
393405
b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
394406
for (b, x, xw) in zip(bs, xs, xws)
395407
]
396408
)
397409
gos = [torch.randn_like(f) for f in fs]
398410
[f.backward(go) for (f, go) in zip(fs, gos)]
399411

400-
# pyre-fixme[16]: `Tensor` has no attribute `weight`.
401412
grad_weights = torch.cat([b.weight.grad.view(-1) for b in bs])
402413
if weights_precision == SparseType.FP16 and not use_cpu:
403414
grad_weights = grad_weights.half()
@@ -570,7 +581,6 @@ def test_backward_sgd( # noqa C901
570581
feature_table_map = list(range(T))
571582
if exact:
572583
table_to_replicate = T // 2
573-
# pyre-fixme[6]: Expected `HalfTensor` for 2nd param but got `Tensor`.
574584
bs.insert(table_to_replicate, bs[table_to_replicate])
575585
feature_table_map.insert(table_to_replicate, table_to_replicate)
576586

@@ -598,11 +608,9 @@ def test_backward_sgd( # noqa C901
598608
xws = [xw.half() for xw in xws]
599609

600610
fs = (
601-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
602611
[b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
603612
if not weighted
604613
else [
605-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
606614
b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
607615
for (b, x, xw) in zip(bs, xs, xws)
608616
]
@@ -613,7 +621,6 @@ def test_backward_sgd( # noqa C901
613621
lr = 0.05
614622
if exact:
615623
del bs[table_to_replicate]
616-
# pyre-fixme[16]: `Tensor` has no attribute `weight`.
617624
new_weights = [(b.weight - b.weight.grad * lr) for b in bs]
618625

619626
cc = split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(
@@ -782,7 +789,6 @@ def test_backward_adagrad( # noqa C901
782789
if exact:
783790
# autograd with shared embedding only works for exact
784791
table_to_replicate = T // 2
785-
# pyre-fixme[6]: Expected `HalfTensor` for 2nd param but got `Tensor`.
786792
bs.insert(table_to_replicate, bs[table_to_replicate])
787793
feature_table_map.insert(table_to_replicate, table_to_replicate)
788794

@@ -805,11 +811,9 @@ def test_backward_adagrad( # noqa C901
805811
xws = [xw.half() for xw in xws]
806812

807813
fs = (
808-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
809814
[b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
810815
if not weighted
811816
else [
812-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
813817
b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
814818
for (b, x, xw) in zip(bs, xs, xws)
815819
]
@@ -839,7 +843,6 @@ def test_backward_adagrad( # noqa C901
839843
if exact:
840844
del bs[table_to_replicate]
841845
for t in range(T):
842-
# pyre-fixme[16]: `Tensor` has no attribute `weight`.
843846
cc.split_embedding_weights()[t].data.copy_(bs[t].weight)
844847

845848
x = torch.cat([x.view(1, B, L) for x in xs], dim=0)
@@ -1162,11 +1165,9 @@ def test_backward_optimizers( # noqa C901
11621165
xws_acc_type = copy.deepcopy(xws)
11631166

11641167
fs = (
1165-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
11661168
[b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)]
11671169
if not weighted
11681170
else [
1169-
# pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
11701171
b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu)
11711172
for (b, x, xw) in zip(bs, xs, xws)
11721173
]
@@ -1214,7 +1215,6 @@ def test_backward_optimizers( # noqa C901
12141215
)
12151216

12161217
for t in range(T):
1217-
# pyre-fixme[16]: `Tensor` has no attribute `weight`.
12181218
cc.split_embedding_weights()[t].data.copy_(bs[t].weight)
12191219

12201220
x = torch.cat([x.view(1, B, L) for x in xs], dim=0)

0 commit comments

Comments
 (0)