Skip to content

Commit 294a5ca

Browse files
committed
Merge pull request #956 from rogday:split_expand
2 parents e2dc7f6 + 2702fbc commit 294a5ca

File tree

4 files changed

+41
-8
lines changed

4 files changed

+41
-8
lines changed
248 Bytes
Binary file not shown.
-32 Bytes
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -434,16 +434,49 @@ def forward(self, x):
434434
save_data_and_model("slice", input, model)
435435
save_data_and_model("slice_opset_11", input, model, version=11)
436436

437-
class SliceStarts(nn.Module):
438-
def __init__(self, *args, **kwargs):
439-
super(SliceStarts, self).__init__()
437+
def generate_slice_neg_starts():
438+
x = np.random.randn(2, 3, 4, 3).astype(np.float32)
439+
y = x[-1:2, -3:-1, 2:3, 1:-1]
440440

441-
def forward(self, x):
442-
return x[-1:]
441+
starts = np.array([-1, -3, 2, 1], dtype=np.int64)
442+
starts = onnx.numpy_helper.from_array(starts, name='starts')
443+
ends = np.array([ 2, -1, 3, -1], dtype=np.int64)
444+
ends = onnx.numpy_helper.from_array(ends, name='ends')
443445

444-
model = SliceStarts()
445-
input_ = Variable(torch.randn(1, 10, dtype=torch.float32))
446-
save_data_and_model("slice_neg_starts", input_, model)
446+
node = onnx.helper.make_node(
447+
'Slice',
448+
inputs=['X', 'starts', 'ends'],
449+
outputs=['Y'],
450+
)
451+
452+
X = onnx.helper.make_tensor_value_info('X', onnx.TensorProto.FLOAT, list(x.shape))
453+
Y = onnx.helper.make_tensor_value_info('Y', onnx.TensorProto.FLOAT, list(y.shape))
454+
455+
graph = onnx.helper.make_graph(
456+
[node], # nodes
457+
'slice_neg_starts', # name
458+
[X], # inputs
459+
[Y], # outputs
460+
)
461+
462+
graph.initializer.append(starts)
463+
graph.initializer.append(ends)
464+
465+
model = onnx.helper.make_model(graph, producer_name='onnx')
466+
onnx.checker.check_model(model)
467+
468+
name = 'slice_neg_starts'
469+
470+
input_files = os.path.join("data", "input_" + name)
471+
np.save(input_files, x.data)
472+
473+
output_files = os.path.join("data", "output_" + name)
474+
np.save(output_files, np.ascontiguousarray(y.data))
475+
476+
models_files = os.path.join("models", name + ".onnx")
477+
onnx.save(model, models_files)
478+
479+
generate_slice_neg_starts()
447480

448481
input_2 = Variable(torch.randn(6, 6))
449482
custom_slice_list = [
49 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)