1010from typing import Tuple
1111
1212import torch
13- import torch .nn as nn
14- import torch .nn .functional as F
1513from torch .testing ._internal import common_utils
1614from torch .testing ._internal .common_utils import (
17- TestCase ,
1815 run_tests ,
1916)
2017
21- from torchao .prototype .moe_quant .utils import MoEQuantConfig
2218from torchao .quantization import (
2319 Float8DynamicActivationFloat8WeightConfig ,
2420 Float8WeightOnlyConfig ,
2824)
2925from torchao .quantization .quantize_ .common import KernelPreference
3026from torchao .quantization .utils import compute_error
27+ from torchao .testing .utils import TorchAOIntegrationTestCase
3128from torchao .utils import (
3229 TORCH_VERSION_AT_LEAST_2_8 ,
3330 _is_fbgemm_genai_gpu_available ,
3936torch ._dynamo .config .cache_size_limit = 128
4037
4138
42- class Experts (nn .Module ):
43- def __init__ (
44- self ,
45- num_local_experts : int ,
46- dim : int ,
47- hidden_dim : int ,
48- dtype : torch .dtype ,
49- device : torch .device ,
50- ) -> None :
51- super ().__init__ ()
52-
53- self .num_local_experts = num_local_experts
54- self .dim = dim
55-
56- self .w1 : nn .Parameter = nn .Parameter (
57- torch .randn (
58- num_local_experts ,
59- dim ,
60- hidden_dim ,
61- dtype = dtype ,
62- device = device ,
63- )
64- )
65-
66- self .w2 : nn .Parameter = nn .Parameter (
67- torch .randn (
68- num_local_experts ,
69- hidden_dim ,
70- dim ,
71- dtype = dtype ,
72- device = device ,
73- )
74- )
75-
76- self .w3 : nn .Parameter = nn .Parameter (
77- torch .randn (
78- num_local_experts ,
79- dim ,
80- hidden_dim ,
81- dtype = dtype ,
82- device = device ,
83- )
84- )
85-
86- def forward (
87- self ,
88- routed_in_egD : torch .Tensor , # noqa: N803
89- ) -> torch .Tensor :
90- e = self .num_local_experts
91- D = self .dim
92-
93- x_egD = routed_in_egD .view (e , - 1 , D )
94-
95- middle_out_egF = F .silu (torch .bmm (x_egD , self .w1 )) * torch .bmm (x_egD , self .w3 )
96- out_egD = torch .bmm (middle_out_egF , self .w2 )
97- out_egD = out_egD .view (- 1 , D )
98-
99- return out_egD
100-
101-
10239class ToyLinearModel (torch .nn .Module ):
10340 def __init__ (self , in_features , out_features ):
10441 super ().__init__ ()
@@ -115,7 +52,7 @@ def forward(self, x):
11552@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
11653@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
11754@unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
118- class TestFloat8Tensor (TestCase ):
55+ class TestFloat8Tensor (TorchAOIntegrationTestCase ):
11956 def setUp (self ):
12057 self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
12158
@@ -340,45 +277,8 @@ def test_slice_preserves_aliasing(self, granularity):
340277
341278 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
342279 def test_slice_and_copy_similar_to_vllm (self , granularity ):
343- # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
344- # the test is similar to the linked code, but with some hardcoded arguments
345- # and does not use tensor parallelism
346-
347- dtype = torch .bfloat16
348- device = "cuda"
349280 config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
350- l = torch .nn .Linear (1024 , 1024 , device = "cuda" , dtype = dtype )
351- quantize_ (l , config )
352-
353- # high level, we do a narrow for both param.data and the loaded_weights
354- # and do inplace copy_ to copy from the loaded_weights into param.data
355-
356- # simulate loaded_weight
357- dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
358- # making the weight different
359- dummy_l .weight = torch .nn .Parameter (
360- dummy_l .weight + 2 * torch .randn (1024 , 1024 , device = device , dtype = dtype ),
361- requires_grad = False ,
362- )
363- quantize_ (dummy_l , config )
364-
365- output_dim = 0
366- shard_size = 512
367- for tp_rank in [0 , 1 ]:
368- start_idx = tp_rank * shard_size
369- param = l .weight
370- param_data = param .data
371- param_data = param_data .narrow (output_dim , start_idx , shard_size )
372- orig_value = param_data .qdata [0 ][0 ].item ()
373- loaded_weight = dummy_l .weight
374- loaded_weight = loaded_weight .narrow (output_dim , start_idx , shard_size )
375-
376- # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
377- assert orig_value != loaded_weight .qdata [0 ][0 ]
378- param_data .copy_ (loaded_weight )
379- # making sure param.data is updated to loaded_weight
380- assert param_data .qdata [0 ][0 ] == loaded_weight .qdata [0 ][0 ]
381- assert param_data .scale [0 ] == loaded_weight .scale [0 ]
281+ self ._test_slice_and_copy_similar_to_vllm (config )
382282
383283 @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
384284 def test_bmm (self ):
@@ -494,122 +394,10 @@ def test_cat(self, granularity, sizes):
494394
495395 @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
496396 def test_moe_weight_reshape_ops (self ):
497- """This is testing the op call sequence in saving and loading quantization
498- checkpoints in llama-models for llama4
499- (https://github.com/meta-llama/llama-models/tree/main/models/llama4)
500- """
501397 # only per row quantization is supported for bmm
502398 granularity = PerRow ()
503- dtype = torch .bfloat16
504- device = "cuda"
505-
506- bmm_config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
507- moe_config = MoEQuantConfig (bmm_config )
508-
509- batch_size = 4
510- num_experts = 2
511- input_dim = 64
512- dim = 128
513- hidden_dim = 256
514-
515- moe1 = Experts (num_experts , dim , hidden_dim , dtype , device )
516- moe2 = Experts (num_experts , dim , hidden_dim , dtype , device )
517- moe_combined = Experts (num_experts , dim , 2 * hidden_dim , dtype , device )
518- input = torch .randn (batch_size , input_dim , dim , dtype = dtype , device = device )
519-
520- moes = [moe1 , moe2 ]
521-
522- for moe in moes :
523- moe (input )
524-
525- def filter_fn (module , fqn ):
526- return isinstance (module , Experts )
527-
528- # need to transpose before quantizing
529- moe .w1 = torch .nn .Parameter (
530- moe .w1 .transpose (1 , 2 ).contiguous (), requires_grad = False
531- )
532- moe .w2 = torch .nn .Parameter (
533- moe .w2 .transpose (1 , 2 ).contiguous (), requires_grad = False
534- )
535- moe .w3 = torch .nn .Parameter (
536- moe .w3 .transpose (1 , 2 ).contiguous (), requires_grad = False
537- )
538-
539- quantize_ (moe , moe_config , filter_fn = filter_fn )
540-
541- # make sure it runs
542- before = moe (input )
543-
544- # transposing for resharding support since only 2D resharding is supported
545- new_last_dim = moe .w1 .shape [- 2 ]
546- moe .w1 = torch .nn .Parameter (
547- moe .w1 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
548- )
549- new_last_dim = moe .w2 .shape [- 2 ]
550- moe .w2 = torch .nn .Parameter (
551- moe .w2 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
552- )
553- new_last_dim = moe .w3 .shape [- 2 ]
554- moe .w3 = torch .nn .Parameter (
555- moe .w3 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
556- )
557-
558- moe .w1 = torch .nn .Parameter (
559- moe .w1 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
560- requires_grad = False ,
561- )
562- moe .w2 = torch .nn .Parameter (
563- moe .w2 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
564- requires_grad = False ,
565- )
566- moe .w3 = torch .nn .Parameter (
567- moe .w3 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
568- requires_grad = False ,
569- )
570-
571- # transpose again to recover the original weights
572- moe .w1 = torch .nn .Parameter (moe .w1 .transpose (1 , 2 ), requires_grad = False )
573- moe .w2 = torch .nn .Parameter (moe .w2 .transpose (1 , 2 ), requires_grad = False )
574- moe .w3 = torch .nn .Parameter (moe .w3 .transpose (1 , 2 ), requires_grad = False )
575-
576- # make sure it runs
577- after = moe (input )
578-
579- self .assertEqual (before , after )
580-
581- state_dicts = [moe1 .state_dict (), moe2 .state_dict ()]
582- # align the scale parameter so they can be concatenated
583- for key in ["w1" , "w2" , "w3" ]:
584- weights = [st [key ] for st in state_dicts ]
585- for i in range (1 , len (weights )):
586- weights [i ].scale = weights [0 ].scale
587-
588- def process_key (key : str ) -> torch .Tensor :
589- tensors = [s [key ] for s in state_dicts ]
590- # Note: we have a hacky implementation for cat in user codebase
591- # since it is not implemented correctly before
592- if key == "w2" :
593- return torch .cat (tensors , dim = - 1 )
594- else :
595- return torch .cat (tensors , dim = - 2 )
596-
597- new_state_dict = {}
598- for key in ["w1" , "w2" , "w3" ]:
599- new_state_dict [key ] = process_key (key )
600-
601- moe_combined .w1 = torch .nn .Parameter (
602- moe_combined .w1 .transpose (1 , 2 ), requires_grad = False
603- )
604- moe_combined .w2 = torch .nn .Parameter (
605- moe_combined .w2 .transpose (1 , 2 ), requires_grad = False
606- )
607- moe_combined .w3 = torch .nn .Parameter (
608- moe_combined .w3 .transpose (1 , 2 ), requires_grad = False
609- )
610- moe_combined .load_state_dict (new_state_dict , assign = True )
611- # make sure it runs
612- moe_combined (input )
399+ config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
400+ self ._test_moe_weight_reshape_ops (config )
613401
614402
615403common_utils .instantiate_parametrized_tests (TestFloat8Tensor )
0 commit comments