11import torch
2- from my_dtype_tensor_subclass import MyDTypeTensor
2+ from my_dtype_tensor_subclass import MyDTypeTensor , fill_defaults
33from torch .utils ._python_dispatch import return_and_correct_aliasing
44
55# a tensor subclass that supports tensor parallelism with DTensor
@@ -10,30 +10,6 @@ class MyDTypeTensorTP(MyDTypeTensor):
1010
1111aten = torch .ops .aten
1212
13- def fill_defaults (args , n , defaults_tail ):
14- """
15- __torch_dispatch__ doesn't guarantee the number of arguments you are
16- passed (e.g., defaulted arguments are not passed); but usually it is
17- convenient to pad out the arguments list with defaults. This function
18- helps you do that.
19- Args:
20- args: the list of positional arguments passed to __torch_dispatch__
21- n: the number of arguments you are expecting to get
22- defaults_tail: default values for the arguments, starting from the
23- end of the list
24- Example:
25- >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
26- [1, 2, 3, 4, 5]
27- >>> fill_defaults([1, 2, 3], 5, [None, None, None])
28- [1, 2, 3, None, None]]
29- """
30- if n - len (defaults_tail ) > len (args ):
31- raise RuntimeError ("not enough defaults to fill arguments" )
32- r = list (args )
33- for i in range (len (args ), n ):
34- r .append (defaults_tail [i - n + len (defaults_tail )])
35- return r
36-
3713@implements ([aten ._to_copy .default , aten .clone .default ])
3814def _ (func , types , args , kwargs ):
3915 return return_and_correct_aliasing (
@@ -51,20 +27,67 @@ def _(func, types, args, kwargs):
5127 empty_like_layout_tensor = func (args [0 ].layout_tensor , * args [1 :], ** kwargs )
5228 return MyDTypeTensorTP (empty_like_layout_tensor , empty_like_layout_tensor .shape )
5329
54- @implements ([ aten .slice .Tensor ] )
30+ @implements (aten .slice .Tensor )
5531def _ (func , types , args , kwargs ):
5632 self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
57- print ("slice:" , dim , start , end , step )
58- if dim == 0 :
59- assert step == 1
60- return self .__class__ (aten .slice .Tensor (self .layout_tensor ), (end - start + 1 ,) + self .shape [1 :], self .dtype )
61- return
33+ assert step == 1
34+ if end >= self .shape [dim ]:
35+ end = self .shape [dim ]
36+ print ("dim:" , dim , "start:" , start , " end:" , end , " shape:" , end - start )
37+ print ("manual shape:" , (end - start ,) + self .shape [1 :])
38+ return self .__class__ (aten .slice .Tensor (self .layout_tensor , dim , start , end , step ), (end - start ,) + self .shape [1 :], self .dtype )
39+
40+ # this is needed for DTensor.from_local() and for flattening tensor
41+ @implements (aten .view .default )
42+ def _ (func , types , args , kwargs ):
43+ x , shape = args
44+
45+ if tuple (x .shape ) == tuple (shape ):
46+ return x .__class__ (x .layout_tensor , x .shape , x .dtype )
47+
48+ if len (shape ) == 1 and shape [0 ] == - 1 :
49+ return x .__class__ (x .layout_tensor , (x .numel (),), x .dtype )
50+
51+ raise ValueError (f"{ x .__class__ .__name__ } only supports .view() with same shape or shape=[-1]" )
52+
53+ @implements (aten .t .default )
54+ def _ (func , types , args , kwargs ):
55+ tensor = args [0 ]
56+ shape = tensor .shape [::- 1 ]
57+ new = tensor .__class__ (tensor .layout_tensor .t (), shape , tensor .dtype )
58+ return return_and_correct_aliasing (func , args , kwargs , new )
59+
60+ @implements (aten .addmm .default )
61+ def _ (func , types , args , kwargs ):
62+ input_tensor , weight_tensor , bias = (
63+ args [1 ],
64+ args [2 ],
65+ args [0 ],
66+ )
67+ transposed = weight_tensor .layout_tensor .transposed
68+ weight_tensor = weight_tensor .dequantize ()
69+ if transposed :
70+ weight_tensor = weight_tensor .t ()
71+ return aten .addmm (input_tensor , weight_tensor , bias )
72+
73+ @implements (aten .mm .default )
74+ def _ (func , types , args , kwargs ):
75+ input_tensor , weight_tensor , bias = (
76+ args [0 ],
77+ args [1 ],
78+ None
79+ )
80+ transposed = weight_tensor .layout_tensor .transposed
81+ weight_tensor = weight_tensor .dequantize ()
82+ if transposed :
83+ weight_tensor = weight_tensor .t ()
84+ return aten .mm (input_tensor , weight_tensor )
6285
6386
6487class M (torch .nn .Module ):
6588 def __init__ (self , * args , ** kwargs ) -> None :
6689 super ().__init__ (* args , ** kwargs )
67- self .linear = torch .nn .Linear (1024 , 1024 )
90+ self .linear = torch .nn .Linear (1024 , 1024 , bias = False , device = "cuda" )
6891
6992 def forward (self , x : torch .Tensor ) -> torch .Tensor :
7093 return self .linear (x )
@@ -79,7 +102,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
79102 torch .manual_seed (5 )
80103
81104 m = M ()
82- example_input = 100 * torch .randn (128 , 1024 )
105+ example_input = 100 * torch .randn (128 , 1024 , device = "cuda" )
83106 m (example_input )
84107
85108
@@ -103,7 +126,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
103126 quantized_shard = quantized_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
104127 print ("quantized shard:" , quantized_shard )
105128 # Construct DTensor from local shard
106- quantized_dtensor = DTensor .from_local (quantized_shard , device_mesh , [Shard (0 )])
129+ quantized_dtensor = DTensor .from_local (quantized_shard , mesh , [Shard (0 )])
107130 print ("quantized dtensor:" , quantized_dtensor )
108131
109132 # Replace parameter in module
@@ -117,4 +140,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
117140 )
118141 print ("input dtensor:" , input_dtensor )
119142
120- m (input_dtensor )
143+ print ( "result:" , m (input_dtensor ) )
0 commit comments