diff --git a/tests/test_tokenicer.py b/tests/test_tokenicer.py index 243359367..e86ca3316 100644 --- a/tests/test_tokenicer.py +++ b/tests/test_tokenicer.py @@ -19,9 +19,13 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import unittest # noqa: E402 - +import tempfile from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 from parameterized import parameterized # noqa: E402 +from datasets import load_dataset +import json +from tokenicer.const import VALIDATE_JSON_FILE_NAME, VALIDATE_ENCODE_PARAMS +from tokenicer.config import ValidateConfig class TestTokenicer(unittest.TestCase): @@ -78,3 +82,33 @@ def test_tokenicer_decode(self): example, msg=f"Expected example='{self.example}' but got '{example}'." ) + + def test_tokenicer_save(self): + traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", + split="train") + calibration_dataset = [self.tokenizer(example["text"]) for example in traindata.select(range(32))] + + self.model.quantize(calibration_dataset, batch_size=32) + + with tempfile.TemporaryDirectory() as tmpdir: + self.model.save(tmpdir) + validate_json_path = os.path.join(tmpdir, VALIDATE_JSON_FILE_NAME) + + result = os.path.isfile(validate_json_path) + self.assertTrue(result, f"Save verify file failed: {validate_json_path} does not exist.") + + with open(validate_json_path, 'r', encoding='utf-8') as f: + data = json.loads(f.read()) + + config = ValidateConfig.from_dict(data) + + validate = True + for data in config.data: + input = data.input + tokenized = self.tokenizer.encode_plus(input, **VALIDATE_ENCODE_PARAMS)["input_ids"].tolist()[0] + if data.output != tokenized: + validate = False + break + + self.assertTrue(validate, f"Expected validate='True' but got '{validate}'.") +