Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
1.7 kB
2
Indexable
Never
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/reverse_sequence.hpp"
#include "helper_ops/complex_type_mark.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"

#include "common_op_table.hpp"

using namespace std;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_reverse_sequence_op(const NodeContext& node) {
    default_op_checks(node, 2, {"ReverseSequence"});
    auto input = node.get_input(0);
    auto seq_lengths = node.get_input(1);

    // retrieve attributes
    auto seq_dim = node.get_attribute<int64_t>("seq_dim");
    auto batch_dim = node.get_attribute<int64_t>("batch_dim", 0);

    auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
    if (complex_type_mark) {

        auto const_one = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
        auto updated_seq_dim = make_shared<v1::Add>(seq_dim, const_one);
        auto updated_batch_dim = make_shared<v1::Add>(batch_dim, const_one);

        auto reverse_sequence = make_shared<v0::ReverseSequence>(input, seq_lengths, updated_batch_dim, updated_seq_dim);
        set_node_name(node.get_name(), reverse_sequence);

        auto complex_reverse = make_shared<ComplexTypeMark>(reverse_sequence, complex_type_mark->get_complex_part_type());
        return {complex_reverse};
    }


    auto reverse_sequence = make_shared<v0::ReverseSequence>(input, seq_lengths, batch_dim, seq_dim);
    set_node_name(node.get_name(), reverse_sequence);
    return {reverse_sequence};
}
}  // namespace op
}  // namespace tensorflow
}  // namespace frontend
}  // namespace ov
Leave a Comment