Untitled
unknown
c_cpp
a year ago
2.7 kB
10
Indexable
#include "openvino/op/add.hpp" // Add the following include for complex type support #include "helper_ops/complex_type_mark.hpp" namespace ov { namespace frontend { namespace tensorflow { namespace op { OutputVector translate_addv2_op(const NodeContext& node) { default_op_checks(node, 2, {}, true); auto lhs = node.get_input(0); auto rhs = node.get_input(1); auto result = make_shared<v1::Add>(lhs, rhs); auto complex_type_mark_lhs = as_type_ptr<ComplexTypeMark>(lhs.get_node_shared_ptr()); auto complex_type_mark_rhs = as_type_ptr<ComplexTypeMark>(rhs.get_node_shared_ptr()); if (complex_type_mark_lhs || complex_type_mark_rhs) { FRONT_END_GENERAL_CHECK(complex_type_mark_lhs != nullptr && complex_type_mark_rhs != nullptr, "AddV2 got complex and non-complex inputs. Inputs should be of the same type."); lhs = complex_type_mark_lhs->input_value(0); rhs = complex_type_mark_rhs->input_value(0); // Extract real and imaginary parts of the complex numbers auto gather_index_real = make_shared<v0::Constant>(element::i32, Shape{}, 0); auto gather_index_imag = make_shared<v0::Constant>(element::i32, Shape{}, 1); auto minus_one = make_shared<v0::Constant>(element::i32, Shape{1}, -1); auto lhs_real = make_shared<v8::Gather>(lhs, gather_index_real, minus_one)->output(0); auto lhs_imag = make_shared<v8::Gather>(lhs, gather_index_imag, minus_one)->output(0); auto rhs_real = make_shared<v8::Gather>(rhs, gather_index_real, minus_one)->output(0); auto rhs_imag = make_shared<v8::Gather>(rhs, gather_index_imag, minus_one)->output(0); // Perform addition for complex numbers: (a + bi) + (c + di) = (a + c) + (b + d)i auto result_real = make_shared<v1::Add>(lhs_real, rhs_real); auto result_imag = make_shared<v1::Add>(lhs_imag, rhs_imag); // Concatenate real and imaginary parts to form the complex result auto real_unsqueeze = make_shared<v0::Unsqueeze>(result_real, minus_one); auto imag_unsqueeze = make_shared<v0::Unsqueeze>(result_imag, minus_one); auto concat_result = make_shared<v0::Concat>(OutputVector{real_unsqueeze, imag_unsqueeze}, -1); // Wrap the complex result with ComplexTypeMark and return auto complex_result = make_shared<ComplexTypeMark>(concat_result->output(0), complex_type_mark_lhs->get_complex_part_type()); return {complex_result}; } // If inputs are not complex, return the result as is set_node_name(node.get_name(), result); return {result}; } } // namespace op } // namespace tensorflow } // namespace frontend } // namespace ov
Editor is loading...
Leave a Comment