|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# |
| 4 | +# Copyright (c) 2022 Intel Corporation |
| 5 | +# |
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 7 | +# you may not use this file except in compliance with the License. |
| 8 | +# You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | +# See the License for the specific language governing permissions and |
| 16 | +# limitations under the License. |
| 17 | + |
| 18 | +import unittest |
| 19 | +import numpy as np |
| 20 | +import os |
| 21 | +import subprocess |
| 22 | +import shutil |
| 23 | +import time |
| 24 | +import torch |
| 25 | +from datasets import load_dataset |
| 26 | +from transformers import BertForSequenceClassification |
| 27 | +from nlp_toolkit.backends.neural_engine.compile import compile |
| 28 | + |
| 29 | + |
| 30 | +class TestDispatcherTuningAcc(unittest.TestCase): |
| 31 | + |
| 32 | + @classmethod |
| 33 | + def setUpClass(self): |
| 34 | + code = """ |
| 35 | +import time |
| 36 | +import math |
| 37 | +import os |
| 38 | +import sys |
| 39 | +import numpy as np |
| 40 | +from transformers import AutoTokenizer |
| 41 | +from datasets import load_from_disk, load_metric, load_dataset |
| 42 | +from nlp_toolkit.backends.neural_engine.compile.graph import Graph |
| 43 | +class MRPCDataSet(): |
| 44 | + def __init__(self, batch_size, data_dir, tokenizer_dir): |
| 45 | + self.batch_size = batch_size |
| 46 | + dataset = load_dataset('glue', 'mrpc', cache_dir=data_dir,split='validation') |
| 47 | + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) |
| 48 | + self.dataset = dataset.map(lambda e: tokenizer(e['sentence1'], e['sentence2'], |
| 49 | + truncation=False, padding='do_not_pad'), batched=True) |
| 50 | +
|
| 51 | + def __getitem__(self, idx): |
| 52 | + start = idx * self.batch_size |
| 53 | + end = start + self.batch_size |
| 54 | + if end > len(self.dataset): |
| 55 | + input_ids_data = self.dataset[start:]['input_ids'] |
| 56 | + segment_ids_data = self.dataset[start:]['token_type_ids'] |
| 57 | + input_mask_data = self.dataset[start:]['attention_mask'] |
| 58 | + label_data = self.dataset[start:]['label'] |
| 59 | + else: |
| 60 | + input_ids_data = self.dataset[start:end]['input_ids'] |
| 61 | + segment_ids_data = self.dataset[start:end]['token_type_ids'] |
| 62 | + input_mask_data = self.dataset[start:end]['attention_mask'] |
| 63 | + label_data = self.dataset[start:end]['label'] |
| 64 | +
|
| 65 | + sample_size = len(input_ids_data) if isinstance(input_ids_data, list) else 1 |
| 66 | +
|
| 67 | + return [np.array(input_ids_data).reshape(sample_size, -1).astype('int32'), |
| 68 | + np.array(segment_ids_data).reshape(sample_size, -1).astype('int32'), |
| 69 | + np.array(input_mask_data).reshape(sample_size, -1).astype('int32')], \ |
| 70 | + np.array(label_data).reshape(sample_size, -1).astype('int32') |
| 71 | +
|
| 72 | + def __len__(self): |
| 73 | + return math.ceil(len(self.dataset)/self.batch_size) |
| 74 | +
|
| 75 | +def load_model(engine_model_path): |
| 76 | + model = Graph() |
| 77 | + model.graph_init(os.path.join(engine_model_path, "conf.yaml"), |
| 78 | + os.path.join(engine_model_path, "model.bin")) |
| 79 | + return model |
| 80 | +
|
| 81 | +def run(): |
| 82 | + os.environ['GLOG_minloglevel'] = '2' |
| 83 | + dataset = MRPCDataSet(1, "/home/tensorflow/.cache/nlp_toolkit/glue_mrpc_data", "/tf_dataset2/models/nlp_toolkit/bert_mini_mrpc") |
| 84 | + model = load_model("ir") |
| 85 | + metric = load_metric('glue', 'mrpc') |
| 86 | + log_path = sys.argv[1] |
| 87 | + for idx in range(len(dataset)): |
| 88 | + inputs = dataset[idx][0] |
| 89 | + labels = dataset[idx][1] |
| 90 | + predictions = model.inference(inputs) |
| 91 | + predictions = list(predictions.values())[0] |
| 92 | + predictions = np.argmax(predictions, axis=1) |
| 93 | + metric.add_batch( |
| 94 | + predictions=predictions, |
| 95 | + references=labels, |
| 96 | + ) |
| 97 | + eval_metric = metric.compute() |
| 98 | + acc = eval_metric.get("accuracy") |
| 99 | + with open(log_path, 'w') as f: |
| 100 | + f.write(format(acc, '.4f')) |
| 101 | +
|
| 102 | +if __name__ == "__main__": |
| 103 | + run() |
| 104 | +
|
| 105 | +""" |
| 106 | + with open('run.py', 'w', encoding='utf-8') as f: |
| 107 | + f.write(code) |
| 108 | + mrpc_dataset = load_dataset('glue', |
| 109 | + 'mrpc', |
| 110 | + cache_dir='/home/tensorflow/.cache/nlp_toolkit/glue_mrpc_data', |
| 111 | + split='validation') |
| 112 | + # export onnx model |
| 113 | + torch_model = BertForSequenceClassification.from_pretrained( |
| 114 | + '/tf_dataset2/models/nlp_toolkit/bert_mini_mrpc') |
| 115 | + with torch.no_grad(): |
| 116 | + inputs = { |
| 117 | + 'input_ids': torch.ones(1, 128, dtype=torch.int32), |
| 118 | + 'attention_mask': torch.ones(1, 128, dtype=torch.int32), |
| 119 | + 'token_type_ids': torch.ones(1, 128, dtype=torch.int32) |
| 120 | + } |
| 121 | + outputs = torch_model(**inputs) |
| 122 | + |
| 123 | + symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} |
| 124 | + torch.onnx.export( |
| 125 | + torch_model, |
| 126 | + (inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']), |
| 127 | + "onnx_fp32.onnx", |
| 128 | + opset_version=11, |
| 129 | + do_constant_folding=True, |
| 130 | + input_names=['input_ids', 'input_mask', 'segment_ids'], |
| 131 | + output_names=['output'], |
| 132 | + dynamic_axes={ |
| 133 | + 'input_ids': symbolic_names, |
| 134 | + 'input_mask': symbolic_names, |
| 135 | + 'segment_ids': symbolic_names |
| 136 | + }) |
| 137 | + graph = compile("onnx_fp32.onnx") |
| 138 | + graph.save() |
| 139 | + self.dispatch_table_dir = os.path.join( |
| 140 | + os.environ['HOME'], '.cache/neural_engine_workspace/engine_dispatch_table.txt') |
| 141 | + |
| 142 | + @classmethod |
| 143 | + def tearDownClass(self): |
| 144 | + os.remove("run.py") |
| 145 | + shutil.rmtree("./data", ignore_errors=True) |
| 146 | + os.remove("onnx_fp32.onnx") |
| 147 | + shutil.rmtree("./ir", ignore_errors=True) |
| 148 | + for i in range(7): |
| 149 | + try: |
| 150 | + os.remove("log" + str(i) + "_dt0.txt") |
| 151 | + os.remove("log" + str(i) + "_dt1.txt") |
| 152 | + os.remove("log" + str(i) + "_dt2.txt") |
| 153 | + except: |
| 154 | + continue |
| 155 | + if os.path.exists(self.dispatch_table_dir): |
| 156 | + os.remove(self.dispatch_table_dir) |
| 157 | + |
| 158 | + def test_dispatcher_tuning_sharing_acc(self): |
| 159 | + if os.path.exists(self.dispatch_table_dir): |
| 160 | + os.remove(self.dispatch_table_dir) |
| 161 | + cmd = "numactl -l -C 0-3 python run.py log0_dt0.txt & " \ |
| 162 | + "numactl -l -C 4-7 python run.py log1_dt0.txt & " \ |
| 163 | + "numactl -l -C 8-11 python run.py log2_dt0.txt & " \ |
| 164 | + "numactl -l -C 12-15 python run.py log3_dt0.txt &" \ |
| 165 | + "numactl -l -C 16-19 python run.py log4_dt0.txt &" \ |
| 166 | + "numactl -l -C 20-23 python run.py log5_dt0.txt &" \ |
| 167 | + "numactl -l -C 24-27 python run.py log6_dt0.txt" |
| 168 | + # close dispatcher and tuning |
| 169 | + process = subprocess.Popen(cmd, shell=True) # nosec |
| 170 | + process.wait() |
| 171 | + if process.returncode != 0: |
| 172 | + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) |
| 173 | + |
| 174 | + # wait all threads end |
| 175 | + for i in range(7): |
| 176 | + log_exist = os.path.exists("log" + str(i) + "_dt0.txt") |
| 177 | + time_exit = 0 |
| 178 | + while not log_exist: |
| 179 | + time.sleep(1) |
| 180 | + time_exit += 1 |
| 181 | + log_exist = os.path.exists("log" + str(i) + "_dt0.txt") |
| 182 | + if time_exit >= 600: |
| 183 | + break |
| 184 | + |
| 185 | + acc_off = [] |
| 186 | + for i in range(7): |
| 187 | + with open("log" + str(i) + "_dt0.txt", 'r') as f: |
| 188 | + acc_off.append(float(f.readline().strip())) |
| 189 | + |
| 190 | + # open kernel tuning |
| 191 | + os.environ['ENGINE_DISPATCHER_TUNING_ON'] = '1' |
| 192 | + os.environ['INST_NUM'] = '7' |
| 193 | + cmd = "numactl -l -C 0-3 python run.py log0_dt1.txt & " \ |
| 194 | + "numactl -l -C 4-7 python run.py log1_dt1.txt & " \ |
| 195 | + "numactl -l -C 8-11 python run.py log2_dt1.txt & " \ |
| 196 | + "numactl -l -C 12-15 python run.py log3_dt1.txt &" \ |
| 197 | + "numactl -l -C 16-19 python run.py log4_dt1.txt &" \ |
| 198 | + "numactl -l -C 20-23 python run.py log5_dt1.txt &" \ |
| 199 | + "numactl -l -C 24-27 python run.py log6_dt1.txt" |
| 200 | + process = subprocess.Popen(cmd, shell=True) # nosec |
| 201 | + process.wait() |
| 202 | + if process.returncode != 0: |
| 203 | + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) |
| 204 | + |
| 205 | + # wait all threads end |
| 206 | + for i in range(7): |
| 207 | + log_exist = os.path.exists("log" + str(i) + "_dt1.txt") |
| 208 | + time_exit = 0 |
| 209 | + while not log_exist: |
| 210 | + time.sleep(1) |
| 211 | + time_exit += 1 |
| 212 | + log_exist = os.path.exists("log" + str(i) + "_dt1.txt") |
| 213 | + if (time_exit >= 600): |
| 214 | + break |
| 215 | + |
| 216 | + acc_tuning = [] |
| 217 | + for i in range(7): |
| 218 | + with open("log" + str(i) + "_dt1.txt", 'r') as f: |
| 219 | + acc_tuning.append(float(f.readline().strip())) |
| 220 | + |
| 221 | + # use dispatch table after tuning |
| 222 | + del os.environ['ENGINE_DISPATCHER_TUNING_ON'] |
| 223 | + cmd = "numactl -l -C 0-3 python run.py log0_dt2.txt & " \ |
| 224 | + "numactl -l -C 4-7 python run.py log1_dt2.txt & " \ |
| 225 | + "numactl -l -C 8-11 python run.py log2_dt2.txt & " \ |
| 226 | + "numactl -l -C 12-15 python run.py log3_dt2.txt &" \ |
| 227 | + "numactl -l -C 16-19 python run.py log4_dt2.txt &" \ |
| 228 | + "numactl -l -C 20-23 python run.py log5_dt2.txt &" \ |
| 229 | + "numactl -l -C 24-27 python run.py log6_dt2.txt" |
| 230 | + process = subprocess.Popen(cmd, shell=True) # nosec |
| 231 | + process.wait() |
| 232 | + if process.returncode != 0: |
| 233 | + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) |
| 234 | + |
| 235 | + # wait all threads end |
| 236 | + for i in range(7): |
| 237 | + log_exist = os.path.exists("log" + str(i) + "_dt2.txt") |
| 238 | + time_limit = 0 |
| 239 | + while not log_exist: |
| 240 | + time.sleep(1) |
| 241 | + time_exit += 1 |
| 242 | + log_exist = os.path.exists("log" + str(i) + "_dt2.txt") |
| 243 | + if time_limit >= 600: |
| 244 | + break |
| 245 | + |
| 246 | + acc_dispatcher = [] |
| 247 | + for i in range(7): |
| 248 | + with open("log" + str(i) + "_dt2.txt", 'r') as f: |
| 249 | + acc_dispatcher.append(float(f.readline().strip())) |
| 250 | + |
| 251 | + self.assertListEqual(acc_tuning, acc_off) |
| 252 | + self.assertListEqual(acc_dispatcher, acc_off) |
| 253 | + |
| 254 | + |
| 255 | +if __name__ == "__main__": |
| 256 | + unittest.main() |
0 commit comments