diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index f1c6c5bd4..1ab233953 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2527,9 +2527,9 @@ def _get_current_num_elm( def _quantize_block( self, block: torch.nn.Module, - input_ids: list[torch.Tensor], + input_ids: Union[list[torch.Tensor], dict], input_others: dict, - q_input: Union[None, torch.Tensor] = None, + q_input: Union[torch.Tensor, dict, None] = None, device: Union[str, torch.device] = "cpu", ): """Quantize the weights of a given block of the model. @@ -2646,7 +2646,11 @@ def _quantize_block( else: lr_schedule = copy.deepcopy(self.lr_scheduler) - nsamples = len(input_ids) + if isinstance(input_ids, dict): # input_ids of Flux is dict + nsamples = len(input_ids["hidden_states"]) + else: + nsamples = len(input_ids) + pick_samples = self.batch_size * self.gradient_accumulate_steps pick_samples = min(nsamples, pick_samples) if self.sampler != "rand": diff --git a/auto_round/compressors/diffusion/compressor.py b/auto_round/compressors/diffusion/compressor.py index 3026204ac..5441d00b5 100644 --- a/auto_round/compressors/diffusion/compressor.py +++ b/auto_round/compressors/diffusion/compressor.py @@ -210,7 +210,7 @@ def _get_current_q_output( def _get_block_outputs( self, block: torch.nn.Module, - input_ids: torch.Tensor, + input_ids: Union[torch.Tensor, dict], input_others: torch.Tensor, bs: int, device: Union[str, torch.device], @@ -233,8 +233,11 @@ def _get_block_outputs( """ output = defaultdict(list) - nsamples = len(input_ids) output_config = output_configs.get(block.__class__.__name__, []) + if isinstance(input_ids, dict): + nsamples = len(input_ids["hidden_states"]) + else: + nsamples = len(input_ids) for i in range(0, nsamples, bs): end_index = min(nsamples, i + bs)