diff --git a/testdata/dnn/onnx/data/input_PReLU_slope.npy b/testdata/dnn/onnx/data/input_PReLU_slope.npy new file mode 100644 index 000000000..b850fc49b Binary files /dev/null and b/testdata/dnn/onnx/data/input_PReLU_slope.npy differ diff --git a/testdata/dnn/onnx/data/output_PReLU_slope.npy b/testdata/dnn/onnx/data/output_PReLU_slope.npy new file mode 100644 index 000000000..e8d0d5d20 Binary files /dev/null and b/testdata/dnn/onnx/data/output_PReLU_slope.npy differ diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index 942bf7b39..3b996fb12 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -8,6 +8,7 @@ import numpy as np import os.path import onnx +import onnxsim import google.protobuf.text_format import io @@ -72,6 +73,14 @@ def save_onnx_data_and_model(input, output, name, operation, *args, **kwargs): model = onnx.helper.make_model(graph, producer_name=name) onnx.save(model, models_files) +def simplify(name, rename=False, **kwargs): + model, check = onnxsim.simplify(name, **kwargs) + assert check, "couldn't valide" + name = name[:-5] + if rename: + name += '_optimized' + onnx.save(model, name + '.onnx') + torch.manual_seed(0) np.random.seed(0) @@ -127,6 +136,18 @@ def save_onnx_data_and_model(input, output, name, operation, *args, **kwargs): relu = nn.ReLU(inplace=True) save_data_and_model("ReLU", input, relu) +class PReLU_slope(nn.Module): + def __init__(self, *args, **kwargs): + super(PReLU_slope, self).__init__() + + def forward(self, x): + return nn.PReLU()(x) + +model = PReLU_slope() +input_ = Variable(torch.randn(1, 1, 5, 5, dtype=torch.float32)) +save_data_and_model("PReLU_slope", input_, model, export_params=True) +simplify('models/PReLU_slope.onnx', False) + input = Variable(torch.randn(2, 3)) dropout = nn.Dropout() diff --git a/testdata/dnn/onnx/models/PReLU_slope.onnx b/testdata/dnn/onnx/models/PReLU_slope.onnx new file mode 100644 index 000000000..b0d72218b Binary files /dev/null and b/testdata/dnn/onnx/models/PReLU_slope.onnx differ