Untitled
unknown
plain_text
a year ago
3.1 kB
7
Indexable
class TestTFScatterNDComplex(CommonTFLayerTest): def create_tf_scatternd_placeholder_const_net(self, x_shape, indices, updates, ir_version, use_legacy_frontend): # # Create Tensorflow model # import tensorflow as tf tf.compat.v1.reset_default_graph() # Create the graph and model with tf.compat.v1.Session() as sess: x = tf.compat.v1.placeholder(tf.complex64, x_shape, 'Input') tf_indices = tf.constant(indices) tf_updates = tf.constant(updates, dtype=tf.complex64) scatter_nd = tf.scatter_nd(tf_indices, tf_updates, tf.shape(x), name="Operation") res = tf.add(x, scatter_nd) tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def ref_net = None return tf_net, ref_net test_data = [ pytest.param( dict(x_shape=[8], indices=[[4], [3], [1], [7]], updates=[9j, 10j, 11j, 12j]), marks=pytest.mark.precommit), pytest.param(dict(x_shape=[4, 4, 4], indices=[[0], [2]], updates= \ [[[[5j, 5j, 5j, 5j], [6j, 6j, 6j, 6j], [7j, 7j, 7j, 7j], [8j, 8j, 8j, 8j]], \ [[1j, 1j, 1j, 1j], [2j, 2j, 2j, 2j], [3j, 3j, 3j, 3j], [4j, 4j, 4j, 4j]]]])), pytest.param(dict(x_shape=[2, 2], indices=[[0]], updates=[[5j, 3j]])), pytest.param(dict(x_shape=[2, 2], indices=[[1, 1]], updates=[5j])), dict(x_shape=[1], indices=[[0]], updates=[3j]), dict(x_shape=[20], indices=[[0], [6], [9], [19], [13]], updates=[3j, 7j, -12j, 4j, -99j]), dict(x_shape=[4, 2], indices=[[1], [2]], updates=[[9j, 14j], [-76j, 0j]]), dict(x_shape=[4, 4, 4], indices=[[0], [1], [3]], updates=[ [[5j, 1j, 5j, 13j], [8j, 6j, 6j, 8j], [7j, 0j, 0j, 7j], [8j, 8j, 8j, 8j]], [[0j, 0j, 0j, 0j], [1j, 2j, 3j, 4j], [5j, 6j, 7j, 8j], [9j, 10j, 11j, 12j]], [[5j, 5j, 5j, 5j], [6j, 6j, 6j, 6j], [7j, 7j, 7j, 7j], [8j, 8j, 8j, 8j]]]), dict(x_shape=[2, 2, 2], indices=[[1, 1, 1], [0, 1, 0]], updates=[9j, 6.3j]), pytest.param(dict(x_shape=[2, 2, 2], indices=[[0, 0], [0, 1]], updates=[[6.7j, 9j], [45j, 8.3j]]), marks=pytest.mark.precommit_tf_fe), dict(x_shape=[2, 2, 2], indices=[[1]], updates=[[[6.7j, 9j], [45j, 8.3j]]]), ] @pytest.mark.parametrize("params", test_data) @pytest.mark.nightly @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', reason='Ticket - 122716') def test_tf_scatter_nd(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): self._test(*self.create_tf_scatternd_placeholder_const_net(**params, ir_version=ir_version, use_legacy_frontend=use_legacy_frontend), ie_device, precision, temp_dir=temp_dir, ir_version=ir_version, use_legacy_frontend=use_legacy_frontend, **params)
Editor is loading...
Leave a Comment