|
36 | 36 | pack_tinygemm_scales_and_zeros, |
37 | 37 | per_token_dynamic_quant, |
38 | 38 | ) |
| 39 | +from torchao.dtypes.utils import is_device |
39 | 40 |
|
40 | 41 | aten = torch.ops.aten |
41 | 42 |
|
@@ -542,12 +543,20 @@ def linear_forward_int4( |
542 | 543 | ): |
543 | 544 | origin_x_size = x.size() |
544 | 545 | x = x.reshape(-1, origin_x_size[-1]) |
545 | | - c = torch.ops.aten._weight_int4pack_mm( |
546 | | - x.to(precision), |
547 | | - weight_int4pack, |
548 | | - groupsize, |
549 | | - scales_and_zeros.to(scales_precision), |
550 | | - ).to(dtype=x.dtype) |
| 546 | + if is_device(x.device.type, "cpu"): |
| 547 | + c = torch.ops.aten._weight_int4pack_mm_for_cpu( |
| 548 | + x.to(precision), |
| 549 | + weight_int4pack, |
| 550 | + groupsize, |
| 551 | + scales_and_zeros.to(scales_precision), |
| 552 | + ).to(dtype=x.dtype) |
| 553 | + else: |
| 554 | + c = torch.ops.aten._weight_int4pack_mm( |
| 555 | + x.to(precision), |
| 556 | + weight_int4pack, |
| 557 | + groupsize, |
| 558 | + scales_and_zeros.to(scales_precision), |
| 559 | + ).to(dtype=x.dtype) |
551 | 560 | new_shape = origin_x_size[:-1] + (out_features,) |
552 | 561 | c = c.reshape(new_shape) |
553 | 562 | return c |
@@ -596,19 +605,32 @@ def __init__( |
596 | 605 | assert ( |
597 | 606 | in_features % (inner_k_tiles * 16) == 0 |
598 | 607 | ), "require in_features % (innerKTiles * 16) == 0" |
599 | | - self.register_buffer( |
600 | | - "weight", |
601 | | - torch.zeros( |
602 | | - ( |
603 | | - out_features // 8, |
604 | | - in_features // (inner_k_tiles * 16), |
605 | | - 32, |
606 | | - inner_k_tiles // 2, |
| 608 | + if is_device(device.type, "cpu"): |
| 609 | + self.register_buffer( |
| 610 | + "weight", |
| 611 | + torch.zeros( |
| 612 | + ( |
| 613 | + out_features, |
| 614 | + in_features // 2, |
| 615 | + ), |
| 616 | + dtype=torch.uint8, |
| 617 | + device=device, |
607 | 618 | ), |
608 | | - dtype=torch.int32, |
609 | | - device=device, |
610 | | - ), |
611 | | - ) |
| 619 | + ) |
| 620 | + else: |
| 621 | + self.register_buffer( |
| 622 | + "weight", |
| 623 | + torch.zeros( |
| 624 | + ( |
| 625 | + out_features // 8, |
| 626 | + in_features // (inner_k_tiles * 16), |
| 627 | + 32, |
| 628 | + inner_k_tiles // 2, |
| 629 | + ), |
| 630 | + dtype=torch.int32, |
| 631 | + device=device, |
| 632 | + ), |
| 633 | + ) |
612 | 634 | self.dtype = dtype |
613 | 635 | self.register_buffer( |
614 | 636 | "scales_and_zeros", |
@@ -765,9 +787,14 @@ def _create_quantized_state_dict( |
765 | 787 | self.precision, # dtype for scales_and_zeros |
766 | 788 | ) |
767 | 789 | # TODO: just get the device from mod.weight.device? |
768 | | - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( |
769 | | - w_int4x8.to(self.device), self.inner_k_tiles |
770 | | - ) |
| 790 | + if is_device(w_int4x8.device.type, "cpu"): |
| 791 | + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( |
| 792 | + w_int4x8.to(self.device), self.inner_k_tiles |
| 793 | + ) |
| 794 | + else: |
| 795 | + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( |
| 796 | + w_int4x8.to(self.device), self.inner_k_tiles |
| 797 | + ) |
771 | 798 | cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) |
772 | 799 | cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( |
773 | 800 | self.device |
@@ -851,9 +878,14 @@ def make_names_and_values_dict_func(q, qparams): |
851 | 878 | # how much we need to pad the weight |
852 | 879 | delta_k = int((new_k - k) / 2) |
853 | 880 | q = q.to(self.device) |
854 | | - final_q = torch.ops.aten._convert_weight_to_int4pack( |
855 | | - F.pad(q, pad=(0, delta_k)), inner_k_tiles |
856 | | - ) |
| 881 | + if is_device(self.device.type, "cpu"): |
| 882 | + final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu( |
| 883 | + F.pad(q, pad=(0, delta_k)), inner_k_tiles |
| 884 | + ) |
| 885 | + else: |
| 886 | + final_q = torch.ops.aten._convert_weight_to_int4pack( |
| 887 | + F.pad(q, pad=(0, delta_k)), inner_k_tiles |
| 888 | + ) |
857 | 889 | scales = qparams[0].to(torch.bfloat16).to(self.device) |
858 | 890 | zeros = qparams[1].to(torch.bfloat16).to(self.device) |
859 | 891 | scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) |
|
0 commit comments