1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include <pybind11/stl.h> 22 23 namespace py = pybind11; 24 using namespace mlir; 25 using namespace mlir::python; 26 27 using llvm::SmallVector; 28 using llvm::StringRef; 29 using llvm::Twine; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kContextParseTypeDocstring[] = 36 R"(Parses the assembly form of a type. 37 38 Returns a Type object or raises a ValueError if the type cannot be parsed. 39 40 See also: https://mlir.llvm.org/docs/LangRef/#type-system 41 )"; 42 43 static const char kContextGetFileLocationDocstring[] = 44 R"(Gets a Location representing a file, line and column)"; 45 46 static const char kModuleParseDocstring[] = 47 R"(Parses a module's assembly format from a string. 48 49 Returns a new MlirModule or raises a ValueError if the parsing fails. 50 51 See also: https://mlir.llvm.org/docs/LangRef/ 52 )"; 53 54 static const char kOperationCreateDocstring[] = 55 R"(Creates a new operation. 56 57 Args: 58 name: Operation name (e.g. "dialect.operation"). 59 results: Sequence of Type representing op result types. 60 attributes: Dict of str:Attribute. 61 successors: List of Block for the operation's successors. 62 regions: Number of regions to create. 63 location: A Location object (defaults to resolve from context manager). 64 ip: An InsertionPoint (defaults to resolve from context manager or set to 65 False to disable insertion, even with an insertion point set in the 66 context manager). 67 Returns: 68 A new "detached" Operation object. Detached operations can be added 69 to blocks, which causes them to become "attached." 70 )"; 71 72 static const char kOperationPrintDocstring[] = 73 R"(Prints the assembly form of the operation to a file like object. 74 75 Args: 76 file: The file like object to write to. Defaults to sys.stdout. 77 binary: Whether to write bytes (True) or str (False). Defaults to False. 78 large_elements_limit: Whether to elide elements attributes above this 79 number of elements. Defaults to None (no limit). 80 enable_debug_info: Whether to print debug/location information. Defaults 81 to False. 82 pretty_debug_info: Whether to format debug information for easier reading 83 by a human (warning: the result is unparseable). 84 print_generic_op_form: Whether to print the generic assembly forms of all 85 ops. Defaults to False. 86 use_local_Scope: Whether to print in a way that is more optimized for 87 multi-threaded access but may not be consistent with how the overall 88 module prints. 89 )"; 90 91 static const char kOperationGetAsmDocstring[] = 92 R"(Gets the assembly form of the operation with all options available. 93 94 Args: 95 binary: Whether to return a bytes (True) or str (False) object. Defaults to 96 False. 97 ... others ...: See the print() method for common keyword arguments for 98 configuring the printout. 99 Returns: 100 Either a bytes or str object, depending on the setting of the 'binary' 101 argument. 102 )"; 103 104 static const char kOperationStrDunderDocstring[] = 105 R"(Gets the assembly form of the operation with default options. 106 107 If more advanced control over the assembly formatting or I/O options is needed, 108 use the dedicated print or get_asm method, which supports keyword arguments to 109 customize behavior. 110 )"; 111 112 static const char kDumpDocstring[] = 113 R"(Dumps a debug representation of the object to stderr.)"; 114 115 static const char kAppendBlockDocstring[] = 116 R"(Appends a new block, with argument types as positional args. 117 118 Returns: 119 The created block. 120 )"; 121 122 static const char kValueDunderStrDocstring[] = 123 R"(Returns the string form of the value. 124 125 If the value is a block argument, this is the assembly form of its type and the 126 position in the argument list. If the value is an operation result, this is 127 equivalent to printing the operation that produced it. 128 )"; 129 130 //------------------------------------------------------------------------------ 131 // Utilities. 132 //------------------------------------------------------------------------------ 133 134 /// Helper for creating an @classmethod. 135 template <class Func, typename... Args> 136 py::object classmethod(Func f, Args... args) { 137 py::object cf = py::cpp_function(f, args...); 138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 139 } 140 141 static py::object 142 createCustomDialectWrapper(const std::string &dialectNamespace, 143 py::object dialectDescriptor) { 144 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 145 if (!dialectClass) { 146 // Use the base class. 147 return py::cast(PyDialect(std::move(dialectDescriptor))); 148 } 149 150 // Create the custom implementation. 151 return (*dialectClass)(std::move(dialectDescriptor)); 152 } 153 154 static MlirStringRef toMlirStringRef(const std::string &s) { 155 return mlirStringRefCreate(s.data(), s.size()); 156 } 157 158 /// Wrapper for the global LLVM debugging flag. 159 struct PyGlobalDebugFlag { 160 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 161 162 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 163 164 static void bind(py::module &m) { 165 // Debug flags. 166 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") 167 .def_property_static("flag", &PyGlobalDebugFlag::get, 168 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 169 } 170 }; 171 172 //------------------------------------------------------------------------------ 173 // Collections. 174 //------------------------------------------------------------------------------ 175 176 namespace { 177 178 class PyRegionIterator { 179 public: 180 PyRegionIterator(PyOperationRef operation) 181 : operation(std::move(operation)) {} 182 183 PyRegionIterator &dunderIter() { return *this; } 184 185 PyRegion dunderNext() { 186 operation->checkValid(); 187 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 188 throw py::stop_iteration(); 189 } 190 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 191 return PyRegion(operation, region); 192 } 193 194 static void bind(py::module &m) { 195 py::class_<PyRegionIterator>(m, "RegionIterator") 196 .def("__iter__", &PyRegionIterator::dunderIter) 197 .def("__next__", &PyRegionIterator::dunderNext); 198 } 199 200 private: 201 PyOperationRef operation; 202 int nextIndex = 0; 203 }; 204 205 /// Regions of an op are fixed length and indexed numerically so are represented 206 /// with a sequence-like container. 207 class PyRegionList { 208 public: 209 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 210 211 intptr_t dunderLen() { 212 operation->checkValid(); 213 return mlirOperationGetNumRegions(operation->get()); 214 } 215 216 PyRegion dunderGetItem(intptr_t index) { 217 // dunderLen checks validity. 218 if (index < 0 || index >= dunderLen()) { 219 throw SetPyError(PyExc_IndexError, 220 "attempt to access out of bounds region"); 221 } 222 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 223 return PyRegion(operation, region); 224 } 225 226 static void bind(py::module &m) { 227 py::class_<PyRegionList>(m, "RegionSequence") 228 .def("__len__", &PyRegionList::dunderLen) 229 .def("__getitem__", &PyRegionList::dunderGetItem); 230 } 231 232 private: 233 PyOperationRef operation; 234 }; 235 236 class PyBlockIterator { 237 public: 238 PyBlockIterator(PyOperationRef operation, MlirBlock next) 239 : operation(std::move(operation)), next(next) {} 240 241 PyBlockIterator &dunderIter() { return *this; } 242 243 PyBlock dunderNext() { 244 operation->checkValid(); 245 if (mlirBlockIsNull(next)) { 246 throw py::stop_iteration(); 247 } 248 249 PyBlock returnBlock(operation, next); 250 next = mlirBlockGetNextInRegion(next); 251 return returnBlock; 252 } 253 254 static void bind(py::module &m) { 255 py::class_<PyBlockIterator>(m, "BlockIterator") 256 .def("__iter__", &PyBlockIterator::dunderIter) 257 .def("__next__", &PyBlockIterator::dunderNext); 258 } 259 260 private: 261 PyOperationRef operation; 262 MlirBlock next; 263 }; 264 265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 266 /// we present them as a more full-featured list-like container but optimize 267 /// it for forward iteration. Blocks are always owned by a region. 268 class PyBlockList { 269 public: 270 PyBlockList(PyOperationRef operation, MlirRegion region) 271 : operation(std::move(operation)), region(region) {} 272 273 PyBlockIterator dunderIter() { 274 operation->checkValid(); 275 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 276 } 277 278 intptr_t dunderLen() { 279 operation->checkValid(); 280 intptr_t count = 0; 281 MlirBlock block = mlirRegionGetFirstBlock(region); 282 while (!mlirBlockIsNull(block)) { 283 count += 1; 284 block = mlirBlockGetNextInRegion(block); 285 } 286 return count; 287 } 288 289 PyBlock dunderGetItem(intptr_t index) { 290 operation->checkValid(); 291 if (index < 0) { 292 throw SetPyError(PyExc_IndexError, 293 "attempt to access out of bounds block"); 294 } 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 if (index == 0) { 298 return PyBlock(operation, block); 299 } 300 block = mlirBlockGetNextInRegion(block); 301 index -= 1; 302 } 303 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 304 } 305 306 PyBlock appendBlock(py::args pyArgTypes) { 307 operation->checkValid(); 308 llvm::SmallVector<MlirType, 4> argTypes; 309 argTypes.reserve(pyArgTypes.size()); 310 for (auto &pyArg : pyArgTypes) { 311 argTypes.push_back(pyArg.cast<PyType &>()); 312 } 313 314 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 315 mlirRegionAppendOwnedBlock(region, block); 316 return PyBlock(operation, block); 317 } 318 319 static void bind(py::module &m) { 320 py::class_<PyBlockList>(m, "BlockList") 321 .def("__getitem__", &PyBlockList::dunderGetItem) 322 .def("__iter__", &PyBlockList::dunderIter) 323 .def("__len__", &PyBlockList::dunderLen) 324 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 325 } 326 327 private: 328 PyOperationRef operation; 329 MlirRegion region; 330 }; 331 332 class PyOperationIterator { 333 public: 334 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 335 : parentOperation(std::move(parentOperation)), next(next) {} 336 337 PyOperationIterator &dunderIter() { return *this; } 338 339 py::object dunderNext() { 340 parentOperation->checkValid(); 341 if (mlirOperationIsNull(next)) { 342 throw py::stop_iteration(); 343 } 344 345 PyOperationRef returnOperation = 346 PyOperation::forOperation(parentOperation->getContext(), next); 347 next = mlirOperationGetNextInBlock(next); 348 return returnOperation->createOpView(); 349 } 350 351 static void bind(py::module &m) { 352 py::class_<PyOperationIterator>(m, "OperationIterator") 353 .def("__iter__", &PyOperationIterator::dunderIter) 354 .def("__next__", &PyOperationIterator::dunderNext); 355 } 356 357 private: 358 PyOperationRef parentOperation; 359 MlirOperation next; 360 }; 361 362 /// Operations are exposed by the C-API as a forward-only linked list. In 363 /// Python, we present them as a more full-featured list-like container but 364 /// optimize it for forward iteration. Iterable operations are always owned 365 /// by a block. 366 class PyOperationList { 367 public: 368 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 369 : parentOperation(std::move(parentOperation)), block(block) {} 370 371 PyOperationIterator dunderIter() { 372 parentOperation->checkValid(); 373 return PyOperationIterator(parentOperation, 374 mlirBlockGetFirstOperation(block)); 375 } 376 377 intptr_t dunderLen() { 378 parentOperation->checkValid(); 379 intptr_t count = 0; 380 MlirOperation childOp = mlirBlockGetFirstOperation(block); 381 while (!mlirOperationIsNull(childOp)) { 382 count += 1; 383 childOp = mlirOperationGetNextInBlock(childOp); 384 } 385 return count; 386 } 387 388 py::object dunderGetItem(intptr_t index) { 389 parentOperation->checkValid(); 390 if (index < 0) { 391 throw SetPyError(PyExc_IndexError, 392 "attempt to access out of bounds operation"); 393 } 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 if (index == 0) { 397 return PyOperation::forOperation(parentOperation->getContext(), childOp) 398 ->createOpView(); 399 } 400 childOp = mlirOperationGetNextInBlock(childOp); 401 index -= 1; 402 } 403 throw SetPyError(PyExc_IndexError, 404 "attempt to access out of bounds operation"); 405 } 406 407 static void bind(py::module &m) { 408 py::class_<PyOperationList>(m, "OperationList") 409 .def("__getitem__", &PyOperationList::dunderGetItem) 410 .def("__iter__", &PyOperationList::dunderIter) 411 .def("__len__", &PyOperationList::dunderLen); 412 } 413 414 private: 415 PyOperationRef parentOperation; 416 MlirBlock block; 417 }; 418 419 } // namespace 420 421 //------------------------------------------------------------------------------ 422 // PyMlirContext 423 //------------------------------------------------------------------------------ 424 425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 426 py::gil_scoped_acquire acquire; 427 auto &liveContexts = getLiveContexts(); 428 liveContexts[context.ptr] = this; 429 } 430 431 PyMlirContext::~PyMlirContext() { 432 // Note that the only public way to construct an instance is via the 433 // forContext method, which always puts the associated handle into 434 // liveContexts. 435 py::gil_scoped_acquire acquire; 436 getLiveContexts().erase(context.ptr); 437 mlirContextDestroy(context); 438 } 439 440 py::object PyMlirContext::getCapsule() { 441 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 442 } 443 444 py::object PyMlirContext::createFromCapsule(py::object capsule) { 445 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 446 if (mlirContextIsNull(rawContext)) 447 throw py::error_already_set(); 448 return forContext(rawContext).releaseObject(); 449 } 450 451 PyMlirContext *PyMlirContext::createNewContextForInit() { 452 MlirContext context = mlirContextCreate(); 453 mlirRegisterAllDialects(context); 454 return new PyMlirContext(context); 455 } 456 457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 458 py::gil_scoped_acquire acquire; 459 auto &liveContexts = getLiveContexts(); 460 auto it = liveContexts.find(context.ptr); 461 if (it == liveContexts.end()) { 462 // Create. 463 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 464 py::object pyRef = py::cast(unownedContextWrapper); 465 assert(pyRef && "cast to py::object failed"); 466 liveContexts[context.ptr] = unownedContextWrapper; 467 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 468 } 469 // Use existing. 470 py::object pyRef = py::cast(it->second); 471 return PyMlirContextRef(it->second, std::move(pyRef)); 472 } 473 474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 475 static LiveContextMap liveContexts; 476 return liveContexts; 477 } 478 479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 480 481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 482 483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 484 485 pybind11::object PyMlirContext::contextEnter() { 486 return PyThreadContextEntry::pushContext(*this); 487 } 488 489 void PyMlirContext::contextExit(pybind11::object excType, 490 pybind11::object excVal, 491 pybind11::object excTb) { 492 PyThreadContextEntry::popContext(*this); 493 } 494 495 PyMlirContext &DefaultingPyMlirContext::resolve() { 496 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 497 if (!context) { 498 throw SetPyError( 499 PyExc_RuntimeError, 500 "An MLIR function requires a Context but none was provided in the call " 501 "or from the surrounding environment. Either pass to the function with " 502 "a 'context=' argument or establish a default using 'with Context():'"); 503 } 504 return *context; 505 } 506 507 //------------------------------------------------------------------------------ 508 // PyThreadContextEntry management 509 //------------------------------------------------------------------------------ 510 511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 512 static thread_local std::vector<PyThreadContextEntry> stack; 513 return stack; 514 } 515 516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 517 auto &stack = getStack(); 518 if (stack.empty()) 519 return nullptr; 520 return &stack.back(); 521 } 522 523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 524 py::object insertionPoint, 525 py::object location) { 526 auto &stack = getStack(); 527 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 528 std::move(location)); 529 // If the new stack has more than one entry and the context of the new top 530 // entry matches the previous, copy the insertionPoint and location from the 531 // previous entry if missing from the new top entry. 532 if (stack.size() > 1) { 533 auto &prev = *(stack.rbegin() + 1); 534 auto ¤t = stack.back(); 535 if (current.context.is(prev.context)) { 536 // Default non-context objects from the previous entry. 537 if (!current.insertionPoint) 538 current.insertionPoint = prev.insertionPoint; 539 if (!current.location) 540 current.location = prev.location; 541 } 542 } 543 } 544 545 PyMlirContext *PyThreadContextEntry::getContext() { 546 if (!context) 547 return nullptr; 548 return py::cast<PyMlirContext *>(context); 549 } 550 551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 552 if (!insertionPoint) 553 return nullptr; 554 return py::cast<PyInsertionPoint *>(insertionPoint); 555 } 556 557 PyLocation *PyThreadContextEntry::getLocation() { 558 if (!location) 559 return nullptr; 560 return py::cast<PyLocation *>(location); 561 } 562 563 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 564 auto *tos = getTopOfStack(); 565 return tos ? tos->getContext() : nullptr; 566 } 567 568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 569 auto *tos = getTopOfStack(); 570 return tos ? tos->getInsertionPoint() : nullptr; 571 } 572 573 PyLocation *PyThreadContextEntry::getDefaultLocation() { 574 auto *tos = getTopOfStack(); 575 return tos ? tos->getLocation() : nullptr; 576 } 577 578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 579 py::object contextObj = py::cast(context); 580 push(FrameKind::Context, /*context=*/contextObj, 581 /*insertionPoint=*/py::object(), 582 /*location=*/py::object()); 583 return contextObj; 584 } 585 586 void PyThreadContextEntry::popContext(PyMlirContext &context) { 587 auto &stack = getStack(); 588 if (stack.empty()) 589 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 590 auto &tos = stack.back(); 591 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593 stack.pop_back(); 594 } 595 596 py::object 597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 598 py::object contextObj = 599 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 600 py::object insertionPointObj = py::cast(insertionPoint); 601 push(FrameKind::InsertionPoint, 602 /*context=*/contextObj, 603 /*insertionPoint=*/insertionPointObj, 604 /*location=*/py::object()); 605 return insertionPointObj; 606 } 607 608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 609 auto &stack = getStack(); 610 if (stack.empty()) 611 throw SetPyError(PyExc_RuntimeError, 612 "Unbalanced InsertionPoint enter/exit"); 613 auto &tos = stack.back(); 614 if (tos.frameKind != FrameKind::InsertionPoint && 615 tos.getInsertionPoint() != &insertionPoint) 616 throw SetPyError(PyExc_RuntimeError, 617 "Unbalanced InsertionPoint enter/exit"); 618 stack.pop_back(); 619 } 620 621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 622 py::object contextObj = location.getContext().getObject(); 623 py::object locationObj = py::cast(location); 624 push(FrameKind::Location, /*context=*/contextObj, 625 /*insertionPoint=*/py::object(), 626 /*location=*/locationObj); 627 return locationObj; 628 } 629 630 void PyThreadContextEntry::popLocation(PyLocation &location) { 631 auto &stack = getStack(); 632 if (stack.empty()) 633 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 634 auto &tos = stack.back(); 635 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637 stack.pop_back(); 638 } 639 640 //------------------------------------------------------------------------------ 641 // PyDialect, PyDialectDescriptor, PyDialects 642 //------------------------------------------------------------------------------ 643 644 MlirDialect PyDialects::getDialectForKey(const std::string &key, 645 bool attrError) { 646 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 647 {key.data(), key.size()}); 648 if (mlirDialectIsNull(dialect)) { 649 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 650 Twine("Dialect '") + key + "' not found"); 651 } 652 return dialect; 653 } 654 655 //------------------------------------------------------------------------------ 656 // PyLocation 657 //------------------------------------------------------------------------------ 658 659 py::object PyLocation::getCapsule() { 660 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 661 } 662 663 PyLocation PyLocation::createFromCapsule(py::object capsule) { 664 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 665 if (mlirLocationIsNull(rawLoc)) 666 throw py::error_already_set(); 667 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 668 rawLoc); 669 } 670 671 py::object PyLocation::contextEnter() { 672 return PyThreadContextEntry::pushLocation(*this); 673 } 674 675 void PyLocation::contextExit(py::object excType, py::object excVal, 676 py::object excTb) { 677 PyThreadContextEntry::popLocation(*this); 678 } 679 680 PyLocation &DefaultingPyLocation::resolve() { 681 auto *location = PyThreadContextEntry::getDefaultLocation(); 682 if (!location) { 683 throw SetPyError( 684 PyExc_RuntimeError, 685 "An MLIR function requires a Location but none was provided in the " 686 "call or from the surrounding environment. Either pass to the function " 687 "with a 'loc=' argument or establish a default using 'with loc:'"); 688 } 689 return *location; 690 } 691 692 //------------------------------------------------------------------------------ 693 // PyModule 694 //------------------------------------------------------------------------------ 695 696 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 697 : BaseContextObject(std::move(contextRef)), module(module) {} 698 699 PyModule::~PyModule() { 700 py::gil_scoped_acquire acquire; 701 auto &liveModules = getContext()->liveModules; 702 assert(liveModules.count(module.ptr) == 1 && 703 "destroying module not in live map"); 704 liveModules.erase(module.ptr); 705 mlirModuleDestroy(module); 706 } 707 708 PyModuleRef PyModule::forModule(MlirModule module) { 709 MlirContext context = mlirModuleGetContext(module); 710 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 711 712 py::gil_scoped_acquire acquire; 713 auto &liveModules = contextRef->liveModules; 714 auto it = liveModules.find(module.ptr); 715 if (it == liveModules.end()) { 716 // Create. 717 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 718 // Note that the default return value policy on cast is automatic_reference, 719 // which does not take ownership (delete will not be called). 720 // Just be explicit. 721 py::object pyRef = 722 py::cast(unownedModule, py::return_value_policy::take_ownership); 723 unownedModule->handle = pyRef; 724 liveModules[module.ptr] = 725 std::make_pair(unownedModule->handle, unownedModule); 726 return PyModuleRef(unownedModule, std::move(pyRef)); 727 } 728 // Use existing. 729 PyModule *existing = it->second.second; 730 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 731 return PyModuleRef(existing, std::move(pyRef)); 732 } 733 734 py::object PyModule::createFromCapsule(py::object capsule) { 735 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 736 if (mlirModuleIsNull(rawModule)) 737 throw py::error_already_set(); 738 return forModule(rawModule).releaseObject(); 739 } 740 741 py::object PyModule::getCapsule() { 742 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 743 } 744 745 //------------------------------------------------------------------------------ 746 // PyOperation 747 //------------------------------------------------------------------------------ 748 749 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 750 : BaseContextObject(std::move(contextRef)), operation(operation) {} 751 752 PyOperation::~PyOperation() { 753 // If the operation has already been invalidated there is nothing to do. 754 if (!valid) 755 return; 756 auto &liveOperations = getContext()->liveOperations; 757 assert(liveOperations.count(operation.ptr) == 1 && 758 "destroying operation not in live map"); 759 liveOperations.erase(operation.ptr); 760 if (!isAttached()) { 761 mlirOperationDestroy(operation); 762 } 763 } 764 765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 766 MlirOperation operation, 767 py::object parentKeepAlive) { 768 auto &liveOperations = contextRef->liveOperations; 769 // Create. 770 PyOperation *unownedOperation = 771 new PyOperation(std::move(contextRef), operation); 772 // Note that the default return value policy on cast is automatic_reference, 773 // which does not take ownership (delete will not be called). 774 // Just be explicit. 775 py::object pyRef = 776 py::cast(unownedOperation, py::return_value_policy::take_ownership); 777 unownedOperation->handle = pyRef; 778 if (parentKeepAlive) { 779 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 780 } 781 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 782 return PyOperationRef(unownedOperation, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 auto it = liveOperations.find(operation.ptr); 790 if (it == liveOperations.end()) { 791 // Create. 792 return createInstance(std::move(contextRef), operation, 793 std::move(parentKeepAlive)); 794 } 795 // Use existing. 796 PyOperation *existing = it->second.second; 797 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 798 return PyOperationRef(existing, std::move(pyRef)); 799 } 800 801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 802 MlirOperation operation, 803 py::object parentKeepAlive) { 804 auto &liveOperations = contextRef->liveOperations; 805 assert(liveOperations.count(operation.ptr) == 0 && 806 "cannot create detached operation that already exists"); 807 (void)liveOperations; 808 809 PyOperationRef created = createInstance(std::move(contextRef), operation, 810 std::move(parentKeepAlive)); 811 created->attached = false; 812 return created; 813 } 814 815 void PyOperation::checkValid() const { 816 if (!valid) { 817 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 818 } 819 } 820 821 void PyOperationBase::print(py::object fileObject, bool binary, 822 llvm::Optional<int64_t> largeElementsLimit, 823 bool enableDebugInfo, bool prettyDebugInfo, 824 bool printGenericOpForm, bool useLocalScope) { 825 PyOperation &operation = getOperation(); 826 operation.checkValid(); 827 if (fileObject.is_none()) 828 fileObject = py::module::import("sys").attr("stdout"); 829 830 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 831 fileObject.attr("write")("// Verification failed, printing generic form\n"); 832 printGenericOpForm = true; 833 } 834 835 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 836 if (largeElementsLimit) 837 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 838 if (enableDebugInfo) 839 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 840 if (printGenericOpForm) 841 mlirOpPrintingFlagsPrintGenericOpForm(flags); 842 843 PyFileAccumulator accum(fileObject, binary); 844 py::gil_scoped_release(); 845 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 846 accum.getUserData()); 847 mlirOpPrintingFlagsDestroy(flags); 848 } 849 850 py::object PyOperationBase::getAsm(bool binary, 851 llvm::Optional<int64_t> largeElementsLimit, 852 bool enableDebugInfo, bool prettyDebugInfo, 853 bool printGenericOpForm, 854 bool useLocalScope) { 855 py::object fileObject; 856 if (binary) { 857 fileObject = py::module::import("io").attr("BytesIO")(); 858 } else { 859 fileObject = py::module::import("io").attr("StringIO")(); 860 } 861 print(fileObject, /*binary=*/binary, 862 /*largeElementsLimit=*/largeElementsLimit, 863 /*enableDebugInfo=*/enableDebugInfo, 864 /*prettyDebugInfo=*/prettyDebugInfo, 865 /*printGenericOpForm=*/printGenericOpForm, 866 /*useLocalScope=*/useLocalScope); 867 868 return fileObject.attr("getvalue")(); 869 } 870 871 PyOperationRef PyOperation::getParentOperation() { 872 checkValid(); 873 if (!isAttached()) 874 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 875 MlirOperation operation = mlirOperationGetParentOperation(get()); 876 if (mlirOperationIsNull(operation)) 877 throw SetPyError(PyExc_ValueError, "Operation has no parent."); 878 return PyOperation::forOperation(getContext(), operation); 879 } 880 881 PyBlock PyOperation::getBlock() { 882 checkValid(); 883 PyOperationRef parentOperation = getParentOperation(); 884 MlirBlock block = mlirOperationGetBlock(get()); 885 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 886 return PyBlock{std::move(parentOperation), block}; 887 } 888 889 py::object PyOperation::getCapsule() { 890 checkValid(); 891 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 892 } 893 894 py::object PyOperation::createFromCapsule(py::object capsule) { 895 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 896 if (mlirOperationIsNull(rawOperation)) 897 throw py::error_already_set(); 898 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 899 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 900 .releaseObject(); 901 } 902 903 py::object PyOperation::create( 904 std::string name, llvm::Optional<std::vector<PyType *>> results, 905 llvm::Optional<std::vector<PyValue *>> operands, 906 llvm::Optional<py::dict> attributes, 907 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 908 DefaultingPyLocation location, py::object maybeIp) { 909 llvm::SmallVector<MlirValue, 4> mlirOperands; 910 llvm::SmallVector<MlirType, 4> mlirResults; 911 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 912 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 913 914 // General parameter validation. 915 if (regions < 0) 916 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 917 918 // Unpack/validate operands. 919 if (operands) { 920 mlirOperands.reserve(operands->size()); 921 for (PyValue *operand : *operands) { 922 if (!operand) 923 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 924 mlirOperands.push_back(operand->get()); 925 } 926 } 927 928 // Unpack/validate results. 929 if (results) { 930 mlirResults.reserve(results->size()); 931 for (PyType *result : *results) { 932 // TODO: Verify result type originate from the same context. 933 if (!result) 934 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 935 mlirResults.push_back(*result); 936 } 937 } 938 // Unpack/validate attributes. 939 if (attributes) { 940 mlirAttributes.reserve(attributes->size()); 941 for (auto &it : *attributes) { 942 std::string key; 943 try { 944 key = it.first.cast<std::string>(); 945 } catch (py::cast_error &err) { 946 std::string msg = "Invalid attribute key (not a string) when " 947 "attempting to create the operation \"" + 948 name + "\" (" + err.what() + ")"; 949 throw py::cast_error(msg); 950 } 951 try { 952 auto &attribute = it.second.cast<PyAttribute &>(); 953 // TODO: Verify attribute originates from the same context. 954 mlirAttributes.emplace_back(std::move(key), attribute); 955 } catch (py::reference_cast_error &) { 956 // This exception seems thrown when the value is "None". 957 std::string msg = 958 "Found an invalid (`None`?) attribute value for the key \"" + key + 959 "\" when attempting to create the operation \"" + name + "\""; 960 throw py::cast_error(msg); 961 } catch (py::cast_error &err) { 962 std::string msg = "Invalid attribute value for the key \"" + key + 963 "\" when attempting to create the operation \"" + 964 name + "\" (" + err.what() + ")"; 965 throw py::cast_error(msg); 966 } 967 } 968 } 969 // Unpack/validate successors. 970 if (successors) { 971 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 972 mlirSuccessors.reserve(successors->size()); 973 for (auto *successor : *successors) { 974 // TODO: Verify successor originate from the same context. 975 if (!successor) 976 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 977 mlirSuccessors.push_back(successor->get()); 978 } 979 } 980 981 // Apply unpacked/validated to the operation state. Beyond this 982 // point, exceptions cannot be thrown or else the state will leak. 983 MlirOperationState state = 984 mlirOperationStateGet(toMlirStringRef(name), location); 985 if (!mlirOperands.empty()) 986 mlirOperationStateAddOperands(&state, mlirOperands.size(), 987 mlirOperands.data()); 988 if (!mlirResults.empty()) 989 mlirOperationStateAddResults(&state, mlirResults.size(), 990 mlirResults.data()); 991 if (!mlirAttributes.empty()) { 992 // Note that the attribute names directly reference bytes in 993 // mlirAttributes, so that vector must not be changed from here 994 // on. 995 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 996 mlirNamedAttributes.reserve(mlirAttributes.size()); 997 for (auto &it : mlirAttributes) 998 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 999 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1000 toMlirStringRef(it.first)), 1001 it.second)); 1002 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1003 mlirNamedAttributes.data()); 1004 } 1005 if (!mlirSuccessors.empty()) 1006 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1007 mlirSuccessors.data()); 1008 if (regions) { 1009 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1010 mlirRegions.resize(regions); 1011 for (int i = 0; i < regions; ++i) 1012 mlirRegions[i] = mlirRegionCreate(); 1013 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1014 mlirRegions.data()); 1015 } 1016 1017 // Construct the operation. 1018 MlirOperation operation = mlirOperationCreate(&state); 1019 PyOperationRef created = 1020 PyOperation::createDetached(location->getContext(), operation); 1021 1022 // InsertPoint active? 1023 if (!maybeIp.is(py::cast(false))) { 1024 PyInsertionPoint *ip; 1025 if (maybeIp.is_none()) { 1026 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1027 } else { 1028 ip = py::cast<PyInsertionPoint *>(maybeIp); 1029 } 1030 if (ip) 1031 ip->insert(*created.get()); 1032 } 1033 1034 return created->createOpView(); 1035 } 1036 1037 py::object PyOperation::createOpView() { 1038 checkValid(); 1039 MlirIdentifier ident = mlirOperationGetName(get()); 1040 MlirStringRef identStr = mlirIdentifierStr(ident); 1041 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1042 StringRef(identStr.data, identStr.length)); 1043 if (opViewClass) 1044 return (*opViewClass)(getRef().getObject()); 1045 return py::cast(PyOpView(getRef().getObject())); 1046 } 1047 1048 void PyOperation::erase() { 1049 checkValid(); 1050 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1051 // Python reference to a child operation is live. All children should also 1052 // have their `valid` bit set to false. 1053 auto &liveOperations = getContext()->liveOperations; 1054 if (liveOperations.count(operation.ptr)) 1055 liveOperations.erase(operation.ptr); 1056 mlirOperationDestroy(operation); 1057 valid = false; 1058 } 1059 1060 //------------------------------------------------------------------------------ 1061 // PyOpView 1062 //------------------------------------------------------------------------------ 1063 1064 py::object 1065 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1066 py::list operandList, 1067 llvm::Optional<py::dict> attributes, 1068 llvm::Optional<std::vector<PyBlock *>> successors, 1069 llvm::Optional<int> regions, 1070 DefaultingPyLocation location, py::object maybeIp) { 1071 PyMlirContextRef context = location->getContext(); 1072 // Class level operation construction metadata. 1073 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1074 // Operand and result segment specs are either none, which does no 1075 // variadic unpacking, or a list of ints with segment sizes, where each 1076 // element is either a positive number (typically 1 for a scalar) or -1 to 1077 // indicate that it is derived from the length of the same-indexed operand 1078 // or result (implying that it is a list at that position). 1079 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1080 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1081 1082 std::vector<uint32_t> operandSegmentLengths; 1083 std::vector<uint32_t> resultSegmentLengths; 1084 1085 // Validate/determine region count. 1086 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1087 int opMinRegionCount = std::get<0>(opRegionSpec); 1088 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1089 if (!regions) { 1090 regions = opMinRegionCount; 1091 } 1092 if (*regions < opMinRegionCount) { 1093 throw py::value_error( 1094 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1095 llvm::Twine(opMinRegionCount) + 1096 " regions but was built with regions=" + llvm::Twine(*regions)) 1097 .str()); 1098 } 1099 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1100 throw py::value_error( 1101 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1102 llvm::Twine(opMinRegionCount) + 1103 " regions but was built with regions=" + llvm::Twine(*regions)) 1104 .str()); 1105 } 1106 1107 // Unpack results. 1108 std::vector<PyType *> resultTypes; 1109 resultTypes.reserve(resultTypeList.size()); 1110 if (resultSegmentSpecObj.is_none()) { 1111 // Non-variadic result unpacking. 1112 for (auto it : llvm::enumerate(resultTypeList)) { 1113 try { 1114 resultTypes.push_back(py::cast<PyType *>(it.value())); 1115 if (!resultTypes.back()) 1116 throw py::cast_error(); 1117 } catch (py::cast_error &err) { 1118 throw py::value_error((llvm::Twine("Result ") + 1119 llvm::Twine(it.index()) + " of operation \"" + 1120 name + "\" must be a Type (" + err.what() + ")") 1121 .str()); 1122 } 1123 } 1124 } else { 1125 // Sized result unpacking. 1126 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1127 if (resultSegmentSpec.size() != resultTypeList.size()) { 1128 throw py::value_error((llvm::Twine("Operation \"") + name + 1129 "\" requires " + 1130 llvm::Twine(resultSegmentSpec.size()) + 1131 "result segments but was provided " + 1132 llvm::Twine(resultTypeList.size())) 1133 .str()); 1134 } 1135 resultSegmentLengths.reserve(resultTypeList.size()); 1136 for (auto it : 1137 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1138 int segmentSpec = std::get<1>(it.value()); 1139 if (segmentSpec == 1 || segmentSpec == 0) { 1140 // Unpack unary element. 1141 try { 1142 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1143 if (resultType) { 1144 resultTypes.push_back(resultType); 1145 resultSegmentLengths.push_back(1); 1146 } else if (segmentSpec == 0) { 1147 // Allowed to be optional. 1148 resultSegmentLengths.push_back(0); 1149 } else { 1150 throw py::cast_error("was None and result is not optional"); 1151 } 1152 } catch (py::cast_error &err) { 1153 throw py::value_error((llvm::Twine("Result ") + 1154 llvm::Twine(it.index()) + " of operation \"" + 1155 name + "\" must be a Type (" + err.what() + 1156 ")") 1157 .str()); 1158 } 1159 } else if (segmentSpec == -1) { 1160 // Unpack sequence by appending. 1161 try { 1162 if (std::get<0>(it.value()).is_none()) { 1163 // Treat it as an empty list. 1164 resultSegmentLengths.push_back(0); 1165 } else { 1166 // Unpack the list. 1167 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1168 for (py::object segmentItem : segment) { 1169 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1170 if (!resultTypes.back()) { 1171 throw py::cast_error("contained a None item"); 1172 } 1173 } 1174 resultSegmentLengths.push_back(segment.size()); 1175 } 1176 } catch (std::exception &err) { 1177 // NOTE: Sloppy to be using a catch-all here, but there are at least 1178 // three different unrelated exceptions that can be thrown in the 1179 // above "casts". Just keep the scope above small and catch them all. 1180 throw py::value_error((llvm::Twine("Result ") + 1181 llvm::Twine(it.index()) + " of operation \"" + 1182 name + "\" must be a Sequence of Types (" + 1183 err.what() + ")") 1184 .str()); 1185 } 1186 } else { 1187 throw py::value_error("Unexpected segment spec"); 1188 } 1189 } 1190 } 1191 1192 // Unpack operands. 1193 std::vector<PyValue *> operands; 1194 operands.reserve(operands.size()); 1195 if (operandSegmentSpecObj.is_none()) { 1196 // Non-sized operand unpacking. 1197 for (auto it : llvm::enumerate(operandList)) { 1198 try { 1199 operands.push_back(py::cast<PyValue *>(it.value())); 1200 if (!operands.back()) 1201 throw py::cast_error(); 1202 } catch (py::cast_error &err) { 1203 throw py::value_error((llvm::Twine("Operand ") + 1204 llvm::Twine(it.index()) + " of operation \"" + 1205 name + "\" must be a Value (" + err.what() + ")") 1206 .str()); 1207 } 1208 } 1209 } else { 1210 // Sized operand unpacking. 1211 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1212 if (operandSegmentSpec.size() != operandList.size()) { 1213 throw py::value_error((llvm::Twine("Operation \"") + name + 1214 "\" requires " + 1215 llvm::Twine(operandSegmentSpec.size()) + 1216 "operand segments but was provided " + 1217 llvm::Twine(operandList.size())) 1218 .str()); 1219 } 1220 operandSegmentLengths.reserve(operandList.size()); 1221 for (auto it : 1222 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1223 int segmentSpec = std::get<1>(it.value()); 1224 if (segmentSpec == 1 || segmentSpec == 0) { 1225 // Unpack unary element. 1226 try { 1227 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1228 if (operandValue) { 1229 operands.push_back(operandValue); 1230 operandSegmentLengths.push_back(1); 1231 } else if (segmentSpec == 0) { 1232 // Allowed to be optional. 1233 operandSegmentLengths.push_back(0); 1234 } else { 1235 throw py::cast_error("was None and operand is not optional"); 1236 } 1237 } catch (py::cast_error &err) { 1238 throw py::value_error((llvm::Twine("Operand ") + 1239 llvm::Twine(it.index()) + " of operation \"" + 1240 name + "\" must be a Value (" + err.what() + 1241 ")") 1242 .str()); 1243 } 1244 } else if (segmentSpec == -1) { 1245 // Unpack sequence by appending. 1246 try { 1247 if (std::get<0>(it.value()).is_none()) { 1248 // Treat it as an empty list. 1249 operandSegmentLengths.push_back(0); 1250 } else { 1251 // Unpack the list. 1252 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1253 for (py::object segmentItem : segment) { 1254 operands.push_back(py::cast<PyValue *>(segmentItem)); 1255 if (!operands.back()) { 1256 throw py::cast_error("contained a None item"); 1257 } 1258 } 1259 operandSegmentLengths.push_back(segment.size()); 1260 } 1261 } catch (std::exception &err) { 1262 // NOTE: Sloppy to be using a catch-all here, but there are at least 1263 // three different unrelated exceptions that can be thrown in the 1264 // above "casts". Just keep the scope above small and catch them all. 1265 throw py::value_error((llvm::Twine("Operand ") + 1266 llvm::Twine(it.index()) + " of operation \"" + 1267 name + "\" must be a Sequence of Values (" + 1268 err.what() + ")") 1269 .str()); 1270 } 1271 } else { 1272 throw py::value_error("Unexpected segment spec"); 1273 } 1274 } 1275 } 1276 1277 // Merge operand/result segment lengths into attributes if needed. 1278 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1279 // Dup. 1280 if (attributes) { 1281 attributes = py::dict(*attributes); 1282 } else { 1283 attributes = py::dict(); 1284 } 1285 if (attributes->contains("result_segment_sizes") || 1286 attributes->contains("operand_segment_sizes")) { 1287 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1288 "'operand_segment_sizes' attribute is unsupported. " 1289 "Use Operation.create for such low-level access."); 1290 } 1291 1292 // Add result_segment_sizes attribute. 1293 if (!resultSegmentLengths.empty()) { 1294 int64_t size = resultSegmentLengths.size(); 1295 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1296 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1297 resultSegmentLengths.size(), resultSegmentLengths.data()); 1298 (*attributes)["result_segment_sizes"] = 1299 PyAttribute(context, segmentLengthAttr); 1300 } 1301 1302 // Add operand_segment_sizes attribute. 1303 if (!operandSegmentLengths.empty()) { 1304 int64_t size = operandSegmentLengths.size(); 1305 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1306 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1307 operandSegmentLengths.size(), operandSegmentLengths.data()); 1308 (*attributes)["operand_segment_sizes"] = 1309 PyAttribute(context, segmentLengthAttr); 1310 } 1311 } 1312 1313 // Delegate to create. 1314 return PyOperation::create(std::move(name), 1315 /*results=*/std::move(resultTypes), 1316 /*operands=*/std::move(operands), 1317 /*attributes=*/std::move(attributes), 1318 /*successors=*/std::move(successors), 1319 /*regions=*/*regions, location, maybeIp); 1320 } 1321 1322 PyOpView::PyOpView(py::object operationObject) 1323 // Casting through the PyOperationBase base-class and then back to the 1324 // Operation lets us accept any PyOperationBase subclass. 1325 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1326 operationObject(operation.getRef().getObject()) {} 1327 1328 py::object PyOpView::createRawSubclass(py::object userClass) { 1329 // This is... a little gross. The typical pattern is to have a pure python 1330 // class that extends OpView like: 1331 // class AddFOp(_cext.ir.OpView): 1332 // def __init__(self, loc, lhs, rhs): 1333 // operation = loc.context.create_operation( 1334 // "addf", lhs, rhs, results=[lhs.type]) 1335 // super().__init__(operation) 1336 // 1337 // I.e. The goal of the user facing type is to provide a nice constructor 1338 // that has complete freedom for the op under construction. This is at odds 1339 // with our other desire to sometimes create this object by just passing an 1340 // operation (to initialize the base class). We could do *arg and **kwargs 1341 // munging to try to make it work, but instead, we synthesize a new class 1342 // on the fly which extends this user class (AddFOp in this example) and 1343 // *give it* the base class's __init__ method, thus bypassing the 1344 // intermediate subclass's __init__ method entirely. While slightly, 1345 // underhanded, this is safe/legal because the type hierarchy has not changed 1346 // (we just added a new leaf) and we aren't mucking around with __new__. 1347 // Typically, this new class will be stored on the original as "_Raw" and will 1348 // be used for casts and other things that need a variant of the class that 1349 // is initialized purely from an operation. 1350 py::object parentMetaclass = 1351 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1352 py::dict attributes; 1353 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1354 // now. 1355 // auto opViewType = py::type::of<PyOpView>(); 1356 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1357 attributes["__init__"] = opViewType.attr("__init__"); 1358 py::str origName = userClass.attr("__name__"); 1359 py::str newName = py::str("_") + origName; 1360 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1361 } 1362 1363 //------------------------------------------------------------------------------ 1364 // PyInsertionPoint. 1365 //------------------------------------------------------------------------------ 1366 1367 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1368 1369 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1370 : refOperation(beforeOperationBase.getOperation().getRef()), 1371 block((*refOperation)->getBlock()) {} 1372 1373 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1374 PyOperation &operation = operationBase.getOperation(); 1375 if (operation.isAttached()) 1376 throw SetPyError(PyExc_ValueError, 1377 "Attempt to insert operation that is already attached"); 1378 block.getParentOperation()->checkValid(); 1379 MlirOperation beforeOp = {nullptr}; 1380 if (refOperation) { 1381 // Insert before operation. 1382 (*refOperation)->checkValid(); 1383 beforeOp = (*refOperation)->get(); 1384 } else { 1385 // Insert at end (before null) is only valid if the block does not 1386 // already end in a known terminator (violating this will cause assertion 1387 // failures later). 1388 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1389 throw py::index_error("Cannot insert operation at the end of a block " 1390 "that already has a terminator. Did you mean to " 1391 "use 'InsertionPoint.at_block_terminator(block)' " 1392 "versus 'InsertionPoint(block)'?"); 1393 } 1394 } 1395 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1396 operation.setAttached(); 1397 } 1398 1399 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1400 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1401 if (mlirOperationIsNull(firstOp)) { 1402 // Just insert at end. 1403 return PyInsertionPoint(block); 1404 } 1405 1406 // Insert before first op. 1407 PyOperationRef firstOpRef = PyOperation::forOperation( 1408 block.getParentOperation()->getContext(), firstOp); 1409 return PyInsertionPoint{block, std::move(firstOpRef)}; 1410 } 1411 1412 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1413 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1414 if (mlirOperationIsNull(terminator)) 1415 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1416 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1417 block.getParentOperation()->getContext(), terminator); 1418 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1419 } 1420 1421 py::object PyInsertionPoint::contextEnter() { 1422 return PyThreadContextEntry::pushInsertionPoint(*this); 1423 } 1424 1425 void PyInsertionPoint::contextExit(pybind11::object excType, 1426 pybind11::object excVal, 1427 pybind11::object excTb) { 1428 PyThreadContextEntry::popInsertionPoint(*this); 1429 } 1430 1431 //------------------------------------------------------------------------------ 1432 // PyAttribute. 1433 //------------------------------------------------------------------------------ 1434 1435 bool PyAttribute::operator==(const PyAttribute &other) { 1436 return mlirAttributeEqual(attr, other.attr); 1437 } 1438 1439 py::object PyAttribute::getCapsule() { 1440 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1441 } 1442 1443 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1444 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1445 if (mlirAttributeIsNull(rawAttr)) 1446 throw py::error_already_set(); 1447 return PyAttribute( 1448 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1449 } 1450 1451 //------------------------------------------------------------------------------ 1452 // PyNamedAttribute. 1453 //------------------------------------------------------------------------------ 1454 1455 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1456 : ownedName(new std::string(std::move(ownedName))) { 1457 namedAttr = mlirNamedAttributeGet( 1458 mlirIdentifierGet(mlirAttributeGetContext(attr), 1459 toMlirStringRef(*this->ownedName)), 1460 attr); 1461 } 1462 1463 //------------------------------------------------------------------------------ 1464 // PyType. 1465 //------------------------------------------------------------------------------ 1466 1467 bool PyType::operator==(const PyType &other) { 1468 return mlirTypeEqual(type, other.type); 1469 } 1470 1471 py::object PyType::getCapsule() { 1472 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1473 } 1474 1475 PyType PyType::createFromCapsule(py::object capsule) { 1476 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1477 if (mlirTypeIsNull(rawType)) 1478 throw py::error_already_set(); 1479 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1480 rawType); 1481 } 1482 1483 //------------------------------------------------------------------------------ 1484 // PyValue and subclases. 1485 //------------------------------------------------------------------------------ 1486 1487 pybind11::object PyValue::getCapsule() { 1488 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1489 } 1490 1491 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1492 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1493 if (mlirValueIsNull(value)) 1494 throw py::error_already_set(); 1495 MlirOperation owner; 1496 if (mlirValueIsAOpResult(value)) 1497 owner = mlirOpResultGetOwner(value); 1498 if (mlirValueIsABlockArgument(value)) 1499 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1500 if (mlirOperationIsNull(owner)) 1501 throw py::error_already_set(); 1502 MlirContext ctx = mlirOperationGetContext(owner); 1503 PyOperationRef ownerRef = 1504 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1505 return PyValue(ownerRef, value); 1506 } 1507 1508 namespace { 1509 /// CRTP base class for Python MLIR values that subclass Value and should be 1510 /// castable from it. The value hierarchy is one level deep and is not supposed 1511 /// to accommodate other levels unless core MLIR changes. 1512 template <typename DerivedTy> 1513 class PyConcreteValue : public PyValue { 1514 public: 1515 // Derived classes must define statics for: 1516 // IsAFunctionTy isaFunction 1517 // const char *pyClassName 1518 // and redefine bindDerived. 1519 using ClassTy = py::class_<DerivedTy, PyValue>; 1520 using IsAFunctionTy = bool (*)(MlirValue); 1521 1522 PyConcreteValue() = default; 1523 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1524 : PyValue(operationRef, value) {} 1525 PyConcreteValue(PyValue &orig) 1526 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1527 1528 /// Attempts to cast the original value to the derived type and throws on 1529 /// type mismatches. 1530 static MlirValue castFrom(PyValue &orig) { 1531 if (!DerivedTy::isaFunction(orig.get())) { 1532 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1533 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1534 DerivedTy::pyClassName + 1535 " (from " + origRepr + ")"); 1536 } 1537 return orig.get(); 1538 } 1539 1540 /// Binds the Python module objects to functions of this class. 1541 static void bind(py::module &m) { 1542 auto cls = ClassTy(m, DerivedTy::pyClassName); 1543 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1544 DerivedTy::bindDerived(cls); 1545 } 1546 1547 /// Implemented by derived classes to add methods to the Python subclass. 1548 static void bindDerived(ClassTy &m) {} 1549 }; 1550 1551 /// Python wrapper for MlirBlockArgument. 1552 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1553 public: 1554 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1555 static constexpr const char *pyClassName = "BlockArgument"; 1556 using PyConcreteValue::PyConcreteValue; 1557 1558 static void bindDerived(ClassTy &c) { 1559 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1560 return PyBlock(self.getParentOperation(), 1561 mlirBlockArgumentGetOwner(self.get())); 1562 }); 1563 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1564 return mlirBlockArgumentGetArgNumber(self.get()); 1565 }); 1566 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1567 return mlirBlockArgumentSetType(self.get(), type); 1568 }); 1569 } 1570 }; 1571 1572 /// Python wrapper for MlirOpResult. 1573 class PyOpResult : public PyConcreteValue<PyOpResult> { 1574 public: 1575 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1576 static constexpr const char *pyClassName = "OpResult"; 1577 using PyConcreteValue::PyConcreteValue; 1578 1579 static void bindDerived(ClassTy &c) { 1580 c.def_property_readonly("owner", [](PyOpResult &self) { 1581 assert( 1582 mlirOperationEqual(self.getParentOperation()->get(), 1583 mlirOpResultGetOwner(self.get())) && 1584 "expected the owner of the value in Python to match that in the IR"); 1585 return self.getParentOperation().getObject(); 1586 }); 1587 c.def_property_readonly("result_number", [](PyOpResult &self) { 1588 return mlirOpResultGetResultNumber(self.get()); 1589 }); 1590 } 1591 }; 1592 1593 /// A list of block arguments. Internally, these are stored as consecutive 1594 /// elements, random access is cheap. The argument list is associated with the 1595 /// operation that contains the block (detached blocks are not allowed in 1596 /// Python bindings) and extends its lifetime. 1597 class PyBlockArgumentList { 1598 public: 1599 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1600 : operation(std::move(operation)), block(block) {} 1601 1602 /// Returns the length of the block argument list. 1603 intptr_t dunderLen() { 1604 operation->checkValid(); 1605 return mlirBlockGetNumArguments(block); 1606 } 1607 1608 /// Returns `index`-th element of the block argument list. 1609 PyBlockArgument dunderGetItem(intptr_t index) { 1610 if (index < 0 || index >= dunderLen()) { 1611 throw SetPyError(PyExc_IndexError, 1612 "attempt to access out of bounds region"); 1613 } 1614 PyValue value(operation, mlirBlockGetArgument(block, index)); 1615 return PyBlockArgument(value); 1616 } 1617 1618 /// Defines a Python class in the bindings. 1619 static void bind(py::module &m) { 1620 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1621 .def("__len__", &PyBlockArgumentList::dunderLen) 1622 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1623 } 1624 1625 private: 1626 PyOperationRef operation; 1627 MlirBlock block; 1628 }; 1629 1630 /// A list of operation operands. Internally, these are stored as consecutive 1631 /// elements, random access is cheap. The result list is associated with the 1632 /// operation whose results these are, and extends the lifetime of this 1633 /// operation. 1634 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1635 public: 1636 static constexpr const char *pyClassName = "OpOperandList"; 1637 1638 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1639 intptr_t length = -1, intptr_t step = 1) 1640 : Sliceable(startIndex, 1641 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1642 : length, 1643 step), 1644 operation(operation) {} 1645 1646 intptr_t getNumElements() { 1647 operation->checkValid(); 1648 return mlirOperationGetNumOperands(operation->get()); 1649 } 1650 1651 PyValue getElement(intptr_t pos) { 1652 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1653 MlirOperation owner; 1654 if (mlirValueIsAOpResult(operand)) 1655 owner = mlirOpResultGetOwner(operand); 1656 else if (mlirValueIsABlockArgument(operand)) 1657 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1658 else 1659 assert(false && "Value must be an block arg or op result."); 1660 PyOperationRef pyOwner = 1661 PyOperation::forOperation(operation->getContext(), owner); 1662 return PyValue(pyOwner, operand); 1663 } 1664 1665 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1666 return PyOpOperandList(operation, startIndex, length, step); 1667 } 1668 1669 void dunderSetItem(intptr_t index, PyValue value) { 1670 index = wrapIndex(index); 1671 mlirOperationSetOperand(operation->get(), index, value.get()); 1672 } 1673 1674 static void bindDerived(ClassTy &c) { 1675 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1676 } 1677 1678 private: 1679 PyOperationRef operation; 1680 }; 1681 1682 /// A list of operation results. Internally, these are stored as consecutive 1683 /// elements, random access is cheap. The result list is associated with the 1684 /// operation whose results these are, and extends the lifetime of this 1685 /// operation. 1686 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1687 public: 1688 static constexpr const char *pyClassName = "OpResultList"; 1689 1690 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1691 intptr_t length = -1, intptr_t step = 1) 1692 : Sliceable(startIndex, 1693 length == -1 ? mlirOperationGetNumResults(operation->get()) 1694 : length, 1695 step), 1696 operation(operation) {} 1697 1698 intptr_t getNumElements() { 1699 operation->checkValid(); 1700 return mlirOperationGetNumResults(operation->get()); 1701 } 1702 1703 PyOpResult getElement(intptr_t index) { 1704 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1705 return PyOpResult(value); 1706 } 1707 1708 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1709 return PyOpResultList(operation, startIndex, length, step); 1710 } 1711 1712 private: 1713 PyOperationRef operation; 1714 }; 1715 1716 /// A list of operation attributes. Can be indexed by name, producing 1717 /// attributes, or by index, producing named attributes. 1718 class PyOpAttributeMap { 1719 public: 1720 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1721 1722 PyAttribute dunderGetItemNamed(const std::string &name) { 1723 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1724 toMlirStringRef(name)); 1725 if (mlirAttributeIsNull(attr)) { 1726 throw SetPyError(PyExc_KeyError, 1727 "attempt to access a non-existent attribute"); 1728 } 1729 return PyAttribute(operation->getContext(), attr); 1730 } 1731 1732 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1733 if (index < 0 || index >= dunderLen()) { 1734 throw SetPyError(PyExc_IndexError, 1735 "attempt to access out of bounds attribute"); 1736 } 1737 MlirNamedAttribute namedAttr = 1738 mlirOperationGetAttribute(operation->get(), index); 1739 return PyNamedAttribute( 1740 namedAttr.attribute, 1741 std::string(mlirIdentifierStr(namedAttr.name).data)); 1742 } 1743 1744 void dunderSetItem(const std::string &name, PyAttribute attr) { 1745 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1746 attr); 1747 } 1748 1749 void dunderDelItem(const std::string &name) { 1750 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1751 toMlirStringRef(name)); 1752 if (!removed) 1753 throw SetPyError(PyExc_KeyError, 1754 "attempt to delete a non-existent attribute"); 1755 } 1756 1757 intptr_t dunderLen() { 1758 return mlirOperationGetNumAttributes(operation->get()); 1759 } 1760 1761 bool dunderContains(const std::string &name) { 1762 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1763 operation->get(), toMlirStringRef(name))); 1764 } 1765 1766 static void bind(py::module &m) { 1767 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1768 .def("__contains__", &PyOpAttributeMap::dunderContains) 1769 .def("__len__", &PyOpAttributeMap::dunderLen) 1770 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1771 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1772 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1773 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1774 } 1775 1776 private: 1777 PyOperationRef operation; 1778 }; 1779 1780 } // end namespace 1781 1782 //------------------------------------------------------------------------------ 1783 // Populates the core exports of the 'ir' submodule. 1784 //------------------------------------------------------------------------------ 1785 1786 void mlir::python::populateIRCore(py::module &m) { 1787 //---------------------------------------------------------------------------- 1788 // Mapping of MlirContext. 1789 //---------------------------------------------------------------------------- 1790 py::class_<PyMlirContext>(m, "Context") 1791 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1792 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1793 .def("_get_context_again", 1794 [](PyMlirContext &self) { 1795 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1796 return ref.releaseObject(); 1797 }) 1798 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1799 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1800 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1801 &PyMlirContext::getCapsule) 1802 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1803 .def("__enter__", &PyMlirContext::contextEnter) 1804 .def("__exit__", &PyMlirContext::contextExit) 1805 .def_property_readonly_static( 1806 "current", 1807 [](py::object & /*class*/) { 1808 auto *context = PyThreadContextEntry::getDefaultContext(); 1809 if (!context) 1810 throw SetPyError(PyExc_ValueError, "No current Context"); 1811 return context; 1812 }, 1813 "Gets the Context bound to the current thread or raises ValueError") 1814 .def_property_readonly( 1815 "dialects", 1816 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1817 "Gets a container for accessing dialects by name") 1818 .def_property_readonly( 1819 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1820 "Alias for 'dialect'") 1821 .def( 1822 "get_dialect_descriptor", 1823 [=](PyMlirContext &self, std::string &name) { 1824 MlirDialect dialect = mlirContextGetOrLoadDialect( 1825 self.get(), {name.data(), name.size()}); 1826 if (mlirDialectIsNull(dialect)) { 1827 throw SetPyError(PyExc_ValueError, 1828 Twine("Dialect '") + name + "' not found"); 1829 } 1830 return PyDialectDescriptor(self.getRef(), dialect); 1831 }, 1832 "Gets or loads a dialect by name, returning its descriptor object") 1833 .def_property( 1834 "allow_unregistered_dialects", 1835 [](PyMlirContext &self) -> bool { 1836 return mlirContextGetAllowUnregisteredDialects(self.get()); 1837 }, 1838 [](PyMlirContext &self, bool value) { 1839 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1840 }) 1841 .def("enable_multithreading", 1842 [](PyMlirContext &self, bool enable) { 1843 mlirContextEnableMultithreading(self.get(), enable); 1844 }) 1845 .def("is_registered_operation", 1846 [](PyMlirContext &self, std::string &name) { 1847 return mlirContextIsRegisteredOperation( 1848 self.get(), MlirStringRef{name.data(), name.size()}); 1849 }); 1850 1851 //---------------------------------------------------------------------------- 1852 // Mapping of PyDialectDescriptor 1853 //---------------------------------------------------------------------------- 1854 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1855 .def_property_readonly("namespace", 1856 [](PyDialectDescriptor &self) { 1857 MlirStringRef ns = 1858 mlirDialectGetNamespace(self.get()); 1859 return py::str(ns.data, ns.length); 1860 }) 1861 .def("__repr__", [](PyDialectDescriptor &self) { 1862 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1863 std::string repr("<DialectDescriptor "); 1864 repr.append(ns.data, ns.length); 1865 repr.append(">"); 1866 return repr; 1867 }); 1868 1869 //---------------------------------------------------------------------------- 1870 // Mapping of PyDialects 1871 //---------------------------------------------------------------------------- 1872 py::class_<PyDialects>(m, "Dialects") 1873 .def("__getitem__", 1874 [=](PyDialects &self, std::string keyName) { 1875 MlirDialect dialect = 1876 self.getDialectForKey(keyName, /*attrError=*/false); 1877 py::object descriptor = 1878 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1879 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1880 }) 1881 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1882 MlirDialect dialect = 1883 self.getDialectForKey(attrName, /*attrError=*/true); 1884 py::object descriptor = 1885 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1886 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1887 }); 1888 1889 //---------------------------------------------------------------------------- 1890 // Mapping of PyDialect 1891 //---------------------------------------------------------------------------- 1892 py::class_<PyDialect>(m, "Dialect") 1893 .def(py::init<py::object>(), "descriptor") 1894 .def_property_readonly( 1895 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1896 .def("__repr__", [](py::object self) { 1897 auto clazz = self.attr("__class__"); 1898 return py::str("<Dialect ") + 1899 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1900 clazz.attr("__module__") + py::str(".") + 1901 clazz.attr("__name__") + py::str(")>"); 1902 }); 1903 1904 //---------------------------------------------------------------------------- 1905 // Mapping of Location 1906 //---------------------------------------------------------------------------- 1907 py::class_<PyLocation>(m, "Location") 1908 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1909 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1910 .def("__enter__", &PyLocation::contextEnter) 1911 .def("__exit__", &PyLocation::contextExit) 1912 .def("__eq__", 1913 [](PyLocation &self, PyLocation &other) -> bool { 1914 return mlirLocationEqual(self, other); 1915 }) 1916 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1917 .def_property_readonly_static( 1918 "current", 1919 [](py::object & /*class*/) { 1920 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1921 if (!loc) 1922 throw SetPyError(PyExc_ValueError, "No current Location"); 1923 return loc; 1924 }, 1925 "Gets the Location bound to the current thread or raises ValueError") 1926 .def_static( 1927 "unknown", 1928 [](DefaultingPyMlirContext context) { 1929 return PyLocation(context->getRef(), 1930 mlirLocationUnknownGet(context->get())); 1931 }, 1932 py::arg("context") = py::none(), 1933 "Gets a Location representing an unknown location") 1934 .def_static( 1935 "file", 1936 [](std::string filename, int line, int col, 1937 DefaultingPyMlirContext context) { 1938 return PyLocation( 1939 context->getRef(), 1940 mlirLocationFileLineColGet( 1941 context->get(), toMlirStringRef(filename), line, col)); 1942 }, 1943 py::arg("filename"), py::arg("line"), py::arg("col"), 1944 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1945 .def_property_readonly( 1946 "context", 1947 [](PyLocation &self) { return self.getContext().getObject(); }, 1948 "Context that owns the Location") 1949 .def("__repr__", [](PyLocation &self) { 1950 PyPrintAccumulator printAccum; 1951 mlirLocationPrint(self, printAccum.getCallback(), 1952 printAccum.getUserData()); 1953 return printAccum.join(); 1954 }); 1955 1956 //---------------------------------------------------------------------------- 1957 // Mapping of Module 1958 //---------------------------------------------------------------------------- 1959 py::class_<PyModule>(m, "Module") 1960 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1961 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1962 .def_static( 1963 "parse", 1964 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1965 MlirModule module = mlirModuleCreateParse( 1966 context->get(), toMlirStringRef(moduleAsm)); 1967 // TODO: Rework error reporting once diagnostic engine is exposed 1968 // in C API. 1969 if (mlirModuleIsNull(module)) { 1970 throw SetPyError( 1971 PyExc_ValueError, 1972 "Unable to parse module assembly (see diagnostics)"); 1973 } 1974 return PyModule::forModule(module).releaseObject(); 1975 }, 1976 py::arg("asm"), py::arg("context") = py::none(), 1977 kModuleParseDocstring) 1978 .def_static( 1979 "create", 1980 [](DefaultingPyLocation loc) { 1981 MlirModule module = mlirModuleCreateEmpty(loc); 1982 return PyModule::forModule(module).releaseObject(); 1983 }, 1984 py::arg("loc") = py::none(), "Creates an empty module") 1985 .def_property_readonly( 1986 "context", 1987 [](PyModule &self) { return self.getContext().getObject(); }, 1988 "Context that created the Module") 1989 .def_property_readonly( 1990 "operation", 1991 [](PyModule &self) { 1992 return PyOperation::forOperation(self.getContext(), 1993 mlirModuleGetOperation(self.get()), 1994 self.getRef().releaseObject()) 1995 .releaseObject(); 1996 }, 1997 "Accesses the module as an operation") 1998 .def_property_readonly( 1999 "body", 2000 [](PyModule &self) { 2001 PyOperationRef module_op = PyOperation::forOperation( 2002 self.getContext(), mlirModuleGetOperation(self.get()), 2003 self.getRef().releaseObject()); 2004 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2005 return returnBlock; 2006 }, 2007 "Return the block for this module") 2008 .def( 2009 "dump", 2010 [](PyModule &self) { 2011 mlirOperationDump(mlirModuleGetOperation(self.get())); 2012 }, 2013 kDumpDocstring) 2014 .def( 2015 "__str__", 2016 [](PyModule &self) { 2017 MlirOperation operation = mlirModuleGetOperation(self.get()); 2018 PyPrintAccumulator printAccum; 2019 mlirOperationPrint(operation, printAccum.getCallback(), 2020 printAccum.getUserData()); 2021 return printAccum.join(); 2022 }, 2023 kOperationStrDunderDocstring); 2024 2025 //---------------------------------------------------------------------------- 2026 // Mapping of Operation. 2027 //---------------------------------------------------------------------------- 2028 py::class_<PyOperationBase>(m, "_OperationBase") 2029 .def("__eq__", 2030 [](PyOperationBase &self, PyOperationBase &other) { 2031 return &self.getOperation() == &other.getOperation(); 2032 }) 2033 .def("__eq__", 2034 [](PyOperationBase &self, py::object other) { return false; }) 2035 .def_property_readonly("attributes", 2036 [](PyOperationBase &self) { 2037 return PyOpAttributeMap( 2038 self.getOperation().getRef()); 2039 }) 2040 .def_property_readonly("operands", 2041 [](PyOperationBase &self) { 2042 return PyOpOperandList( 2043 self.getOperation().getRef()); 2044 }) 2045 .def_property_readonly("regions", 2046 [](PyOperationBase &self) { 2047 return PyRegionList( 2048 self.getOperation().getRef()); 2049 }) 2050 .def_property_readonly( 2051 "results", 2052 [](PyOperationBase &self) { 2053 return PyOpResultList(self.getOperation().getRef()); 2054 }, 2055 "Returns the list of Operation results.") 2056 .def_property_readonly( 2057 "result", 2058 [](PyOperationBase &self) { 2059 auto &operation = self.getOperation(); 2060 auto numResults = mlirOperationGetNumResults(operation); 2061 if (numResults != 1) { 2062 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2063 throw SetPyError( 2064 PyExc_ValueError, 2065 Twine("Cannot call .result on operation ") + 2066 StringRef(name.data, name.length) + " which has " + 2067 Twine(numResults) + 2068 " results (it is only valid for operations with a " 2069 "single result)"); 2070 } 2071 return PyOpResult(operation.getRef(), 2072 mlirOperationGetResult(operation, 0)); 2073 }, 2074 "Shortcut to get an op result if it has only one (throws an error " 2075 "otherwise).") 2076 .def("__iter__", 2077 [](PyOperationBase &self) { 2078 return PyRegionIterator(self.getOperation().getRef()); 2079 }) 2080 .def( 2081 "__str__", 2082 [](PyOperationBase &self) { 2083 return self.getAsm(/*binary=*/false, 2084 /*largeElementsLimit=*/llvm::None, 2085 /*enableDebugInfo=*/false, 2086 /*prettyDebugInfo=*/false, 2087 /*printGenericOpForm=*/false, 2088 /*useLocalScope=*/false); 2089 }, 2090 "Returns the assembly form of the operation.") 2091 .def("print", &PyOperationBase::print, 2092 // Careful: Lots of arguments must match up with print method. 2093 py::arg("file") = py::none(), py::arg("binary") = false, 2094 py::arg("large_elements_limit") = py::none(), 2095 py::arg("enable_debug_info") = false, 2096 py::arg("pretty_debug_info") = false, 2097 py::arg("print_generic_op_form") = false, 2098 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2099 .def("get_asm", &PyOperationBase::getAsm, 2100 // Careful: Lots of arguments must match up with get_asm method. 2101 py::arg("binary") = false, 2102 py::arg("large_elements_limit") = py::none(), 2103 py::arg("enable_debug_info") = false, 2104 py::arg("pretty_debug_info") = false, 2105 py::arg("print_generic_op_form") = false, 2106 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2107 .def( 2108 "verify", 2109 [](PyOperationBase &self) { 2110 return mlirOperationVerify(self.getOperation()); 2111 }, 2112 "Verify the operation and return true if it passes, false if it " 2113 "fails."); 2114 2115 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2116 .def_static("create", &PyOperation::create, py::arg("name"), 2117 py::arg("results") = py::none(), 2118 py::arg("operands") = py::none(), 2119 py::arg("attributes") = py::none(), 2120 py::arg("successors") = py::none(), py::arg("regions") = 0, 2121 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2122 kOperationCreateDocstring) 2123 .def_property_readonly("parent", 2124 [](PyOperation &self) { 2125 return self.getParentOperation().getObject(); 2126 }) 2127 .def("erase", &PyOperation::erase) 2128 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2129 &PyOperation::getCapsule) 2130 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2131 .def_property_readonly("name", 2132 [](PyOperation &self) { 2133 self.checkValid(); 2134 MlirOperation operation = self.get(); 2135 MlirStringRef name = mlirIdentifierStr( 2136 mlirOperationGetName(operation)); 2137 return py::str(name.data, name.length); 2138 }) 2139 .def_property_readonly( 2140 "context", 2141 [](PyOperation &self) { 2142 self.checkValid(); 2143 return self.getContext().getObject(); 2144 }, 2145 "Context that owns the Operation") 2146 .def_property_readonly("opview", &PyOperation::createOpView); 2147 2148 auto opViewClass = 2149 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2150 .def(py::init<py::object>()) 2151 .def_property_readonly("operation", &PyOpView::getOperationObject) 2152 .def_property_readonly( 2153 "context", 2154 [](PyOpView &self) { 2155 return self.getOperation().getContext().getObject(); 2156 }, 2157 "Context that owns the Operation") 2158 .def("__str__", [](PyOpView &self) { 2159 return py::str(self.getOperationObject()); 2160 }); 2161 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2162 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2163 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2164 opViewClass.attr("build_generic") = classmethod( 2165 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2166 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2167 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2168 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2169 "Builds a specific, generated OpView based on class level attributes."); 2170 2171 //---------------------------------------------------------------------------- 2172 // Mapping of PyRegion. 2173 //---------------------------------------------------------------------------- 2174 py::class_<PyRegion>(m, "Region") 2175 .def_property_readonly( 2176 "blocks", 2177 [](PyRegion &self) { 2178 return PyBlockList(self.getParentOperation(), self.get()); 2179 }, 2180 "Returns a forward-optimized sequence of blocks.") 2181 .def( 2182 "__iter__", 2183 [](PyRegion &self) { 2184 self.checkValid(); 2185 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2186 return PyBlockIterator(self.getParentOperation(), firstBlock); 2187 }, 2188 "Iterates over blocks in the region.") 2189 .def("__eq__", 2190 [](PyRegion &self, PyRegion &other) { 2191 return self.get().ptr == other.get().ptr; 2192 }) 2193 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2194 2195 //---------------------------------------------------------------------------- 2196 // Mapping of PyBlock. 2197 //---------------------------------------------------------------------------- 2198 py::class_<PyBlock>(m, "Block") 2199 .def_property_readonly( 2200 "arguments", 2201 [](PyBlock &self) { 2202 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2203 }, 2204 "Returns a list of block arguments.") 2205 .def_property_readonly( 2206 "operations", 2207 [](PyBlock &self) { 2208 return PyOperationList(self.getParentOperation(), self.get()); 2209 }, 2210 "Returns a forward-optimized sequence of operations.") 2211 .def( 2212 "__iter__", 2213 [](PyBlock &self) { 2214 self.checkValid(); 2215 MlirOperation firstOperation = 2216 mlirBlockGetFirstOperation(self.get()); 2217 return PyOperationIterator(self.getParentOperation(), 2218 firstOperation); 2219 }, 2220 "Iterates over operations in the block.") 2221 .def("__eq__", 2222 [](PyBlock &self, PyBlock &other) { 2223 return self.get().ptr == other.get().ptr; 2224 }) 2225 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2226 .def( 2227 "__str__", 2228 [](PyBlock &self) { 2229 self.checkValid(); 2230 PyPrintAccumulator printAccum; 2231 mlirBlockPrint(self.get(), printAccum.getCallback(), 2232 printAccum.getUserData()); 2233 return printAccum.join(); 2234 }, 2235 "Returns the assembly form of the block."); 2236 2237 //---------------------------------------------------------------------------- 2238 // Mapping of PyInsertionPoint. 2239 //---------------------------------------------------------------------------- 2240 2241 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2242 .def(py::init<PyBlock &>(), py::arg("block"), 2243 "Inserts after the last operation but still inside the block.") 2244 .def("__enter__", &PyInsertionPoint::contextEnter) 2245 .def("__exit__", &PyInsertionPoint::contextExit) 2246 .def_property_readonly_static( 2247 "current", 2248 [](py::object & /*class*/) { 2249 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2250 if (!ip) 2251 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2252 return ip; 2253 }, 2254 "Gets the InsertionPoint bound to the current thread or raises " 2255 "ValueError if none has been set") 2256 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2257 "Inserts before a referenced operation.") 2258 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2259 py::arg("block"), "Inserts at the beginning of the block.") 2260 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2261 py::arg("block"), "Inserts before the block terminator.") 2262 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2263 "Inserts an operation."); 2264 2265 //---------------------------------------------------------------------------- 2266 // Mapping of PyAttribute. 2267 //---------------------------------------------------------------------------- 2268 py::class_<PyAttribute>(m, "Attribute") 2269 // Delegate to the PyAttribute copy constructor, which will also lifetime 2270 // extend the backing context which owns the MlirAttribute. 2271 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2272 "Casts the passed attribute to the generic Attribute") 2273 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2274 &PyAttribute::getCapsule) 2275 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2276 .def_static( 2277 "parse", 2278 [](std::string attrSpec, DefaultingPyMlirContext context) { 2279 MlirAttribute type = mlirAttributeParseGet( 2280 context->get(), toMlirStringRef(attrSpec)); 2281 // TODO: Rework error reporting once diagnostic engine is exposed 2282 // in C API. 2283 if (mlirAttributeIsNull(type)) { 2284 throw SetPyError(PyExc_ValueError, 2285 Twine("Unable to parse attribute: '") + 2286 attrSpec + "'"); 2287 } 2288 return PyAttribute(context->getRef(), type); 2289 }, 2290 py::arg("asm"), py::arg("context") = py::none(), 2291 "Parses an attribute from an assembly form") 2292 .def_property_readonly( 2293 "context", 2294 [](PyAttribute &self) { return self.getContext().getObject(); }, 2295 "Context that owns the Attribute") 2296 .def_property_readonly("type", 2297 [](PyAttribute &self) { 2298 return PyType(self.getContext()->getRef(), 2299 mlirAttributeGetType(self)); 2300 }) 2301 .def( 2302 "get_named", 2303 [](PyAttribute &self, std::string name) { 2304 return PyNamedAttribute(self, std::move(name)); 2305 }, 2306 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2307 .def("__eq__", 2308 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2309 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2310 .def( 2311 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2312 kDumpDocstring) 2313 .def( 2314 "__str__", 2315 [](PyAttribute &self) { 2316 PyPrintAccumulator printAccum; 2317 mlirAttributePrint(self, printAccum.getCallback(), 2318 printAccum.getUserData()); 2319 return printAccum.join(); 2320 }, 2321 "Returns the assembly form of the Attribute.") 2322 .def("__repr__", [](PyAttribute &self) { 2323 // Generally, assembly formats are not printed for __repr__ because 2324 // this can cause exceptionally long debug output and exceptions. 2325 // However, attribute values are generally considered useful and are 2326 // printed. This may need to be re-evaluated if debug dumps end up 2327 // being excessive. 2328 PyPrintAccumulator printAccum; 2329 printAccum.parts.append("Attribute("); 2330 mlirAttributePrint(self, printAccum.getCallback(), 2331 printAccum.getUserData()); 2332 printAccum.parts.append(")"); 2333 return printAccum.join(); 2334 }); 2335 2336 //---------------------------------------------------------------------------- 2337 // Mapping of PyNamedAttribute 2338 //---------------------------------------------------------------------------- 2339 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2340 .def("__repr__", 2341 [](PyNamedAttribute &self) { 2342 PyPrintAccumulator printAccum; 2343 printAccum.parts.append("NamedAttribute("); 2344 printAccum.parts.append( 2345 mlirIdentifierStr(self.namedAttr.name).data); 2346 printAccum.parts.append("="); 2347 mlirAttributePrint(self.namedAttr.attribute, 2348 printAccum.getCallback(), 2349 printAccum.getUserData()); 2350 printAccum.parts.append(")"); 2351 return printAccum.join(); 2352 }) 2353 .def_property_readonly( 2354 "name", 2355 [](PyNamedAttribute &self) { 2356 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2357 mlirIdentifierStr(self.namedAttr.name).length); 2358 }, 2359 "The name of the NamedAttribute binding") 2360 .def_property_readonly( 2361 "attr", 2362 [](PyNamedAttribute &self) { 2363 // TODO: When named attribute is removed/refactored, also remove 2364 // this constructor (it does an inefficient table lookup). 2365 auto contextRef = PyMlirContext::forContext( 2366 mlirAttributeGetContext(self.namedAttr.attribute)); 2367 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2368 }, 2369 py::keep_alive<0, 1>(), 2370 "The underlying generic attribute of the NamedAttribute binding"); 2371 2372 //---------------------------------------------------------------------------- 2373 // Mapping of PyType. 2374 //---------------------------------------------------------------------------- 2375 py::class_<PyType>(m, "Type") 2376 // Delegate to the PyType copy constructor, which will also lifetime 2377 // extend the backing context which owns the MlirType. 2378 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2379 "Casts the passed type to the generic Type") 2380 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2381 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2382 .def_static( 2383 "parse", 2384 [](std::string typeSpec, DefaultingPyMlirContext context) { 2385 MlirType type = 2386 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2387 // TODO: Rework error reporting once diagnostic engine is exposed 2388 // in C API. 2389 if (mlirTypeIsNull(type)) { 2390 throw SetPyError(PyExc_ValueError, 2391 Twine("Unable to parse type: '") + typeSpec + 2392 "'"); 2393 } 2394 return PyType(context->getRef(), type); 2395 }, 2396 py::arg("asm"), py::arg("context") = py::none(), 2397 kContextParseTypeDocstring) 2398 .def_property_readonly( 2399 "context", [](PyType &self) { return self.getContext().getObject(); }, 2400 "Context that owns the Type") 2401 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2402 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2403 .def( 2404 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2405 .def( 2406 "__str__", 2407 [](PyType &self) { 2408 PyPrintAccumulator printAccum; 2409 mlirTypePrint(self, printAccum.getCallback(), 2410 printAccum.getUserData()); 2411 return printAccum.join(); 2412 }, 2413 "Returns the assembly form of the type.") 2414 .def("__repr__", [](PyType &self) { 2415 // Generally, assembly formats are not printed for __repr__ because 2416 // this can cause exceptionally long debug output and exceptions. 2417 // However, types are an exception as they typically have compact 2418 // assembly forms and printing them is useful. 2419 PyPrintAccumulator printAccum; 2420 printAccum.parts.append("Type("); 2421 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2422 printAccum.parts.append(")"); 2423 return printAccum.join(); 2424 }); 2425 2426 //---------------------------------------------------------------------------- 2427 // Mapping of Value. 2428 //---------------------------------------------------------------------------- 2429 py::class_<PyValue>(m, "Value") 2430 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2431 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2432 .def_property_readonly( 2433 "context", 2434 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2435 "Context in which the value lives.") 2436 .def( 2437 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2438 kDumpDocstring) 2439 .def_property_readonly( 2440 "owner", 2441 [](PyValue &self) { 2442 assert(mlirOperationEqual(self.getParentOperation()->get(), 2443 mlirOpResultGetOwner(self.get())) && 2444 "expected the owner of the value in Python to match that in " 2445 "the IR"); 2446 return self.getParentOperation().getObject(); 2447 }) 2448 .def("__eq__", 2449 [](PyValue &self, PyValue &other) { 2450 return self.get().ptr == other.get().ptr; 2451 }) 2452 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2453 .def( 2454 "__str__", 2455 [](PyValue &self) { 2456 PyPrintAccumulator printAccum; 2457 printAccum.parts.append("Value("); 2458 mlirValuePrint(self.get(), printAccum.getCallback(), 2459 printAccum.getUserData()); 2460 printAccum.parts.append(")"); 2461 return printAccum.join(); 2462 }, 2463 kValueDunderStrDocstring) 2464 .def_property_readonly("type", [](PyValue &self) { 2465 return PyType(self.getParentOperation()->getContext(), 2466 mlirValueGetType(self.get())); 2467 }); 2468 PyBlockArgument::bind(m); 2469 PyOpResult::bind(m); 2470 2471 // Container bindings. 2472 PyBlockArgumentList::bind(m); 2473 PyBlockIterator::bind(m); 2474 PyBlockList::bind(m); 2475 PyOperationIterator::bind(m); 2476 PyOperationList::bind(m); 2477 PyOpAttributeMap::bind(m); 2478 PyOpOperandList::bind(m); 2479 PyOpResultList::bind(m); 2480 PyRegionIterator::bind(m); 2481 PyRegionList::bind(m); 2482 2483 // Debug bindings. 2484 PyGlobalDebugFlag::bind(m); 2485 } 2486