@@ -92,6 +92,23 @@ def save_data_and_onnx_model(name, input_np, output_np, onnx_model):
9292 with open (models_files , 'wb' ) as file :
9393 file .write (model_def .SerializeToString ())
9494
95+ def save_data_and_onnx_model_multy_inputs (name , input_list , output_np , onnx_model ):
96+ for index in range (len (input_list )):
97+ print (name + " input " + str (index )+ " has sizes" , input_list [index ].shape )
98+ input_files = os .path .join ("data" , "input_" + name + "_" + str (index ))
99+ np .save (input_files , input_list [index ])
100+
101+ print (name + " output has sizes" , output_np .shape )
102+ print ()
103+ output_files = os .path .join ("data" , "output_" + name )
104+ np .save (output_files , np .ascontiguousarray (output_np .data ))
105+
106+ models_files = os .path .join ("models" , name + ".onnx" )
107+
108+ onnx_model_pb = onnx ._serialize (onnx_model )
109+ model_def = assertONNXExpected (onnx_model_pb )
110+ with open (models_files , 'wb' ) as file :
111+ file .write (model_def .SerializeToString ())
95112
96113def simplify (name , rename = False , ** kwargs ):
97114 model , check = onnxsim .simplify (name , ** kwargs )
@@ -2091,3 +2108,23 @@ def gemm_reference_implementation(A: np.ndarray, B: np.ndarray, C: Optional[np.n
20912108
20922109output_np = np .sum (input_np , axis = 1 , keepdims = 1 )
20932110save_data_and_onnx_model ("reduce_sum_axis_dynamic_batch" , input_np , output_np , onnx_model )
2111+
2112+
2113+ # ########################## DivBroadcast ##########################
2114+ input_np = np .random .rand (1 , 4 ).astype ("float32" )
2115+ input2_np = np .random .rand (1 , 1 ).astype (np .float32 )
2116+ inputs = [onnx .helper .make_tensor_value_info ("input1" , onnx .mapping .NP_TYPE_TO_TENSOR_TYPE [input_np .dtype ], shape = input_np .shape ), \
2117+ onnx .helper .make_tensor_value_info ("input2" , onnx .mapping .NP_TYPE_TO_TENSOR_TYPE [input2_np .dtype ], shape = input2_np .shape )]
2118+
2119+ outputs = [onnx .helper .make_tensor_value_info ("output" , onnx .TensorProto .FLOAT , shape = (1 , 4 ))]
2120+
2121+ nodes = [onnx .helper .make_node ("Div" , ["input1" , "input2" ], ["output" ])]
2122+
2123+ graph = onnx .helper .make_graph (nodes ,
2124+ "div_test" ,
2125+ inputs ,
2126+ outputs )
2127+ onnx_model = onnx .helper .make_model (graph )
2128+
2129+ output_np = input_np / input2_np
2130+ save_data_and_onnx_model_multy_inputs ("div_test_1x1" , [input_np , input2_np ], output_np , onnx_model )
0 commit comments