Untitled
unknown
python
a year ago
2.1 kB
3
Indexable
class TestComplexScatterND(CommonTFLayerTest): def create_complex_scatter_nd_net(self, x_shape, indices, updates, ir_version, use_legacy_frontend): import tensorflow as tf tf.compat.v1.reset_default_graph() # Create the graph and model with tf.compat.v1.Session() as sess: x = tf.compat.v1.placeholder(tf.float32, x_shape, 'Input') tf_indices = tf.compat.v1.placeholder(np.int32, [None], 'indices') updates_real = tf.compat.v1.placeholder(np.float32, [None], 'updates_real') updates_imag = tf.compat.v1.placeholder(np.float32, [None], 'updates_imag') updates = tf.raw_ops.Complex(real=updates_real,imag=updates_imag) scatter_nd = tf.raw_ops.ScatterNd(indices, updates, tf.shape(x), name="Operation") res = tf.add(x_shape, scatter_nd, name="Operation") real = tf.raw_ops.Real(input=res) img = tf.raw_ops.Imag(input=res) tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def return tf_net, None def prepare_input(self, inputs_info): rng = np.random.default_rng() assert 'indices' in inputs_info assert 'updates_real' in inputs_info assert 'updates_imag' in inputs_info assert 'x_shape' in inputs_info indices_shape = inputs_info['indices'] updates_real_shape = inputs_info['updates_real'] updates_imag_shape = inputs_info['updates_imag'] x_shape = inputs_info['x_shape'] inputs_data = {} inputs_data['indices'] = rng.integers(0, 10, indices_shape).astype(np.int32) # Example range (0, 10), adjust as needed inputs_data['updates_real'] = 4 * rng.random(updates_real_shape).astype(np.float32) - 2 inputs_data['updates_imag'] = 4 * rng.random(updates_imag_shape).astype(np.float32) - 2 inputs_data['x_shape'] = rng.integers(0, 10, indices_shape).astype(np.int32) return inputs_data
Editor is loading...
Leave a Comment