88
99
1010aten = torch .ops .aten
11-
11+ c10d_functional = torch .ops .c10d_functional
12+ _c10d_functional = torch .ops ._c10d_functional
1213
1314# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml
1415# NOTE: power-1 is linear
@@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape
3132 )
3233
3334 def __init__ (self , codes : Tensor , scale : Tensor , qmap : Tensor , signed : bool , shape ):
35+ """Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507
36+
37+ Args
38+ codes: quantized and packed 4-bit data stored as uint8.
39+ scale: scale data for block-wise quantization.
40+ qmap: lookup table that maps between quantized value (code) and float value.
41+ signed: whether the tensor is signed or unsigned.
42+ shape: shape of original float tensor.
43+
44+ NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
45+ Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
46+ Given `codes` and `scale`, `block_size` is calculated as `codes.numel() * 2 // scale.numel()`.
47+ The extra `* 2` is because `codes` is 4-bit data packed in 8-bit storage.
48+ """
3449 assert codes .dtype is torch .uint8
3550 assert codes .ndim == 1 # flattened buffer
51+ assert scale .ndim == 1
3652 self .codes = codes
3753 self .scale = scale
3854 self .qmap = qmap
3955 self .signed = signed
4056 self ._shape = shape
41-
42- @property
43- def block_size (self ):
44- return self .codes .numel () * 2 // self .scale .numel ()
57+ self .block_size = codes .numel () * 2 // scale .numel ()
4558
4659 def __tensor_flatten__ (self ):
4760 return self .tensor_attrs , [self .signed , self ._shape ]
@@ -113,9 +126,37 @@ def _(func, *args, **kwargs):
113126 return func (* args , ** kwargs )
114127
115128
129+ # this is needed for DTensor.from_local() and for flattening tensor
116130@OptimState4bit .implements (aten .view .default )
117131def _ (func , * args , ** kwargs ):
118132 x , shape = args
119- if len (shape ) > 1 or shape [0 ] != - 1 :
120- raise ValueError (f"{ x .__class__ .__name__ } only supports .view() with shape=[-1]" )
121- return OptimState4bit (x .codes , x .scale , x .qmap , x .signed , (x .numel (),))
133+
134+ if tuple (x .shape ) == tuple (shape ):
135+ return OptimState4bit (x .codes , x .scale , x .qmap , x .signed , x ._shape )
136+
137+ if len (shape ) == 1 and shape [0 ] == - 1 :
138+ return OptimState4bit (x .codes , x .scale , x .qmap , x .signed , (x .numel (),))
139+
140+ raise ValueError (f"{ x .__class__ .__name__ } only supports .view() with same shape or shape=[-1]" )
141+
142+
143+ # this is needed for DTensor.full_tensor()
144+ @OptimState4bit .implements ([
145+ c10d_functional .all_gather_into_tensor .default ,
146+ _c10d_functional .all_gather_into_tensor .default ,
147+ c10d_functional .wait_tensor .default ,
148+ _c10d_functional .wait_tensor .default ,
149+ ])
150+ def _ (func , * args , ** kwargs ):
151+ x = args [0 ]
152+ if not isinstance (x , OptimState4bit ):
153+ raise ValueError (f"expecting a OptimState4bit but found { type (x )} " )
154+
155+ codes = func (x .codes , * args [1 :], ** kwargs )
156+ scale = func (x .scale , * args [1 :], ** kwargs )
157+
158+ # adjust the first dim
159+ shape = (x ._shape [0 ] * codes .numel () // x .codes .numel (),) + x ._shape [1 :]
160+
161+ # assume tensors from all ranks have the same signedness
162+ return OptimState4bit (codes , scale , x .qmap .clone (), x .signed , shape )
0 commit comments