Untitled
unknown
c_cpp
a year ago
1.0 kB
5
Indexable
OutputVector translate_scatter_nd_op(const NodeContext& node) { default_op_checks(node, 3, {"ScatterNd", "SCATTER_ND"}, true); auto input_indices = node.get_input(0); auto updates = node.get_input(1); auto shape = node.get_input(2); auto complex_type_mark_updates = as_type_ptr<ComplexTypeMark>(updates.get_node_shared_ptr()); if (complex_type_mark_updates) { updates = complex_type_mark_updates->input_value(0); } auto input_data = create_same_type_const<int32_t>(updates, vector<int32_t>{0}, Shape{1}); auto broadcast = make_shared<v3::Broadcast>(input_data, shape); auto scatter_nd = make_shared<v3::ScatterNDUpdate>(broadcast, input_indices, updates); set_node_name(node.get_name(), scatter_nd); if (complex_type_inputs) { auto complex_scatter_nd = make_shared<ComplexTypeMark>(scatter_nd, complex_type_mark_updates->get_complex_part_type()); return {complex_scatter_nd}; } return {scatter_nd}; }
Editor is loading...
Leave a Comment