From 7e942654b7ad6c69e023c04f6175bd1f3fa496e1 Mon Sep 17 00:00:00 2001 From: Julie Bareeva Date: Fri, 23 Jul 2021 10:00:00 +0300 Subject: [PATCH] add test data for LSTM layer with non-zero hidden params --- testdata/dnn/layers/lstm.hidden.B.npy | Bin 0 -> 176 bytes testdata/dnn/layers/lstm.hidden.R.npy | Bin 0 -> 224 bytes testdata/dnn/layers/lstm.hidden.W.npy | Bin 0 -> 272 bytes testdata/dnn/layers/lstm.hidden.c0.npy | Bin 0 -> 188 bytes testdata/dnn/layers/lstm.hidden.h0.npy | Bin 0 -> 188 bytes testdata/dnn/layers/lstm.hidden.input.npy | Bin 0 -> 288 bytes testdata/dnn/layers/lstm.hidden.output.npy | Bin 0 -> 248 bytes testdata/dnn/onnx/data/input_hidden_lstm.npy | Bin 0 -> 288 bytes .../dnn/onnx/data/input_hidden_lstm_bi.npy | Bin 0 -> 288 bytes testdata/dnn/onnx/data/output_hidden_lstm.npy | Bin 0 -> 248 bytes .../dnn/onnx/data/output_hidden_lstm_bi.npy | Bin 0 -> 368 bytes testdata/dnn/onnx/generate_onnx_models.py | 27 ++++++++++++++++++ testdata/dnn/onnx/models/hidden_lstm.onnx | Bin 0 -> 3809 bytes testdata/dnn/onnx/models/hidden_lstm_bi.onnx | Bin 0 -> 6069 bytes 14 files changed, 27 insertions(+) create mode 100644 testdata/dnn/layers/lstm.hidden.B.npy create mode 100644 testdata/dnn/layers/lstm.hidden.R.npy create mode 100644 testdata/dnn/layers/lstm.hidden.W.npy create mode 100644 testdata/dnn/layers/lstm.hidden.c0.npy create mode 100644 testdata/dnn/layers/lstm.hidden.h0.npy create mode 100644 testdata/dnn/layers/lstm.hidden.input.npy create mode 100644 testdata/dnn/layers/lstm.hidden.output.npy create mode 100644 testdata/dnn/onnx/data/input_hidden_lstm.npy create mode 100644 testdata/dnn/onnx/data/input_hidden_lstm_bi.npy create mode 100644 testdata/dnn/onnx/data/output_hidden_lstm.npy create mode 100644 testdata/dnn/onnx/data/output_hidden_lstm_bi.npy create mode 100644 testdata/dnn/onnx/models/hidden_lstm.onnx create mode 100644 testdata/dnn/onnx/models/hidden_lstm_bi.onnx diff --git a/testdata/dnn/layers/lstm.hidden.B.npy b/testdata/dnn/layers/lstm.hidden.B.npy new file mode 100644 index 0000000000000000000000000000000000000000..ca178d67555d0a8f20aa87fe27309a5601b9618b GIT binary patch literal 176 zcmbR27wQ`j$;jZwP_3Yzl3JWxq;934Zj)xBuA`uymS0p-l$aNvUzCyx5_e0?DNY57 z7iT0EqyqUGhB^vHCYlPh3NXNR;QKL~hTeU4*NX4&%h);9w&6<8?jw6&?k#wKYMOQ`0be@9cDK# ziFH34`)<1xg&limd~MzrlK93>>%qHy9}d>qmh^w$D<*N-&VobL?iXjDoyDmDyJ^~d X`xsXL*!yzk=Y5B?7uem|{K^ghP7GQv literal 0 HcmV?d00001 diff --git a/testdata/dnn/layers/lstm.hidden.W.npy b/testdata/dnn/layers/lstm.hidden.W.npy new file mode 100644 index 0000000000000000000000000000000000000000..f35a32424de7950e1c6a27c0d4ea58e1847f0509 GIT binary patch literal 272 zcmV+r0q_2kPE}1%Spfh>0DB-OWMy+>awj?lfb?Jx$?Z8 zo6|nf8gsp%a)Uh30!+PVOz%C;^t(PR0SrHaxw}3rQw%>a2P?h*#RR{JBaFTxN69__ WM|Hcl?UKGahT^^qQVPG^efho&9(PCp literal 0 HcmV?d00001 diff --git a/testdata/dnn/layers/lstm.hidden.c0.npy b/testdata/dnn/layers/lstm.hidden.c0.npy new file mode 100644 index 0000000000000000000000000000000000000000..1f43a8dce5620d13fdf8a4261d03f38736a9eb30 GIT binary patch literal 188 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-Itr#b3dWi`3bhL40j^VCzist@3G9EwvES}oWXk?Rt%Lh5m+ZA$o43n8 lX?lTuU5uN3cBP9Q`;mqAzjT)F-`eW6KaYF$9{+dd`vF={IRF3v literal 0 HcmV?d00001 diff --git a/testdata/dnn/layers/lstm.hidden.h0.npy b/testdata/dnn/layers/lstm.hidden.h0.npy new file mode 100644 index 0000000000000000000000000000000000000000..889bdd4cacf3b418bd1dac5d35312ce936a23b09 GIT binary patch literal 188 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-Itr#b3dWi`3bhL40j?PPqxJzd)Ao55av#{%B51$xV%L6M*3J8?n$Oui l?Fh4XQMB04-F$AJ#9oR0$I`9r*!`pT&wS8hzgnhqKL89@Hx2** literal 0 HcmV?d00001 diff --git a/testdata/dnn/layers/lstm.hidden.input.npy b/testdata/dnn/layers/lstm.hidden.input.npy new file mode 100644 index 0000000000000000000000000000000000000000..4b36d6f6f6dff71544b9630bd97507bfccc40765 GIT binary patch literal 288 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itr#b3MQI53bhL40j@q(3HyY&qxP#GFxr3i72j{SOK`t^bGE&uexLme zsipQRRr~kb*1GM}k$Pe`d(le!gS{8_Gwk8ne{IdzeNNZ@@7pdsVgF>`Mf(`8b?=X? zTDLbwcmLkDNBZ{HW{2%pJ(+2>)u n9sie?cE37e_n%?*uorU5-G3%7*lyjqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itr#b3dWi`3bhL40j`;UUfZ3#D6w1P=G8s7Gv@Cp+WTqOi%_-QyvMcd zY}Ksn_I_Jo8|_iFFJZ-9yX=pNyF(_l?EAgY$7av|t$Ve@L+y@VII_3;ZiUUZxzqO9 yb-37h6dtr|f3VC>{TiFC{JI>g>wiwzHooq*d&P9YRzM_TpL_JIy%)`I*#H2?30i6Z literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/input_hidden_lstm.npy b/testdata/dnn/onnx/data/input_hidden_lstm.npy new file mode 100644 index 0000000000000000000000000000000000000000..d0a66a9949e28958dbcb521df8948a43f5c1bcbe GIT binary patch literal 288 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itr#b3MQI53bhL40j^f*3H!SmChx!I-LNlK_1->F*GT(jbwm5yGYj_} zjAuV!!MN1!%a(=v{~K(!+kJShU3vaj+sWDU?E60T+1HrW?mMKscE6j*&HXF_g7&e> z-uB(TlJ?QO$M>H)wr-!F{#$z|?$~{c#qIVt@d?;JdfQ@mRmo^yz#-lJw+##IRegWi n*)Y$tZ&z4rt37Y#{&g4U?CxE_vOkn#?mnitQ~SM&Z`lI?Rt;tU literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/input_hidden_lstm_bi.npy b/testdata/dnn/onnx/data/input_hidden_lstm_bi.npy new file mode 100644 index 0000000000000000000000000000000000000000..4068441637de8ffc78f4a78ce6b488436672ac10 GIT binary patch literal 288 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itr#b3MQI53bhL40j|n5uWc(_&)eH(zO`$pTV~(>Y>C}XrVD!?zYnsz zabULnX}je8QNI@L`)0O!-@WNC?5;Hax3kYox4rb@yj^$Gz5SOPzuG-M2{XYXHp#o1oldZs-9eZ+19 literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/output_hidden_lstm.npy b/testdata/dnn/onnx/data/output_hidden_lstm.npy new file mode 100644 index 0000000000000000000000000000000000000000..1de0fa377b1b9f89819bc6f8543d0d836fb3a809 GIT binary patch literal 248 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itr#b3dWi`3bhL40WRmC9QIw#!uDyc3U=I{oc0e^3fjjVtF>MVy?X?%}5|laZ uU*zf7t)73;E^IoNy=k|)-Sp0jc59z-+0PKtuw%)%Xtz>{+x{i9h8+Om8c5Or literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/output_hidden_lstm_bi.npy b/testdata/dnn/onnx/data/output_hidden_lstm_bi.npy new file mode 100644 index 0000000000000000000000000000000000000000..1b3ebdc3d8a515bfcc3636f9d8aa7aeb2e25155c GIT binary patch literal 368 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itr#b3TB!*3bhL40j?acqqbt)yKN)XEA6}!YwfB`J@zS{zq)t+?nAce zcW&5Pi#OVRT~cQE?YZB+1yvXKo)Oq$d#mT7ZO)ZOyJE*$y9cK|_Dwu^Z!g>BwYGg~ z2W^jCt+or$ZL(u2aNQTY^}}8#^KG^}9$&GQpU`0UTC3XbOM%C}eyfLjc^j+uUX2ak z`?=TNE}$pRHsni(O=)faUf24*y?>@z?&W^!Z&z_U*j6L1*rv31#@?I@6ZY=>=Cijl z-oFnFPm?y<6fT>*_nq*fy$dd6?)}x~ TZ+GT!oNeyIB{s$rX72?6M`x7Q literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index 234118f73..6187bbe54 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -769,6 +769,33 @@ def forward(self, x): lstm = LSTM(features, hidden, batch, bidirectional=True) save_data_and_model("lstm_bidirectional", input, lstm) + + +class HiddenLSTM(nn.Module): + def __init__(self, input_size, hidden_size, num_layers=1, is_bidirectional=False): + super().__init__() + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bi_coeff = 2 if is_bidirectional else 1 + self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, bidirectional=is_bidirectional) + + def forward(self, t): + h_0 = torch.ones(self.num_layers * self.bi_coeff, t.size(1), + self.hidden_size) + c_0 = torch.ones(self.num_layers * self.bi_coeff, t.size(1), + self.hidden_size) + return self.lstm(t, (h_0, c_0))[0] + +input = torch.randn(seq_len, batch, features) +hidden_lstm = HiddenLSTM(features, hidden, num_layers=3, is_bidirectional=False) +save_data_and_model("hidden_lstm", input, hidden_lstm, version=11, export_params=True) + +input = torch.randn(seq_len, batch, features) +hidden_lstm = HiddenLSTM(features, hidden, num_layers=3, is_bidirectional=True) +save_data_and_model("hidden_lstm_bi", input, hidden_lstm, version=11, export_params=True) + + class MatMul(nn.Module): def __init__(self): super(MatMul, self).__init__() diff --git a/testdata/dnn/onnx/models/hidden_lstm.onnx b/testdata/dnn/onnx/models/hidden_lstm.onnx new file mode 100644 index 0000000000000000000000000000000000000000..62152f83c5e823bd1ecb7537fa12ce3a140cfe18 GIT binary patch literal 3809 zcmb7{3sh4_8po56kZ_4ggUiA6u`G+NMrD(G^VHlgu2yh4h^Xiy$2Fj)1rI`m&~<(6 zrYpLN&sDVeD4TlJwXI#MWmSaxC0bvl9LquLzRrX`raDL$YZzPoGW8-Q%gS*UxKt7yqn6Le%y2sF+Gv@7K zUQJER&B=FV=zSVD~_6{cY5^lP& zA+?oKyig|;GS4BTg+hG*p^((!#2}g?yOPF#*OC)3YQ>D~toaVR&Opk`kNgMI3)eRW zSbtzf2BA<0#*bNH6bj7&b{voq!)Ib3%}5$y%VO%y$kLNha`&w#a2R+(j9R|aQ6eHC zNLu((m`Ss!f^=G~TAnx~Ws05FMoamDMn#2YW-VCY$g$^VEp6&1gd5hfOhmB$hj;z3_!^WAMNBV>kVgkUny!3<5_ z^}fr%1Y%}jV&t6{C<8GwATgd9vM=BHM_Q5?)B9{P{g^TGxL8QXXbCe6L5MIzY~aNV zuha7#p@(LS-OT8Jgqd!QsgD``Qb08Hxvok?@w-v0VB3o0I#5Q@_9kyh($lI`P zHk;aX3^!qjH*$KAt@hIfIFuXi-JE=xU9LULB^ICL+{u5!5oz-QBFa!-nF`t!+d$*# zN-4;nl2MzmS$yzwrPn9ZR88M?1v`VTVb~N2SF+|1d;jG;T>WMlT)z4VmUw?)r)~A2 zsG_B+bk_x_Rp)b_yp`5kNg|h-*vy`9t%SKR1=|j9t%iqj9XK~89HSO~X}!1dARA^@ zLWwvNlK(o0#c`W)dc!)Hc6N~WguV=4?F`3|N1O0QT@q$bZo$mL?>WsPVtQoP`E&{C zqs2*@D(f&0NY$ zqzbiv0r!Yoob^4aS7+_SMe7E8=Poz~52xNlXGaJ3&!f3$_BC?-9WhymBck96&V8VD zocLF(QTq3}a8xeyx=Vh_RX-HlB(E;U=-qnnf~KJf@10n|x;C|Q5s??btQZat zLpFlBIvKW)8jPY3A8@yKe1|7Td7%E2uLl&(%ookqe#YwFfN&}QD|9wjBkcPShu%C7 zdkRx<{j3hSb#5lCuX_y@bKsx-K@vi3Q#v1m;Rv)%lD^Y((hH? z{zbF!MH`v;0WNQQALwLAiyS44+8BF@N3D1?@fm=mBb^+T4iF6$HJ0f`_$r0nSYoKvZJ!aE+ zu=Ddce7kfbj#zk+^=#OIg(d}r>2I;3e_7bL!!@w~7c$#%?jvyKm9ok~MevjJo3QA9 zD-NL_;qqnwLEn{A?A#7Rw!c}$wxP1?5K|<_#KuG9GD~fz>NbK@^9J~2b!<%0GPoG?9<;?(;DdKZ x0{n6+Rz0`@o=-1wX}$kxsi-i8Sf!*S3JGa4(xmf-@xM4ss>eSHnji@a{Xc}kdb9ul literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/models/hidden_lstm_bi.onnx b/testdata/dnn/onnx/models/hidden_lstm_bi.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fe5b959beabbdcca306dcebaea99488a2aa5815e GIT binary patch literal 6069 zcmbtY30M=?+Ma|ELc$UPD#lhRm%323NoEoPX1<~Rs7PICrIq3WE?7_m1WMIfO)d4> z7LigHv@Q{?prWGUR+w+7V&y6=Xe}b#YE(q36j4!ess1x8n%aA5``me+Ih;(+cfRw! z-+SIUK_*o>B+g4oNSYR-w$l#Q2W)s%-b-#D8=v@Iib|~0s~yJ0Oi7H6)Og#sb>tNP z)S#CCUqXCx%9Qw&NUgW5v+wIEw@;lC_g=KGlhPqXCL-D|i%24WMb6K~-v&O~Pc0iU zB_$?0DN^SxZSVOiaGofNQmdY0B~q*X>r5F{c3K0!%+5I?wK^L* z)Sdjg@ci<^>Cg>dT#39bQfC`U{+&qDHj?ThQfwpE0ek@{Ilp3_F|C0V8m$0;Qp?80 z#ZHTk)av;v3%&N`>1~r3x=8W|U35{L=W_$|F zzY|z;hX?SQ@#6PE1IS zj?`(r9iM*kb(2dHqmyO_5`p4CQ6L}ZAP*K6U|VXPt-O3Y<<(ne$IHvJvkIMmpLAS; z+qM zZUkLdO5;7M?Gkc!$eGV8=a*>-3U7WzXQy6 zBaoUdBA){$$XVyF7aCqXI)6$4#v5llFjCtIY>rS;eym=n5gMKey-p{HnWsT${DinR z^t}8?jiy7)r0#zbbK8`pMr$i(l7Er?nV3n9ZJ|k9as5p(^U38r9-RP+d?sd6|2r{L zT7R3GlU)hxxtdct+igI0Wns3J<2|8}x#hc^ZVqAjI*CrgUu-_;-kzqW&i_w@*oGY{g9J|6I(JaE^QW z;-_rI^*6;#kwVc){#{Iwx@Tgh{GN(grEELN4w@O8GAMd(VnR}iue)%j7|KVGA+k=s zpFN~PWOf~ufA&!SpGF#NMrPW{#4;zDL@DAY_VWud+XX#&0M@vprgvOU;fXcFuxY** zdiAMeOG2m8j|yarSHN$kgtP+~IP;dNsdyP!cAqsZ8*a}@`rUzWxy@vz0VC9dj%g$PyF?0b+#BagDYo{=!bOM~I z=ku$3(VmiHC_g<82lf;J5q=SCYO^5dRVkP6;lkZq=gI}2zrj8}_8s`tDLBi%J0M*Y zhXwrx!Ma8grml9uV!yHA?OlaOPVdCcidA^}j03Z(TOTAg{P4{UzvZ|j-JMw&p=4sZ z-Di^&65LQ#1YXKZCTqPbH~8aSSohgTw3MuXg|Bm%I!VD)?X*M58)w*U&F7KWd;<%W zRwOZXxI!4@Aib= z?b(3qQcRG!YXevgZK2c3j^UM8uVDQbn<1)uwyES_26Rd%@Q}}XI(=yt9PzBcx*n|% zo@r+Dq7B&i?MJZqgc~D%xDB%Ve!|9$91HshH*V{^M>wkKQy7qti(3sDAlq4qQ)gzO z--6pH9=aB+cN1BkMVFzn`2d7%yb76FE5K?014#XCI|PT5c&$$<2Ht*^mE^WS!K~rv zQfr2njy}QjKg*aZi4*h99DAl}(tTXu9*KR&oWzW~3Z|EvBWH|Ag4Bzy%)&AKaY$(k zZd6LS=6)U!r6k#$;VaQf+@L+uq9Cm2br@y70?m&KSj#vumz`nH1rTy3B!4mZ{vhHh z^K2)~KcqzDzV9Bs*1Qp&j*FRNp&K!@AOMJ)b70^r4ouBe8K?ZP5^njdfEVuXf-R;8 zkiB;>PCqvcUg+tI&5Ac@r_dc>4VlHRI&uI~WIkw|F%g@8@&NPRH(+{PCT2S<#I!q0 z@bvk$@OAWJEZqMFP#FX>8<%4E>Mzh*GlmTvGmSkKRF6>!cfsnu5*pskrG3&K!h?u^ zVf{)GGiumTcImAR^vI)2kvOej!hcnOi(bhc8QULyD1)7uw z^tk;i^t|uFHO6jd%XU|=4Qm((yY()3wO)ea0R=3PQB-^={si2*+Xtt9A!1r9OVQlz z0PQmbNdi}9M?ooO)c#_AqN>6kN%p!^2Xg@6EoZyX|*%=({4MiHDf;%^n0JS#>wbO zdp2X=C-38ue-Fg*4_&#ylXpy&y-(4E5AjX^>i~7M6BDF;5%S`~(PwNjHhZzya;+Q_ zXCr&Y>lFKq`D5@bT!1aP*CA+D1ri=-1GfBI!hF}NVj_m@z`%{C>A5>Tr#-Vj!{I|m zVaiKhFm>7Y_&ZzSkBz*o#72CElG@Qwqmytxxl*W$E<($}MQmP=Y&x;(A)InthU$t& zOsn4uyN9m9lF{|-z%AFna=;!-!p74LIX|Fg(?d8B$%2(y#nw+t$FBpC@9%YR~K>~pa{UseZTU1yL0A2w@idOz1V8q zLuY<;3DPdcVVY6KQ9aLM-f~CgTBbEX!wd3u#JG)Ld3<^hSftYj;B(xi|D^<#v-v17&eyjl5 z7ujs%j4OD0?gTVq8v7p0J8cgqZvO-Y_edNK#V{Y2o!y7oBVy5c=S>_qrV-5-3*eS% z6~vg=)AjQu}N%~5XL3iWEp{^wY*mLc6z;R3oU5hHnf zJlt6);x=Vf;lh!NA@LuZ=<)%XaB`hHcWQ==%Nn*HsytZu_Mkmi85;$6Pt{}fzb`=b z$NtdbqhiP+H||NW3%9lQIFuUg7~>23A!kV`#64(5>h4d_@VSD^c9+16P2ZyW$C(&D z`4l}lIR=Zn6|t5*FSCX<3hqeac5Eq00P82gSSTxp{N2Amj3}QrKM7z1w+}%}{~N{O z{%2`%`9&;#-5I-Gx`Uas_X8bRjCbuFn4r{~P;fksZg83gbvvuEeC>EJd!@0));L&F zS_ztkW-t#J$y#@~v7baAM&sGH(8sd`rmJRS^@@X#A3qZx6*yw`3{RAXZ$p_h0~&@E z&;fhOA-f_KJg2&Flka^4n(gP=>t8R%h{HvYzxphUZV++SpWb38mfV8Gd($Dt@jj@& z3PIX!GitPtpC6#@isx*_K=HPsxUb^+FIr8fXP<|cF6_g!m&G9cp#s837(jXX2|f6I zcP6~J9_Ft4nQrL?SYs>$tG|f0{`xZ=@G>BA!eK~W_+_^0%MzIQ+ExsW8AU&`TMgoO z#EjY5o$IZ#K+M;#prlC2EgfP9`X&puF3+W7u5;+0d>#|?g4suFhU1%;^FdknI&BUe zX}Z42fg`RK7llrWXET$+*v!d|tiA6mHuClD;RNOx3R8Ad0}Si;8~W7L;iaZ5I8^coCG37^`9;CS p&uhhqr~f@x$(1sZ%1O?jVENmgza>+93O{sosy}`u9V%8T{vX>hTE74Q literal 0 HcmV?d00001