Skip to content

Commit c7db10c

Browse files
committed
Updates
1 parent 677aca7 commit c7db10c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+222
-217
lines changed
File renamed without changes.
File renamed without changes.

benchmarks/_models/llama/eval.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
from typing import List, Optional
99

1010
import torch
11-
from generate import (
12-
_load_model,
13-
device_sync,
14-
)
1511
from tokenizer import get_tokenizer
1612

1713
import torchao
18-
from torchao._models.llm.model import prepare_inputs_for_model
14+
from benchmarks._models.llama.model import prepare_inputs_for_model
15+
from benchmarks._models.utils import (
16+
_load_model,
17+
)
1918
from torchao.quantization import (
2019
PerRow,
2120
PerTensor,
@@ -28,7 +27,11 @@
2827
quantize_,
2928
uintx_weight_only,
3029
)
31-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass
30+
from torchao.utils import (
31+
TORCH_VERSION_AT_LEAST_2_5,
32+
device_sync,
33+
unwrap_tensor_subclass,
34+
)
3235

3336

3437
def run_evaluation(
@@ -120,7 +123,7 @@ def run_evaluation(
120123
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
121124
if "int4wo" in quantization and "gptq" in quantization:
122125
# avoid circular imports
123-
from torchao._models._eval import MultiTensorInputRecorder
126+
from benchmarks._models._eval import MultiTensorInputRecorder
124127
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer
125128

126129
groupsize = int(quantization.split("-")[-2])
@@ -172,7 +175,7 @@ def run_evaluation(
172175
if "autoround" in quantization:
173176
from transformers import AutoTokenizer
174177

175-
from torchao._models.llm.model import TransformerBlock
178+
from benchmarks._models.llama.model import TransformerBlock
176179
from torchao.prototype.autoround.autoround_llm import (
177180
quantize_model_with_autoround_,
178181
)
@@ -242,7 +245,7 @@ def run_evaluation(
242245
with torch.no_grad():
243246
print("Running evaluation ...")
244247
# avoid circular imports
245-
from torchao._models._eval import TransformerEvalWrapper
248+
from benchmarks._models._eval import TransformerEvalWrapper
246249

247250
TransformerEvalWrapper(
248251
model=model.to(device),

benchmarks/_models/llama/generate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch._inductor.config
1515

1616
import torchao
17-
from torchao._models.utils import (
17+
from benchmarks._models.utils import (
1818
_load_model,
1919
decode_n_tokens,
2020
decode_one_token,
@@ -63,8 +63,8 @@ def device_timer(device):
6363
wd = Path(__file__).parent.parent.resolve()
6464
sys.path.append(str(wd))
6565

66-
from torchao._models.llm.model import Transformer, prepare_inputs_for_model
67-
from torchao._models.llm.tokenizer import get_tokenizer
66+
from benchmarks._models.llama.model import Transformer, prepare_inputs_for_model
67+
from benchmarks._models.llama.tokenizer import get_tokenizer
6868

6969

7070
def model_forward(model, x, input_pos):
@@ -382,7 +382,7 @@ def ffn_or_attn_only(mod, fqn):
382382
filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding),
383383
)
384384
elif quantization.startswith("awq"):
385-
from torchao._models._eval import TransformerEvalWrapper
385+
from benchmarks._models._eval import TransformerEvalWrapper
386386
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
387387

388388
if not TORCH_VERSION_AT_LEAST_2_3:
@@ -481,8 +481,8 @@ def ffn_or_attn_only(mod, fqn):
481481
model, float8_dynamic_activation_float8_weight(granularity=granularity)
482482
)
483483
elif "autoquant_v2" in quantization:
484-
from torchao._models._eval import InputRecorder
485-
from torchao._models.llm.model import prepare_inputs_for_model
484+
from benchmarks._models._eval import InputRecorder
485+
from benchmarks._models.llama.model import prepare_inputs_for_model
486486
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
487487

488488
calibration_seq_length = 256
@@ -571,8 +571,8 @@ def ffn_or_attn_only(mod, fqn):
571571
# do autoquantization
572572
model.finalize_autoquant()
573573
elif "autoquant" in quantization:
574-
from torchao._models._eval import InputRecorder
575-
from torchao._models.llm.model import prepare_inputs_for_model
574+
from benchmarks._models._eval import InputRecorder
575+
from benchmarks._models.llama.model import prepare_inputs_for_model
576576

577577
calibration_seq_length = 256
578578
inputs = (
File renamed without changes.

benchmarks/_models/llama/perf_profile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@
116116
import torch
117117
from torch.nn.attention import SDPBackend
118118

119-
from torchao._models.llm.model import Transformer
120-
from torchao._models.llm.tokenizer import get_tokenizer
119+
from benchmarks._models.llama.model import Transformer
120+
from benchmarks._models.llama.tokenizer import get_tokenizer
121121
from torchao.prototype.profiler import (
122122
CUDADeviceSpec,
123123
TransformerPerformanceCounter,
File renamed without changes.
File renamed without changes.

benchmarks/_models/sam/eval_combo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from metrics import calculate_miou, create_result_entry
1010

1111
import torchao
12-
from torchao._models.utils import (
12+
from benchmarks._models.utils import (
1313
get_arch_name,
1414
write_json_result_local,
1515
write_json_result_ossci,

torchao/_models/sam2/__init__.py renamed to benchmarks/_models/sam2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
from hydra.core.global_hydra import GlobalHydra
99

1010
if not GlobalHydra.instance().is_initialized():
11-
initialize_config_module("torchao._models.sam2", version_base="1.2")
11+
initialize_config_module("benchmarks._models.sam2", version_base="1.2")

0 commit comments

Comments
 (0)