Untitled

 avatar
unknown
python
a year ago
2.1 kB
3
Indexable
class TestComplexScatterND(CommonTFLayerTest):
    def create_complex_scatter_nd_net(self, x_shape, indices, updates, ir_version,
                                                  use_legacy_frontend):
        import tensorflow as tf
        tf.compat.v1.reset_default_graph()
        # Create the graph and model
        with tf.compat.v1.Session() as sess:
            x = tf.compat.v1.placeholder(tf.float32, x_shape, 'Input')
            tf_indices = tf.compat.v1.placeholder(np.int32, [None], 'indices')
            updates_real = tf.compat.v1.placeholder(np.float32, [None], 'updates_real')
            updates_imag = tf.compat.v1.placeholder(np.float32, [None], 'updates_imag')
            updates = tf.raw_ops.Complex(real=updates_real,imag=updates_imag)

            scatter_nd = tf.raw_ops.ScatterNd(indices, updates, tf.shape(x), name="Operation")
            res = tf.add(x_shape, scatter_nd, name="Operation")
            real = tf.raw_ops.Real(input=res)
            img = tf.raw_ops.Imag(input=res)
            tf.compat.v1.global_variables_initializer()
            tf_net = sess.graph_def

        return tf_net, None
    
    def prepare_input(self, inputs_info):
        rng = np.random.default_rng()

        assert 'indices' in inputs_info
        assert 'updates_real' in inputs_info
        assert 'updates_imag' in inputs_info
        assert 'x_shape' in inputs_info

        indices_shape = inputs_info['indices']
        updates_real_shape = inputs_info['updates_real']
        updates_imag_shape = inputs_info['updates_imag']
        x_shape = inputs_info['x_shape']

        inputs_data = {}
        inputs_data['indices'] = rng.integers(0, 10, indices_shape).astype(np.int32)  # Example range (0, 10), adjust as needed
        inputs_data['updates_real'] = 4 * rng.random(updates_real_shape).astype(np.float32) - 2
        inputs_data['updates_imag'] = 4 * rng.random(updates_imag_shape).astype(np.float32) - 2
        inputs_data['x_shape'] = rng.integers(0, 10, indices_shape).astype(np.int32)

        return inputs_data
Editor is loading...
Leave a Comment