1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include <pybind11/stl.h> 22 23 namespace py = pybind11; 24 using namespace mlir; 25 using namespace mlir::python; 26 27 using llvm::SmallVector; 28 using llvm::StringRef; 29 using llvm::Twine; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kContextParseTypeDocstring[] = 36 R"(Parses the assembly form of a type. 37 38 Returns a Type object or raises a ValueError if the type cannot be parsed. 39 40 See also: https://mlir.llvm.org/docs/LangRef/#type-system 41 )"; 42 43 static const char kContextGetFileLocationDocstring[] = 44 R"(Gets a Location representing a file, line and column)"; 45 46 static const char kModuleParseDocstring[] = 47 R"(Parses a module's assembly format from a string. 48 49 Returns a new MlirModule or raises a ValueError if the parsing fails. 50 51 See also: https://mlir.llvm.org/docs/LangRef/ 52 )"; 53 54 static const char kOperationCreateDocstring[] = 55 R"(Creates a new operation. 56 57 Args: 58 name: Operation name (e.g. "dialect.operation"). 59 results: Sequence of Type representing op result types. 60 attributes: Dict of str:Attribute. 61 successors: List of Block for the operation's successors. 62 regions: Number of regions to create. 63 location: A Location object (defaults to resolve from context manager). 64 ip: An InsertionPoint (defaults to resolve from context manager or set to 65 False to disable insertion, even with an insertion point set in the 66 context manager). 67 Returns: 68 A new "detached" Operation object. Detached operations can be added 69 to blocks, which causes them to become "attached." 70 )"; 71 72 static const char kOperationPrintDocstring[] = 73 R"(Prints the assembly form of the operation to a file like object. 74 75 Args: 76 file: The file like object to write to. Defaults to sys.stdout. 77 binary: Whether to write bytes (True) or str (False). Defaults to False. 78 large_elements_limit: Whether to elide elements attributes above this 79 number of elements. Defaults to None (no limit). 80 enable_debug_info: Whether to print debug/location information. Defaults 81 to False. 82 pretty_debug_info: Whether to format debug information for easier reading 83 by a human (warning: the result is unparseable). 84 print_generic_op_form: Whether to print the generic assembly forms of all 85 ops. Defaults to False. 86 use_local_Scope: Whether to print in a way that is more optimized for 87 multi-threaded access but may not be consistent with how the overall 88 module prints. 89 )"; 90 91 static const char kOperationGetAsmDocstring[] = 92 R"(Gets the assembly form of the operation with all options available. 93 94 Args: 95 binary: Whether to return a bytes (True) or str (False) object. Defaults to 96 False. 97 ... others ...: See the print() method for common keyword arguments for 98 configuring the printout. 99 Returns: 100 Either a bytes or str object, depending on the setting of the 'binary' 101 argument. 102 )"; 103 104 static const char kOperationStrDunderDocstring[] = 105 R"(Gets the assembly form of the operation with default options. 106 107 If more advanced control over the assembly formatting or I/O options is needed, 108 use the dedicated print or get_asm method, which supports keyword arguments to 109 customize behavior. 110 )"; 111 112 static const char kDumpDocstring[] = 113 R"(Dumps a debug representation of the object to stderr.)"; 114 115 static const char kAppendBlockDocstring[] = 116 R"(Appends a new block, with argument types as positional args. 117 118 Returns: 119 The created block. 120 )"; 121 122 static const char kValueDunderStrDocstring[] = 123 R"(Returns the string form of the value. 124 125 If the value is a block argument, this is the assembly form of its type and the 126 position in the argument list. If the value is an operation result, this is 127 equivalent to printing the operation that produced it. 128 )"; 129 130 //------------------------------------------------------------------------------ 131 // Utilities. 132 //------------------------------------------------------------------------------ 133 134 /// Helper for creating an @classmethod. 135 template <class Func, typename... Args> 136 py::object classmethod(Func f, Args... args) { 137 py::object cf = py::cpp_function(f, args...); 138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 139 } 140 141 static py::object 142 createCustomDialectWrapper(const std::string &dialectNamespace, 143 py::object dialectDescriptor) { 144 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 145 if (!dialectClass) { 146 // Use the base class. 147 return py::cast(PyDialect(std::move(dialectDescriptor))); 148 } 149 150 // Create the custom implementation. 151 return (*dialectClass)(std::move(dialectDescriptor)); 152 } 153 154 static MlirStringRef toMlirStringRef(const std::string &s) { 155 return mlirStringRefCreate(s.data(), s.size()); 156 } 157 158 /// Wrapper for the global LLVM debugging flag. 159 struct PyGlobalDebugFlag { 160 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 161 162 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 163 164 static void bind(py::module &m) { 165 // Debug flags. 166 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 167 .def_property_static("flag", &PyGlobalDebugFlag::get, 168 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 169 } 170 }; 171 172 //------------------------------------------------------------------------------ 173 // Collections. 174 //------------------------------------------------------------------------------ 175 176 namespace { 177 178 class PyRegionIterator { 179 public: 180 PyRegionIterator(PyOperationRef operation) 181 : operation(std::move(operation)) {} 182 183 PyRegionIterator &dunderIter() { return *this; } 184 185 PyRegion dunderNext() { 186 operation->checkValid(); 187 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 188 throw py::stop_iteration(); 189 } 190 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 191 return PyRegion(operation, region); 192 } 193 194 static void bind(py::module &m) { 195 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 196 .def("__iter__", &PyRegionIterator::dunderIter) 197 .def("__next__", &PyRegionIterator::dunderNext); 198 } 199 200 private: 201 PyOperationRef operation; 202 int nextIndex = 0; 203 }; 204 205 /// Regions of an op are fixed length and indexed numerically so are represented 206 /// with a sequence-like container. 207 class PyRegionList { 208 public: 209 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 210 211 intptr_t dunderLen() { 212 operation->checkValid(); 213 return mlirOperationGetNumRegions(operation->get()); 214 } 215 216 PyRegion dunderGetItem(intptr_t index) { 217 // dunderLen checks validity. 218 if (index < 0 || index >= dunderLen()) { 219 throw SetPyError(PyExc_IndexError, 220 "attempt to access out of bounds region"); 221 } 222 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 223 return PyRegion(operation, region); 224 } 225 226 static void bind(py::module &m) { 227 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 228 .def("__len__", &PyRegionList::dunderLen) 229 .def("__getitem__", &PyRegionList::dunderGetItem); 230 } 231 232 private: 233 PyOperationRef operation; 234 }; 235 236 class PyBlockIterator { 237 public: 238 PyBlockIterator(PyOperationRef operation, MlirBlock next) 239 : operation(std::move(operation)), next(next) {} 240 241 PyBlockIterator &dunderIter() { return *this; } 242 243 PyBlock dunderNext() { 244 operation->checkValid(); 245 if (mlirBlockIsNull(next)) { 246 throw py::stop_iteration(); 247 } 248 249 PyBlock returnBlock(operation, next); 250 next = mlirBlockGetNextInRegion(next); 251 return returnBlock; 252 } 253 254 static void bind(py::module &m) { 255 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 256 .def("__iter__", &PyBlockIterator::dunderIter) 257 .def("__next__", &PyBlockIterator::dunderNext); 258 } 259 260 private: 261 PyOperationRef operation; 262 MlirBlock next; 263 }; 264 265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 266 /// we present them as a more full-featured list-like container but optimize 267 /// it for forward iteration. Blocks are always owned by a region. 268 class PyBlockList { 269 public: 270 PyBlockList(PyOperationRef operation, MlirRegion region) 271 : operation(std::move(operation)), region(region) {} 272 273 PyBlockIterator dunderIter() { 274 operation->checkValid(); 275 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 276 } 277 278 intptr_t dunderLen() { 279 operation->checkValid(); 280 intptr_t count = 0; 281 MlirBlock block = mlirRegionGetFirstBlock(region); 282 while (!mlirBlockIsNull(block)) { 283 count += 1; 284 block = mlirBlockGetNextInRegion(block); 285 } 286 return count; 287 } 288 289 PyBlock dunderGetItem(intptr_t index) { 290 operation->checkValid(); 291 if (index < 0) { 292 throw SetPyError(PyExc_IndexError, 293 "attempt to access out of bounds block"); 294 } 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 if (index == 0) { 298 return PyBlock(operation, block); 299 } 300 block = mlirBlockGetNextInRegion(block); 301 index -= 1; 302 } 303 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 304 } 305 306 PyBlock appendBlock(py::args pyArgTypes) { 307 operation->checkValid(); 308 llvm::SmallVector<MlirType, 4> argTypes; 309 argTypes.reserve(pyArgTypes.size()); 310 for (auto &pyArg : pyArgTypes) { 311 argTypes.push_back(pyArg.cast<PyType &>()); 312 } 313 314 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 315 mlirRegionAppendOwnedBlock(region, block); 316 return PyBlock(operation, block); 317 } 318 319 static void bind(py::module &m) { 320 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 321 .def("__getitem__", &PyBlockList::dunderGetItem) 322 .def("__iter__", &PyBlockList::dunderIter) 323 .def("__len__", &PyBlockList::dunderLen) 324 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 325 } 326 327 private: 328 PyOperationRef operation; 329 MlirRegion region; 330 }; 331 332 class PyOperationIterator { 333 public: 334 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 335 : parentOperation(std::move(parentOperation)), next(next) {} 336 337 PyOperationIterator &dunderIter() { return *this; } 338 339 py::object dunderNext() { 340 parentOperation->checkValid(); 341 if (mlirOperationIsNull(next)) { 342 throw py::stop_iteration(); 343 } 344 345 PyOperationRef returnOperation = 346 PyOperation::forOperation(parentOperation->getContext(), next); 347 next = mlirOperationGetNextInBlock(next); 348 return returnOperation->createOpView(); 349 } 350 351 static void bind(py::module &m) { 352 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 353 .def("__iter__", &PyOperationIterator::dunderIter) 354 .def("__next__", &PyOperationIterator::dunderNext); 355 } 356 357 private: 358 PyOperationRef parentOperation; 359 MlirOperation next; 360 }; 361 362 /// Operations are exposed by the C-API as a forward-only linked list. In 363 /// Python, we present them as a more full-featured list-like container but 364 /// optimize it for forward iteration. Iterable operations are always owned 365 /// by a block. 366 class PyOperationList { 367 public: 368 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 369 : parentOperation(std::move(parentOperation)), block(block) {} 370 371 PyOperationIterator dunderIter() { 372 parentOperation->checkValid(); 373 return PyOperationIterator(parentOperation, 374 mlirBlockGetFirstOperation(block)); 375 } 376 377 intptr_t dunderLen() { 378 parentOperation->checkValid(); 379 intptr_t count = 0; 380 MlirOperation childOp = mlirBlockGetFirstOperation(block); 381 while (!mlirOperationIsNull(childOp)) { 382 count += 1; 383 childOp = mlirOperationGetNextInBlock(childOp); 384 } 385 return count; 386 } 387 388 py::object dunderGetItem(intptr_t index) { 389 parentOperation->checkValid(); 390 if (index < 0) { 391 throw SetPyError(PyExc_IndexError, 392 "attempt to access out of bounds operation"); 393 } 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 if (index == 0) { 397 return PyOperation::forOperation(parentOperation->getContext(), childOp) 398 ->createOpView(); 399 } 400 childOp = mlirOperationGetNextInBlock(childOp); 401 index -= 1; 402 } 403 throw SetPyError(PyExc_IndexError, 404 "attempt to access out of bounds operation"); 405 } 406 407 static void bind(py::module &m) { 408 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 409 .def("__getitem__", &PyOperationList::dunderGetItem) 410 .def("__iter__", &PyOperationList::dunderIter) 411 .def("__len__", &PyOperationList::dunderLen); 412 } 413 414 private: 415 PyOperationRef parentOperation; 416 MlirBlock block; 417 }; 418 419 } // namespace 420 421 //------------------------------------------------------------------------------ 422 // PyMlirContext 423 //------------------------------------------------------------------------------ 424 425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 426 py::gil_scoped_acquire acquire; 427 auto &liveContexts = getLiveContexts(); 428 liveContexts[context.ptr] = this; 429 } 430 431 PyMlirContext::~PyMlirContext() { 432 // Note that the only public way to construct an instance is via the 433 // forContext method, which always puts the associated handle into 434 // liveContexts. 435 py::gil_scoped_acquire acquire; 436 getLiveContexts().erase(context.ptr); 437 mlirContextDestroy(context); 438 } 439 440 py::object PyMlirContext::getCapsule() { 441 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 442 } 443 444 py::object PyMlirContext::createFromCapsule(py::object capsule) { 445 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 446 if (mlirContextIsNull(rawContext)) 447 throw py::error_already_set(); 448 return forContext(rawContext).releaseObject(); 449 } 450 451 PyMlirContext *PyMlirContext::createNewContextForInit() { 452 MlirContext context = mlirContextCreate(); 453 mlirRegisterAllDialects(context); 454 return new PyMlirContext(context); 455 } 456 457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 458 py::gil_scoped_acquire acquire; 459 auto &liveContexts = getLiveContexts(); 460 auto it = liveContexts.find(context.ptr); 461 if (it == liveContexts.end()) { 462 // Create. 463 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 464 py::object pyRef = py::cast(unownedContextWrapper); 465 assert(pyRef && "cast to py::object failed"); 466 liveContexts[context.ptr] = unownedContextWrapper; 467 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 468 } 469 // Use existing. 470 py::object pyRef = py::cast(it->second); 471 return PyMlirContextRef(it->second, std::move(pyRef)); 472 } 473 474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 475 static LiveContextMap liveContexts; 476 return liveContexts; 477 } 478 479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 480 481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 482 483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 484 485 pybind11::object PyMlirContext::contextEnter() { 486 return PyThreadContextEntry::pushContext(*this); 487 } 488 489 void PyMlirContext::contextExit(pybind11::object excType, 490 pybind11::object excVal, 491 pybind11::object excTb) { 492 PyThreadContextEntry::popContext(*this); 493 } 494 495 PyMlirContext &DefaultingPyMlirContext::resolve() { 496 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 497 if (!context) { 498 throw SetPyError( 499 PyExc_RuntimeError, 500 "An MLIR function requires a Context but none was provided in the call " 501 "or from the surrounding environment. Either pass to the function with " 502 "a 'context=' argument or establish a default using 'with Context():'"); 503 } 504 return *context; 505 } 506 507 //------------------------------------------------------------------------------ 508 // PyThreadContextEntry management 509 //------------------------------------------------------------------------------ 510 511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 512 static thread_local std::vector<PyThreadContextEntry> stack; 513 return stack; 514 } 515 516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 517 auto &stack = getStack(); 518 if (stack.empty()) 519 return nullptr; 520 return &stack.back(); 521 } 522 523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 524 py::object insertionPoint, 525 py::object location) { 526 auto &stack = getStack(); 527 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 528 std::move(location)); 529 // If the new stack has more than one entry and the context of the new top 530 // entry matches the previous, copy the insertionPoint and location from the 531 // previous entry if missing from the new top entry. 532 if (stack.size() > 1) { 533 auto &prev = *(stack.rbegin() + 1); 534 auto ¤t = stack.back(); 535 if (current.context.is(prev.context)) { 536 // Default non-context objects from the previous entry. 537 if (!current.insertionPoint) 538 current.insertionPoint = prev.insertionPoint; 539 if (!current.location) 540 current.location = prev.location; 541 } 542 } 543 } 544 545 PyMlirContext *PyThreadContextEntry::getContext() { 546 if (!context) 547 return nullptr; 548 return py::cast<PyMlirContext *>(context); 549 } 550 551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 552 if (!insertionPoint) 553 return nullptr; 554 return py::cast<PyInsertionPoint *>(insertionPoint); 555 } 556 557 PyLocation *PyThreadContextEntry::getLocation() { 558 if (!location) 559 return nullptr; 560 return py::cast<PyLocation *>(location); 561 } 562 563 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 564 auto *tos = getTopOfStack(); 565 return tos ? tos->getContext() : nullptr; 566 } 567 568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 569 auto *tos = getTopOfStack(); 570 return tos ? tos->getInsertionPoint() : nullptr; 571 } 572 573 PyLocation *PyThreadContextEntry::getDefaultLocation() { 574 auto *tos = getTopOfStack(); 575 return tos ? tos->getLocation() : nullptr; 576 } 577 578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 579 py::object contextObj = py::cast(context); 580 push(FrameKind::Context, /*context=*/contextObj, 581 /*insertionPoint=*/py::object(), 582 /*location=*/py::object()); 583 return contextObj; 584 } 585 586 void PyThreadContextEntry::popContext(PyMlirContext &context) { 587 auto &stack = getStack(); 588 if (stack.empty()) 589 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 590 auto &tos = stack.back(); 591 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593 stack.pop_back(); 594 } 595 596 py::object 597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 598 py::object contextObj = 599 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 600 py::object insertionPointObj = py::cast(insertionPoint); 601 push(FrameKind::InsertionPoint, 602 /*context=*/contextObj, 603 /*insertionPoint=*/insertionPointObj, 604 /*location=*/py::object()); 605 return insertionPointObj; 606 } 607 608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 609 auto &stack = getStack(); 610 if (stack.empty()) 611 throw SetPyError(PyExc_RuntimeError, 612 "Unbalanced InsertionPoint enter/exit"); 613 auto &tos = stack.back(); 614 if (tos.frameKind != FrameKind::InsertionPoint && 615 tos.getInsertionPoint() != &insertionPoint) 616 throw SetPyError(PyExc_RuntimeError, 617 "Unbalanced InsertionPoint enter/exit"); 618 stack.pop_back(); 619 } 620 621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 622 py::object contextObj = location.getContext().getObject(); 623 py::object locationObj = py::cast(location); 624 push(FrameKind::Location, /*context=*/contextObj, 625 /*insertionPoint=*/py::object(), 626 /*location=*/locationObj); 627 return locationObj; 628 } 629 630 void PyThreadContextEntry::popLocation(PyLocation &location) { 631 auto &stack = getStack(); 632 if (stack.empty()) 633 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 634 auto &tos = stack.back(); 635 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637 stack.pop_back(); 638 } 639 640 //------------------------------------------------------------------------------ 641 // PyDialect, PyDialectDescriptor, PyDialects 642 //------------------------------------------------------------------------------ 643 644 MlirDialect PyDialects::getDialectForKey(const std::string &key, 645 bool attrError) { 646 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 647 {key.data(), key.size()}); 648 if (mlirDialectIsNull(dialect)) { 649 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 650 Twine("Dialect '") + key + "' not found"); 651 } 652 return dialect; 653 } 654 655 //------------------------------------------------------------------------------ 656 // PyLocation 657 //------------------------------------------------------------------------------ 658 659 py::object PyLocation::getCapsule() { 660 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 661 } 662 663 PyLocation PyLocation::createFromCapsule(py::object capsule) { 664 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 665 if (mlirLocationIsNull(rawLoc)) 666 throw py::error_already_set(); 667 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 668 rawLoc); 669 } 670 671 py::object PyLocation::contextEnter() { 672 return PyThreadContextEntry::pushLocation(*this); 673 } 674 675 void PyLocation::contextExit(py::object excType, py::object excVal, 676 py::object excTb) { 677 PyThreadContextEntry::popLocation(*this); 678 } 679 680 PyLocation &DefaultingPyLocation::resolve() { 681 auto *location = PyThreadContextEntry::getDefaultLocation(); 682 if (!location) { 683 throw SetPyError( 684 PyExc_RuntimeError, 685 "An MLIR function requires a Location but none was provided in the " 686 "call or from the surrounding environment. Either pass to the function " 687 "with a 'loc=' argument or establish a default using 'with loc:'"); 688 } 689 return *location; 690 } 691 692 //------------------------------------------------------------------------------ 693 // PyModule 694 //------------------------------------------------------------------------------ 695 696 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 697 : BaseContextObject(std::move(contextRef)), module(module) {} 698 699 PyModule::~PyModule() { 700 py::gil_scoped_acquire acquire; 701 auto &liveModules = getContext()->liveModules; 702 assert(liveModules.count(module.ptr) == 1 && 703 "destroying module not in live map"); 704 liveModules.erase(module.ptr); 705 mlirModuleDestroy(module); 706 } 707 708 PyModuleRef PyModule::forModule(MlirModule module) { 709 MlirContext context = mlirModuleGetContext(module); 710 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 711 712 py::gil_scoped_acquire acquire; 713 auto &liveModules = contextRef->liveModules; 714 auto it = liveModules.find(module.ptr); 715 if (it == liveModules.end()) { 716 // Create. 717 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 718 // Note that the default return value policy on cast is automatic_reference, 719 // which does not take ownership (delete will not be called). 720 // Just be explicit. 721 py::object pyRef = 722 py::cast(unownedModule, py::return_value_policy::take_ownership); 723 unownedModule->handle = pyRef; 724 liveModules[module.ptr] = 725 std::make_pair(unownedModule->handle, unownedModule); 726 return PyModuleRef(unownedModule, std::move(pyRef)); 727 } 728 // Use existing. 729 PyModule *existing = it->second.second; 730 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 731 return PyModuleRef(existing, std::move(pyRef)); 732 } 733 734 py::object PyModule::createFromCapsule(py::object capsule) { 735 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 736 if (mlirModuleIsNull(rawModule)) 737 throw py::error_already_set(); 738 return forModule(rawModule).releaseObject(); 739 } 740 741 py::object PyModule::getCapsule() { 742 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 743 } 744 745 //------------------------------------------------------------------------------ 746 // PyOperation 747 //------------------------------------------------------------------------------ 748 749 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 750 : BaseContextObject(std::move(contextRef)), operation(operation) {} 751 752 PyOperation::~PyOperation() { 753 // If the operation has already been invalidated there is nothing to do. 754 if (!valid) 755 return; 756 auto &liveOperations = getContext()->liveOperations; 757 assert(liveOperations.count(operation.ptr) == 1 && 758 "destroying operation not in live map"); 759 liveOperations.erase(operation.ptr); 760 if (!isAttached()) { 761 mlirOperationDestroy(operation); 762 } 763 } 764 765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 766 MlirOperation operation, 767 py::object parentKeepAlive) { 768 auto &liveOperations = contextRef->liveOperations; 769 // Create. 770 PyOperation *unownedOperation = 771 new PyOperation(std::move(contextRef), operation); 772 // Note that the default return value policy on cast is automatic_reference, 773 // which does not take ownership (delete will not be called). 774 // Just be explicit. 775 py::object pyRef = 776 py::cast(unownedOperation, py::return_value_policy::take_ownership); 777 unownedOperation->handle = pyRef; 778 if (parentKeepAlive) { 779 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 780 } 781 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 782 return PyOperationRef(unownedOperation, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 auto it = liveOperations.find(operation.ptr); 790 if (it == liveOperations.end()) { 791 // Create. 792 return createInstance(std::move(contextRef), operation, 793 std::move(parentKeepAlive)); 794 } 795 // Use existing. 796 PyOperation *existing = it->second.second; 797 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 798 return PyOperationRef(existing, std::move(pyRef)); 799 } 800 801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 802 MlirOperation operation, 803 py::object parentKeepAlive) { 804 auto &liveOperations = contextRef->liveOperations; 805 assert(liveOperations.count(operation.ptr) == 0 && 806 "cannot create detached operation that already exists"); 807 (void)liveOperations; 808 809 PyOperationRef created = createInstance(std::move(contextRef), operation, 810 std::move(parentKeepAlive)); 811 created->attached = false; 812 return created; 813 } 814 815 void PyOperation::checkValid() const { 816 if (!valid) { 817 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 818 } 819 } 820 821 void PyOperationBase::print(py::object fileObject, bool binary, 822 llvm::Optional<int64_t> largeElementsLimit, 823 bool enableDebugInfo, bool prettyDebugInfo, 824 bool printGenericOpForm, bool useLocalScope) { 825 PyOperation &operation = getOperation(); 826 operation.checkValid(); 827 if (fileObject.is_none()) 828 fileObject = py::module::import("sys").attr("stdout"); 829 830 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 831 fileObject.attr("write")("// Verification failed, printing generic form\n"); 832 printGenericOpForm = true; 833 } 834 835 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 836 if (largeElementsLimit) 837 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 838 if (enableDebugInfo) 839 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 840 if (printGenericOpForm) 841 mlirOpPrintingFlagsPrintGenericOpForm(flags); 842 843 PyFileAccumulator accum(fileObject, binary); 844 py::gil_scoped_release(); 845 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 846 accum.getUserData()); 847 mlirOpPrintingFlagsDestroy(flags); 848 } 849 850 py::object PyOperationBase::getAsm(bool binary, 851 llvm::Optional<int64_t> largeElementsLimit, 852 bool enableDebugInfo, bool prettyDebugInfo, 853 bool printGenericOpForm, 854 bool useLocalScope) { 855 py::object fileObject; 856 if (binary) { 857 fileObject = py::module::import("io").attr("BytesIO")(); 858 } else { 859 fileObject = py::module::import("io").attr("StringIO")(); 860 } 861 print(fileObject, /*binary=*/binary, 862 /*largeElementsLimit=*/largeElementsLimit, 863 /*enableDebugInfo=*/enableDebugInfo, 864 /*prettyDebugInfo=*/prettyDebugInfo, 865 /*printGenericOpForm=*/printGenericOpForm, 866 /*useLocalScope=*/useLocalScope); 867 868 return fileObject.attr("getvalue")(); 869 } 870 871 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 872 checkValid(); 873 if (!isAttached()) 874 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 875 MlirOperation operation = mlirOperationGetParentOperation(get()); 876 if (mlirOperationIsNull(operation)) 877 return {}; 878 return PyOperation::forOperation(getContext(), operation); 879 } 880 881 PyBlock PyOperation::getBlock() { 882 checkValid(); 883 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 884 MlirBlock block = mlirOperationGetBlock(get()); 885 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 886 assert(parentOperation && "Operation has no parent"); 887 return PyBlock{std::move(*parentOperation), block}; 888 } 889 890 py::object PyOperation::getCapsule() { 891 checkValid(); 892 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 893 } 894 895 py::object PyOperation::createFromCapsule(py::object capsule) { 896 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 897 if (mlirOperationIsNull(rawOperation)) 898 throw py::error_already_set(); 899 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 900 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 901 .releaseObject(); 902 } 903 904 py::object PyOperation::create( 905 std::string name, llvm::Optional<std::vector<PyType *>> results, 906 llvm::Optional<std::vector<PyValue *>> operands, 907 llvm::Optional<py::dict> attributes, 908 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 909 DefaultingPyLocation location, py::object maybeIp) { 910 llvm::SmallVector<MlirValue, 4> mlirOperands; 911 llvm::SmallVector<MlirType, 4> mlirResults; 912 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 913 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 914 915 // General parameter validation. 916 if (regions < 0) 917 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 918 919 // Unpack/validate operands. 920 if (operands) { 921 mlirOperands.reserve(operands->size()); 922 for (PyValue *operand : *operands) { 923 if (!operand) 924 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 925 mlirOperands.push_back(operand->get()); 926 } 927 } 928 929 // Unpack/validate results. 930 if (results) { 931 mlirResults.reserve(results->size()); 932 for (PyType *result : *results) { 933 // TODO: Verify result type originate from the same context. 934 if (!result) 935 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 936 mlirResults.push_back(*result); 937 } 938 } 939 // Unpack/validate attributes. 940 if (attributes) { 941 mlirAttributes.reserve(attributes->size()); 942 for (auto &it : *attributes) { 943 std::string key; 944 try { 945 key = it.first.cast<std::string>(); 946 } catch (py::cast_error &err) { 947 std::string msg = "Invalid attribute key (not a string) when " 948 "attempting to create the operation \"" + 949 name + "\" (" + err.what() + ")"; 950 throw py::cast_error(msg); 951 } 952 try { 953 auto &attribute = it.second.cast<PyAttribute &>(); 954 // TODO: Verify attribute originates from the same context. 955 mlirAttributes.emplace_back(std::move(key), attribute); 956 } catch (py::reference_cast_error &) { 957 // This exception seems thrown when the value is "None". 958 std::string msg = 959 "Found an invalid (`None`?) attribute value for the key \"" + key + 960 "\" when attempting to create the operation \"" + name + "\""; 961 throw py::cast_error(msg); 962 } catch (py::cast_error &err) { 963 std::string msg = "Invalid attribute value for the key \"" + key + 964 "\" when attempting to create the operation \"" + 965 name + "\" (" + err.what() + ")"; 966 throw py::cast_error(msg); 967 } 968 } 969 } 970 // Unpack/validate successors. 971 if (successors) { 972 mlirSuccessors.reserve(successors->size()); 973 for (auto *successor : *successors) { 974 // TODO: Verify successor originate from the same context. 975 if (!successor) 976 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 977 mlirSuccessors.push_back(successor->get()); 978 } 979 } 980 981 // Apply unpacked/validated to the operation state. Beyond this 982 // point, exceptions cannot be thrown or else the state will leak. 983 MlirOperationState state = 984 mlirOperationStateGet(toMlirStringRef(name), location); 985 if (!mlirOperands.empty()) 986 mlirOperationStateAddOperands(&state, mlirOperands.size(), 987 mlirOperands.data()); 988 if (!mlirResults.empty()) 989 mlirOperationStateAddResults(&state, mlirResults.size(), 990 mlirResults.data()); 991 if (!mlirAttributes.empty()) { 992 // Note that the attribute names directly reference bytes in 993 // mlirAttributes, so that vector must not be changed from here 994 // on. 995 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 996 mlirNamedAttributes.reserve(mlirAttributes.size()); 997 for (auto &it : mlirAttributes) 998 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 999 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1000 toMlirStringRef(it.first)), 1001 it.second)); 1002 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1003 mlirNamedAttributes.data()); 1004 } 1005 if (!mlirSuccessors.empty()) 1006 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1007 mlirSuccessors.data()); 1008 if (regions) { 1009 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1010 mlirRegions.resize(regions); 1011 for (int i = 0; i < regions; ++i) 1012 mlirRegions[i] = mlirRegionCreate(); 1013 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1014 mlirRegions.data()); 1015 } 1016 1017 // Construct the operation. 1018 MlirOperation operation = mlirOperationCreate(&state); 1019 PyOperationRef created = 1020 PyOperation::createDetached(location->getContext(), operation); 1021 1022 // InsertPoint active? 1023 if (!maybeIp.is(py::cast(false))) { 1024 PyInsertionPoint *ip; 1025 if (maybeIp.is_none()) { 1026 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1027 } else { 1028 ip = py::cast<PyInsertionPoint *>(maybeIp); 1029 } 1030 if (ip) 1031 ip->insert(*created.get()); 1032 } 1033 1034 return created->createOpView(); 1035 } 1036 1037 py::object PyOperation::createOpView() { 1038 checkValid(); 1039 MlirIdentifier ident = mlirOperationGetName(get()); 1040 MlirStringRef identStr = mlirIdentifierStr(ident); 1041 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1042 StringRef(identStr.data, identStr.length)); 1043 if (opViewClass) 1044 return (*opViewClass)(getRef().getObject()); 1045 return py::cast(PyOpView(getRef().getObject())); 1046 } 1047 1048 void PyOperation::erase() { 1049 checkValid(); 1050 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1051 // Python reference to a child operation is live. All children should also 1052 // have their `valid` bit set to false. 1053 auto &liveOperations = getContext()->liveOperations; 1054 if (liveOperations.count(operation.ptr)) 1055 liveOperations.erase(operation.ptr); 1056 mlirOperationDestroy(operation); 1057 valid = false; 1058 } 1059 1060 //------------------------------------------------------------------------------ 1061 // PyOpView 1062 //------------------------------------------------------------------------------ 1063 1064 py::object 1065 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1066 py::list operandList, 1067 llvm::Optional<py::dict> attributes, 1068 llvm::Optional<std::vector<PyBlock *>> successors, 1069 llvm::Optional<int> regions, 1070 DefaultingPyLocation location, py::object maybeIp) { 1071 PyMlirContextRef context = location->getContext(); 1072 // Class level operation construction metadata. 1073 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1074 // Operand and result segment specs are either none, which does no 1075 // variadic unpacking, or a list of ints with segment sizes, where each 1076 // element is either a positive number (typically 1 for a scalar) or -1 to 1077 // indicate that it is derived from the length of the same-indexed operand 1078 // or result (implying that it is a list at that position). 1079 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1080 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1081 1082 std::vector<uint32_t> operandSegmentLengths; 1083 std::vector<uint32_t> resultSegmentLengths; 1084 1085 // Validate/determine region count. 1086 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1087 int opMinRegionCount = std::get<0>(opRegionSpec); 1088 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1089 if (!regions) { 1090 regions = opMinRegionCount; 1091 } 1092 if (*regions < opMinRegionCount) { 1093 throw py::value_error( 1094 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1095 llvm::Twine(opMinRegionCount) + 1096 " regions but was built with regions=" + llvm::Twine(*regions)) 1097 .str()); 1098 } 1099 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1100 throw py::value_error( 1101 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1102 llvm::Twine(opMinRegionCount) + 1103 " regions but was built with regions=" + llvm::Twine(*regions)) 1104 .str()); 1105 } 1106 1107 // Unpack results. 1108 std::vector<PyType *> resultTypes; 1109 resultTypes.reserve(resultTypeList.size()); 1110 if (resultSegmentSpecObj.is_none()) { 1111 // Non-variadic result unpacking. 1112 for (auto it : llvm::enumerate(resultTypeList)) { 1113 try { 1114 resultTypes.push_back(py::cast<PyType *>(it.value())); 1115 if (!resultTypes.back()) 1116 throw py::cast_error(); 1117 } catch (py::cast_error &err) { 1118 throw py::value_error((llvm::Twine("Result ") + 1119 llvm::Twine(it.index()) + " of operation \"" + 1120 name + "\" must be a Type (" + err.what() + ")") 1121 .str()); 1122 } 1123 } 1124 } else { 1125 // Sized result unpacking. 1126 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1127 if (resultSegmentSpec.size() != resultTypeList.size()) { 1128 throw py::value_error((llvm::Twine("Operation \"") + name + 1129 "\" requires " + 1130 llvm::Twine(resultSegmentSpec.size()) + 1131 "result segments but was provided " + 1132 llvm::Twine(resultTypeList.size())) 1133 .str()); 1134 } 1135 resultSegmentLengths.reserve(resultTypeList.size()); 1136 for (auto it : 1137 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1138 int segmentSpec = std::get<1>(it.value()); 1139 if (segmentSpec == 1 || segmentSpec == 0) { 1140 // Unpack unary element. 1141 try { 1142 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1143 if (resultType) { 1144 resultTypes.push_back(resultType); 1145 resultSegmentLengths.push_back(1); 1146 } else if (segmentSpec == 0) { 1147 // Allowed to be optional. 1148 resultSegmentLengths.push_back(0); 1149 } else { 1150 throw py::cast_error("was None and result is not optional"); 1151 } 1152 } catch (py::cast_error &err) { 1153 throw py::value_error((llvm::Twine("Result ") + 1154 llvm::Twine(it.index()) + " of operation \"" + 1155 name + "\" must be a Type (" + err.what() + 1156 ")") 1157 .str()); 1158 } 1159 } else if (segmentSpec == -1) { 1160 // Unpack sequence by appending. 1161 try { 1162 if (std::get<0>(it.value()).is_none()) { 1163 // Treat it as an empty list. 1164 resultSegmentLengths.push_back(0); 1165 } else { 1166 // Unpack the list. 1167 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1168 for (py::object segmentItem : segment) { 1169 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1170 if (!resultTypes.back()) { 1171 throw py::cast_error("contained a None item"); 1172 } 1173 } 1174 resultSegmentLengths.push_back(segment.size()); 1175 } 1176 } catch (std::exception &err) { 1177 // NOTE: Sloppy to be using a catch-all here, but there are at least 1178 // three different unrelated exceptions that can be thrown in the 1179 // above "casts". Just keep the scope above small and catch them all. 1180 throw py::value_error((llvm::Twine("Result ") + 1181 llvm::Twine(it.index()) + " of operation \"" + 1182 name + "\" must be a Sequence of Types (" + 1183 err.what() + ")") 1184 .str()); 1185 } 1186 } else { 1187 throw py::value_error("Unexpected segment spec"); 1188 } 1189 } 1190 } 1191 1192 // Unpack operands. 1193 std::vector<PyValue *> operands; 1194 operands.reserve(operands.size()); 1195 if (operandSegmentSpecObj.is_none()) { 1196 // Non-sized operand unpacking. 1197 for (auto it : llvm::enumerate(operandList)) { 1198 try { 1199 operands.push_back(py::cast<PyValue *>(it.value())); 1200 if (!operands.back()) 1201 throw py::cast_error(); 1202 } catch (py::cast_error &err) { 1203 throw py::value_error((llvm::Twine("Operand ") + 1204 llvm::Twine(it.index()) + " of operation \"" + 1205 name + "\" must be a Value (" + err.what() + ")") 1206 .str()); 1207 } 1208 } 1209 } else { 1210 // Sized operand unpacking. 1211 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1212 if (operandSegmentSpec.size() != operandList.size()) { 1213 throw py::value_error((llvm::Twine("Operation \"") + name + 1214 "\" requires " + 1215 llvm::Twine(operandSegmentSpec.size()) + 1216 "operand segments but was provided " + 1217 llvm::Twine(operandList.size())) 1218 .str()); 1219 } 1220 operandSegmentLengths.reserve(operandList.size()); 1221 for (auto it : 1222 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1223 int segmentSpec = std::get<1>(it.value()); 1224 if (segmentSpec == 1 || segmentSpec == 0) { 1225 // Unpack unary element. 1226 try { 1227 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1228 if (operandValue) { 1229 operands.push_back(operandValue); 1230 operandSegmentLengths.push_back(1); 1231 } else if (segmentSpec == 0) { 1232 // Allowed to be optional. 1233 operandSegmentLengths.push_back(0); 1234 } else { 1235 throw py::cast_error("was None and operand is not optional"); 1236 } 1237 } catch (py::cast_error &err) { 1238 throw py::value_error((llvm::Twine("Operand ") + 1239 llvm::Twine(it.index()) + " of operation \"" + 1240 name + "\" must be a Value (" + err.what() + 1241 ")") 1242 .str()); 1243 } 1244 } else if (segmentSpec == -1) { 1245 // Unpack sequence by appending. 1246 try { 1247 if (std::get<0>(it.value()).is_none()) { 1248 // Treat it as an empty list. 1249 operandSegmentLengths.push_back(0); 1250 } else { 1251 // Unpack the list. 1252 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1253 for (py::object segmentItem : segment) { 1254 operands.push_back(py::cast<PyValue *>(segmentItem)); 1255 if (!operands.back()) { 1256 throw py::cast_error("contained a None item"); 1257 } 1258 } 1259 operandSegmentLengths.push_back(segment.size()); 1260 } 1261 } catch (std::exception &err) { 1262 // NOTE: Sloppy to be using a catch-all here, but there are at least 1263 // three different unrelated exceptions that can be thrown in the 1264 // above "casts". Just keep the scope above small and catch them all. 1265 throw py::value_error((llvm::Twine("Operand ") + 1266 llvm::Twine(it.index()) + " of operation \"" + 1267 name + "\" must be a Sequence of Values (" + 1268 err.what() + ")") 1269 .str()); 1270 } 1271 } else { 1272 throw py::value_error("Unexpected segment spec"); 1273 } 1274 } 1275 } 1276 1277 // Merge operand/result segment lengths into attributes if needed. 1278 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1279 // Dup. 1280 if (attributes) { 1281 attributes = py::dict(*attributes); 1282 } else { 1283 attributes = py::dict(); 1284 } 1285 if (attributes->contains("result_segment_sizes") || 1286 attributes->contains("operand_segment_sizes")) { 1287 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1288 "'operand_segment_sizes' attribute is unsupported. " 1289 "Use Operation.create for such low-level access."); 1290 } 1291 1292 // Add result_segment_sizes attribute. 1293 if (!resultSegmentLengths.empty()) { 1294 int64_t size = resultSegmentLengths.size(); 1295 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1296 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1297 resultSegmentLengths.size(), resultSegmentLengths.data()); 1298 (*attributes)["result_segment_sizes"] = 1299 PyAttribute(context, segmentLengthAttr); 1300 } 1301 1302 // Add operand_segment_sizes attribute. 1303 if (!operandSegmentLengths.empty()) { 1304 int64_t size = operandSegmentLengths.size(); 1305 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1306 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1307 operandSegmentLengths.size(), operandSegmentLengths.data()); 1308 (*attributes)["operand_segment_sizes"] = 1309 PyAttribute(context, segmentLengthAttr); 1310 } 1311 } 1312 1313 // Delegate to create. 1314 return PyOperation::create(std::move(name), 1315 /*results=*/std::move(resultTypes), 1316 /*operands=*/std::move(operands), 1317 /*attributes=*/std::move(attributes), 1318 /*successors=*/std::move(successors), 1319 /*regions=*/*regions, location, maybeIp); 1320 } 1321 1322 PyOpView::PyOpView(py::object operationObject) 1323 // Casting through the PyOperationBase base-class and then back to the 1324 // Operation lets us accept any PyOperationBase subclass. 1325 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1326 operationObject(operation.getRef().getObject()) {} 1327 1328 py::object PyOpView::createRawSubclass(py::object userClass) { 1329 // This is... a little gross. The typical pattern is to have a pure python 1330 // class that extends OpView like: 1331 // class AddFOp(_cext.ir.OpView): 1332 // def __init__(self, loc, lhs, rhs): 1333 // operation = loc.context.create_operation( 1334 // "addf", lhs, rhs, results=[lhs.type]) 1335 // super().__init__(operation) 1336 // 1337 // I.e. The goal of the user facing type is to provide a nice constructor 1338 // that has complete freedom for the op under construction. This is at odds 1339 // with our other desire to sometimes create this object by just passing an 1340 // operation (to initialize the base class). We could do *arg and **kwargs 1341 // munging to try to make it work, but instead, we synthesize a new class 1342 // on the fly which extends this user class (AddFOp in this example) and 1343 // *give it* the base class's __init__ method, thus bypassing the 1344 // intermediate subclass's __init__ method entirely. While slightly, 1345 // underhanded, this is safe/legal because the type hierarchy has not changed 1346 // (we just added a new leaf) and we aren't mucking around with __new__. 1347 // Typically, this new class will be stored on the original as "_Raw" and will 1348 // be used for casts and other things that need a variant of the class that 1349 // is initialized purely from an operation. 1350 py::object parentMetaclass = 1351 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1352 py::dict attributes; 1353 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1354 // now. 1355 // auto opViewType = py::type::of<PyOpView>(); 1356 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1357 attributes["__init__"] = opViewType.attr("__init__"); 1358 py::str origName = userClass.attr("__name__"); 1359 py::str newName = py::str("_") + origName; 1360 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1361 } 1362 1363 //------------------------------------------------------------------------------ 1364 // PyInsertionPoint. 1365 //------------------------------------------------------------------------------ 1366 1367 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1368 1369 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1370 : refOperation(beforeOperationBase.getOperation().getRef()), 1371 block((*refOperation)->getBlock()) {} 1372 1373 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1374 PyOperation &operation = operationBase.getOperation(); 1375 if (operation.isAttached()) 1376 throw SetPyError(PyExc_ValueError, 1377 "Attempt to insert operation that is already attached"); 1378 block.getParentOperation()->checkValid(); 1379 MlirOperation beforeOp = {nullptr}; 1380 if (refOperation) { 1381 // Insert before operation. 1382 (*refOperation)->checkValid(); 1383 beforeOp = (*refOperation)->get(); 1384 } else { 1385 // Insert at end (before null) is only valid if the block does not 1386 // already end in a known terminator (violating this will cause assertion 1387 // failures later). 1388 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1389 throw py::index_error("Cannot insert operation at the end of a block " 1390 "that already has a terminator. Did you mean to " 1391 "use 'InsertionPoint.at_block_terminator(block)' " 1392 "versus 'InsertionPoint(block)'?"); 1393 } 1394 } 1395 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1396 operation.setAttached(); 1397 } 1398 1399 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1400 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1401 if (mlirOperationIsNull(firstOp)) { 1402 // Just insert at end. 1403 return PyInsertionPoint(block); 1404 } 1405 1406 // Insert before first op. 1407 PyOperationRef firstOpRef = PyOperation::forOperation( 1408 block.getParentOperation()->getContext(), firstOp); 1409 return PyInsertionPoint{block, std::move(firstOpRef)}; 1410 } 1411 1412 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1413 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1414 if (mlirOperationIsNull(terminator)) 1415 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1416 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1417 block.getParentOperation()->getContext(), terminator); 1418 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1419 } 1420 1421 py::object PyInsertionPoint::contextEnter() { 1422 return PyThreadContextEntry::pushInsertionPoint(*this); 1423 } 1424 1425 void PyInsertionPoint::contextExit(pybind11::object excType, 1426 pybind11::object excVal, 1427 pybind11::object excTb) { 1428 PyThreadContextEntry::popInsertionPoint(*this); 1429 } 1430 1431 //------------------------------------------------------------------------------ 1432 // PyAttribute. 1433 //------------------------------------------------------------------------------ 1434 1435 bool PyAttribute::operator==(const PyAttribute &other) { 1436 return mlirAttributeEqual(attr, other.attr); 1437 } 1438 1439 py::object PyAttribute::getCapsule() { 1440 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1441 } 1442 1443 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1444 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1445 if (mlirAttributeIsNull(rawAttr)) 1446 throw py::error_already_set(); 1447 return PyAttribute( 1448 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1449 } 1450 1451 //------------------------------------------------------------------------------ 1452 // PyNamedAttribute. 1453 //------------------------------------------------------------------------------ 1454 1455 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1456 : ownedName(new std::string(std::move(ownedName))) { 1457 namedAttr = mlirNamedAttributeGet( 1458 mlirIdentifierGet(mlirAttributeGetContext(attr), 1459 toMlirStringRef(*this->ownedName)), 1460 attr); 1461 } 1462 1463 //------------------------------------------------------------------------------ 1464 // PyType. 1465 //------------------------------------------------------------------------------ 1466 1467 bool PyType::operator==(const PyType &other) { 1468 return mlirTypeEqual(type, other.type); 1469 } 1470 1471 py::object PyType::getCapsule() { 1472 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1473 } 1474 1475 PyType PyType::createFromCapsule(py::object capsule) { 1476 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1477 if (mlirTypeIsNull(rawType)) 1478 throw py::error_already_set(); 1479 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1480 rawType); 1481 } 1482 1483 //------------------------------------------------------------------------------ 1484 // PyValue and subclases. 1485 //------------------------------------------------------------------------------ 1486 1487 pybind11::object PyValue::getCapsule() { 1488 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1489 } 1490 1491 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1492 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1493 if (mlirValueIsNull(value)) 1494 throw py::error_already_set(); 1495 MlirOperation owner; 1496 if (mlirValueIsAOpResult(value)) 1497 owner = mlirOpResultGetOwner(value); 1498 if (mlirValueIsABlockArgument(value)) 1499 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1500 if (mlirOperationIsNull(owner)) 1501 throw py::error_already_set(); 1502 MlirContext ctx = mlirOperationGetContext(owner); 1503 PyOperationRef ownerRef = 1504 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1505 return PyValue(ownerRef, value); 1506 } 1507 1508 namespace { 1509 /// CRTP base class for Python MLIR values that subclass Value and should be 1510 /// castable from it. The value hierarchy is one level deep and is not supposed 1511 /// to accommodate other levels unless core MLIR changes. 1512 template <typename DerivedTy> 1513 class PyConcreteValue : public PyValue { 1514 public: 1515 // Derived classes must define statics for: 1516 // IsAFunctionTy isaFunction 1517 // const char *pyClassName 1518 // and redefine bindDerived. 1519 using ClassTy = py::class_<DerivedTy, PyValue>; 1520 using IsAFunctionTy = bool (*)(MlirValue); 1521 1522 PyConcreteValue() = default; 1523 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1524 : PyValue(operationRef, value) {} 1525 PyConcreteValue(PyValue &orig) 1526 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1527 1528 /// Attempts to cast the original value to the derived type and throws on 1529 /// type mismatches. 1530 static MlirValue castFrom(PyValue &orig) { 1531 if (!DerivedTy::isaFunction(orig.get())) { 1532 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1533 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1534 DerivedTy::pyClassName + 1535 " (from " + origRepr + ")"); 1536 } 1537 return orig.get(); 1538 } 1539 1540 /// Binds the Python module objects to functions of this class. 1541 static void bind(py::module &m) { 1542 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1543 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1544 DerivedTy::bindDerived(cls); 1545 } 1546 1547 /// Implemented by derived classes to add methods to the Python subclass. 1548 static void bindDerived(ClassTy &m) {} 1549 }; 1550 1551 /// Python wrapper for MlirBlockArgument. 1552 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1553 public: 1554 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1555 static constexpr const char *pyClassName = "BlockArgument"; 1556 using PyConcreteValue::PyConcreteValue; 1557 1558 static void bindDerived(ClassTy &c) { 1559 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1560 return PyBlock(self.getParentOperation(), 1561 mlirBlockArgumentGetOwner(self.get())); 1562 }); 1563 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1564 return mlirBlockArgumentGetArgNumber(self.get()); 1565 }); 1566 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1567 return mlirBlockArgumentSetType(self.get(), type); 1568 }); 1569 } 1570 }; 1571 1572 /// Python wrapper for MlirOpResult. 1573 class PyOpResult : public PyConcreteValue<PyOpResult> { 1574 public: 1575 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1576 static constexpr const char *pyClassName = "OpResult"; 1577 using PyConcreteValue::PyConcreteValue; 1578 1579 static void bindDerived(ClassTy &c) { 1580 c.def_property_readonly("owner", [](PyOpResult &self) { 1581 assert( 1582 mlirOperationEqual(self.getParentOperation()->get(), 1583 mlirOpResultGetOwner(self.get())) && 1584 "expected the owner of the value in Python to match that in the IR"); 1585 return self.getParentOperation().getObject(); 1586 }); 1587 c.def_property_readonly("result_number", [](PyOpResult &self) { 1588 return mlirOpResultGetResultNumber(self.get()); 1589 }); 1590 } 1591 }; 1592 1593 /// Returns the list of types of the values held by container. 1594 template <typename Container> 1595 static std::vector<PyType> getValueTypes(Container &container, 1596 PyMlirContextRef &context) { 1597 std::vector<PyType> result; 1598 result.reserve(container.getNumElements()); 1599 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1600 result.push_back( 1601 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1602 } 1603 return result; 1604 } 1605 1606 /// A list of block arguments. Internally, these are stored as consecutive 1607 /// elements, random access is cheap. The argument list is associated with the 1608 /// operation that contains the block (detached blocks are not allowed in 1609 /// Python bindings) and extends its lifetime. 1610 class PyBlockArgumentList 1611 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1612 public: 1613 static constexpr const char *pyClassName = "BlockArgumentList"; 1614 1615 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1616 intptr_t startIndex = 0, intptr_t length = -1, 1617 intptr_t step = 1) 1618 : Sliceable(startIndex, 1619 length == -1 ? mlirBlockGetNumArguments(block) : length, 1620 step), 1621 operation(std::move(operation)), block(block) {} 1622 1623 /// Returns the number of arguments in the list. 1624 intptr_t getNumElements() { 1625 operation->checkValid(); 1626 return mlirBlockGetNumArguments(block); 1627 } 1628 1629 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1630 PyBlockArgument getElement(intptr_t pos) { 1631 MlirValue argument = mlirBlockGetArgument(block, pos); 1632 return PyBlockArgument(operation, argument); 1633 } 1634 1635 /// Returns a sublist of this list. 1636 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1637 intptr_t step) { 1638 return PyBlockArgumentList(operation, block, startIndex, length, step); 1639 } 1640 1641 static void bindDerived(ClassTy &c) { 1642 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1643 return getValueTypes(self, self.operation->getContext()); 1644 }); 1645 } 1646 1647 private: 1648 PyOperationRef operation; 1649 MlirBlock block; 1650 }; 1651 1652 /// A list of operation operands. Internally, these are stored as consecutive 1653 /// elements, random access is cheap. The result list is associated with the 1654 /// operation whose results these are, and extends the lifetime of this 1655 /// operation. 1656 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1657 public: 1658 static constexpr const char *pyClassName = "OpOperandList"; 1659 1660 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1661 intptr_t length = -1, intptr_t step = 1) 1662 : Sliceable(startIndex, 1663 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1664 : length, 1665 step), 1666 operation(operation) {} 1667 1668 intptr_t getNumElements() { 1669 operation->checkValid(); 1670 return mlirOperationGetNumOperands(operation->get()); 1671 } 1672 1673 PyValue getElement(intptr_t pos) { 1674 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1675 MlirOperation owner; 1676 if (mlirValueIsAOpResult(operand)) 1677 owner = mlirOpResultGetOwner(operand); 1678 else if (mlirValueIsABlockArgument(operand)) 1679 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1680 else 1681 assert(false && "Value must be an block arg or op result."); 1682 PyOperationRef pyOwner = 1683 PyOperation::forOperation(operation->getContext(), owner); 1684 return PyValue(pyOwner, operand); 1685 } 1686 1687 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1688 return PyOpOperandList(operation, startIndex, length, step); 1689 } 1690 1691 void dunderSetItem(intptr_t index, PyValue value) { 1692 index = wrapIndex(index); 1693 mlirOperationSetOperand(operation->get(), index, value.get()); 1694 } 1695 1696 static void bindDerived(ClassTy &c) { 1697 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1698 } 1699 1700 private: 1701 PyOperationRef operation; 1702 }; 1703 1704 /// A list of operation results. Internally, these are stored as consecutive 1705 /// elements, random access is cheap. The result list is associated with the 1706 /// operation whose results these are, and extends the lifetime of this 1707 /// operation. 1708 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1709 public: 1710 static constexpr const char *pyClassName = "OpResultList"; 1711 1712 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1713 intptr_t length = -1, intptr_t step = 1) 1714 : Sliceable(startIndex, 1715 length == -1 ? mlirOperationGetNumResults(operation->get()) 1716 : length, 1717 step), 1718 operation(operation) {} 1719 1720 intptr_t getNumElements() { 1721 operation->checkValid(); 1722 return mlirOperationGetNumResults(operation->get()); 1723 } 1724 1725 PyOpResult getElement(intptr_t index) { 1726 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1727 return PyOpResult(value); 1728 } 1729 1730 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1731 return PyOpResultList(operation, startIndex, length, step); 1732 } 1733 1734 static void bindDerived(ClassTy &c) { 1735 c.def_property_readonly("types", [](PyOpResultList &self) { 1736 return getValueTypes(self, self.operation->getContext()); 1737 }); 1738 } 1739 1740 private: 1741 PyOperationRef operation; 1742 }; 1743 1744 /// A list of operation attributes. Can be indexed by name, producing 1745 /// attributes, or by index, producing named attributes. 1746 class PyOpAttributeMap { 1747 public: 1748 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1749 1750 PyAttribute dunderGetItemNamed(const std::string &name) { 1751 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1752 toMlirStringRef(name)); 1753 if (mlirAttributeIsNull(attr)) { 1754 throw SetPyError(PyExc_KeyError, 1755 "attempt to access a non-existent attribute"); 1756 } 1757 return PyAttribute(operation->getContext(), attr); 1758 } 1759 1760 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1761 if (index < 0 || index >= dunderLen()) { 1762 throw SetPyError(PyExc_IndexError, 1763 "attempt to access out of bounds attribute"); 1764 } 1765 MlirNamedAttribute namedAttr = 1766 mlirOperationGetAttribute(operation->get(), index); 1767 return PyNamedAttribute( 1768 namedAttr.attribute, 1769 std::string(mlirIdentifierStr(namedAttr.name).data)); 1770 } 1771 1772 void dunderSetItem(const std::string &name, PyAttribute attr) { 1773 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1774 attr); 1775 } 1776 1777 void dunderDelItem(const std::string &name) { 1778 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1779 toMlirStringRef(name)); 1780 if (!removed) 1781 throw SetPyError(PyExc_KeyError, 1782 "attempt to delete a non-existent attribute"); 1783 } 1784 1785 intptr_t dunderLen() { 1786 return mlirOperationGetNumAttributes(operation->get()); 1787 } 1788 1789 bool dunderContains(const std::string &name) { 1790 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1791 operation->get(), toMlirStringRef(name))); 1792 } 1793 1794 static void bind(py::module &m) { 1795 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1796 .def("__contains__", &PyOpAttributeMap::dunderContains) 1797 .def("__len__", &PyOpAttributeMap::dunderLen) 1798 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1799 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1800 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1801 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1802 } 1803 1804 private: 1805 PyOperationRef operation; 1806 }; 1807 1808 } // end namespace 1809 1810 //------------------------------------------------------------------------------ 1811 // Populates the core exports of the 'ir' submodule. 1812 //------------------------------------------------------------------------------ 1813 1814 void mlir::python::populateIRCore(py::module &m) { 1815 //---------------------------------------------------------------------------- 1816 // Mapping of MlirContext. 1817 //---------------------------------------------------------------------------- 1818 py::class_<PyMlirContext>(m, "Context", py::module_local()) 1819 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1820 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1821 .def("_get_context_again", 1822 [](PyMlirContext &self) { 1823 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1824 return ref.releaseObject(); 1825 }) 1826 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1827 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1828 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1829 &PyMlirContext::getCapsule) 1830 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1831 .def("__enter__", &PyMlirContext::contextEnter) 1832 .def("__exit__", &PyMlirContext::contextExit) 1833 .def_property_readonly_static( 1834 "current", 1835 [](py::object & /*class*/) { 1836 auto *context = PyThreadContextEntry::getDefaultContext(); 1837 if (!context) 1838 throw SetPyError(PyExc_ValueError, "No current Context"); 1839 return context; 1840 }, 1841 "Gets the Context bound to the current thread or raises ValueError") 1842 .def_property_readonly( 1843 "dialects", 1844 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1845 "Gets a container for accessing dialects by name") 1846 .def_property_readonly( 1847 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1848 "Alias for 'dialect'") 1849 .def( 1850 "get_dialect_descriptor", 1851 [=](PyMlirContext &self, std::string &name) { 1852 MlirDialect dialect = mlirContextGetOrLoadDialect( 1853 self.get(), {name.data(), name.size()}); 1854 if (mlirDialectIsNull(dialect)) { 1855 throw SetPyError(PyExc_ValueError, 1856 Twine("Dialect '") + name + "' not found"); 1857 } 1858 return PyDialectDescriptor(self.getRef(), dialect); 1859 }, 1860 "Gets or loads a dialect by name, returning its descriptor object") 1861 .def_property( 1862 "allow_unregistered_dialects", 1863 [](PyMlirContext &self) -> bool { 1864 return mlirContextGetAllowUnregisteredDialects(self.get()); 1865 }, 1866 [](PyMlirContext &self, bool value) { 1867 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1868 }) 1869 .def("enable_multithreading", 1870 [](PyMlirContext &self, bool enable) { 1871 mlirContextEnableMultithreading(self.get(), enable); 1872 }) 1873 .def("is_registered_operation", 1874 [](PyMlirContext &self, std::string &name) { 1875 return mlirContextIsRegisteredOperation( 1876 self.get(), MlirStringRef{name.data(), name.size()}); 1877 }); 1878 1879 //---------------------------------------------------------------------------- 1880 // Mapping of PyDialectDescriptor 1881 //---------------------------------------------------------------------------- 1882 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1883 .def_property_readonly("namespace", 1884 [](PyDialectDescriptor &self) { 1885 MlirStringRef ns = 1886 mlirDialectGetNamespace(self.get()); 1887 return py::str(ns.data, ns.length); 1888 }) 1889 .def("__repr__", [](PyDialectDescriptor &self) { 1890 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1891 std::string repr("<DialectDescriptor "); 1892 repr.append(ns.data, ns.length); 1893 repr.append(">"); 1894 return repr; 1895 }); 1896 1897 //---------------------------------------------------------------------------- 1898 // Mapping of PyDialects 1899 //---------------------------------------------------------------------------- 1900 py::class_<PyDialects>(m, "Dialects", py::module_local()) 1901 .def("__getitem__", 1902 [=](PyDialects &self, std::string keyName) { 1903 MlirDialect dialect = 1904 self.getDialectForKey(keyName, /*attrError=*/false); 1905 py::object descriptor = 1906 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1907 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1908 }) 1909 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1910 MlirDialect dialect = 1911 self.getDialectForKey(attrName, /*attrError=*/true); 1912 py::object descriptor = 1913 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1914 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1915 }); 1916 1917 //---------------------------------------------------------------------------- 1918 // Mapping of PyDialect 1919 //---------------------------------------------------------------------------- 1920 py::class_<PyDialect>(m, "Dialect", py::module_local()) 1921 .def(py::init<py::object>(), "descriptor") 1922 .def_property_readonly( 1923 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1924 .def("__repr__", [](py::object self) { 1925 auto clazz = self.attr("__class__"); 1926 return py::str("<Dialect ") + 1927 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1928 clazz.attr("__module__") + py::str(".") + 1929 clazz.attr("__name__") + py::str(")>"); 1930 }); 1931 1932 //---------------------------------------------------------------------------- 1933 // Mapping of Location 1934 //---------------------------------------------------------------------------- 1935 py::class_<PyLocation>(m, "Location", py::module_local()) 1936 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1937 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1938 .def("__enter__", &PyLocation::contextEnter) 1939 .def("__exit__", &PyLocation::contextExit) 1940 .def("__eq__", 1941 [](PyLocation &self, PyLocation &other) -> bool { 1942 return mlirLocationEqual(self, other); 1943 }) 1944 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1945 .def_property_readonly_static( 1946 "current", 1947 [](py::object & /*class*/) { 1948 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1949 if (!loc) 1950 throw SetPyError(PyExc_ValueError, "No current Location"); 1951 return loc; 1952 }, 1953 "Gets the Location bound to the current thread or raises ValueError") 1954 .def_static( 1955 "unknown", 1956 [](DefaultingPyMlirContext context) { 1957 return PyLocation(context->getRef(), 1958 mlirLocationUnknownGet(context->get())); 1959 }, 1960 py::arg("context") = py::none(), 1961 "Gets a Location representing an unknown location") 1962 .def_static( 1963 "file", 1964 [](std::string filename, int line, int col, 1965 DefaultingPyMlirContext context) { 1966 return PyLocation( 1967 context->getRef(), 1968 mlirLocationFileLineColGet( 1969 context->get(), toMlirStringRef(filename), line, col)); 1970 }, 1971 py::arg("filename"), py::arg("line"), py::arg("col"), 1972 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1973 .def_property_readonly( 1974 "context", 1975 [](PyLocation &self) { return self.getContext().getObject(); }, 1976 "Context that owns the Location") 1977 .def("__repr__", [](PyLocation &self) { 1978 PyPrintAccumulator printAccum; 1979 mlirLocationPrint(self, printAccum.getCallback(), 1980 printAccum.getUserData()); 1981 return printAccum.join(); 1982 }); 1983 1984 //---------------------------------------------------------------------------- 1985 // Mapping of Module 1986 //---------------------------------------------------------------------------- 1987 py::class_<PyModule>(m, "Module", py::module_local()) 1988 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1989 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1990 .def_static( 1991 "parse", 1992 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1993 MlirModule module = mlirModuleCreateParse( 1994 context->get(), toMlirStringRef(moduleAsm)); 1995 // TODO: Rework error reporting once diagnostic engine is exposed 1996 // in C API. 1997 if (mlirModuleIsNull(module)) { 1998 throw SetPyError( 1999 PyExc_ValueError, 2000 "Unable to parse module assembly (see diagnostics)"); 2001 } 2002 return PyModule::forModule(module).releaseObject(); 2003 }, 2004 py::arg("asm"), py::arg("context") = py::none(), 2005 kModuleParseDocstring) 2006 .def_static( 2007 "create", 2008 [](DefaultingPyLocation loc) { 2009 MlirModule module = mlirModuleCreateEmpty(loc); 2010 return PyModule::forModule(module).releaseObject(); 2011 }, 2012 py::arg("loc") = py::none(), "Creates an empty module") 2013 .def_property_readonly( 2014 "context", 2015 [](PyModule &self) { return self.getContext().getObject(); }, 2016 "Context that created the Module") 2017 .def_property_readonly( 2018 "operation", 2019 [](PyModule &self) { 2020 return PyOperation::forOperation(self.getContext(), 2021 mlirModuleGetOperation(self.get()), 2022 self.getRef().releaseObject()) 2023 .releaseObject(); 2024 }, 2025 "Accesses the module as an operation") 2026 .def_property_readonly( 2027 "body", 2028 [](PyModule &self) { 2029 PyOperationRef module_op = PyOperation::forOperation( 2030 self.getContext(), mlirModuleGetOperation(self.get()), 2031 self.getRef().releaseObject()); 2032 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2033 return returnBlock; 2034 }, 2035 "Return the block for this module") 2036 .def( 2037 "dump", 2038 [](PyModule &self) { 2039 mlirOperationDump(mlirModuleGetOperation(self.get())); 2040 }, 2041 kDumpDocstring) 2042 .def( 2043 "__str__", 2044 [](PyModule &self) { 2045 MlirOperation operation = mlirModuleGetOperation(self.get()); 2046 PyPrintAccumulator printAccum; 2047 mlirOperationPrint(operation, printAccum.getCallback(), 2048 printAccum.getUserData()); 2049 return printAccum.join(); 2050 }, 2051 kOperationStrDunderDocstring); 2052 2053 //---------------------------------------------------------------------------- 2054 // Mapping of Operation. 2055 //---------------------------------------------------------------------------- 2056 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2057 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2058 [](PyOperationBase &self) { 2059 return self.getOperation().getCapsule(); 2060 }) 2061 .def("__eq__", 2062 [](PyOperationBase &self, PyOperationBase &other) { 2063 return &self.getOperation() == &other.getOperation(); 2064 }) 2065 .def("__eq__", 2066 [](PyOperationBase &self, py::object other) { return false; }) 2067 .def_property_readonly("attributes", 2068 [](PyOperationBase &self) { 2069 return PyOpAttributeMap( 2070 self.getOperation().getRef()); 2071 }) 2072 .def_property_readonly("operands", 2073 [](PyOperationBase &self) { 2074 return PyOpOperandList( 2075 self.getOperation().getRef()); 2076 }) 2077 .def_property_readonly("regions", 2078 [](PyOperationBase &self) { 2079 return PyRegionList( 2080 self.getOperation().getRef()); 2081 }) 2082 .def_property_readonly( 2083 "results", 2084 [](PyOperationBase &self) { 2085 return PyOpResultList(self.getOperation().getRef()); 2086 }, 2087 "Returns the list of Operation results.") 2088 .def_property_readonly( 2089 "result", 2090 [](PyOperationBase &self) { 2091 auto &operation = self.getOperation(); 2092 auto numResults = mlirOperationGetNumResults(operation); 2093 if (numResults != 1) { 2094 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2095 throw SetPyError( 2096 PyExc_ValueError, 2097 Twine("Cannot call .result on operation ") + 2098 StringRef(name.data, name.length) + " which has " + 2099 Twine(numResults) + 2100 " results (it is only valid for operations with a " 2101 "single result)"); 2102 } 2103 return PyOpResult(operation.getRef(), 2104 mlirOperationGetResult(operation, 0)); 2105 }, 2106 "Shortcut to get an op result if it has only one (throws an error " 2107 "otherwise).") 2108 .def("__iter__", 2109 [](PyOperationBase &self) { 2110 return PyRegionIterator(self.getOperation().getRef()); 2111 }) 2112 .def( 2113 "__str__", 2114 [](PyOperationBase &self) { 2115 return self.getAsm(/*binary=*/false, 2116 /*largeElementsLimit=*/llvm::None, 2117 /*enableDebugInfo=*/false, 2118 /*prettyDebugInfo=*/false, 2119 /*printGenericOpForm=*/false, 2120 /*useLocalScope=*/false); 2121 }, 2122 "Returns the assembly form of the operation.") 2123 .def("print", &PyOperationBase::print, 2124 // Careful: Lots of arguments must match up with print method. 2125 py::arg("file") = py::none(), py::arg("binary") = false, 2126 py::arg("large_elements_limit") = py::none(), 2127 py::arg("enable_debug_info") = false, 2128 py::arg("pretty_debug_info") = false, 2129 py::arg("print_generic_op_form") = false, 2130 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2131 .def("get_asm", &PyOperationBase::getAsm, 2132 // Careful: Lots of arguments must match up with get_asm method. 2133 py::arg("binary") = false, 2134 py::arg("large_elements_limit") = py::none(), 2135 py::arg("enable_debug_info") = false, 2136 py::arg("pretty_debug_info") = false, 2137 py::arg("print_generic_op_form") = false, 2138 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2139 .def( 2140 "verify", 2141 [](PyOperationBase &self) { 2142 return mlirOperationVerify(self.getOperation()); 2143 }, 2144 "Verify the operation and return true if it passes, false if it " 2145 "fails."); 2146 2147 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2148 .def_static("create", &PyOperation::create, py::arg("name"), 2149 py::arg("results") = py::none(), 2150 py::arg("operands") = py::none(), 2151 py::arg("attributes") = py::none(), 2152 py::arg("successors") = py::none(), py::arg("regions") = 0, 2153 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2154 kOperationCreateDocstring) 2155 .def_property_readonly("parent", 2156 [](PyOperation &self) -> py::object { 2157 auto parent = self.getParentOperation(); 2158 if (parent) 2159 return parent->getObject(); 2160 return py::none(); 2161 }) 2162 .def("erase", &PyOperation::erase) 2163 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2164 &PyOperation::getCapsule) 2165 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2166 .def_property_readonly("name", 2167 [](PyOperation &self) { 2168 self.checkValid(); 2169 MlirOperation operation = self.get(); 2170 MlirStringRef name = mlirIdentifierStr( 2171 mlirOperationGetName(operation)); 2172 return py::str(name.data, name.length); 2173 }) 2174 .def_property_readonly( 2175 "context", 2176 [](PyOperation &self) { 2177 self.checkValid(); 2178 return self.getContext().getObject(); 2179 }, 2180 "Context that owns the Operation") 2181 .def_property_readonly("opview", &PyOperation::createOpView); 2182 2183 auto opViewClass = 2184 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2185 .def(py::init<py::object>()) 2186 .def_property_readonly("operation", &PyOpView::getOperationObject) 2187 .def_property_readonly( 2188 "context", 2189 [](PyOpView &self) { 2190 return self.getOperation().getContext().getObject(); 2191 }, 2192 "Context that owns the Operation") 2193 .def("__str__", [](PyOpView &self) { 2194 return py::str(self.getOperationObject()); 2195 }); 2196 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2197 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2198 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2199 opViewClass.attr("build_generic") = classmethod( 2200 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2201 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2202 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2203 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2204 "Builds a specific, generated OpView based on class level attributes."); 2205 2206 //---------------------------------------------------------------------------- 2207 // Mapping of PyRegion. 2208 //---------------------------------------------------------------------------- 2209 py::class_<PyRegion>(m, "Region", py::module_local()) 2210 .def_property_readonly( 2211 "blocks", 2212 [](PyRegion &self) { 2213 return PyBlockList(self.getParentOperation(), self.get()); 2214 }, 2215 "Returns a forward-optimized sequence of blocks.") 2216 .def( 2217 "__iter__", 2218 [](PyRegion &self) { 2219 self.checkValid(); 2220 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2221 return PyBlockIterator(self.getParentOperation(), firstBlock); 2222 }, 2223 "Iterates over blocks in the region.") 2224 .def("__eq__", 2225 [](PyRegion &self, PyRegion &other) { 2226 return self.get().ptr == other.get().ptr; 2227 }) 2228 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2229 2230 //---------------------------------------------------------------------------- 2231 // Mapping of PyBlock. 2232 //---------------------------------------------------------------------------- 2233 py::class_<PyBlock>(m, "Block", py::module_local()) 2234 .def_property_readonly( 2235 "owner", 2236 [](PyBlock &self) { 2237 return self.getParentOperation()->createOpView(); 2238 }, 2239 "Returns the owning operation of this block.") 2240 .def_property_readonly( 2241 "region", 2242 [](PyBlock &self) { 2243 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2244 return PyRegion(self.getParentOperation(), region); 2245 }, 2246 "Returns the owning region of this block.") 2247 .def_property_readonly( 2248 "arguments", 2249 [](PyBlock &self) { 2250 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2251 }, 2252 "Returns a list of block arguments.") 2253 .def_property_readonly( 2254 "operations", 2255 [](PyBlock &self) { 2256 return PyOperationList(self.getParentOperation(), self.get()); 2257 }, 2258 "Returns a forward-optimized sequence of operations.") 2259 .def( 2260 "create_before", 2261 [](PyBlock &self, py::args pyArgTypes) { 2262 self.checkValid(); 2263 llvm::SmallVector<MlirType, 4> argTypes; 2264 argTypes.reserve(pyArgTypes.size()); 2265 for (auto &pyArg : pyArgTypes) { 2266 argTypes.push_back(pyArg.cast<PyType &>()); 2267 } 2268 2269 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2270 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2271 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2272 return PyBlock(self.getParentOperation(), block); 2273 }, 2274 "Creates and returns a new Block before this block " 2275 "(with given argument types).") 2276 .def( 2277 "create_after", 2278 [](PyBlock &self, py::args pyArgTypes) { 2279 self.checkValid(); 2280 llvm::SmallVector<MlirType, 4> argTypes; 2281 argTypes.reserve(pyArgTypes.size()); 2282 for (auto &pyArg : pyArgTypes) { 2283 argTypes.push_back(pyArg.cast<PyType &>()); 2284 } 2285 2286 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2287 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2288 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2289 return PyBlock(self.getParentOperation(), block); 2290 }, 2291 "Creates and returns a new Block after this block " 2292 "(with given argument types).") 2293 .def( 2294 "__iter__", 2295 [](PyBlock &self) { 2296 self.checkValid(); 2297 MlirOperation firstOperation = 2298 mlirBlockGetFirstOperation(self.get()); 2299 return PyOperationIterator(self.getParentOperation(), 2300 firstOperation); 2301 }, 2302 "Iterates over operations in the block.") 2303 .def("__eq__", 2304 [](PyBlock &self, PyBlock &other) { 2305 return self.get().ptr == other.get().ptr; 2306 }) 2307 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2308 .def( 2309 "__str__", 2310 [](PyBlock &self) { 2311 self.checkValid(); 2312 PyPrintAccumulator printAccum; 2313 mlirBlockPrint(self.get(), printAccum.getCallback(), 2314 printAccum.getUserData()); 2315 return printAccum.join(); 2316 }, 2317 "Returns the assembly form of the block."); 2318 2319 //---------------------------------------------------------------------------- 2320 // Mapping of PyInsertionPoint. 2321 //---------------------------------------------------------------------------- 2322 2323 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2324 .def(py::init<PyBlock &>(), py::arg("block"), 2325 "Inserts after the last operation but still inside the block.") 2326 .def("__enter__", &PyInsertionPoint::contextEnter) 2327 .def("__exit__", &PyInsertionPoint::contextExit) 2328 .def_property_readonly_static( 2329 "current", 2330 [](py::object & /*class*/) { 2331 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2332 if (!ip) 2333 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2334 return ip; 2335 }, 2336 "Gets the InsertionPoint bound to the current thread or raises " 2337 "ValueError if none has been set") 2338 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2339 "Inserts before a referenced operation.") 2340 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2341 py::arg("block"), "Inserts at the beginning of the block.") 2342 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2343 py::arg("block"), "Inserts before the block terminator.") 2344 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2345 "Inserts an operation.") 2346 .def_property_readonly( 2347 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2348 "Returns the block that this InsertionPoint points to."); 2349 2350 //---------------------------------------------------------------------------- 2351 // Mapping of PyAttribute. 2352 //---------------------------------------------------------------------------- 2353 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2354 // Delegate to the PyAttribute copy constructor, which will also lifetime 2355 // extend the backing context which owns the MlirAttribute. 2356 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2357 "Casts the passed attribute to the generic Attribute") 2358 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2359 &PyAttribute::getCapsule) 2360 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2361 .def_static( 2362 "parse", 2363 [](std::string attrSpec, DefaultingPyMlirContext context) { 2364 MlirAttribute type = mlirAttributeParseGet( 2365 context->get(), toMlirStringRef(attrSpec)); 2366 // TODO: Rework error reporting once diagnostic engine is exposed 2367 // in C API. 2368 if (mlirAttributeIsNull(type)) { 2369 throw SetPyError(PyExc_ValueError, 2370 Twine("Unable to parse attribute: '") + 2371 attrSpec + "'"); 2372 } 2373 return PyAttribute(context->getRef(), type); 2374 }, 2375 py::arg("asm"), py::arg("context") = py::none(), 2376 "Parses an attribute from an assembly form") 2377 .def_property_readonly( 2378 "context", 2379 [](PyAttribute &self) { return self.getContext().getObject(); }, 2380 "Context that owns the Attribute") 2381 .def_property_readonly("type", 2382 [](PyAttribute &self) { 2383 return PyType(self.getContext()->getRef(), 2384 mlirAttributeGetType(self)); 2385 }) 2386 .def( 2387 "get_named", 2388 [](PyAttribute &self, std::string name) { 2389 return PyNamedAttribute(self, std::move(name)); 2390 }, 2391 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2392 .def("__eq__", 2393 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2394 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2395 .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; }) 2396 .def( 2397 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2398 kDumpDocstring) 2399 .def( 2400 "__str__", 2401 [](PyAttribute &self) { 2402 PyPrintAccumulator printAccum; 2403 mlirAttributePrint(self, printAccum.getCallback(), 2404 printAccum.getUserData()); 2405 return printAccum.join(); 2406 }, 2407 "Returns the assembly form of the Attribute.") 2408 .def("__repr__", [](PyAttribute &self) { 2409 // Generally, assembly formats are not printed for __repr__ because 2410 // this can cause exceptionally long debug output and exceptions. 2411 // However, attribute values are generally considered useful and are 2412 // printed. This may need to be re-evaluated if debug dumps end up 2413 // being excessive. 2414 PyPrintAccumulator printAccum; 2415 printAccum.parts.append("Attribute("); 2416 mlirAttributePrint(self, printAccum.getCallback(), 2417 printAccum.getUserData()); 2418 printAccum.parts.append(")"); 2419 return printAccum.join(); 2420 }); 2421 2422 //---------------------------------------------------------------------------- 2423 // Mapping of PyNamedAttribute 2424 //---------------------------------------------------------------------------- 2425 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2426 .def("__repr__", 2427 [](PyNamedAttribute &self) { 2428 PyPrintAccumulator printAccum; 2429 printAccum.parts.append("NamedAttribute("); 2430 printAccum.parts.append( 2431 mlirIdentifierStr(self.namedAttr.name).data); 2432 printAccum.parts.append("="); 2433 mlirAttributePrint(self.namedAttr.attribute, 2434 printAccum.getCallback(), 2435 printAccum.getUserData()); 2436 printAccum.parts.append(")"); 2437 return printAccum.join(); 2438 }) 2439 .def_property_readonly( 2440 "name", 2441 [](PyNamedAttribute &self) { 2442 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2443 mlirIdentifierStr(self.namedAttr.name).length); 2444 }, 2445 "The name of the NamedAttribute binding") 2446 .def_property_readonly( 2447 "attr", 2448 [](PyNamedAttribute &self) { 2449 // TODO: When named attribute is removed/refactored, also remove 2450 // this constructor (it does an inefficient table lookup). 2451 auto contextRef = PyMlirContext::forContext( 2452 mlirAttributeGetContext(self.namedAttr.attribute)); 2453 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2454 }, 2455 py::keep_alive<0, 1>(), 2456 "The underlying generic attribute of the NamedAttribute binding"); 2457 2458 //---------------------------------------------------------------------------- 2459 // Mapping of PyType. 2460 //---------------------------------------------------------------------------- 2461 py::class_<PyType>(m, "Type", py::module_local()) 2462 // Delegate to the PyType copy constructor, which will also lifetime 2463 // extend the backing context which owns the MlirType. 2464 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2465 "Casts the passed type to the generic Type") 2466 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2467 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2468 .def_static( 2469 "parse", 2470 [](std::string typeSpec, DefaultingPyMlirContext context) { 2471 MlirType type = 2472 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2473 // TODO: Rework error reporting once diagnostic engine is exposed 2474 // in C API. 2475 if (mlirTypeIsNull(type)) { 2476 throw SetPyError(PyExc_ValueError, 2477 Twine("Unable to parse type: '") + typeSpec + 2478 "'"); 2479 } 2480 return PyType(context->getRef(), type); 2481 }, 2482 py::arg("asm"), py::arg("context") = py::none(), 2483 kContextParseTypeDocstring) 2484 .def_property_readonly( 2485 "context", [](PyType &self) { return self.getContext().getObject(); }, 2486 "Context that owns the Type") 2487 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2488 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2489 .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; }) 2490 .def( 2491 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2492 .def( 2493 "__str__", 2494 [](PyType &self) { 2495 PyPrintAccumulator printAccum; 2496 mlirTypePrint(self, printAccum.getCallback(), 2497 printAccum.getUserData()); 2498 return printAccum.join(); 2499 }, 2500 "Returns the assembly form of the type.") 2501 .def("__repr__", [](PyType &self) { 2502 // Generally, assembly formats are not printed for __repr__ because 2503 // this can cause exceptionally long debug output and exceptions. 2504 // However, types are an exception as they typically have compact 2505 // assembly forms and printing them is useful. 2506 PyPrintAccumulator printAccum; 2507 printAccum.parts.append("Type("); 2508 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2509 printAccum.parts.append(")"); 2510 return printAccum.join(); 2511 }); 2512 2513 //---------------------------------------------------------------------------- 2514 // Mapping of Value. 2515 //---------------------------------------------------------------------------- 2516 py::class_<PyValue>(m, "Value", py::module_local()) 2517 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2518 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2519 .def_property_readonly( 2520 "context", 2521 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2522 "Context in which the value lives.") 2523 .def( 2524 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2525 kDumpDocstring) 2526 .def_property_readonly( 2527 "owner", 2528 [](PyValue &self) { 2529 assert(mlirOperationEqual(self.getParentOperation()->get(), 2530 mlirOpResultGetOwner(self.get())) && 2531 "expected the owner of the value in Python to match that in " 2532 "the IR"); 2533 return self.getParentOperation().getObject(); 2534 }) 2535 .def("__eq__", 2536 [](PyValue &self, PyValue &other) { 2537 return self.get().ptr == other.get().ptr; 2538 }) 2539 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2540 .def( 2541 "__str__", 2542 [](PyValue &self) { 2543 PyPrintAccumulator printAccum; 2544 printAccum.parts.append("Value("); 2545 mlirValuePrint(self.get(), printAccum.getCallback(), 2546 printAccum.getUserData()); 2547 printAccum.parts.append(")"); 2548 return printAccum.join(); 2549 }, 2550 kValueDunderStrDocstring) 2551 .def_property_readonly("type", [](PyValue &self) { 2552 return PyType(self.getParentOperation()->getContext(), 2553 mlirValueGetType(self.get())); 2554 }); 2555 PyBlockArgument::bind(m); 2556 PyOpResult::bind(m); 2557 2558 // Container bindings. 2559 PyBlockArgumentList::bind(m); 2560 PyBlockIterator::bind(m); 2561 PyBlockList::bind(m); 2562 PyOperationIterator::bind(m); 2563 PyOperationList::bind(m); 2564 PyOpAttributeMap::bind(m); 2565 PyOpOperandList::bind(m); 2566 PyOpResultList::bind(m); 2567 PyRegionIterator::bind(m); 2568 PyRegionList::bind(m); 2569 2570 // Debug bindings. 2571 PyGlobalDebugFlag::bind(m); 2572 } 2573