30
30
if USE_HOST_DEPS :
31
31
print ("Using dependencies from host python" )
32
32
33
+ # Set epochs to train VGG model for accuracy tests
34
+ EPOCHS = 25
35
+
33
36
SUPPORTED_PYTHON_VERSIONS = ["3.7" , "3.8" , "3.9" , "3.10" ]
34
37
35
38
nox .options .sessions = [
@@ -63,31 +66,6 @@ def install_torch_trt(session):
63
66
session .run ("python" , "setup.py" , "develop" )
64
67
65
68
66
- def download_datasets (session ):
67
- print (
68
- "Downloading dataset to path" ,
69
- os .path .join (TOP_DIR , "examples/int8/training/vgg16" ),
70
- )
71
- session .chdir (os .path .join (TOP_DIR , "examples/int8/training/vgg16" ))
72
- session .run_always (
73
- "wget" , "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz" , external = True
74
- )
75
- session .run_always ("tar" , "-xvzf" , "cifar-10-binary.tar.gz" , external = True )
76
- session .run_always (
77
- "mkdir" ,
78
- "-p" ,
79
- os .path .join (TOP_DIR , "tests/accuracy/datasets/data" ),
80
- external = True ,
81
- )
82
- session .run_always (
83
- "cp" ,
84
- "-rpf" ,
85
- os .path .join (TOP_DIR , "examples/int8/training/vgg16/cifar-10-batches-bin" ),
86
- os .path .join (TOP_DIR , "tests/accuracy/datasets/data/cidar-10-batches-bin" ),
87
- external = True ,
88
- )
89
-
90
-
91
69
def train_model (session ):
92
70
session .chdir (os .path .join (TOP_DIR , "examples/int8/training/vgg16" ))
93
71
session .install ("-r" , "requirements.txt" )
@@ -107,14 +85,14 @@ def train_model(session):
107
85
"--ckpt-dir" ,
108
86
"vgg16_ckpts" ,
109
87
"--epochs" ,
110
- "25" ,
88
+ str ( EPOCHS ) ,
111
89
env = {"PYTHONPATH" : PYT_PATH },
112
90
)
113
91
114
92
session .run_always (
115
93
"python" ,
116
94
"export_ckpt.py" ,
117
- "vgg16_ckpts/ckpt_epoch25 .pth" ,
95
+ "vgg16_ckpts/ckpt_epoch" + str ( EPOCHS ) + " .pth" ,
118
96
env = {"PYTHONPATH" : PYT_PATH },
119
97
)
120
98
else :
@@ -130,10 +108,12 @@ def train_model(session):
130
108
"--ckpt-dir" ,
131
109
"vgg16_ckpts" ,
132
110
"--epochs" ,
133
- "25" ,
111
+ str ( EPOCHS ) ,
134
112
)
135
113
136
- session .run_always ("python" , "export_ckpt.py" , "vgg16_ckpts/ckpt_epoch25.pth" )
114
+ session .run_always (
115
+ "python" , "export_ckpt.py" , "vgg16_ckpts/ckpt_epoch" + str (EPOCHS ) + ".pth"
116
+ )
137
117
138
118
139
119
def finetune_model (session ):
@@ -156,17 +136,17 @@ def finetune_model(session):
156
136
"--ckpt-dir" ,
157
137
"vgg16_ckpts" ,
158
138
"--start-from" ,
159
- "25" ,
139
+ str ( EPOCHS ) ,
160
140
"--epochs" ,
161
- "26" ,
141
+ str ( EPOCHS + 1 ) ,
162
142
env = {"PYTHONPATH" : PYT_PATH },
163
143
)
164
144
165
145
# Export model
166
146
session .run_always (
167
147
"python" ,
168
148
"export_qat.py" ,
169
- "vgg16_ckpts/ckpt_epoch26 .pth" ,
149
+ "vgg16_ckpts/ckpt_epoch" + str ( EPOCHS + 1 ) + " .pth" ,
170
150
env = {"PYTHONPATH" : PYT_PATH },
171
151
)
172
152
else :
@@ -182,13 +162,17 @@ def finetune_model(session):
182
162
"--ckpt-dir" ,
183
163
"vgg16_ckpts" ,
184
164
"--start-from" ,
185
- "25" ,
165
+ str ( EPOCHS ) ,
186
166
"--epochs" ,
187
- "26" ,
167
+ str ( EPOCHS + 1 ) ,
188
168
)
189
169
190
170
# Export model
191
- session .run_always ("python" , "export_qat.py" , "vgg16_ckpts/ckpt_epoch26.pth" )
171
+ session .run_always (
172
+ "python" ,
173
+ "export_qat.py" ,
174
+ "vgg16_ckpts/ckpt_epoch" + str (EPOCHS + 1 ) + ".pth" ,
175
+ )
192
176
193
177
194
178
def cleanup (session ):
@@ -219,6 +203,19 @@ def run_base_tests(session):
219
203
session .run_always ("pytest" , test )
220
204
221
205
206
+ def run_model_tests (session ):
207
+ print ("Running model tests" )
208
+ session .chdir (os .path .join (TOP_DIR , "tests/py" ))
209
+ tests = [
210
+ "models" ,
211
+ ]
212
+ for test in tests :
213
+ if USE_HOST_DEPS :
214
+ session .run_always ("pytest" , test , env = {"PYTHONPATH" : PYT_PATH })
215
+ else :
216
+ session .run_always ("pytest" , test )
217
+
218
+
222
219
def run_accuracy_tests (session ):
223
220
print ("Running accuracy tests" )
224
221
session .chdir (os .path .join (TOP_DIR , "tests/py" ))
@@ -282,7 +279,7 @@ def run_dla_tests(session):
282
279
print ("Running DLA tests" )
283
280
session .chdir (os .path .join (TOP_DIR , "tests/py" ))
284
281
tests = [
285
- "test_api_dla.py" ,
282
+ "hw/ test_api_dla.py" ,
286
283
]
287
284
for test in tests :
288
285
if USE_HOST_DEPS :
@@ -322,21 +319,19 @@ def run_l0_dla_tests(session):
322
319
cleanup (session )
323
320
324
321
325
- def run_l1_accuracy_tests (session ):
322
+ def run_l1_model_tests (session ):
326
323
if not USE_HOST_DEPS :
327
324
install_deps (session )
328
325
install_torch_trt (session )
329
- download_datasets (session )
330
- train_model (session )
331
- run_accuracy_tests (session )
326
+ download_models (session )
327
+ run_model_tests (session )
332
328
cleanup (session )
333
329
334
330
335
331
def run_l1_int8_accuracy_tests (session ):
336
332
if not USE_HOST_DEPS :
337
333
install_deps (session )
338
334
install_torch_trt (session )
339
- download_datasets (session )
340
335
train_model (session )
341
336
finetune_model (session )
342
337
run_int8_accuracy_tests (session )
@@ -347,9 +342,6 @@ def run_l2_trt_compatibility_tests(session):
347
342
if not USE_HOST_DEPS :
348
343
install_deps (session )
349
344
install_torch_trt (session )
350
- download_models (session )
351
- download_datasets (session )
352
- train_model (session )
353
345
run_trt_compatibility_tests (session )
354
346
cleanup (session )
355
347
@@ -376,9 +368,9 @@ def l0_dla_tests(session):
376
368
377
369
378
370
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
379
- def l1_accuracy_tests (session ):
380
- """Checking accuracy performance on various usecases """
381
- run_l1_accuracy_tests (session )
371
+ def l1_model_tests (session ):
372
+ """When a user needs to test the functionality of standard models compilation and results """
373
+ run_l1_model_tests (session )
382
374
383
375
384
376
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
@@ -397,13 +389,3 @@ def l2_trt_compatibility_tests(session):
397
389
def l2_multi_gpu_tests (session ):
398
390
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
399
391
run_l2_multi_gpu_tests (session )
400
-
401
-
402
- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
403
- def download_test_models (session ):
404
- """Grab all the models needed for testing"""
405
- try :
406
- import torch
407
- except ModuleNotFoundError :
408
- install_deps (session )
409
- download_models (session )
0 commit comments