Untitled
OutputVector translate_reciprocal_op(const NodeContext& node) { // computes element-wise 1/x, where x - input default_op_checks(node, 1, {"Reciprocal"}); auto x = node.get_input(0); // Check if the input tensor type is complex, and if so, create a complex constant with value (-1, 0) auto complex_type = x.get_element_type() == TensorProto_DataType_COMPLEX64 || x.get_element_type() == TensorProto_DataType_COMPLEX128; auto minus_one_const = complex_type ? create_same_type_const_scalar<std::complex<float>>(x, std::complex<float>(-1.0, 0.0)) : create_same_type_const_scalar<float>(x, -1.0); // If the input is complex, use complex division, else use regular division auto reciprocal = complex_type ? make_shared<v1::ComplexDiv>(create_constant(1.0f, x), x) : make_shared<v1::Div>(create_constant(1.0f, x), x); set_node_name(node.get_name(), reciprocal); return {reciprocal}; }
Leave a Comment