C++ code

It depends on absl, and tensorflow(just for CHECK).
mail@pastecode.io avatar
unknown
plain_text
3 years ago
9.7 kB
3
Indexable
Never
#include "pybind11/cast.h"

#include <Python.h>

#include "third_party/absl/status/status.h"
#include "third_party/absl/status/statusor.h"
#include "third_party/pybind11/include/pybind11/cast.h"
#include "third_party/pybind11/include/pybind11/numpy.h"
#include "third_party/pybind11/include/pybind11/pybind11.h"
#include "third_party/pybind11/include/pybind11/pytypes.h"
#include "third_party/pybind11/include/pybind11/stl.h"
#include "third_party/tensorflow/core/profiler/lib/traceme.h"

namespace xla {

namespace py = pybind11;

class JaxCompiledFunction {
 public:
  explicit JaxCompiledFunction(int int_field) : int_field(int_field){};
  int int_field;

  pybind11::handle AsPyHandle();
};

PyObject* JaxCompiledFunction_Type = nullptr;

struct JaxCompiledFunctionObject {
  PyObject_HEAD;
  PyObject* dict;      // Dictionary for __dict__
  PyObject* weakrefs;  // Weak references; for use by the Python interpreter.
  JaxCompiledFunction fun;
};

bool JaxCompiledFunction_Check(py::handle handle) {
  return handle.get_type() == JaxCompiledFunction_Type;
}

JaxCompiledFunction* AsCompiledFunctionUnchecked(py::handle handle) {
  return &(reinterpret_cast<JaxCompiledFunctionObject*>(handle.ptr())->fun);
}

absl::StatusOr<JaxCompiledFunction*> AsCompiledFunction(py::handle handle) {
  if (!JaxCompiledFunction_Check(handle)) {
    return absl::InvalidArgumentError("Expected a CompiledFunction");
  }
  return AsCompiledFunctionUnchecked(handle);
}

py::handle JaxCompiledFunction::AsPyHandle() {
  return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
                                     offsetof(JaxCompiledFunctionObject, fun));
}

int AccessField(const JaxCompiledFunction& fun) { return fun.int_field; }

int AccessFieldPython(py::handle handle) {
  CHECK(JaxCompiledFunction_Check(handle));
  JaxCompiledFunction* f = AsCompiledFunctionUnchecked(handle);
  return f->int_field;
}

int AccessFieldsPython(std::vector<py::handle> handles) {
  int v = 0;
  for (const py::handle& handle : handles) {
    CHECK(JaxCompiledFunction_Check(handle));
    JaxCompiledFunction* f = AsCompiledFunctionUnchecked(handle);
    v += f->int_field;
  }
  return v;
}

int AccessFieldPybind11(py::handle handle) {
  return py::cast<JaxCompiledFunction*>(handle)->int_field;
}

int AccessFieldsPybind11(std::vector<py::handle> handles) {
  int v = 0;
  for (const py::handle& handle : handles) {
    v += py::cast<JaxCompiledFunction*>(handle)->int_field;
  }
  return v;
}

// For the C API: https://docs.python.org/3/c-api/typeobj.html
extern "C" {

PyObject* JaxCompiledFunction_tp_new(PyTypeObject* subtype, PyObject* args,
                                     PyObject* kwds) {
  JaxCompiledFunctionObject* self =
      reinterpret_cast<JaxCompiledFunctionObject*>(
          subtype->tp_alloc(subtype, 0));
  if (!self) return nullptr;
  self->dict = nullptr;
  self->weakrefs = nullptr;
  return reinterpret_cast<PyObject*>(self);
}

void JaxCompiledFunction_tp_dealloc(PyObject* self) {
  PyTypeObject* tp = Py_TYPE(self);
  JaxCompiledFunctionObject* o =
      reinterpret_cast<JaxCompiledFunctionObject*>(self);
  if (o->weakrefs) {
    PyObject_ClearWeakRefs(self);
  }
  Py_CLEAR(o->dict);
  o->fun.~JaxCompiledFunction();
  tp->tp_free(self);
  Py_DECREF(tp);
}

int JaxCompiledFunction_tp_traverse(PyObject* self, visitproc visit,
                                    void* arg) {
  JaxCompiledFunctionObject* o =
      reinterpret_cast<JaxCompiledFunctionObject*>(self);
  Py_VISIT(o->dict);
  // Py_VISIT(o->fun.fun().ptr());
  // Py_VISIT(o->fun.cache_miss().ptr());
  // Py_VISIT(o->fun.get_device().ptr());
  return 0;
}

int JaxCompiledFunction_tp_clear(PyObject* self) {
  JaxCompiledFunctionObject* o =
      reinterpret_cast<JaxCompiledFunctionObject*>(self);
  Py_CLEAR(o->dict);
  // o->fun.ClearPythonReferences();
  return 0;
}

// Implements the Python descriptor protocol so JIT-compiled functions can be
// used as bound methods. See:
// https://docs.python.org/3/howto/descriptor.html#functions-and-methods
PyObject* JaxCompiledFunction_tp_descr_get(PyObject* self, PyObject* obj,
                                           PyObject* type) {
  if (obj == nullptr || obj == Py_None) {
    Py_INCREF(self);
    return self;
  }
  return PyMethod_New(self, obj);
}

// Support d = instance.__dict__.
PyObject* JaxCompiledFunction_get_dict(PyObject* self, void*) {
  JaxCompiledFunctionObject* o =
      reinterpret_cast<JaxCompiledFunctionObject*>(self);
  if (!o->dict) {
    o->dict = PyDict_New();
  }
  Py_XINCREF(o->dict);
  return o->dict;
}

int JaxCompiledFunction_set_dict(PyObject* self, PyObject* new_dict, void*) {
  JaxCompiledFunctionObject* o =
      reinterpret_cast<JaxCompiledFunctionObject*>(self);
  if (!PyDict_Check(new_dict)) {
    PyErr_Format(PyExc_TypeError,
                 "__dict__ must be set to a dictionary, not a '%s'",
                 Py_TYPE(new_dict)->tp_name);
    return -1;
  }
  Py_INCREF(new_dict);
  Py_CLEAR(o->dict);
  o->dict = new_dict;
  return 0;
}

static PyGetSetDef JaxCompiledFunction_tp_getset[] = {
    // Having a __dict__ seems necessary to allow !functool.wraps to override
    // __doc__.
    {const_cast<char*>("__dict__"), JaxCompiledFunction_get_dict,
     JaxCompiledFunction_set_dict, nullptr, nullptr},
    {nullptr, nullptr, nullptr, nullptr, nullptr}};

PyObject* JaxCompiledFunction_tp_call(PyObject* self, PyObject* args,
                                      PyObject* kwargs) {
  // tensorflow::profiler::TraceMe traceme("JaxCompiledFunction::tp_call");
  JaxCompiledFunctionObject* o =
      reinterpret_cast<JaxCompiledFunctionObject*>(self);
  absl::optional<py::kwargs> py_kwargs;
  if (kwargs) {
    py_kwargs = py::reinterpret_borrow<py::kwargs>(kwargs);
  }
  auto obj = py::object(py::str("ddede"));
  return obj.ptr();
  // try {
  //   xla::StatusOr<py::object> out = o->fun.Call(args, std::move(py_kwargs));
  //   if (!out.ok()) {
  //     PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str());
  //     return nullptr;
  //   }
  //   return out.ValueOrDie().release().ptr();
  // } catch (py::error_already_set& e) {
  //   e.restore();
  //   return nullptr;
  // } catch (py::cast_error& e) {
  //   PyErr_SetString(PyExc_ValueError, e.what());
  //   return nullptr;
  // } catch (std::invalid_argument& e) {
  //   PyErr_SetString(PyExc_ValueError, e.what());
  //   return nullptr;
  // }
}

}  // extern "C"

py::object MakeJaxCompiledFunction(int int_field) {
  py::object obj = py::reinterpret_steal<py::object>(JaxCompiledFunction_tp_new(
      reinterpret_cast<PyTypeObject*>(JaxCompiledFunction_Type), nullptr,
      nullptr));
  JaxCompiledFunctionObject* buf =
      reinterpret_cast<JaxCompiledFunctionObject*>(obj.ptr());

  new (&buf->fun) JaxCompiledFunction(int_field);
  return obj;
}

std::vector<py::object> CreateManyObjectsPython(int num_obj) {
  std::vector<py::object> results;
  results.reserve(num_obj);
  for (int i = 0; i < num_obj; ++i) {
    results.push_back(MakeJaxCompiledFunction(i));
  }
  return results;
}

std::vector<py::object> CreateManyObjectsPybind11(int num_obj) {
  std::vector<py::object> results;
  results.reserve(num_obj);
  for (int i = 0; i < num_obj; ++i) {
    results.push_back(py::cast(JaxCompiledFunction(i)));
  }
  return results;
}

PYBIND11_MODULE(cast, m) {
  // Initializes the NumPy API for the use of the types module.
  m.doc() = "Jax XLA function";

  // We need to use heap-allocated type objects because we want to add
  // additional methods dynamically.
  py::object cfun;
  {
    py::str name = py::str("CompiledFunction");
    py::str qualname = py::str("CompiledFunction");
    PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
        PyType_Type.tp_alloc(&PyType_Type, 0));
    // Caution: we must not call any functions that might invoke the GC until
    // PyType_Ready() is called. Otherwise the GC might see a half-constructed
    // type object.
    CHECK(heap_type);  //  << "Unable to create heap type object";
    heap_type->ht_name = name.release().ptr();
    heap_type->ht_qualname = qualname.release().ptr();
    PyTypeObject* type = &heap_type->ht_type;
    type->tp_name = "CompiledFunction";
    type->tp_basicsize = sizeof(JaxCompiledFunctionObject);
    type->tp_flags =
        Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC;
    type->tp_new = JaxCompiledFunction_tp_new;
    type->tp_dealloc = JaxCompiledFunction_tp_dealloc;
    type->tp_dictoffset = offsetof(JaxCompiledFunctionObject, dict);
    type->tp_traverse = JaxCompiledFunction_tp_traverse;
    type->tp_clear = JaxCompiledFunction_tp_clear;
    type->tp_weaklistoffset = offsetof(JaxCompiledFunctionObject, weakrefs);
    type->tp_getset = JaxCompiledFunction_tp_getset;
    type->tp_descr_get = JaxCompiledFunction_tp_descr_get;
    type->tp_call = JaxCompiledFunction_tp_call;
    CHECK(PyType_Ready(type) == 0);
    JaxCompiledFunction_Type = reinterpret_cast<PyObject*>(type);
    cfun = py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
  }

  py::class_<JaxCompiledFunction>(m, "JaxCompiledFunction")
      .def(py::init<int>())
      .def_readwrite("int_field", &JaxCompiledFunction::int_field);

  m.def("MakeJaxCompiledFunction", &MakeJaxCompiledFunction);
  m.def("AccessFieldPython", &AccessFieldPython);
  m.def("AccessFieldPybind11", &AccessFieldPybind11);
  m.def("AccessFieldsPython", &AccessFieldsPython);
  m.def("AccessFieldsPybind11", &AccessFieldsPybind11);

  m.def("CreateManyObjectsPython", &CreateManyObjectsPython);
  m.def("CreateManyObjectsPybind11", &CreateManyObjectsPybind11);

}  // NOLINT(readability/fn_size)

}  // namespace xla