Untitled

 avatar
unknown
plain_text
a year ago
1.0 kB
2
Indexable
OutputVector translate_reciprocal_op(const NodeContext& node) {
    // computes element-wise 1/x, where x - input
    default_op_checks(node, 1, {"Reciprocal"});
    auto x = node.get_input(0);

    // Check if the input tensor type is complex, and if so, create a complex constant with value (-1, 0)
    auto complex_type = x.get_element_type() == TensorProto_DataType_COMPLEX64 || x.get_element_type() == TensorProto_DataType_COMPLEX128;
    auto minus_one_const = complex_type ? create_same_type_const_scalar<std::complex<float>>(x, std::complex<float>(-1.0, 0.0)) :
                                           create_same_type_const_scalar<float>(x, -1.0);

    // If the input is complex, use complex division, else use regular division
    auto reciprocal = complex_type ? make_shared<v1::ComplexDiv>(create_constant(1.0f, x), x) :
                                     make_shared<v1::Div>(create_constant(1.0f, x), x);

    set_node_name(node.get_name(), reciprocal);
    return {reciprocal};
}
Leave a Comment