Untitled
unknown
c_cpp
2 years ago
2.7 kB
17
Indexable
#include "openvino/op/add.hpp"
// Add the following include for complex type support
#include "helper_ops/complex_type_mark.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_addv2_op(const NodeContext& node) {
default_op_checks(node, 2, {}, true);
auto lhs = node.get_input(0);
auto rhs = node.get_input(1);
auto result = make_shared<v1::Add>(lhs, rhs);
auto complex_type_mark_lhs = as_type_ptr<ComplexTypeMark>(lhs.get_node_shared_ptr());
auto complex_type_mark_rhs = as_type_ptr<ComplexTypeMark>(rhs.get_node_shared_ptr());
if (complex_type_mark_lhs || complex_type_mark_rhs) {
FRONT_END_GENERAL_CHECK(complex_type_mark_lhs != nullptr && complex_type_mark_rhs != nullptr,
"AddV2 got complex and non-complex inputs. Inputs should be of the same type.");
lhs = complex_type_mark_lhs->input_value(0);
rhs = complex_type_mark_rhs->input_value(0);
// Extract real and imaginary parts of the complex numbers
auto gather_index_real = make_shared<v0::Constant>(element::i32, Shape{}, 0);
auto gather_index_imag = make_shared<v0::Constant>(element::i32, Shape{}, 1);
auto minus_one = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
auto lhs_real = make_shared<v8::Gather>(lhs, gather_index_real, minus_one)->output(0);
auto lhs_imag = make_shared<v8::Gather>(lhs, gather_index_imag, minus_one)->output(0);
auto rhs_real = make_shared<v8::Gather>(rhs, gather_index_real, minus_one)->output(0);
auto rhs_imag = make_shared<v8::Gather>(rhs, gather_index_imag, minus_one)->output(0);
// Perform addition for complex numbers: (a + bi) + (c + di) = (a + c) + (b + d)i
auto result_real = make_shared<v1::Add>(lhs_real, rhs_real);
auto result_imag = make_shared<v1::Add>(lhs_imag, rhs_imag);
// Concatenate real and imaginary parts to form the complex result
auto real_unsqueeze = make_shared<v0::Unsqueeze>(result_real, minus_one);
auto imag_unsqueeze = make_shared<v0::Unsqueeze>(result_imag, minus_one);
auto concat_result = make_shared<v0::Concat>(OutputVector{real_unsqueeze, imag_unsqueeze}, -1);
// Wrap the complex result with ComplexTypeMark and return
auto complex_result = make_shared<ComplexTypeMark>(concat_result->output(0), complex_type_mark_lhs->get_complex_part_type());
return {complex_result};
}
// If inputs are not complex, return the result as is
set_node_name(node.get_name(), result);
return {result};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
Editor is loading...
Leave a Comment