Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
3.1 kB
3
Indexable
Never
class TestTFScatterNDComplex(CommonTFLayerTest):
    def create_tf_scatternd_placeholder_const_net(self, x_shape, indices, updates, ir_version,
                                                  use_legacy_frontend):
        #
        #   Create Tensorflow model
        #

        import tensorflow as tf

        tf.compat.v1.reset_default_graph()

        # Create the graph and model
        with tf.compat.v1.Session() as sess:
            x = tf.compat.v1.placeholder(tf.complex64, x_shape, 'Input')
            tf_indices = tf.constant(indices)
            tf_updates = tf.constant(updates, dtype=tf.complex64)

            scatter_nd = tf.scatter_nd(tf_indices, tf_updates, tf.shape(x), name="Operation")
            res = tf.add(x, scatter_nd)
            tf.compat.v1.global_variables_initializer()

            tf_net = sess.graph_def

        ref_net = None

        return tf_net, ref_net

    test_data = [
        pytest.param(
            dict(x_shape=[8], indices=[[4], [3], [1], [7]], updates=[9j, 10j, 11j, 12j]),
            marks=pytest.mark.precommit),
        pytest.param(dict(x_shape=[4, 4, 4], indices=[[0], [2]], updates= \
            [[[[5j, 5j, 5j, 5j], [6j, 6j, 6j, 6j], [7j, 7j, 7j, 7j], [8j, 8j, 8j, 8j]], \
             [[1j, 1j, 1j, 1j], [2j, 2j, 2j, 2j], [3j, 3j, 3j, 3j], [4j, 4j, 4j, 4j]]]])),
        pytest.param(dict(x_shape=[2, 2], indices=[[0]], updates=[[5j, 3j]])),
        pytest.param(dict(x_shape=[2, 2], indices=[[1, 1]], updates=[5j])),
        dict(x_shape=[1], indices=[[0]], updates=[3j]),
        dict(x_shape=[20], indices=[[0], [6], [9], [19], [13]],
             updates=[3j, 7j, -12j, 4j, -99j]),
        dict(x_shape=[4, 2], indices=[[1], [2]], updates=[[9j, 14j], [-76j, 0j]]),
        dict(x_shape=[4, 4, 4], indices=[[0], [1], [3]], updates=[
            [[5j, 1j, 5j, 13j], [8j, 6j, 6j, 8j], [7j, 0j, 0j, 7j], [8j, 8j, 8j, 8j]],
            [[0j, 0j, 0j, 0j], [1j, 2j, 3j, 4j], [5j, 6j, 7j, 8j], [9j, 10j, 11j, 12j]],
            [[5j, 5j, 5j, 5j], [6j, 6j, 6j, 6j], [7j, 7j, 7j, 7j], [8j, 8j, 8j, 8j]]]),
        dict(x_shape=[2, 2, 2], indices=[[1, 1, 1], [0, 1, 0]], updates=[9j, 6.3j]),
        pytest.param(dict(x_shape=[2, 2, 2], indices=[[0, 0], [0, 1]], updates=[[6.7j, 9j], [45j, 8.3j]]),
                     marks=pytest.mark.precommit_tf_fe),
        dict(x_shape=[2, 2, 2], indices=[[1]], updates=[[[6.7j, 9j], [45j, 8.3j]]]),

    ]

    @pytest.mark.parametrize("params", test_data)
    @pytest.mark.nightly
    @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
                       reason='Ticket - 122716')
    def test_tf_scatter_nd(self, params, ie_device, precision, ir_version, temp_dir,
                           use_legacy_frontend):
        self._test(*self.create_tf_scatternd_placeholder_const_net(**params, ir_version=ir_version,
                                                                   use_legacy_frontend=use_legacy_frontend),
                   ie_device, precision, temp_dir=temp_dir, ir_version=ir_version,
                   use_legacy_frontend=use_legacy_frontend, **params)
Leave a Comment