C++ wrapper
unknown
plain_text
4 years ago
9.5 kB
4
Indexable
#include "pybind11/cast.h" #include <Python.h> #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" #include "third_party/absl/status/status.h" #include "third_party/absl/status/statusor.h" namespace xla { namespace py = pybind11; class JaxCompiledFunction { public: explicit JaxCompiledFunction(int int_field) : int_field(int_field){}; int int_field; pybind11::handle AsPyHandle(); }; PyObject* JaxCompiledFunction_Type = nullptr; struct JaxCompiledFunctionObject { PyObject_HEAD; PyObject* dict; // Dictionary for __dict__ PyObject* weakrefs; // Weak references; for use by the Python interpreter. JaxCompiledFunction fun; }; bool JaxCompiledFunction_Check(py::handle handle) { return handle.get_type() == JaxCompiledFunction_Type; } JaxCompiledFunction* AsCompiledFunctionUnchecked(py::handle handle) { return &(reinterpret_cast<JaxCompiledFunctionObject*>(handle.ptr())->fun); } absl::StatusOr<JaxCompiledFunction*> AsCompiledFunction(py::handle handle) { if (!JaxCompiledFunction_Check(handle)) { return absl::InvalidArgumentError("Expected a CompiledFunction"); } return AsCompiledFunctionUnchecked(handle); } py::handle JaxCompiledFunction::AsPyHandle() { return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) - offsetof(JaxCompiledFunctionObject, fun)); } int AccessField(const JaxCompiledFunction& fun) { return fun.int_field; } int AccessFieldPython(py::handle handle) { assert(JaxCompiledFunction_Check(handle)); JaxCompiledFunction* f = AsCompiledFunctionUnchecked(handle); return f->int_field; } int AccessFieldsPython(std::vector<py::handle> handles) { int v = 0; for (const py::handle& handle : handles) { assert(JaxCompiledFunction_Check(handle)); JaxCompiledFunction* f = AsCompiledFunctionUnchecked(handle); v += f->int_field; } return v; } int AccessFieldPybind11(py::handle handle) { return py::cast<JaxCompiledFunction*>(handle)->int_field; } int AccessFieldsPybind11(std::vector<py::handle> handles) { int v = 0; for (const py::handle& handle : handles) { v += py::cast<JaxCompiledFunction*>(handle)->int_field; } return v; } // For the C API: https://docs.python.org/3/c-api/typeobj.html extern "C" { PyObject* JaxCompiledFunction_tp_new(PyTypeObject* subtype, PyObject* args, PyObject* kwds) { JaxCompiledFunctionObject* self = reinterpret_cast<JaxCompiledFunctionObject*>( subtype->tp_alloc(subtype, 0)); if (!self) return nullptr; self->dict = nullptr; self->weakrefs = nullptr; return reinterpret_cast<PyObject*>(self); } void JaxCompiledFunction_tp_dealloc(PyObject* self) { PyTypeObject* tp = Py_TYPE(self); JaxCompiledFunctionObject* o = reinterpret_cast<JaxCompiledFunctionObject*>(self); if (o->weakrefs) { PyObject_ClearWeakRefs(self); } Py_CLEAR(o->dict); o->fun.~JaxCompiledFunction(); tp->tp_free(self); Py_DECREF(tp); } int JaxCompiledFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { JaxCompiledFunctionObject* o = reinterpret_cast<JaxCompiledFunctionObject*>(self); Py_VISIT(o->dict); // Py_VISIT(o->fun.fun().ptr()); // Py_VISIT(o->fun.cache_miss().ptr()); // Py_VISIT(o->fun.get_device().ptr()); return 0; } int JaxCompiledFunction_tp_clear(PyObject* self) { JaxCompiledFunctionObject* o = reinterpret_cast<JaxCompiledFunctionObject*>(self); Py_CLEAR(o->dict); // o->fun.ClearPythonReferences(); return 0; } // Implements the Python descriptor protocol so JIT-compiled functions can be // used as bound methods. See: // https://docs.python.org/3/howto/descriptor.html#functions-and-methods PyObject* JaxCompiledFunction_tp_descr_get(PyObject* self, PyObject* obj, PyObject* type) { if (obj == nullptr || obj == Py_None) { Py_INCREF(self); return self; } return PyMethod_New(self, obj); } // Support d = instance.__dict__. PyObject* JaxCompiledFunction_get_dict(PyObject* self, void*) { JaxCompiledFunctionObject* o = reinterpret_cast<JaxCompiledFunctionObject*>(self); if (!o->dict) { o->dict = PyDict_New(); } Py_XINCREF(o->dict); return o->dict; } int JaxCompiledFunction_set_dict(PyObject* self, PyObject* new_dict, void*) { JaxCompiledFunctionObject* o = reinterpret_cast<JaxCompiledFunctionObject*>(self); if (!PyDict_Check(new_dict)) { PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%s'", Py_TYPE(new_dict)->tp_name); return -1; } Py_INCREF(new_dict); Py_CLEAR(o->dict); o->dict = new_dict; return 0; } static PyGetSetDef JaxCompiledFunction_tp_getset[] = { // Having a __dict__ seems necessary to allow !functool.wraps to override // __doc__. {const_cast<char*>("__dict__"), JaxCompiledFunction_get_dict, JaxCompiledFunction_set_dict, nullptr, nullptr}, {nullptr, nullptr, nullptr, nullptr, nullptr}}; PyObject* JaxCompiledFunction_tp_call(PyObject* self, PyObject* args, PyObject* kwargs) { // tensorflow::profiler::TraceMe traceme("JaxCompiledFunction::tp_call"); JaxCompiledFunctionObject* o = reinterpret_cast<JaxCompiledFunctionObject*>(self); absl::optional<py::kwargs> py_kwargs; if (kwargs) { py_kwargs = py::reinterpret_borrow<py::kwargs>(kwargs); } auto obj = py::object(py::str("ddede")); return obj.ptr(); // try { // xla::StatusOr<py::object> out = o->fun.Call(args, std::move(py_kwargs)); // if (!out.ok()) { // PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); // return nullptr; // } // return out.ValueOrDie().release().ptr(); // } catch (py::error_already_set& e) { // e.restore(); // return nullptr; // } catch (py::cast_error& e) { // PyErr_SetString(PyExc_ValueError, e.what()); // return nullptr; // } catch (std::invalid_argument& e) { // PyErr_SetString(PyExc_ValueError, e.what()); // return nullptr; // } } } // extern "C" py::object MakeJaxCompiledFunction(int int_field) { py::object obj = py::reinterpret_steal<py::object>(JaxCompiledFunction_tp_new( reinterpret_cast<PyTypeObject*>(JaxCompiledFunction_Type), nullptr, nullptr)); JaxCompiledFunctionObject* buf = reinterpret_cast<JaxCompiledFunctionObject*>(obj.ptr()); new (&buf->fun) JaxCompiledFunction(int_field); return obj; } std::vector<py::object> CreateManyObjectsPython(int num_obj) { std::vector<py::object> results; results.reserve(num_obj); for (int i = 0; i < num_obj; ++i) { results.push_back(MakeJaxCompiledFunction(i)); } return results; } std::vector<py::object> CreateManyObjectsPybind11(int num_obj) { std::vector<py::object> results; results.reserve(num_obj); for (int i = 0; i < num_obj; ++i) { results.push_back(py::cast(JaxCompiledFunction(i))); } return results; } PYBIND11_MODULE(cast, m) { // Initializes the NumPy API for the use of the types module. m.doc() = "Jax XLA function"; // We need to use heap-allocated type objects because we want to add // additional methods dynamically. py::object cfun; { py::str name = py::str("CompiledFunction"); py::str qualname = py::str("CompiledFunction"); PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>( PyType_Type.tp_alloc(&PyType_Type, 0)); // Caution: we must not call any functions that might invoke the GC until // PyType_Ready() is called. Otherwise the GC might see a half-constructed // type object. assert(heap_type); // << "Unable to create heap type object"; heap_type->ht_name = name.release().ptr(); heap_type->ht_qualname = qualname.release().ptr(); PyTypeObject* type = &heap_type->ht_type; type->tp_name = "CompiledFunction"; type->tp_basicsize = sizeof(JaxCompiledFunctionObject); type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC; type->tp_new = JaxCompiledFunction_tp_new; type->tp_dealloc = JaxCompiledFunction_tp_dealloc; type->tp_dictoffset = offsetof(JaxCompiledFunctionObject, dict); type->tp_traverse = JaxCompiledFunction_tp_traverse; type->tp_clear = JaxCompiledFunction_tp_clear; type->tp_weaklistoffset = offsetof(JaxCompiledFunctionObject, weakrefs); type->tp_getset = JaxCompiledFunction_tp_getset; type->tp_descr_get = JaxCompiledFunction_tp_descr_get; type->tp_call = JaxCompiledFunction_tp_call; assert(PyType_Ready(type) == 0); JaxCompiledFunction_Type = reinterpret_cast<PyObject*>(type); cfun = py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type); } py::class_<JaxCompiledFunction>(m, "JaxCompiledFunction") .def(py::init<int>()) .def_readwrite("int_field", &JaxCompiledFunction::int_field); m.def("MakeJaxCompiledFunction", &MakeJaxCompiledFunction); m.def("AccessFieldPython", &AccessFieldPython); m.def("AccessFieldPybind11", &AccessFieldPybind11); m.def("AccessFieldsPython", &AccessFieldsPython); m.def("AccessFieldsPybind11", &AccessFieldsPybind11); m.def("CreateManyObjectsPython", &CreateManyObjectsPython); m.def("CreateManyObjectsPybind11", &CreateManyObjectsPybind11); } // NOLINT(readability/fn_size) } // namespace xla
Editor is loading...