Untitled

 avatar
unknown
c_cpp
a year ago
2.7 kB
10
Indexable
#include "openvino/op/add.hpp"

// Add the following include for complex type support
#include "helper_ops/complex_type_mark.hpp"

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

OutputVector translate_addv2_op(const NodeContext& node) {
    default_op_checks(node, 2, {}, true);
    auto lhs = node.get_input(0);
    auto rhs = node.get_input(1);
    auto result = make_shared<v1::Add>(lhs, rhs);

    auto complex_type_mark_lhs = as_type_ptr<ComplexTypeMark>(lhs.get_node_shared_ptr());
    auto complex_type_mark_rhs = as_type_ptr<ComplexTypeMark>(rhs.get_node_shared_ptr());
    if (complex_type_mark_lhs || complex_type_mark_rhs) {
        FRONT_END_GENERAL_CHECK(complex_type_mark_lhs != nullptr && complex_type_mark_rhs != nullptr,
                                "AddV2 got complex and non-complex inputs. Inputs should be of the same type.");
        lhs = complex_type_mark_lhs->input_value(0);
        rhs = complex_type_mark_rhs->input_value(0);

        // Extract real and imaginary parts of the complex numbers
        auto gather_index_real = make_shared<v0::Constant>(element::i32, Shape{}, 0);
        auto gather_index_imag = make_shared<v0::Constant>(element::i32, Shape{}, 1);
        auto minus_one = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
        auto lhs_real = make_shared<v8::Gather>(lhs, gather_index_real, minus_one)->output(0);
        auto lhs_imag = make_shared<v8::Gather>(lhs, gather_index_imag, minus_one)->output(0);
        auto rhs_real = make_shared<v8::Gather>(rhs, gather_index_real, minus_one)->output(0);
        auto rhs_imag = make_shared<v8::Gather>(rhs, gather_index_imag, minus_one)->output(0);

        // Perform addition for complex numbers: (a + bi) + (c + di) = (a + c) + (b + d)i
        auto result_real = make_shared<v1::Add>(lhs_real, rhs_real);
        auto result_imag = make_shared<v1::Add>(lhs_imag, rhs_imag);

        // Concatenate real and imaginary parts to form the complex result
        auto real_unsqueeze = make_shared<v0::Unsqueeze>(result_real, minus_one);
        auto imag_unsqueeze = make_shared<v0::Unsqueeze>(result_imag, minus_one);
        auto concat_result = make_shared<v0::Concat>(OutputVector{real_unsqueeze, imag_unsqueeze}, -1);

        // Wrap the complex result with ComplexTypeMark and return
        auto complex_result = make_shared<ComplexTypeMark>(concat_result->output(0), complex_type_mark_lhs->get_complex_part_type());
        return {complex_result};
    }

    // If inputs are not complex, return the result as is
    set_node_name(node.get_name(), result);
    return {result};
}

}  // namespace op
}  // namespace tensorflow
}  // namespace frontend
}  // namespace ov
Editor is loading...
Leave a Comment