Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions model-optimizer/extensions/front/onnx/resize_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@

from extensions.ops.upsample import UpsampleOp
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
from mo.graph.graph import Node
from mo.utils.error import Error


class ResizeExtractor(FrontExtractorOp):
op = 'Resize'
enabled = True

@staticmethod
def extract(node: Node):
@classmethod
def extract(cls, node: Node):
onnx_opset_version = get_onnx_opset_version(node)
if onnx_opset_version is not None and onnx_opset_version >= 11:
raise Error("ONNX Resize operation from opset {} is not supported.".format(onnx_opset_version))
mode = onnx_attr(node, 'mode', 's', default=b'nearest').decode()
UpsampleOp.update_node_stat(node, {'mode': mode})
return __class__.enabled
return cls.enabled
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class ReverseSequenceExtractor(FrontExtractorOp):
op = 'ReverseSequence'
enabled = True

@staticmethod
def extract(node):
@classmethod
def extract(cls, node):
batch_axis = onnx_attr(node, 'batch_axis', 'i', default=1)
time_axis = onnx_attr(node, 'time_axis', 'i', default=0)

Expand All @@ -33,4 +33,4 @@ def extract(node):
'seq_axis': time_axis,
}
ReverseSequence.update_node_stat(node, attrs)
return __class__.enabled
return cls.enabled
7 changes: 3 additions & 4 deletions model-optimizer/extensions/front/tf/bucketize_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ class BucketizeFrontExtractor(FrontExtractorOp):
op = 'Bucketize'
enabled = True

@staticmethod
def extract(node):
@classmethod
def extract(cls, node):
boundaries = np.array(node.pb.attr['boundaries'].list.f, dtype=np.float)
Bucketize.update_node_stat(node, {'boundaries': boundaries, 'with_right_bound': False})

return __class__.enabled
return cls.enabled
7 changes: 3 additions & 4 deletions model-optimizer/extensions/front/tf/sparse_to_dense_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class SparseToDenseFrontExtractor(FrontExtractorOp):
op = 'SparseToDense'
enabled = True

@staticmethod
def extract(node):
@classmethod
def extract(cls, node):
SparseToDense.update_node_stat(node)

return __class__.enabled
return cls.enabled
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class LinearComponentFrontExtractor(FrontExtractorOp):
op = 'linearcomponent'
enabled = True

@staticmethod
def extract(node):
@classmethod
def extract(cls, node):
pb = node.parameters
collect_until_token(pb, b'<Params>')
weights, weights_shape = read_binary_matrix(pb)
Expand All @@ -39,4 +39,4 @@ def extract(node):
embed_input(mapping_rule, 1, 'weights', weights)

FullyConnected.update_node_stat(node, mapping_rule)
return __class__.enabled
return cls.enabled