22import torch_tensorrt as torchtrt
33import torch
44import torchvision .models as models
5+ import os
56
67def find_repo_root (max_depth = 10 ):
78 dir_path = os .path .dirname (os .path .realpath (__file__ ))
@@ -22,7 +23,7 @@ class TestStandardTensorInput(unittest.TestCase):
2223 def test_compile (self ):
2324
2425 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
25- self .model = torch .jit .load (MODULE_DIR + "/standard_tensor_input .jit.pt" ).eval ().to ("cuda" )
26+ self .model = torch .jit .load (MODULE_DIR + "/standard_tensor_input_scripted .jit.pt" ).eval ().to ("cuda" )
2627
2728 compile_spec = {
2829 "inputs" : [torchtrt .Input (self .input .shape ),
@@ -41,7 +42,7 @@ class TestTupleInput(unittest.TestCase):
4142 def test_compile (self ):
4243
4344 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
44- self .model = torch .jit .load (MODULE_DIR + "/tuple_input .jit.pt" ).eval ().to ("cuda" )
45+ self .model = torch .jit .load (MODULE_DIR + "/tuple_input_scripted .jit.pt" ).eval ().to ("cuda" )
4546
4647 compile_spec = {
4748 "input_signature" : ((torchtrt .Input (self .input .shape ), torchtrt .Input (self .input .shape )),),
@@ -61,7 +62,7 @@ class TestListInput(unittest.TestCase):
6162 def test_compile (self ):
6263
6364 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
64- self .model = torch .jit .load (MODULE_DIR + "/list_input .jit.pt" ).eval ().to ("cuda" )
65+ self .model = torch .jit .load (MODULE_DIR + "/list_input_scripted .jit.pt" ).eval ().to ("cuda" )
6566
6667
6768 compile_spec = {
@@ -81,7 +82,7 @@ class TestTupleInputOutput(unittest.TestCase):
8182 def test_compile (self ):
8283
8384 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
84- self .model = torch .jit .load (MODULE_DIR + "/tuple_input_output .jit.pt" ).eval ().to ("cuda" )
85+ self .model = torch .jit .load (MODULE_DIR + "/tuple_input_output_scripted .jit.pt" ).eval ().to ("cuda" )
8586
8687
8788 compile_spec = {
@@ -103,7 +104,7 @@ class TestListInputOutput(unittest.TestCase):
103104 def test_compile (self ):
104105
105106 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
106- self .model = torch .jit .load (MODULE_DIR + "/list_input_output .jit.pt" ).eval ().to ("cuda" )
107+ self .model = torch .jit .load (MODULE_DIR + "/list_input_output_scripted .jit.pt" ).eval ().to ("cuda" )
107108
108109
109110 compile_spec = {
@@ -126,7 +127,7 @@ class TestListInputTupleOutput(unittest.TestCase):
126127 def test_compile (self ):
127128
128129 self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
129- self .model = torch .jit .load (MODULE_DIR + "/list_input_tuple_output .jit.pt" ).eval ().to ("cuda" )
130+ self .model = torch .jit .load (MODULE_DIR + "/list_input_tuple_output_scripted .jit.pt" ).eval ().to ("cuda" )
130131
131132
132133 compile_spec = {
0 commit comments