Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
5.1 kB
2
Indexable
Never
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/gather.hpp"

#include "common_op_table.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather_nd.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"

using namespace std;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_basic_gather_op(const NodeContext& node, const ov::Output<ov::Node>& axis, int64_t batch_dims) {
    auto op_type = node.get_op_type();
    TENSORFLOW_OP_VALIDATION(node, node.get_input_size() >= 2, op_type + " must have at least two inputs.");
    auto params = node.get_input(0);
    auto indices = node.get_input(1);
    auto gather = make_shared<v8::Gather>(params, indices, axis, batch_dims);
    set_node_name(node.get_name(), gather);
    return {gather};
}

OutputVector translate_gather_op(const NodeContext& node) {
    // Gather has two inputs: data and indices
    // axis by which data is sliced is always equal to 0, batch_dims is always equal to 0
    default_op_checks(node, 2, {"Gather"});
    auto params = node.get_input(0);
    auto complex_type_mark = as_type_ptr<ComplexTypeMark>(params.get_node_shared_ptr());

    if (complex_type_mark) {
        params = complex_type_mark->input_value(0);

        // Since complex type outputs would have an extra axis for Real & Image, we route the indicies to 
        // be picked from the 2nd dimension instead
        auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 1);
        auto indices = node.get_input(1);
        auto gather = make_shared<v8::Gather>(params, indices, axis, 0);
        set_node_name(node.get_name(), gather);
        auto complex_reshape = make_shared<ComplexTypeMark>(gather, complex_type_mark->get_complex_part_type());
        return {complex_reshape->output(0)};
    }
    else {
        auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 0);
        return translate_basic_gather_op(node, axis, 0);
    }
}

OutputVector translate_resource_gather_op(const NodeContext& node) {
    // ResourceGather has two inputs: data and indices
    // axis by which data is sliced is always equal to 0, batch_dims is an attribute and can vary
    default_op_checks(node, 2, {"ResourceGather"});
    auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 0);
    auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
    return translate_basic_gather_op(node, axis, batch_dims);
}

OutputVector translate_gather_v2_op(const NodeContext& node) {
    // GatherV2 has three inputs: data, indices, and axis by which data is sliced
    // batch_dims is an attribute and can vary
    default_op_checks(node, 3, {"GatherV2"});
    auto axis = node.get_input(2);
    auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
    auto params = node.get_input(0);

    auto complex_type_mark = as_type_ptr<ComplexTypeMark>(params.get_node_shared_ptr());

    if (complex_type_mark) {
        params = complex_type_mark->input_value(0);
        auto const_one = create_same_type_const_scalar<float>(axis, 1);
        auto minus_one = create_same_type_const_scalar<float>(axis, -1);

        if (make_shared<v1::Equal>(axis, minus_one)){
            auto params_shape = make_shared<v3::ShapeOf>(params, ov::element::i32);
            auto params_rank = make_shared<v3::ShapeOf>(params_shape, ov::element::i32);
            axis = make_shared<v1::Subtract>(params_rank, make_shared<v0::Constant>(ov::element::i32, Shape{}, 1));    
        }
        else {
           axis = make_shared<v1::Add>(axis, const_one);
        }
        // Since complex type outputs would have an extra axis for Real & Image, we route the indicies to 
        // be picked from the 2nd dimension instead
        auto indices = node.get_input(1);
        auto gather = make_shared<v8::Gather>(params, indices, axis, 0);

        set_node_name(node.get_name(), gather);
        auto complex_reshape = make_shared<ComplexTypeMark>(gather, complex_type_mark->get_complex_part_type());
        return {complex_reshape->output(0)};
    }
    else {
        auto axis = make_shared<v0::Constant>(element::i64, Shape{}, 0);
        return translate_basic_gather_op(node, axis, 0);
    }
    
    return translate_basic_gather_op(node, axis, batch_dims);
}

OutputVector translate_gather_nd_op(const NodeContext& node) {
    // GatherND has two inputs: data and indices
    // batch_dims is always equal to 0
    default_op_checks(node, 2, {"GatherNd", "GATHER_ND"});
    auto input = node.get_input(0);
    auto input_indices = node.get_input(1);
    auto batch_dims = node.get_attribute<int64_t>("batch_dims", 0);
    auto gather_nd = make_shared<v8::GatherND>(input, input_indices, batch_dims);
    set_node_name(node.get_name(), gather_nd);
    return {gather_nd};
}

}  // namespace op
}  // namespace tensorflow
}  // namespace frontend
}  // namespace ov
Leave a Comment