Untitled
unknown
python
2 years ago
2.1 kB
13
Indexable
class TestComplexScatterND(CommonTFLayerTest):
def create_complex_scatter_nd_net(self, x_shape, indices, updates, ir_version,
use_legacy_frontend):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'Input')
tf_indices = tf.compat.v1.placeholder(np.int32, [None], 'indices')
updates_real = tf.compat.v1.placeholder(np.float32, [None], 'updates_real')
updates_imag = tf.compat.v1.placeholder(np.float32, [None], 'updates_imag')
updates = tf.raw_ops.Complex(real=updates_real,imag=updates_imag)
scatter_nd = tf.raw_ops.ScatterNd(indices, updates, tf.shape(x), name="Operation")
res = tf.add(x_shape, scatter_nd, name="Operation")
real = tf.raw_ops.Real(input=res)
img = tf.raw_ops.Imag(input=res)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
return tf_net, None
def prepare_input(self, inputs_info):
rng = np.random.default_rng()
assert 'indices' in inputs_info
assert 'updates_real' in inputs_info
assert 'updates_imag' in inputs_info
assert 'x_shape' in inputs_info
indices_shape = inputs_info['indices']
updates_real_shape = inputs_info['updates_real']
updates_imag_shape = inputs_info['updates_imag']
x_shape = inputs_info['x_shape']
inputs_data = {}
inputs_data['indices'] = rng.integers(0, 10, indices_shape).astype(np.int32) # Example range (0, 10), adjust as needed
inputs_data['updates_real'] = 4 * rng.random(updates_real_shape).astype(np.float32) - 2
inputs_data['updates_imag'] = 4 * rng.random(updates_imag_shape).astype(np.float32) - 2
inputs_data['x_shape'] = rng.integers(0, 10, indices_shape).astype(np.int32)
return inputs_dataEditor is loading...
Leave a Comment