From 28f195d7f8f92817e59abc1afd39b754c217fa6e Mon Sep 17 00:00:00 2001 From: Smirnov Egor Date: Wed, 8 Sep 2021 18:45:49 +0300 Subject: [PATCH] add test for prelu negative slope access pattern --- testdata/dnn/onnx/data/input_PReLU_slope.npy | Bin 0 -> 228 bytes testdata/dnn/onnx/data/output_PReLU_slope.npy | Bin 0 -> 228 bytes testdata/dnn/onnx/generate_onnx_models.py | 21 ++++++++++++++++++ testdata/dnn/onnx/models/PReLU_slope.onnx | Bin 0 -> 180 bytes 4 files changed, 21 insertions(+) create mode 100644 testdata/dnn/onnx/data/input_PReLU_slope.npy create mode 100644 testdata/dnn/onnx/data/output_PReLU_slope.npy create mode 100644 testdata/dnn/onnx/models/PReLU_slope.onnx diff --git a/testdata/dnn/onnx/data/input_PReLU_slope.npy b/testdata/dnn/onnx/data/input_PReLU_slope.npy new file mode 100644 index 0000000000000000000000000000000000000000..b850fc49b46072258df1fea0d724391acb6e02ab GIT binary patch literal 228 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-ItoB!3PhSZ3bhJk0IpvrZtuM{tKPml{Qo|!8&3A$J>=})c+A}Q)Wcz) zlFe6p-M$@m(@s3JT{Smw|JOf?_5m?p_rG@L-#=#}hkar21v`88Q+EELSM2;$5AM(N cV{%xwpUo~dX2#w>XS?@r*nYz9u91X20HhjDMgRZ+ literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/output_PReLU_slope.npy b/testdata/dnn/onnx/data/output_PReLU_slope.npy new file mode 100644 index 0000000000000000000000000000000000000000..e8d0d5d207dbcf9efef4a2e6e6e3ebf93f0a69e3 GIT binary patch literal 228 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-ItoB!3PhSZ3bhJk0IpvrZtuA@tKPml{Qq998&3A$J>=})c+A}U)Wcz~ zlFe6p-M$@m(@s3JT{Smw-`78i_5m?p_q}%J-#2F>hkar21v`88Q+EELSM2;$5AMtJ cV{%xwpUo~dX2zaBXS?@p*nYz9u91X20HOp=JOBUy literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index 942bf7b39..3b996fb12 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -8,6 +8,7 @@ import numpy as np import os.path import onnx +import onnxsim import google.protobuf.text_format import io @@ -72,6 +73,14 @@ def save_onnx_data_and_model(input, output, name, operation, *args, **kwargs): model = onnx.helper.make_model(graph, producer_name=name) onnx.save(model, models_files) +def simplify(name, rename=False, **kwargs): + model, check = onnxsim.simplify(name, **kwargs) + assert check, "couldn't valide" + name = name[:-5] + if rename: + name += '_optimized' + onnx.save(model, name + '.onnx') + torch.manual_seed(0) np.random.seed(0) @@ -127,6 +136,18 @@ def save_onnx_data_and_model(input, output, name, operation, *args, **kwargs): relu = nn.ReLU(inplace=True) save_data_and_model("ReLU", input, relu) +class PReLU_slope(nn.Module): + def __init__(self, *args, **kwargs): + super(PReLU_slope, self).__init__() + + def forward(self, x): + return nn.PReLU()(x) + +model = PReLU_slope() +input_ = Variable(torch.randn(1, 1, 5, 5, dtype=torch.float32)) +save_data_and_model("PReLU_slope", input_, model, export_params=True) +simplify('models/PReLU_slope.onnx', False) + input = Variable(torch.randn(2, 3)) dropout = nn.Dropout() diff --git a/testdata/dnn/onnx/models/PReLU_slope.onnx b/testdata/dnn/onnx/models/PReLU_slope.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b0d72218b8b7d1f5fbced9a2d4a29c2203a39d57 GIT binary patch literal 180 zcmd;J6Jjr@EXglQ&X8g@)U&jj&B!Io#hRH{P+G#pXez{LCdD2Ql$ui-Z>q!!W(Wy@ zRqJMDmguHd6yz6`XbEyKf}jAS6QijY3j;%gU6ef33?VTt5e`Nn0WKyEMkr