Untitled
unknown
plain_text
2 years ago
1.0 kB
5
Indexable
OutputVector translate_reciprocal_op(const NodeContext& node) {
// computes element-wise 1/x, where x - input
default_op_checks(node, 1, {"Reciprocal"});
auto x = node.get_input(0);
// Check if the input tensor type is complex, and if so, create a complex constant with value (-1, 0)
auto complex_type = x.get_element_type() == TensorProto_DataType_COMPLEX64 || x.get_element_type() == TensorProto_DataType_COMPLEX128;
auto minus_one_const = complex_type ? create_same_type_const_scalar<std::complex<float>>(x, std::complex<float>(-1.0, 0.0)) :
create_same_type_const_scalar<float>(x, -1.0);
// If the input is complex, use complex division, else use regular division
auto reciprocal = complex_type ? make_shared<v1::ComplexDiv>(create_constant(1.0f, x), x) :
make_shared<v1::Div>(create_constant(1.0f, x), x);
set_node_name(node.get_name(), reciprocal);
return {reciprocal};
}
Editor is loading...
Leave a Comment