1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9436c6c9cSStella Laurenzo #include "IRModule.h" 10436c6c9cSStella Laurenzo 11436c6c9cSStella Laurenzo #include "Globals.h" 12436c6c9cSStella Laurenzo #include "PybindUtils.h" 13436c6c9cSStella Laurenzo 14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h" 15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 174acd8457SAlex Zinenko #include "mlir-c/Debug.h" 183f3d1c90SMike Urbach #include "mlir-c/IR.h" 19436c6c9cSStella Laurenzo #include "mlir-c/Registration.h" 20436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h" 21436c6c9cSStella Laurenzo #include <pybind11/stl.h> 22436c6c9cSStella Laurenzo 23436c6c9cSStella Laurenzo namespace py = pybind11; 24436c6c9cSStella Laurenzo using namespace mlir; 25436c6c9cSStella Laurenzo using namespace mlir::python; 26436c6c9cSStella Laurenzo 27436c6c9cSStella Laurenzo using llvm::SmallVector; 28436c6c9cSStella Laurenzo using llvm::StringRef; 29436c6c9cSStella Laurenzo using llvm::Twine; 30436c6c9cSStella Laurenzo 31436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 32436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 33436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 34436c6c9cSStella Laurenzo 35436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] = 36436c6c9cSStella Laurenzo R"(Parses the assembly form of a type. 37436c6c9cSStella Laurenzo 38436c6c9cSStella Laurenzo Returns a Type object or raises a ValueError if the type cannot be parsed. 39436c6c9cSStella Laurenzo 40436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system 41436c6c9cSStella Laurenzo )"; 42436c6c9cSStella Laurenzo 43436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] = 44436c6c9cSStella Laurenzo R"(Gets a Location representing a file, line and column)"; 45436c6c9cSStella Laurenzo 46436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] = 47436c6c9cSStella Laurenzo R"(Parses a module's assembly format from a string. 48436c6c9cSStella Laurenzo 49436c6c9cSStella Laurenzo Returns a new MlirModule or raises a ValueError if the parsing fails. 50436c6c9cSStella Laurenzo 51436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/ 52436c6c9cSStella Laurenzo )"; 53436c6c9cSStella Laurenzo 54436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] = 55436c6c9cSStella Laurenzo R"(Creates a new operation. 56436c6c9cSStella Laurenzo 57436c6c9cSStella Laurenzo Args: 58436c6c9cSStella Laurenzo name: Operation name (e.g. "dialect.operation"). 59436c6c9cSStella Laurenzo results: Sequence of Type representing op result types. 60436c6c9cSStella Laurenzo attributes: Dict of str:Attribute. 61436c6c9cSStella Laurenzo successors: List of Block for the operation's successors. 62436c6c9cSStella Laurenzo regions: Number of regions to create. 63436c6c9cSStella Laurenzo location: A Location object (defaults to resolve from context manager). 64436c6c9cSStella Laurenzo ip: An InsertionPoint (defaults to resolve from context manager or set to 65436c6c9cSStella Laurenzo False to disable insertion, even with an insertion point set in the 66436c6c9cSStella Laurenzo context manager). 67436c6c9cSStella Laurenzo Returns: 68436c6c9cSStella Laurenzo A new "detached" Operation object. Detached operations can be added 69436c6c9cSStella Laurenzo to blocks, which causes them to become "attached." 70436c6c9cSStella Laurenzo )"; 71436c6c9cSStella Laurenzo 72436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] = 73436c6c9cSStella Laurenzo R"(Prints the assembly form of the operation to a file like object. 74436c6c9cSStella Laurenzo 75436c6c9cSStella Laurenzo Args: 76436c6c9cSStella Laurenzo file: The file like object to write to. Defaults to sys.stdout. 77436c6c9cSStella Laurenzo binary: Whether to write bytes (True) or str (False). Defaults to False. 78436c6c9cSStella Laurenzo large_elements_limit: Whether to elide elements attributes above this 79436c6c9cSStella Laurenzo number of elements. Defaults to None (no limit). 80436c6c9cSStella Laurenzo enable_debug_info: Whether to print debug/location information. Defaults 81436c6c9cSStella Laurenzo to False. 82436c6c9cSStella Laurenzo pretty_debug_info: Whether to format debug information for easier reading 83436c6c9cSStella Laurenzo by a human (warning: the result is unparseable). 84436c6c9cSStella Laurenzo print_generic_op_form: Whether to print the generic assembly forms of all 85436c6c9cSStella Laurenzo ops. Defaults to False. 86436c6c9cSStella Laurenzo use_local_Scope: Whether to print in a way that is more optimized for 87436c6c9cSStella Laurenzo multi-threaded access but may not be consistent with how the overall 88436c6c9cSStella Laurenzo module prints. 89436c6c9cSStella Laurenzo )"; 90436c6c9cSStella Laurenzo 91436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] = 92436c6c9cSStella Laurenzo R"(Gets the assembly form of the operation with all options available. 93436c6c9cSStella Laurenzo 94436c6c9cSStella Laurenzo Args: 95436c6c9cSStella Laurenzo binary: Whether to return a bytes (True) or str (False) object. Defaults to 96436c6c9cSStella Laurenzo False. 97436c6c9cSStella Laurenzo ... others ...: See the print() method for common keyword arguments for 98436c6c9cSStella Laurenzo configuring the printout. 99436c6c9cSStella Laurenzo Returns: 100436c6c9cSStella Laurenzo Either a bytes or str object, depending on the setting of the 'binary' 101436c6c9cSStella Laurenzo argument. 102436c6c9cSStella Laurenzo )"; 103436c6c9cSStella Laurenzo 104436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] = 105436c6c9cSStella Laurenzo R"(Gets the assembly form of the operation with default options. 106436c6c9cSStella Laurenzo 107436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed, 108436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to 109436c6c9cSStella Laurenzo customize behavior. 110436c6c9cSStella Laurenzo )"; 111436c6c9cSStella Laurenzo 112436c6c9cSStella Laurenzo static const char kDumpDocstring[] = 113436c6c9cSStella Laurenzo R"(Dumps a debug representation of the object to stderr.)"; 114436c6c9cSStella Laurenzo 115436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] = 116436c6c9cSStella Laurenzo R"(Appends a new block, with argument types as positional args. 117436c6c9cSStella Laurenzo 118436c6c9cSStella Laurenzo Returns: 119436c6c9cSStella Laurenzo The created block. 120436c6c9cSStella Laurenzo )"; 121436c6c9cSStella Laurenzo 122436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] = 123436c6c9cSStella Laurenzo R"(Returns the string form of the value. 124436c6c9cSStella Laurenzo 125436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the 126436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is 127436c6c9cSStella Laurenzo equivalent to printing the operation that produced it. 128436c6c9cSStella Laurenzo )"; 129436c6c9cSStella Laurenzo 130436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 131436c6c9cSStella Laurenzo // Utilities. 132436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 133436c6c9cSStella Laurenzo 1344acd8457SAlex Zinenko /// Helper for creating an @classmethod. 135436c6c9cSStella Laurenzo template <class Func, typename... Args> 136436c6c9cSStella Laurenzo py::object classmethod(Func f, Args... args) { 137436c6c9cSStella Laurenzo py::object cf = py::cpp_function(f, args...); 138436c6c9cSStella Laurenzo return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 139436c6c9cSStella Laurenzo } 140436c6c9cSStella Laurenzo 141436c6c9cSStella Laurenzo static py::object 142436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace, 143436c6c9cSStella Laurenzo py::object dialectDescriptor) { 144436c6c9cSStella Laurenzo auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 145436c6c9cSStella Laurenzo if (!dialectClass) { 146436c6c9cSStella Laurenzo // Use the base class. 147436c6c9cSStella Laurenzo return py::cast(PyDialect(std::move(dialectDescriptor))); 148436c6c9cSStella Laurenzo } 149436c6c9cSStella Laurenzo 150436c6c9cSStella Laurenzo // Create the custom implementation. 151436c6c9cSStella Laurenzo return (*dialectClass)(std::move(dialectDescriptor)); 152436c6c9cSStella Laurenzo } 153436c6c9cSStella Laurenzo 154436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 155436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 156436c6c9cSStella Laurenzo } 157436c6c9cSStella Laurenzo 1584acd8457SAlex Zinenko /// Wrapper for the global LLVM debugging flag. 1594acd8457SAlex Zinenko struct PyGlobalDebugFlag { 1604acd8457SAlex Zinenko static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 1614acd8457SAlex Zinenko 1624acd8457SAlex Zinenko static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 1634acd8457SAlex Zinenko 1644acd8457SAlex Zinenko static void bind(py::module &m) { 1654acd8457SAlex Zinenko // Debug flags. 1664acd8457SAlex Zinenko py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") 1674acd8457SAlex Zinenko .def_property_static("flag", &PyGlobalDebugFlag::get, 1684acd8457SAlex Zinenko &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 1694acd8457SAlex Zinenko } 1704acd8457SAlex Zinenko }; 1714acd8457SAlex Zinenko 172436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 173436c6c9cSStella Laurenzo // Collections. 174436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 175436c6c9cSStella Laurenzo 176436c6c9cSStella Laurenzo namespace { 177436c6c9cSStella Laurenzo 178436c6c9cSStella Laurenzo class PyRegionIterator { 179436c6c9cSStella Laurenzo public: 180436c6c9cSStella Laurenzo PyRegionIterator(PyOperationRef operation) 181436c6c9cSStella Laurenzo : operation(std::move(operation)) {} 182436c6c9cSStella Laurenzo 183436c6c9cSStella Laurenzo PyRegionIterator &dunderIter() { return *this; } 184436c6c9cSStella Laurenzo 185436c6c9cSStella Laurenzo PyRegion dunderNext() { 186436c6c9cSStella Laurenzo operation->checkValid(); 187436c6c9cSStella Laurenzo if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 188436c6c9cSStella Laurenzo throw py::stop_iteration(); 189436c6c9cSStella Laurenzo } 190436c6c9cSStella Laurenzo MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 191436c6c9cSStella Laurenzo return PyRegion(operation, region); 192436c6c9cSStella Laurenzo } 193436c6c9cSStella Laurenzo 194436c6c9cSStella Laurenzo static void bind(py::module &m) { 195436c6c9cSStella Laurenzo py::class_<PyRegionIterator>(m, "RegionIterator") 196436c6c9cSStella Laurenzo .def("__iter__", &PyRegionIterator::dunderIter) 197436c6c9cSStella Laurenzo .def("__next__", &PyRegionIterator::dunderNext); 198436c6c9cSStella Laurenzo } 199436c6c9cSStella Laurenzo 200436c6c9cSStella Laurenzo private: 201436c6c9cSStella Laurenzo PyOperationRef operation; 202436c6c9cSStella Laurenzo int nextIndex = 0; 203436c6c9cSStella Laurenzo }; 204436c6c9cSStella Laurenzo 205436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented 206436c6c9cSStella Laurenzo /// with a sequence-like container. 207436c6c9cSStella Laurenzo class PyRegionList { 208436c6c9cSStella Laurenzo public: 209436c6c9cSStella Laurenzo PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 210436c6c9cSStella Laurenzo 211436c6c9cSStella Laurenzo intptr_t dunderLen() { 212436c6c9cSStella Laurenzo operation->checkValid(); 213436c6c9cSStella Laurenzo return mlirOperationGetNumRegions(operation->get()); 214436c6c9cSStella Laurenzo } 215436c6c9cSStella Laurenzo 216436c6c9cSStella Laurenzo PyRegion dunderGetItem(intptr_t index) { 217436c6c9cSStella Laurenzo // dunderLen checks validity. 218436c6c9cSStella Laurenzo if (index < 0 || index >= dunderLen()) { 219436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 220436c6c9cSStella Laurenzo "attempt to access out of bounds region"); 221436c6c9cSStella Laurenzo } 222436c6c9cSStella Laurenzo MlirRegion region = mlirOperationGetRegion(operation->get(), index); 223436c6c9cSStella Laurenzo return PyRegion(operation, region); 224436c6c9cSStella Laurenzo } 225436c6c9cSStella Laurenzo 226436c6c9cSStella Laurenzo static void bind(py::module &m) { 227436c6c9cSStella Laurenzo py::class_<PyRegionList>(m, "RegionSequence") 228436c6c9cSStella Laurenzo .def("__len__", &PyRegionList::dunderLen) 229436c6c9cSStella Laurenzo .def("__getitem__", &PyRegionList::dunderGetItem); 230436c6c9cSStella Laurenzo } 231436c6c9cSStella Laurenzo 232436c6c9cSStella Laurenzo private: 233436c6c9cSStella Laurenzo PyOperationRef operation; 234436c6c9cSStella Laurenzo }; 235436c6c9cSStella Laurenzo 236436c6c9cSStella Laurenzo class PyBlockIterator { 237436c6c9cSStella Laurenzo public: 238436c6c9cSStella Laurenzo PyBlockIterator(PyOperationRef operation, MlirBlock next) 239436c6c9cSStella Laurenzo : operation(std::move(operation)), next(next) {} 240436c6c9cSStella Laurenzo 241436c6c9cSStella Laurenzo PyBlockIterator &dunderIter() { return *this; } 242436c6c9cSStella Laurenzo 243436c6c9cSStella Laurenzo PyBlock dunderNext() { 244436c6c9cSStella Laurenzo operation->checkValid(); 245436c6c9cSStella Laurenzo if (mlirBlockIsNull(next)) { 246436c6c9cSStella Laurenzo throw py::stop_iteration(); 247436c6c9cSStella Laurenzo } 248436c6c9cSStella Laurenzo 249436c6c9cSStella Laurenzo PyBlock returnBlock(operation, next); 250436c6c9cSStella Laurenzo next = mlirBlockGetNextInRegion(next); 251436c6c9cSStella Laurenzo return returnBlock; 252436c6c9cSStella Laurenzo } 253436c6c9cSStella Laurenzo 254436c6c9cSStella Laurenzo static void bind(py::module &m) { 255436c6c9cSStella Laurenzo py::class_<PyBlockIterator>(m, "BlockIterator") 256436c6c9cSStella Laurenzo .def("__iter__", &PyBlockIterator::dunderIter) 257436c6c9cSStella Laurenzo .def("__next__", &PyBlockIterator::dunderNext); 258436c6c9cSStella Laurenzo } 259436c6c9cSStella Laurenzo 260436c6c9cSStella Laurenzo private: 261436c6c9cSStella Laurenzo PyOperationRef operation; 262436c6c9cSStella Laurenzo MlirBlock next; 263436c6c9cSStella Laurenzo }; 264436c6c9cSStella Laurenzo 265436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 266436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize 267436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region. 268436c6c9cSStella Laurenzo class PyBlockList { 269436c6c9cSStella Laurenzo public: 270436c6c9cSStella Laurenzo PyBlockList(PyOperationRef operation, MlirRegion region) 271436c6c9cSStella Laurenzo : operation(std::move(operation)), region(region) {} 272436c6c9cSStella Laurenzo 273436c6c9cSStella Laurenzo PyBlockIterator dunderIter() { 274436c6c9cSStella Laurenzo operation->checkValid(); 275436c6c9cSStella Laurenzo return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 276436c6c9cSStella Laurenzo } 277436c6c9cSStella Laurenzo 278436c6c9cSStella Laurenzo intptr_t dunderLen() { 279436c6c9cSStella Laurenzo operation->checkValid(); 280436c6c9cSStella Laurenzo intptr_t count = 0; 281436c6c9cSStella Laurenzo MlirBlock block = mlirRegionGetFirstBlock(region); 282436c6c9cSStella Laurenzo while (!mlirBlockIsNull(block)) { 283436c6c9cSStella Laurenzo count += 1; 284436c6c9cSStella Laurenzo block = mlirBlockGetNextInRegion(block); 285436c6c9cSStella Laurenzo } 286436c6c9cSStella Laurenzo return count; 287436c6c9cSStella Laurenzo } 288436c6c9cSStella Laurenzo 289436c6c9cSStella Laurenzo PyBlock dunderGetItem(intptr_t index) { 290436c6c9cSStella Laurenzo operation->checkValid(); 291436c6c9cSStella Laurenzo if (index < 0) { 292436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 293436c6c9cSStella Laurenzo "attempt to access out of bounds block"); 294436c6c9cSStella Laurenzo } 295436c6c9cSStella Laurenzo MlirBlock block = mlirRegionGetFirstBlock(region); 296436c6c9cSStella Laurenzo while (!mlirBlockIsNull(block)) { 297436c6c9cSStella Laurenzo if (index == 0) { 298436c6c9cSStella Laurenzo return PyBlock(operation, block); 299436c6c9cSStella Laurenzo } 300436c6c9cSStella Laurenzo block = mlirBlockGetNextInRegion(block); 301436c6c9cSStella Laurenzo index -= 1; 302436c6c9cSStella Laurenzo } 303436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 304436c6c9cSStella Laurenzo } 305436c6c9cSStella Laurenzo 306436c6c9cSStella Laurenzo PyBlock appendBlock(py::args pyArgTypes) { 307436c6c9cSStella Laurenzo operation->checkValid(); 308436c6c9cSStella Laurenzo llvm::SmallVector<MlirType, 4> argTypes; 309436c6c9cSStella Laurenzo argTypes.reserve(pyArgTypes.size()); 310436c6c9cSStella Laurenzo for (auto &pyArg : pyArgTypes) { 311436c6c9cSStella Laurenzo argTypes.push_back(pyArg.cast<PyType &>()); 312436c6c9cSStella Laurenzo } 313436c6c9cSStella Laurenzo 314436c6c9cSStella Laurenzo MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 315436c6c9cSStella Laurenzo mlirRegionAppendOwnedBlock(region, block); 316436c6c9cSStella Laurenzo return PyBlock(operation, block); 317436c6c9cSStella Laurenzo } 318436c6c9cSStella Laurenzo 319436c6c9cSStella Laurenzo static void bind(py::module &m) { 320436c6c9cSStella Laurenzo py::class_<PyBlockList>(m, "BlockList") 321436c6c9cSStella Laurenzo .def("__getitem__", &PyBlockList::dunderGetItem) 322436c6c9cSStella Laurenzo .def("__iter__", &PyBlockList::dunderIter) 323436c6c9cSStella Laurenzo .def("__len__", &PyBlockList::dunderLen) 324436c6c9cSStella Laurenzo .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 325436c6c9cSStella Laurenzo } 326436c6c9cSStella Laurenzo 327436c6c9cSStella Laurenzo private: 328436c6c9cSStella Laurenzo PyOperationRef operation; 329436c6c9cSStella Laurenzo MlirRegion region; 330436c6c9cSStella Laurenzo }; 331436c6c9cSStella Laurenzo 332436c6c9cSStella Laurenzo class PyOperationIterator { 333436c6c9cSStella Laurenzo public: 334436c6c9cSStella Laurenzo PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 335436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), next(next) {} 336436c6c9cSStella Laurenzo 337436c6c9cSStella Laurenzo PyOperationIterator &dunderIter() { return *this; } 338436c6c9cSStella Laurenzo 339436c6c9cSStella Laurenzo py::object dunderNext() { 340436c6c9cSStella Laurenzo parentOperation->checkValid(); 341436c6c9cSStella Laurenzo if (mlirOperationIsNull(next)) { 342436c6c9cSStella Laurenzo throw py::stop_iteration(); 343436c6c9cSStella Laurenzo } 344436c6c9cSStella Laurenzo 345436c6c9cSStella Laurenzo PyOperationRef returnOperation = 346436c6c9cSStella Laurenzo PyOperation::forOperation(parentOperation->getContext(), next); 347436c6c9cSStella Laurenzo next = mlirOperationGetNextInBlock(next); 348436c6c9cSStella Laurenzo return returnOperation->createOpView(); 349436c6c9cSStella Laurenzo } 350436c6c9cSStella Laurenzo 351436c6c9cSStella Laurenzo static void bind(py::module &m) { 352436c6c9cSStella Laurenzo py::class_<PyOperationIterator>(m, "OperationIterator") 353436c6c9cSStella Laurenzo .def("__iter__", &PyOperationIterator::dunderIter) 354436c6c9cSStella Laurenzo .def("__next__", &PyOperationIterator::dunderNext); 355436c6c9cSStella Laurenzo } 356436c6c9cSStella Laurenzo 357436c6c9cSStella Laurenzo private: 358436c6c9cSStella Laurenzo PyOperationRef parentOperation; 359436c6c9cSStella Laurenzo MlirOperation next; 360436c6c9cSStella Laurenzo }; 361436c6c9cSStella Laurenzo 362436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In 363436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but 364436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned 365436c6c9cSStella Laurenzo /// by a block. 366436c6c9cSStella Laurenzo class PyOperationList { 367436c6c9cSStella Laurenzo public: 368436c6c9cSStella Laurenzo PyOperationList(PyOperationRef parentOperation, MlirBlock block) 369436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), block(block) {} 370436c6c9cSStella Laurenzo 371436c6c9cSStella Laurenzo PyOperationIterator dunderIter() { 372436c6c9cSStella Laurenzo parentOperation->checkValid(); 373436c6c9cSStella Laurenzo return PyOperationIterator(parentOperation, 374436c6c9cSStella Laurenzo mlirBlockGetFirstOperation(block)); 375436c6c9cSStella Laurenzo } 376436c6c9cSStella Laurenzo 377436c6c9cSStella Laurenzo intptr_t dunderLen() { 378436c6c9cSStella Laurenzo parentOperation->checkValid(); 379436c6c9cSStella Laurenzo intptr_t count = 0; 380436c6c9cSStella Laurenzo MlirOperation childOp = mlirBlockGetFirstOperation(block); 381436c6c9cSStella Laurenzo while (!mlirOperationIsNull(childOp)) { 382436c6c9cSStella Laurenzo count += 1; 383436c6c9cSStella Laurenzo childOp = mlirOperationGetNextInBlock(childOp); 384436c6c9cSStella Laurenzo } 385436c6c9cSStella Laurenzo return count; 386436c6c9cSStella Laurenzo } 387436c6c9cSStella Laurenzo 388436c6c9cSStella Laurenzo py::object dunderGetItem(intptr_t index) { 389436c6c9cSStella Laurenzo parentOperation->checkValid(); 390436c6c9cSStella Laurenzo if (index < 0) { 391436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 392436c6c9cSStella Laurenzo "attempt to access out of bounds operation"); 393436c6c9cSStella Laurenzo } 394436c6c9cSStella Laurenzo MlirOperation childOp = mlirBlockGetFirstOperation(block); 395436c6c9cSStella Laurenzo while (!mlirOperationIsNull(childOp)) { 396436c6c9cSStella Laurenzo if (index == 0) { 397436c6c9cSStella Laurenzo return PyOperation::forOperation(parentOperation->getContext(), childOp) 398436c6c9cSStella Laurenzo ->createOpView(); 399436c6c9cSStella Laurenzo } 400436c6c9cSStella Laurenzo childOp = mlirOperationGetNextInBlock(childOp); 401436c6c9cSStella Laurenzo index -= 1; 402436c6c9cSStella Laurenzo } 403436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 404436c6c9cSStella Laurenzo "attempt to access out of bounds operation"); 405436c6c9cSStella Laurenzo } 406436c6c9cSStella Laurenzo 407436c6c9cSStella Laurenzo static void bind(py::module &m) { 408436c6c9cSStella Laurenzo py::class_<PyOperationList>(m, "OperationList") 409436c6c9cSStella Laurenzo .def("__getitem__", &PyOperationList::dunderGetItem) 410436c6c9cSStella Laurenzo .def("__iter__", &PyOperationList::dunderIter) 411436c6c9cSStella Laurenzo .def("__len__", &PyOperationList::dunderLen); 412436c6c9cSStella Laurenzo } 413436c6c9cSStella Laurenzo 414436c6c9cSStella Laurenzo private: 415436c6c9cSStella Laurenzo PyOperationRef parentOperation; 416436c6c9cSStella Laurenzo MlirBlock block; 417436c6c9cSStella Laurenzo }; 418436c6c9cSStella Laurenzo 419436c6c9cSStella Laurenzo } // namespace 420436c6c9cSStella Laurenzo 421436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 422436c6c9cSStella Laurenzo // PyMlirContext 423436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 424436c6c9cSStella Laurenzo 425436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 426436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 427436c6c9cSStella Laurenzo auto &liveContexts = getLiveContexts(); 428436c6c9cSStella Laurenzo liveContexts[context.ptr] = this; 429436c6c9cSStella Laurenzo } 430436c6c9cSStella Laurenzo 431436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() { 432436c6c9cSStella Laurenzo // Note that the only public way to construct an instance is via the 433436c6c9cSStella Laurenzo // forContext method, which always puts the associated handle into 434436c6c9cSStella Laurenzo // liveContexts. 435436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 436436c6c9cSStella Laurenzo getLiveContexts().erase(context.ptr); 437436c6c9cSStella Laurenzo mlirContextDestroy(context); 438436c6c9cSStella Laurenzo } 439436c6c9cSStella Laurenzo 440436c6c9cSStella Laurenzo py::object PyMlirContext::getCapsule() { 441436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 442436c6c9cSStella Laurenzo } 443436c6c9cSStella Laurenzo 444436c6c9cSStella Laurenzo py::object PyMlirContext::createFromCapsule(py::object capsule) { 445436c6c9cSStella Laurenzo MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 446436c6c9cSStella Laurenzo if (mlirContextIsNull(rawContext)) 447436c6c9cSStella Laurenzo throw py::error_already_set(); 448436c6c9cSStella Laurenzo return forContext(rawContext).releaseObject(); 449436c6c9cSStella Laurenzo } 450436c6c9cSStella Laurenzo 451436c6c9cSStella Laurenzo PyMlirContext *PyMlirContext::createNewContextForInit() { 452436c6c9cSStella Laurenzo MlirContext context = mlirContextCreate(); 453436c6c9cSStella Laurenzo mlirRegisterAllDialects(context); 454436c6c9cSStella Laurenzo return new PyMlirContext(context); 455436c6c9cSStella Laurenzo } 456436c6c9cSStella Laurenzo 457436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 458436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 459436c6c9cSStella Laurenzo auto &liveContexts = getLiveContexts(); 460436c6c9cSStella Laurenzo auto it = liveContexts.find(context.ptr); 461436c6c9cSStella Laurenzo if (it == liveContexts.end()) { 462436c6c9cSStella Laurenzo // Create. 463436c6c9cSStella Laurenzo PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 464436c6c9cSStella Laurenzo py::object pyRef = py::cast(unownedContextWrapper); 465436c6c9cSStella Laurenzo assert(pyRef && "cast to py::object failed"); 466436c6c9cSStella Laurenzo liveContexts[context.ptr] = unownedContextWrapper; 467436c6c9cSStella Laurenzo return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 468436c6c9cSStella Laurenzo } 469436c6c9cSStella Laurenzo // Use existing. 470436c6c9cSStella Laurenzo py::object pyRef = py::cast(it->second); 471436c6c9cSStella Laurenzo return PyMlirContextRef(it->second, std::move(pyRef)); 472436c6c9cSStella Laurenzo } 473436c6c9cSStella Laurenzo 474436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 475436c6c9cSStella Laurenzo static LiveContextMap liveContexts; 476436c6c9cSStella Laurenzo return liveContexts; 477436c6c9cSStella Laurenzo } 478436c6c9cSStella Laurenzo 479436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 480436c6c9cSStella Laurenzo 481436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 482436c6c9cSStella Laurenzo 483436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 484436c6c9cSStella Laurenzo 485436c6c9cSStella Laurenzo pybind11::object PyMlirContext::contextEnter() { 486436c6c9cSStella Laurenzo return PyThreadContextEntry::pushContext(*this); 487436c6c9cSStella Laurenzo } 488436c6c9cSStella Laurenzo 489436c6c9cSStella Laurenzo void PyMlirContext::contextExit(pybind11::object excType, 490436c6c9cSStella Laurenzo pybind11::object excVal, 491436c6c9cSStella Laurenzo pybind11::object excTb) { 492436c6c9cSStella Laurenzo PyThreadContextEntry::popContext(*this); 493436c6c9cSStella Laurenzo } 494436c6c9cSStella Laurenzo 495436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() { 496436c6c9cSStella Laurenzo PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 497436c6c9cSStella Laurenzo if (!context) { 498436c6c9cSStella Laurenzo throw SetPyError( 499436c6c9cSStella Laurenzo PyExc_RuntimeError, 500436c6c9cSStella Laurenzo "An MLIR function requires a Context but none was provided in the call " 501436c6c9cSStella Laurenzo "or from the surrounding environment. Either pass to the function with " 502436c6c9cSStella Laurenzo "a 'context=' argument or establish a default using 'with Context():'"); 503436c6c9cSStella Laurenzo } 504436c6c9cSStella Laurenzo return *context; 505436c6c9cSStella Laurenzo } 506436c6c9cSStella Laurenzo 507436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 508436c6c9cSStella Laurenzo // PyThreadContextEntry management 509436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 510436c6c9cSStella Laurenzo 511436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 512436c6c9cSStella Laurenzo static thread_local std::vector<PyThreadContextEntry> stack; 513436c6c9cSStella Laurenzo return stack; 514436c6c9cSStella Laurenzo } 515436c6c9cSStella Laurenzo 516436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 517436c6c9cSStella Laurenzo auto &stack = getStack(); 518436c6c9cSStella Laurenzo if (stack.empty()) 519436c6c9cSStella Laurenzo return nullptr; 520436c6c9cSStella Laurenzo return &stack.back(); 521436c6c9cSStella Laurenzo } 522436c6c9cSStella Laurenzo 523436c6c9cSStella Laurenzo void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 524436c6c9cSStella Laurenzo py::object insertionPoint, 525436c6c9cSStella Laurenzo py::object location) { 526436c6c9cSStella Laurenzo auto &stack = getStack(); 527436c6c9cSStella Laurenzo stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 528436c6c9cSStella Laurenzo std::move(location)); 529436c6c9cSStella Laurenzo // If the new stack has more than one entry and the context of the new top 530436c6c9cSStella Laurenzo // entry matches the previous, copy the insertionPoint and location from the 531436c6c9cSStella Laurenzo // previous entry if missing from the new top entry. 532436c6c9cSStella Laurenzo if (stack.size() > 1) { 533436c6c9cSStella Laurenzo auto &prev = *(stack.rbegin() + 1); 534436c6c9cSStella Laurenzo auto ¤t = stack.back(); 535436c6c9cSStella Laurenzo if (current.context.is(prev.context)) { 536436c6c9cSStella Laurenzo // Default non-context objects from the previous entry. 537436c6c9cSStella Laurenzo if (!current.insertionPoint) 538436c6c9cSStella Laurenzo current.insertionPoint = prev.insertionPoint; 539436c6c9cSStella Laurenzo if (!current.location) 540436c6c9cSStella Laurenzo current.location = prev.location; 541436c6c9cSStella Laurenzo } 542436c6c9cSStella Laurenzo } 543436c6c9cSStella Laurenzo } 544436c6c9cSStella Laurenzo 545436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() { 546436c6c9cSStella Laurenzo if (!context) 547436c6c9cSStella Laurenzo return nullptr; 548436c6c9cSStella Laurenzo return py::cast<PyMlirContext *>(context); 549436c6c9cSStella Laurenzo } 550436c6c9cSStella Laurenzo 551436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 552436c6c9cSStella Laurenzo if (!insertionPoint) 553436c6c9cSStella Laurenzo return nullptr; 554436c6c9cSStella Laurenzo return py::cast<PyInsertionPoint *>(insertionPoint); 555436c6c9cSStella Laurenzo } 556436c6c9cSStella Laurenzo 557436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() { 558436c6c9cSStella Laurenzo if (!location) 559436c6c9cSStella Laurenzo return nullptr; 560436c6c9cSStella Laurenzo return py::cast<PyLocation *>(location); 561436c6c9cSStella Laurenzo } 562436c6c9cSStella Laurenzo 563436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() { 564436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 565436c6c9cSStella Laurenzo return tos ? tos->getContext() : nullptr; 566436c6c9cSStella Laurenzo } 567436c6c9cSStella Laurenzo 568436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 569436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 570436c6c9cSStella Laurenzo return tos ? tos->getInsertionPoint() : nullptr; 571436c6c9cSStella Laurenzo } 572436c6c9cSStella Laurenzo 573436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() { 574436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 575436c6c9cSStella Laurenzo return tos ? tos->getLocation() : nullptr; 576436c6c9cSStella Laurenzo } 577436c6c9cSStella Laurenzo 578436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 579436c6c9cSStella Laurenzo py::object contextObj = py::cast(context); 580436c6c9cSStella Laurenzo push(FrameKind::Context, /*context=*/contextObj, 581436c6c9cSStella Laurenzo /*insertionPoint=*/py::object(), 582436c6c9cSStella Laurenzo /*location=*/py::object()); 583436c6c9cSStella Laurenzo return contextObj; 584436c6c9cSStella Laurenzo } 585436c6c9cSStella Laurenzo 586436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) { 587436c6c9cSStella Laurenzo auto &stack = getStack(); 588436c6c9cSStella Laurenzo if (stack.empty()) 589436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 590436c6c9cSStella Laurenzo auto &tos = stack.back(); 591436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 592436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593436c6c9cSStella Laurenzo stack.pop_back(); 594436c6c9cSStella Laurenzo } 595436c6c9cSStella Laurenzo 596436c6c9cSStella Laurenzo py::object 597436c6c9cSStella Laurenzo PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 598436c6c9cSStella Laurenzo py::object contextObj = 599436c6c9cSStella Laurenzo insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 600436c6c9cSStella Laurenzo py::object insertionPointObj = py::cast(insertionPoint); 601436c6c9cSStella Laurenzo push(FrameKind::InsertionPoint, 602436c6c9cSStella Laurenzo /*context=*/contextObj, 603436c6c9cSStella Laurenzo /*insertionPoint=*/insertionPointObj, 604436c6c9cSStella Laurenzo /*location=*/py::object()); 605436c6c9cSStella Laurenzo return insertionPointObj; 606436c6c9cSStella Laurenzo } 607436c6c9cSStella Laurenzo 608436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 609436c6c9cSStella Laurenzo auto &stack = getStack(); 610436c6c9cSStella Laurenzo if (stack.empty()) 611436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, 612436c6c9cSStella Laurenzo "Unbalanced InsertionPoint enter/exit"); 613436c6c9cSStella Laurenzo auto &tos = stack.back(); 614436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::InsertionPoint && 615436c6c9cSStella Laurenzo tos.getInsertionPoint() != &insertionPoint) 616436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, 617436c6c9cSStella Laurenzo "Unbalanced InsertionPoint enter/exit"); 618436c6c9cSStella Laurenzo stack.pop_back(); 619436c6c9cSStella Laurenzo } 620436c6c9cSStella Laurenzo 621436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 622436c6c9cSStella Laurenzo py::object contextObj = location.getContext().getObject(); 623436c6c9cSStella Laurenzo py::object locationObj = py::cast(location); 624436c6c9cSStella Laurenzo push(FrameKind::Location, /*context=*/contextObj, 625436c6c9cSStella Laurenzo /*insertionPoint=*/py::object(), 626436c6c9cSStella Laurenzo /*location=*/locationObj); 627436c6c9cSStella Laurenzo return locationObj; 628436c6c9cSStella Laurenzo } 629436c6c9cSStella Laurenzo 630436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) { 631436c6c9cSStella Laurenzo auto &stack = getStack(); 632436c6c9cSStella Laurenzo if (stack.empty()) 633436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 634436c6c9cSStella Laurenzo auto &tos = stack.back(); 635436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 636436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637436c6c9cSStella Laurenzo stack.pop_back(); 638436c6c9cSStella Laurenzo } 639436c6c9cSStella Laurenzo 640436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 641436c6c9cSStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects 642436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 643436c6c9cSStella Laurenzo 644436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key, 645436c6c9cSStella Laurenzo bool attrError) { 646436c6c9cSStella Laurenzo // If the "std" dialect was asked for, substitute the empty namespace :( 647436c6c9cSStella Laurenzo static const std::string emptyKey; 648436c6c9cSStella Laurenzo const std::string *canonKey = key == "std" ? &emptyKey : &key; 649436c6c9cSStella Laurenzo MlirDialect dialect = mlirContextGetOrLoadDialect( 650436c6c9cSStella Laurenzo getContext()->get(), {canonKey->data(), canonKey->size()}); 651436c6c9cSStella Laurenzo if (mlirDialectIsNull(dialect)) { 652436c6c9cSStella Laurenzo throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 653436c6c9cSStella Laurenzo Twine("Dialect '") + key + "' not found"); 654436c6c9cSStella Laurenzo } 655436c6c9cSStella Laurenzo return dialect; 656436c6c9cSStella Laurenzo } 657436c6c9cSStella Laurenzo 658436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 659436c6c9cSStella Laurenzo // PyLocation 660436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 661436c6c9cSStella Laurenzo 662436c6c9cSStella Laurenzo py::object PyLocation::getCapsule() { 663436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 664436c6c9cSStella Laurenzo } 665436c6c9cSStella Laurenzo 666436c6c9cSStella Laurenzo PyLocation PyLocation::createFromCapsule(py::object capsule) { 667436c6c9cSStella Laurenzo MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 668436c6c9cSStella Laurenzo if (mlirLocationIsNull(rawLoc)) 669436c6c9cSStella Laurenzo throw py::error_already_set(); 670436c6c9cSStella Laurenzo return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 671436c6c9cSStella Laurenzo rawLoc); 672436c6c9cSStella Laurenzo } 673436c6c9cSStella Laurenzo 674436c6c9cSStella Laurenzo py::object PyLocation::contextEnter() { 675436c6c9cSStella Laurenzo return PyThreadContextEntry::pushLocation(*this); 676436c6c9cSStella Laurenzo } 677436c6c9cSStella Laurenzo 678436c6c9cSStella Laurenzo void PyLocation::contextExit(py::object excType, py::object excVal, 679436c6c9cSStella Laurenzo py::object excTb) { 680436c6c9cSStella Laurenzo PyThreadContextEntry::popLocation(*this); 681436c6c9cSStella Laurenzo } 682436c6c9cSStella Laurenzo 683436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() { 684436c6c9cSStella Laurenzo auto *location = PyThreadContextEntry::getDefaultLocation(); 685436c6c9cSStella Laurenzo if (!location) { 686436c6c9cSStella Laurenzo throw SetPyError( 687436c6c9cSStella Laurenzo PyExc_RuntimeError, 688436c6c9cSStella Laurenzo "An MLIR function requires a Location but none was provided in the " 689436c6c9cSStella Laurenzo "call or from the surrounding environment. Either pass to the function " 690436c6c9cSStella Laurenzo "with a 'loc=' argument or establish a default using 'with loc:'"); 691436c6c9cSStella Laurenzo } 692436c6c9cSStella Laurenzo return *location; 693436c6c9cSStella Laurenzo } 694436c6c9cSStella Laurenzo 695436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 696436c6c9cSStella Laurenzo // PyModule 697436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 698436c6c9cSStella Laurenzo 699436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 700436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), module(module) {} 701436c6c9cSStella Laurenzo 702436c6c9cSStella Laurenzo PyModule::~PyModule() { 703436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 704436c6c9cSStella Laurenzo auto &liveModules = getContext()->liveModules; 705436c6c9cSStella Laurenzo assert(liveModules.count(module.ptr) == 1 && 706436c6c9cSStella Laurenzo "destroying module not in live map"); 707436c6c9cSStella Laurenzo liveModules.erase(module.ptr); 708436c6c9cSStella Laurenzo mlirModuleDestroy(module); 709436c6c9cSStella Laurenzo } 710436c6c9cSStella Laurenzo 711436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) { 712436c6c9cSStella Laurenzo MlirContext context = mlirModuleGetContext(module); 713436c6c9cSStella Laurenzo PyMlirContextRef contextRef = PyMlirContext::forContext(context); 714436c6c9cSStella Laurenzo 715436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 716436c6c9cSStella Laurenzo auto &liveModules = contextRef->liveModules; 717436c6c9cSStella Laurenzo auto it = liveModules.find(module.ptr); 718436c6c9cSStella Laurenzo if (it == liveModules.end()) { 719436c6c9cSStella Laurenzo // Create. 720436c6c9cSStella Laurenzo PyModule *unownedModule = new PyModule(std::move(contextRef), module); 721436c6c9cSStella Laurenzo // Note that the default return value policy on cast is automatic_reference, 722436c6c9cSStella Laurenzo // which does not take ownership (delete will not be called). 723436c6c9cSStella Laurenzo // Just be explicit. 724436c6c9cSStella Laurenzo py::object pyRef = 725436c6c9cSStella Laurenzo py::cast(unownedModule, py::return_value_policy::take_ownership); 726436c6c9cSStella Laurenzo unownedModule->handle = pyRef; 727436c6c9cSStella Laurenzo liveModules[module.ptr] = 728436c6c9cSStella Laurenzo std::make_pair(unownedModule->handle, unownedModule); 729436c6c9cSStella Laurenzo return PyModuleRef(unownedModule, std::move(pyRef)); 730436c6c9cSStella Laurenzo } 731436c6c9cSStella Laurenzo // Use existing. 732436c6c9cSStella Laurenzo PyModule *existing = it->second.second; 733436c6c9cSStella Laurenzo py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 734436c6c9cSStella Laurenzo return PyModuleRef(existing, std::move(pyRef)); 735436c6c9cSStella Laurenzo } 736436c6c9cSStella Laurenzo 737436c6c9cSStella Laurenzo py::object PyModule::createFromCapsule(py::object capsule) { 738436c6c9cSStella Laurenzo MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 739436c6c9cSStella Laurenzo if (mlirModuleIsNull(rawModule)) 740436c6c9cSStella Laurenzo throw py::error_already_set(); 741436c6c9cSStella Laurenzo return forModule(rawModule).releaseObject(); 742436c6c9cSStella Laurenzo } 743436c6c9cSStella Laurenzo 744436c6c9cSStella Laurenzo py::object PyModule::getCapsule() { 745436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 746436c6c9cSStella Laurenzo } 747436c6c9cSStella Laurenzo 748436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 749436c6c9cSStella Laurenzo // PyOperation 750436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 751436c6c9cSStella Laurenzo 752436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 753436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), operation(operation) {} 754436c6c9cSStella Laurenzo 755436c6c9cSStella Laurenzo PyOperation::~PyOperation() { 75649745f87SMike Urbach // If the operation has already been invalidated there is nothing to do. 75749745f87SMike Urbach if (!valid) 75849745f87SMike Urbach return; 759436c6c9cSStella Laurenzo auto &liveOperations = getContext()->liveOperations; 760436c6c9cSStella Laurenzo assert(liveOperations.count(operation.ptr) == 1 && 761436c6c9cSStella Laurenzo "destroying operation not in live map"); 762436c6c9cSStella Laurenzo liveOperations.erase(operation.ptr); 763436c6c9cSStella Laurenzo if (!isAttached()) { 764436c6c9cSStella Laurenzo mlirOperationDestroy(operation); 765436c6c9cSStella Laurenzo } 766436c6c9cSStella Laurenzo } 767436c6c9cSStella Laurenzo 768436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 769436c6c9cSStella Laurenzo MlirOperation operation, 770436c6c9cSStella Laurenzo py::object parentKeepAlive) { 771436c6c9cSStella Laurenzo auto &liveOperations = contextRef->liveOperations; 772436c6c9cSStella Laurenzo // Create. 773436c6c9cSStella Laurenzo PyOperation *unownedOperation = 774436c6c9cSStella Laurenzo new PyOperation(std::move(contextRef), operation); 775436c6c9cSStella Laurenzo // Note that the default return value policy on cast is automatic_reference, 776436c6c9cSStella Laurenzo // which does not take ownership (delete will not be called). 777436c6c9cSStella Laurenzo // Just be explicit. 778436c6c9cSStella Laurenzo py::object pyRef = 779436c6c9cSStella Laurenzo py::cast(unownedOperation, py::return_value_policy::take_ownership); 780436c6c9cSStella Laurenzo unownedOperation->handle = pyRef; 781436c6c9cSStella Laurenzo if (parentKeepAlive) { 782436c6c9cSStella Laurenzo unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 783436c6c9cSStella Laurenzo } 784436c6c9cSStella Laurenzo liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 785436c6c9cSStella Laurenzo return PyOperationRef(unownedOperation, std::move(pyRef)); 786436c6c9cSStella Laurenzo } 787436c6c9cSStella Laurenzo 788436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 789436c6c9cSStella Laurenzo MlirOperation operation, 790436c6c9cSStella Laurenzo py::object parentKeepAlive) { 791436c6c9cSStella Laurenzo auto &liveOperations = contextRef->liveOperations; 792436c6c9cSStella Laurenzo auto it = liveOperations.find(operation.ptr); 793436c6c9cSStella Laurenzo if (it == liveOperations.end()) { 794436c6c9cSStella Laurenzo // Create. 795436c6c9cSStella Laurenzo return createInstance(std::move(contextRef), operation, 796436c6c9cSStella Laurenzo std::move(parentKeepAlive)); 797436c6c9cSStella Laurenzo } 798436c6c9cSStella Laurenzo // Use existing. 799436c6c9cSStella Laurenzo PyOperation *existing = it->second.second; 800436c6c9cSStella Laurenzo py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 801436c6c9cSStella Laurenzo return PyOperationRef(existing, std::move(pyRef)); 802436c6c9cSStella Laurenzo } 803436c6c9cSStella Laurenzo 804436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 805436c6c9cSStella Laurenzo MlirOperation operation, 806436c6c9cSStella Laurenzo py::object parentKeepAlive) { 807436c6c9cSStella Laurenzo auto &liveOperations = contextRef->liveOperations; 808436c6c9cSStella Laurenzo assert(liveOperations.count(operation.ptr) == 0 && 809436c6c9cSStella Laurenzo "cannot create detached operation that already exists"); 810436c6c9cSStella Laurenzo (void)liveOperations; 811436c6c9cSStella Laurenzo 812436c6c9cSStella Laurenzo PyOperationRef created = createInstance(std::move(contextRef), operation, 813436c6c9cSStella Laurenzo std::move(parentKeepAlive)); 814436c6c9cSStella Laurenzo created->attached = false; 815436c6c9cSStella Laurenzo return created; 816436c6c9cSStella Laurenzo } 817436c6c9cSStella Laurenzo 818436c6c9cSStella Laurenzo void PyOperation::checkValid() const { 819436c6c9cSStella Laurenzo if (!valid) { 820436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 821436c6c9cSStella Laurenzo } 822436c6c9cSStella Laurenzo } 823436c6c9cSStella Laurenzo 824436c6c9cSStella Laurenzo void PyOperationBase::print(py::object fileObject, bool binary, 825436c6c9cSStella Laurenzo llvm::Optional<int64_t> largeElementsLimit, 826436c6c9cSStella Laurenzo bool enableDebugInfo, bool prettyDebugInfo, 827436c6c9cSStella Laurenzo bool printGenericOpForm, bool useLocalScope) { 828436c6c9cSStella Laurenzo PyOperation &operation = getOperation(); 829436c6c9cSStella Laurenzo operation.checkValid(); 830436c6c9cSStella Laurenzo if (fileObject.is_none()) 831436c6c9cSStella Laurenzo fileObject = py::module::import("sys").attr("stdout"); 832436c6c9cSStella Laurenzo 833436c6c9cSStella Laurenzo if (!printGenericOpForm && !mlirOperationVerify(operation)) { 834436c6c9cSStella Laurenzo fileObject.attr("write")("// Verification failed, printing generic form\n"); 835436c6c9cSStella Laurenzo printGenericOpForm = true; 836436c6c9cSStella Laurenzo } 837436c6c9cSStella Laurenzo 838436c6c9cSStella Laurenzo MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 839436c6c9cSStella Laurenzo if (largeElementsLimit) 840436c6c9cSStella Laurenzo mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 841436c6c9cSStella Laurenzo if (enableDebugInfo) 842436c6c9cSStella Laurenzo mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 843436c6c9cSStella Laurenzo if (printGenericOpForm) 844436c6c9cSStella Laurenzo mlirOpPrintingFlagsPrintGenericOpForm(flags); 845436c6c9cSStella Laurenzo 846436c6c9cSStella Laurenzo PyFileAccumulator accum(fileObject, binary); 847436c6c9cSStella Laurenzo py::gil_scoped_release(); 848436c6c9cSStella Laurenzo mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 849436c6c9cSStella Laurenzo accum.getUserData()); 850436c6c9cSStella Laurenzo mlirOpPrintingFlagsDestroy(flags); 851436c6c9cSStella Laurenzo } 852436c6c9cSStella Laurenzo 853436c6c9cSStella Laurenzo py::object PyOperationBase::getAsm(bool binary, 854436c6c9cSStella Laurenzo llvm::Optional<int64_t> largeElementsLimit, 855436c6c9cSStella Laurenzo bool enableDebugInfo, bool prettyDebugInfo, 856436c6c9cSStella Laurenzo bool printGenericOpForm, 857436c6c9cSStella Laurenzo bool useLocalScope) { 858436c6c9cSStella Laurenzo py::object fileObject; 859436c6c9cSStella Laurenzo if (binary) { 860436c6c9cSStella Laurenzo fileObject = py::module::import("io").attr("BytesIO")(); 861436c6c9cSStella Laurenzo } else { 862436c6c9cSStella Laurenzo fileObject = py::module::import("io").attr("StringIO")(); 863436c6c9cSStella Laurenzo } 864436c6c9cSStella Laurenzo print(fileObject, /*binary=*/binary, 865436c6c9cSStella Laurenzo /*largeElementsLimit=*/largeElementsLimit, 866436c6c9cSStella Laurenzo /*enableDebugInfo=*/enableDebugInfo, 867436c6c9cSStella Laurenzo /*prettyDebugInfo=*/prettyDebugInfo, 868436c6c9cSStella Laurenzo /*printGenericOpForm=*/printGenericOpForm, 869436c6c9cSStella Laurenzo /*useLocalScope=*/useLocalScope); 870436c6c9cSStella Laurenzo 871436c6c9cSStella Laurenzo return fileObject.attr("getvalue")(); 872436c6c9cSStella Laurenzo } 873436c6c9cSStella Laurenzo 874436c6c9cSStella Laurenzo PyOperationRef PyOperation::getParentOperation() { 87549745f87SMike Urbach checkValid(); 876436c6c9cSStella Laurenzo if (!isAttached()) 877436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 878436c6c9cSStella Laurenzo MlirOperation operation = mlirOperationGetParentOperation(get()); 879436c6c9cSStella Laurenzo if (mlirOperationIsNull(operation)) 880436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "Operation has no parent."); 881436c6c9cSStella Laurenzo return PyOperation::forOperation(getContext(), operation); 882436c6c9cSStella Laurenzo } 883436c6c9cSStella Laurenzo 884436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() { 88549745f87SMike Urbach checkValid(); 886436c6c9cSStella Laurenzo PyOperationRef parentOperation = getParentOperation(); 887436c6c9cSStella Laurenzo MlirBlock block = mlirOperationGetBlock(get()); 888436c6c9cSStella Laurenzo assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 889436c6c9cSStella Laurenzo return PyBlock{std::move(parentOperation), block}; 890436c6c9cSStella Laurenzo } 891436c6c9cSStella Laurenzo 8920126e906SJohn Demme py::object PyOperation::getCapsule() { 89349745f87SMike Urbach checkValid(); 8940126e906SJohn Demme return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 8950126e906SJohn Demme } 8960126e906SJohn Demme 8970126e906SJohn Demme py::object PyOperation::createFromCapsule(py::object capsule) { 8980126e906SJohn Demme MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 8990126e906SJohn Demme if (mlirOperationIsNull(rawOperation)) 9000126e906SJohn Demme throw py::error_already_set(); 9010126e906SJohn Demme MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 9020126e906SJohn Demme return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 9030126e906SJohn Demme .releaseObject(); 9040126e906SJohn Demme } 9050126e906SJohn Demme 906436c6c9cSStella Laurenzo py::object PyOperation::create( 907436c6c9cSStella Laurenzo std::string name, llvm::Optional<std::vector<PyType *>> results, 908436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyValue *>> operands, 909436c6c9cSStella Laurenzo llvm::Optional<py::dict> attributes, 910436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyBlock *>> successors, int regions, 911436c6c9cSStella Laurenzo DefaultingPyLocation location, py::object maybeIp) { 912436c6c9cSStella Laurenzo llvm::SmallVector<MlirValue, 4> mlirOperands; 913436c6c9cSStella Laurenzo llvm::SmallVector<MlirType, 4> mlirResults; 914436c6c9cSStella Laurenzo llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 915436c6c9cSStella Laurenzo llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 916436c6c9cSStella Laurenzo 917436c6c9cSStella Laurenzo // General parameter validation. 918436c6c9cSStella Laurenzo if (regions < 0) 919436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 920436c6c9cSStella Laurenzo 921436c6c9cSStella Laurenzo // Unpack/validate operands. 922436c6c9cSStella Laurenzo if (operands) { 923436c6c9cSStella Laurenzo mlirOperands.reserve(operands->size()); 924436c6c9cSStella Laurenzo for (PyValue *operand : *operands) { 925436c6c9cSStella Laurenzo if (!operand) 926436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 927436c6c9cSStella Laurenzo mlirOperands.push_back(operand->get()); 928436c6c9cSStella Laurenzo } 929436c6c9cSStella Laurenzo } 930436c6c9cSStella Laurenzo 931436c6c9cSStella Laurenzo // Unpack/validate results. 932436c6c9cSStella Laurenzo if (results) { 933436c6c9cSStella Laurenzo mlirResults.reserve(results->size()); 934436c6c9cSStella Laurenzo for (PyType *result : *results) { 935436c6c9cSStella Laurenzo // TODO: Verify result type originate from the same context. 936436c6c9cSStella Laurenzo if (!result) 937436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "result type cannot be None"); 938436c6c9cSStella Laurenzo mlirResults.push_back(*result); 939436c6c9cSStella Laurenzo } 940436c6c9cSStella Laurenzo } 941436c6c9cSStella Laurenzo // Unpack/validate attributes. 942436c6c9cSStella Laurenzo if (attributes) { 943436c6c9cSStella Laurenzo mlirAttributes.reserve(attributes->size()); 944436c6c9cSStella Laurenzo for (auto &it : *attributes) { 945436c6c9cSStella Laurenzo std::string key; 946436c6c9cSStella Laurenzo try { 947436c6c9cSStella Laurenzo key = it.first.cast<std::string>(); 948436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 949436c6c9cSStella Laurenzo std::string msg = "Invalid attribute key (not a string) when " 950436c6c9cSStella Laurenzo "attempting to create the operation \"" + 951436c6c9cSStella Laurenzo name + "\" (" + err.what() + ")"; 952436c6c9cSStella Laurenzo throw py::cast_error(msg); 953436c6c9cSStella Laurenzo } 954436c6c9cSStella Laurenzo try { 955436c6c9cSStella Laurenzo auto &attribute = it.second.cast<PyAttribute &>(); 956436c6c9cSStella Laurenzo // TODO: Verify attribute originates from the same context. 957436c6c9cSStella Laurenzo mlirAttributes.emplace_back(std::move(key), attribute); 958436c6c9cSStella Laurenzo } catch (py::reference_cast_error &) { 959436c6c9cSStella Laurenzo // This exception seems thrown when the value is "None". 960436c6c9cSStella Laurenzo std::string msg = 961436c6c9cSStella Laurenzo "Found an invalid (`None`?) attribute value for the key \"" + key + 962436c6c9cSStella Laurenzo "\" when attempting to create the operation \"" + name + "\""; 963436c6c9cSStella Laurenzo throw py::cast_error(msg); 964436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 965436c6c9cSStella Laurenzo std::string msg = "Invalid attribute value for the key \"" + key + 966436c6c9cSStella Laurenzo "\" when attempting to create the operation \"" + 967436c6c9cSStella Laurenzo name + "\" (" + err.what() + ")"; 968436c6c9cSStella Laurenzo throw py::cast_error(msg); 969436c6c9cSStella Laurenzo } 970436c6c9cSStella Laurenzo } 971436c6c9cSStella Laurenzo } 972436c6c9cSStella Laurenzo // Unpack/validate successors. 973436c6c9cSStella Laurenzo if (successors) { 974436c6c9cSStella Laurenzo llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 975436c6c9cSStella Laurenzo mlirSuccessors.reserve(successors->size()); 976436c6c9cSStella Laurenzo for (auto *successor : *successors) { 977436c6c9cSStella Laurenzo // TODO: Verify successor originate from the same context. 978436c6c9cSStella Laurenzo if (!successor) 979436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 980436c6c9cSStella Laurenzo mlirSuccessors.push_back(successor->get()); 981436c6c9cSStella Laurenzo } 982436c6c9cSStella Laurenzo } 983436c6c9cSStella Laurenzo 984436c6c9cSStella Laurenzo // Apply unpacked/validated to the operation state. Beyond this 985436c6c9cSStella Laurenzo // point, exceptions cannot be thrown or else the state will leak. 986436c6c9cSStella Laurenzo MlirOperationState state = 987436c6c9cSStella Laurenzo mlirOperationStateGet(toMlirStringRef(name), location); 988436c6c9cSStella Laurenzo if (!mlirOperands.empty()) 989436c6c9cSStella Laurenzo mlirOperationStateAddOperands(&state, mlirOperands.size(), 990436c6c9cSStella Laurenzo mlirOperands.data()); 991436c6c9cSStella Laurenzo if (!mlirResults.empty()) 992436c6c9cSStella Laurenzo mlirOperationStateAddResults(&state, mlirResults.size(), 993436c6c9cSStella Laurenzo mlirResults.data()); 994436c6c9cSStella Laurenzo if (!mlirAttributes.empty()) { 995436c6c9cSStella Laurenzo // Note that the attribute names directly reference bytes in 996436c6c9cSStella Laurenzo // mlirAttributes, so that vector must not be changed from here 997436c6c9cSStella Laurenzo // on. 998436c6c9cSStella Laurenzo llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 999436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(mlirAttributes.size()); 1000436c6c9cSStella Laurenzo for (auto &it : mlirAttributes) 1001436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1002436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(it.second), 1003436c6c9cSStella Laurenzo toMlirStringRef(it.first)), 1004436c6c9cSStella Laurenzo it.second)); 1005436c6c9cSStella Laurenzo mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1006436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 1007436c6c9cSStella Laurenzo } 1008436c6c9cSStella Laurenzo if (!mlirSuccessors.empty()) 1009436c6c9cSStella Laurenzo mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1010436c6c9cSStella Laurenzo mlirSuccessors.data()); 1011436c6c9cSStella Laurenzo if (regions) { 1012436c6c9cSStella Laurenzo llvm::SmallVector<MlirRegion, 4> mlirRegions; 1013436c6c9cSStella Laurenzo mlirRegions.resize(regions); 1014436c6c9cSStella Laurenzo for (int i = 0; i < regions; ++i) 1015436c6c9cSStella Laurenzo mlirRegions[i] = mlirRegionCreate(); 1016436c6c9cSStella Laurenzo mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1017436c6c9cSStella Laurenzo mlirRegions.data()); 1018436c6c9cSStella Laurenzo } 1019436c6c9cSStella Laurenzo 1020436c6c9cSStella Laurenzo // Construct the operation. 1021436c6c9cSStella Laurenzo MlirOperation operation = mlirOperationCreate(&state); 1022436c6c9cSStella Laurenzo PyOperationRef created = 1023436c6c9cSStella Laurenzo PyOperation::createDetached(location->getContext(), operation); 1024436c6c9cSStella Laurenzo 1025436c6c9cSStella Laurenzo // InsertPoint active? 1026436c6c9cSStella Laurenzo if (!maybeIp.is(py::cast(false))) { 1027436c6c9cSStella Laurenzo PyInsertionPoint *ip; 1028436c6c9cSStella Laurenzo if (maybeIp.is_none()) { 1029436c6c9cSStella Laurenzo ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1030436c6c9cSStella Laurenzo } else { 1031436c6c9cSStella Laurenzo ip = py::cast<PyInsertionPoint *>(maybeIp); 1032436c6c9cSStella Laurenzo } 1033436c6c9cSStella Laurenzo if (ip) 1034436c6c9cSStella Laurenzo ip->insert(*created.get()); 1035436c6c9cSStella Laurenzo } 1036436c6c9cSStella Laurenzo 1037436c6c9cSStella Laurenzo return created->createOpView(); 1038436c6c9cSStella Laurenzo } 1039436c6c9cSStella Laurenzo 1040436c6c9cSStella Laurenzo py::object PyOperation::createOpView() { 104149745f87SMike Urbach checkValid(); 1042436c6c9cSStella Laurenzo MlirIdentifier ident = mlirOperationGetName(get()); 1043436c6c9cSStella Laurenzo MlirStringRef identStr = mlirIdentifierStr(ident); 1044436c6c9cSStella Laurenzo auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1045436c6c9cSStella Laurenzo StringRef(identStr.data, identStr.length)); 1046436c6c9cSStella Laurenzo if (opViewClass) 1047436c6c9cSStella Laurenzo return (*opViewClass)(getRef().getObject()); 1048436c6c9cSStella Laurenzo return py::cast(PyOpView(getRef().getObject())); 1049436c6c9cSStella Laurenzo } 1050436c6c9cSStella Laurenzo 105149745f87SMike Urbach void PyOperation::erase() { 105249745f87SMike Urbach checkValid(); 105349745f87SMike Urbach // TODO: Fix memory hazards when erasing a tree of operations for which a deep 105449745f87SMike Urbach // Python reference to a child operation is live. All children should also 105549745f87SMike Urbach // have their `valid` bit set to false. 105649745f87SMike Urbach auto &liveOperations = getContext()->liveOperations; 105749745f87SMike Urbach if (liveOperations.count(operation.ptr)) 105849745f87SMike Urbach liveOperations.erase(operation.ptr); 105949745f87SMike Urbach mlirOperationDestroy(operation); 106049745f87SMike Urbach valid = false; 106149745f87SMike Urbach } 106249745f87SMike Urbach 1063436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1064436c6c9cSStella Laurenzo // PyOpView 1065436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1066436c6c9cSStella Laurenzo 1067436c6c9cSStella Laurenzo py::object 1068436c6c9cSStella Laurenzo PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1069436c6c9cSStella Laurenzo py::list operandList, 1070436c6c9cSStella Laurenzo llvm::Optional<py::dict> attributes, 1071436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyBlock *>> successors, 1072436c6c9cSStella Laurenzo llvm::Optional<int> regions, 1073436c6c9cSStella Laurenzo DefaultingPyLocation location, py::object maybeIp) { 1074436c6c9cSStella Laurenzo PyMlirContextRef context = location->getContext(); 1075436c6c9cSStella Laurenzo // Class level operation construction metadata. 1076436c6c9cSStella Laurenzo std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1077436c6c9cSStella Laurenzo // Operand and result segment specs are either none, which does no 1078436c6c9cSStella Laurenzo // variadic unpacking, or a list of ints with segment sizes, where each 1079436c6c9cSStella Laurenzo // element is either a positive number (typically 1 for a scalar) or -1 to 1080436c6c9cSStella Laurenzo // indicate that it is derived from the length of the same-indexed operand 1081436c6c9cSStella Laurenzo // or result (implying that it is a list at that position). 1082436c6c9cSStella Laurenzo py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1083436c6c9cSStella Laurenzo py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1084436c6c9cSStella Laurenzo 10858d05a288SStella Laurenzo std::vector<uint32_t> operandSegmentLengths; 10868d05a288SStella Laurenzo std::vector<uint32_t> resultSegmentLengths; 1087436c6c9cSStella Laurenzo 1088436c6c9cSStella Laurenzo // Validate/determine region count. 1089436c6c9cSStella Laurenzo auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1090436c6c9cSStella Laurenzo int opMinRegionCount = std::get<0>(opRegionSpec); 1091436c6c9cSStella Laurenzo bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1092436c6c9cSStella Laurenzo if (!regions) { 1093436c6c9cSStella Laurenzo regions = opMinRegionCount; 1094436c6c9cSStella Laurenzo } 1095436c6c9cSStella Laurenzo if (*regions < opMinRegionCount) { 1096436c6c9cSStella Laurenzo throw py::value_error( 1097436c6c9cSStella Laurenzo (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1098436c6c9cSStella Laurenzo llvm::Twine(opMinRegionCount) + 1099436c6c9cSStella Laurenzo " regions but was built with regions=" + llvm::Twine(*regions)) 1100436c6c9cSStella Laurenzo .str()); 1101436c6c9cSStella Laurenzo } 1102436c6c9cSStella Laurenzo if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1103436c6c9cSStella Laurenzo throw py::value_error( 1104436c6c9cSStella Laurenzo (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1105436c6c9cSStella Laurenzo llvm::Twine(opMinRegionCount) + 1106436c6c9cSStella Laurenzo " regions but was built with regions=" + llvm::Twine(*regions)) 1107436c6c9cSStella Laurenzo .str()); 1108436c6c9cSStella Laurenzo } 1109436c6c9cSStella Laurenzo 1110436c6c9cSStella Laurenzo // Unpack results. 1111436c6c9cSStella Laurenzo std::vector<PyType *> resultTypes; 1112436c6c9cSStella Laurenzo resultTypes.reserve(resultTypeList.size()); 1113436c6c9cSStella Laurenzo if (resultSegmentSpecObj.is_none()) { 1114436c6c9cSStella Laurenzo // Non-variadic result unpacking. 1115436c6c9cSStella Laurenzo for (auto it : llvm::enumerate(resultTypeList)) { 1116436c6c9cSStella Laurenzo try { 1117436c6c9cSStella Laurenzo resultTypes.push_back(py::cast<PyType *>(it.value())); 1118436c6c9cSStella Laurenzo if (!resultTypes.back()) 1119436c6c9cSStella Laurenzo throw py::cast_error(); 1120436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1121436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Result ") + 1122436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1123436c6c9cSStella Laurenzo name + "\" must be a Type (" + err.what() + ")") 1124436c6c9cSStella Laurenzo .str()); 1125436c6c9cSStella Laurenzo } 1126436c6c9cSStella Laurenzo } 1127436c6c9cSStella Laurenzo } else { 1128436c6c9cSStella Laurenzo // Sized result unpacking. 1129436c6c9cSStella Laurenzo auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1130436c6c9cSStella Laurenzo if (resultSegmentSpec.size() != resultTypeList.size()) { 1131436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operation \"") + name + 1132436c6c9cSStella Laurenzo "\" requires " + 1133436c6c9cSStella Laurenzo llvm::Twine(resultSegmentSpec.size()) + 1134436c6c9cSStella Laurenzo "result segments but was provided " + 1135436c6c9cSStella Laurenzo llvm::Twine(resultTypeList.size())) 1136436c6c9cSStella Laurenzo .str()); 1137436c6c9cSStella Laurenzo } 1138436c6c9cSStella Laurenzo resultSegmentLengths.reserve(resultTypeList.size()); 1139436c6c9cSStella Laurenzo for (auto it : 1140436c6c9cSStella Laurenzo llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1141436c6c9cSStella Laurenzo int segmentSpec = std::get<1>(it.value()); 1142436c6c9cSStella Laurenzo if (segmentSpec == 1 || segmentSpec == 0) { 1143436c6c9cSStella Laurenzo // Unpack unary element. 1144436c6c9cSStella Laurenzo try { 1145436c6c9cSStella Laurenzo auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1146436c6c9cSStella Laurenzo if (resultType) { 1147436c6c9cSStella Laurenzo resultTypes.push_back(resultType); 1148436c6c9cSStella Laurenzo resultSegmentLengths.push_back(1); 1149436c6c9cSStella Laurenzo } else if (segmentSpec == 0) { 1150436c6c9cSStella Laurenzo // Allowed to be optional. 1151436c6c9cSStella Laurenzo resultSegmentLengths.push_back(0); 1152436c6c9cSStella Laurenzo } else { 1153436c6c9cSStella Laurenzo throw py::cast_error("was None and result is not optional"); 1154436c6c9cSStella Laurenzo } 1155436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1156436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Result ") + 1157436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1158436c6c9cSStella Laurenzo name + "\" must be a Type (" + err.what() + 1159436c6c9cSStella Laurenzo ")") 1160436c6c9cSStella Laurenzo .str()); 1161436c6c9cSStella Laurenzo } 1162436c6c9cSStella Laurenzo } else if (segmentSpec == -1) { 1163436c6c9cSStella Laurenzo // Unpack sequence by appending. 1164436c6c9cSStella Laurenzo try { 1165436c6c9cSStella Laurenzo if (std::get<0>(it.value()).is_none()) { 1166436c6c9cSStella Laurenzo // Treat it as an empty list. 1167436c6c9cSStella Laurenzo resultSegmentLengths.push_back(0); 1168436c6c9cSStella Laurenzo } else { 1169436c6c9cSStella Laurenzo // Unpack the list. 1170436c6c9cSStella Laurenzo auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1171436c6c9cSStella Laurenzo for (py::object segmentItem : segment) { 1172436c6c9cSStella Laurenzo resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1173436c6c9cSStella Laurenzo if (!resultTypes.back()) { 1174436c6c9cSStella Laurenzo throw py::cast_error("contained a None item"); 1175436c6c9cSStella Laurenzo } 1176436c6c9cSStella Laurenzo } 1177436c6c9cSStella Laurenzo resultSegmentLengths.push_back(segment.size()); 1178436c6c9cSStella Laurenzo } 1179436c6c9cSStella Laurenzo } catch (std::exception &err) { 1180436c6c9cSStella Laurenzo // NOTE: Sloppy to be using a catch-all here, but there are at least 1181436c6c9cSStella Laurenzo // three different unrelated exceptions that can be thrown in the 1182436c6c9cSStella Laurenzo // above "casts". Just keep the scope above small and catch them all. 1183436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Result ") + 1184436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1185436c6c9cSStella Laurenzo name + "\" must be a Sequence of Types (" + 1186436c6c9cSStella Laurenzo err.what() + ")") 1187436c6c9cSStella Laurenzo .str()); 1188436c6c9cSStella Laurenzo } 1189436c6c9cSStella Laurenzo } else { 1190436c6c9cSStella Laurenzo throw py::value_error("Unexpected segment spec"); 1191436c6c9cSStella Laurenzo } 1192436c6c9cSStella Laurenzo } 1193436c6c9cSStella Laurenzo } 1194436c6c9cSStella Laurenzo 1195436c6c9cSStella Laurenzo // Unpack operands. 1196436c6c9cSStella Laurenzo std::vector<PyValue *> operands; 1197436c6c9cSStella Laurenzo operands.reserve(operands.size()); 1198436c6c9cSStella Laurenzo if (operandSegmentSpecObj.is_none()) { 1199436c6c9cSStella Laurenzo // Non-sized operand unpacking. 1200436c6c9cSStella Laurenzo for (auto it : llvm::enumerate(operandList)) { 1201436c6c9cSStella Laurenzo try { 1202436c6c9cSStella Laurenzo operands.push_back(py::cast<PyValue *>(it.value())); 1203436c6c9cSStella Laurenzo if (!operands.back()) 1204436c6c9cSStella Laurenzo throw py::cast_error(); 1205436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1206436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operand ") + 1207436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1208436c6c9cSStella Laurenzo name + "\" must be a Value (" + err.what() + ")") 1209436c6c9cSStella Laurenzo .str()); 1210436c6c9cSStella Laurenzo } 1211436c6c9cSStella Laurenzo } 1212436c6c9cSStella Laurenzo } else { 1213436c6c9cSStella Laurenzo // Sized operand unpacking. 1214436c6c9cSStella Laurenzo auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1215436c6c9cSStella Laurenzo if (operandSegmentSpec.size() != operandList.size()) { 1216436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operation \"") + name + 1217436c6c9cSStella Laurenzo "\" requires " + 1218436c6c9cSStella Laurenzo llvm::Twine(operandSegmentSpec.size()) + 1219436c6c9cSStella Laurenzo "operand segments but was provided " + 1220436c6c9cSStella Laurenzo llvm::Twine(operandList.size())) 1221436c6c9cSStella Laurenzo .str()); 1222436c6c9cSStella Laurenzo } 1223436c6c9cSStella Laurenzo operandSegmentLengths.reserve(operandList.size()); 1224436c6c9cSStella Laurenzo for (auto it : 1225436c6c9cSStella Laurenzo llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1226436c6c9cSStella Laurenzo int segmentSpec = std::get<1>(it.value()); 1227436c6c9cSStella Laurenzo if (segmentSpec == 1 || segmentSpec == 0) { 1228436c6c9cSStella Laurenzo // Unpack unary element. 1229436c6c9cSStella Laurenzo try { 1230436c6c9cSStella Laurenzo auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1231436c6c9cSStella Laurenzo if (operandValue) { 1232436c6c9cSStella Laurenzo operands.push_back(operandValue); 1233436c6c9cSStella Laurenzo operandSegmentLengths.push_back(1); 1234436c6c9cSStella Laurenzo } else if (segmentSpec == 0) { 1235436c6c9cSStella Laurenzo // Allowed to be optional. 1236436c6c9cSStella Laurenzo operandSegmentLengths.push_back(0); 1237436c6c9cSStella Laurenzo } else { 1238436c6c9cSStella Laurenzo throw py::cast_error("was None and operand is not optional"); 1239436c6c9cSStella Laurenzo } 1240436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1241436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operand ") + 1242436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1243436c6c9cSStella Laurenzo name + "\" must be a Value (" + err.what() + 1244436c6c9cSStella Laurenzo ")") 1245436c6c9cSStella Laurenzo .str()); 1246436c6c9cSStella Laurenzo } 1247436c6c9cSStella Laurenzo } else if (segmentSpec == -1) { 1248436c6c9cSStella Laurenzo // Unpack sequence by appending. 1249436c6c9cSStella Laurenzo try { 1250436c6c9cSStella Laurenzo if (std::get<0>(it.value()).is_none()) { 1251436c6c9cSStella Laurenzo // Treat it as an empty list. 1252436c6c9cSStella Laurenzo operandSegmentLengths.push_back(0); 1253436c6c9cSStella Laurenzo } else { 1254436c6c9cSStella Laurenzo // Unpack the list. 1255436c6c9cSStella Laurenzo auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1256436c6c9cSStella Laurenzo for (py::object segmentItem : segment) { 1257436c6c9cSStella Laurenzo operands.push_back(py::cast<PyValue *>(segmentItem)); 1258436c6c9cSStella Laurenzo if (!operands.back()) { 1259436c6c9cSStella Laurenzo throw py::cast_error("contained a None item"); 1260436c6c9cSStella Laurenzo } 1261436c6c9cSStella Laurenzo } 1262436c6c9cSStella Laurenzo operandSegmentLengths.push_back(segment.size()); 1263436c6c9cSStella Laurenzo } 1264436c6c9cSStella Laurenzo } catch (std::exception &err) { 1265436c6c9cSStella Laurenzo // NOTE: Sloppy to be using a catch-all here, but there are at least 1266436c6c9cSStella Laurenzo // three different unrelated exceptions that can be thrown in the 1267436c6c9cSStella Laurenzo // above "casts". Just keep the scope above small and catch them all. 1268436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operand ") + 1269436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1270436c6c9cSStella Laurenzo name + "\" must be a Sequence of Values (" + 1271436c6c9cSStella Laurenzo err.what() + ")") 1272436c6c9cSStella Laurenzo .str()); 1273436c6c9cSStella Laurenzo } 1274436c6c9cSStella Laurenzo } else { 1275436c6c9cSStella Laurenzo throw py::value_error("Unexpected segment spec"); 1276436c6c9cSStella Laurenzo } 1277436c6c9cSStella Laurenzo } 1278436c6c9cSStella Laurenzo } 1279436c6c9cSStella Laurenzo 1280436c6c9cSStella Laurenzo // Merge operand/result segment lengths into attributes if needed. 1281436c6c9cSStella Laurenzo if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1282436c6c9cSStella Laurenzo // Dup. 1283436c6c9cSStella Laurenzo if (attributes) { 1284436c6c9cSStella Laurenzo attributes = py::dict(*attributes); 1285436c6c9cSStella Laurenzo } else { 1286436c6c9cSStella Laurenzo attributes = py::dict(); 1287436c6c9cSStella Laurenzo } 1288436c6c9cSStella Laurenzo if (attributes->contains("result_segment_sizes") || 1289436c6c9cSStella Laurenzo attributes->contains("operand_segment_sizes")) { 1290436c6c9cSStella Laurenzo throw py::value_error("Manually setting a 'result_segment_sizes' or " 1291436c6c9cSStella Laurenzo "'operand_segment_sizes' attribute is unsupported. " 1292436c6c9cSStella Laurenzo "Use Operation.create for such low-level access."); 1293436c6c9cSStella Laurenzo } 1294436c6c9cSStella Laurenzo 1295436c6c9cSStella Laurenzo // Add result_segment_sizes attribute. 1296436c6c9cSStella Laurenzo if (!resultSegmentLengths.empty()) { 1297436c6c9cSStella Laurenzo int64_t size = resultSegmentLengths.size(); 12988d05a288SStella Laurenzo MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 12998d05a288SStella Laurenzo mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1300436c6c9cSStella Laurenzo resultSegmentLengths.size(), resultSegmentLengths.data()); 1301436c6c9cSStella Laurenzo (*attributes)["result_segment_sizes"] = 1302436c6c9cSStella Laurenzo PyAttribute(context, segmentLengthAttr); 1303436c6c9cSStella Laurenzo } 1304436c6c9cSStella Laurenzo 1305436c6c9cSStella Laurenzo // Add operand_segment_sizes attribute. 1306436c6c9cSStella Laurenzo if (!operandSegmentLengths.empty()) { 1307436c6c9cSStella Laurenzo int64_t size = operandSegmentLengths.size(); 13088d05a288SStella Laurenzo MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 13098d05a288SStella Laurenzo mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1310436c6c9cSStella Laurenzo operandSegmentLengths.size(), operandSegmentLengths.data()); 1311436c6c9cSStella Laurenzo (*attributes)["operand_segment_sizes"] = 1312436c6c9cSStella Laurenzo PyAttribute(context, segmentLengthAttr); 1313436c6c9cSStella Laurenzo } 1314436c6c9cSStella Laurenzo } 1315436c6c9cSStella Laurenzo 1316436c6c9cSStella Laurenzo // Delegate to create. 1317436c6c9cSStella Laurenzo return PyOperation::create(std::move(name), 1318436c6c9cSStella Laurenzo /*results=*/std::move(resultTypes), 1319436c6c9cSStella Laurenzo /*operands=*/std::move(operands), 1320436c6c9cSStella Laurenzo /*attributes=*/std::move(attributes), 1321436c6c9cSStella Laurenzo /*successors=*/std::move(successors), 1322436c6c9cSStella Laurenzo /*regions=*/*regions, location, maybeIp); 1323436c6c9cSStella Laurenzo } 1324436c6c9cSStella Laurenzo 1325436c6c9cSStella Laurenzo PyOpView::PyOpView(py::object operationObject) 1326436c6c9cSStella Laurenzo // Casting through the PyOperationBase base-class and then back to the 1327436c6c9cSStella Laurenzo // Operation lets us accept any PyOperationBase subclass. 1328436c6c9cSStella Laurenzo : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1329436c6c9cSStella Laurenzo operationObject(operation.getRef().getObject()) {} 1330436c6c9cSStella Laurenzo 1331436c6c9cSStella Laurenzo py::object PyOpView::createRawSubclass(py::object userClass) { 1332436c6c9cSStella Laurenzo // This is... a little gross. The typical pattern is to have a pure python 1333436c6c9cSStella Laurenzo // class that extends OpView like: 1334436c6c9cSStella Laurenzo // class AddFOp(_cext.ir.OpView): 1335436c6c9cSStella Laurenzo // def __init__(self, loc, lhs, rhs): 1336436c6c9cSStella Laurenzo // operation = loc.context.create_operation( 1337436c6c9cSStella Laurenzo // "addf", lhs, rhs, results=[lhs.type]) 1338436c6c9cSStella Laurenzo // super().__init__(operation) 1339436c6c9cSStella Laurenzo // 1340436c6c9cSStella Laurenzo // I.e. The goal of the user facing type is to provide a nice constructor 1341436c6c9cSStella Laurenzo // that has complete freedom for the op under construction. This is at odds 1342436c6c9cSStella Laurenzo // with our other desire to sometimes create this object by just passing an 1343436c6c9cSStella Laurenzo // operation (to initialize the base class). We could do *arg and **kwargs 1344436c6c9cSStella Laurenzo // munging to try to make it work, but instead, we synthesize a new class 1345436c6c9cSStella Laurenzo // on the fly which extends this user class (AddFOp in this example) and 1346436c6c9cSStella Laurenzo // *give it* the base class's __init__ method, thus bypassing the 1347436c6c9cSStella Laurenzo // intermediate subclass's __init__ method entirely. While slightly, 1348436c6c9cSStella Laurenzo // underhanded, this is safe/legal because the type hierarchy has not changed 1349436c6c9cSStella Laurenzo // (we just added a new leaf) and we aren't mucking around with __new__. 1350436c6c9cSStella Laurenzo // Typically, this new class will be stored on the original as "_Raw" and will 1351436c6c9cSStella Laurenzo // be used for casts and other things that need a variant of the class that 1352436c6c9cSStella Laurenzo // is initialized purely from an operation. 1353436c6c9cSStella Laurenzo py::object parentMetaclass = 1354436c6c9cSStella Laurenzo py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1355436c6c9cSStella Laurenzo py::dict attributes; 1356436c6c9cSStella Laurenzo // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1357436c6c9cSStella Laurenzo // now. 1358436c6c9cSStella Laurenzo // auto opViewType = py::type::of<PyOpView>(); 1359436c6c9cSStella Laurenzo auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1360436c6c9cSStella Laurenzo attributes["__init__"] = opViewType.attr("__init__"); 1361436c6c9cSStella Laurenzo py::str origName = userClass.attr("__name__"); 1362436c6c9cSStella Laurenzo py::str newName = py::str("_") + origName; 1363436c6c9cSStella Laurenzo return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1364436c6c9cSStella Laurenzo } 1365436c6c9cSStella Laurenzo 1366436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1367436c6c9cSStella Laurenzo // PyInsertionPoint. 1368436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1369436c6c9cSStella Laurenzo 1370436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1371436c6c9cSStella Laurenzo 1372436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1373436c6c9cSStella Laurenzo : refOperation(beforeOperationBase.getOperation().getRef()), 1374436c6c9cSStella Laurenzo block((*refOperation)->getBlock()) {} 1375436c6c9cSStella Laurenzo 1376436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1377436c6c9cSStella Laurenzo PyOperation &operation = operationBase.getOperation(); 1378436c6c9cSStella Laurenzo if (operation.isAttached()) 1379436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 1380436c6c9cSStella Laurenzo "Attempt to insert operation that is already attached"); 1381436c6c9cSStella Laurenzo block.getParentOperation()->checkValid(); 1382436c6c9cSStella Laurenzo MlirOperation beforeOp = {nullptr}; 1383436c6c9cSStella Laurenzo if (refOperation) { 1384436c6c9cSStella Laurenzo // Insert before operation. 1385436c6c9cSStella Laurenzo (*refOperation)->checkValid(); 1386436c6c9cSStella Laurenzo beforeOp = (*refOperation)->get(); 1387436c6c9cSStella Laurenzo } else { 1388436c6c9cSStella Laurenzo // Insert at end (before null) is only valid if the block does not 1389436c6c9cSStella Laurenzo // already end in a known terminator (violating this will cause assertion 1390436c6c9cSStella Laurenzo // failures later). 1391436c6c9cSStella Laurenzo if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1392436c6c9cSStella Laurenzo throw py::index_error("Cannot insert operation at the end of a block " 1393436c6c9cSStella Laurenzo "that already has a terminator. Did you mean to " 1394436c6c9cSStella Laurenzo "use 'InsertionPoint.at_block_terminator(block)' " 1395436c6c9cSStella Laurenzo "versus 'InsertionPoint(block)'?"); 1396436c6c9cSStella Laurenzo } 1397436c6c9cSStella Laurenzo } 1398436c6c9cSStella Laurenzo mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1399436c6c9cSStella Laurenzo operation.setAttached(); 1400436c6c9cSStella Laurenzo } 1401436c6c9cSStella Laurenzo 1402436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1403436c6c9cSStella Laurenzo MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1404436c6c9cSStella Laurenzo if (mlirOperationIsNull(firstOp)) { 1405436c6c9cSStella Laurenzo // Just insert at end. 1406436c6c9cSStella Laurenzo return PyInsertionPoint(block); 1407436c6c9cSStella Laurenzo } 1408436c6c9cSStella Laurenzo 1409436c6c9cSStella Laurenzo // Insert before first op. 1410436c6c9cSStella Laurenzo PyOperationRef firstOpRef = PyOperation::forOperation( 1411436c6c9cSStella Laurenzo block.getParentOperation()->getContext(), firstOp); 1412436c6c9cSStella Laurenzo return PyInsertionPoint{block, std::move(firstOpRef)}; 1413436c6c9cSStella Laurenzo } 1414436c6c9cSStella Laurenzo 1415436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1416436c6c9cSStella Laurenzo MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1417436c6c9cSStella Laurenzo if (mlirOperationIsNull(terminator)) 1418436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1419436c6c9cSStella Laurenzo PyOperationRef terminatorOpRef = PyOperation::forOperation( 1420436c6c9cSStella Laurenzo block.getParentOperation()->getContext(), terminator); 1421436c6c9cSStella Laurenzo return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1422436c6c9cSStella Laurenzo } 1423436c6c9cSStella Laurenzo 1424436c6c9cSStella Laurenzo py::object PyInsertionPoint::contextEnter() { 1425436c6c9cSStella Laurenzo return PyThreadContextEntry::pushInsertionPoint(*this); 1426436c6c9cSStella Laurenzo } 1427436c6c9cSStella Laurenzo 1428436c6c9cSStella Laurenzo void PyInsertionPoint::contextExit(pybind11::object excType, 1429436c6c9cSStella Laurenzo pybind11::object excVal, 1430436c6c9cSStella Laurenzo pybind11::object excTb) { 1431436c6c9cSStella Laurenzo PyThreadContextEntry::popInsertionPoint(*this); 1432436c6c9cSStella Laurenzo } 1433436c6c9cSStella Laurenzo 1434436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1435436c6c9cSStella Laurenzo // PyAttribute. 1436436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1437436c6c9cSStella Laurenzo 1438436c6c9cSStella Laurenzo bool PyAttribute::operator==(const PyAttribute &other) { 1439436c6c9cSStella Laurenzo return mlirAttributeEqual(attr, other.attr); 1440436c6c9cSStella Laurenzo } 1441436c6c9cSStella Laurenzo 1442436c6c9cSStella Laurenzo py::object PyAttribute::getCapsule() { 1443436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1444436c6c9cSStella Laurenzo } 1445436c6c9cSStella Laurenzo 1446436c6c9cSStella Laurenzo PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1447436c6c9cSStella Laurenzo MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1448436c6c9cSStella Laurenzo if (mlirAttributeIsNull(rawAttr)) 1449436c6c9cSStella Laurenzo throw py::error_already_set(); 1450436c6c9cSStella Laurenzo return PyAttribute( 1451436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1452436c6c9cSStella Laurenzo } 1453436c6c9cSStella Laurenzo 1454436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1455436c6c9cSStella Laurenzo // PyNamedAttribute. 1456436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1457436c6c9cSStella Laurenzo 1458436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1459436c6c9cSStella Laurenzo : ownedName(new std::string(std::move(ownedName))) { 1460436c6c9cSStella Laurenzo namedAttr = mlirNamedAttributeGet( 1461436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(attr), 1462436c6c9cSStella Laurenzo toMlirStringRef(*this->ownedName)), 1463436c6c9cSStella Laurenzo attr); 1464436c6c9cSStella Laurenzo } 1465436c6c9cSStella Laurenzo 1466436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1467436c6c9cSStella Laurenzo // PyType. 1468436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1469436c6c9cSStella Laurenzo 1470436c6c9cSStella Laurenzo bool PyType::operator==(const PyType &other) { 1471436c6c9cSStella Laurenzo return mlirTypeEqual(type, other.type); 1472436c6c9cSStella Laurenzo } 1473436c6c9cSStella Laurenzo 1474436c6c9cSStella Laurenzo py::object PyType::getCapsule() { 1475436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1476436c6c9cSStella Laurenzo } 1477436c6c9cSStella Laurenzo 1478436c6c9cSStella Laurenzo PyType PyType::createFromCapsule(py::object capsule) { 1479436c6c9cSStella Laurenzo MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1480436c6c9cSStella Laurenzo if (mlirTypeIsNull(rawType)) 1481436c6c9cSStella Laurenzo throw py::error_already_set(); 1482436c6c9cSStella Laurenzo return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1483436c6c9cSStella Laurenzo rawType); 1484436c6c9cSStella Laurenzo } 1485436c6c9cSStella Laurenzo 1486436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1487436c6c9cSStella Laurenzo // PyValue and subclases. 1488436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1489436c6c9cSStella Laurenzo 14903f3d1c90SMike Urbach pybind11::object PyValue::getCapsule() { 14913f3d1c90SMike Urbach return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 14923f3d1c90SMike Urbach } 14933f3d1c90SMike Urbach 14943f3d1c90SMike Urbach PyValue PyValue::createFromCapsule(pybind11::object capsule) { 14953f3d1c90SMike Urbach MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 14963f3d1c90SMike Urbach if (mlirValueIsNull(value)) 14973f3d1c90SMike Urbach throw py::error_already_set(); 14983f3d1c90SMike Urbach MlirOperation owner; 14993f3d1c90SMike Urbach if (mlirValueIsAOpResult(value)) 15003f3d1c90SMike Urbach owner = mlirOpResultGetOwner(value); 15013f3d1c90SMike Urbach if (mlirValueIsABlockArgument(value)) 15023f3d1c90SMike Urbach owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 15033f3d1c90SMike Urbach if (mlirOperationIsNull(owner)) 15043f3d1c90SMike Urbach throw py::error_already_set(); 15053f3d1c90SMike Urbach MlirContext ctx = mlirOperationGetContext(owner); 15063f3d1c90SMike Urbach PyOperationRef ownerRef = 15073f3d1c90SMike Urbach PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 15083f3d1c90SMike Urbach return PyValue(ownerRef, value); 15093f3d1c90SMike Urbach } 15103f3d1c90SMike Urbach 1511436c6c9cSStella Laurenzo namespace { 1512436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR values that subclass Value and should be 1513436c6c9cSStella Laurenzo /// castable from it. The value hierarchy is one level deep and is not supposed 1514436c6c9cSStella Laurenzo /// to accommodate other levels unless core MLIR changes. 1515436c6c9cSStella Laurenzo template <typename DerivedTy> 1516436c6c9cSStella Laurenzo class PyConcreteValue : public PyValue { 1517436c6c9cSStella Laurenzo public: 1518436c6c9cSStella Laurenzo // Derived classes must define statics for: 1519436c6c9cSStella Laurenzo // IsAFunctionTy isaFunction 1520436c6c9cSStella Laurenzo // const char *pyClassName 1521436c6c9cSStella Laurenzo // and redefine bindDerived. 1522436c6c9cSStella Laurenzo using ClassTy = py::class_<DerivedTy, PyValue>; 1523436c6c9cSStella Laurenzo using IsAFunctionTy = bool (*)(MlirValue); 1524436c6c9cSStella Laurenzo 1525436c6c9cSStella Laurenzo PyConcreteValue() = default; 1526436c6c9cSStella Laurenzo PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1527436c6c9cSStella Laurenzo : PyValue(operationRef, value) {} 1528436c6c9cSStella Laurenzo PyConcreteValue(PyValue &orig) 1529436c6c9cSStella Laurenzo : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1530436c6c9cSStella Laurenzo 1531436c6c9cSStella Laurenzo /// Attempts to cast the original value to the derived type and throws on 1532436c6c9cSStella Laurenzo /// type mismatches. 1533436c6c9cSStella Laurenzo static MlirValue castFrom(PyValue &orig) { 1534436c6c9cSStella Laurenzo if (!DerivedTy::isaFunction(orig.get())) { 1535436c6c9cSStella Laurenzo auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1536436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1537436c6c9cSStella Laurenzo DerivedTy::pyClassName + 1538436c6c9cSStella Laurenzo " (from " + origRepr + ")"); 1539436c6c9cSStella Laurenzo } 1540436c6c9cSStella Laurenzo return orig.get(); 1541436c6c9cSStella Laurenzo } 1542436c6c9cSStella Laurenzo 1543436c6c9cSStella Laurenzo /// Binds the Python module objects to functions of this class. 1544436c6c9cSStella Laurenzo static void bind(py::module &m) { 1545436c6c9cSStella Laurenzo auto cls = ClassTy(m, DerivedTy::pyClassName); 1546436c6c9cSStella Laurenzo cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1547436c6c9cSStella Laurenzo DerivedTy::bindDerived(cls); 1548436c6c9cSStella Laurenzo } 1549436c6c9cSStella Laurenzo 1550436c6c9cSStella Laurenzo /// Implemented by derived classes to add methods to the Python subclass. 1551436c6c9cSStella Laurenzo static void bindDerived(ClassTy &m) {} 1552436c6c9cSStella Laurenzo }; 1553436c6c9cSStella Laurenzo 1554436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument. 1555436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1556436c6c9cSStella Laurenzo public: 1557436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1558436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BlockArgument"; 1559436c6c9cSStella Laurenzo using PyConcreteValue::PyConcreteValue; 1560436c6c9cSStella Laurenzo 1561436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1562436c6c9cSStella Laurenzo c.def_property_readonly("owner", [](PyBlockArgument &self) { 1563436c6c9cSStella Laurenzo return PyBlock(self.getParentOperation(), 1564436c6c9cSStella Laurenzo mlirBlockArgumentGetOwner(self.get())); 1565436c6c9cSStella Laurenzo }); 1566436c6c9cSStella Laurenzo c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1567436c6c9cSStella Laurenzo return mlirBlockArgumentGetArgNumber(self.get()); 1568436c6c9cSStella Laurenzo }); 1569436c6c9cSStella Laurenzo c.def("set_type", [](PyBlockArgument &self, PyType type) { 1570436c6c9cSStella Laurenzo return mlirBlockArgumentSetType(self.get(), type); 1571436c6c9cSStella Laurenzo }); 1572436c6c9cSStella Laurenzo } 1573436c6c9cSStella Laurenzo }; 1574436c6c9cSStella Laurenzo 1575436c6c9cSStella Laurenzo /// Python wrapper for MlirOpResult. 1576436c6c9cSStella Laurenzo class PyOpResult : public PyConcreteValue<PyOpResult> { 1577436c6c9cSStella Laurenzo public: 1578436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1579436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "OpResult"; 1580436c6c9cSStella Laurenzo using PyConcreteValue::PyConcreteValue; 1581436c6c9cSStella Laurenzo 1582436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1583436c6c9cSStella Laurenzo c.def_property_readonly("owner", [](PyOpResult &self) { 1584436c6c9cSStella Laurenzo assert( 1585436c6c9cSStella Laurenzo mlirOperationEqual(self.getParentOperation()->get(), 1586436c6c9cSStella Laurenzo mlirOpResultGetOwner(self.get())) && 1587436c6c9cSStella Laurenzo "expected the owner of the value in Python to match that in the IR"); 15886ff74f96SMike Urbach return self.getParentOperation().getObject(); 1589436c6c9cSStella Laurenzo }); 1590436c6c9cSStella Laurenzo c.def_property_readonly("result_number", [](PyOpResult &self) { 1591436c6c9cSStella Laurenzo return mlirOpResultGetResultNumber(self.get()); 1592436c6c9cSStella Laurenzo }); 1593436c6c9cSStella Laurenzo } 1594436c6c9cSStella Laurenzo }; 1595436c6c9cSStella Laurenzo 1596436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive 1597436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the 1598436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in 1599436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime. 1600436c6c9cSStella Laurenzo class PyBlockArgumentList { 1601436c6c9cSStella Laurenzo public: 1602436c6c9cSStella Laurenzo PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1603436c6c9cSStella Laurenzo : operation(std::move(operation)), block(block) {} 1604436c6c9cSStella Laurenzo 1605436c6c9cSStella Laurenzo /// Returns the length of the block argument list. 1606436c6c9cSStella Laurenzo intptr_t dunderLen() { 1607436c6c9cSStella Laurenzo operation->checkValid(); 1608436c6c9cSStella Laurenzo return mlirBlockGetNumArguments(block); 1609436c6c9cSStella Laurenzo } 1610436c6c9cSStella Laurenzo 1611436c6c9cSStella Laurenzo /// Returns `index`-th element of the block argument list. 1612436c6c9cSStella Laurenzo PyBlockArgument dunderGetItem(intptr_t index) { 1613436c6c9cSStella Laurenzo if (index < 0 || index >= dunderLen()) { 1614436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 1615436c6c9cSStella Laurenzo "attempt to access out of bounds region"); 1616436c6c9cSStella Laurenzo } 1617436c6c9cSStella Laurenzo PyValue value(operation, mlirBlockGetArgument(block, index)); 1618436c6c9cSStella Laurenzo return PyBlockArgument(value); 1619436c6c9cSStella Laurenzo } 1620436c6c9cSStella Laurenzo 1621436c6c9cSStella Laurenzo /// Defines a Python class in the bindings. 1622436c6c9cSStella Laurenzo static void bind(py::module &m) { 1623436c6c9cSStella Laurenzo py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1624436c6c9cSStella Laurenzo .def("__len__", &PyBlockArgumentList::dunderLen) 1625436c6c9cSStella Laurenzo .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1626436c6c9cSStella Laurenzo } 1627436c6c9cSStella Laurenzo 1628436c6c9cSStella Laurenzo private: 1629436c6c9cSStella Laurenzo PyOperationRef operation; 1630436c6c9cSStella Laurenzo MlirBlock block; 1631436c6c9cSStella Laurenzo }; 1632436c6c9cSStella Laurenzo 1633436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive 1634436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the 1635436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this 1636436c6c9cSStella Laurenzo /// operation. 1637436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1638436c6c9cSStella Laurenzo public: 1639436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "OpOperandList"; 1640436c6c9cSStella Laurenzo 1641436c6c9cSStella Laurenzo PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1642436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 1643436c6c9cSStella Laurenzo : Sliceable(startIndex, 1644436c6c9cSStella Laurenzo length == -1 ? mlirOperationGetNumOperands(operation->get()) 1645436c6c9cSStella Laurenzo : length, 1646436c6c9cSStella Laurenzo step), 1647436c6c9cSStella Laurenzo operation(operation) {} 1648436c6c9cSStella Laurenzo 1649436c6c9cSStella Laurenzo intptr_t getNumElements() { 1650436c6c9cSStella Laurenzo operation->checkValid(); 1651436c6c9cSStella Laurenzo return mlirOperationGetNumOperands(operation->get()); 1652436c6c9cSStella Laurenzo } 1653436c6c9cSStella Laurenzo 1654436c6c9cSStella Laurenzo PyValue getElement(intptr_t pos) { 1655436c6c9cSStella Laurenzo return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); 1656436c6c9cSStella Laurenzo } 1657436c6c9cSStella Laurenzo 1658436c6c9cSStella Laurenzo PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1659436c6c9cSStella Laurenzo return PyOpOperandList(operation, startIndex, length, step); 1660436c6c9cSStella Laurenzo } 1661436c6c9cSStella Laurenzo 166263d16d06SMike Urbach void dunderSetItem(intptr_t index, PyValue value) { 166363d16d06SMike Urbach index = wrapIndex(index); 166463d16d06SMike Urbach mlirOperationSetOperand(operation->get(), index, value.get()); 166563d16d06SMike Urbach } 166663d16d06SMike Urbach 166763d16d06SMike Urbach static void bindDerived(ClassTy &c) { 166863d16d06SMike Urbach c.def("__setitem__", &PyOpOperandList::dunderSetItem); 166963d16d06SMike Urbach } 167063d16d06SMike Urbach 1671436c6c9cSStella Laurenzo private: 1672436c6c9cSStella Laurenzo PyOperationRef operation; 1673436c6c9cSStella Laurenzo }; 1674436c6c9cSStella Laurenzo 1675436c6c9cSStella Laurenzo /// A list of operation results. Internally, these are stored as consecutive 1676436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the 1677436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this 1678436c6c9cSStella Laurenzo /// operation. 1679436c6c9cSStella Laurenzo class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1680436c6c9cSStella Laurenzo public: 1681436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "OpResultList"; 1682436c6c9cSStella Laurenzo 1683436c6c9cSStella Laurenzo PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1684436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 1685436c6c9cSStella Laurenzo : Sliceable(startIndex, 1686436c6c9cSStella Laurenzo length == -1 ? mlirOperationGetNumResults(operation->get()) 1687436c6c9cSStella Laurenzo : length, 1688436c6c9cSStella Laurenzo step), 1689436c6c9cSStella Laurenzo operation(operation) {} 1690436c6c9cSStella Laurenzo 1691436c6c9cSStella Laurenzo intptr_t getNumElements() { 1692436c6c9cSStella Laurenzo operation->checkValid(); 1693436c6c9cSStella Laurenzo return mlirOperationGetNumResults(operation->get()); 1694436c6c9cSStella Laurenzo } 1695436c6c9cSStella Laurenzo 1696436c6c9cSStella Laurenzo PyOpResult getElement(intptr_t index) { 1697436c6c9cSStella Laurenzo PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1698436c6c9cSStella Laurenzo return PyOpResult(value); 1699436c6c9cSStella Laurenzo } 1700436c6c9cSStella Laurenzo 1701436c6c9cSStella Laurenzo PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1702436c6c9cSStella Laurenzo return PyOpResultList(operation, startIndex, length, step); 1703436c6c9cSStella Laurenzo } 1704436c6c9cSStella Laurenzo 1705436c6c9cSStella Laurenzo private: 1706436c6c9cSStella Laurenzo PyOperationRef operation; 1707436c6c9cSStella Laurenzo }; 1708436c6c9cSStella Laurenzo 1709436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing 1710436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes. 1711436c6c9cSStella Laurenzo class PyOpAttributeMap { 1712436c6c9cSStella Laurenzo public: 1713436c6c9cSStella Laurenzo PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1714436c6c9cSStella Laurenzo 1715436c6c9cSStella Laurenzo PyAttribute dunderGetItemNamed(const std::string &name) { 1716436c6c9cSStella Laurenzo MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1717436c6c9cSStella Laurenzo toMlirStringRef(name)); 1718436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 1719436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 1720436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 1721436c6c9cSStella Laurenzo } 1722436c6c9cSStella Laurenzo return PyAttribute(operation->getContext(), attr); 1723436c6c9cSStella Laurenzo } 1724436c6c9cSStella Laurenzo 1725436c6c9cSStella Laurenzo PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1726436c6c9cSStella Laurenzo if (index < 0 || index >= dunderLen()) { 1727436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 1728436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 1729436c6c9cSStella Laurenzo } 1730436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = 1731436c6c9cSStella Laurenzo mlirOperationGetAttribute(operation->get(), index); 1732436c6c9cSStella Laurenzo return PyNamedAttribute( 1733436c6c9cSStella Laurenzo namedAttr.attribute, 1734436c6c9cSStella Laurenzo std::string(mlirIdentifierStr(namedAttr.name).data)); 1735436c6c9cSStella Laurenzo } 1736436c6c9cSStella Laurenzo 1737436c6c9cSStella Laurenzo void dunderSetItem(const std::string &name, PyAttribute attr) { 1738436c6c9cSStella Laurenzo mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1739436c6c9cSStella Laurenzo attr); 1740436c6c9cSStella Laurenzo } 1741436c6c9cSStella Laurenzo 1742436c6c9cSStella Laurenzo void dunderDelItem(const std::string &name) { 1743436c6c9cSStella Laurenzo int removed = mlirOperationRemoveAttributeByName(operation->get(), 1744436c6c9cSStella Laurenzo toMlirStringRef(name)); 1745436c6c9cSStella Laurenzo if (!removed) 1746436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 1747436c6c9cSStella Laurenzo "attempt to delete a non-existent attribute"); 1748436c6c9cSStella Laurenzo } 1749436c6c9cSStella Laurenzo 1750436c6c9cSStella Laurenzo intptr_t dunderLen() { 1751436c6c9cSStella Laurenzo return mlirOperationGetNumAttributes(operation->get()); 1752436c6c9cSStella Laurenzo } 1753436c6c9cSStella Laurenzo 1754436c6c9cSStella Laurenzo bool dunderContains(const std::string &name) { 1755436c6c9cSStella Laurenzo return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1756436c6c9cSStella Laurenzo operation->get(), toMlirStringRef(name))); 1757436c6c9cSStella Laurenzo } 1758436c6c9cSStella Laurenzo 1759436c6c9cSStella Laurenzo static void bind(py::module &m) { 1760436c6c9cSStella Laurenzo py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1761436c6c9cSStella Laurenzo .def("__contains__", &PyOpAttributeMap::dunderContains) 1762436c6c9cSStella Laurenzo .def("__len__", &PyOpAttributeMap::dunderLen) 1763436c6c9cSStella Laurenzo .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1764436c6c9cSStella Laurenzo .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1765436c6c9cSStella Laurenzo .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1766436c6c9cSStella Laurenzo .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1767436c6c9cSStella Laurenzo } 1768436c6c9cSStella Laurenzo 1769436c6c9cSStella Laurenzo private: 1770436c6c9cSStella Laurenzo PyOperationRef operation; 1771436c6c9cSStella Laurenzo }; 1772436c6c9cSStella Laurenzo 1773436c6c9cSStella Laurenzo } // end namespace 1774436c6c9cSStella Laurenzo 1775436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1776436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule. 1777436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1778436c6c9cSStella Laurenzo 1779436c6c9cSStella Laurenzo void mlir::python::populateIRCore(py::module &m) { 1780436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 17814acd8457SAlex Zinenko // Mapping of MlirContext. 1782436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1783436c6c9cSStella Laurenzo py::class_<PyMlirContext>(m, "Context") 1784436c6c9cSStella Laurenzo .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1785436c6c9cSStella Laurenzo .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1786436c6c9cSStella Laurenzo .def("_get_context_again", 1787436c6c9cSStella Laurenzo [](PyMlirContext &self) { 1788436c6c9cSStella Laurenzo PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1789436c6c9cSStella Laurenzo return ref.releaseObject(); 1790436c6c9cSStella Laurenzo }) 1791436c6c9cSStella Laurenzo .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1792436c6c9cSStella Laurenzo .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1793436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1794436c6c9cSStella Laurenzo &PyMlirContext::getCapsule) 1795436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1796436c6c9cSStella Laurenzo .def("__enter__", &PyMlirContext::contextEnter) 1797436c6c9cSStella Laurenzo .def("__exit__", &PyMlirContext::contextExit) 1798436c6c9cSStella Laurenzo .def_property_readonly_static( 1799436c6c9cSStella Laurenzo "current", 1800436c6c9cSStella Laurenzo [](py::object & /*class*/) { 1801436c6c9cSStella Laurenzo auto *context = PyThreadContextEntry::getDefaultContext(); 1802436c6c9cSStella Laurenzo if (!context) 1803436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "No current Context"); 1804436c6c9cSStella Laurenzo return context; 1805436c6c9cSStella Laurenzo }, 1806436c6c9cSStella Laurenzo "Gets the Context bound to the current thread or raises ValueError") 1807436c6c9cSStella Laurenzo .def_property_readonly( 1808436c6c9cSStella Laurenzo "dialects", 1809436c6c9cSStella Laurenzo [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1810436c6c9cSStella Laurenzo "Gets a container for accessing dialects by name") 1811436c6c9cSStella Laurenzo .def_property_readonly( 1812436c6c9cSStella Laurenzo "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1813436c6c9cSStella Laurenzo "Alias for 'dialect'") 1814436c6c9cSStella Laurenzo .def( 1815436c6c9cSStella Laurenzo "get_dialect_descriptor", 1816436c6c9cSStella Laurenzo [=](PyMlirContext &self, std::string &name) { 1817436c6c9cSStella Laurenzo MlirDialect dialect = mlirContextGetOrLoadDialect( 1818436c6c9cSStella Laurenzo self.get(), {name.data(), name.size()}); 1819436c6c9cSStella Laurenzo if (mlirDialectIsNull(dialect)) { 1820436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 1821436c6c9cSStella Laurenzo Twine("Dialect '") + name + "' not found"); 1822436c6c9cSStella Laurenzo } 1823436c6c9cSStella Laurenzo return PyDialectDescriptor(self.getRef(), dialect); 1824436c6c9cSStella Laurenzo }, 1825436c6c9cSStella Laurenzo "Gets or loads a dialect by name, returning its descriptor object") 1826436c6c9cSStella Laurenzo .def_property( 1827436c6c9cSStella Laurenzo "allow_unregistered_dialects", 1828436c6c9cSStella Laurenzo [](PyMlirContext &self) -> bool { 1829436c6c9cSStella Laurenzo return mlirContextGetAllowUnregisteredDialects(self.get()); 1830436c6c9cSStella Laurenzo }, 1831436c6c9cSStella Laurenzo [](PyMlirContext &self, bool value) { 1832436c6c9cSStella Laurenzo mlirContextSetAllowUnregisteredDialects(self.get(), value); 18339a9214faSStella Laurenzo }) 1834caa159f0SNicolas Vasilache .def("enable_multithreading", 1835caa159f0SNicolas Vasilache [](PyMlirContext &self, bool enable) { 1836caa159f0SNicolas Vasilache mlirContextEnableMultithreading(self.get(), enable); 1837caa159f0SNicolas Vasilache }) 18389a9214faSStella Laurenzo .def("is_registered_operation", 18399a9214faSStella Laurenzo [](PyMlirContext &self, std::string &name) { 18409a9214faSStella Laurenzo return mlirContextIsRegisteredOperation( 18419a9214faSStella Laurenzo self.get(), MlirStringRef{name.data(), name.size()}); 1842436c6c9cSStella Laurenzo }); 1843436c6c9cSStella Laurenzo 1844436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1845436c6c9cSStella Laurenzo // Mapping of PyDialectDescriptor 1846436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1847436c6c9cSStella Laurenzo py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1848436c6c9cSStella Laurenzo .def_property_readonly("namespace", 1849436c6c9cSStella Laurenzo [](PyDialectDescriptor &self) { 1850436c6c9cSStella Laurenzo MlirStringRef ns = 1851436c6c9cSStella Laurenzo mlirDialectGetNamespace(self.get()); 1852436c6c9cSStella Laurenzo return py::str(ns.data, ns.length); 1853436c6c9cSStella Laurenzo }) 1854436c6c9cSStella Laurenzo .def("__repr__", [](PyDialectDescriptor &self) { 1855436c6c9cSStella Laurenzo MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1856436c6c9cSStella Laurenzo std::string repr("<DialectDescriptor "); 1857436c6c9cSStella Laurenzo repr.append(ns.data, ns.length); 1858436c6c9cSStella Laurenzo repr.append(">"); 1859436c6c9cSStella Laurenzo return repr; 1860436c6c9cSStella Laurenzo }); 1861436c6c9cSStella Laurenzo 1862436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1863436c6c9cSStella Laurenzo // Mapping of PyDialects 1864436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1865436c6c9cSStella Laurenzo py::class_<PyDialects>(m, "Dialects") 1866436c6c9cSStella Laurenzo .def("__getitem__", 1867436c6c9cSStella Laurenzo [=](PyDialects &self, std::string keyName) { 1868436c6c9cSStella Laurenzo MlirDialect dialect = 1869436c6c9cSStella Laurenzo self.getDialectForKey(keyName, /*attrError=*/false); 1870436c6c9cSStella Laurenzo py::object descriptor = 1871436c6c9cSStella Laurenzo py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1872436c6c9cSStella Laurenzo return createCustomDialectWrapper(keyName, std::move(descriptor)); 1873436c6c9cSStella Laurenzo }) 1874436c6c9cSStella Laurenzo .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1875436c6c9cSStella Laurenzo MlirDialect dialect = 1876436c6c9cSStella Laurenzo self.getDialectForKey(attrName, /*attrError=*/true); 1877436c6c9cSStella Laurenzo py::object descriptor = 1878436c6c9cSStella Laurenzo py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1879436c6c9cSStella Laurenzo return createCustomDialectWrapper(attrName, std::move(descriptor)); 1880436c6c9cSStella Laurenzo }); 1881436c6c9cSStella Laurenzo 1882436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1883436c6c9cSStella Laurenzo // Mapping of PyDialect 1884436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1885436c6c9cSStella Laurenzo py::class_<PyDialect>(m, "Dialect") 1886436c6c9cSStella Laurenzo .def(py::init<py::object>(), "descriptor") 1887436c6c9cSStella Laurenzo .def_property_readonly( 1888436c6c9cSStella Laurenzo "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1889436c6c9cSStella Laurenzo .def("__repr__", [](py::object self) { 1890436c6c9cSStella Laurenzo auto clazz = self.attr("__class__"); 1891436c6c9cSStella Laurenzo return py::str("<Dialect ") + 1892436c6c9cSStella Laurenzo self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1893436c6c9cSStella Laurenzo clazz.attr("__module__") + py::str(".") + 1894436c6c9cSStella Laurenzo clazz.attr("__name__") + py::str(")>"); 1895436c6c9cSStella Laurenzo }); 1896436c6c9cSStella Laurenzo 1897436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1898436c6c9cSStella Laurenzo // Mapping of Location 1899436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1900436c6c9cSStella Laurenzo py::class_<PyLocation>(m, "Location") 1901436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1902436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1903436c6c9cSStella Laurenzo .def("__enter__", &PyLocation::contextEnter) 1904436c6c9cSStella Laurenzo .def("__exit__", &PyLocation::contextExit) 1905436c6c9cSStella Laurenzo .def("__eq__", 1906436c6c9cSStella Laurenzo [](PyLocation &self, PyLocation &other) -> bool { 1907436c6c9cSStella Laurenzo return mlirLocationEqual(self, other); 1908436c6c9cSStella Laurenzo }) 1909436c6c9cSStella Laurenzo .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1910436c6c9cSStella Laurenzo .def_property_readonly_static( 1911436c6c9cSStella Laurenzo "current", 1912436c6c9cSStella Laurenzo [](py::object & /*class*/) { 1913436c6c9cSStella Laurenzo auto *loc = PyThreadContextEntry::getDefaultLocation(); 1914436c6c9cSStella Laurenzo if (!loc) 1915436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "No current Location"); 1916436c6c9cSStella Laurenzo return loc; 1917436c6c9cSStella Laurenzo }, 1918436c6c9cSStella Laurenzo "Gets the Location bound to the current thread or raises ValueError") 1919436c6c9cSStella Laurenzo .def_static( 1920436c6c9cSStella Laurenzo "unknown", 1921436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 1922436c6c9cSStella Laurenzo return PyLocation(context->getRef(), 1923436c6c9cSStella Laurenzo mlirLocationUnknownGet(context->get())); 1924436c6c9cSStella Laurenzo }, 1925436c6c9cSStella Laurenzo py::arg("context") = py::none(), 1926436c6c9cSStella Laurenzo "Gets a Location representing an unknown location") 1927436c6c9cSStella Laurenzo .def_static( 1928436c6c9cSStella Laurenzo "file", 1929436c6c9cSStella Laurenzo [](std::string filename, int line, int col, 1930436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 1931436c6c9cSStella Laurenzo return PyLocation( 1932436c6c9cSStella Laurenzo context->getRef(), 1933436c6c9cSStella Laurenzo mlirLocationFileLineColGet( 1934436c6c9cSStella Laurenzo context->get(), toMlirStringRef(filename), line, col)); 1935436c6c9cSStella Laurenzo }, 1936436c6c9cSStella Laurenzo py::arg("filename"), py::arg("line"), py::arg("col"), 1937436c6c9cSStella Laurenzo py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1938436c6c9cSStella Laurenzo .def_property_readonly( 1939436c6c9cSStella Laurenzo "context", 1940436c6c9cSStella Laurenzo [](PyLocation &self) { return self.getContext().getObject(); }, 1941436c6c9cSStella Laurenzo "Context that owns the Location") 1942436c6c9cSStella Laurenzo .def("__repr__", [](PyLocation &self) { 1943436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 1944436c6c9cSStella Laurenzo mlirLocationPrint(self, printAccum.getCallback(), 1945436c6c9cSStella Laurenzo printAccum.getUserData()); 1946436c6c9cSStella Laurenzo return printAccum.join(); 1947436c6c9cSStella Laurenzo }); 1948436c6c9cSStella Laurenzo 1949436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1950436c6c9cSStella Laurenzo // Mapping of Module 1951436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1952436c6c9cSStella Laurenzo py::class_<PyModule>(m, "Module") 1953436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1954436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1955436c6c9cSStella Laurenzo .def_static( 1956436c6c9cSStella Laurenzo "parse", 1957436c6c9cSStella Laurenzo [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1958436c6c9cSStella Laurenzo MlirModule module = mlirModuleCreateParse( 1959436c6c9cSStella Laurenzo context->get(), toMlirStringRef(moduleAsm)); 1960436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 1961436c6c9cSStella Laurenzo // in C API. 1962436c6c9cSStella Laurenzo if (mlirModuleIsNull(module)) { 1963436c6c9cSStella Laurenzo throw SetPyError( 1964436c6c9cSStella Laurenzo PyExc_ValueError, 1965436c6c9cSStella Laurenzo "Unable to parse module assembly (see diagnostics)"); 1966436c6c9cSStella Laurenzo } 1967436c6c9cSStella Laurenzo return PyModule::forModule(module).releaseObject(); 1968436c6c9cSStella Laurenzo }, 1969436c6c9cSStella Laurenzo py::arg("asm"), py::arg("context") = py::none(), 1970436c6c9cSStella Laurenzo kModuleParseDocstring) 1971436c6c9cSStella Laurenzo .def_static( 1972436c6c9cSStella Laurenzo "create", 1973436c6c9cSStella Laurenzo [](DefaultingPyLocation loc) { 1974436c6c9cSStella Laurenzo MlirModule module = mlirModuleCreateEmpty(loc); 1975436c6c9cSStella Laurenzo return PyModule::forModule(module).releaseObject(); 1976436c6c9cSStella Laurenzo }, 1977436c6c9cSStella Laurenzo py::arg("loc") = py::none(), "Creates an empty module") 1978436c6c9cSStella Laurenzo .def_property_readonly( 1979436c6c9cSStella Laurenzo "context", 1980436c6c9cSStella Laurenzo [](PyModule &self) { return self.getContext().getObject(); }, 1981436c6c9cSStella Laurenzo "Context that created the Module") 1982436c6c9cSStella Laurenzo .def_property_readonly( 1983436c6c9cSStella Laurenzo "operation", 1984436c6c9cSStella Laurenzo [](PyModule &self) { 1985436c6c9cSStella Laurenzo return PyOperation::forOperation(self.getContext(), 1986436c6c9cSStella Laurenzo mlirModuleGetOperation(self.get()), 1987436c6c9cSStella Laurenzo self.getRef().releaseObject()) 1988436c6c9cSStella Laurenzo .releaseObject(); 1989436c6c9cSStella Laurenzo }, 1990436c6c9cSStella Laurenzo "Accesses the module as an operation") 1991436c6c9cSStella Laurenzo .def_property_readonly( 1992436c6c9cSStella Laurenzo "body", 1993436c6c9cSStella Laurenzo [](PyModule &self) { 1994436c6c9cSStella Laurenzo PyOperationRef module_op = PyOperation::forOperation( 1995436c6c9cSStella Laurenzo self.getContext(), mlirModuleGetOperation(self.get()), 1996436c6c9cSStella Laurenzo self.getRef().releaseObject()); 1997436c6c9cSStella Laurenzo PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 1998436c6c9cSStella Laurenzo return returnBlock; 1999436c6c9cSStella Laurenzo }, 2000436c6c9cSStella Laurenzo "Return the block for this module") 2001436c6c9cSStella Laurenzo .def( 2002436c6c9cSStella Laurenzo "dump", 2003436c6c9cSStella Laurenzo [](PyModule &self) { 2004436c6c9cSStella Laurenzo mlirOperationDump(mlirModuleGetOperation(self.get())); 2005436c6c9cSStella Laurenzo }, 2006436c6c9cSStella Laurenzo kDumpDocstring) 2007436c6c9cSStella Laurenzo .def( 2008436c6c9cSStella Laurenzo "__str__", 2009436c6c9cSStella Laurenzo [](PyModule &self) { 2010436c6c9cSStella Laurenzo MlirOperation operation = mlirModuleGetOperation(self.get()); 2011436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2012436c6c9cSStella Laurenzo mlirOperationPrint(operation, printAccum.getCallback(), 2013436c6c9cSStella Laurenzo printAccum.getUserData()); 2014436c6c9cSStella Laurenzo return printAccum.join(); 2015436c6c9cSStella Laurenzo }, 2016436c6c9cSStella Laurenzo kOperationStrDunderDocstring); 2017436c6c9cSStella Laurenzo 2018436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2019436c6c9cSStella Laurenzo // Mapping of Operation. 2020436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2021436c6c9cSStella Laurenzo py::class_<PyOperationBase>(m, "_OperationBase") 2022436c6c9cSStella Laurenzo .def("__eq__", 2023436c6c9cSStella Laurenzo [](PyOperationBase &self, PyOperationBase &other) { 2024436c6c9cSStella Laurenzo return &self.getOperation() == &other.getOperation(); 2025436c6c9cSStella Laurenzo }) 2026436c6c9cSStella Laurenzo .def("__eq__", 2027436c6c9cSStella Laurenzo [](PyOperationBase &self, py::object other) { return false; }) 2028436c6c9cSStella Laurenzo .def_property_readonly("attributes", 2029436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2030436c6c9cSStella Laurenzo return PyOpAttributeMap( 2031436c6c9cSStella Laurenzo self.getOperation().getRef()); 2032436c6c9cSStella Laurenzo }) 2033436c6c9cSStella Laurenzo .def_property_readonly("operands", 2034436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2035436c6c9cSStella Laurenzo return PyOpOperandList( 2036436c6c9cSStella Laurenzo self.getOperation().getRef()); 2037436c6c9cSStella Laurenzo }) 2038436c6c9cSStella Laurenzo .def_property_readonly("regions", 2039436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2040436c6c9cSStella Laurenzo return PyRegionList( 2041436c6c9cSStella Laurenzo self.getOperation().getRef()); 2042436c6c9cSStella Laurenzo }) 2043436c6c9cSStella Laurenzo .def_property_readonly( 2044436c6c9cSStella Laurenzo "results", 2045436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2046436c6c9cSStella Laurenzo return PyOpResultList(self.getOperation().getRef()); 2047436c6c9cSStella Laurenzo }, 2048436c6c9cSStella Laurenzo "Returns the list of Operation results.") 2049436c6c9cSStella Laurenzo .def_property_readonly( 2050436c6c9cSStella Laurenzo "result", 2051436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2052436c6c9cSStella Laurenzo auto &operation = self.getOperation(); 2053436c6c9cSStella Laurenzo auto numResults = mlirOperationGetNumResults(operation); 2054436c6c9cSStella Laurenzo if (numResults != 1) { 2055436c6c9cSStella Laurenzo auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2056436c6c9cSStella Laurenzo throw SetPyError( 2057436c6c9cSStella Laurenzo PyExc_ValueError, 2058436c6c9cSStella Laurenzo Twine("Cannot call .result on operation ") + 2059436c6c9cSStella Laurenzo StringRef(name.data, name.length) + " which has " + 2060436c6c9cSStella Laurenzo Twine(numResults) + 2061436c6c9cSStella Laurenzo " results (it is only valid for operations with a " 2062436c6c9cSStella Laurenzo "single result)"); 2063436c6c9cSStella Laurenzo } 2064436c6c9cSStella Laurenzo return PyOpResult(operation.getRef(), 2065436c6c9cSStella Laurenzo mlirOperationGetResult(operation, 0)); 2066436c6c9cSStella Laurenzo }, 2067436c6c9cSStella Laurenzo "Shortcut to get an op result if it has only one (throws an error " 2068436c6c9cSStella Laurenzo "otherwise).") 2069436c6c9cSStella Laurenzo .def("__iter__", 2070436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2071436c6c9cSStella Laurenzo return PyRegionIterator(self.getOperation().getRef()); 2072436c6c9cSStella Laurenzo }) 2073436c6c9cSStella Laurenzo .def( 2074436c6c9cSStella Laurenzo "__str__", 2075436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2076436c6c9cSStella Laurenzo return self.getAsm(/*binary=*/false, 2077436c6c9cSStella Laurenzo /*largeElementsLimit=*/llvm::None, 2078436c6c9cSStella Laurenzo /*enableDebugInfo=*/false, 2079436c6c9cSStella Laurenzo /*prettyDebugInfo=*/false, 2080436c6c9cSStella Laurenzo /*printGenericOpForm=*/false, 2081436c6c9cSStella Laurenzo /*useLocalScope=*/false); 2082436c6c9cSStella Laurenzo }, 2083436c6c9cSStella Laurenzo "Returns the assembly form of the operation.") 2084436c6c9cSStella Laurenzo .def("print", &PyOperationBase::print, 2085436c6c9cSStella Laurenzo // Careful: Lots of arguments must match up with print method. 2086436c6c9cSStella Laurenzo py::arg("file") = py::none(), py::arg("binary") = false, 2087436c6c9cSStella Laurenzo py::arg("large_elements_limit") = py::none(), 2088436c6c9cSStella Laurenzo py::arg("enable_debug_info") = false, 2089436c6c9cSStella Laurenzo py::arg("pretty_debug_info") = false, 2090436c6c9cSStella Laurenzo py::arg("print_generic_op_form") = false, 2091436c6c9cSStella Laurenzo py::arg("use_local_scope") = false, kOperationPrintDocstring) 2092436c6c9cSStella Laurenzo .def("get_asm", &PyOperationBase::getAsm, 2093436c6c9cSStella Laurenzo // Careful: Lots of arguments must match up with get_asm method. 2094436c6c9cSStella Laurenzo py::arg("binary") = false, 2095436c6c9cSStella Laurenzo py::arg("large_elements_limit") = py::none(), 2096436c6c9cSStella Laurenzo py::arg("enable_debug_info") = false, 2097436c6c9cSStella Laurenzo py::arg("pretty_debug_info") = false, 2098436c6c9cSStella Laurenzo py::arg("print_generic_op_form") = false, 2099436c6c9cSStella Laurenzo py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2100436c6c9cSStella Laurenzo .def( 2101436c6c9cSStella Laurenzo "verify", 2102436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2103436c6c9cSStella Laurenzo return mlirOperationVerify(self.getOperation()); 2104436c6c9cSStella Laurenzo }, 2105436c6c9cSStella Laurenzo "Verify the operation and return true if it passes, false if it " 2106436c6c9cSStella Laurenzo "fails."); 2107436c6c9cSStella Laurenzo 2108436c6c9cSStella Laurenzo py::class_<PyOperation, PyOperationBase>(m, "Operation") 2109436c6c9cSStella Laurenzo .def_static("create", &PyOperation::create, py::arg("name"), 2110436c6c9cSStella Laurenzo py::arg("results") = py::none(), 2111436c6c9cSStella Laurenzo py::arg("operands") = py::none(), 2112436c6c9cSStella Laurenzo py::arg("attributes") = py::none(), 2113436c6c9cSStella Laurenzo py::arg("successors") = py::none(), py::arg("regions") = 0, 2114436c6c9cSStella Laurenzo py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2115436c6c9cSStella Laurenzo kOperationCreateDocstring) 2116*c65bb760SJohn Demme .def_property_readonly("parent", 2117*c65bb760SJohn Demme [](PyOperation &self) { 2118*c65bb760SJohn Demme return self.getParentOperation().getObject(); 2119*c65bb760SJohn Demme }) 212049745f87SMike Urbach .def("erase", &PyOperation::erase) 21210126e906SJohn Demme .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 21220126e906SJohn Demme &PyOperation::getCapsule) 21230126e906SJohn Demme .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2124436c6c9cSStella Laurenzo .def_property_readonly("name", 2125436c6c9cSStella Laurenzo [](PyOperation &self) { 212649745f87SMike Urbach self.checkValid(); 2127436c6c9cSStella Laurenzo MlirOperation operation = self.get(); 2128436c6c9cSStella Laurenzo MlirStringRef name = mlirIdentifierStr( 2129436c6c9cSStella Laurenzo mlirOperationGetName(operation)); 2130436c6c9cSStella Laurenzo return py::str(name.data, name.length); 2131436c6c9cSStella Laurenzo }) 2132436c6c9cSStella Laurenzo .def_property_readonly( 2133436c6c9cSStella Laurenzo "context", 213449745f87SMike Urbach [](PyOperation &self) { 213549745f87SMike Urbach self.checkValid(); 213649745f87SMike Urbach return self.getContext().getObject(); 213749745f87SMike Urbach }, 2138436c6c9cSStella Laurenzo "Context that owns the Operation") 2139436c6c9cSStella Laurenzo .def_property_readonly("opview", &PyOperation::createOpView); 2140436c6c9cSStella Laurenzo 2141436c6c9cSStella Laurenzo auto opViewClass = 2142436c6c9cSStella Laurenzo py::class_<PyOpView, PyOperationBase>(m, "OpView") 2143436c6c9cSStella Laurenzo .def(py::init<py::object>()) 2144436c6c9cSStella Laurenzo .def_property_readonly("operation", &PyOpView::getOperationObject) 2145436c6c9cSStella Laurenzo .def_property_readonly( 2146436c6c9cSStella Laurenzo "context", 2147436c6c9cSStella Laurenzo [](PyOpView &self) { 2148436c6c9cSStella Laurenzo return self.getOperation().getContext().getObject(); 2149436c6c9cSStella Laurenzo }, 2150436c6c9cSStella Laurenzo "Context that owns the Operation") 2151436c6c9cSStella Laurenzo .def("__str__", [](PyOpView &self) { 2152436c6c9cSStella Laurenzo return py::str(self.getOperationObject()); 2153436c6c9cSStella Laurenzo }); 2154436c6c9cSStella Laurenzo opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2155436c6c9cSStella Laurenzo opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2156436c6c9cSStella Laurenzo opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2157436c6c9cSStella Laurenzo opViewClass.attr("build_generic") = classmethod( 2158436c6c9cSStella Laurenzo &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2159436c6c9cSStella Laurenzo py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2160436c6c9cSStella Laurenzo py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2161436c6c9cSStella Laurenzo py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2162436c6c9cSStella Laurenzo "Builds a specific, generated OpView based on class level attributes."); 2163436c6c9cSStella Laurenzo 2164436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2165436c6c9cSStella Laurenzo // Mapping of PyRegion. 2166436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2167436c6c9cSStella Laurenzo py::class_<PyRegion>(m, "Region") 2168436c6c9cSStella Laurenzo .def_property_readonly( 2169436c6c9cSStella Laurenzo "blocks", 2170436c6c9cSStella Laurenzo [](PyRegion &self) { 2171436c6c9cSStella Laurenzo return PyBlockList(self.getParentOperation(), self.get()); 2172436c6c9cSStella Laurenzo }, 2173436c6c9cSStella Laurenzo "Returns a forward-optimized sequence of blocks.") 2174436c6c9cSStella Laurenzo .def( 2175436c6c9cSStella Laurenzo "__iter__", 2176436c6c9cSStella Laurenzo [](PyRegion &self) { 2177436c6c9cSStella Laurenzo self.checkValid(); 2178436c6c9cSStella Laurenzo MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2179436c6c9cSStella Laurenzo return PyBlockIterator(self.getParentOperation(), firstBlock); 2180436c6c9cSStella Laurenzo }, 2181436c6c9cSStella Laurenzo "Iterates over blocks in the region.") 2182436c6c9cSStella Laurenzo .def("__eq__", 2183436c6c9cSStella Laurenzo [](PyRegion &self, PyRegion &other) { 2184436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 2185436c6c9cSStella Laurenzo }) 2186436c6c9cSStella Laurenzo .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2187436c6c9cSStella Laurenzo 2188436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2189436c6c9cSStella Laurenzo // Mapping of PyBlock. 2190436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2191436c6c9cSStella Laurenzo py::class_<PyBlock>(m, "Block") 2192436c6c9cSStella Laurenzo .def_property_readonly( 2193436c6c9cSStella Laurenzo "arguments", 2194436c6c9cSStella Laurenzo [](PyBlock &self) { 2195436c6c9cSStella Laurenzo return PyBlockArgumentList(self.getParentOperation(), self.get()); 2196436c6c9cSStella Laurenzo }, 2197436c6c9cSStella Laurenzo "Returns a list of block arguments.") 2198436c6c9cSStella Laurenzo .def_property_readonly( 2199436c6c9cSStella Laurenzo "operations", 2200436c6c9cSStella Laurenzo [](PyBlock &self) { 2201436c6c9cSStella Laurenzo return PyOperationList(self.getParentOperation(), self.get()); 2202436c6c9cSStella Laurenzo }, 2203436c6c9cSStella Laurenzo "Returns a forward-optimized sequence of operations.") 2204436c6c9cSStella Laurenzo .def( 2205436c6c9cSStella Laurenzo "__iter__", 2206436c6c9cSStella Laurenzo [](PyBlock &self) { 2207436c6c9cSStella Laurenzo self.checkValid(); 2208436c6c9cSStella Laurenzo MlirOperation firstOperation = 2209436c6c9cSStella Laurenzo mlirBlockGetFirstOperation(self.get()); 2210436c6c9cSStella Laurenzo return PyOperationIterator(self.getParentOperation(), 2211436c6c9cSStella Laurenzo firstOperation); 2212436c6c9cSStella Laurenzo }, 2213436c6c9cSStella Laurenzo "Iterates over operations in the block.") 2214436c6c9cSStella Laurenzo .def("__eq__", 2215436c6c9cSStella Laurenzo [](PyBlock &self, PyBlock &other) { 2216436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 2217436c6c9cSStella Laurenzo }) 2218436c6c9cSStella Laurenzo .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2219436c6c9cSStella Laurenzo .def( 2220436c6c9cSStella Laurenzo "__str__", 2221436c6c9cSStella Laurenzo [](PyBlock &self) { 2222436c6c9cSStella Laurenzo self.checkValid(); 2223436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2224436c6c9cSStella Laurenzo mlirBlockPrint(self.get(), printAccum.getCallback(), 2225436c6c9cSStella Laurenzo printAccum.getUserData()); 2226436c6c9cSStella Laurenzo return printAccum.join(); 2227436c6c9cSStella Laurenzo }, 2228436c6c9cSStella Laurenzo "Returns the assembly form of the block."); 2229436c6c9cSStella Laurenzo 2230436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2231436c6c9cSStella Laurenzo // Mapping of PyInsertionPoint. 2232436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2233436c6c9cSStella Laurenzo 2234436c6c9cSStella Laurenzo py::class_<PyInsertionPoint>(m, "InsertionPoint") 2235436c6c9cSStella Laurenzo .def(py::init<PyBlock &>(), py::arg("block"), 2236436c6c9cSStella Laurenzo "Inserts after the last operation but still inside the block.") 2237436c6c9cSStella Laurenzo .def("__enter__", &PyInsertionPoint::contextEnter) 2238436c6c9cSStella Laurenzo .def("__exit__", &PyInsertionPoint::contextExit) 2239436c6c9cSStella Laurenzo .def_property_readonly_static( 2240436c6c9cSStella Laurenzo "current", 2241436c6c9cSStella Laurenzo [](py::object & /*class*/) { 2242436c6c9cSStella Laurenzo auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2243436c6c9cSStella Laurenzo if (!ip) 2244436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2245436c6c9cSStella Laurenzo return ip; 2246436c6c9cSStella Laurenzo }, 2247436c6c9cSStella Laurenzo "Gets the InsertionPoint bound to the current thread or raises " 2248436c6c9cSStella Laurenzo "ValueError if none has been set") 2249436c6c9cSStella Laurenzo .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2250436c6c9cSStella Laurenzo "Inserts before a referenced operation.") 2251436c6c9cSStella Laurenzo .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2252436c6c9cSStella Laurenzo py::arg("block"), "Inserts at the beginning of the block.") 2253436c6c9cSStella Laurenzo .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2254436c6c9cSStella Laurenzo py::arg("block"), "Inserts before the block terminator.") 2255436c6c9cSStella Laurenzo .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2256436c6c9cSStella Laurenzo "Inserts an operation."); 2257436c6c9cSStella Laurenzo 2258436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2259436c6c9cSStella Laurenzo // Mapping of PyAttribute. 2260436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2261436c6c9cSStella Laurenzo py::class_<PyAttribute>(m, "Attribute") 2262b57d6fe4SStella Laurenzo // Delegate to the PyAttribute copy constructor, which will also lifetime 2263b57d6fe4SStella Laurenzo // extend the backing context which owns the MlirAttribute. 2264b57d6fe4SStella Laurenzo .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2265b57d6fe4SStella Laurenzo "Casts the passed attribute to the generic Attribute") 2266436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2267436c6c9cSStella Laurenzo &PyAttribute::getCapsule) 2268436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2269436c6c9cSStella Laurenzo .def_static( 2270436c6c9cSStella Laurenzo "parse", 2271436c6c9cSStella Laurenzo [](std::string attrSpec, DefaultingPyMlirContext context) { 2272436c6c9cSStella Laurenzo MlirAttribute type = mlirAttributeParseGet( 2273436c6c9cSStella Laurenzo context->get(), toMlirStringRef(attrSpec)); 2274436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 2275436c6c9cSStella Laurenzo // in C API. 2276436c6c9cSStella Laurenzo if (mlirAttributeIsNull(type)) { 2277436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 2278436c6c9cSStella Laurenzo Twine("Unable to parse attribute: '") + 2279436c6c9cSStella Laurenzo attrSpec + "'"); 2280436c6c9cSStella Laurenzo } 2281436c6c9cSStella Laurenzo return PyAttribute(context->getRef(), type); 2282436c6c9cSStella Laurenzo }, 2283436c6c9cSStella Laurenzo py::arg("asm"), py::arg("context") = py::none(), 2284436c6c9cSStella Laurenzo "Parses an attribute from an assembly form") 2285436c6c9cSStella Laurenzo .def_property_readonly( 2286436c6c9cSStella Laurenzo "context", 2287436c6c9cSStella Laurenzo [](PyAttribute &self) { return self.getContext().getObject(); }, 2288436c6c9cSStella Laurenzo "Context that owns the Attribute") 2289436c6c9cSStella Laurenzo .def_property_readonly("type", 2290436c6c9cSStella Laurenzo [](PyAttribute &self) { 2291436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 2292436c6c9cSStella Laurenzo mlirAttributeGetType(self)); 2293436c6c9cSStella Laurenzo }) 2294436c6c9cSStella Laurenzo .def( 2295436c6c9cSStella Laurenzo "get_named", 2296436c6c9cSStella Laurenzo [](PyAttribute &self, std::string name) { 2297436c6c9cSStella Laurenzo return PyNamedAttribute(self, std::move(name)); 2298436c6c9cSStella Laurenzo }, 2299436c6c9cSStella Laurenzo py::keep_alive<0, 1>(), "Binds a name to the attribute") 2300436c6c9cSStella Laurenzo .def("__eq__", 2301436c6c9cSStella Laurenzo [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2302436c6c9cSStella Laurenzo .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2303436c6c9cSStella Laurenzo .def( 2304436c6c9cSStella Laurenzo "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2305436c6c9cSStella Laurenzo kDumpDocstring) 2306436c6c9cSStella Laurenzo .def( 2307436c6c9cSStella Laurenzo "__str__", 2308436c6c9cSStella Laurenzo [](PyAttribute &self) { 2309436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2310436c6c9cSStella Laurenzo mlirAttributePrint(self, printAccum.getCallback(), 2311436c6c9cSStella Laurenzo printAccum.getUserData()); 2312436c6c9cSStella Laurenzo return printAccum.join(); 2313436c6c9cSStella Laurenzo }, 2314436c6c9cSStella Laurenzo "Returns the assembly form of the Attribute.") 2315436c6c9cSStella Laurenzo .def("__repr__", [](PyAttribute &self) { 2316436c6c9cSStella Laurenzo // Generally, assembly formats are not printed for __repr__ because 2317436c6c9cSStella Laurenzo // this can cause exceptionally long debug output and exceptions. 2318436c6c9cSStella Laurenzo // However, attribute values are generally considered useful and are 2319436c6c9cSStella Laurenzo // printed. This may need to be re-evaluated if debug dumps end up 2320436c6c9cSStella Laurenzo // being excessive. 2321436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2322436c6c9cSStella Laurenzo printAccum.parts.append("Attribute("); 2323436c6c9cSStella Laurenzo mlirAttributePrint(self, printAccum.getCallback(), 2324436c6c9cSStella Laurenzo printAccum.getUserData()); 2325436c6c9cSStella Laurenzo printAccum.parts.append(")"); 2326436c6c9cSStella Laurenzo return printAccum.join(); 2327436c6c9cSStella Laurenzo }); 2328436c6c9cSStella Laurenzo 2329436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2330436c6c9cSStella Laurenzo // Mapping of PyNamedAttribute 2331436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2332436c6c9cSStella Laurenzo py::class_<PyNamedAttribute>(m, "NamedAttribute") 2333436c6c9cSStella Laurenzo .def("__repr__", 2334436c6c9cSStella Laurenzo [](PyNamedAttribute &self) { 2335436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2336436c6c9cSStella Laurenzo printAccum.parts.append("NamedAttribute("); 2337436c6c9cSStella Laurenzo printAccum.parts.append( 2338436c6c9cSStella Laurenzo mlirIdentifierStr(self.namedAttr.name).data); 2339436c6c9cSStella Laurenzo printAccum.parts.append("="); 2340436c6c9cSStella Laurenzo mlirAttributePrint(self.namedAttr.attribute, 2341436c6c9cSStella Laurenzo printAccum.getCallback(), 2342436c6c9cSStella Laurenzo printAccum.getUserData()); 2343436c6c9cSStella Laurenzo printAccum.parts.append(")"); 2344436c6c9cSStella Laurenzo return printAccum.join(); 2345436c6c9cSStella Laurenzo }) 2346436c6c9cSStella Laurenzo .def_property_readonly( 2347436c6c9cSStella Laurenzo "name", 2348436c6c9cSStella Laurenzo [](PyNamedAttribute &self) { 2349436c6c9cSStella Laurenzo return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2350436c6c9cSStella Laurenzo mlirIdentifierStr(self.namedAttr.name).length); 2351436c6c9cSStella Laurenzo }, 2352436c6c9cSStella Laurenzo "The name of the NamedAttribute binding") 2353436c6c9cSStella Laurenzo .def_property_readonly( 2354436c6c9cSStella Laurenzo "attr", 2355436c6c9cSStella Laurenzo [](PyNamedAttribute &self) { 2356436c6c9cSStella Laurenzo // TODO: When named attribute is removed/refactored, also remove 2357436c6c9cSStella Laurenzo // this constructor (it does an inefficient table lookup). 2358436c6c9cSStella Laurenzo auto contextRef = PyMlirContext::forContext( 2359436c6c9cSStella Laurenzo mlirAttributeGetContext(self.namedAttr.attribute)); 2360436c6c9cSStella Laurenzo return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2361436c6c9cSStella Laurenzo }, 2362436c6c9cSStella Laurenzo py::keep_alive<0, 1>(), 2363436c6c9cSStella Laurenzo "The underlying generic attribute of the NamedAttribute binding"); 2364436c6c9cSStella Laurenzo 2365436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2366436c6c9cSStella Laurenzo // Mapping of PyType. 2367436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2368436c6c9cSStella Laurenzo py::class_<PyType>(m, "Type") 2369b57d6fe4SStella Laurenzo // Delegate to the PyType copy constructor, which will also lifetime 2370b57d6fe4SStella Laurenzo // extend the backing context which owns the MlirType. 2371b57d6fe4SStella Laurenzo .def(py::init<PyType &>(), py::arg("cast_from_type"), 2372b57d6fe4SStella Laurenzo "Casts the passed type to the generic Type") 2373436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2374436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2375436c6c9cSStella Laurenzo .def_static( 2376436c6c9cSStella Laurenzo "parse", 2377436c6c9cSStella Laurenzo [](std::string typeSpec, DefaultingPyMlirContext context) { 2378436c6c9cSStella Laurenzo MlirType type = 2379436c6c9cSStella Laurenzo mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2380436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 2381436c6c9cSStella Laurenzo // in C API. 2382436c6c9cSStella Laurenzo if (mlirTypeIsNull(type)) { 2383436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 2384436c6c9cSStella Laurenzo Twine("Unable to parse type: '") + typeSpec + 2385436c6c9cSStella Laurenzo "'"); 2386436c6c9cSStella Laurenzo } 2387436c6c9cSStella Laurenzo return PyType(context->getRef(), type); 2388436c6c9cSStella Laurenzo }, 2389436c6c9cSStella Laurenzo py::arg("asm"), py::arg("context") = py::none(), 2390436c6c9cSStella Laurenzo kContextParseTypeDocstring) 2391436c6c9cSStella Laurenzo .def_property_readonly( 2392436c6c9cSStella Laurenzo "context", [](PyType &self) { return self.getContext().getObject(); }, 2393436c6c9cSStella Laurenzo "Context that owns the Type") 2394436c6c9cSStella Laurenzo .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2395436c6c9cSStella Laurenzo .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2396436c6c9cSStella Laurenzo .def( 2397436c6c9cSStella Laurenzo "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2398436c6c9cSStella Laurenzo .def( 2399436c6c9cSStella Laurenzo "__str__", 2400436c6c9cSStella Laurenzo [](PyType &self) { 2401436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2402436c6c9cSStella Laurenzo mlirTypePrint(self, printAccum.getCallback(), 2403436c6c9cSStella Laurenzo printAccum.getUserData()); 2404436c6c9cSStella Laurenzo return printAccum.join(); 2405436c6c9cSStella Laurenzo }, 2406436c6c9cSStella Laurenzo "Returns the assembly form of the type.") 2407436c6c9cSStella Laurenzo .def("__repr__", [](PyType &self) { 2408436c6c9cSStella Laurenzo // Generally, assembly formats are not printed for __repr__ because 2409436c6c9cSStella Laurenzo // this can cause exceptionally long debug output and exceptions. 2410436c6c9cSStella Laurenzo // However, types are an exception as they typically have compact 2411436c6c9cSStella Laurenzo // assembly forms and printing them is useful. 2412436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2413436c6c9cSStella Laurenzo printAccum.parts.append("Type("); 2414436c6c9cSStella Laurenzo mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2415436c6c9cSStella Laurenzo printAccum.parts.append(")"); 2416436c6c9cSStella Laurenzo return printAccum.join(); 2417436c6c9cSStella Laurenzo }); 2418436c6c9cSStella Laurenzo 2419436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2420436c6c9cSStella Laurenzo // Mapping of Value. 2421436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2422436c6c9cSStella Laurenzo py::class_<PyValue>(m, "Value") 24233f3d1c90SMike Urbach .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 24243f3d1c90SMike Urbach .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2425436c6c9cSStella Laurenzo .def_property_readonly( 2426436c6c9cSStella Laurenzo "context", 2427436c6c9cSStella Laurenzo [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2428436c6c9cSStella Laurenzo "Context in which the value lives.") 2429436c6c9cSStella Laurenzo .def( 2430436c6c9cSStella Laurenzo "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2431436c6c9cSStella Laurenzo kDumpDocstring) 2432436c6c9cSStella Laurenzo .def("__eq__", 2433436c6c9cSStella Laurenzo [](PyValue &self, PyValue &other) { 2434436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 2435436c6c9cSStella Laurenzo }) 2436436c6c9cSStella Laurenzo .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2437436c6c9cSStella Laurenzo .def( 2438436c6c9cSStella Laurenzo "__str__", 2439436c6c9cSStella Laurenzo [](PyValue &self) { 2440436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2441436c6c9cSStella Laurenzo printAccum.parts.append("Value("); 2442436c6c9cSStella Laurenzo mlirValuePrint(self.get(), printAccum.getCallback(), 2443436c6c9cSStella Laurenzo printAccum.getUserData()); 2444436c6c9cSStella Laurenzo printAccum.parts.append(")"); 2445436c6c9cSStella Laurenzo return printAccum.join(); 2446436c6c9cSStella Laurenzo }, 2447436c6c9cSStella Laurenzo kValueDunderStrDocstring) 2448436c6c9cSStella Laurenzo .def_property_readonly("type", [](PyValue &self) { 2449436c6c9cSStella Laurenzo return PyType(self.getParentOperation()->getContext(), 2450436c6c9cSStella Laurenzo mlirValueGetType(self.get())); 2451436c6c9cSStella Laurenzo }); 2452436c6c9cSStella Laurenzo PyBlockArgument::bind(m); 2453436c6c9cSStella Laurenzo PyOpResult::bind(m); 2454436c6c9cSStella Laurenzo 2455436c6c9cSStella Laurenzo // Container bindings. 2456436c6c9cSStella Laurenzo PyBlockArgumentList::bind(m); 2457436c6c9cSStella Laurenzo PyBlockIterator::bind(m); 2458436c6c9cSStella Laurenzo PyBlockList::bind(m); 2459436c6c9cSStella Laurenzo PyOperationIterator::bind(m); 2460436c6c9cSStella Laurenzo PyOperationList::bind(m); 2461436c6c9cSStella Laurenzo PyOpAttributeMap::bind(m); 2462436c6c9cSStella Laurenzo PyOpOperandList::bind(m); 2463436c6c9cSStella Laurenzo PyOpResultList::bind(m); 2464436c6c9cSStella Laurenzo PyRegionIterator::bind(m); 2465436c6c9cSStella Laurenzo PyRegionList::bind(m); 24664acd8457SAlex Zinenko 24674acd8457SAlex Zinenko // Debug bindings. 24684acd8457SAlex Zinenko PyGlobalDebugFlag::bind(m); 2469436c6c9cSStella Laurenzo } 2470