Untitled
unknown
plain_text
7 months ago
3.3 kB
3
Indexable
Never
import numpy as np import tensorflow as tf import pytest class TestApproximateEqual: def _prepare_input(self, inputs_info): assert 'tensor1:0' in inputs_info assert 'tensor2:0' in inputs_info tensor1_shape = inputs_info['tensor1:0'] tensor2_shape = inputs_info['tensor2:0'] inputs_data = {} inputs_data['tensor1:0'] = np.random.uniform(-10, 10, tensor1_shape).astype(np.float32) inputs_data['tensor2:0'] = np.random.uniform(-10, 10, tensor2_shape).astype(np.float32) return inputs_data def create_approximate_equal_net(self, input1_shape, input2_shape): tf.compat.v1.reset_default_graph() # Create the graph and model with tf.compat.v1.Session() as sess: tensor1 = tf.compat.v1.placeholder(tf.float32, input1_shape, 'tensor1') tensor2 = tf.compat.v1.placeholder(tf.float32, input2_shape, 'tensor2') approx_equal_op = tf.raw_ops.ApproximateEqual(x=tensor1, y=tensor2, tolerance=0.01) tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def return tf_net, None test_data_basic = [ dict(input1_shape=[2, 3], input2_shape=[2, 3]), dict(input1_shape=[3, 4, 5], input2_shape=[3, 4, 5]), dict(input1_shape=[1, 2, 3, 4], input2_shape=[1, 2, 3, 4]), ] @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.precommit_tf_fe @pytest.mark.nightly def test_approximate_equal_basic(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): # Execute the original TensorFlow model tf_net, _ = self.create_approximate_equal_net(**params) ie_device, precision, ir_version, temp_dir, use_legacy_frontend = ( ie_device, precision, ir_version, temp_dir, use_legacy_frontend) with self._ngraph_capture() as captured: self._test(tf_net, None, ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) # Extract the inputs to the model from captured tensors inputs_info = self._extract_input_info(captured) # Prepare inputs for ApproximateEqual check input_data = self._prepare_input(inputs_info) # Execute the TensorFlow model with tf.compat.v1.Session() as sess: original_output = sess.run(tf_net, feed_dict=input_data) # Execute the TensorFlow model again with ApproximateEqual check with self._approximation_check() as checker: self._test(tf_net, None, ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) # Extract the inputs to the model from captured tensors inputs_info = self._extract_input_info(checker) # Prepare inputs for ApproximateEqual check input_data = self._prepare_input(inputs_info) # Execute the TensorFlow model with ApproximateEqual check with tf.compat.v1.Session() as sess: approx_equal_output = sess.run(tf_net, feed_dict=input_data) # Perform the ApproximateEqual check assert np.all(approx_equal_output)
Leave a Comment