"""Tests for cast."""
import time
from absl import logging
from google3.experimental.users.jblespiau.pybind11 import cast
from google3.testing.pybase import googletest
class CastTest(googletest.TestCase):
def test_give_me_a_name(self):
num_iterations = 1000000
logging.info("Creating %d objects from Python", num_iterations)
for _ in range(3):
start = time.time()
for _ in range(num_iterations):
obj = cast.JaxCompiledFunction(1)
end = time.time()
logging.info("Pybind11 took %.2fs", end - start)
start = time.time()
for _ in range(num_iterations):
obj = cast.MakeJaxCompiledFunction(1)
end = time.time()
logging.info("Raw API took %.2fs", end - start)
logging.info("")
logging.info("Accessing the C++ object from the Python wrapped object")
c_api_obj = cast.MakeJaxCompiledFunction(1)
start = time.time()
for _ in range(num_iterations):
cast.AccessFieldPython(c_api_obj)
end = time.time()
logging.info("Raw C API cast took %.2fs", end - start)
pybind11_obj = cast.JaxCompiledFunction(1)
start = time.time()
for _ in range(num_iterations):
cast.AccessFieldPybind11(pybind11_obj)
end = time.time()
logging.info("Pybind11 took %.2fs", end - start)
logging.info("Accessing many C++ objects from the Python wrapped objects")
c_api_objs = [cast.MakeJaxCompiledFunction(i) for i in range(100)]
start = time.time()
for _ in range(num_iterations):
cast.AccessFieldsPython(c_api_objs)
end = time.time()
logging.info("Raw C API cast took %.2fs", end - start)
pybind11_objs = [cast.JaxCompiledFunction(i) for i in range(100)]
start = time.time()
for _ in range(num_iterations):
cast.AccessFieldsPybind11(pybind11_objs)
end = time.time()
logging.info("Pybind11 took %.2fs", end - start)
logging.info("")
logging.info("Creating many C++ objects and returning them to Python")
num_iterations = 100
start = time.time()
for _ in range(num_iterations):
cast.CreateManyObjectsPython(100000)
end = time.time()
logging.info("Raw C API cast took %.2fs", end - start)
start = time.time()
for _ in range(num_iterations):
cast.CreateManyObjectsPybind11(100000)
end = time.time()
logging.info("Raw C API cast took %.2fs", end - start)
if __name__ == "__main__":
googletest.main()