-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-344] [ONNX-MXNet] Add new Operator Translations for ONNX import module #11140
Changes from all commits
ba43304
16dca96
3213dbc
4d081ec
422c4ef
159e376
17456ce
7c63b5c
7aace0a
9128cf0
016b4d1
a238d8e
59e6ad2
31e3d7e
933ec17
ed7ab27
290cfa9
a333a33
005df8e
a3d733b
52a4c18
f9158f0
fee20ce
6f518f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
# coding: utf-8 | ||
""" Module for translating ONNX operators into Mxnet operatoes""" | ||
# pylint: disable=unused-argument,protected-access | ||
import numpy as np | ||
from . import translation_utils | ||
from .... import symbol | ||
|
||
|
@@ -80,6 +81,22 @@ def divide(attrs, inputs, proto_obj): | |
return op_value, new_attr, inputs | ||
return 'broadcast_div', new_attr, inputs | ||
|
||
def logical_and(attrs, inputs, proto_obj): | ||
"""Logical and of two input arrays.""" | ||
return 'broadcast_logical_and', attrs, inputs | ||
|
||
def logical_or(attrs, inputs, proto_obj): | ||
"""Logical or of two input arrays.""" | ||
return 'broadcast_logical_or', attrs, inputs | ||
|
||
def logical_xor(attrs, inputs, proto_obj): | ||
"""Logical xor of two input arrays.""" | ||
return 'broadcast_logical_xor', attrs, inputs | ||
|
||
def logical_not(attrs, inputs, proto_obj): | ||
"""Logical not of two input arrays.""" | ||
return 'logical_not', attrs, inputs | ||
|
||
def absolute(attrs, inputs, proto_obj): | ||
"""Returns element-wise absolute value of the input.""" | ||
return 'abs', attrs, inputs | ||
|
@@ -97,7 +114,6 @@ def argmax(attrs, inputs, proto_obj): | |
"""Returns indices of the maximum values along an axis""" | ||
return 'argmax', attrs, inputs | ||
|
||
|
||
def argmin(attrs, inputs, proto_obj): | ||
"""Returns indices of the minimum values along an axis.""" | ||
return 'argmin', attrs, inputs | ||
|
@@ -130,6 +146,18 @@ def minimum(attrs, inputs, proto_obj): | |
mxnet_op = inputs[0] | ||
return mxnet_op, attrs, inputs | ||
|
||
def lesser(attrs, inputs, proto_obj): | ||
"""Logical Lesser operator with broadcasting.""" | ||
return 'broadcast_lesser', attrs, inputs | ||
|
||
def greater(attrs, inputs, proto_obj): | ||
"""Logical Greater operator with broadcasting.""" | ||
return 'broadcast_greater', attrs, inputs | ||
|
||
def equal(attrs, inputs, proto_obj): | ||
"""Logical Equal operator with broadcasting.""" | ||
return 'broadcast_equal', attrs, inputs | ||
|
||
#Hyperbolic functions | ||
def tanh(attrs, inputs, proto_obj): | ||
"""Returns the hyperbolic tangent of the input array.""" | ||
|
@@ -151,6 +179,10 @@ def concat(attrs, inputs, proto_obj): | |
return 'concat', new_attrs, inputs | ||
|
||
# Basic neural network functions | ||
def softsign(attrs, inputs, proto_obj): | ||
"""Computes softsign of x element-wise.""" | ||
return 'softsign', attrs, inputs | ||
|
||
def sigmoid(attrs, inputs, proto_obj): | ||
"""Computes elementwise sigmoid of the input array""" | ||
return 'sigmoid', attrs, inputs | ||
|
@@ -183,6 +215,11 @@ def batch_norm(attrs, inputs, proto_obj): | |
new_attrs['fix_gamma'] = not attrs.get('is_test', 1) | ||
return 'BatchNorm', new_attrs, inputs | ||
|
||
def instance_norm(attrs, inputs, proto_obj): | ||
"""Instance Normalization.""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, {'epsilon' : 'eps'}) | ||
return 'InstanceNorm', new_attrs, inputs | ||
|
||
def leaky_relu(attrs, inputs, proto_obj): | ||
"""Leaky Relu function""" | ||
if 'alpha' in attrs: | ||
|
@@ -211,6 +248,16 @@ def softmax(attrs, inputs, proto_obj): | |
attrs = translation_utils._add_extra_attributes(attrs, {'axis': 1}) | ||
return 'softmax', attrs, inputs | ||
|
||
def log_softmax(attrs, inputs, proto_obj): | ||
"""Computes the log softmax of the input. This is equivalent to | ||
computing softmax followed by log.""" | ||
return 'log_softmax', attrs, inputs | ||
|
||
def softplus(attrs, inputs, proto_obj): | ||
"""Applies the sofplus activation function element-wise to the input.""" | ||
new_attrs = translation_utils._add_extra_attributes(attrs, {'act_type' : 'softrelu'}) | ||
return 'Activation', new_attrs, inputs | ||
|
||
def conv(attrs, inputs, proto_obj): | ||
"""Compute N-D convolution on (N+2)-D input.""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, {'kernel_shape' : 'kernel', | ||
|
@@ -389,15 +436,9 @@ def transpose(attrs, inputs, proto_obj): | |
|
||
def squeeze(attrs, inputs, proto_obj): | ||
"""Remove single-dimensional entries from the shape of a tensor.""" | ||
# MXNet doesnt have a squeeze operator. | ||
# Using "split" to perform similar operation. | ||
new_attrs = translation_utils._fix_attribute_names(attrs, | ||
{'axes' : 'axis'}) | ||
axes = new_attrs.get('axis') | ||
mxnet_op = symbol.split(inputs[0], axis=axes[0], num_outputs=1, squeeze_axis=1) | ||
for i in axes[1:]: | ||
mxnet_op = symbol.split(mxnet_op, axis=i-1, num_outputs=1, squeeze_axis=1) | ||
return mxnet_op, new_attrs, inputs | ||
return 'squeeze', new_attrs, inputs | ||
|
||
def unsqueeze(attrs, inputs, cls): | ||
"""Inserts a new axis of size 1 into the array shape""" | ||
|
@@ -417,6 +458,16 @@ def flatten(attrs, inputs, proto_obj): | |
new_attrs = translation_utils._remove_attributes(attrs, ['axis']) | ||
return 'Flatten', new_attrs, inputs | ||
|
||
def clip(attrs, inputs, proto_obj): | ||
"""Clips (limits) the values in an array.""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, {'min' : 'a_min', | ||
'max' : 'a_max'}) | ||
if 'a_max' not in new_attrs: | ||
new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_max' : np.inf}) | ||
if 'a_min' not in new_attrs: | ||
new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_min' : -np.inf}) | ||
return 'clip', new_attrs, inputs | ||
|
||
#Powers | ||
def reciprocal(attrs, inputs, proto_obj): | ||
"""Returns the reciprocal of the argument, element-wise.""" | ||
|
@@ -454,20 +505,49 @@ def reduce_mean(attrs, inputs, proto_obj): | |
return 'mean', new_attrs, inputs | ||
|
||
def reduce_min(attrs, inputs, proto_obj): | ||
"""Reduce the array along a given axis by mean value""" | ||
"""Reduce the array along a given axis by minimum value""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) | ||
return 'min', new_attrs, inputs | ||
|
||
def reduce_sum(attrs, inputs, proto_obj): | ||
"""Reduce the array along a given axis by mean value""" | ||
"""Reduce the array along a given axis by sum value""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) | ||
return 'sum', new_attrs, inputs | ||
|
||
def reduce_prod(attrs, inputs, proto_obj): | ||
"""Reduce the array along a given axis by mean value""" | ||
"""Reduce the array along a given axis by product value""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) | ||
return 'prod', new_attrs, inputs | ||
|
||
def reduce_log_sum(attrs, inputs, proto_obj): | ||
"""Reduce the array along a given axis by log sum value""" | ||
keep_dims = True if 'keepdims' not in attrs else attrs.get('keepdims') | ||
sum_op = symbol.sum(inputs[0], axis=attrs.get('axes'), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why inputs[0]? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ReduceLogSum op takes has input tensor - https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceLogSum. And while translating this ONNX operator into MXNet, we are splitting the reduceLogSum op to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it has been similarly used in other operator translations, for example here - https://github.com/apache/incubator-mxnet/pull/11140/files/6f518f13413387b40a5c45d96bd2c1d2decf7804#diff-ea46c490a749b158ef8986126da6bc12R284 |
||
keepdims=keep_dims) | ||
log_sym = symbol.log(sum_op) | ||
return log_sym, attrs, inputs | ||
|
||
def reduce_log_sum_exp(attrs, inputs, proto_obj): | ||
"""Reduce the array along a given axis by log sum exp value""" | ||
keep_dims = True if 'keepdims' not in attrs else attrs.get('keepdims') | ||
exp_op = symbol.exp(inputs[0]) | ||
sum_op = symbol.sum(exp_op, axis=attrs.get('axes'), | ||
keepdims=keep_dims) | ||
log_sym = symbol.log(sum_op) | ||
return log_sym, attrs, inputs | ||
|
||
def reduce_sum_square(attrs, inputs, proto_obj): | ||
"""Reduce the array along a given axis by sum square value""" | ||
square_op = symbol.square(inputs[0]) | ||
sum_op = symbol.sum(square_op, axis=attrs.get('axes'), | ||
keepdims=attrs.get('keepdims')) | ||
return sum_op, attrs, inputs | ||
|
||
def reduce_l2(attrs, inputs, proto_obj): | ||
"""Reduce input tensor by l2 normalization.""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) | ||
return 'norm', new_attrs, inputs | ||
|
||
def avg_pooling(attrs, inputs, proto_obj): | ||
""" Average pooling""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, | ||
|
@@ -497,3 +577,11 @@ def max_pooling(attrs, inputs, proto_obj): | |
new_op = translation_utils._fix_pooling('max', inputs, new_attrs) | ||
|
||
return new_op, new_attrs, inputs | ||
|
||
def max_roi_pooling(attrs, inputs, proto_obj): | ||
"""Max ROI Pooling.""" | ||
new_attrs = translation_utils._fix_attribute_names(attrs, | ||
{'pooled_shape': 'pooled_size', | ||
'spatial_scale': 'spatial_scale' | ||
}) | ||
return 'ROIPooling', new_attrs, inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use np.ninf? I think that is float, please check once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.inf can be used as float - https://stackoverflow.com/questions/42315541/difference-between-np-inf-and-floatinf