99
1010import copy
1111import unittest
12- from typing import Any , Callable , List , Optional , Tuple
12+ from typing import Any , Callable , List , Optional , Tuple , TypeVar
1313
1414import fbgemm_gpu .split_table_batched_embeddings_ops as split_table_batched_embeddings_ops
1515import hypothesis .strategies as st
1616import numpy as np
1717import torch
1818from fbgemm_gpu .split_table_batched_embeddings_ops import OptimType , SparseType
19+ from torch import Tensor
1920from hypothesis import HealthCheck , Verbosity , assume , given , settings
2021
2122
2223MAX_EXAMPLES = 40
24+ Deviceable = TypeVar ("Deviceable" , torch .nn .EmbeddingBag , Tensor )
25+
26+
27+ def b_indices (
28+ b : torch .nn .EmbeddingBag ,
29+ x : torch .Tensor ,
30+ per_sample_weights : Optional [torch .Tensor ] = None ,
31+ use_cpu : bool = False
32+ ) -> torch .Tensor :
33+ (indices , offsets ) = get_offsets_from_dense (x )
34+ return b (
35+ to_device (indices , use_cpu ),
36+ to_device (offsets , use_cpu ),
37+ per_sample_weights = per_sample_weights ,
38+ )
2339
2440
2541def div_round_up (a : int , b : int ) -> int :
@@ -35,8 +51,9 @@ def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.T
3551 ),
3652 )
3753
38-
39- def to_device (t : torch .Tensor , use_cpu : bool ) -> torch .Tensor :
54+ def to_device (t : Deviceable , use_cpu : bool ) -> Deviceable :
55+ # pyre-fixme[7]: Expected `Deviceable` but got `Union[Tensor,
56+ # torch.nn.EmbeddingBag]`.
4057 return t .cpu () if use_cpu else t .cuda ()
4158
4259
@@ -239,11 +256,9 @@ def test_forward(
239256 xws = [xw .half () for xw in xws ]
240257
241258 fs = (
242- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
243259 [b_indices (b , x , use_cpu = use_cpu ) for (b , x ) in zip (bs , xs )]
244260 if not weighted
245261 else [
246- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
247262 b_indices (b , x , per_sample_weights = xw .view (- 1 ), use_cpu = use_cpu )
248263 for (b , x , xw ) in zip (bs , xs , xws )
249264 ]
@@ -270,7 +285,6 @@ def test_forward(
270285 cc = torch .jit .script (cc )
271286
272287 for t in range (T ):
273- # pyre-fixme[16]: `Tensor` has no attribute `weight`.
274288 cc .split_embedding_weights ()[t ].data .copy_ (bs [t ].weight )
275289
276290 x = torch .cat ([x .view (1 , B , L ) for x in xs ], dim = 0 )
@@ -385,19 +399,16 @@ def test_backward_dense(
385399 xws = [xw .half () for xw in xws ]
386400
387401 fs = (
388- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
389402 [b_indices (b , x , use_cpu = use_cpu ) for (b , x ) in zip (bs , xs )]
390403 if not weighted
391404 else [
392- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
393405 b_indices (b , x , per_sample_weights = xw .view (- 1 ), use_cpu = use_cpu )
394406 for (b , x , xw ) in zip (bs , xs , xws )
395407 ]
396408 )
397409 gos = [torch .randn_like (f ) for f in fs ]
398410 [f .backward (go ) for (f , go ) in zip (fs , gos )]
399411
400- # pyre-fixme[16]: `Tensor` has no attribute `weight`.
401412 grad_weights = torch .cat ([b .weight .grad .view (- 1 ) for b in bs ])
402413 if weights_precision == SparseType .FP16 and not use_cpu :
403414 grad_weights = grad_weights .half ()
@@ -570,7 +581,6 @@ def test_backward_sgd( # noqa C901
570581 feature_table_map = list (range (T ))
571582 if exact :
572583 table_to_replicate = T // 2
573- # pyre-fixme[6]: Expected `HalfTensor` for 2nd param but got `Tensor`.
574584 bs .insert (table_to_replicate , bs [table_to_replicate ])
575585 feature_table_map .insert (table_to_replicate , table_to_replicate )
576586
@@ -598,11 +608,9 @@ def test_backward_sgd( # noqa C901
598608 xws = [xw .half () for xw in xws ]
599609
600610 fs = (
601- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
602611 [b_indices (b , x , use_cpu = use_cpu ) for (b , x ) in zip (bs , xs )]
603612 if not weighted
604613 else [
605- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
606614 b_indices (b , x , per_sample_weights = xw .view (- 1 ), use_cpu = use_cpu )
607615 for (b , x , xw ) in zip (bs , xs , xws )
608616 ]
@@ -613,7 +621,6 @@ def test_backward_sgd( # noqa C901
613621 lr = 0.05
614622 if exact :
615623 del bs [table_to_replicate ]
616- # pyre-fixme[16]: `Tensor` has no attribute `weight`.
617624 new_weights = [(b .weight - b .weight .grad * lr ) for b in bs ]
618625
619626 cc = split_table_batched_embeddings_ops .SplitTableBatchedEmbeddingBagsCodegen (
@@ -782,7 +789,6 @@ def test_backward_adagrad( # noqa C901
782789 if exact :
783790 # autograd with shared embedding only works for exact
784791 table_to_replicate = T // 2
785- # pyre-fixme[6]: Expected `HalfTensor` for 2nd param but got `Tensor`.
786792 bs .insert (table_to_replicate , bs [table_to_replicate ])
787793 feature_table_map .insert (table_to_replicate , table_to_replicate )
788794
@@ -805,11 +811,9 @@ def test_backward_adagrad( # noqa C901
805811 xws = [xw .half () for xw in xws ]
806812
807813 fs = (
808- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
809814 [b_indices (b , x , use_cpu = use_cpu ) for (b , x ) in zip (bs , xs )]
810815 if not weighted
811816 else [
812- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
813817 b_indices (b , x , per_sample_weights = xw .view (- 1 ), use_cpu = use_cpu )
814818 for (b , x , xw ) in zip (bs , xs , xws )
815819 ]
@@ -839,7 +843,6 @@ def test_backward_adagrad( # noqa C901
839843 if exact :
840844 del bs [table_to_replicate ]
841845 for t in range (T ):
842- # pyre-fixme[16]: `Tensor` has no attribute `weight`.
843846 cc .split_embedding_weights ()[t ].data .copy_ (bs [t ].weight )
844847
845848 x = torch .cat ([x .view (1 , B , L ) for x in xs ], dim = 0 )
@@ -1162,11 +1165,9 @@ def test_backward_optimizers( # noqa C901
11621165 xws_acc_type = copy .deepcopy (xws )
11631166
11641167 fs = (
1165- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
11661168 [b_indices (b , x , use_cpu = use_cpu ) for (b , x ) in zip (bs , xs )]
11671169 if not weighted
11681170 else [
1169- # pyre-fixme[6]: Expected `(...) -> Any` for 1st param but got `Tensor`.
11701171 b_indices (b , x , per_sample_weights = xw .view (- 1 ), use_cpu = use_cpu )
11711172 for (b , x , xw ) in zip (bs , xs , xws )
11721173 ]
@@ -1214,7 +1215,6 @@ def test_backward_optimizers( # noqa C901
12141215 )
12151216
12161217 for t in range (T ):
1217- # pyre-fixme[16]: `Tensor` has no attribute `weight`.
12181218 cc .split_embedding_weights ()[t ].data .copy_ (bs [t ].weight )
12191219
12201220 x = torch .cat ([x .view (1 , B , L ) for x in xs ], dim = 0 )
0 commit comments