1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9436c6c9cSStella Laurenzo #include "IRModule.h" 10436c6c9cSStella Laurenzo 11436c6c9cSStella Laurenzo #include "Globals.h" 12436c6c9cSStella Laurenzo #include "PybindUtils.h" 13436c6c9cSStella Laurenzo 14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h" 15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h" 16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h" 174acd8457SAlex Zinenko #include "mlir-c/Debug.h" 183f3d1c90SMike Urbach #include "mlir-c/IR.h" 19436c6c9cSStella Laurenzo #include "mlir-c/Registration.h" 20e67cbbefSJacques Pienaar #include "llvm/ADT/ArrayRef.h" 21436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h" 22436c6c9cSStella Laurenzo #include <pybind11/stl.h> 23436c6c9cSStella Laurenzo 24436c6c9cSStella Laurenzo namespace py = pybind11; 25436c6c9cSStella Laurenzo using namespace mlir; 26436c6c9cSStella Laurenzo using namespace mlir::python; 27436c6c9cSStella Laurenzo 28436c6c9cSStella Laurenzo using llvm::SmallVector; 29436c6c9cSStella Laurenzo using llvm::StringRef; 30436c6c9cSStella Laurenzo using llvm::Twine; 31436c6c9cSStella Laurenzo 32436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 33436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline). 34436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 35436c6c9cSStella Laurenzo 36436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] = 37436c6c9cSStella Laurenzo R"(Parses the assembly form of a type. 38436c6c9cSStella Laurenzo 39436c6c9cSStella Laurenzo Returns a Type object or raises a ValueError if the type cannot be parsed. 40436c6c9cSStella Laurenzo 41436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system 42436c6c9cSStella Laurenzo )"; 43436c6c9cSStella Laurenzo 44e67cbbefSJacques Pienaar static const char kContextGetCallSiteLocationDocstring[] = 45e67cbbefSJacques Pienaar R"(Gets a Location representing a caller and callsite)"; 46e67cbbefSJacques Pienaar 47436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] = 48436c6c9cSStella Laurenzo R"(Gets a Location representing a file, line and column)"; 49436c6c9cSStella Laurenzo 5004d76d36SJacques Pienaar static const char kContextGetNameLocationDocString[] = 5104d76d36SJacques Pienaar R"(Gets a Location representing a named location with optional child location)"; 5204d76d36SJacques Pienaar 53436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] = 54436c6c9cSStella Laurenzo R"(Parses a module's assembly format from a string. 55436c6c9cSStella Laurenzo 56436c6c9cSStella Laurenzo Returns a new MlirModule or raises a ValueError if the parsing fails. 57436c6c9cSStella Laurenzo 58436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/ 59436c6c9cSStella Laurenzo )"; 60436c6c9cSStella Laurenzo 61436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] = 62436c6c9cSStella Laurenzo R"(Creates a new operation. 63436c6c9cSStella Laurenzo 64436c6c9cSStella Laurenzo Args: 65436c6c9cSStella Laurenzo name: Operation name (e.g. "dialect.operation"). 66436c6c9cSStella Laurenzo results: Sequence of Type representing op result types. 67436c6c9cSStella Laurenzo attributes: Dict of str:Attribute. 68436c6c9cSStella Laurenzo successors: List of Block for the operation's successors. 69436c6c9cSStella Laurenzo regions: Number of regions to create. 70436c6c9cSStella Laurenzo location: A Location object (defaults to resolve from context manager). 71436c6c9cSStella Laurenzo ip: An InsertionPoint (defaults to resolve from context manager or set to 72436c6c9cSStella Laurenzo False to disable insertion, even with an insertion point set in the 73436c6c9cSStella Laurenzo context manager). 74436c6c9cSStella Laurenzo Returns: 75436c6c9cSStella Laurenzo A new "detached" Operation object. Detached operations can be added 76436c6c9cSStella Laurenzo to blocks, which causes them to become "attached." 77436c6c9cSStella Laurenzo )"; 78436c6c9cSStella Laurenzo 79436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] = 80436c6c9cSStella Laurenzo R"(Prints the assembly form of the operation to a file like object. 81436c6c9cSStella Laurenzo 82436c6c9cSStella Laurenzo Args: 83436c6c9cSStella Laurenzo file: The file like object to write to. Defaults to sys.stdout. 84436c6c9cSStella Laurenzo binary: Whether to write bytes (True) or str (False). Defaults to False. 85436c6c9cSStella Laurenzo large_elements_limit: Whether to elide elements attributes above this 86436c6c9cSStella Laurenzo number of elements. Defaults to None (no limit). 87436c6c9cSStella Laurenzo enable_debug_info: Whether to print debug/location information. Defaults 88436c6c9cSStella Laurenzo to False. 89436c6c9cSStella Laurenzo pretty_debug_info: Whether to format debug information for easier reading 90436c6c9cSStella Laurenzo by a human (warning: the result is unparseable). 91436c6c9cSStella Laurenzo print_generic_op_form: Whether to print the generic assembly forms of all 92436c6c9cSStella Laurenzo ops. Defaults to False. 93436c6c9cSStella Laurenzo use_local_Scope: Whether to print in a way that is more optimized for 94436c6c9cSStella Laurenzo multi-threaded access but may not be consistent with how the overall 95436c6c9cSStella Laurenzo module prints. 96ace1d0adSStella Laurenzo assume_verified: By default, if not printing generic form, the verifier 97ace1d0adSStella Laurenzo will be run and if it fails, generic form will be printed with a comment 98ace1d0adSStella Laurenzo about failed verification. While a reasonable default for interactive use, 99ace1d0adSStella Laurenzo for systematic use, it is often better for the caller to verify explicitly 100ace1d0adSStella Laurenzo and report failures in a more robust fashion. Set this to True if doing this 101ace1d0adSStella Laurenzo in order to avoid running a redundant verification. If the IR is actually 102ace1d0adSStella Laurenzo invalid, behavior is undefined. 103436c6c9cSStella Laurenzo )"; 104436c6c9cSStella Laurenzo 105436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] = 106436c6c9cSStella Laurenzo R"(Gets the assembly form of the operation with all options available. 107436c6c9cSStella Laurenzo 108436c6c9cSStella Laurenzo Args: 109436c6c9cSStella Laurenzo binary: Whether to return a bytes (True) or str (False) object. Defaults to 110436c6c9cSStella Laurenzo False. 111436c6c9cSStella Laurenzo ... others ...: See the print() method for common keyword arguments for 112436c6c9cSStella Laurenzo configuring the printout. 113436c6c9cSStella Laurenzo Returns: 114436c6c9cSStella Laurenzo Either a bytes or str object, depending on the setting of the 'binary' 115436c6c9cSStella Laurenzo argument. 116436c6c9cSStella Laurenzo )"; 117436c6c9cSStella Laurenzo 118436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] = 119436c6c9cSStella Laurenzo R"(Gets the assembly form of the operation with default options. 120436c6c9cSStella Laurenzo 121436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed, 122436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to 123436c6c9cSStella Laurenzo customize behavior. 124436c6c9cSStella Laurenzo )"; 125436c6c9cSStella Laurenzo 126436c6c9cSStella Laurenzo static const char kDumpDocstring[] = 127436c6c9cSStella Laurenzo R"(Dumps a debug representation of the object to stderr.)"; 128436c6c9cSStella Laurenzo 129436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] = 130436c6c9cSStella Laurenzo R"(Appends a new block, with argument types as positional args. 131436c6c9cSStella Laurenzo 132436c6c9cSStella Laurenzo Returns: 133436c6c9cSStella Laurenzo The created block. 134436c6c9cSStella Laurenzo )"; 135436c6c9cSStella Laurenzo 136436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] = 137436c6c9cSStella Laurenzo R"(Returns the string form of the value. 138436c6c9cSStella Laurenzo 139436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the 140436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is 141436c6c9cSStella Laurenzo equivalent to printing the operation that produced it. 142436c6c9cSStella Laurenzo )"; 143436c6c9cSStella Laurenzo 144436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 145436c6c9cSStella Laurenzo // Utilities. 146436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 147436c6c9cSStella Laurenzo 1484acd8457SAlex Zinenko /// Helper for creating an @classmethod. 149436c6c9cSStella Laurenzo template <class Func, typename... Args> 150436c6c9cSStella Laurenzo py::object classmethod(Func f, Args... args) { 151436c6c9cSStella Laurenzo py::object cf = py::cpp_function(f, args...); 152436c6c9cSStella Laurenzo return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 153436c6c9cSStella Laurenzo } 154436c6c9cSStella Laurenzo 155436c6c9cSStella Laurenzo static py::object 156436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace, 157436c6c9cSStella Laurenzo py::object dialectDescriptor) { 158436c6c9cSStella Laurenzo auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 159436c6c9cSStella Laurenzo if (!dialectClass) { 160436c6c9cSStella Laurenzo // Use the base class. 161436c6c9cSStella Laurenzo return py::cast(PyDialect(std::move(dialectDescriptor))); 162436c6c9cSStella Laurenzo } 163436c6c9cSStella Laurenzo 164436c6c9cSStella Laurenzo // Create the custom implementation. 165436c6c9cSStella Laurenzo return (*dialectClass)(std::move(dialectDescriptor)); 166436c6c9cSStella Laurenzo } 167436c6c9cSStella Laurenzo 168436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) { 169436c6c9cSStella Laurenzo return mlirStringRefCreate(s.data(), s.size()); 170436c6c9cSStella Laurenzo } 171436c6c9cSStella Laurenzo 1724acd8457SAlex Zinenko /// Wrapper for the global LLVM debugging flag. 1734acd8457SAlex Zinenko struct PyGlobalDebugFlag { 1744acd8457SAlex Zinenko static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 1754acd8457SAlex Zinenko 1764acd8457SAlex Zinenko static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 1774acd8457SAlex Zinenko 1784acd8457SAlex Zinenko static void bind(py::module &m) { 1794acd8457SAlex Zinenko // Debug flags. 180f05ff4f7SStella Laurenzo py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 1814acd8457SAlex Zinenko .def_property_static("flag", &PyGlobalDebugFlag::get, 1824acd8457SAlex Zinenko &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 1834acd8457SAlex Zinenko } 1844acd8457SAlex Zinenko }; 1854acd8457SAlex Zinenko 186436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 187436c6c9cSStella Laurenzo // Collections. 188436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 189436c6c9cSStella Laurenzo 190436c6c9cSStella Laurenzo namespace { 191436c6c9cSStella Laurenzo 192436c6c9cSStella Laurenzo class PyRegionIterator { 193436c6c9cSStella Laurenzo public: 194436c6c9cSStella Laurenzo PyRegionIterator(PyOperationRef operation) 195436c6c9cSStella Laurenzo : operation(std::move(operation)) {} 196436c6c9cSStella Laurenzo 197436c6c9cSStella Laurenzo PyRegionIterator &dunderIter() { return *this; } 198436c6c9cSStella Laurenzo 199436c6c9cSStella Laurenzo PyRegion dunderNext() { 200436c6c9cSStella Laurenzo operation->checkValid(); 201436c6c9cSStella Laurenzo if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 202436c6c9cSStella Laurenzo throw py::stop_iteration(); 203436c6c9cSStella Laurenzo } 204436c6c9cSStella Laurenzo MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 205436c6c9cSStella Laurenzo return PyRegion(operation, region); 206436c6c9cSStella Laurenzo } 207436c6c9cSStella Laurenzo 208436c6c9cSStella Laurenzo static void bind(py::module &m) { 209f05ff4f7SStella Laurenzo py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 210436c6c9cSStella Laurenzo .def("__iter__", &PyRegionIterator::dunderIter) 211436c6c9cSStella Laurenzo .def("__next__", &PyRegionIterator::dunderNext); 212436c6c9cSStella Laurenzo } 213436c6c9cSStella Laurenzo 214436c6c9cSStella Laurenzo private: 215436c6c9cSStella Laurenzo PyOperationRef operation; 216436c6c9cSStella Laurenzo int nextIndex = 0; 217436c6c9cSStella Laurenzo }; 218436c6c9cSStella Laurenzo 219436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented 220436c6c9cSStella Laurenzo /// with a sequence-like container. 221436c6c9cSStella Laurenzo class PyRegionList { 222436c6c9cSStella Laurenzo public: 223436c6c9cSStella Laurenzo PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 224436c6c9cSStella Laurenzo 225436c6c9cSStella Laurenzo intptr_t dunderLen() { 226436c6c9cSStella Laurenzo operation->checkValid(); 227436c6c9cSStella Laurenzo return mlirOperationGetNumRegions(operation->get()); 228436c6c9cSStella Laurenzo } 229436c6c9cSStella Laurenzo 230436c6c9cSStella Laurenzo PyRegion dunderGetItem(intptr_t index) { 231436c6c9cSStella Laurenzo // dunderLen checks validity. 232436c6c9cSStella Laurenzo if (index < 0 || index >= dunderLen()) { 233436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 234436c6c9cSStella Laurenzo "attempt to access out of bounds region"); 235436c6c9cSStella Laurenzo } 236436c6c9cSStella Laurenzo MlirRegion region = mlirOperationGetRegion(operation->get(), index); 237436c6c9cSStella Laurenzo return PyRegion(operation, region); 238436c6c9cSStella Laurenzo } 239436c6c9cSStella Laurenzo 240436c6c9cSStella Laurenzo static void bind(py::module &m) { 241f05ff4f7SStella Laurenzo py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 242436c6c9cSStella Laurenzo .def("__len__", &PyRegionList::dunderLen) 243436c6c9cSStella Laurenzo .def("__getitem__", &PyRegionList::dunderGetItem); 244436c6c9cSStella Laurenzo } 245436c6c9cSStella Laurenzo 246436c6c9cSStella Laurenzo private: 247436c6c9cSStella Laurenzo PyOperationRef operation; 248436c6c9cSStella Laurenzo }; 249436c6c9cSStella Laurenzo 250436c6c9cSStella Laurenzo class PyBlockIterator { 251436c6c9cSStella Laurenzo public: 252436c6c9cSStella Laurenzo PyBlockIterator(PyOperationRef operation, MlirBlock next) 253436c6c9cSStella Laurenzo : operation(std::move(operation)), next(next) {} 254436c6c9cSStella Laurenzo 255436c6c9cSStella Laurenzo PyBlockIterator &dunderIter() { return *this; } 256436c6c9cSStella Laurenzo 257436c6c9cSStella Laurenzo PyBlock dunderNext() { 258436c6c9cSStella Laurenzo operation->checkValid(); 259436c6c9cSStella Laurenzo if (mlirBlockIsNull(next)) { 260436c6c9cSStella Laurenzo throw py::stop_iteration(); 261436c6c9cSStella Laurenzo } 262436c6c9cSStella Laurenzo 263436c6c9cSStella Laurenzo PyBlock returnBlock(operation, next); 264436c6c9cSStella Laurenzo next = mlirBlockGetNextInRegion(next); 265436c6c9cSStella Laurenzo return returnBlock; 266436c6c9cSStella Laurenzo } 267436c6c9cSStella Laurenzo 268436c6c9cSStella Laurenzo static void bind(py::module &m) { 269f05ff4f7SStella Laurenzo py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 270436c6c9cSStella Laurenzo .def("__iter__", &PyBlockIterator::dunderIter) 271436c6c9cSStella Laurenzo .def("__next__", &PyBlockIterator::dunderNext); 272436c6c9cSStella Laurenzo } 273436c6c9cSStella Laurenzo 274436c6c9cSStella Laurenzo private: 275436c6c9cSStella Laurenzo PyOperationRef operation; 276436c6c9cSStella Laurenzo MlirBlock next; 277436c6c9cSStella Laurenzo }; 278436c6c9cSStella Laurenzo 279436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 280436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize 281436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region. 282436c6c9cSStella Laurenzo class PyBlockList { 283436c6c9cSStella Laurenzo public: 284436c6c9cSStella Laurenzo PyBlockList(PyOperationRef operation, MlirRegion region) 285436c6c9cSStella Laurenzo : operation(std::move(operation)), region(region) {} 286436c6c9cSStella Laurenzo 287436c6c9cSStella Laurenzo PyBlockIterator dunderIter() { 288436c6c9cSStella Laurenzo operation->checkValid(); 289436c6c9cSStella Laurenzo return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 290436c6c9cSStella Laurenzo } 291436c6c9cSStella Laurenzo 292436c6c9cSStella Laurenzo intptr_t dunderLen() { 293436c6c9cSStella Laurenzo operation->checkValid(); 294436c6c9cSStella Laurenzo intptr_t count = 0; 295436c6c9cSStella Laurenzo MlirBlock block = mlirRegionGetFirstBlock(region); 296436c6c9cSStella Laurenzo while (!mlirBlockIsNull(block)) { 297436c6c9cSStella Laurenzo count += 1; 298436c6c9cSStella Laurenzo block = mlirBlockGetNextInRegion(block); 299436c6c9cSStella Laurenzo } 300436c6c9cSStella Laurenzo return count; 301436c6c9cSStella Laurenzo } 302436c6c9cSStella Laurenzo 303436c6c9cSStella Laurenzo PyBlock dunderGetItem(intptr_t index) { 304436c6c9cSStella Laurenzo operation->checkValid(); 305436c6c9cSStella Laurenzo if (index < 0) { 306436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 307436c6c9cSStella Laurenzo "attempt to access out of bounds block"); 308436c6c9cSStella Laurenzo } 309436c6c9cSStella Laurenzo MlirBlock block = mlirRegionGetFirstBlock(region); 310436c6c9cSStella Laurenzo while (!mlirBlockIsNull(block)) { 311436c6c9cSStella Laurenzo if (index == 0) { 312436c6c9cSStella Laurenzo return PyBlock(operation, block); 313436c6c9cSStella Laurenzo } 314436c6c9cSStella Laurenzo block = mlirBlockGetNextInRegion(block); 315436c6c9cSStella Laurenzo index -= 1; 316436c6c9cSStella Laurenzo } 317436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 318436c6c9cSStella Laurenzo } 319436c6c9cSStella Laurenzo 320436c6c9cSStella Laurenzo PyBlock appendBlock(py::args pyArgTypes) { 321436c6c9cSStella Laurenzo operation->checkValid(); 322436c6c9cSStella Laurenzo llvm::SmallVector<MlirType, 4> argTypes; 323436c6c9cSStella Laurenzo argTypes.reserve(pyArgTypes.size()); 324436c6c9cSStella Laurenzo for (auto &pyArg : pyArgTypes) { 325436c6c9cSStella Laurenzo argTypes.push_back(pyArg.cast<PyType &>()); 326436c6c9cSStella Laurenzo } 327436c6c9cSStella Laurenzo 328436c6c9cSStella Laurenzo MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 329436c6c9cSStella Laurenzo mlirRegionAppendOwnedBlock(region, block); 330436c6c9cSStella Laurenzo return PyBlock(operation, block); 331436c6c9cSStella Laurenzo } 332436c6c9cSStella Laurenzo 333436c6c9cSStella Laurenzo static void bind(py::module &m) { 334f05ff4f7SStella Laurenzo py::class_<PyBlockList>(m, "BlockList", py::module_local()) 335436c6c9cSStella Laurenzo .def("__getitem__", &PyBlockList::dunderGetItem) 336436c6c9cSStella Laurenzo .def("__iter__", &PyBlockList::dunderIter) 337436c6c9cSStella Laurenzo .def("__len__", &PyBlockList::dunderLen) 338436c6c9cSStella Laurenzo .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 339436c6c9cSStella Laurenzo } 340436c6c9cSStella Laurenzo 341436c6c9cSStella Laurenzo private: 342436c6c9cSStella Laurenzo PyOperationRef operation; 343436c6c9cSStella Laurenzo MlirRegion region; 344436c6c9cSStella Laurenzo }; 345436c6c9cSStella Laurenzo 346436c6c9cSStella Laurenzo class PyOperationIterator { 347436c6c9cSStella Laurenzo public: 348436c6c9cSStella Laurenzo PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 349436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), next(next) {} 350436c6c9cSStella Laurenzo 351436c6c9cSStella Laurenzo PyOperationIterator &dunderIter() { return *this; } 352436c6c9cSStella Laurenzo 353436c6c9cSStella Laurenzo py::object dunderNext() { 354436c6c9cSStella Laurenzo parentOperation->checkValid(); 355436c6c9cSStella Laurenzo if (mlirOperationIsNull(next)) { 356436c6c9cSStella Laurenzo throw py::stop_iteration(); 357436c6c9cSStella Laurenzo } 358436c6c9cSStella Laurenzo 359436c6c9cSStella Laurenzo PyOperationRef returnOperation = 360436c6c9cSStella Laurenzo PyOperation::forOperation(parentOperation->getContext(), next); 361436c6c9cSStella Laurenzo next = mlirOperationGetNextInBlock(next); 362436c6c9cSStella Laurenzo return returnOperation->createOpView(); 363436c6c9cSStella Laurenzo } 364436c6c9cSStella Laurenzo 365436c6c9cSStella Laurenzo static void bind(py::module &m) { 366f05ff4f7SStella Laurenzo py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 367436c6c9cSStella Laurenzo .def("__iter__", &PyOperationIterator::dunderIter) 368436c6c9cSStella Laurenzo .def("__next__", &PyOperationIterator::dunderNext); 369436c6c9cSStella Laurenzo } 370436c6c9cSStella Laurenzo 371436c6c9cSStella Laurenzo private: 372436c6c9cSStella Laurenzo PyOperationRef parentOperation; 373436c6c9cSStella Laurenzo MlirOperation next; 374436c6c9cSStella Laurenzo }; 375436c6c9cSStella Laurenzo 376436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In 377436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but 378436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned 379436c6c9cSStella Laurenzo /// by a block. 380436c6c9cSStella Laurenzo class PyOperationList { 381436c6c9cSStella Laurenzo public: 382436c6c9cSStella Laurenzo PyOperationList(PyOperationRef parentOperation, MlirBlock block) 383436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), block(block) {} 384436c6c9cSStella Laurenzo 385436c6c9cSStella Laurenzo PyOperationIterator dunderIter() { 386436c6c9cSStella Laurenzo parentOperation->checkValid(); 387436c6c9cSStella Laurenzo return PyOperationIterator(parentOperation, 388436c6c9cSStella Laurenzo mlirBlockGetFirstOperation(block)); 389436c6c9cSStella Laurenzo } 390436c6c9cSStella Laurenzo 391436c6c9cSStella Laurenzo intptr_t dunderLen() { 392436c6c9cSStella Laurenzo parentOperation->checkValid(); 393436c6c9cSStella Laurenzo intptr_t count = 0; 394436c6c9cSStella Laurenzo MlirOperation childOp = mlirBlockGetFirstOperation(block); 395436c6c9cSStella Laurenzo while (!mlirOperationIsNull(childOp)) { 396436c6c9cSStella Laurenzo count += 1; 397436c6c9cSStella Laurenzo childOp = mlirOperationGetNextInBlock(childOp); 398436c6c9cSStella Laurenzo } 399436c6c9cSStella Laurenzo return count; 400436c6c9cSStella Laurenzo } 401436c6c9cSStella Laurenzo 402436c6c9cSStella Laurenzo py::object dunderGetItem(intptr_t index) { 403436c6c9cSStella Laurenzo parentOperation->checkValid(); 404436c6c9cSStella Laurenzo if (index < 0) { 405436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 406436c6c9cSStella Laurenzo "attempt to access out of bounds operation"); 407436c6c9cSStella Laurenzo } 408436c6c9cSStella Laurenzo MlirOperation childOp = mlirBlockGetFirstOperation(block); 409436c6c9cSStella Laurenzo while (!mlirOperationIsNull(childOp)) { 410436c6c9cSStella Laurenzo if (index == 0) { 411436c6c9cSStella Laurenzo return PyOperation::forOperation(parentOperation->getContext(), childOp) 412436c6c9cSStella Laurenzo ->createOpView(); 413436c6c9cSStella Laurenzo } 414436c6c9cSStella Laurenzo childOp = mlirOperationGetNextInBlock(childOp); 415436c6c9cSStella Laurenzo index -= 1; 416436c6c9cSStella Laurenzo } 417436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 418436c6c9cSStella Laurenzo "attempt to access out of bounds operation"); 419436c6c9cSStella Laurenzo } 420436c6c9cSStella Laurenzo 421436c6c9cSStella Laurenzo static void bind(py::module &m) { 422f05ff4f7SStella Laurenzo py::class_<PyOperationList>(m, "OperationList", py::module_local()) 423436c6c9cSStella Laurenzo .def("__getitem__", &PyOperationList::dunderGetItem) 424436c6c9cSStella Laurenzo .def("__iter__", &PyOperationList::dunderIter) 425436c6c9cSStella Laurenzo .def("__len__", &PyOperationList::dunderLen); 426436c6c9cSStella Laurenzo } 427436c6c9cSStella Laurenzo 428436c6c9cSStella Laurenzo private: 429436c6c9cSStella Laurenzo PyOperationRef parentOperation; 430436c6c9cSStella Laurenzo MlirBlock block; 431436c6c9cSStella Laurenzo }; 432436c6c9cSStella Laurenzo 433436c6c9cSStella Laurenzo } // namespace 434436c6c9cSStella Laurenzo 435436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 436436c6c9cSStella Laurenzo // PyMlirContext 437436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 438436c6c9cSStella Laurenzo 439436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 440436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 441436c6c9cSStella Laurenzo auto &liveContexts = getLiveContexts(); 442436c6c9cSStella Laurenzo liveContexts[context.ptr] = this; 443436c6c9cSStella Laurenzo } 444436c6c9cSStella Laurenzo 445436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() { 446436c6c9cSStella Laurenzo // Note that the only public way to construct an instance is via the 447436c6c9cSStella Laurenzo // forContext method, which always puts the associated handle into 448436c6c9cSStella Laurenzo // liveContexts. 449436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 450436c6c9cSStella Laurenzo getLiveContexts().erase(context.ptr); 451436c6c9cSStella Laurenzo mlirContextDestroy(context); 452436c6c9cSStella Laurenzo } 453436c6c9cSStella Laurenzo 454436c6c9cSStella Laurenzo py::object PyMlirContext::getCapsule() { 455436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 456436c6c9cSStella Laurenzo } 457436c6c9cSStella Laurenzo 458436c6c9cSStella Laurenzo py::object PyMlirContext::createFromCapsule(py::object capsule) { 459436c6c9cSStella Laurenzo MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 460436c6c9cSStella Laurenzo if (mlirContextIsNull(rawContext)) 461436c6c9cSStella Laurenzo throw py::error_already_set(); 462436c6c9cSStella Laurenzo return forContext(rawContext).releaseObject(); 463436c6c9cSStella Laurenzo } 464436c6c9cSStella Laurenzo 465436c6c9cSStella Laurenzo PyMlirContext *PyMlirContext::createNewContextForInit() { 466436c6c9cSStella Laurenzo MlirContext context = mlirContextCreate(); 467436c6c9cSStella Laurenzo mlirRegisterAllDialects(context); 468436c6c9cSStella Laurenzo return new PyMlirContext(context); 469436c6c9cSStella Laurenzo } 470436c6c9cSStella Laurenzo 471436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 472436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 473436c6c9cSStella Laurenzo auto &liveContexts = getLiveContexts(); 474436c6c9cSStella Laurenzo auto it = liveContexts.find(context.ptr); 475436c6c9cSStella Laurenzo if (it == liveContexts.end()) { 476436c6c9cSStella Laurenzo // Create. 477436c6c9cSStella Laurenzo PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 478436c6c9cSStella Laurenzo py::object pyRef = py::cast(unownedContextWrapper); 479436c6c9cSStella Laurenzo assert(pyRef && "cast to py::object failed"); 480436c6c9cSStella Laurenzo liveContexts[context.ptr] = unownedContextWrapper; 481436c6c9cSStella Laurenzo return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 482436c6c9cSStella Laurenzo } 483436c6c9cSStella Laurenzo // Use existing. 484436c6c9cSStella Laurenzo py::object pyRef = py::cast(it->second); 485436c6c9cSStella Laurenzo return PyMlirContextRef(it->second, std::move(pyRef)); 486436c6c9cSStella Laurenzo } 487436c6c9cSStella Laurenzo 488436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 489436c6c9cSStella Laurenzo static LiveContextMap liveContexts; 490436c6c9cSStella Laurenzo return liveContexts; 491436c6c9cSStella Laurenzo } 492436c6c9cSStella Laurenzo 493436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 494436c6c9cSStella Laurenzo 495436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 496436c6c9cSStella Laurenzo 497436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 498436c6c9cSStella Laurenzo 499436c6c9cSStella Laurenzo pybind11::object PyMlirContext::contextEnter() { 500436c6c9cSStella Laurenzo return PyThreadContextEntry::pushContext(*this); 501436c6c9cSStella Laurenzo } 502436c6c9cSStella Laurenzo 503436c6c9cSStella Laurenzo void PyMlirContext::contextExit(pybind11::object excType, 504436c6c9cSStella Laurenzo pybind11::object excVal, 505436c6c9cSStella Laurenzo pybind11::object excTb) { 506436c6c9cSStella Laurenzo PyThreadContextEntry::popContext(*this); 507436c6c9cSStella Laurenzo } 508436c6c9cSStella Laurenzo 509436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() { 510436c6c9cSStella Laurenzo PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 511436c6c9cSStella Laurenzo if (!context) { 512436c6c9cSStella Laurenzo throw SetPyError( 513436c6c9cSStella Laurenzo PyExc_RuntimeError, 514436c6c9cSStella Laurenzo "An MLIR function requires a Context but none was provided in the call " 515436c6c9cSStella Laurenzo "or from the surrounding environment. Either pass to the function with " 516436c6c9cSStella Laurenzo "a 'context=' argument or establish a default using 'with Context():'"); 517436c6c9cSStella Laurenzo } 518436c6c9cSStella Laurenzo return *context; 519436c6c9cSStella Laurenzo } 520436c6c9cSStella Laurenzo 521436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 522436c6c9cSStella Laurenzo // PyThreadContextEntry management 523436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 524436c6c9cSStella Laurenzo 525436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 526436c6c9cSStella Laurenzo static thread_local std::vector<PyThreadContextEntry> stack; 527436c6c9cSStella Laurenzo return stack; 528436c6c9cSStella Laurenzo } 529436c6c9cSStella Laurenzo 530436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 531436c6c9cSStella Laurenzo auto &stack = getStack(); 532436c6c9cSStella Laurenzo if (stack.empty()) 533436c6c9cSStella Laurenzo return nullptr; 534436c6c9cSStella Laurenzo return &stack.back(); 535436c6c9cSStella Laurenzo } 536436c6c9cSStella Laurenzo 537436c6c9cSStella Laurenzo void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 538436c6c9cSStella Laurenzo py::object insertionPoint, 539436c6c9cSStella Laurenzo py::object location) { 540436c6c9cSStella Laurenzo auto &stack = getStack(); 541436c6c9cSStella Laurenzo stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 542436c6c9cSStella Laurenzo std::move(location)); 543436c6c9cSStella Laurenzo // If the new stack has more than one entry and the context of the new top 544436c6c9cSStella Laurenzo // entry matches the previous, copy the insertionPoint and location from the 545436c6c9cSStella Laurenzo // previous entry if missing from the new top entry. 546436c6c9cSStella Laurenzo if (stack.size() > 1) { 547436c6c9cSStella Laurenzo auto &prev = *(stack.rbegin() + 1); 548436c6c9cSStella Laurenzo auto ¤t = stack.back(); 549436c6c9cSStella Laurenzo if (current.context.is(prev.context)) { 550436c6c9cSStella Laurenzo // Default non-context objects from the previous entry. 551436c6c9cSStella Laurenzo if (!current.insertionPoint) 552436c6c9cSStella Laurenzo current.insertionPoint = prev.insertionPoint; 553436c6c9cSStella Laurenzo if (!current.location) 554436c6c9cSStella Laurenzo current.location = prev.location; 555436c6c9cSStella Laurenzo } 556436c6c9cSStella Laurenzo } 557436c6c9cSStella Laurenzo } 558436c6c9cSStella Laurenzo 559436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() { 560436c6c9cSStella Laurenzo if (!context) 561436c6c9cSStella Laurenzo return nullptr; 562436c6c9cSStella Laurenzo return py::cast<PyMlirContext *>(context); 563436c6c9cSStella Laurenzo } 564436c6c9cSStella Laurenzo 565436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 566436c6c9cSStella Laurenzo if (!insertionPoint) 567436c6c9cSStella Laurenzo return nullptr; 568436c6c9cSStella Laurenzo return py::cast<PyInsertionPoint *>(insertionPoint); 569436c6c9cSStella Laurenzo } 570436c6c9cSStella Laurenzo 571436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() { 572436c6c9cSStella Laurenzo if (!location) 573436c6c9cSStella Laurenzo return nullptr; 574436c6c9cSStella Laurenzo return py::cast<PyLocation *>(location); 575436c6c9cSStella Laurenzo } 576436c6c9cSStella Laurenzo 577436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() { 578436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 579436c6c9cSStella Laurenzo return tos ? tos->getContext() : nullptr; 580436c6c9cSStella Laurenzo } 581436c6c9cSStella Laurenzo 582436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 583436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 584436c6c9cSStella Laurenzo return tos ? tos->getInsertionPoint() : nullptr; 585436c6c9cSStella Laurenzo } 586436c6c9cSStella Laurenzo 587436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() { 588436c6c9cSStella Laurenzo auto *tos = getTopOfStack(); 589436c6c9cSStella Laurenzo return tos ? tos->getLocation() : nullptr; 590436c6c9cSStella Laurenzo } 591436c6c9cSStella Laurenzo 592436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 593436c6c9cSStella Laurenzo py::object contextObj = py::cast(context); 594436c6c9cSStella Laurenzo push(FrameKind::Context, /*context=*/contextObj, 595436c6c9cSStella Laurenzo /*insertionPoint=*/py::object(), 596436c6c9cSStella Laurenzo /*location=*/py::object()); 597436c6c9cSStella Laurenzo return contextObj; 598436c6c9cSStella Laurenzo } 599436c6c9cSStella Laurenzo 600436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) { 601436c6c9cSStella Laurenzo auto &stack = getStack(); 602436c6c9cSStella Laurenzo if (stack.empty()) 603436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 604436c6c9cSStella Laurenzo auto &tos = stack.back(); 605436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 606436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 607436c6c9cSStella Laurenzo stack.pop_back(); 608436c6c9cSStella Laurenzo } 609436c6c9cSStella Laurenzo 610436c6c9cSStella Laurenzo py::object 611436c6c9cSStella Laurenzo PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 612436c6c9cSStella Laurenzo py::object contextObj = 613436c6c9cSStella Laurenzo insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 614436c6c9cSStella Laurenzo py::object insertionPointObj = py::cast(insertionPoint); 615436c6c9cSStella Laurenzo push(FrameKind::InsertionPoint, 616436c6c9cSStella Laurenzo /*context=*/contextObj, 617436c6c9cSStella Laurenzo /*insertionPoint=*/insertionPointObj, 618436c6c9cSStella Laurenzo /*location=*/py::object()); 619436c6c9cSStella Laurenzo return insertionPointObj; 620436c6c9cSStella Laurenzo } 621436c6c9cSStella Laurenzo 622436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 623436c6c9cSStella Laurenzo auto &stack = getStack(); 624436c6c9cSStella Laurenzo if (stack.empty()) 625436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, 626436c6c9cSStella Laurenzo "Unbalanced InsertionPoint enter/exit"); 627436c6c9cSStella Laurenzo auto &tos = stack.back(); 628436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::InsertionPoint && 629436c6c9cSStella Laurenzo tos.getInsertionPoint() != &insertionPoint) 630436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, 631436c6c9cSStella Laurenzo "Unbalanced InsertionPoint enter/exit"); 632436c6c9cSStella Laurenzo stack.pop_back(); 633436c6c9cSStella Laurenzo } 634436c6c9cSStella Laurenzo 635436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 636436c6c9cSStella Laurenzo py::object contextObj = location.getContext().getObject(); 637436c6c9cSStella Laurenzo py::object locationObj = py::cast(location); 638436c6c9cSStella Laurenzo push(FrameKind::Location, /*context=*/contextObj, 639436c6c9cSStella Laurenzo /*insertionPoint=*/py::object(), 640436c6c9cSStella Laurenzo /*location=*/locationObj); 641436c6c9cSStella Laurenzo return locationObj; 642436c6c9cSStella Laurenzo } 643436c6c9cSStella Laurenzo 644436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) { 645436c6c9cSStella Laurenzo auto &stack = getStack(); 646436c6c9cSStella Laurenzo if (stack.empty()) 647436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 648436c6c9cSStella Laurenzo auto &tos = stack.back(); 649436c6c9cSStella Laurenzo if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 650436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 651436c6c9cSStella Laurenzo stack.pop_back(); 652436c6c9cSStella Laurenzo } 653436c6c9cSStella Laurenzo 654436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 655436c6c9cSStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects 656436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 657436c6c9cSStella Laurenzo 658436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key, 659436c6c9cSStella Laurenzo bool attrError) { 660f8479d9dSRiver Riddle MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 661f8479d9dSRiver Riddle {key.data(), key.size()}); 662436c6c9cSStella Laurenzo if (mlirDialectIsNull(dialect)) { 663436c6c9cSStella Laurenzo throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 664436c6c9cSStella Laurenzo Twine("Dialect '") + key + "' not found"); 665436c6c9cSStella Laurenzo } 666436c6c9cSStella Laurenzo return dialect; 667436c6c9cSStella Laurenzo } 668436c6c9cSStella Laurenzo 669436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 670436c6c9cSStella Laurenzo // PyLocation 671436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 672436c6c9cSStella Laurenzo 673436c6c9cSStella Laurenzo py::object PyLocation::getCapsule() { 674436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 675436c6c9cSStella Laurenzo } 676436c6c9cSStella Laurenzo 677436c6c9cSStella Laurenzo PyLocation PyLocation::createFromCapsule(py::object capsule) { 678436c6c9cSStella Laurenzo MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 679436c6c9cSStella Laurenzo if (mlirLocationIsNull(rawLoc)) 680436c6c9cSStella Laurenzo throw py::error_already_set(); 681436c6c9cSStella Laurenzo return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 682436c6c9cSStella Laurenzo rawLoc); 683436c6c9cSStella Laurenzo } 684436c6c9cSStella Laurenzo 685436c6c9cSStella Laurenzo py::object PyLocation::contextEnter() { 686436c6c9cSStella Laurenzo return PyThreadContextEntry::pushLocation(*this); 687436c6c9cSStella Laurenzo } 688436c6c9cSStella Laurenzo 689436c6c9cSStella Laurenzo void PyLocation::contextExit(py::object excType, py::object excVal, 690436c6c9cSStella Laurenzo py::object excTb) { 691436c6c9cSStella Laurenzo PyThreadContextEntry::popLocation(*this); 692436c6c9cSStella Laurenzo } 693436c6c9cSStella Laurenzo 694436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() { 695436c6c9cSStella Laurenzo auto *location = PyThreadContextEntry::getDefaultLocation(); 696436c6c9cSStella Laurenzo if (!location) { 697436c6c9cSStella Laurenzo throw SetPyError( 698436c6c9cSStella Laurenzo PyExc_RuntimeError, 699436c6c9cSStella Laurenzo "An MLIR function requires a Location but none was provided in the " 700436c6c9cSStella Laurenzo "call or from the surrounding environment. Either pass to the function " 701436c6c9cSStella Laurenzo "with a 'loc=' argument or establish a default using 'with loc:'"); 702436c6c9cSStella Laurenzo } 703436c6c9cSStella Laurenzo return *location; 704436c6c9cSStella Laurenzo } 705436c6c9cSStella Laurenzo 706436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 707436c6c9cSStella Laurenzo // PyModule 708436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 709436c6c9cSStella Laurenzo 710436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 711436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), module(module) {} 712436c6c9cSStella Laurenzo 713436c6c9cSStella Laurenzo PyModule::~PyModule() { 714436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 715436c6c9cSStella Laurenzo auto &liveModules = getContext()->liveModules; 716436c6c9cSStella Laurenzo assert(liveModules.count(module.ptr) == 1 && 717436c6c9cSStella Laurenzo "destroying module not in live map"); 718436c6c9cSStella Laurenzo liveModules.erase(module.ptr); 719436c6c9cSStella Laurenzo mlirModuleDestroy(module); 720436c6c9cSStella Laurenzo } 721436c6c9cSStella Laurenzo 722436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) { 723436c6c9cSStella Laurenzo MlirContext context = mlirModuleGetContext(module); 724436c6c9cSStella Laurenzo PyMlirContextRef contextRef = PyMlirContext::forContext(context); 725436c6c9cSStella Laurenzo 726436c6c9cSStella Laurenzo py::gil_scoped_acquire acquire; 727436c6c9cSStella Laurenzo auto &liveModules = contextRef->liveModules; 728436c6c9cSStella Laurenzo auto it = liveModules.find(module.ptr); 729436c6c9cSStella Laurenzo if (it == liveModules.end()) { 730436c6c9cSStella Laurenzo // Create. 731436c6c9cSStella Laurenzo PyModule *unownedModule = new PyModule(std::move(contextRef), module); 732436c6c9cSStella Laurenzo // Note that the default return value policy on cast is automatic_reference, 733436c6c9cSStella Laurenzo // which does not take ownership (delete will not be called). 734436c6c9cSStella Laurenzo // Just be explicit. 735436c6c9cSStella Laurenzo py::object pyRef = 736436c6c9cSStella Laurenzo py::cast(unownedModule, py::return_value_policy::take_ownership); 737436c6c9cSStella Laurenzo unownedModule->handle = pyRef; 738436c6c9cSStella Laurenzo liveModules[module.ptr] = 739436c6c9cSStella Laurenzo std::make_pair(unownedModule->handle, unownedModule); 740436c6c9cSStella Laurenzo return PyModuleRef(unownedModule, std::move(pyRef)); 741436c6c9cSStella Laurenzo } 742436c6c9cSStella Laurenzo // Use existing. 743436c6c9cSStella Laurenzo PyModule *existing = it->second.second; 744436c6c9cSStella Laurenzo py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 745436c6c9cSStella Laurenzo return PyModuleRef(existing, std::move(pyRef)); 746436c6c9cSStella Laurenzo } 747436c6c9cSStella Laurenzo 748436c6c9cSStella Laurenzo py::object PyModule::createFromCapsule(py::object capsule) { 749436c6c9cSStella Laurenzo MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 750436c6c9cSStella Laurenzo if (mlirModuleIsNull(rawModule)) 751436c6c9cSStella Laurenzo throw py::error_already_set(); 752436c6c9cSStella Laurenzo return forModule(rawModule).releaseObject(); 753436c6c9cSStella Laurenzo } 754436c6c9cSStella Laurenzo 755436c6c9cSStella Laurenzo py::object PyModule::getCapsule() { 756436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 757436c6c9cSStella Laurenzo } 758436c6c9cSStella Laurenzo 759436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 760436c6c9cSStella Laurenzo // PyOperation 761436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 762436c6c9cSStella Laurenzo 763436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 764436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), operation(operation) {} 765436c6c9cSStella Laurenzo 766436c6c9cSStella Laurenzo PyOperation::~PyOperation() { 76749745f87SMike Urbach // If the operation has already been invalidated there is nothing to do. 76849745f87SMike Urbach if (!valid) 76949745f87SMike Urbach return; 770436c6c9cSStella Laurenzo auto &liveOperations = getContext()->liveOperations; 771436c6c9cSStella Laurenzo assert(liveOperations.count(operation.ptr) == 1 && 772436c6c9cSStella Laurenzo "destroying operation not in live map"); 773436c6c9cSStella Laurenzo liveOperations.erase(operation.ptr); 774436c6c9cSStella Laurenzo if (!isAttached()) { 775436c6c9cSStella Laurenzo mlirOperationDestroy(operation); 776436c6c9cSStella Laurenzo } 777436c6c9cSStella Laurenzo } 778436c6c9cSStella Laurenzo 779436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 780436c6c9cSStella Laurenzo MlirOperation operation, 781436c6c9cSStella Laurenzo py::object parentKeepAlive) { 782436c6c9cSStella Laurenzo auto &liveOperations = contextRef->liveOperations; 783436c6c9cSStella Laurenzo // Create. 784436c6c9cSStella Laurenzo PyOperation *unownedOperation = 785436c6c9cSStella Laurenzo new PyOperation(std::move(contextRef), operation); 786436c6c9cSStella Laurenzo // Note that the default return value policy on cast is automatic_reference, 787436c6c9cSStella Laurenzo // which does not take ownership (delete will not be called). 788436c6c9cSStella Laurenzo // Just be explicit. 789436c6c9cSStella Laurenzo py::object pyRef = 790436c6c9cSStella Laurenzo py::cast(unownedOperation, py::return_value_policy::take_ownership); 791436c6c9cSStella Laurenzo unownedOperation->handle = pyRef; 792436c6c9cSStella Laurenzo if (parentKeepAlive) { 793436c6c9cSStella Laurenzo unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 794436c6c9cSStella Laurenzo } 795436c6c9cSStella Laurenzo liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 796436c6c9cSStella Laurenzo return PyOperationRef(unownedOperation, std::move(pyRef)); 797436c6c9cSStella Laurenzo } 798436c6c9cSStella Laurenzo 799436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 800436c6c9cSStella Laurenzo MlirOperation operation, 801436c6c9cSStella Laurenzo py::object parentKeepAlive) { 802436c6c9cSStella Laurenzo auto &liveOperations = contextRef->liveOperations; 803436c6c9cSStella Laurenzo auto it = liveOperations.find(operation.ptr); 804436c6c9cSStella Laurenzo if (it == liveOperations.end()) { 805436c6c9cSStella Laurenzo // Create. 806436c6c9cSStella Laurenzo return createInstance(std::move(contextRef), operation, 807436c6c9cSStella Laurenzo std::move(parentKeepAlive)); 808436c6c9cSStella Laurenzo } 809436c6c9cSStella Laurenzo // Use existing. 810436c6c9cSStella Laurenzo PyOperation *existing = it->second.second; 811436c6c9cSStella Laurenzo py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 812436c6c9cSStella Laurenzo return PyOperationRef(existing, std::move(pyRef)); 813436c6c9cSStella Laurenzo } 814436c6c9cSStella Laurenzo 815436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 816436c6c9cSStella Laurenzo MlirOperation operation, 817436c6c9cSStella Laurenzo py::object parentKeepAlive) { 818436c6c9cSStella Laurenzo auto &liveOperations = contextRef->liveOperations; 819436c6c9cSStella Laurenzo assert(liveOperations.count(operation.ptr) == 0 && 820436c6c9cSStella Laurenzo "cannot create detached operation that already exists"); 821436c6c9cSStella Laurenzo (void)liveOperations; 822436c6c9cSStella Laurenzo 823436c6c9cSStella Laurenzo PyOperationRef created = createInstance(std::move(contextRef), operation, 824436c6c9cSStella Laurenzo std::move(parentKeepAlive)); 825436c6c9cSStella Laurenzo created->attached = false; 826436c6c9cSStella Laurenzo return created; 827436c6c9cSStella Laurenzo } 828436c6c9cSStella Laurenzo 829436c6c9cSStella Laurenzo void PyOperation::checkValid() const { 830436c6c9cSStella Laurenzo if (!valid) { 831436c6c9cSStella Laurenzo throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 832436c6c9cSStella Laurenzo } 833436c6c9cSStella Laurenzo } 834436c6c9cSStella Laurenzo 835436c6c9cSStella Laurenzo void PyOperationBase::print(py::object fileObject, bool binary, 836436c6c9cSStella Laurenzo llvm::Optional<int64_t> largeElementsLimit, 837436c6c9cSStella Laurenzo bool enableDebugInfo, bool prettyDebugInfo, 838ace1d0adSStella Laurenzo bool printGenericOpForm, bool useLocalScope, 839ace1d0adSStella Laurenzo bool assumeVerified) { 840436c6c9cSStella Laurenzo PyOperation &operation = getOperation(); 841436c6c9cSStella Laurenzo operation.checkValid(); 842436c6c9cSStella Laurenzo if (fileObject.is_none()) 843436c6c9cSStella Laurenzo fileObject = py::module::import("sys").attr("stdout"); 844436c6c9cSStella Laurenzo 845ace1d0adSStella Laurenzo if (!assumeVerified && !printGenericOpForm && 846ace1d0adSStella Laurenzo !mlirOperationVerify(operation)) { 847ace1d0adSStella Laurenzo std::string message("// Verification failed, printing generic form\n"); 848ace1d0adSStella Laurenzo if (binary) { 849ace1d0adSStella Laurenzo fileObject.attr("write")(py::bytes(message)); 850ace1d0adSStella Laurenzo } else { 851ace1d0adSStella Laurenzo fileObject.attr("write")(py::str(message)); 852ace1d0adSStella Laurenzo } 853436c6c9cSStella Laurenzo printGenericOpForm = true; 854436c6c9cSStella Laurenzo } 855436c6c9cSStella Laurenzo 856436c6c9cSStella Laurenzo MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 857436c6c9cSStella Laurenzo if (largeElementsLimit) 858436c6c9cSStella Laurenzo mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 859436c6c9cSStella Laurenzo if (enableDebugInfo) 860436c6c9cSStella Laurenzo mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 861436c6c9cSStella Laurenzo if (printGenericOpForm) 862436c6c9cSStella Laurenzo mlirOpPrintingFlagsPrintGenericOpForm(flags); 863436c6c9cSStella Laurenzo 864436c6c9cSStella Laurenzo PyFileAccumulator accum(fileObject, binary); 865436c6c9cSStella Laurenzo py::gil_scoped_release(); 866436c6c9cSStella Laurenzo mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 867436c6c9cSStella Laurenzo accum.getUserData()); 868436c6c9cSStella Laurenzo mlirOpPrintingFlagsDestroy(flags); 869436c6c9cSStella Laurenzo } 870436c6c9cSStella Laurenzo 871436c6c9cSStella Laurenzo py::object PyOperationBase::getAsm(bool binary, 872436c6c9cSStella Laurenzo llvm::Optional<int64_t> largeElementsLimit, 873436c6c9cSStella Laurenzo bool enableDebugInfo, bool prettyDebugInfo, 874ace1d0adSStella Laurenzo bool printGenericOpForm, bool useLocalScope, 875ace1d0adSStella Laurenzo bool assumeVerified) { 876436c6c9cSStella Laurenzo py::object fileObject; 877436c6c9cSStella Laurenzo if (binary) { 878436c6c9cSStella Laurenzo fileObject = py::module::import("io").attr("BytesIO")(); 879436c6c9cSStella Laurenzo } else { 880436c6c9cSStella Laurenzo fileObject = py::module::import("io").attr("StringIO")(); 881436c6c9cSStella Laurenzo } 882436c6c9cSStella Laurenzo print(fileObject, /*binary=*/binary, 883436c6c9cSStella Laurenzo /*largeElementsLimit=*/largeElementsLimit, 884436c6c9cSStella Laurenzo /*enableDebugInfo=*/enableDebugInfo, 885436c6c9cSStella Laurenzo /*prettyDebugInfo=*/prettyDebugInfo, 886436c6c9cSStella Laurenzo /*printGenericOpForm=*/printGenericOpForm, 887ace1d0adSStella Laurenzo /*useLocalScope=*/useLocalScope, 888ace1d0adSStella Laurenzo /*assumeVerified=*/assumeVerified); 889436c6c9cSStella Laurenzo 890436c6c9cSStella Laurenzo return fileObject.attr("getvalue")(); 891436c6c9cSStella Laurenzo } 892436c6c9cSStella Laurenzo 89324685aaeSAlex Zinenko void PyOperationBase::moveAfter(PyOperationBase &other) { 89424685aaeSAlex Zinenko PyOperation &operation = getOperation(); 89524685aaeSAlex Zinenko PyOperation &otherOp = other.getOperation(); 89624685aaeSAlex Zinenko operation.checkValid(); 89724685aaeSAlex Zinenko otherOp.checkValid(); 89824685aaeSAlex Zinenko mlirOperationMoveAfter(operation, otherOp); 89924685aaeSAlex Zinenko operation.parentKeepAlive = otherOp.parentKeepAlive; 90024685aaeSAlex Zinenko } 90124685aaeSAlex Zinenko 90224685aaeSAlex Zinenko void PyOperationBase::moveBefore(PyOperationBase &other) { 90324685aaeSAlex Zinenko PyOperation &operation = getOperation(); 90424685aaeSAlex Zinenko PyOperation &otherOp = other.getOperation(); 90524685aaeSAlex Zinenko operation.checkValid(); 90624685aaeSAlex Zinenko otherOp.checkValid(); 90724685aaeSAlex Zinenko mlirOperationMoveBefore(operation, otherOp); 90824685aaeSAlex Zinenko operation.parentKeepAlive = otherOp.parentKeepAlive; 90924685aaeSAlex Zinenko } 91024685aaeSAlex Zinenko 9111689dadeSJohn Demme llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 91249745f87SMike Urbach checkValid(); 913436c6c9cSStella Laurenzo if (!isAttached()) 914436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 915436c6c9cSStella Laurenzo MlirOperation operation = mlirOperationGetParentOperation(get()); 916436c6c9cSStella Laurenzo if (mlirOperationIsNull(operation)) 9171689dadeSJohn Demme return {}; 918436c6c9cSStella Laurenzo return PyOperation::forOperation(getContext(), operation); 919436c6c9cSStella Laurenzo } 920436c6c9cSStella Laurenzo 921436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() { 92249745f87SMike Urbach checkValid(); 9231689dadeSJohn Demme llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 924436c6c9cSStella Laurenzo MlirBlock block = mlirOperationGetBlock(get()); 925436c6c9cSStella Laurenzo assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 9261689dadeSJohn Demme assert(parentOperation && "Operation has no parent"); 9271689dadeSJohn Demme return PyBlock{std::move(*parentOperation), block}; 928436c6c9cSStella Laurenzo } 929436c6c9cSStella Laurenzo 9300126e906SJohn Demme py::object PyOperation::getCapsule() { 93149745f87SMike Urbach checkValid(); 9320126e906SJohn Demme return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 9330126e906SJohn Demme } 9340126e906SJohn Demme 9350126e906SJohn Demme py::object PyOperation::createFromCapsule(py::object capsule) { 9360126e906SJohn Demme MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 9370126e906SJohn Demme if (mlirOperationIsNull(rawOperation)) 9380126e906SJohn Demme throw py::error_already_set(); 9390126e906SJohn Demme MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 9400126e906SJohn Demme return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 9410126e906SJohn Demme .releaseObject(); 9420126e906SJohn Demme } 9430126e906SJohn Demme 944436c6c9cSStella Laurenzo py::object PyOperation::create( 945436c6c9cSStella Laurenzo std::string name, llvm::Optional<std::vector<PyType *>> results, 946436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyValue *>> operands, 947436c6c9cSStella Laurenzo llvm::Optional<py::dict> attributes, 948436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyBlock *>> successors, int regions, 949436c6c9cSStella Laurenzo DefaultingPyLocation location, py::object maybeIp) { 950436c6c9cSStella Laurenzo llvm::SmallVector<MlirValue, 4> mlirOperands; 951436c6c9cSStella Laurenzo llvm::SmallVector<MlirType, 4> mlirResults; 952436c6c9cSStella Laurenzo llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 953436c6c9cSStella Laurenzo llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 954436c6c9cSStella Laurenzo 955436c6c9cSStella Laurenzo // General parameter validation. 956436c6c9cSStella Laurenzo if (regions < 0) 957436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 958436c6c9cSStella Laurenzo 959436c6c9cSStella Laurenzo // Unpack/validate operands. 960436c6c9cSStella Laurenzo if (operands) { 961436c6c9cSStella Laurenzo mlirOperands.reserve(operands->size()); 962436c6c9cSStella Laurenzo for (PyValue *operand : *operands) { 963436c6c9cSStella Laurenzo if (!operand) 964436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 965436c6c9cSStella Laurenzo mlirOperands.push_back(operand->get()); 966436c6c9cSStella Laurenzo } 967436c6c9cSStella Laurenzo } 968436c6c9cSStella Laurenzo 969436c6c9cSStella Laurenzo // Unpack/validate results. 970436c6c9cSStella Laurenzo if (results) { 971436c6c9cSStella Laurenzo mlirResults.reserve(results->size()); 972436c6c9cSStella Laurenzo for (PyType *result : *results) { 973436c6c9cSStella Laurenzo // TODO: Verify result type originate from the same context. 974436c6c9cSStella Laurenzo if (!result) 975436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "result type cannot be None"); 976436c6c9cSStella Laurenzo mlirResults.push_back(*result); 977436c6c9cSStella Laurenzo } 978436c6c9cSStella Laurenzo } 979436c6c9cSStella Laurenzo // Unpack/validate attributes. 980436c6c9cSStella Laurenzo if (attributes) { 981436c6c9cSStella Laurenzo mlirAttributes.reserve(attributes->size()); 982436c6c9cSStella Laurenzo for (auto &it : *attributes) { 983436c6c9cSStella Laurenzo std::string key; 984436c6c9cSStella Laurenzo try { 985436c6c9cSStella Laurenzo key = it.first.cast<std::string>(); 986436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 987436c6c9cSStella Laurenzo std::string msg = "Invalid attribute key (not a string) when " 988436c6c9cSStella Laurenzo "attempting to create the operation \"" + 989436c6c9cSStella Laurenzo name + "\" (" + err.what() + ")"; 990436c6c9cSStella Laurenzo throw py::cast_error(msg); 991436c6c9cSStella Laurenzo } 992436c6c9cSStella Laurenzo try { 993436c6c9cSStella Laurenzo auto &attribute = it.second.cast<PyAttribute &>(); 994436c6c9cSStella Laurenzo // TODO: Verify attribute originates from the same context. 995436c6c9cSStella Laurenzo mlirAttributes.emplace_back(std::move(key), attribute); 996436c6c9cSStella Laurenzo } catch (py::reference_cast_error &) { 997436c6c9cSStella Laurenzo // This exception seems thrown when the value is "None". 998436c6c9cSStella Laurenzo std::string msg = 999436c6c9cSStella Laurenzo "Found an invalid (`None`?) attribute value for the key \"" + key + 1000436c6c9cSStella Laurenzo "\" when attempting to create the operation \"" + name + "\""; 1001436c6c9cSStella Laurenzo throw py::cast_error(msg); 1002436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1003436c6c9cSStella Laurenzo std::string msg = "Invalid attribute value for the key \"" + key + 1004436c6c9cSStella Laurenzo "\" when attempting to create the operation \"" + 1005436c6c9cSStella Laurenzo name + "\" (" + err.what() + ")"; 1006436c6c9cSStella Laurenzo throw py::cast_error(msg); 1007436c6c9cSStella Laurenzo } 1008436c6c9cSStella Laurenzo } 1009436c6c9cSStella Laurenzo } 1010436c6c9cSStella Laurenzo // Unpack/validate successors. 1011436c6c9cSStella Laurenzo if (successors) { 1012436c6c9cSStella Laurenzo mlirSuccessors.reserve(successors->size()); 1013436c6c9cSStella Laurenzo for (auto *successor : *successors) { 1014436c6c9cSStella Laurenzo // TODO: Verify successor originate from the same context. 1015436c6c9cSStella Laurenzo if (!successor) 1016436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1017436c6c9cSStella Laurenzo mlirSuccessors.push_back(successor->get()); 1018436c6c9cSStella Laurenzo } 1019436c6c9cSStella Laurenzo } 1020436c6c9cSStella Laurenzo 1021436c6c9cSStella Laurenzo // Apply unpacked/validated to the operation state. Beyond this 1022436c6c9cSStella Laurenzo // point, exceptions cannot be thrown or else the state will leak. 1023436c6c9cSStella Laurenzo MlirOperationState state = 1024436c6c9cSStella Laurenzo mlirOperationStateGet(toMlirStringRef(name), location); 1025436c6c9cSStella Laurenzo if (!mlirOperands.empty()) 1026436c6c9cSStella Laurenzo mlirOperationStateAddOperands(&state, mlirOperands.size(), 1027436c6c9cSStella Laurenzo mlirOperands.data()); 1028436c6c9cSStella Laurenzo if (!mlirResults.empty()) 1029436c6c9cSStella Laurenzo mlirOperationStateAddResults(&state, mlirResults.size(), 1030436c6c9cSStella Laurenzo mlirResults.data()); 1031436c6c9cSStella Laurenzo if (!mlirAttributes.empty()) { 1032436c6c9cSStella Laurenzo // Note that the attribute names directly reference bytes in 1033436c6c9cSStella Laurenzo // mlirAttributes, so that vector must not be changed from here 1034436c6c9cSStella Laurenzo // on. 1035436c6c9cSStella Laurenzo llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1036436c6c9cSStella Laurenzo mlirNamedAttributes.reserve(mlirAttributes.size()); 1037436c6c9cSStella Laurenzo for (auto &it : mlirAttributes) 1038436c6c9cSStella Laurenzo mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1039436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(it.second), 1040436c6c9cSStella Laurenzo toMlirStringRef(it.first)), 1041436c6c9cSStella Laurenzo it.second)); 1042436c6c9cSStella Laurenzo mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1043436c6c9cSStella Laurenzo mlirNamedAttributes.data()); 1044436c6c9cSStella Laurenzo } 1045436c6c9cSStella Laurenzo if (!mlirSuccessors.empty()) 1046436c6c9cSStella Laurenzo mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1047436c6c9cSStella Laurenzo mlirSuccessors.data()); 1048436c6c9cSStella Laurenzo if (regions) { 1049436c6c9cSStella Laurenzo llvm::SmallVector<MlirRegion, 4> mlirRegions; 1050436c6c9cSStella Laurenzo mlirRegions.resize(regions); 1051436c6c9cSStella Laurenzo for (int i = 0; i < regions; ++i) 1052436c6c9cSStella Laurenzo mlirRegions[i] = mlirRegionCreate(); 1053436c6c9cSStella Laurenzo mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1054436c6c9cSStella Laurenzo mlirRegions.data()); 1055436c6c9cSStella Laurenzo } 1056436c6c9cSStella Laurenzo 1057436c6c9cSStella Laurenzo // Construct the operation. 1058436c6c9cSStella Laurenzo MlirOperation operation = mlirOperationCreate(&state); 1059436c6c9cSStella Laurenzo PyOperationRef created = 1060436c6c9cSStella Laurenzo PyOperation::createDetached(location->getContext(), operation); 1061436c6c9cSStella Laurenzo 1062436c6c9cSStella Laurenzo // InsertPoint active? 1063436c6c9cSStella Laurenzo if (!maybeIp.is(py::cast(false))) { 1064436c6c9cSStella Laurenzo PyInsertionPoint *ip; 1065436c6c9cSStella Laurenzo if (maybeIp.is_none()) { 1066436c6c9cSStella Laurenzo ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1067436c6c9cSStella Laurenzo } else { 1068436c6c9cSStella Laurenzo ip = py::cast<PyInsertionPoint *>(maybeIp); 1069436c6c9cSStella Laurenzo } 1070436c6c9cSStella Laurenzo if (ip) 1071436c6c9cSStella Laurenzo ip->insert(*created.get()); 1072436c6c9cSStella Laurenzo } 1073436c6c9cSStella Laurenzo 1074436c6c9cSStella Laurenzo return created->createOpView(); 1075436c6c9cSStella Laurenzo } 1076436c6c9cSStella Laurenzo 1077436c6c9cSStella Laurenzo py::object PyOperation::createOpView() { 107849745f87SMike Urbach checkValid(); 1079436c6c9cSStella Laurenzo MlirIdentifier ident = mlirOperationGetName(get()); 1080436c6c9cSStella Laurenzo MlirStringRef identStr = mlirIdentifierStr(ident); 1081436c6c9cSStella Laurenzo auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1082436c6c9cSStella Laurenzo StringRef(identStr.data, identStr.length)); 1083436c6c9cSStella Laurenzo if (opViewClass) 1084436c6c9cSStella Laurenzo return (*opViewClass)(getRef().getObject()); 1085436c6c9cSStella Laurenzo return py::cast(PyOpView(getRef().getObject())); 1086436c6c9cSStella Laurenzo } 1087436c6c9cSStella Laurenzo 108849745f87SMike Urbach void PyOperation::erase() { 108949745f87SMike Urbach checkValid(); 109049745f87SMike Urbach // TODO: Fix memory hazards when erasing a tree of operations for which a deep 109149745f87SMike Urbach // Python reference to a child operation is live. All children should also 109249745f87SMike Urbach // have their `valid` bit set to false. 109349745f87SMike Urbach auto &liveOperations = getContext()->liveOperations; 109449745f87SMike Urbach if (liveOperations.count(operation.ptr)) 109549745f87SMike Urbach liveOperations.erase(operation.ptr); 109649745f87SMike Urbach mlirOperationDestroy(operation); 109749745f87SMike Urbach valid = false; 109849745f87SMike Urbach } 109949745f87SMike Urbach 1100436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1101436c6c9cSStella Laurenzo // PyOpView 1102436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1103436c6c9cSStella Laurenzo 1104436c6c9cSStella Laurenzo py::object 1105436c6c9cSStella Laurenzo PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1106436c6c9cSStella Laurenzo py::list operandList, 1107436c6c9cSStella Laurenzo llvm::Optional<py::dict> attributes, 1108436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyBlock *>> successors, 1109436c6c9cSStella Laurenzo llvm::Optional<int> regions, 1110436c6c9cSStella Laurenzo DefaultingPyLocation location, py::object maybeIp) { 1111436c6c9cSStella Laurenzo PyMlirContextRef context = location->getContext(); 1112436c6c9cSStella Laurenzo // Class level operation construction metadata. 1113436c6c9cSStella Laurenzo std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1114436c6c9cSStella Laurenzo // Operand and result segment specs are either none, which does no 1115436c6c9cSStella Laurenzo // variadic unpacking, or a list of ints with segment sizes, where each 1116436c6c9cSStella Laurenzo // element is either a positive number (typically 1 for a scalar) or -1 to 1117436c6c9cSStella Laurenzo // indicate that it is derived from the length of the same-indexed operand 1118436c6c9cSStella Laurenzo // or result (implying that it is a list at that position). 1119436c6c9cSStella Laurenzo py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1120436c6c9cSStella Laurenzo py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1121436c6c9cSStella Laurenzo 11228d05a288SStella Laurenzo std::vector<uint32_t> operandSegmentLengths; 11238d05a288SStella Laurenzo std::vector<uint32_t> resultSegmentLengths; 1124436c6c9cSStella Laurenzo 1125436c6c9cSStella Laurenzo // Validate/determine region count. 1126436c6c9cSStella Laurenzo auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1127436c6c9cSStella Laurenzo int opMinRegionCount = std::get<0>(opRegionSpec); 1128436c6c9cSStella Laurenzo bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1129436c6c9cSStella Laurenzo if (!regions) { 1130436c6c9cSStella Laurenzo regions = opMinRegionCount; 1131436c6c9cSStella Laurenzo } 1132436c6c9cSStella Laurenzo if (*regions < opMinRegionCount) { 1133436c6c9cSStella Laurenzo throw py::value_error( 1134436c6c9cSStella Laurenzo (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1135436c6c9cSStella Laurenzo llvm::Twine(opMinRegionCount) + 1136436c6c9cSStella Laurenzo " regions but was built with regions=" + llvm::Twine(*regions)) 1137436c6c9cSStella Laurenzo .str()); 1138436c6c9cSStella Laurenzo } 1139436c6c9cSStella Laurenzo if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1140436c6c9cSStella Laurenzo throw py::value_error( 1141436c6c9cSStella Laurenzo (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1142436c6c9cSStella Laurenzo llvm::Twine(opMinRegionCount) + 1143436c6c9cSStella Laurenzo " regions but was built with regions=" + llvm::Twine(*regions)) 1144436c6c9cSStella Laurenzo .str()); 1145436c6c9cSStella Laurenzo } 1146436c6c9cSStella Laurenzo 1147436c6c9cSStella Laurenzo // Unpack results. 1148436c6c9cSStella Laurenzo std::vector<PyType *> resultTypes; 1149436c6c9cSStella Laurenzo resultTypes.reserve(resultTypeList.size()); 1150436c6c9cSStella Laurenzo if (resultSegmentSpecObj.is_none()) { 1151436c6c9cSStella Laurenzo // Non-variadic result unpacking. 1152436c6c9cSStella Laurenzo for (auto it : llvm::enumerate(resultTypeList)) { 1153436c6c9cSStella Laurenzo try { 1154436c6c9cSStella Laurenzo resultTypes.push_back(py::cast<PyType *>(it.value())); 1155436c6c9cSStella Laurenzo if (!resultTypes.back()) 1156436c6c9cSStella Laurenzo throw py::cast_error(); 1157436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1158436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Result ") + 1159436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1160436c6c9cSStella Laurenzo name + "\" must be a Type (" + err.what() + ")") 1161436c6c9cSStella Laurenzo .str()); 1162436c6c9cSStella Laurenzo } 1163436c6c9cSStella Laurenzo } 1164436c6c9cSStella Laurenzo } else { 1165436c6c9cSStella Laurenzo // Sized result unpacking. 1166436c6c9cSStella Laurenzo auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1167436c6c9cSStella Laurenzo if (resultSegmentSpec.size() != resultTypeList.size()) { 1168436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operation \"") + name + 1169436c6c9cSStella Laurenzo "\" requires " + 1170436c6c9cSStella Laurenzo llvm::Twine(resultSegmentSpec.size()) + 1171436c6c9cSStella Laurenzo " result segments but was provided " + 1172436c6c9cSStella Laurenzo llvm::Twine(resultTypeList.size())) 1173436c6c9cSStella Laurenzo .str()); 1174436c6c9cSStella Laurenzo } 1175436c6c9cSStella Laurenzo resultSegmentLengths.reserve(resultTypeList.size()); 1176436c6c9cSStella Laurenzo for (auto it : 1177436c6c9cSStella Laurenzo llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1178436c6c9cSStella Laurenzo int segmentSpec = std::get<1>(it.value()); 1179436c6c9cSStella Laurenzo if (segmentSpec == 1 || segmentSpec == 0) { 1180436c6c9cSStella Laurenzo // Unpack unary element. 1181436c6c9cSStella Laurenzo try { 11826981e5ecSAlex Zinenko auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1183436c6c9cSStella Laurenzo if (resultType) { 1184436c6c9cSStella Laurenzo resultTypes.push_back(resultType); 1185436c6c9cSStella Laurenzo resultSegmentLengths.push_back(1); 1186436c6c9cSStella Laurenzo } else if (segmentSpec == 0) { 1187436c6c9cSStella Laurenzo // Allowed to be optional. 1188436c6c9cSStella Laurenzo resultSegmentLengths.push_back(0); 1189436c6c9cSStella Laurenzo } else { 1190436c6c9cSStella Laurenzo throw py::cast_error("was None and result is not optional"); 1191436c6c9cSStella Laurenzo } 1192436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1193436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Result ") + 1194436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1195436c6c9cSStella Laurenzo name + "\" must be a Type (" + err.what() + 1196436c6c9cSStella Laurenzo ")") 1197436c6c9cSStella Laurenzo .str()); 1198436c6c9cSStella Laurenzo } 1199436c6c9cSStella Laurenzo } else if (segmentSpec == -1) { 1200436c6c9cSStella Laurenzo // Unpack sequence by appending. 1201436c6c9cSStella Laurenzo try { 1202436c6c9cSStella Laurenzo if (std::get<0>(it.value()).is_none()) { 1203436c6c9cSStella Laurenzo // Treat it as an empty list. 1204436c6c9cSStella Laurenzo resultSegmentLengths.push_back(0); 1205436c6c9cSStella Laurenzo } else { 1206436c6c9cSStella Laurenzo // Unpack the list. 1207436c6c9cSStella Laurenzo auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1208436c6c9cSStella Laurenzo for (py::object segmentItem : segment) { 1209436c6c9cSStella Laurenzo resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1210436c6c9cSStella Laurenzo if (!resultTypes.back()) { 1211436c6c9cSStella Laurenzo throw py::cast_error("contained a None item"); 1212436c6c9cSStella Laurenzo } 1213436c6c9cSStella Laurenzo } 1214436c6c9cSStella Laurenzo resultSegmentLengths.push_back(segment.size()); 1215436c6c9cSStella Laurenzo } 1216436c6c9cSStella Laurenzo } catch (std::exception &err) { 1217436c6c9cSStella Laurenzo // NOTE: Sloppy to be using a catch-all here, but there are at least 1218436c6c9cSStella Laurenzo // three different unrelated exceptions that can be thrown in the 1219436c6c9cSStella Laurenzo // above "casts". Just keep the scope above small and catch them all. 1220436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Result ") + 1221436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1222436c6c9cSStella Laurenzo name + "\" must be a Sequence of Types (" + 1223436c6c9cSStella Laurenzo err.what() + ")") 1224436c6c9cSStella Laurenzo .str()); 1225436c6c9cSStella Laurenzo } 1226436c6c9cSStella Laurenzo } else { 1227436c6c9cSStella Laurenzo throw py::value_error("Unexpected segment spec"); 1228436c6c9cSStella Laurenzo } 1229436c6c9cSStella Laurenzo } 1230436c6c9cSStella Laurenzo } 1231436c6c9cSStella Laurenzo 1232436c6c9cSStella Laurenzo // Unpack operands. 1233436c6c9cSStella Laurenzo std::vector<PyValue *> operands; 1234436c6c9cSStella Laurenzo operands.reserve(operands.size()); 1235436c6c9cSStella Laurenzo if (operandSegmentSpecObj.is_none()) { 1236436c6c9cSStella Laurenzo // Non-sized operand unpacking. 1237436c6c9cSStella Laurenzo for (auto it : llvm::enumerate(operandList)) { 1238436c6c9cSStella Laurenzo try { 1239436c6c9cSStella Laurenzo operands.push_back(py::cast<PyValue *>(it.value())); 1240436c6c9cSStella Laurenzo if (!operands.back()) 1241436c6c9cSStella Laurenzo throw py::cast_error(); 1242436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1243436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operand ") + 1244436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1245436c6c9cSStella Laurenzo name + "\" must be a Value (" + err.what() + ")") 1246436c6c9cSStella Laurenzo .str()); 1247436c6c9cSStella Laurenzo } 1248436c6c9cSStella Laurenzo } 1249436c6c9cSStella Laurenzo } else { 1250436c6c9cSStella Laurenzo // Sized operand unpacking. 1251436c6c9cSStella Laurenzo auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1252436c6c9cSStella Laurenzo if (operandSegmentSpec.size() != operandList.size()) { 1253436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operation \"") + name + 1254436c6c9cSStella Laurenzo "\" requires " + 1255436c6c9cSStella Laurenzo llvm::Twine(operandSegmentSpec.size()) + 1256436c6c9cSStella Laurenzo "operand segments but was provided " + 1257436c6c9cSStella Laurenzo llvm::Twine(operandList.size())) 1258436c6c9cSStella Laurenzo .str()); 1259436c6c9cSStella Laurenzo } 1260436c6c9cSStella Laurenzo operandSegmentLengths.reserve(operandList.size()); 1261436c6c9cSStella Laurenzo for (auto it : 1262436c6c9cSStella Laurenzo llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1263436c6c9cSStella Laurenzo int segmentSpec = std::get<1>(it.value()); 1264436c6c9cSStella Laurenzo if (segmentSpec == 1 || segmentSpec == 0) { 1265436c6c9cSStella Laurenzo // Unpack unary element. 1266436c6c9cSStella Laurenzo try { 1267436c6c9cSStella Laurenzo auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1268436c6c9cSStella Laurenzo if (operandValue) { 1269436c6c9cSStella Laurenzo operands.push_back(operandValue); 1270436c6c9cSStella Laurenzo operandSegmentLengths.push_back(1); 1271436c6c9cSStella Laurenzo } else if (segmentSpec == 0) { 1272436c6c9cSStella Laurenzo // Allowed to be optional. 1273436c6c9cSStella Laurenzo operandSegmentLengths.push_back(0); 1274436c6c9cSStella Laurenzo } else { 1275436c6c9cSStella Laurenzo throw py::cast_error("was None and operand is not optional"); 1276436c6c9cSStella Laurenzo } 1277436c6c9cSStella Laurenzo } catch (py::cast_error &err) { 1278436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operand ") + 1279436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1280436c6c9cSStella Laurenzo name + "\" must be a Value (" + err.what() + 1281436c6c9cSStella Laurenzo ")") 1282436c6c9cSStella Laurenzo .str()); 1283436c6c9cSStella Laurenzo } 1284436c6c9cSStella Laurenzo } else if (segmentSpec == -1) { 1285436c6c9cSStella Laurenzo // Unpack sequence by appending. 1286436c6c9cSStella Laurenzo try { 1287436c6c9cSStella Laurenzo if (std::get<0>(it.value()).is_none()) { 1288436c6c9cSStella Laurenzo // Treat it as an empty list. 1289436c6c9cSStella Laurenzo operandSegmentLengths.push_back(0); 1290436c6c9cSStella Laurenzo } else { 1291436c6c9cSStella Laurenzo // Unpack the list. 1292436c6c9cSStella Laurenzo auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1293436c6c9cSStella Laurenzo for (py::object segmentItem : segment) { 1294436c6c9cSStella Laurenzo operands.push_back(py::cast<PyValue *>(segmentItem)); 1295436c6c9cSStella Laurenzo if (!operands.back()) { 1296436c6c9cSStella Laurenzo throw py::cast_error("contained a None item"); 1297436c6c9cSStella Laurenzo } 1298436c6c9cSStella Laurenzo } 1299436c6c9cSStella Laurenzo operandSegmentLengths.push_back(segment.size()); 1300436c6c9cSStella Laurenzo } 1301436c6c9cSStella Laurenzo } catch (std::exception &err) { 1302436c6c9cSStella Laurenzo // NOTE: Sloppy to be using a catch-all here, but there are at least 1303436c6c9cSStella Laurenzo // three different unrelated exceptions that can be thrown in the 1304436c6c9cSStella Laurenzo // above "casts". Just keep the scope above small and catch them all. 1305436c6c9cSStella Laurenzo throw py::value_error((llvm::Twine("Operand ") + 1306436c6c9cSStella Laurenzo llvm::Twine(it.index()) + " of operation \"" + 1307436c6c9cSStella Laurenzo name + "\" must be a Sequence of Values (" + 1308436c6c9cSStella Laurenzo err.what() + ")") 1309436c6c9cSStella Laurenzo .str()); 1310436c6c9cSStella Laurenzo } 1311436c6c9cSStella Laurenzo } else { 1312436c6c9cSStella Laurenzo throw py::value_error("Unexpected segment spec"); 1313436c6c9cSStella Laurenzo } 1314436c6c9cSStella Laurenzo } 1315436c6c9cSStella Laurenzo } 1316436c6c9cSStella Laurenzo 1317436c6c9cSStella Laurenzo // Merge operand/result segment lengths into attributes if needed. 1318436c6c9cSStella Laurenzo if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1319436c6c9cSStella Laurenzo // Dup. 1320436c6c9cSStella Laurenzo if (attributes) { 1321436c6c9cSStella Laurenzo attributes = py::dict(*attributes); 1322436c6c9cSStella Laurenzo } else { 1323436c6c9cSStella Laurenzo attributes = py::dict(); 1324436c6c9cSStella Laurenzo } 1325436c6c9cSStella Laurenzo if (attributes->contains("result_segment_sizes") || 1326436c6c9cSStella Laurenzo attributes->contains("operand_segment_sizes")) { 1327436c6c9cSStella Laurenzo throw py::value_error("Manually setting a 'result_segment_sizes' or " 1328436c6c9cSStella Laurenzo "'operand_segment_sizes' attribute is unsupported. " 1329436c6c9cSStella Laurenzo "Use Operation.create for such low-level access."); 1330436c6c9cSStella Laurenzo } 1331436c6c9cSStella Laurenzo 1332436c6c9cSStella Laurenzo // Add result_segment_sizes attribute. 1333436c6c9cSStella Laurenzo if (!resultSegmentLengths.empty()) { 1334436c6c9cSStella Laurenzo int64_t size = resultSegmentLengths.size(); 13358d05a288SStella Laurenzo MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 13368d05a288SStella Laurenzo mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1337436c6c9cSStella Laurenzo resultSegmentLengths.size(), resultSegmentLengths.data()); 1338436c6c9cSStella Laurenzo (*attributes)["result_segment_sizes"] = 1339436c6c9cSStella Laurenzo PyAttribute(context, segmentLengthAttr); 1340436c6c9cSStella Laurenzo } 1341436c6c9cSStella Laurenzo 1342436c6c9cSStella Laurenzo // Add operand_segment_sizes attribute. 1343436c6c9cSStella Laurenzo if (!operandSegmentLengths.empty()) { 1344436c6c9cSStella Laurenzo int64_t size = operandSegmentLengths.size(); 13458d05a288SStella Laurenzo MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 13468d05a288SStella Laurenzo mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1347436c6c9cSStella Laurenzo operandSegmentLengths.size(), operandSegmentLengths.data()); 1348436c6c9cSStella Laurenzo (*attributes)["operand_segment_sizes"] = 1349436c6c9cSStella Laurenzo PyAttribute(context, segmentLengthAttr); 1350436c6c9cSStella Laurenzo } 1351436c6c9cSStella Laurenzo } 1352436c6c9cSStella Laurenzo 1353436c6c9cSStella Laurenzo // Delegate to create. 1354436c6c9cSStella Laurenzo return PyOperation::create(std::move(name), 1355436c6c9cSStella Laurenzo /*results=*/std::move(resultTypes), 1356436c6c9cSStella Laurenzo /*operands=*/std::move(operands), 1357436c6c9cSStella Laurenzo /*attributes=*/std::move(attributes), 1358436c6c9cSStella Laurenzo /*successors=*/std::move(successors), 1359436c6c9cSStella Laurenzo /*regions=*/*regions, location, maybeIp); 1360436c6c9cSStella Laurenzo } 1361436c6c9cSStella Laurenzo 1362436c6c9cSStella Laurenzo PyOpView::PyOpView(py::object operationObject) 1363436c6c9cSStella Laurenzo // Casting through the PyOperationBase base-class and then back to the 1364436c6c9cSStella Laurenzo // Operation lets us accept any PyOperationBase subclass. 1365436c6c9cSStella Laurenzo : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1366436c6c9cSStella Laurenzo operationObject(operation.getRef().getObject()) {} 1367436c6c9cSStella Laurenzo 1368436c6c9cSStella Laurenzo py::object PyOpView::createRawSubclass(py::object userClass) { 1369436c6c9cSStella Laurenzo // This is... a little gross. The typical pattern is to have a pure python 1370436c6c9cSStella Laurenzo // class that extends OpView like: 1371436c6c9cSStella Laurenzo // class AddFOp(_cext.ir.OpView): 1372436c6c9cSStella Laurenzo // def __init__(self, loc, lhs, rhs): 1373436c6c9cSStella Laurenzo // operation = loc.context.create_operation( 1374436c6c9cSStella Laurenzo // "addf", lhs, rhs, results=[lhs.type]) 1375436c6c9cSStella Laurenzo // super().__init__(operation) 1376436c6c9cSStella Laurenzo // 1377436c6c9cSStella Laurenzo // I.e. The goal of the user facing type is to provide a nice constructor 1378436c6c9cSStella Laurenzo // that has complete freedom for the op under construction. This is at odds 1379436c6c9cSStella Laurenzo // with our other desire to sometimes create this object by just passing an 1380436c6c9cSStella Laurenzo // operation (to initialize the base class). We could do *arg and **kwargs 1381436c6c9cSStella Laurenzo // munging to try to make it work, but instead, we synthesize a new class 1382436c6c9cSStella Laurenzo // on the fly which extends this user class (AddFOp in this example) and 1383436c6c9cSStella Laurenzo // *give it* the base class's __init__ method, thus bypassing the 1384436c6c9cSStella Laurenzo // intermediate subclass's __init__ method entirely. While slightly, 1385436c6c9cSStella Laurenzo // underhanded, this is safe/legal because the type hierarchy has not changed 1386436c6c9cSStella Laurenzo // (we just added a new leaf) and we aren't mucking around with __new__. 1387436c6c9cSStella Laurenzo // Typically, this new class will be stored on the original as "_Raw" and will 1388436c6c9cSStella Laurenzo // be used for casts and other things that need a variant of the class that 1389436c6c9cSStella Laurenzo // is initialized purely from an operation. 1390436c6c9cSStella Laurenzo py::object parentMetaclass = 1391436c6c9cSStella Laurenzo py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1392436c6c9cSStella Laurenzo py::dict attributes; 1393436c6c9cSStella Laurenzo // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1394436c6c9cSStella Laurenzo // now. 1395436c6c9cSStella Laurenzo // auto opViewType = py::type::of<PyOpView>(); 1396436c6c9cSStella Laurenzo auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1397436c6c9cSStella Laurenzo attributes["__init__"] = opViewType.attr("__init__"); 1398436c6c9cSStella Laurenzo py::str origName = userClass.attr("__name__"); 1399436c6c9cSStella Laurenzo py::str newName = py::str("_") + origName; 1400436c6c9cSStella Laurenzo return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1401436c6c9cSStella Laurenzo } 1402436c6c9cSStella Laurenzo 1403436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1404436c6c9cSStella Laurenzo // PyInsertionPoint. 1405436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1406436c6c9cSStella Laurenzo 1407436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1408436c6c9cSStella Laurenzo 1409436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1410436c6c9cSStella Laurenzo : refOperation(beforeOperationBase.getOperation().getRef()), 1411436c6c9cSStella Laurenzo block((*refOperation)->getBlock()) {} 1412436c6c9cSStella Laurenzo 1413436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1414436c6c9cSStella Laurenzo PyOperation &operation = operationBase.getOperation(); 1415436c6c9cSStella Laurenzo if (operation.isAttached()) 1416436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 1417436c6c9cSStella Laurenzo "Attempt to insert operation that is already attached"); 1418436c6c9cSStella Laurenzo block.getParentOperation()->checkValid(); 1419436c6c9cSStella Laurenzo MlirOperation beforeOp = {nullptr}; 1420436c6c9cSStella Laurenzo if (refOperation) { 1421436c6c9cSStella Laurenzo // Insert before operation. 1422436c6c9cSStella Laurenzo (*refOperation)->checkValid(); 1423436c6c9cSStella Laurenzo beforeOp = (*refOperation)->get(); 1424436c6c9cSStella Laurenzo } else { 1425436c6c9cSStella Laurenzo // Insert at end (before null) is only valid if the block does not 1426436c6c9cSStella Laurenzo // already end in a known terminator (violating this will cause assertion 1427436c6c9cSStella Laurenzo // failures later). 1428436c6c9cSStella Laurenzo if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1429436c6c9cSStella Laurenzo throw py::index_error("Cannot insert operation at the end of a block " 1430436c6c9cSStella Laurenzo "that already has a terminator. Did you mean to " 1431436c6c9cSStella Laurenzo "use 'InsertionPoint.at_block_terminator(block)' " 1432436c6c9cSStella Laurenzo "versus 'InsertionPoint(block)'?"); 1433436c6c9cSStella Laurenzo } 1434436c6c9cSStella Laurenzo } 1435436c6c9cSStella Laurenzo mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1436436c6c9cSStella Laurenzo operation.setAttached(); 1437436c6c9cSStella Laurenzo } 1438436c6c9cSStella Laurenzo 1439436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1440436c6c9cSStella Laurenzo MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1441436c6c9cSStella Laurenzo if (mlirOperationIsNull(firstOp)) { 1442436c6c9cSStella Laurenzo // Just insert at end. 1443436c6c9cSStella Laurenzo return PyInsertionPoint(block); 1444436c6c9cSStella Laurenzo } 1445436c6c9cSStella Laurenzo 1446436c6c9cSStella Laurenzo // Insert before first op. 1447436c6c9cSStella Laurenzo PyOperationRef firstOpRef = PyOperation::forOperation( 1448436c6c9cSStella Laurenzo block.getParentOperation()->getContext(), firstOp); 1449436c6c9cSStella Laurenzo return PyInsertionPoint{block, std::move(firstOpRef)}; 1450436c6c9cSStella Laurenzo } 1451436c6c9cSStella Laurenzo 1452436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1453436c6c9cSStella Laurenzo MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1454436c6c9cSStella Laurenzo if (mlirOperationIsNull(terminator)) 1455436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1456436c6c9cSStella Laurenzo PyOperationRef terminatorOpRef = PyOperation::forOperation( 1457436c6c9cSStella Laurenzo block.getParentOperation()->getContext(), terminator); 1458436c6c9cSStella Laurenzo return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1459436c6c9cSStella Laurenzo } 1460436c6c9cSStella Laurenzo 1461436c6c9cSStella Laurenzo py::object PyInsertionPoint::contextEnter() { 1462436c6c9cSStella Laurenzo return PyThreadContextEntry::pushInsertionPoint(*this); 1463436c6c9cSStella Laurenzo } 1464436c6c9cSStella Laurenzo 1465436c6c9cSStella Laurenzo void PyInsertionPoint::contextExit(pybind11::object excType, 1466436c6c9cSStella Laurenzo pybind11::object excVal, 1467436c6c9cSStella Laurenzo pybind11::object excTb) { 1468436c6c9cSStella Laurenzo PyThreadContextEntry::popInsertionPoint(*this); 1469436c6c9cSStella Laurenzo } 1470436c6c9cSStella Laurenzo 1471436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1472436c6c9cSStella Laurenzo // PyAttribute. 1473436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1474436c6c9cSStella Laurenzo 1475436c6c9cSStella Laurenzo bool PyAttribute::operator==(const PyAttribute &other) { 1476436c6c9cSStella Laurenzo return mlirAttributeEqual(attr, other.attr); 1477436c6c9cSStella Laurenzo } 1478436c6c9cSStella Laurenzo 1479436c6c9cSStella Laurenzo py::object PyAttribute::getCapsule() { 1480436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1481436c6c9cSStella Laurenzo } 1482436c6c9cSStella Laurenzo 1483436c6c9cSStella Laurenzo PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1484436c6c9cSStella Laurenzo MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1485436c6c9cSStella Laurenzo if (mlirAttributeIsNull(rawAttr)) 1486436c6c9cSStella Laurenzo throw py::error_already_set(); 1487436c6c9cSStella Laurenzo return PyAttribute( 1488436c6c9cSStella Laurenzo PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1489436c6c9cSStella Laurenzo } 1490436c6c9cSStella Laurenzo 1491436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1492436c6c9cSStella Laurenzo // PyNamedAttribute. 1493436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1494436c6c9cSStella Laurenzo 1495436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1496436c6c9cSStella Laurenzo : ownedName(new std::string(std::move(ownedName))) { 1497436c6c9cSStella Laurenzo namedAttr = mlirNamedAttributeGet( 1498436c6c9cSStella Laurenzo mlirIdentifierGet(mlirAttributeGetContext(attr), 1499436c6c9cSStella Laurenzo toMlirStringRef(*this->ownedName)), 1500436c6c9cSStella Laurenzo attr); 1501436c6c9cSStella Laurenzo } 1502436c6c9cSStella Laurenzo 1503436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1504436c6c9cSStella Laurenzo // PyType. 1505436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1506436c6c9cSStella Laurenzo 1507436c6c9cSStella Laurenzo bool PyType::operator==(const PyType &other) { 1508436c6c9cSStella Laurenzo return mlirTypeEqual(type, other.type); 1509436c6c9cSStella Laurenzo } 1510436c6c9cSStella Laurenzo 1511436c6c9cSStella Laurenzo py::object PyType::getCapsule() { 1512436c6c9cSStella Laurenzo return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1513436c6c9cSStella Laurenzo } 1514436c6c9cSStella Laurenzo 1515436c6c9cSStella Laurenzo PyType PyType::createFromCapsule(py::object capsule) { 1516436c6c9cSStella Laurenzo MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1517436c6c9cSStella Laurenzo if (mlirTypeIsNull(rawType)) 1518436c6c9cSStella Laurenzo throw py::error_already_set(); 1519436c6c9cSStella Laurenzo return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1520436c6c9cSStella Laurenzo rawType); 1521436c6c9cSStella Laurenzo } 1522436c6c9cSStella Laurenzo 1523436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1524436c6c9cSStella Laurenzo // PyValue and subclases. 1525436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1526436c6c9cSStella Laurenzo 15273f3d1c90SMike Urbach pybind11::object PyValue::getCapsule() { 15283f3d1c90SMike Urbach return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 15293f3d1c90SMike Urbach } 15303f3d1c90SMike Urbach 15313f3d1c90SMike Urbach PyValue PyValue::createFromCapsule(pybind11::object capsule) { 15323f3d1c90SMike Urbach MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 15333f3d1c90SMike Urbach if (mlirValueIsNull(value)) 15343f3d1c90SMike Urbach throw py::error_already_set(); 15353f3d1c90SMike Urbach MlirOperation owner; 15363f3d1c90SMike Urbach if (mlirValueIsAOpResult(value)) 15373f3d1c90SMike Urbach owner = mlirOpResultGetOwner(value); 15383f3d1c90SMike Urbach if (mlirValueIsABlockArgument(value)) 15393f3d1c90SMike Urbach owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 15403f3d1c90SMike Urbach if (mlirOperationIsNull(owner)) 15413f3d1c90SMike Urbach throw py::error_already_set(); 15423f3d1c90SMike Urbach MlirContext ctx = mlirOperationGetContext(owner); 15433f3d1c90SMike Urbach PyOperationRef ownerRef = 15443f3d1c90SMike Urbach PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 15453f3d1c90SMike Urbach return PyValue(ownerRef, value); 15463f3d1c90SMike Urbach } 15473f3d1c90SMike Urbach 154830d61893SAlex Zinenko //------------------------------------------------------------------------------ 154930d61893SAlex Zinenko // PySymbolTable. 155030d61893SAlex Zinenko //------------------------------------------------------------------------------ 155130d61893SAlex Zinenko 155230d61893SAlex Zinenko PySymbolTable::PySymbolTable(PyOperationBase &operation) 155330d61893SAlex Zinenko : operation(operation.getOperation().getRef()) { 155430d61893SAlex Zinenko symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 155530d61893SAlex Zinenko if (mlirSymbolTableIsNull(symbolTable)) { 155630d61893SAlex Zinenko throw py::cast_error("Operation is not a Symbol Table."); 155730d61893SAlex Zinenko } 155830d61893SAlex Zinenko } 155930d61893SAlex Zinenko 156030d61893SAlex Zinenko py::object PySymbolTable::dunderGetItem(const std::string &name) { 156130d61893SAlex Zinenko operation->checkValid(); 156230d61893SAlex Zinenko MlirOperation symbol = mlirSymbolTableLookup( 156330d61893SAlex Zinenko symbolTable, mlirStringRefCreate(name.data(), name.length())); 156430d61893SAlex Zinenko if (mlirOperationIsNull(symbol)) 156530d61893SAlex Zinenko throw py::key_error("Symbol '" + name + "' not in the symbol table."); 156630d61893SAlex Zinenko 156730d61893SAlex Zinenko return PyOperation::forOperation(operation->getContext(), symbol, 156830d61893SAlex Zinenko operation.getObject()) 156930d61893SAlex Zinenko ->createOpView(); 157030d61893SAlex Zinenko } 157130d61893SAlex Zinenko 157230d61893SAlex Zinenko void PySymbolTable::erase(PyOperationBase &symbol) { 157330d61893SAlex Zinenko operation->checkValid(); 157430d61893SAlex Zinenko symbol.getOperation().checkValid(); 157530d61893SAlex Zinenko mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 157630d61893SAlex Zinenko // The operation is also erased, so we must invalidate it. There may be Python 157730d61893SAlex Zinenko // references to this operation so we don't want to delete it from the list of 157830d61893SAlex Zinenko // live operations here. 157930d61893SAlex Zinenko symbol.getOperation().valid = false; 158030d61893SAlex Zinenko } 158130d61893SAlex Zinenko 158230d61893SAlex Zinenko void PySymbolTable::dunderDel(const std::string &name) { 158330d61893SAlex Zinenko py::object operation = dunderGetItem(name); 158430d61893SAlex Zinenko erase(py::cast<PyOperationBase &>(operation)); 158530d61893SAlex Zinenko } 158630d61893SAlex Zinenko 158730d61893SAlex Zinenko PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 158830d61893SAlex Zinenko operation->checkValid(); 158930d61893SAlex Zinenko symbol.getOperation().checkValid(); 159030d61893SAlex Zinenko MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 159130d61893SAlex Zinenko symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 159230d61893SAlex Zinenko if (mlirAttributeIsNull(symbolAttr)) 159330d61893SAlex Zinenko throw py::value_error("Expected operation to have a symbol name."); 159430d61893SAlex Zinenko return PyAttribute( 159530d61893SAlex Zinenko symbol.getOperation().getContext(), 159630d61893SAlex Zinenko mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 159730d61893SAlex Zinenko } 159830d61893SAlex Zinenko 1599436c6c9cSStella Laurenzo namespace { 1600436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR values that subclass Value and should be 1601436c6c9cSStella Laurenzo /// castable from it. The value hierarchy is one level deep and is not supposed 1602436c6c9cSStella Laurenzo /// to accommodate other levels unless core MLIR changes. 1603436c6c9cSStella Laurenzo template <typename DerivedTy> 1604436c6c9cSStella Laurenzo class PyConcreteValue : public PyValue { 1605436c6c9cSStella Laurenzo public: 1606436c6c9cSStella Laurenzo // Derived classes must define statics for: 1607436c6c9cSStella Laurenzo // IsAFunctionTy isaFunction 1608436c6c9cSStella Laurenzo // const char *pyClassName 1609436c6c9cSStella Laurenzo // and redefine bindDerived. 1610436c6c9cSStella Laurenzo using ClassTy = py::class_<DerivedTy, PyValue>; 1611436c6c9cSStella Laurenzo using IsAFunctionTy = bool (*)(MlirValue); 1612436c6c9cSStella Laurenzo 1613436c6c9cSStella Laurenzo PyConcreteValue() = default; 1614436c6c9cSStella Laurenzo PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1615436c6c9cSStella Laurenzo : PyValue(operationRef, value) {} 1616436c6c9cSStella Laurenzo PyConcreteValue(PyValue &orig) 1617436c6c9cSStella Laurenzo : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1618436c6c9cSStella Laurenzo 1619436c6c9cSStella Laurenzo /// Attempts to cast the original value to the derived type and throws on 1620436c6c9cSStella Laurenzo /// type mismatches. 1621436c6c9cSStella Laurenzo static MlirValue castFrom(PyValue &orig) { 1622436c6c9cSStella Laurenzo if (!DerivedTy::isaFunction(orig.get())) { 1623436c6c9cSStella Laurenzo auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1624436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1625436c6c9cSStella Laurenzo DerivedTy::pyClassName + 1626436c6c9cSStella Laurenzo " (from " + origRepr + ")"); 1627436c6c9cSStella Laurenzo } 1628436c6c9cSStella Laurenzo return orig.get(); 1629436c6c9cSStella Laurenzo } 1630436c6c9cSStella Laurenzo 1631436c6c9cSStella Laurenzo /// Binds the Python module objects to functions of this class. 1632436c6c9cSStella Laurenzo static void bind(py::module &m) { 1633f05ff4f7SStella Laurenzo auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1634*a6e7d024SStella Laurenzo cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1635*a6e7d024SStella Laurenzo cls.def_static( 1636*a6e7d024SStella Laurenzo "isinstance", 1637*a6e7d024SStella Laurenzo [](PyValue &otherValue) -> bool { 163878f2dae0SAlex Zinenko return DerivedTy::isaFunction(otherValue); 1639*a6e7d024SStella Laurenzo }, 1640*a6e7d024SStella Laurenzo py::arg("other_value")); 1641436c6c9cSStella Laurenzo DerivedTy::bindDerived(cls); 1642436c6c9cSStella Laurenzo } 1643436c6c9cSStella Laurenzo 1644436c6c9cSStella Laurenzo /// Implemented by derived classes to add methods to the Python subclass. 1645436c6c9cSStella Laurenzo static void bindDerived(ClassTy &m) {} 1646436c6c9cSStella Laurenzo }; 1647436c6c9cSStella Laurenzo 1648436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument. 1649436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1650436c6c9cSStella Laurenzo public: 1651436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1652436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "BlockArgument"; 1653436c6c9cSStella Laurenzo using PyConcreteValue::PyConcreteValue; 1654436c6c9cSStella Laurenzo 1655436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1656436c6c9cSStella Laurenzo c.def_property_readonly("owner", [](PyBlockArgument &self) { 1657436c6c9cSStella Laurenzo return PyBlock(self.getParentOperation(), 1658436c6c9cSStella Laurenzo mlirBlockArgumentGetOwner(self.get())); 1659436c6c9cSStella Laurenzo }); 1660436c6c9cSStella Laurenzo c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1661436c6c9cSStella Laurenzo return mlirBlockArgumentGetArgNumber(self.get()); 1662436c6c9cSStella Laurenzo }); 1663*a6e7d024SStella Laurenzo c.def( 1664*a6e7d024SStella Laurenzo "set_type", 1665*a6e7d024SStella Laurenzo [](PyBlockArgument &self, PyType type) { 1666436c6c9cSStella Laurenzo return mlirBlockArgumentSetType(self.get(), type); 1667*a6e7d024SStella Laurenzo }, 1668*a6e7d024SStella Laurenzo py::arg("type")); 1669436c6c9cSStella Laurenzo } 1670436c6c9cSStella Laurenzo }; 1671436c6c9cSStella Laurenzo 1672436c6c9cSStella Laurenzo /// Python wrapper for MlirOpResult. 1673436c6c9cSStella Laurenzo class PyOpResult : public PyConcreteValue<PyOpResult> { 1674436c6c9cSStella Laurenzo public: 1675436c6c9cSStella Laurenzo static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1676436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "OpResult"; 1677436c6c9cSStella Laurenzo using PyConcreteValue::PyConcreteValue; 1678436c6c9cSStella Laurenzo 1679436c6c9cSStella Laurenzo static void bindDerived(ClassTy &c) { 1680436c6c9cSStella Laurenzo c.def_property_readonly("owner", [](PyOpResult &self) { 1681436c6c9cSStella Laurenzo assert( 1682436c6c9cSStella Laurenzo mlirOperationEqual(self.getParentOperation()->get(), 1683436c6c9cSStella Laurenzo mlirOpResultGetOwner(self.get())) && 1684436c6c9cSStella Laurenzo "expected the owner of the value in Python to match that in the IR"); 16856ff74f96SMike Urbach return self.getParentOperation().getObject(); 1686436c6c9cSStella Laurenzo }); 1687436c6c9cSStella Laurenzo c.def_property_readonly("result_number", [](PyOpResult &self) { 1688436c6c9cSStella Laurenzo return mlirOpResultGetResultNumber(self.get()); 1689436c6c9cSStella Laurenzo }); 1690436c6c9cSStella Laurenzo } 1691436c6c9cSStella Laurenzo }; 1692436c6c9cSStella Laurenzo 1693ed9e52f3SAlex Zinenko /// Returns the list of types of the values held by container. 1694ed9e52f3SAlex Zinenko template <typename Container> 1695ed9e52f3SAlex Zinenko static std::vector<PyType> getValueTypes(Container &container, 1696ed9e52f3SAlex Zinenko PyMlirContextRef &context) { 1697ed9e52f3SAlex Zinenko std::vector<PyType> result; 1698ed9e52f3SAlex Zinenko result.reserve(container.getNumElements()); 1699ed9e52f3SAlex Zinenko for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1700ed9e52f3SAlex Zinenko result.push_back( 1701ed9e52f3SAlex Zinenko PyType(context, mlirValueGetType(container.getElement(i).get()))); 1702ed9e52f3SAlex Zinenko } 1703ed9e52f3SAlex Zinenko return result; 1704ed9e52f3SAlex Zinenko } 1705ed9e52f3SAlex Zinenko 1706436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive 1707436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the 1708436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in 1709436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime. 1710afeda4b9SAlex Zinenko class PyBlockArgumentList 1711afeda4b9SAlex Zinenko : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1712436c6c9cSStella Laurenzo public: 1713afeda4b9SAlex Zinenko static constexpr const char *pyClassName = "BlockArgumentList"; 1714436c6c9cSStella Laurenzo 1715afeda4b9SAlex Zinenko PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1716afeda4b9SAlex Zinenko intptr_t startIndex = 0, intptr_t length = -1, 1717afeda4b9SAlex Zinenko intptr_t step = 1) 1718afeda4b9SAlex Zinenko : Sliceable(startIndex, 1719afeda4b9SAlex Zinenko length == -1 ? mlirBlockGetNumArguments(block) : length, 1720afeda4b9SAlex Zinenko step), 1721afeda4b9SAlex Zinenko operation(std::move(operation)), block(block) {} 1722afeda4b9SAlex Zinenko 1723afeda4b9SAlex Zinenko /// Returns the number of arguments in the list. 1724afeda4b9SAlex Zinenko intptr_t getNumElements() { 1725436c6c9cSStella Laurenzo operation->checkValid(); 1726436c6c9cSStella Laurenzo return mlirBlockGetNumArguments(block); 1727436c6c9cSStella Laurenzo } 1728436c6c9cSStella Laurenzo 1729afeda4b9SAlex Zinenko /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1730afeda4b9SAlex Zinenko PyBlockArgument getElement(intptr_t pos) { 1731afeda4b9SAlex Zinenko MlirValue argument = mlirBlockGetArgument(block, pos); 1732afeda4b9SAlex Zinenko return PyBlockArgument(operation, argument); 1733436c6c9cSStella Laurenzo } 1734436c6c9cSStella Laurenzo 1735afeda4b9SAlex Zinenko /// Returns a sublist of this list. 1736afeda4b9SAlex Zinenko PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1737afeda4b9SAlex Zinenko intptr_t step) { 1738afeda4b9SAlex Zinenko return PyBlockArgumentList(operation, block, startIndex, length, step); 1739436c6c9cSStella Laurenzo } 1740436c6c9cSStella Laurenzo 1741ed9e52f3SAlex Zinenko static void bindDerived(ClassTy &c) { 1742ed9e52f3SAlex Zinenko c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1743ed9e52f3SAlex Zinenko return getValueTypes(self, self.operation->getContext()); 1744ed9e52f3SAlex Zinenko }); 1745ed9e52f3SAlex Zinenko } 1746ed9e52f3SAlex Zinenko 1747436c6c9cSStella Laurenzo private: 1748436c6c9cSStella Laurenzo PyOperationRef operation; 1749436c6c9cSStella Laurenzo MlirBlock block; 1750436c6c9cSStella Laurenzo }; 1751436c6c9cSStella Laurenzo 1752436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive 1753436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the 1754436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this 1755436c6c9cSStella Laurenzo /// operation. 1756436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1757436c6c9cSStella Laurenzo public: 1758436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "OpOperandList"; 1759436c6c9cSStella Laurenzo 1760436c6c9cSStella Laurenzo PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1761436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 1762436c6c9cSStella Laurenzo : Sliceable(startIndex, 1763436c6c9cSStella Laurenzo length == -1 ? mlirOperationGetNumOperands(operation->get()) 1764436c6c9cSStella Laurenzo : length, 1765436c6c9cSStella Laurenzo step), 1766436c6c9cSStella Laurenzo operation(operation) {} 1767436c6c9cSStella Laurenzo 1768436c6c9cSStella Laurenzo intptr_t getNumElements() { 1769436c6c9cSStella Laurenzo operation->checkValid(); 1770436c6c9cSStella Laurenzo return mlirOperationGetNumOperands(operation->get()); 1771436c6c9cSStella Laurenzo } 1772436c6c9cSStella Laurenzo 1773436c6c9cSStella Laurenzo PyValue getElement(intptr_t pos) { 17745664c5e2SJohn Demme MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 17755664c5e2SJohn Demme MlirOperation owner; 17765664c5e2SJohn Demme if (mlirValueIsAOpResult(operand)) 17775664c5e2SJohn Demme owner = mlirOpResultGetOwner(operand); 17785664c5e2SJohn Demme else if (mlirValueIsABlockArgument(operand)) 17795664c5e2SJohn Demme owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 17805664c5e2SJohn Demme else 17815664c5e2SJohn Demme assert(false && "Value must be an block arg or op result."); 17825664c5e2SJohn Demme PyOperationRef pyOwner = 17835664c5e2SJohn Demme PyOperation::forOperation(operation->getContext(), owner); 17845664c5e2SJohn Demme return PyValue(pyOwner, operand); 1785436c6c9cSStella Laurenzo } 1786436c6c9cSStella Laurenzo 1787436c6c9cSStella Laurenzo PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1788436c6c9cSStella Laurenzo return PyOpOperandList(operation, startIndex, length, step); 1789436c6c9cSStella Laurenzo } 1790436c6c9cSStella Laurenzo 179163d16d06SMike Urbach void dunderSetItem(intptr_t index, PyValue value) { 179263d16d06SMike Urbach index = wrapIndex(index); 179363d16d06SMike Urbach mlirOperationSetOperand(operation->get(), index, value.get()); 179463d16d06SMike Urbach } 179563d16d06SMike Urbach 179663d16d06SMike Urbach static void bindDerived(ClassTy &c) { 179763d16d06SMike Urbach c.def("__setitem__", &PyOpOperandList::dunderSetItem); 179863d16d06SMike Urbach } 179963d16d06SMike Urbach 1800436c6c9cSStella Laurenzo private: 1801436c6c9cSStella Laurenzo PyOperationRef operation; 1802436c6c9cSStella Laurenzo }; 1803436c6c9cSStella Laurenzo 1804436c6c9cSStella Laurenzo /// A list of operation results. Internally, these are stored as consecutive 1805436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the 1806436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this 1807436c6c9cSStella Laurenzo /// operation. 1808436c6c9cSStella Laurenzo class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1809436c6c9cSStella Laurenzo public: 1810436c6c9cSStella Laurenzo static constexpr const char *pyClassName = "OpResultList"; 1811436c6c9cSStella Laurenzo 1812436c6c9cSStella Laurenzo PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1813436c6c9cSStella Laurenzo intptr_t length = -1, intptr_t step = 1) 1814436c6c9cSStella Laurenzo : Sliceable(startIndex, 1815436c6c9cSStella Laurenzo length == -1 ? mlirOperationGetNumResults(operation->get()) 1816436c6c9cSStella Laurenzo : length, 1817436c6c9cSStella Laurenzo step), 1818436c6c9cSStella Laurenzo operation(operation) {} 1819436c6c9cSStella Laurenzo 1820436c6c9cSStella Laurenzo intptr_t getNumElements() { 1821436c6c9cSStella Laurenzo operation->checkValid(); 1822436c6c9cSStella Laurenzo return mlirOperationGetNumResults(operation->get()); 1823436c6c9cSStella Laurenzo } 1824436c6c9cSStella Laurenzo 1825436c6c9cSStella Laurenzo PyOpResult getElement(intptr_t index) { 1826436c6c9cSStella Laurenzo PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1827436c6c9cSStella Laurenzo return PyOpResult(value); 1828436c6c9cSStella Laurenzo } 1829436c6c9cSStella Laurenzo 1830436c6c9cSStella Laurenzo PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1831436c6c9cSStella Laurenzo return PyOpResultList(operation, startIndex, length, step); 1832436c6c9cSStella Laurenzo } 1833436c6c9cSStella Laurenzo 1834ed9e52f3SAlex Zinenko static void bindDerived(ClassTy &c) { 1835ed9e52f3SAlex Zinenko c.def_property_readonly("types", [](PyOpResultList &self) { 1836ed9e52f3SAlex Zinenko return getValueTypes(self, self.operation->getContext()); 1837ed9e52f3SAlex Zinenko }); 1838ed9e52f3SAlex Zinenko } 1839ed9e52f3SAlex Zinenko 1840436c6c9cSStella Laurenzo private: 1841436c6c9cSStella Laurenzo PyOperationRef operation; 1842436c6c9cSStella Laurenzo }; 1843436c6c9cSStella Laurenzo 1844436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing 1845436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes. 1846436c6c9cSStella Laurenzo class PyOpAttributeMap { 1847436c6c9cSStella Laurenzo public: 1848436c6c9cSStella Laurenzo PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1849436c6c9cSStella Laurenzo 1850436c6c9cSStella Laurenzo PyAttribute dunderGetItemNamed(const std::string &name) { 1851436c6c9cSStella Laurenzo MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1852436c6c9cSStella Laurenzo toMlirStringRef(name)); 1853436c6c9cSStella Laurenzo if (mlirAttributeIsNull(attr)) { 1854436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 1855436c6c9cSStella Laurenzo "attempt to access a non-existent attribute"); 1856436c6c9cSStella Laurenzo } 1857436c6c9cSStella Laurenzo return PyAttribute(operation->getContext(), attr); 1858436c6c9cSStella Laurenzo } 1859436c6c9cSStella Laurenzo 1860436c6c9cSStella Laurenzo PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1861436c6c9cSStella Laurenzo if (index < 0 || index >= dunderLen()) { 1862436c6c9cSStella Laurenzo throw SetPyError(PyExc_IndexError, 1863436c6c9cSStella Laurenzo "attempt to access out of bounds attribute"); 1864436c6c9cSStella Laurenzo } 1865436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr = 1866436c6c9cSStella Laurenzo mlirOperationGetAttribute(operation->get(), index); 1867436c6c9cSStella Laurenzo return PyNamedAttribute( 1868436c6c9cSStella Laurenzo namedAttr.attribute, 1869120591e1SRiver Riddle std::string(mlirIdentifierStr(namedAttr.name).data, 1870120591e1SRiver Riddle mlirIdentifierStr(namedAttr.name).length)); 1871436c6c9cSStella Laurenzo } 1872436c6c9cSStella Laurenzo 1873436c6c9cSStella Laurenzo void dunderSetItem(const std::string &name, PyAttribute attr) { 1874436c6c9cSStella Laurenzo mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1875436c6c9cSStella Laurenzo attr); 1876436c6c9cSStella Laurenzo } 1877436c6c9cSStella Laurenzo 1878436c6c9cSStella Laurenzo void dunderDelItem(const std::string &name) { 1879436c6c9cSStella Laurenzo int removed = mlirOperationRemoveAttributeByName(operation->get(), 1880436c6c9cSStella Laurenzo toMlirStringRef(name)); 1881436c6c9cSStella Laurenzo if (!removed) 1882436c6c9cSStella Laurenzo throw SetPyError(PyExc_KeyError, 1883436c6c9cSStella Laurenzo "attempt to delete a non-existent attribute"); 1884436c6c9cSStella Laurenzo } 1885436c6c9cSStella Laurenzo 1886436c6c9cSStella Laurenzo intptr_t dunderLen() { 1887436c6c9cSStella Laurenzo return mlirOperationGetNumAttributes(operation->get()); 1888436c6c9cSStella Laurenzo } 1889436c6c9cSStella Laurenzo 1890436c6c9cSStella Laurenzo bool dunderContains(const std::string &name) { 1891436c6c9cSStella Laurenzo return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1892436c6c9cSStella Laurenzo operation->get(), toMlirStringRef(name))); 1893436c6c9cSStella Laurenzo } 1894436c6c9cSStella Laurenzo 1895436c6c9cSStella Laurenzo static void bind(py::module &m) { 1896f05ff4f7SStella Laurenzo py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1897436c6c9cSStella Laurenzo .def("__contains__", &PyOpAttributeMap::dunderContains) 1898436c6c9cSStella Laurenzo .def("__len__", &PyOpAttributeMap::dunderLen) 1899436c6c9cSStella Laurenzo .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1900436c6c9cSStella Laurenzo .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1901436c6c9cSStella Laurenzo .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1902436c6c9cSStella Laurenzo .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1903436c6c9cSStella Laurenzo } 1904436c6c9cSStella Laurenzo 1905436c6c9cSStella Laurenzo private: 1906436c6c9cSStella Laurenzo PyOperationRef operation; 1907436c6c9cSStella Laurenzo }; 1908436c6c9cSStella Laurenzo 1909436c6c9cSStella Laurenzo } // end namespace 1910436c6c9cSStella Laurenzo 1911436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1912436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule. 1913436c6c9cSStella Laurenzo //------------------------------------------------------------------------------ 1914436c6c9cSStella Laurenzo 1915436c6c9cSStella Laurenzo void mlir::python::populateIRCore(py::module &m) { 1916436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 19174acd8457SAlex Zinenko // Mapping of MlirContext. 1918436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1919f05ff4f7SStella Laurenzo py::class_<PyMlirContext>(m, "Context", py::module_local()) 1920436c6c9cSStella Laurenzo .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1921436c6c9cSStella Laurenzo .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1922436c6c9cSStella Laurenzo .def("_get_context_again", 1923436c6c9cSStella Laurenzo [](PyMlirContext &self) { 1924436c6c9cSStella Laurenzo PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1925436c6c9cSStella Laurenzo return ref.releaseObject(); 1926436c6c9cSStella Laurenzo }) 1927436c6c9cSStella Laurenzo .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1928436c6c9cSStella Laurenzo .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1929436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1930436c6c9cSStella Laurenzo &PyMlirContext::getCapsule) 1931436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1932436c6c9cSStella Laurenzo .def("__enter__", &PyMlirContext::contextEnter) 1933436c6c9cSStella Laurenzo .def("__exit__", &PyMlirContext::contextExit) 1934436c6c9cSStella Laurenzo .def_property_readonly_static( 1935436c6c9cSStella Laurenzo "current", 1936436c6c9cSStella Laurenzo [](py::object & /*class*/) { 1937436c6c9cSStella Laurenzo auto *context = PyThreadContextEntry::getDefaultContext(); 1938436c6c9cSStella Laurenzo if (!context) 1939436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "No current Context"); 1940436c6c9cSStella Laurenzo return context; 1941436c6c9cSStella Laurenzo }, 1942436c6c9cSStella Laurenzo "Gets the Context bound to the current thread or raises ValueError") 1943436c6c9cSStella Laurenzo .def_property_readonly( 1944436c6c9cSStella Laurenzo "dialects", 1945436c6c9cSStella Laurenzo [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1946436c6c9cSStella Laurenzo "Gets a container for accessing dialects by name") 1947436c6c9cSStella Laurenzo .def_property_readonly( 1948436c6c9cSStella Laurenzo "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1949436c6c9cSStella Laurenzo "Alias for 'dialect'") 1950436c6c9cSStella Laurenzo .def( 1951436c6c9cSStella Laurenzo "get_dialect_descriptor", 1952436c6c9cSStella Laurenzo [=](PyMlirContext &self, std::string &name) { 1953436c6c9cSStella Laurenzo MlirDialect dialect = mlirContextGetOrLoadDialect( 1954436c6c9cSStella Laurenzo self.get(), {name.data(), name.size()}); 1955436c6c9cSStella Laurenzo if (mlirDialectIsNull(dialect)) { 1956436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 1957436c6c9cSStella Laurenzo Twine("Dialect '") + name + "' not found"); 1958436c6c9cSStella Laurenzo } 1959436c6c9cSStella Laurenzo return PyDialectDescriptor(self.getRef(), dialect); 1960436c6c9cSStella Laurenzo }, 1961*a6e7d024SStella Laurenzo py::arg("dialect_name"), 1962436c6c9cSStella Laurenzo "Gets or loads a dialect by name, returning its descriptor object") 1963436c6c9cSStella Laurenzo .def_property( 1964436c6c9cSStella Laurenzo "allow_unregistered_dialects", 1965436c6c9cSStella Laurenzo [](PyMlirContext &self) -> bool { 1966436c6c9cSStella Laurenzo return mlirContextGetAllowUnregisteredDialects(self.get()); 1967436c6c9cSStella Laurenzo }, 1968436c6c9cSStella Laurenzo [](PyMlirContext &self, bool value) { 1969436c6c9cSStella Laurenzo mlirContextSetAllowUnregisteredDialects(self.get(), value); 19709a9214faSStella Laurenzo }) 1971*a6e7d024SStella Laurenzo .def( 1972*a6e7d024SStella Laurenzo "enable_multithreading", 1973caa159f0SNicolas Vasilache [](PyMlirContext &self, bool enable) { 1974caa159f0SNicolas Vasilache mlirContextEnableMultithreading(self.get(), enable); 1975*a6e7d024SStella Laurenzo }, 1976*a6e7d024SStella Laurenzo py::arg("enable")) 1977*a6e7d024SStella Laurenzo .def( 1978*a6e7d024SStella Laurenzo "is_registered_operation", 19799a9214faSStella Laurenzo [](PyMlirContext &self, std::string &name) { 19809a9214faSStella Laurenzo return mlirContextIsRegisteredOperation( 19819a9214faSStella Laurenzo self.get(), MlirStringRef{name.data(), name.size()}); 1982*a6e7d024SStella Laurenzo }, 1983*a6e7d024SStella Laurenzo py::arg("operation_name")); 1984436c6c9cSStella Laurenzo 1985436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1986436c6c9cSStella Laurenzo // Mapping of PyDialectDescriptor 1987436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 1988f05ff4f7SStella Laurenzo py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1989436c6c9cSStella Laurenzo .def_property_readonly("namespace", 1990436c6c9cSStella Laurenzo [](PyDialectDescriptor &self) { 1991436c6c9cSStella Laurenzo MlirStringRef ns = 1992436c6c9cSStella Laurenzo mlirDialectGetNamespace(self.get()); 1993436c6c9cSStella Laurenzo return py::str(ns.data, ns.length); 1994436c6c9cSStella Laurenzo }) 1995436c6c9cSStella Laurenzo .def("__repr__", [](PyDialectDescriptor &self) { 1996436c6c9cSStella Laurenzo MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1997436c6c9cSStella Laurenzo std::string repr("<DialectDescriptor "); 1998436c6c9cSStella Laurenzo repr.append(ns.data, ns.length); 1999436c6c9cSStella Laurenzo repr.append(">"); 2000436c6c9cSStella Laurenzo return repr; 2001436c6c9cSStella Laurenzo }); 2002436c6c9cSStella Laurenzo 2003436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2004436c6c9cSStella Laurenzo // Mapping of PyDialects 2005436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2006f05ff4f7SStella Laurenzo py::class_<PyDialects>(m, "Dialects", py::module_local()) 2007436c6c9cSStella Laurenzo .def("__getitem__", 2008436c6c9cSStella Laurenzo [=](PyDialects &self, std::string keyName) { 2009436c6c9cSStella Laurenzo MlirDialect dialect = 2010436c6c9cSStella Laurenzo self.getDialectForKey(keyName, /*attrError=*/false); 2011436c6c9cSStella Laurenzo py::object descriptor = 2012436c6c9cSStella Laurenzo py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2013436c6c9cSStella Laurenzo return createCustomDialectWrapper(keyName, std::move(descriptor)); 2014436c6c9cSStella Laurenzo }) 2015436c6c9cSStella Laurenzo .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2016436c6c9cSStella Laurenzo MlirDialect dialect = 2017436c6c9cSStella Laurenzo self.getDialectForKey(attrName, /*attrError=*/true); 2018436c6c9cSStella Laurenzo py::object descriptor = 2019436c6c9cSStella Laurenzo py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2020436c6c9cSStella Laurenzo return createCustomDialectWrapper(attrName, std::move(descriptor)); 2021436c6c9cSStella Laurenzo }); 2022436c6c9cSStella Laurenzo 2023436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2024436c6c9cSStella Laurenzo // Mapping of PyDialect 2025436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2026f05ff4f7SStella Laurenzo py::class_<PyDialect>(m, "Dialect", py::module_local()) 2027*a6e7d024SStella Laurenzo .def(py::init<py::object>(), py::arg("descriptor")) 2028436c6c9cSStella Laurenzo .def_property_readonly( 2029436c6c9cSStella Laurenzo "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2030436c6c9cSStella Laurenzo .def("__repr__", [](py::object self) { 2031436c6c9cSStella Laurenzo auto clazz = self.attr("__class__"); 2032436c6c9cSStella Laurenzo return py::str("<Dialect ") + 2033436c6c9cSStella Laurenzo self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2034436c6c9cSStella Laurenzo clazz.attr("__module__") + py::str(".") + 2035436c6c9cSStella Laurenzo clazz.attr("__name__") + py::str(")>"); 2036436c6c9cSStella Laurenzo }); 2037436c6c9cSStella Laurenzo 2038436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2039436c6c9cSStella Laurenzo // Mapping of Location 2040436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2041f05ff4f7SStella Laurenzo py::class_<PyLocation>(m, "Location", py::module_local()) 2042436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2043436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2044436c6c9cSStella Laurenzo .def("__enter__", &PyLocation::contextEnter) 2045436c6c9cSStella Laurenzo .def("__exit__", &PyLocation::contextExit) 2046436c6c9cSStella Laurenzo .def("__eq__", 2047436c6c9cSStella Laurenzo [](PyLocation &self, PyLocation &other) -> bool { 2048436c6c9cSStella Laurenzo return mlirLocationEqual(self, other); 2049436c6c9cSStella Laurenzo }) 2050436c6c9cSStella Laurenzo .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2051436c6c9cSStella Laurenzo .def_property_readonly_static( 2052436c6c9cSStella Laurenzo "current", 2053436c6c9cSStella Laurenzo [](py::object & /*class*/) { 2054436c6c9cSStella Laurenzo auto *loc = PyThreadContextEntry::getDefaultLocation(); 2055436c6c9cSStella Laurenzo if (!loc) 2056436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "No current Location"); 2057436c6c9cSStella Laurenzo return loc; 2058436c6c9cSStella Laurenzo }, 2059436c6c9cSStella Laurenzo "Gets the Location bound to the current thread or raises ValueError") 2060436c6c9cSStella Laurenzo .def_static( 2061436c6c9cSStella Laurenzo "unknown", 2062436c6c9cSStella Laurenzo [](DefaultingPyMlirContext context) { 2063436c6c9cSStella Laurenzo return PyLocation(context->getRef(), 2064436c6c9cSStella Laurenzo mlirLocationUnknownGet(context->get())); 2065436c6c9cSStella Laurenzo }, 2066436c6c9cSStella Laurenzo py::arg("context") = py::none(), 2067436c6c9cSStella Laurenzo "Gets a Location representing an unknown location") 2068436c6c9cSStella Laurenzo .def_static( 2069e67cbbefSJacques Pienaar "callsite", 2070e67cbbefSJacques Pienaar [](PyLocation callee, const std::vector<PyLocation> &frames, 2071e67cbbefSJacques Pienaar DefaultingPyMlirContext context) { 2072e67cbbefSJacques Pienaar if (frames.empty()) 2073e67cbbefSJacques Pienaar throw py::value_error("No caller frames provided"); 2074e67cbbefSJacques Pienaar MlirLocation caller = frames.back().get(); 2075e2f16be5SMehdi Amini for (const PyLocation &frame : 2076e67cbbefSJacques Pienaar llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2077e67cbbefSJacques Pienaar caller = mlirLocationCallSiteGet(frame.get(), caller); 2078e67cbbefSJacques Pienaar return PyLocation(context->getRef(), 2079e67cbbefSJacques Pienaar mlirLocationCallSiteGet(callee.get(), caller)); 2080e67cbbefSJacques Pienaar }, 2081e67cbbefSJacques Pienaar py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2082e67cbbefSJacques Pienaar kContextGetCallSiteLocationDocstring) 2083e67cbbefSJacques Pienaar .def_static( 2084436c6c9cSStella Laurenzo "file", 2085436c6c9cSStella Laurenzo [](std::string filename, int line, int col, 2086436c6c9cSStella Laurenzo DefaultingPyMlirContext context) { 2087436c6c9cSStella Laurenzo return PyLocation( 2088436c6c9cSStella Laurenzo context->getRef(), 2089436c6c9cSStella Laurenzo mlirLocationFileLineColGet( 2090436c6c9cSStella Laurenzo context->get(), toMlirStringRef(filename), line, col)); 2091436c6c9cSStella Laurenzo }, 2092436c6c9cSStella Laurenzo py::arg("filename"), py::arg("line"), py::arg("col"), 2093436c6c9cSStella Laurenzo py::arg("context") = py::none(), kContextGetFileLocationDocstring) 209404d76d36SJacques Pienaar .def_static( 209504d76d36SJacques Pienaar "name", 209604d76d36SJacques Pienaar [](std::string name, llvm::Optional<PyLocation> childLoc, 209704d76d36SJacques Pienaar DefaultingPyMlirContext context) { 209804d76d36SJacques Pienaar return PyLocation( 209904d76d36SJacques Pienaar context->getRef(), 210004d76d36SJacques Pienaar mlirLocationNameGet( 210104d76d36SJacques Pienaar context->get(), toMlirStringRef(name), 210204d76d36SJacques Pienaar childLoc ? childLoc->get() 210304d76d36SJacques Pienaar : mlirLocationUnknownGet(context->get()))); 210404d76d36SJacques Pienaar }, 210504d76d36SJacques Pienaar py::arg("name"), py::arg("childLoc") = py::none(), 210604d76d36SJacques Pienaar py::arg("context") = py::none(), kContextGetNameLocationDocString) 2107436c6c9cSStella Laurenzo .def_property_readonly( 2108436c6c9cSStella Laurenzo "context", 2109436c6c9cSStella Laurenzo [](PyLocation &self) { return self.getContext().getObject(); }, 2110436c6c9cSStella Laurenzo "Context that owns the Location") 2111436c6c9cSStella Laurenzo .def("__repr__", [](PyLocation &self) { 2112436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2113436c6c9cSStella Laurenzo mlirLocationPrint(self, printAccum.getCallback(), 2114436c6c9cSStella Laurenzo printAccum.getUserData()); 2115436c6c9cSStella Laurenzo return printAccum.join(); 2116436c6c9cSStella Laurenzo }); 2117436c6c9cSStella Laurenzo 2118436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2119436c6c9cSStella Laurenzo // Mapping of Module 2120436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2121f05ff4f7SStella Laurenzo py::class_<PyModule>(m, "Module", py::module_local()) 2122436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2123436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2124436c6c9cSStella Laurenzo .def_static( 2125436c6c9cSStella Laurenzo "parse", 2126436c6c9cSStella Laurenzo [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2127436c6c9cSStella Laurenzo MlirModule module = mlirModuleCreateParse( 2128436c6c9cSStella Laurenzo context->get(), toMlirStringRef(moduleAsm)); 2129436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 2130436c6c9cSStella Laurenzo // in C API. 2131436c6c9cSStella Laurenzo if (mlirModuleIsNull(module)) { 2132436c6c9cSStella Laurenzo throw SetPyError( 2133436c6c9cSStella Laurenzo PyExc_ValueError, 2134436c6c9cSStella Laurenzo "Unable to parse module assembly (see diagnostics)"); 2135436c6c9cSStella Laurenzo } 2136436c6c9cSStella Laurenzo return PyModule::forModule(module).releaseObject(); 2137436c6c9cSStella Laurenzo }, 2138436c6c9cSStella Laurenzo py::arg("asm"), py::arg("context") = py::none(), 2139436c6c9cSStella Laurenzo kModuleParseDocstring) 2140436c6c9cSStella Laurenzo .def_static( 2141436c6c9cSStella Laurenzo "create", 2142436c6c9cSStella Laurenzo [](DefaultingPyLocation loc) { 2143436c6c9cSStella Laurenzo MlirModule module = mlirModuleCreateEmpty(loc); 2144436c6c9cSStella Laurenzo return PyModule::forModule(module).releaseObject(); 2145436c6c9cSStella Laurenzo }, 2146436c6c9cSStella Laurenzo py::arg("loc") = py::none(), "Creates an empty module") 2147436c6c9cSStella Laurenzo .def_property_readonly( 2148436c6c9cSStella Laurenzo "context", 2149436c6c9cSStella Laurenzo [](PyModule &self) { return self.getContext().getObject(); }, 2150436c6c9cSStella Laurenzo "Context that created the Module") 2151436c6c9cSStella Laurenzo .def_property_readonly( 2152436c6c9cSStella Laurenzo "operation", 2153436c6c9cSStella Laurenzo [](PyModule &self) { 2154436c6c9cSStella Laurenzo return PyOperation::forOperation(self.getContext(), 2155436c6c9cSStella Laurenzo mlirModuleGetOperation(self.get()), 2156436c6c9cSStella Laurenzo self.getRef().releaseObject()) 2157436c6c9cSStella Laurenzo .releaseObject(); 2158436c6c9cSStella Laurenzo }, 2159436c6c9cSStella Laurenzo "Accesses the module as an operation") 2160436c6c9cSStella Laurenzo .def_property_readonly( 2161436c6c9cSStella Laurenzo "body", 2162436c6c9cSStella Laurenzo [](PyModule &self) { 2163436c6c9cSStella Laurenzo PyOperationRef module_op = PyOperation::forOperation( 2164436c6c9cSStella Laurenzo self.getContext(), mlirModuleGetOperation(self.get()), 2165436c6c9cSStella Laurenzo self.getRef().releaseObject()); 2166436c6c9cSStella Laurenzo PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2167436c6c9cSStella Laurenzo return returnBlock; 2168436c6c9cSStella Laurenzo }, 2169436c6c9cSStella Laurenzo "Return the block for this module") 2170436c6c9cSStella Laurenzo .def( 2171436c6c9cSStella Laurenzo "dump", 2172436c6c9cSStella Laurenzo [](PyModule &self) { 2173436c6c9cSStella Laurenzo mlirOperationDump(mlirModuleGetOperation(self.get())); 2174436c6c9cSStella Laurenzo }, 2175436c6c9cSStella Laurenzo kDumpDocstring) 2176436c6c9cSStella Laurenzo .def( 2177436c6c9cSStella Laurenzo "__str__", 2178ace1d0adSStella Laurenzo [](py::object self) { 2179ace1d0adSStella Laurenzo // Defer to the operation's __str__. 2180ace1d0adSStella Laurenzo return self.attr("operation").attr("__str__")(); 2181436c6c9cSStella Laurenzo }, 2182436c6c9cSStella Laurenzo kOperationStrDunderDocstring); 2183436c6c9cSStella Laurenzo 2184436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2185436c6c9cSStella Laurenzo // Mapping of Operation. 2186436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2187f05ff4f7SStella Laurenzo py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 21881fb2e842SStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 21891fb2e842SStella Laurenzo [](PyOperationBase &self) { 21901fb2e842SStella Laurenzo return self.getOperation().getCapsule(); 21911fb2e842SStella Laurenzo }) 2192436c6c9cSStella Laurenzo .def("__eq__", 2193436c6c9cSStella Laurenzo [](PyOperationBase &self, PyOperationBase &other) { 2194436c6c9cSStella Laurenzo return &self.getOperation() == &other.getOperation(); 2195436c6c9cSStella Laurenzo }) 2196436c6c9cSStella Laurenzo .def("__eq__", 2197436c6c9cSStella Laurenzo [](PyOperationBase &self, py::object other) { return false; }) 2198f78fe0b7Srkayaith .def("__hash__", 2199f78fe0b7Srkayaith [](PyOperationBase &self) { 2200f78fe0b7Srkayaith return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2201f78fe0b7Srkayaith }) 2202436c6c9cSStella Laurenzo .def_property_readonly("attributes", 2203436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2204436c6c9cSStella Laurenzo return PyOpAttributeMap( 2205436c6c9cSStella Laurenzo self.getOperation().getRef()); 2206436c6c9cSStella Laurenzo }) 2207436c6c9cSStella Laurenzo .def_property_readonly("operands", 2208436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2209436c6c9cSStella Laurenzo return PyOpOperandList( 2210436c6c9cSStella Laurenzo self.getOperation().getRef()); 2211436c6c9cSStella Laurenzo }) 2212436c6c9cSStella Laurenzo .def_property_readonly("regions", 2213436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2214436c6c9cSStella Laurenzo return PyRegionList( 2215436c6c9cSStella Laurenzo self.getOperation().getRef()); 2216436c6c9cSStella Laurenzo }) 2217436c6c9cSStella Laurenzo .def_property_readonly( 2218436c6c9cSStella Laurenzo "results", 2219436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2220436c6c9cSStella Laurenzo return PyOpResultList(self.getOperation().getRef()); 2221436c6c9cSStella Laurenzo }, 2222436c6c9cSStella Laurenzo "Returns the list of Operation results.") 2223436c6c9cSStella Laurenzo .def_property_readonly( 2224436c6c9cSStella Laurenzo "result", 2225436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2226436c6c9cSStella Laurenzo auto &operation = self.getOperation(); 2227436c6c9cSStella Laurenzo auto numResults = mlirOperationGetNumResults(operation); 2228436c6c9cSStella Laurenzo if (numResults != 1) { 2229436c6c9cSStella Laurenzo auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2230436c6c9cSStella Laurenzo throw SetPyError( 2231436c6c9cSStella Laurenzo PyExc_ValueError, 2232436c6c9cSStella Laurenzo Twine("Cannot call .result on operation ") + 2233436c6c9cSStella Laurenzo StringRef(name.data, name.length) + " which has " + 2234436c6c9cSStella Laurenzo Twine(numResults) + 2235436c6c9cSStella Laurenzo " results (it is only valid for operations with a " 2236436c6c9cSStella Laurenzo "single result)"); 2237436c6c9cSStella Laurenzo } 2238436c6c9cSStella Laurenzo return PyOpResult(operation.getRef(), 2239436c6c9cSStella Laurenzo mlirOperationGetResult(operation, 0)); 2240436c6c9cSStella Laurenzo }, 2241436c6c9cSStella Laurenzo "Shortcut to get an op result if it has only one (throws an error " 2242436c6c9cSStella Laurenzo "otherwise).") 2243d5429a13Srkayaith .def_property_readonly( 2244d5429a13Srkayaith "location", 2245d5429a13Srkayaith [](PyOperationBase &self) { 2246d5429a13Srkayaith PyOperation &operation = self.getOperation(); 2247d5429a13Srkayaith return PyLocation(operation.getContext(), 2248d5429a13Srkayaith mlirOperationGetLocation(operation.get())); 2249d5429a13Srkayaith }, 2250d5429a13Srkayaith "Returns the source location the operation was defined or derived " 2251d5429a13Srkayaith "from.") 2252436c6c9cSStella Laurenzo .def( 2253436c6c9cSStella Laurenzo "__str__", 2254436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2255436c6c9cSStella Laurenzo return self.getAsm(/*binary=*/false, 2256436c6c9cSStella Laurenzo /*largeElementsLimit=*/llvm::None, 2257436c6c9cSStella Laurenzo /*enableDebugInfo=*/false, 2258436c6c9cSStella Laurenzo /*prettyDebugInfo=*/false, 2259436c6c9cSStella Laurenzo /*printGenericOpForm=*/false, 2260ace1d0adSStella Laurenzo /*useLocalScope=*/false, 2261ace1d0adSStella Laurenzo /*assumeVerified=*/false); 2262436c6c9cSStella Laurenzo }, 2263436c6c9cSStella Laurenzo "Returns the assembly form of the operation.") 2264436c6c9cSStella Laurenzo .def("print", &PyOperationBase::print, 2265436c6c9cSStella Laurenzo // Careful: Lots of arguments must match up with print method. 2266436c6c9cSStella Laurenzo py::arg("file") = py::none(), py::arg("binary") = false, 2267436c6c9cSStella Laurenzo py::arg("large_elements_limit") = py::none(), 2268436c6c9cSStella Laurenzo py::arg("enable_debug_info") = false, 2269436c6c9cSStella Laurenzo py::arg("pretty_debug_info") = false, 2270436c6c9cSStella Laurenzo py::arg("print_generic_op_form") = false, 2271ace1d0adSStella Laurenzo py::arg("use_local_scope") = false, 2272ace1d0adSStella Laurenzo py::arg("assume_verified") = false, kOperationPrintDocstring) 2273436c6c9cSStella Laurenzo .def("get_asm", &PyOperationBase::getAsm, 2274436c6c9cSStella Laurenzo // Careful: Lots of arguments must match up with get_asm method. 2275436c6c9cSStella Laurenzo py::arg("binary") = false, 2276436c6c9cSStella Laurenzo py::arg("large_elements_limit") = py::none(), 2277436c6c9cSStella Laurenzo py::arg("enable_debug_info") = false, 2278436c6c9cSStella Laurenzo py::arg("pretty_debug_info") = false, 2279436c6c9cSStella Laurenzo py::arg("print_generic_op_form") = false, 2280ace1d0adSStella Laurenzo py::arg("use_local_scope") = false, 2281ace1d0adSStella Laurenzo py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2282436c6c9cSStella Laurenzo .def( 2283436c6c9cSStella Laurenzo "verify", 2284436c6c9cSStella Laurenzo [](PyOperationBase &self) { 2285436c6c9cSStella Laurenzo return mlirOperationVerify(self.getOperation()); 2286436c6c9cSStella Laurenzo }, 2287436c6c9cSStella Laurenzo "Verify the operation and return true if it passes, false if it " 228824685aaeSAlex Zinenko "fails.") 228924685aaeSAlex Zinenko .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 229024685aaeSAlex Zinenko "Puts self immediately after the other operation in its parent " 229124685aaeSAlex Zinenko "block.") 229224685aaeSAlex Zinenko .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 229324685aaeSAlex Zinenko "Puts self immediately before the other operation in its parent " 229424685aaeSAlex Zinenko "block.") 229524685aaeSAlex Zinenko .def( 229624685aaeSAlex Zinenko "detach_from_parent", 229724685aaeSAlex Zinenko [](PyOperationBase &self) { 229824685aaeSAlex Zinenko PyOperation &operation = self.getOperation(); 229924685aaeSAlex Zinenko operation.checkValid(); 230024685aaeSAlex Zinenko if (!operation.isAttached()) 230124685aaeSAlex Zinenko throw py::value_error("Detached operation has no parent."); 230224685aaeSAlex Zinenko 230324685aaeSAlex Zinenko operation.detachFromParent(); 230424685aaeSAlex Zinenko return operation.createOpView(); 230524685aaeSAlex Zinenko }, 230624685aaeSAlex Zinenko "Detaches the operation from its parent block."); 2307436c6c9cSStella Laurenzo 2308f05ff4f7SStella Laurenzo py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2309436c6c9cSStella Laurenzo .def_static("create", &PyOperation::create, py::arg("name"), 2310436c6c9cSStella Laurenzo py::arg("results") = py::none(), 2311436c6c9cSStella Laurenzo py::arg("operands") = py::none(), 2312436c6c9cSStella Laurenzo py::arg("attributes") = py::none(), 2313436c6c9cSStella Laurenzo py::arg("successors") = py::none(), py::arg("regions") = 0, 2314436c6c9cSStella Laurenzo py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2315436c6c9cSStella Laurenzo kOperationCreateDocstring) 2316c65bb760SJohn Demme .def_property_readonly("parent", 23171689dadeSJohn Demme [](PyOperation &self) -> py::object { 23181689dadeSJohn Demme auto parent = self.getParentOperation(); 23191689dadeSJohn Demme if (parent) 23201689dadeSJohn Demme return parent->getObject(); 23211689dadeSJohn Demme return py::none(); 2322c65bb760SJohn Demme }) 232349745f87SMike Urbach .def("erase", &PyOperation::erase) 23240126e906SJohn Demme .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 23250126e906SJohn Demme &PyOperation::getCapsule) 23260126e906SJohn Demme .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2327436c6c9cSStella Laurenzo .def_property_readonly("name", 2328436c6c9cSStella Laurenzo [](PyOperation &self) { 232949745f87SMike Urbach self.checkValid(); 2330436c6c9cSStella Laurenzo MlirOperation operation = self.get(); 2331436c6c9cSStella Laurenzo MlirStringRef name = mlirIdentifierStr( 2332436c6c9cSStella Laurenzo mlirOperationGetName(operation)); 2333436c6c9cSStella Laurenzo return py::str(name.data, name.length); 2334436c6c9cSStella Laurenzo }) 2335436c6c9cSStella Laurenzo .def_property_readonly( 2336436c6c9cSStella Laurenzo "context", 233749745f87SMike Urbach [](PyOperation &self) { 233849745f87SMike Urbach self.checkValid(); 233949745f87SMike Urbach return self.getContext().getObject(); 234049745f87SMike Urbach }, 2341436c6c9cSStella Laurenzo "Context that owns the Operation") 2342436c6c9cSStella Laurenzo .def_property_readonly("opview", &PyOperation::createOpView); 2343436c6c9cSStella Laurenzo 2344436c6c9cSStella Laurenzo auto opViewClass = 2345f05ff4f7SStella Laurenzo py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2346*a6e7d024SStella Laurenzo .def(py::init<py::object>(), py::arg("operation")) 2347436c6c9cSStella Laurenzo .def_property_readonly("operation", &PyOpView::getOperationObject) 2348436c6c9cSStella Laurenzo .def_property_readonly( 2349436c6c9cSStella Laurenzo "context", 2350436c6c9cSStella Laurenzo [](PyOpView &self) { 2351436c6c9cSStella Laurenzo return self.getOperation().getContext().getObject(); 2352436c6c9cSStella Laurenzo }, 2353436c6c9cSStella Laurenzo "Context that owns the Operation") 2354436c6c9cSStella Laurenzo .def("__str__", [](PyOpView &self) { 2355436c6c9cSStella Laurenzo return py::str(self.getOperationObject()); 2356436c6c9cSStella Laurenzo }); 2357436c6c9cSStella Laurenzo opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2358436c6c9cSStella Laurenzo opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2359436c6c9cSStella Laurenzo opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2360436c6c9cSStella Laurenzo opViewClass.attr("build_generic") = classmethod( 2361436c6c9cSStella Laurenzo &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2362436c6c9cSStella Laurenzo py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2363436c6c9cSStella Laurenzo py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2364436c6c9cSStella Laurenzo py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2365436c6c9cSStella Laurenzo "Builds a specific, generated OpView based on class level attributes."); 2366436c6c9cSStella Laurenzo 2367436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2368436c6c9cSStella Laurenzo // Mapping of PyRegion. 2369436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2370f05ff4f7SStella Laurenzo py::class_<PyRegion>(m, "Region", py::module_local()) 2371436c6c9cSStella Laurenzo .def_property_readonly( 2372436c6c9cSStella Laurenzo "blocks", 2373436c6c9cSStella Laurenzo [](PyRegion &self) { 2374436c6c9cSStella Laurenzo return PyBlockList(self.getParentOperation(), self.get()); 2375436c6c9cSStella Laurenzo }, 2376436c6c9cSStella Laurenzo "Returns a forward-optimized sequence of blocks.") 237778f2dae0SAlex Zinenko .def_property_readonly( 237878f2dae0SAlex Zinenko "owner", 237978f2dae0SAlex Zinenko [](PyRegion &self) { 238078f2dae0SAlex Zinenko return self.getParentOperation()->createOpView(); 238178f2dae0SAlex Zinenko }, 238278f2dae0SAlex Zinenko "Returns the operation owning this region.") 2383436c6c9cSStella Laurenzo .def( 2384436c6c9cSStella Laurenzo "__iter__", 2385436c6c9cSStella Laurenzo [](PyRegion &self) { 2386436c6c9cSStella Laurenzo self.checkValid(); 2387436c6c9cSStella Laurenzo MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2388436c6c9cSStella Laurenzo return PyBlockIterator(self.getParentOperation(), firstBlock); 2389436c6c9cSStella Laurenzo }, 2390436c6c9cSStella Laurenzo "Iterates over blocks in the region.") 2391436c6c9cSStella Laurenzo .def("__eq__", 2392436c6c9cSStella Laurenzo [](PyRegion &self, PyRegion &other) { 2393436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 2394436c6c9cSStella Laurenzo }) 2395436c6c9cSStella Laurenzo .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2396436c6c9cSStella Laurenzo 2397436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2398436c6c9cSStella Laurenzo // Mapping of PyBlock. 2399436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2400f05ff4f7SStella Laurenzo py::class_<PyBlock>(m, "Block", py::module_local()) 2401436c6c9cSStella Laurenzo .def_property_readonly( 240296fbd5cdSJohn Demme "owner", 240396fbd5cdSJohn Demme [](PyBlock &self) { 240496fbd5cdSJohn Demme return self.getParentOperation()->createOpView(); 240596fbd5cdSJohn Demme }, 240696fbd5cdSJohn Demme "Returns the owning operation of this block.") 240796fbd5cdSJohn Demme .def_property_readonly( 24088e6c55c9SStella Laurenzo "region", 24098e6c55c9SStella Laurenzo [](PyBlock &self) { 24108e6c55c9SStella Laurenzo MlirRegion region = mlirBlockGetParentRegion(self.get()); 24118e6c55c9SStella Laurenzo return PyRegion(self.getParentOperation(), region); 24128e6c55c9SStella Laurenzo }, 24138e6c55c9SStella Laurenzo "Returns the owning region of this block.") 24148e6c55c9SStella Laurenzo .def_property_readonly( 2415436c6c9cSStella Laurenzo "arguments", 2416436c6c9cSStella Laurenzo [](PyBlock &self) { 2417436c6c9cSStella Laurenzo return PyBlockArgumentList(self.getParentOperation(), self.get()); 2418436c6c9cSStella Laurenzo }, 2419436c6c9cSStella Laurenzo "Returns a list of block arguments.") 2420436c6c9cSStella Laurenzo .def_property_readonly( 2421436c6c9cSStella Laurenzo "operations", 2422436c6c9cSStella Laurenzo [](PyBlock &self) { 2423436c6c9cSStella Laurenzo return PyOperationList(self.getParentOperation(), self.get()); 2424436c6c9cSStella Laurenzo }, 2425436c6c9cSStella Laurenzo "Returns a forward-optimized sequence of operations.") 242678f2dae0SAlex Zinenko .def_static( 242778f2dae0SAlex Zinenko "create_at_start", 242878f2dae0SAlex Zinenko [](PyRegion &parent, py::list pyArgTypes) { 242978f2dae0SAlex Zinenko parent.checkValid(); 243078f2dae0SAlex Zinenko llvm::SmallVector<MlirType, 4> argTypes; 243178f2dae0SAlex Zinenko argTypes.reserve(pyArgTypes.size()); 243278f2dae0SAlex Zinenko for (auto &pyArg : pyArgTypes) { 243378f2dae0SAlex Zinenko argTypes.push_back(pyArg.cast<PyType &>()); 243478f2dae0SAlex Zinenko } 243578f2dae0SAlex Zinenko 243678f2dae0SAlex Zinenko MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 243778f2dae0SAlex Zinenko mlirRegionInsertOwnedBlock(parent, 0, block); 243878f2dae0SAlex Zinenko return PyBlock(parent.getParentOperation(), block); 243978f2dae0SAlex Zinenko }, 2440*a6e7d024SStella Laurenzo py::arg("parent"), py::arg("arg_types") = py::list(), 244178f2dae0SAlex Zinenko "Creates and returns a new Block at the beginning of the given " 244278f2dae0SAlex Zinenko "region (with given argument types).") 2443436c6c9cSStella Laurenzo .def( 24448e6c55c9SStella Laurenzo "create_before", 24458e6c55c9SStella Laurenzo [](PyBlock &self, py::args pyArgTypes) { 24468e6c55c9SStella Laurenzo self.checkValid(); 24478e6c55c9SStella Laurenzo llvm::SmallVector<MlirType, 4> argTypes; 24488e6c55c9SStella Laurenzo argTypes.reserve(pyArgTypes.size()); 24498e6c55c9SStella Laurenzo for (auto &pyArg : pyArgTypes) { 24508e6c55c9SStella Laurenzo argTypes.push_back(pyArg.cast<PyType &>()); 24518e6c55c9SStella Laurenzo } 24528e6c55c9SStella Laurenzo 24538e6c55c9SStella Laurenzo MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 24548e6c55c9SStella Laurenzo MlirRegion region = mlirBlockGetParentRegion(self.get()); 24558e6c55c9SStella Laurenzo mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 24568e6c55c9SStella Laurenzo return PyBlock(self.getParentOperation(), block); 24578e6c55c9SStella Laurenzo }, 24588e6c55c9SStella Laurenzo "Creates and returns a new Block before this block " 24598e6c55c9SStella Laurenzo "(with given argument types).") 24608e6c55c9SStella Laurenzo .def( 24618e6c55c9SStella Laurenzo "create_after", 24628e6c55c9SStella Laurenzo [](PyBlock &self, py::args pyArgTypes) { 24638e6c55c9SStella Laurenzo self.checkValid(); 24648e6c55c9SStella Laurenzo llvm::SmallVector<MlirType, 4> argTypes; 24658e6c55c9SStella Laurenzo argTypes.reserve(pyArgTypes.size()); 24668e6c55c9SStella Laurenzo for (auto &pyArg : pyArgTypes) { 24678e6c55c9SStella Laurenzo argTypes.push_back(pyArg.cast<PyType &>()); 24688e6c55c9SStella Laurenzo } 24698e6c55c9SStella Laurenzo 24708e6c55c9SStella Laurenzo MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 24718e6c55c9SStella Laurenzo MlirRegion region = mlirBlockGetParentRegion(self.get()); 24728e6c55c9SStella Laurenzo mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 24738e6c55c9SStella Laurenzo return PyBlock(self.getParentOperation(), block); 24748e6c55c9SStella Laurenzo }, 24758e6c55c9SStella Laurenzo "Creates and returns a new Block after this block " 24768e6c55c9SStella Laurenzo "(with given argument types).") 24778e6c55c9SStella Laurenzo .def( 2478436c6c9cSStella Laurenzo "__iter__", 2479436c6c9cSStella Laurenzo [](PyBlock &self) { 2480436c6c9cSStella Laurenzo self.checkValid(); 2481436c6c9cSStella Laurenzo MlirOperation firstOperation = 2482436c6c9cSStella Laurenzo mlirBlockGetFirstOperation(self.get()); 2483436c6c9cSStella Laurenzo return PyOperationIterator(self.getParentOperation(), 2484436c6c9cSStella Laurenzo firstOperation); 2485436c6c9cSStella Laurenzo }, 2486436c6c9cSStella Laurenzo "Iterates over operations in the block.") 2487436c6c9cSStella Laurenzo .def("__eq__", 2488436c6c9cSStella Laurenzo [](PyBlock &self, PyBlock &other) { 2489436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 2490436c6c9cSStella Laurenzo }) 2491436c6c9cSStella Laurenzo .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2492436c6c9cSStella Laurenzo .def( 2493436c6c9cSStella Laurenzo "__str__", 2494436c6c9cSStella Laurenzo [](PyBlock &self) { 2495436c6c9cSStella Laurenzo self.checkValid(); 2496436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2497436c6c9cSStella Laurenzo mlirBlockPrint(self.get(), printAccum.getCallback(), 2498436c6c9cSStella Laurenzo printAccum.getUserData()); 2499436c6c9cSStella Laurenzo return printAccum.join(); 2500436c6c9cSStella Laurenzo }, 250124685aaeSAlex Zinenko "Returns the assembly form of the block.") 250224685aaeSAlex Zinenko .def( 250324685aaeSAlex Zinenko "append", 250424685aaeSAlex Zinenko [](PyBlock &self, PyOperationBase &operation) { 250524685aaeSAlex Zinenko if (operation.getOperation().isAttached()) 250624685aaeSAlex Zinenko operation.getOperation().detachFromParent(); 250724685aaeSAlex Zinenko 250824685aaeSAlex Zinenko MlirOperation mlirOperation = operation.getOperation().get(); 250924685aaeSAlex Zinenko mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 251024685aaeSAlex Zinenko operation.getOperation().setAttached( 251124685aaeSAlex Zinenko self.getParentOperation().getObject()); 251224685aaeSAlex Zinenko }, 2513*a6e7d024SStella Laurenzo py::arg("operation"), 251424685aaeSAlex Zinenko "Appends an operation to this block. If the operation is currently " 251524685aaeSAlex Zinenko "in another block, it will be moved."); 2516436c6c9cSStella Laurenzo 2517436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2518436c6c9cSStella Laurenzo // Mapping of PyInsertionPoint. 2519436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2520436c6c9cSStella Laurenzo 2521f05ff4f7SStella Laurenzo py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2522436c6c9cSStella Laurenzo .def(py::init<PyBlock &>(), py::arg("block"), 2523436c6c9cSStella Laurenzo "Inserts after the last operation but still inside the block.") 2524436c6c9cSStella Laurenzo .def("__enter__", &PyInsertionPoint::contextEnter) 2525436c6c9cSStella Laurenzo .def("__exit__", &PyInsertionPoint::contextExit) 2526436c6c9cSStella Laurenzo .def_property_readonly_static( 2527436c6c9cSStella Laurenzo "current", 2528436c6c9cSStella Laurenzo [](py::object & /*class*/) { 2529436c6c9cSStella Laurenzo auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2530436c6c9cSStella Laurenzo if (!ip) 2531436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2532436c6c9cSStella Laurenzo return ip; 2533436c6c9cSStella Laurenzo }, 2534436c6c9cSStella Laurenzo "Gets the InsertionPoint bound to the current thread or raises " 2535436c6c9cSStella Laurenzo "ValueError if none has been set") 2536436c6c9cSStella Laurenzo .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2537436c6c9cSStella Laurenzo "Inserts before a referenced operation.") 2538436c6c9cSStella Laurenzo .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2539436c6c9cSStella Laurenzo py::arg("block"), "Inserts at the beginning of the block.") 2540436c6c9cSStella Laurenzo .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2541436c6c9cSStella Laurenzo py::arg("block"), "Inserts before the block terminator.") 2542436c6c9cSStella Laurenzo .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 25438e6c55c9SStella Laurenzo "Inserts an operation.") 25448e6c55c9SStella Laurenzo .def_property_readonly( 25458e6c55c9SStella Laurenzo "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 25468e6c55c9SStella Laurenzo "Returns the block that this InsertionPoint points to."); 2547436c6c9cSStella Laurenzo 2548436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2549436c6c9cSStella Laurenzo // Mapping of PyAttribute. 2550436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2551f05ff4f7SStella Laurenzo py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2552b57d6fe4SStella Laurenzo // Delegate to the PyAttribute copy constructor, which will also lifetime 2553b57d6fe4SStella Laurenzo // extend the backing context which owns the MlirAttribute. 2554b57d6fe4SStella Laurenzo .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2555b57d6fe4SStella Laurenzo "Casts the passed attribute to the generic Attribute") 2556436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2557436c6c9cSStella Laurenzo &PyAttribute::getCapsule) 2558436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2559436c6c9cSStella Laurenzo .def_static( 2560436c6c9cSStella Laurenzo "parse", 2561436c6c9cSStella Laurenzo [](std::string attrSpec, DefaultingPyMlirContext context) { 2562436c6c9cSStella Laurenzo MlirAttribute type = mlirAttributeParseGet( 2563436c6c9cSStella Laurenzo context->get(), toMlirStringRef(attrSpec)); 2564436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 2565436c6c9cSStella Laurenzo // in C API. 2566436c6c9cSStella Laurenzo if (mlirAttributeIsNull(type)) { 2567436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 2568436c6c9cSStella Laurenzo Twine("Unable to parse attribute: '") + 2569436c6c9cSStella Laurenzo attrSpec + "'"); 2570436c6c9cSStella Laurenzo } 2571436c6c9cSStella Laurenzo return PyAttribute(context->getRef(), type); 2572436c6c9cSStella Laurenzo }, 2573436c6c9cSStella Laurenzo py::arg("asm"), py::arg("context") = py::none(), 2574436c6c9cSStella Laurenzo "Parses an attribute from an assembly form") 2575436c6c9cSStella Laurenzo .def_property_readonly( 2576436c6c9cSStella Laurenzo "context", 2577436c6c9cSStella Laurenzo [](PyAttribute &self) { return self.getContext().getObject(); }, 2578436c6c9cSStella Laurenzo "Context that owns the Attribute") 2579436c6c9cSStella Laurenzo .def_property_readonly("type", 2580436c6c9cSStella Laurenzo [](PyAttribute &self) { 2581436c6c9cSStella Laurenzo return PyType(self.getContext()->getRef(), 2582436c6c9cSStella Laurenzo mlirAttributeGetType(self)); 2583436c6c9cSStella Laurenzo }) 2584436c6c9cSStella Laurenzo .def( 2585436c6c9cSStella Laurenzo "get_named", 2586436c6c9cSStella Laurenzo [](PyAttribute &self, std::string name) { 2587436c6c9cSStella Laurenzo return PyNamedAttribute(self, std::move(name)); 2588436c6c9cSStella Laurenzo }, 2589436c6c9cSStella Laurenzo py::keep_alive<0, 1>(), "Binds a name to the attribute") 2590436c6c9cSStella Laurenzo .def("__eq__", 2591436c6c9cSStella Laurenzo [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2592436c6c9cSStella Laurenzo .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2593f78fe0b7Srkayaith .def("__hash__", 2594f78fe0b7Srkayaith [](PyAttribute &self) { 2595f78fe0b7Srkayaith return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2596f78fe0b7Srkayaith }) 2597436c6c9cSStella Laurenzo .def( 2598436c6c9cSStella Laurenzo "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2599436c6c9cSStella Laurenzo kDumpDocstring) 2600436c6c9cSStella Laurenzo .def( 2601436c6c9cSStella Laurenzo "__str__", 2602436c6c9cSStella Laurenzo [](PyAttribute &self) { 2603436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2604436c6c9cSStella Laurenzo mlirAttributePrint(self, printAccum.getCallback(), 2605436c6c9cSStella Laurenzo printAccum.getUserData()); 2606436c6c9cSStella Laurenzo return printAccum.join(); 2607436c6c9cSStella Laurenzo }, 2608436c6c9cSStella Laurenzo "Returns the assembly form of the Attribute.") 2609436c6c9cSStella Laurenzo .def("__repr__", [](PyAttribute &self) { 2610436c6c9cSStella Laurenzo // Generally, assembly formats are not printed for __repr__ because 2611436c6c9cSStella Laurenzo // this can cause exceptionally long debug output and exceptions. 2612436c6c9cSStella Laurenzo // However, attribute values are generally considered useful and are 2613436c6c9cSStella Laurenzo // printed. This may need to be re-evaluated if debug dumps end up 2614436c6c9cSStella Laurenzo // being excessive. 2615436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2616436c6c9cSStella Laurenzo printAccum.parts.append("Attribute("); 2617436c6c9cSStella Laurenzo mlirAttributePrint(self, printAccum.getCallback(), 2618436c6c9cSStella Laurenzo printAccum.getUserData()); 2619436c6c9cSStella Laurenzo printAccum.parts.append(")"); 2620436c6c9cSStella Laurenzo return printAccum.join(); 2621436c6c9cSStella Laurenzo }); 2622436c6c9cSStella Laurenzo 2623436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2624436c6c9cSStella Laurenzo // Mapping of PyNamedAttribute 2625436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2626f05ff4f7SStella Laurenzo py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2627436c6c9cSStella Laurenzo .def("__repr__", 2628436c6c9cSStella Laurenzo [](PyNamedAttribute &self) { 2629436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2630436c6c9cSStella Laurenzo printAccum.parts.append("NamedAttribute("); 2631436c6c9cSStella Laurenzo printAccum.parts.append( 2632120591e1SRiver Riddle py::str(mlirIdentifierStr(self.namedAttr.name).data, 2633120591e1SRiver Riddle mlirIdentifierStr(self.namedAttr.name).length)); 2634436c6c9cSStella Laurenzo printAccum.parts.append("="); 2635436c6c9cSStella Laurenzo mlirAttributePrint(self.namedAttr.attribute, 2636436c6c9cSStella Laurenzo printAccum.getCallback(), 2637436c6c9cSStella Laurenzo printAccum.getUserData()); 2638436c6c9cSStella Laurenzo printAccum.parts.append(")"); 2639436c6c9cSStella Laurenzo return printAccum.join(); 2640436c6c9cSStella Laurenzo }) 2641436c6c9cSStella Laurenzo .def_property_readonly( 2642436c6c9cSStella Laurenzo "name", 2643436c6c9cSStella Laurenzo [](PyNamedAttribute &self) { 2644436c6c9cSStella Laurenzo return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2645436c6c9cSStella Laurenzo mlirIdentifierStr(self.namedAttr.name).length); 2646436c6c9cSStella Laurenzo }, 2647436c6c9cSStella Laurenzo "The name of the NamedAttribute binding") 2648436c6c9cSStella Laurenzo .def_property_readonly( 2649436c6c9cSStella Laurenzo "attr", 2650436c6c9cSStella Laurenzo [](PyNamedAttribute &self) { 2651436c6c9cSStella Laurenzo // TODO: When named attribute is removed/refactored, also remove 2652436c6c9cSStella Laurenzo // this constructor (it does an inefficient table lookup). 2653436c6c9cSStella Laurenzo auto contextRef = PyMlirContext::forContext( 2654436c6c9cSStella Laurenzo mlirAttributeGetContext(self.namedAttr.attribute)); 2655436c6c9cSStella Laurenzo return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2656436c6c9cSStella Laurenzo }, 2657436c6c9cSStella Laurenzo py::keep_alive<0, 1>(), 2658436c6c9cSStella Laurenzo "The underlying generic attribute of the NamedAttribute binding"); 2659436c6c9cSStella Laurenzo 2660436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2661436c6c9cSStella Laurenzo // Mapping of PyType. 2662436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2663f05ff4f7SStella Laurenzo py::class_<PyType>(m, "Type", py::module_local()) 2664b57d6fe4SStella Laurenzo // Delegate to the PyType copy constructor, which will also lifetime 2665b57d6fe4SStella Laurenzo // extend the backing context which owns the MlirType. 2666b57d6fe4SStella Laurenzo .def(py::init<PyType &>(), py::arg("cast_from_type"), 2667b57d6fe4SStella Laurenzo "Casts the passed type to the generic Type") 2668436c6c9cSStella Laurenzo .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2669436c6c9cSStella Laurenzo .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2670436c6c9cSStella Laurenzo .def_static( 2671436c6c9cSStella Laurenzo "parse", 2672436c6c9cSStella Laurenzo [](std::string typeSpec, DefaultingPyMlirContext context) { 2673436c6c9cSStella Laurenzo MlirType type = 2674436c6c9cSStella Laurenzo mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2675436c6c9cSStella Laurenzo // TODO: Rework error reporting once diagnostic engine is exposed 2676436c6c9cSStella Laurenzo // in C API. 2677436c6c9cSStella Laurenzo if (mlirTypeIsNull(type)) { 2678436c6c9cSStella Laurenzo throw SetPyError(PyExc_ValueError, 2679436c6c9cSStella Laurenzo Twine("Unable to parse type: '") + typeSpec + 2680436c6c9cSStella Laurenzo "'"); 2681436c6c9cSStella Laurenzo } 2682436c6c9cSStella Laurenzo return PyType(context->getRef(), type); 2683436c6c9cSStella Laurenzo }, 2684436c6c9cSStella Laurenzo py::arg("asm"), py::arg("context") = py::none(), 2685436c6c9cSStella Laurenzo kContextParseTypeDocstring) 2686436c6c9cSStella Laurenzo .def_property_readonly( 2687436c6c9cSStella Laurenzo "context", [](PyType &self) { return self.getContext().getObject(); }, 2688436c6c9cSStella Laurenzo "Context that owns the Type") 2689436c6c9cSStella Laurenzo .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2690436c6c9cSStella Laurenzo .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2691f78fe0b7Srkayaith .def("__hash__", 2692f78fe0b7Srkayaith [](PyType &self) { 2693f78fe0b7Srkayaith return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2694f78fe0b7Srkayaith }) 2695436c6c9cSStella Laurenzo .def( 2696436c6c9cSStella Laurenzo "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2697436c6c9cSStella Laurenzo .def( 2698436c6c9cSStella Laurenzo "__str__", 2699436c6c9cSStella Laurenzo [](PyType &self) { 2700436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2701436c6c9cSStella Laurenzo mlirTypePrint(self, printAccum.getCallback(), 2702436c6c9cSStella Laurenzo printAccum.getUserData()); 2703436c6c9cSStella Laurenzo return printAccum.join(); 2704436c6c9cSStella Laurenzo }, 2705436c6c9cSStella Laurenzo "Returns the assembly form of the type.") 2706436c6c9cSStella Laurenzo .def("__repr__", [](PyType &self) { 2707436c6c9cSStella Laurenzo // Generally, assembly formats are not printed for __repr__ because 2708436c6c9cSStella Laurenzo // this can cause exceptionally long debug output and exceptions. 2709436c6c9cSStella Laurenzo // However, types are an exception as they typically have compact 2710436c6c9cSStella Laurenzo // assembly forms and printing them is useful. 2711436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2712436c6c9cSStella Laurenzo printAccum.parts.append("Type("); 2713436c6c9cSStella Laurenzo mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2714436c6c9cSStella Laurenzo printAccum.parts.append(")"); 2715436c6c9cSStella Laurenzo return printAccum.join(); 2716436c6c9cSStella Laurenzo }); 2717436c6c9cSStella Laurenzo 2718436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2719436c6c9cSStella Laurenzo // Mapping of Value. 2720436c6c9cSStella Laurenzo //---------------------------------------------------------------------------- 2721f05ff4f7SStella Laurenzo py::class_<PyValue>(m, "Value", py::module_local()) 27223f3d1c90SMike Urbach .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 27233f3d1c90SMike Urbach .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2724436c6c9cSStella Laurenzo .def_property_readonly( 2725436c6c9cSStella Laurenzo "context", 2726436c6c9cSStella Laurenzo [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2727436c6c9cSStella Laurenzo "Context in which the value lives.") 2728436c6c9cSStella Laurenzo .def( 2729436c6c9cSStella Laurenzo "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2730436c6c9cSStella Laurenzo kDumpDocstring) 27315664c5e2SJohn Demme .def_property_readonly( 27325664c5e2SJohn Demme "owner", 27335664c5e2SJohn Demme [](PyValue &self) { 27345664c5e2SJohn Demme assert(mlirOperationEqual(self.getParentOperation()->get(), 27355664c5e2SJohn Demme mlirOpResultGetOwner(self.get())) && 27365664c5e2SJohn Demme "expected the owner of the value in Python to match that in " 27375664c5e2SJohn Demme "the IR"); 27385664c5e2SJohn Demme return self.getParentOperation().getObject(); 27395664c5e2SJohn Demme }) 2740436c6c9cSStella Laurenzo .def("__eq__", 2741436c6c9cSStella Laurenzo [](PyValue &self, PyValue &other) { 2742436c6c9cSStella Laurenzo return self.get().ptr == other.get().ptr; 2743436c6c9cSStella Laurenzo }) 2744436c6c9cSStella Laurenzo .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2745f78fe0b7Srkayaith .def("__hash__", 2746f78fe0b7Srkayaith [](PyValue &self) { 2747f78fe0b7Srkayaith return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2748f78fe0b7Srkayaith }) 2749436c6c9cSStella Laurenzo .def( 2750436c6c9cSStella Laurenzo "__str__", 2751436c6c9cSStella Laurenzo [](PyValue &self) { 2752436c6c9cSStella Laurenzo PyPrintAccumulator printAccum; 2753436c6c9cSStella Laurenzo printAccum.parts.append("Value("); 2754436c6c9cSStella Laurenzo mlirValuePrint(self.get(), printAccum.getCallback(), 2755436c6c9cSStella Laurenzo printAccum.getUserData()); 2756436c6c9cSStella Laurenzo printAccum.parts.append(")"); 2757436c6c9cSStella Laurenzo return printAccum.join(); 2758436c6c9cSStella Laurenzo }, 2759436c6c9cSStella Laurenzo kValueDunderStrDocstring) 2760436c6c9cSStella Laurenzo .def_property_readonly("type", [](PyValue &self) { 2761436c6c9cSStella Laurenzo return PyType(self.getParentOperation()->getContext(), 2762436c6c9cSStella Laurenzo mlirValueGetType(self.get())); 2763436c6c9cSStella Laurenzo }); 2764436c6c9cSStella Laurenzo PyBlockArgument::bind(m); 2765436c6c9cSStella Laurenzo PyOpResult::bind(m); 2766436c6c9cSStella Laurenzo 276730d61893SAlex Zinenko //---------------------------------------------------------------------------- 276830d61893SAlex Zinenko // Mapping of SymbolTable. 276930d61893SAlex Zinenko //---------------------------------------------------------------------------- 277030d61893SAlex Zinenko py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 277130d61893SAlex Zinenko .def(py::init<PyOperationBase &>()) 277230d61893SAlex Zinenko .def("__getitem__", &PySymbolTable::dunderGetItem) 2773*a6e7d024SStella Laurenzo .def("insert", &PySymbolTable::insert, py::arg("operation")) 2774*a6e7d024SStella Laurenzo .def("erase", &PySymbolTable::erase, py::arg("operation")) 277530d61893SAlex Zinenko .def("__delitem__", &PySymbolTable::dunderDel) 277630d61893SAlex Zinenko .def("__contains__", [](PySymbolTable &table, const std::string &name) { 277730d61893SAlex Zinenko return !mlirOperationIsNull(mlirSymbolTableLookup( 277830d61893SAlex Zinenko table, mlirStringRefCreate(name.data(), name.length()))); 277930d61893SAlex Zinenko }); 278030d61893SAlex Zinenko 2781436c6c9cSStella Laurenzo // Container bindings. 2782436c6c9cSStella Laurenzo PyBlockArgumentList::bind(m); 2783436c6c9cSStella Laurenzo PyBlockIterator::bind(m); 2784436c6c9cSStella Laurenzo PyBlockList::bind(m); 2785436c6c9cSStella Laurenzo PyOperationIterator::bind(m); 2786436c6c9cSStella Laurenzo PyOperationList::bind(m); 2787436c6c9cSStella Laurenzo PyOpAttributeMap::bind(m); 2788436c6c9cSStella Laurenzo PyOpOperandList::bind(m); 2789436c6c9cSStella Laurenzo PyOpResultList::bind(m); 2790436c6c9cSStella Laurenzo PyRegionIterator::bind(m); 2791436c6c9cSStella Laurenzo PyRegionList::bind(m); 27924acd8457SAlex Zinenko 27934acd8457SAlex Zinenko // Debug bindings. 27944acd8457SAlex Zinenko PyGlobalDebugFlag::bind(m); 2795436c6c9cSStella Laurenzo } 2796