@@ -10,6 +10,30 @@ class MyDTypeTensorTP(MyDTypeTensor):
1010
1111aten = torch .ops .aten
1212
13+ def fill_defaults (args , n , defaults_tail ):
14+ """
15+ __torch_dispatch__ doesn't guarantee the number of arguments you are
16+ passed (e.g., defaulted arguments are not passed); but usually it is
17+ convenient to pad out the arguments list with defaults. This function
18+ helps you do that.
19+ Args:
20+ args: the list of positional arguments passed to __torch_dispatch__
21+ n: the number of arguments you are expecting to get
22+ defaults_tail: default values for the arguments, starting from the
23+ end of the list
24+ Example:
25+ >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
26+ [1, 2, 3, 4, 5]
27+ >>> fill_defaults([1, 2, 3], 5, [None, None, None])
28+ [1, 2, 3, None, None]]
29+ """
30+ if n - len (defaults_tail ) > len (args ):
31+ raise RuntimeError ("not enough defaults to fill arguments" )
32+ r = list (args )
33+ for i in range (len (args ), n ):
34+ r .append (defaults_tail [i - n + len (defaults_tail )])
35+ return r
36+
1337@implements ([aten ._to_copy .default , aten .clone .default ])
1438def _ (func , types , args , kwargs ):
1539 return return_and_correct_aliasing (
@@ -27,6 +51,16 @@ def _(func, types, args, kwargs):
2751 empty_like_layout_tensor = func (args [0 ].layout_tensor , * args [1 :], ** kwargs )
2852 return MyDTypeTensorTP (empty_like_layout_tensor , empty_like_layout_tensor .shape )
2953
54+ @implements ([aten .slice .Tensor ])
55+ def _ (func , types , args , kwargs ):
56+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
57+ print ("slice:" , dim , start , end , step )
58+ if dim == 0 :
59+ assert step == 1
60+ return self .__class__ (aten .slice .Tensor (self .layout_tensor ), (end - start + 1 ,) + self .shape [1 :], self .dtype )
61+ return
62+
63+
3064class M (torch .nn .Module ):
3165 def __init__ (self , * args , ** kwargs ) -> None :
3266 super ().__init__ (* args , ** kwargs )
@@ -84,4 +118,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
84118 print ("input dtensor:" , input_dtensor )
85119
86120 m (input_dtensor )
87-
0 commit comments