Untitled
unknown
c_cpp
2 years ago
1.0 kB
6
Indexable
OutputVector translate_scatter_nd_op(const NodeContext& node) {
default_op_checks(node, 3, {"ScatterNd", "SCATTER_ND"}, true);
auto input_indices = node.get_input(0);
auto updates = node.get_input(1);
auto shape = node.get_input(2);
auto complex_type_mark_updates = as_type_ptr<ComplexTypeMark>(updates.get_node_shared_ptr());
if (complex_type_mark_updates) {
updates = complex_type_mark_updates->input_value(0);
}
auto input_data = create_same_type_const<int32_t>(updates, vector<int32_t>{0}, Shape{1});
auto broadcast = make_shared<v3::Broadcast>(input_data, shape);
auto scatter_nd = make_shared<v3::ScatterNDUpdate>(broadcast, input_indices, updates);
set_node_name(node.get_name(), scatter_nd);
if (complex_type_inputs) {
auto complex_scatter_nd =
make_shared<ComplexTypeMark>(scatter_nd, complex_type_mark_updates->get_complex_part_type());
return {complex_scatter_nd};
}
return {scatter_nd};
}Editor is loading...
Leave a Comment