Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions auto_round/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
environment_variables: dict[str, Callable[[], Any]] = {
# this is used for configuring the default logging level
"AR_LOG_LEVEL": lambda: os.getenv("AR_LOG_LEVEL", "INFO").upper(),
"AR_ENABLE_COMPILE_PACKING": lambda: os.getenv("AR_ENABLE_COMPILE_PACKING", "0").lower() in ("1", "true", "yes"),
"AR_USE_MODELSCOPE": lambda: os.getenv("AR_USE_MODELSCOPE", "False").lower() in ["1", "true"],
}

Expand Down
22 changes: 17 additions & 5 deletions auto_round/export/export_to_autoround/qlinear_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,20 @@
# limitations under the License.

import math
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import transformers

import auto_round.envs as envs
from auto_round.compressors.utils import BackendDataType, is_mx_fp, is_nv_fp
from auto_round.data_type.mxfp import FP32_EXPONENT_BIAS, FP32_MIN_NORMAL
from auto_round.data_type.nvfp import cast_to_fp4, get_reciprocal
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad
from auto_round.utils import get_packing_device
from auto_round.utils import get_packing_device, logger

# from auto_round.utils import get_weight_compress_dtype
logger = getLogger(__name__)
E8M0_EXPONENT_BIAS = 127
E8M0_EXPONENT_NAN_VAL = 255

Expand Down Expand Up @@ -202,15 +201,28 @@ def pack_fp4_to_uint8_cpu(x: torch.Tensor) -> torch.Tensor:


# Adapted from https://github.com/neuralmagic/compressed-tensors/pull/400
@torch.compile(fullgraph=True, dynamic=True)


def _get_packing_fn():
if envs.AR_ENABLE_COMPILE_PACKING:
logger.warning_once(
"Compiled FP4 to UINT8 packing may be incompatible with multi-threading."
" Disable it by setting AR_ENABLE_COMPILE_PACKING=0"
)
return torch.compile(fullgraph=True, dynamic=True)(_pack_fp4_to_uint8)
else:
return torch.compiler.disable()(_pack_fp4_to_uint8)


def pack_fp4_to_uint8_cuda(x: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor with values in the fp4 range into uint8.

:param x: tensor to pack
returns: a packed tensor in uint8
"""
return _pack_fp4_to_uint8(x)
pack_fn = _get_packing_fn()
return pack_fn(x)


def _pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
Expand Down