diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md index 9cab2c59e862..fea746bb02f4 100644 --- a/docs/api/python/symbol/symbol.md +++ b/docs/api/python/symbol/symbol.md @@ -337,6 +337,7 @@ Composite multiple symbols into a new one by an operator. :nosignatures: Symbol.infer_type + Symbol.infer_type_partial Symbol.infer_shape Symbol.infer_shape_partial ``` diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 13ee903407b3..76a4995d15c0 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1562,6 +1562,38 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, const int **aux_type_data, int *complete); +/*! + * \brief partially infer type of unknown input types given the known one. + * + * Return partially inferred results if not all types could be inferred. + * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data + * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. + * + * \param sym symbol handle + * \param num_args numbe of input arguments. + * \param keys the key of keyword args (optional) + * \param arg_type_data the content of the CSR + * \param in_type_size sizeof the returning array of in_types + * \param in_type_data returning array of pointers to head of the input type. + * \param out_type_size sizeof the returning array of out_types + * \param out_type_data returning array of pointers to head of the input type. + * \param aux_type_size sizeof the returning array of aux_types + * \param aux_type_data returning array of pointers to head of the auxiliary type. + * \param complete whether infer type completes or more information is needed. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int *arg_type_data, + mx_uint *in_type_size, + const int **in_type_data, + mx_uint *out_type_size, + const int **out_type_data, + mx_uint *aux_type_size, + const int **aux_type_data, + int *complete); + /*! * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 * \param sym_handle symbol to be converted diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 43de0c9d7535..3e3e79ed59f7 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -882,6 +882,81 @@ def infer_type(self, *args, **kwargs): List of auxiliary state types. The order is same as the order of list_auxiliary_states(). """ + try: + res = self._infer_type_impl(False, *args, **kwargs) + if res[1] is None: + arg_shapes, _, _ = self._infer_type_impl(True, *args, **kwargs) + arg_names = self.list_arguments() + unknowns = [] + for name, dtype in zip(arg_names, arg_shapes): + if not dtype: + if len(unknowns) >= 10: + unknowns.append('...') + break + unknowns.append('%s: %s' % (name, str(dtype))) + warnings.warn( + "Cannot decide type for the following arguments. " + + "Consider providing them as input:\n\t" + + "\n\t".join(unknowns), stacklevel=2) + return res + except MXNetError: + print("infer_type error. Arguments:") + for i, arg in enumerate(args): + print(" #%d: %s" % (i, arg)) + for k, v in kwargs.items(): + print(" %s: %s" % (k, v)) + raise + + def infer_type_partial(self, *args, **kwargs): + """Infers the type partially. + + This functions works the same way as `infer_type`, + except that this function can return partial results. + + In the following example, information about fc2 is not available. So, `infer_shape` + will return a tuple of `None` values but `infer_shape_partial` will return partial values. + + Example + ------- + >>> data = mx.sym.Variable('data') + >>> prev = mx.sym.Variable('prev') + >>> casted_prev = mx.sym.cast(prev, dtype='float32') + >>> out = mx.sym.Activation(data=mx.sym.elemwise_add(data, casted_prev), act_type='relu') + >>> out.list_arguments() + ['data', 'prev'] + >>> out.infer_type(data='float32') + (None, None, None) + >>> out.infer_type_partial(data='float32') + ([numpy.float32, None], [numpy.float32], []) + >>> # infers type if you give information about prev + >>> out.infer_type(data='float32', prev='float16') + ([numpy.float32, numpy.float16], [numpy.float32], []) + + Parameters + ---------- + *args : + Type of known arguments in a positional way. + Unknown type can be marked as None. + + **kwargs : + Keyword arguments of known types. + + Returns + ------- + arg_types : list of numpy.dtype or None + List of argument types. + The order is same as the order of list_arguments(). + out_types : list of numpy.dtype or None + List of output types. + The order is same as the order of list_outputs(). + aux_types : list of numpy.dtype or None + List of auxiliary state types. + The order is same as the order of list_auxiliary_states(). + """ + return self._infer_type_impl(True, *args, **kwargs) + + def _infer_type_impl(self, partial, *args, **kwargs): + """The actual implementation for calling type inference API.""" # pylint: disable=too-many-locals if len(args) != 0 and len(kwargs) != 0: raise ValueError('Can only specify known argument \ @@ -912,7 +987,11 @@ def infer_type(self, *args, **kwargs): aux_type_size = mx_uint() aux_type_data = ctypes.POINTER(ctypes.c_int)() complete = ctypes.c_int() - check_call(_LIB.MXSymbolInferType( + if partial: + infer_func = _LIB.MXSymbolInferTypePartial + else: + infer_func = _LIB.MXSymbolInferType + check_call(infer_func( self.handle, mx_uint(len(sdata)), keys, diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 32b63c11dd9a..9f0d2834fcce 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -638,6 +638,27 @@ int MXSymbolInferType(SymbolHandle sym, API_END(); } +int MXSymbolInferTypePartial(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const int *arg_type_data, + mx_uint *in_type_size, + const int **in_type_data, + mx_uint *out_type_size, + const int **out_type_data, + mx_uint *aux_type_size, + const int **aux_type_data, + int *complete) { + int succ; + *complete = 1; + return MXSymbolInferType(sym, num_args, keys, + arg_type_data, + in_type_size, in_type_data, + out_type_size, out_type_data, + aux_type_size, aux_type_data, + &succ); +} + int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) { API_BEGIN(); LOG(FATAL) << "not implemented"; diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index c5c1b018b081..ac4564b66fa0 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -120,6 +120,13 @@ def test_symbol_infer_type(): assert out == [np.float32] assert aux == [] + # partial infer type + arg, out, aux = mlp.infer_type_partial() + assert arg == [None, np.float32, np.float32, np.float32] + assert out == [np.float32] + assert aux == [] + + def test_symbol_infer_shape(): num_hidden = 128 num_dim = 64