Untitled
unknown
plain_text
a year ago
5.1 kB
10
Indexable
// Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include "openvino/op/gather.hpp" #include "common_op_table.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/gather_nd.hpp" #include "helper_ops/complex_type_mark.hpp" #include "openvino/op/add.hpp" #include "openvino/op/equal.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/subtract.hpp" using namespace std; using namespace ov::op; namespace ov { namespace frontend { namespace tensorflow { namespace op { OutputVector translate_basic_gather_op(const NodeContext& node, const ov::Output<ov::Node>& axis, int64_t batch_dims) { auto op_type = node.get_op_type(); TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 2, op_type + " must have at least two inputs."); auto params = node.get_input(0); auto indices = node.get_input(1); auto gather = make_shared<v8::Gather>(params, indices, axis, batch_dims); set_node_name(node.get_name(), gather); return {gather}; } OutputVector translate_gather_op(const NodeContext& node) { // Gather has two inputs: data and indices // axis by which data is sliced is always equal to 0, batch_dims is always equal to 0 default_op_checks(node, 2, {"Gather"}); auto params = node.get_input(0); auto complex_type_mark = as_type_ptr<ComplexTypeMark>(params.get_node_shared_ptr()); if (complex_type_mark) { params = complex_type_mark->input_value(0); // Since complex type outputs would have an extra axis for Real & Image, we route the indicies to // be picked from the 2nd dimension instead auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 1); auto indices = node.get_input(1); auto gather = make_shared<v8::Gather>(params, indices, axis, 0); set_node_name(node.get_name(), gather); auto complex_reshape = make_shared<ComplexTypeMark>(gather, complex_type_mark->get_complex_part_type()); return {complex_reshape->output(0)}; } else { auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 0); return translate_basic_gather_op(node, axis, 0); } } OutputVector translate_resource_gather_op(const NodeContext& node) { // ResourceGather has two inputs: data and indices // axis by which data is sliced is always equal to 0, batch_dims is an attribute and can vary default_op_checks(node, 2, {"ResourceGather"}); auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 0); auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0); return translate_basic_gather_op(node, axis, batch_dims); } OutputVector translate_gather_v2_op(const NodeContext& node) { // GatherV2 has three inputs: data, indices, and axis by which data is sliced // batch_dims is an attribute and can vary default_op_checks(node, 3, {"GatherV2"}); auto axis = node.get_input(2); auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0); auto params = node.get_input(0); auto complex_type_mark = as_type_ptr<ComplexTypeMark>(params.get_node_shared_ptr()); if (complex_type_mark) { params = complex_type_mark->input_value(0); auto const_one = create_same_type_const_scalar<float>(axis, 1); auto minus_one = create_same_type_const_scalar<float>(axis, -1); if (make_shared<v1::Equal>(axis, minus_one)){ auto params_shape = make_shared<v3::ShapeOf>(params, ov::element::i32); auto params_rank = make_shared<v3::ShapeOf>(params_shape, ov::element::i32); axis = make_shared<v1::Subtract>(params_rank, make_shared<v0::Constant>(ov::element::i32, Shape{}, 1)); } else { axis = make_shared<v1::Add>(axis, const_one); } // Since complex type outputs would have an extra axis for Real & Image, we route the indicies to // be picked from the 2nd dimension instead auto indices = node.get_input(1); auto gather = make_shared<v8::Gather>(params, indices, axis, 0); set_node_name(node.get_name(), gather); auto complex_reshape = make_shared<ComplexTypeMark>(gather, complex_type_mark->get_complex_part_type()); return {complex_reshape->output(0)}; } else { auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 0); return translate_basic_gather_op(node, axis, 0); } return translate_basic_gather_op(node, axis, batch_dims); } OutputVector translate_gather_nd_op(const NodeContext& node) { // GatherND has two inputs: data and indices // batch_dims is always equal to 0 default_op_checks(node, 2, {"GatherNd", "GATHER_ND"}); auto input = node.get_input(0); auto input_indices = node.get_input(1); auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0); auto gather_nd = make_shared<v8::GatherND>(input, input_indices, batch_dims); set_node_name(node.get_name(), gather_nd); return {gather_nd}; } } // namespace op } // namespace tensorflow } // namespace frontend } // namespace ov
Editor is loading...
Leave a Comment