Python code

 avatar
unknown
plain_text
4 years ago
2.4 kB
1
Indexable
"""Tests for cast."""

import time
from absl import logging

from google3.experimental.users.jblespiau.pybind11 import cast
from google3.testing.pybase import googletest


class CastTest(googletest.TestCase):

  def test_give_me_a_name(self):

    num_iterations = 1000000
    logging.info("Creating %d objects from Python", num_iterations)
    for _ in range(3):
      start = time.time()
      for _ in range(num_iterations):
        obj = cast.JaxCompiledFunction(1)
      end = time.time()
      logging.info("Pybind11 took %.2fs", end - start)

      start = time.time()
      for _ in range(num_iterations):
        obj = cast.MakeJaxCompiledFunction(1)
      end = time.time()
      logging.info("Raw API took %.2fs", end - start)

    logging.info("")
    logging.info("Accessing the C++ object from the Python wrapped object")
    c_api_obj = cast.MakeJaxCompiledFunction(1)
    start = time.time()
    for _ in range(num_iterations):
      cast.AccessFieldPython(c_api_obj)
    end = time.time()
    logging.info("Raw C API cast took %.2fs", end - start)

    pybind11_obj = cast.JaxCompiledFunction(1)
    start = time.time()
    for _ in range(num_iterations):
      cast.AccessFieldPybind11(pybind11_obj)
    end = time.time()
    logging.info("Pybind11 took %.2fs", end - start)

    logging.info("Accessing many C++ objects from the Python wrapped objects")
    c_api_objs = [cast.MakeJaxCompiledFunction(i) for i in range(100)]
    start = time.time()
    for _ in range(num_iterations):
      cast.AccessFieldsPython(c_api_objs)
    end = time.time()
    logging.info("Raw C API cast took %.2fs", end - start)

    pybind11_objs = [cast.JaxCompiledFunction(i) for i in range(100)]
    start = time.time()
    for _ in range(num_iterations):
      cast.AccessFieldsPybind11(pybind11_objs)
    end = time.time()
    logging.info("Pybind11 took %.2fs", end - start)

    logging.info("")
    logging.info("Creating many C++ objects and returning them to Python")
    num_iterations = 100

    start = time.time()
    for _ in range(num_iterations):
      cast.CreateManyObjectsPython(100000)
    end = time.time()
    logging.info("Raw C API cast took %.2fs", end - start)

    start = time.time()
    for _ in range(num_iterations):
      cast.CreateManyObjectsPybind11(100000)
    end = time.time()
    logging.info("Raw C API cast took %.2fs", end - start)



if __name__ == "__main__":
  googletest.main()