1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include <pybind11/stl.h> 22 23 namespace py = pybind11; 24 using namespace mlir; 25 using namespace mlir::python; 26 27 using llvm::SmallVector; 28 using llvm::StringRef; 29 using llvm::Twine; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kContextParseTypeDocstring[] = 36 R"(Parses the assembly form of a type. 37 38 Returns a Type object or raises a ValueError if the type cannot be parsed. 39 40 See also: https://mlir.llvm.org/docs/LangRef/#type-system 41 )"; 42 43 static const char kContextGetFileLocationDocstring[] = 44 R"(Gets a Location representing a file, line and column)"; 45 46 static const char kModuleParseDocstring[] = 47 R"(Parses a module's assembly format from a string. 48 49 Returns a new MlirModule or raises a ValueError if the parsing fails. 50 51 See also: https://mlir.llvm.org/docs/LangRef/ 52 )"; 53 54 static const char kOperationCreateDocstring[] = 55 R"(Creates a new operation. 56 57 Args: 58 name: Operation name (e.g. "dialect.operation"). 59 results: Sequence of Type representing op result types. 60 attributes: Dict of str:Attribute. 61 successors: List of Block for the operation's successors. 62 regions: Number of regions to create. 63 location: A Location object (defaults to resolve from context manager). 64 ip: An InsertionPoint (defaults to resolve from context manager or set to 65 False to disable insertion, even with an insertion point set in the 66 context manager). 67 Returns: 68 A new "detached" Operation object. Detached operations can be added 69 to blocks, which causes them to become "attached." 70 )"; 71 72 static const char kOperationPrintDocstring[] = 73 R"(Prints the assembly form of the operation to a file like object. 74 75 Args: 76 file: The file like object to write to. Defaults to sys.stdout. 77 binary: Whether to write bytes (True) or str (False). Defaults to False. 78 large_elements_limit: Whether to elide elements attributes above this 79 number of elements. Defaults to None (no limit). 80 enable_debug_info: Whether to print debug/location information. Defaults 81 to False. 82 pretty_debug_info: Whether to format debug information for easier reading 83 by a human (warning: the result is unparseable). 84 print_generic_op_form: Whether to print the generic assembly forms of all 85 ops. Defaults to False. 86 use_local_Scope: Whether to print in a way that is more optimized for 87 multi-threaded access but may not be consistent with how the overall 88 module prints. 89 )"; 90 91 static const char kOperationGetAsmDocstring[] = 92 R"(Gets the assembly form of the operation with all options available. 93 94 Args: 95 binary: Whether to return a bytes (True) or str (False) object. Defaults to 96 False. 97 ... others ...: See the print() method for common keyword arguments for 98 configuring the printout. 99 Returns: 100 Either a bytes or str object, depending on the setting of the 'binary' 101 argument. 102 )"; 103 104 static const char kOperationStrDunderDocstring[] = 105 R"(Gets the assembly form of the operation with default options. 106 107 If more advanced control over the assembly formatting or I/O options is needed, 108 use the dedicated print or get_asm method, which supports keyword arguments to 109 customize behavior. 110 )"; 111 112 static const char kDumpDocstring[] = 113 R"(Dumps a debug representation of the object to stderr.)"; 114 115 static const char kAppendBlockDocstring[] = 116 R"(Appends a new block, with argument types as positional args. 117 118 Returns: 119 The created block. 120 )"; 121 122 static const char kValueDunderStrDocstring[] = 123 R"(Returns the string form of the value. 124 125 If the value is a block argument, this is the assembly form of its type and the 126 position in the argument list. If the value is an operation result, this is 127 equivalent to printing the operation that produced it. 128 )"; 129 130 //------------------------------------------------------------------------------ 131 // Utilities. 132 //------------------------------------------------------------------------------ 133 134 /// Helper for creating an @classmethod. 135 template <class Func, typename... Args> 136 py::object classmethod(Func f, Args... args) { 137 py::object cf = py::cpp_function(f, args...); 138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 139 } 140 141 static py::object 142 createCustomDialectWrapper(const std::string &dialectNamespace, 143 py::object dialectDescriptor) { 144 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 145 if (!dialectClass) { 146 // Use the base class. 147 return py::cast(PyDialect(std::move(dialectDescriptor))); 148 } 149 150 // Create the custom implementation. 151 return (*dialectClass)(std::move(dialectDescriptor)); 152 } 153 154 static MlirStringRef toMlirStringRef(const std::string &s) { 155 return mlirStringRefCreate(s.data(), s.size()); 156 } 157 158 /// Wrapper for the global LLVM debugging flag. 159 struct PyGlobalDebugFlag { 160 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 161 162 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 163 164 static void bind(py::module &m) { 165 // Debug flags. 166 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") 167 .def_property_static("flag", &PyGlobalDebugFlag::get, 168 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 169 } 170 }; 171 172 //------------------------------------------------------------------------------ 173 // Collections. 174 //------------------------------------------------------------------------------ 175 176 namespace { 177 178 class PyRegionIterator { 179 public: 180 PyRegionIterator(PyOperationRef operation) 181 : operation(std::move(operation)) {} 182 183 PyRegionIterator &dunderIter() { return *this; } 184 185 PyRegion dunderNext() { 186 operation->checkValid(); 187 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 188 throw py::stop_iteration(); 189 } 190 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 191 return PyRegion(operation, region); 192 } 193 194 static void bind(py::module &m) { 195 py::class_<PyRegionIterator>(m, "RegionIterator") 196 .def("__iter__", &PyRegionIterator::dunderIter) 197 .def("__next__", &PyRegionIterator::dunderNext); 198 } 199 200 private: 201 PyOperationRef operation; 202 int nextIndex = 0; 203 }; 204 205 /// Regions of an op are fixed length and indexed numerically so are represented 206 /// with a sequence-like container. 207 class PyRegionList { 208 public: 209 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 210 211 intptr_t dunderLen() { 212 operation->checkValid(); 213 return mlirOperationGetNumRegions(operation->get()); 214 } 215 216 PyRegion dunderGetItem(intptr_t index) { 217 // dunderLen checks validity. 218 if (index < 0 || index >= dunderLen()) { 219 throw SetPyError(PyExc_IndexError, 220 "attempt to access out of bounds region"); 221 } 222 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 223 return PyRegion(operation, region); 224 } 225 226 static void bind(py::module &m) { 227 py::class_<PyRegionList>(m, "RegionSequence") 228 .def("__len__", &PyRegionList::dunderLen) 229 .def("__getitem__", &PyRegionList::dunderGetItem); 230 } 231 232 private: 233 PyOperationRef operation; 234 }; 235 236 class PyBlockIterator { 237 public: 238 PyBlockIterator(PyOperationRef operation, MlirBlock next) 239 : operation(std::move(operation)), next(next) {} 240 241 PyBlockIterator &dunderIter() { return *this; } 242 243 PyBlock dunderNext() { 244 operation->checkValid(); 245 if (mlirBlockIsNull(next)) { 246 throw py::stop_iteration(); 247 } 248 249 PyBlock returnBlock(operation, next); 250 next = mlirBlockGetNextInRegion(next); 251 return returnBlock; 252 } 253 254 static void bind(py::module &m) { 255 py::class_<PyBlockIterator>(m, "BlockIterator") 256 .def("__iter__", &PyBlockIterator::dunderIter) 257 .def("__next__", &PyBlockIterator::dunderNext); 258 } 259 260 private: 261 PyOperationRef operation; 262 MlirBlock next; 263 }; 264 265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 266 /// we present them as a more full-featured list-like container but optimize 267 /// it for forward iteration. Blocks are always owned by a region. 268 class PyBlockList { 269 public: 270 PyBlockList(PyOperationRef operation, MlirRegion region) 271 : operation(std::move(operation)), region(region) {} 272 273 PyBlockIterator dunderIter() { 274 operation->checkValid(); 275 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 276 } 277 278 intptr_t dunderLen() { 279 operation->checkValid(); 280 intptr_t count = 0; 281 MlirBlock block = mlirRegionGetFirstBlock(region); 282 while (!mlirBlockIsNull(block)) { 283 count += 1; 284 block = mlirBlockGetNextInRegion(block); 285 } 286 return count; 287 } 288 289 PyBlock dunderGetItem(intptr_t index) { 290 operation->checkValid(); 291 if (index < 0) { 292 throw SetPyError(PyExc_IndexError, 293 "attempt to access out of bounds block"); 294 } 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 if (index == 0) { 298 return PyBlock(operation, block); 299 } 300 block = mlirBlockGetNextInRegion(block); 301 index -= 1; 302 } 303 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 304 } 305 306 PyBlock appendBlock(py::args pyArgTypes) { 307 operation->checkValid(); 308 llvm::SmallVector<MlirType, 4> argTypes; 309 argTypes.reserve(pyArgTypes.size()); 310 for (auto &pyArg : pyArgTypes) { 311 argTypes.push_back(pyArg.cast<PyType &>()); 312 } 313 314 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 315 mlirRegionAppendOwnedBlock(region, block); 316 return PyBlock(operation, block); 317 } 318 319 static void bind(py::module &m) { 320 py::class_<PyBlockList>(m, "BlockList") 321 .def("__getitem__", &PyBlockList::dunderGetItem) 322 .def("__iter__", &PyBlockList::dunderIter) 323 .def("__len__", &PyBlockList::dunderLen) 324 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 325 } 326 327 private: 328 PyOperationRef operation; 329 MlirRegion region; 330 }; 331 332 class PyOperationIterator { 333 public: 334 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 335 : parentOperation(std::move(parentOperation)), next(next) {} 336 337 PyOperationIterator &dunderIter() { return *this; } 338 339 py::object dunderNext() { 340 parentOperation->checkValid(); 341 if (mlirOperationIsNull(next)) { 342 throw py::stop_iteration(); 343 } 344 345 PyOperationRef returnOperation = 346 PyOperation::forOperation(parentOperation->getContext(), next); 347 next = mlirOperationGetNextInBlock(next); 348 return returnOperation->createOpView(); 349 } 350 351 static void bind(py::module &m) { 352 py::class_<PyOperationIterator>(m, "OperationIterator") 353 .def("__iter__", &PyOperationIterator::dunderIter) 354 .def("__next__", &PyOperationIterator::dunderNext); 355 } 356 357 private: 358 PyOperationRef parentOperation; 359 MlirOperation next; 360 }; 361 362 /// Operations are exposed by the C-API as a forward-only linked list. In 363 /// Python, we present them as a more full-featured list-like container but 364 /// optimize it for forward iteration. Iterable operations are always owned 365 /// by a block. 366 class PyOperationList { 367 public: 368 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 369 : parentOperation(std::move(parentOperation)), block(block) {} 370 371 PyOperationIterator dunderIter() { 372 parentOperation->checkValid(); 373 return PyOperationIterator(parentOperation, 374 mlirBlockGetFirstOperation(block)); 375 } 376 377 intptr_t dunderLen() { 378 parentOperation->checkValid(); 379 intptr_t count = 0; 380 MlirOperation childOp = mlirBlockGetFirstOperation(block); 381 while (!mlirOperationIsNull(childOp)) { 382 count += 1; 383 childOp = mlirOperationGetNextInBlock(childOp); 384 } 385 return count; 386 } 387 388 py::object dunderGetItem(intptr_t index) { 389 parentOperation->checkValid(); 390 if (index < 0) { 391 throw SetPyError(PyExc_IndexError, 392 "attempt to access out of bounds operation"); 393 } 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 if (index == 0) { 397 return PyOperation::forOperation(parentOperation->getContext(), childOp) 398 ->createOpView(); 399 } 400 childOp = mlirOperationGetNextInBlock(childOp); 401 index -= 1; 402 } 403 throw SetPyError(PyExc_IndexError, 404 "attempt to access out of bounds operation"); 405 } 406 407 static void bind(py::module &m) { 408 py::class_<PyOperationList>(m, "OperationList") 409 .def("__getitem__", &PyOperationList::dunderGetItem) 410 .def("__iter__", &PyOperationList::dunderIter) 411 .def("__len__", &PyOperationList::dunderLen); 412 } 413 414 private: 415 PyOperationRef parentOperation; 416 MlirBlock block; 417 }; 418 419 } // namespace 420 421 //------------------------------------------------------------------------------ 422 // PyMlirContext 423 //------------------------------------------------------------------------------ 424 425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 426 py::gil_scoped_acquire acquire; 427 auto &liveContexts = getLiveContexts(); 428 liveContexts[context.ptr] = this; 429 } 430 431 PyMlirContext::~PyMlirContext() { 432 // Note that the only public way to construct an instance is via the 433 // forContext method, which always puts the associated handle into 434 // liveContexts. 435 py::gil_scoped_acquire acquire; 436 getLiveContexts().erase(context.ptr); 437 mlirContextDestroy(context); 438 } 439 440 py::object PyMlirContext::getCapsule() { 441 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 442 } 443 444 py::object PyMlirContext::createFromCapsule(py::object capsule) { 445 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 446 if (mlirContextIsNull(rawContext)) 447 throw py::error_already_set(); 448 return forContext(rawContext).releaseObject(); 449 } 450 451 PyMlirContext *PyMlirContext::createNewContextForInit() { 452 MlirContext context = mlirContextCreate(); 453 mlirRegisterAllDialects(context); 454 return new PyMlirContext(context); 455 } 456 457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 458 py::gil_scoped_acquire acquire; 459 auto &liveContexts = getLiveContexts(); 460 auto it = liveContexts.find(context.ptr); 461 if (it == liveContexts.end()) { 462 // Create. 463 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 464 py::object pyRef = py::cast(unownedContextWrapper); 465 assert(pyRef && "cast to py::object failed"); 466 liveContexts[context.ptr] = unownedContextWrapper; 467 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 468 } 469 // Use existing. 470 py::object pyRef = py::cast(it->second); 471 return PyMlirContextRef(it->second, std::move(pyRef)); 472 } 473 474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 475 static LiveContextMap liveContexts; 476 return liveContexts; 477 } 478 479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 480 481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 482 483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 484 485 pybind11::object PyMlirContext::contextEnter() { 486 return PyThreadContextEntry::pushContext(*this); 487 } 488 489 void PyMlirContext::contextExit(pybind11::object excType, 490 pybind11::object excVal, 491 pybind11::object excTb) { 492 PyThreadContextEntry::popContext(*this); 493 } 494 495 PyMlirContext &DefaultingPyMlirContext::resolve() { 496 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 497 if (!context) { 498 throw SetPyError( 499 PyExc_RuntimeError, 500 "An MLIR function requires a Context but none was provided in the call " 501 "or from the surrounding environment. Either pass to the function with " 502 "a 'context=' argument or establish a default using 'with Context():'"); 503 } 504 return *context; 505 } 506 507 //------------------------------------------------------------------------------ 508 // PyThreadContextEntry management 509 //------------------------------------------------------------------------------ 510 511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 512 static thread_local std::vector<PyThreadContextEntry> stack; 513 return stack; 514 } 515 516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 517 auto &stack = getStack(); 518 if (stack.empty()) 519 return nullptr; 520 return &stack.back(); 521 } 522 523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 524 py::object insertionPoint, 525 py::object location) { 526 auto &stack = getStack(); 527 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 528 std::move(location)); 529 // If the new stack has more than one entry and the context of the new top 530 // entry matches the previous, copy the insertionPoint and location from the 531 // previous entry if missing from the new top entry. 532 if (stack.size() > 1) { 533 auto &prev = *(stack.rbegin() + 1); 534 auto ¤t = stack.back(); 535 if (current.context.is(prev.context)) { 536 // Default non-context objects from the previous entry. 537 if (!current.insertionPoint) 538 current.insertionPoint = prev.insertionPoint; 539 if (!current.location) 540 current.location = prev.location; 541 } 542 } 543 } 544 545 PyMlirContext *PyThreadContextEntry::getContext() { 546 if (!context) 547 return nullptr; 548 return py::cast<PyMlirContext *>(context); 549 } 550 551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 552 if (!insertionPoint) 553 return nullptr; 554 return py::cast<PyInsertionPoint *>(insertionPoint); 555 } 556 557 PyLocation *PyThreadContextEntry::getLocation() { 558 if (!location) 559 return nullptr; 560 return py::cast<PyLocation *>(location); 561 } 562 563 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 564 auto *tos = getTopOfStack(); 565 return tos ? tos->getContext() : nullptr; 566 } 567 568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 569 auto *tos = getTopOfStack(); 570 return tos ? tos->getInsertionPoint() : nullptr; 571 } 572 573 PyLocation *PyThreadContextEntry::getDefaultLocation() { 574 auto *tos = getTopOfStack(); 575 return tos ? tos->getLocation() : nullptr; 576 } 577 578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 579 py::object contextObj = py::cast(context); 580 push(FrameKind::Context, /*context=*/contextObj, 581 /*insertionPoint=*/py::object(), 582 /*location=*/py::object()); 583 return contextObj; 584 } 585 586 void PyThreadContextEntry::popContext(PyMlirContext &context) { 587 auto &stack = getStack(); 588 if (stack.empty()) 589 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 590 auto &tos = stack.back(); 591 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593 stack.pop_back(); 594 } 595 596 py::object 597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 598 py::object contextObj = 599 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 600 py::object insertionPointObj = py::cast(insertionPoint); 601 push(FrameKind::InsertionPoint, 602 /*context=*/contextObj, 603 /*insertionPoint=*/insertionPointObj, 604 /*location=*/py::object()); 605 return insertionPointObj; 606 } 607 608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 609 auto &stack = getStack(); 610 if (stack.empty()) 611 throw SetPyError(PyExc_RuntimeError, 612 "Unbalanced InsertionPoint enter/exit"); 613 auto &tos = stack.back(); 614 if (tos.frameKind != FrameKind::InsertionPoint && 615 tos.getInsertionPoint() != &insertionPoint) 616 throw SetPyError(PyExc_RuntimeError, 617 "Unbalanced InsertionPoint enter/exit"); 618 stack.pop_back(); 619 } 620 621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 622 py::object contextObj = location.getContext().getObject(); 623 py::object locationObj = py::cast(location); 624 push(FrameKind::Location, /*context=*/contextObj, 625 /*insertionPoint=*/py::object(), 626 /*location=*/locationObj); 627 return locationObj; 628 } 629 630 void PyThreadContextEntry::popLocation(PyLocation &location) { 631 auto &stack = getStack(); 632 if (stack.empty()) 633 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 634 auto &tos = stack.back(); 635 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637 stack.pop_back(); 638 } 639 640 //------------------------------------------------------------------------------ 641 // PyDialect, PyDialectDescriptor, PyDialects 642 //------------------------------------------------------------------------------ 643 644 MlirDialect PyDialects::getDialectForKey(const std::string &key, 645 bool attrError) { 646 // If the "std" dialect was asked for, substitute the empty namespace :( 647 static const std::string emptyKey; 648 const std::string *canonKey = key == "std" ? &emptyKey : &key; 649 MlirDialect dialect = mlirContextGetOrLoadDialect( 650 getContext()->get(), {canonKey->data(), canonKey->size()}); 651 if (mlirDialectIsNull(dialect)) { 652 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 653 Twine("Dialect '") + key + "' not found"); 654 } 655 return dialect; 656 } 657 658 //------------------------------------------------------------------------------ 659 // PyLocation 660 //------------------------------------------------------------------------------ 661 662 py::object PyLocation::getCapsule() { 663 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 664 } 665 666 PyLocation PyLocation::createFromCapsule(py::object capsule) { 667 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 668 if (mlirLocationIsNull(rawLoc)) 669 throw py::error_already_set(); 670 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 671 rawLoc); 672 } 673 674 py::object PyLocation::contextEnter() { 675 return PyThreadContextEntry::pushLocation(*this); 676 } 677 678 void PyLocation::contextExit(py::object excType, py::object excVal, 679 py::object excTb) { 680 PyThreadContextEntry::popLocation(*this); 681 } 682 683 PyLocation &DefaultingPyLocation::resolve() { 684 auto *location = PyThreadContextEntry::getDefaultLocation(); 685 if (!location) { 686 throw SetPyError( 687 PyExc_RuntimeError, 688 "An MLIR function requires a Location but none was provided in the " 689 "call or from the surrounding environment. Either pass to the function " 690 "with a 'loc=' argument or establish a default using 'with loc:'"); 691 } 692 return *location; 693 } 694 695 //------------------------------------------------------------------------------ 696 // PyModule 697 //------------------------------------------------------------------------------ 698 699 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 700 : BaseContextObject(std::move(contextRef)), module(module) {} 701 702 PyModule::~PyModule() { 703 py::gil_scoped_acquire acquire; 704 auto &liveModules = getContext()->liveModules; 705 assert(liveModules.count(module.ptr) == 1 && 706 "destroying module not in live map"); 707 liveModules.erase(module.ptr); 708 mlirModuleDestroy(module); 709 } 710 711 PyModuleRef PyModule::forModule(MlirModule module) { 712 MlirContext context = mlirModuleGetContext(module); 713 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 714 715 py::gil_scoped_acquire acquire; 716 auto &liveModules = contextRef->liveModules; 717 auto it = liveModules.find(module.ptr); 718 if (it == liveModules.end()) { 719 // Create. 720 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 721 // Note that the default return value policy on cast is automatic_reference, 722 // which does not take ownership (delete will not be called). 723 // Just be explicit. 724 py::object pyRef = 725 py::cast(unownedModule, py::return_value_policy::take_ownership); 726 unownedModule->handle = pyRef; 727 liveModules[module.ptr] = 728 std::make_pair(unownedModule->handle, unownedModule); 729 return PyModuleRef(unownedModule, std::move(pyRef)); 730 } 731 // Use existing. 732 PyModule *existing = it->second.second; 733 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 734 return PyModuleRef(existing, std::move(pyRef)); 735 } 736 737 py::object PyModule::createFromCapsule(py::object capsule) { 738 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 739 if (mlirModuleIsNull(rawModule)) 740 throw py::error_already_set(); 741 return forModule(rawModule).releaseObject(); 742 } 743 744 py::object PyModule::getCapsule() { 745 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 746 } 747 748 //------------------------------------------------------------------------------ 749 // PyOperation 750 //------------------------------------------------------------------------------ 751 752 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 753 : BaseContextObject(std::move(contextRef)), operation(operation) {} 754 755 PyOperation::~PyOperation() { 756 auto &liveOperations = getContext()->liveOperations; 757 assert(liveOperations.count(operation.ptr) == 1 && 758 "destroying operation not in live map"); 759 liveOperations.erase(operation.ptr); 760 if (!isAttached()) { 761 mlirOperationDestroy(operation); 762 } 763 } 764 765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 766 MlirOperation operation, 767 py::object parentKeepAlive) { 768 auto &liveOperations = contextRef->liveOperations; 769 // Create. 770 PyOperation *unownedOperation = 771 new PyOperation(std::move(contextRef), operation); 772 // Note that the default return value policy on cast is automatic_reference, 773 // which does not take ownership (delete will not be called). 774 // Just be explicit. 775 py::object pyRef = 776 py::cast(unownedOperation, py::return_value_policy::take_ownership); 777 unownedOperation->handle = pyRef; 778 if (parentKeepAlive) { 779 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 780 } 781 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 782 return PyOperationRef(unownedOperation, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 auto it = liveOperations.find(operation.ptr); 790 if (it == liveOperations.end()) { 791 // Create. 792 return createInstance(std::move(contextRef), operation, 793 std::move(parentKeepAlive)); 794 } 795 // Use existing. 796 PyOperation *existing = it->second.second; 797 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 798 return PyOperationRef(existing, std::move(pyRef)); 799 } 800 801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 802 MlirOperation operation, 803 py::object parentKeepAlive) { 804 auto &liveOperations = contextRef->liveOperations; 805 assert(liveOperations.count(operation.ptr) == 0 && 806 "cannot create detached operation that already exists"); 807 (void)liveOperations; 808 809 PyOperationRef created = createInstance(std::move(contextRef), operation, 810 std::move(parentKeepAlive)); 811 created->attached = false; 812 return created; 813 } 814 815 void PyOperation::checkValid() const { 816 if (!valid) { 817 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 818 } 819 } 820 821 void PyOperationBase::print(py::object fileObject, bool binary, 822 llvm::Optional<int64_t> largeElementsLimit, 823 bool enableDebugInfo, bool prettyDebugInfo, 824 bool printGenericOpForm, bool useLocalScope) { 825 PyOperation &operation = getOperation(); 826 operation.checkValid(); 827 if (fileObject.is_none()) 828 fileObject = py::module::import("sys").attr("stdout"); 829 830 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 831 fileObject.attr("write")("// Verification failed, printing generic form\n"); 832 printGenericOpForm = true; 833 } 834 835 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 836 if (largeElementsLimit) 837 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 838 if (enableDebugInfo) 839 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 840 if (printGenericOpForm) 841 mlirOpPrintingFlagsPrintGenericOpForm(flags); 842 843 PyFileAccumulator accum(fileObject, binary); 844 py::gil_scoped_release(); 845 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 846 accum.getUserData()); 847 mlirOpPrintingFlagsDestroy(flags); 848 } 849 850 py::object PyOperationBase::getAsm(bool binary, 851 llvm::Optional<int64_t> largeElementsLimit, 852 bool enableDebugInfo, bool prettyDebugInfo, 853 bool printGenericOpForm, 854 bool useLocalScope) { 855 py::object fileObject; 856 if (binary) { 857 fileObject = py::module::import("io").attr("BytesIO")(); 858 } else { 859 fileObject = py::module::import("io").attr("StringIO")(); 860 } 861 print(fileObject, /*binary=*/binary, 862 /*largeElementsLimit=*/largeElementsLimit, 863 /*enableDebugInfo=*/enableDebugInfo, 864 /*prettyDebugInfo=*/prettyDebugInfo, 865 /*printGenericOpForm=*/printGenericOpForm, 866 /*useLocalScope=*/useLocalScope); 867 868 return fileObject.attr("getvalue")(); 869 } 870 871 PyOperationRef PyOperation::getParentOperation() { 872 if (!isAttached()) 873 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 874 MlirOperation operation = mlirOperationGetParentOperation(get()); 875 if (mlirOperationIsNull(operation)) 876 throw SetPyError(PyExc_ValueError, "Operation has no parent."); 877 return PyOperation::forOperation(getContext(), operation); 878 } 879 880 PyBlock PyOperation::getBlock() { 881 PyOperationRef parentOperation = getParentOperation(); 882 MlirBlock block = mlirOperationGetBlock(get()); 883 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 884 return PyBlock{std::move(parentOperation), block}; 885 } 886 887 py::object PyOperation::getCapsule() { 888 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 889 } 890 891 py::object PyOperation::createFromCapsule(py::object capsule) { 892 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 893 if (mlirOperationIsNull(rawOperation)) 894 throw py::error_already_set(); 895 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 896 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 897 .releaseObject(); 898 } 899 900 py::object PyOperation::create( 901 std::string name, llvm::Optional<std::vector<PyType *>> results, 902 llvm::Optional<std::vector<PyValue *>> operands, 903 llvm::Optional<py::dict> attributes, 904 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 905 DefaultingPyLocation location, py::object maybeIp) { 906 llvm::SmallVector<MlirValue, 4> mlirOperands; 907 llvm::SmallVector<MlirType, 4> mlirResults; 908 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 909 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 910 911 // General parameter validation. 912 if (regions < 0) 913 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 914 915 // Unpack/validate operands. 916 if (operands) { 917 mlirOperands.reserve(operands->size()); 918 for (PyValue *operand : *operands) { 919 if (!operand) 920 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 921 mlirOperands.push_back(operand->get()); 922 } 923 } 924 925 // Unpack/validate results. 926 if (results) { 927 mlirResults.reserve(results->size()); 928 for (PyType *result : *results) { 929 // TODO: Verify result type originate from the same context. 930 if (!result) 931 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 932 mlirResults.push_back(*result); 933 } 934 } 935 // Unpack/validate attributes. 936 if (attributes) { 937 mlirAttributes.reserve(attributes->size()); 938 for (auto &it : *attributes) { 939 std::string key; 940 try { 941 key = it.first.cast<std::string>(); 942 } catch (py::cast_error &err) { 943 std::string msg = "Invalid attribute key (not a string) when " 944 "attempting to create the operation \"" + 945 name + "\" (" + err.what() + ")"; 946 throw py::cast_error(msg); 947 } 948 try { 949 auto &attribute = it.second.cast<PyAttribute &>(); 950 // TODO: Verify attribute originates from the same context. 951 mlirAttributes.emplace_back(std::move(key), attribute); 952 } catch (py::reference_cast_error &) { 953 // This exception seems thrown when the value is "None". 954 std::string msg = 955 "Found an invalid (`None`?) attribute value for the key \"" + key + 956 "\" when attempting to create the operation \"" + name + "\""; 957 throw py::cast_error(msg); 958 } catch (py::cast_error &err) { 959 std::string msg = "Invalid attribute value for the key \"" + key + 960 "\" when attempting to create the operation \"" + 961 name + "\" (" + err.what() + ")"; 962 throw py::cast_error(msg); 963 } 964 } 965 } 966 // Unpack/validate successors. 967 if (successors) { 968 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 969 mlirSuccessors.reserve(successors->size()); 970 for (auto *successor : *successors) { 971 // TODO: Verify successor originate from the same context. 972 if (!successor) 973 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 974 mlirSuccessors.push_back(successor->get()); 975 } 976 } 977 978 // Apply unpacked/validated to the operation state. Beyond this 979 // point, exceptions cannot be thrown or else the state will leak. 980 MlirOperationState state = 981 mlirOperationStateGet(toMlirStringRef(name), location); 982 if (!mlirOperands.empty()) 983 mlirOperationStateAddOperands(&state, mlirOperands.size(), 984 mlirOperands.data()); 985 if (!mlirResults.empty()) 986 mlirOperationStateAddResults(&state, mlirResults.size(), 987 mlirResults.data()); 988 if (!mlirAttributes.empty()) { 989 // Note that the attribute names directly reference bytes in 990 // mlirAttributes, so that vector must not be changed from here 991 // on. 992 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 993 mlirNamedAttributes.reserve(mlirAttributes.size()); 994 for (auto &it : mlirAttributes) 995 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 996 mlirIdentifierGet(mlirAttributeGetContext(it.second), 997 toMlirStringRef(it.first)), 998 it.second)); 999 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1000 mlirNamedAttributes.data()); 1001 } 1002 if (!mlirSuccessors.empty()) 1003 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1004 mlirSuccessors.data()); 1005 if (regions) { 1006 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1007 mlirRegions.resize(regions); 1008 for (int i = 0; i < regions; ++i) 1009 mlirRegions[i] = mlirRegionCreate(); 1010 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1011 mlirRegions.data()); 1012 } 1013 1014 // Construct the operation. 1015 MlirOperation operation = mlirOperationCreate(&state); 1016 PyOperationRef created = 1017 PyOperation::createDetached(location->getContext(), operation); 1018 1019 // InsertPoint active? 1020 if (!maybeIp.is(py::cast(false))) { 1021 PyInsertionPoint *ip; 1022 if (maybeIp.is_none()) { 1023 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1024 } else { 1025 ip = py::cast<PyInsertionPoint *>(maybeIp); 1026 } 1027 if (ip) 1028 ip->insert(*created.get()); 1029 } 1030 1031 return created->createOpView(); 1032 } 1033 1034 py::object PyOperation::createOpView() { 1035 MlirIdentifier ident = mlirOperationGetName(get()); 1036 MlirStringRef identStr = mlirIdentifierStr(ident); 1037 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1038 StringRef(identStr.data, identStr.length)); 1039 if (opViewClass) 1040 return (*opViewClass)(getRef().getObject()); 1041 return py::cast(PyOpView(getRef().getObject())); 1042 } 1043 1044 //------------------------------------------------------------------------------ 1045 // PyOpView 1046 //------------------------------------------------------------------------------ 1047 1048 py::object 1049 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1050 py::list operandList, 1051 llvm::Optional<py::dict> attributes, 1052 llvm::Optional<std::vector<PyBlock *>> successors, 1053 llvm::Optional<int> regions, 1054 DefaultingPyLocation location, py::object maybeIp) { 1055 PyMlirContextRef context = location->getContext(); 1056 // Class level operation construction metadata. 1057 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1058 // Operand and result segment specs are either none, which does no 1059 // variadic unpacking, or a list of ints with segment sizes, where each 1060 // element is either a positive number (typically 1 for a scalar) or -1 to 1061 // indicate that it is derived from the length of the same-indexed operand 1062 // or result (implying that it is a list at that position). 1063 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1064 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1065 1066 std::vector<uint32_t> operandSegmentLengths; 1067 std::vector<uint32_t> resultSegmentLengths; 1068 1069 // Validate/determine region count. 1070 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1071 int opMinRegionCount = std::get<0>(opRegionSpec); 1072 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1073 if (!regions) { 1074 regions = opMinRegionCount; 1075 } 1076 if (*regions < opMinRegionCount) { 1077 throw py::value_error( 1078 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1079 llvm::Twine(opMinRegionCount) + 1080 " regions but was built with regions=" + llvm::Twine(*regions)) 1081 .str()); 1082 } 1083 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1084 throw py::value_error( 1085 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1086 llvm::Twine(opMinRegionCount) + 1087 " regions but was built with regions=" + llvm::Twine(*regions)) 1088 .str()); 1089 } 1090 1091 // Unpack results. 1092 std::vector<PyType *> resultTypes; 1093 resultTypes.reserve(resultTypeList.size()); 1094 if (resultSegmentSpecObj.is_none()) { 1095 // Non-variadic result unpacking. 1096 for (auto it : llvm::enumerate(resultTypeList)) { 1097 try { 1098 resultTypes.push_back(py::cast<PyType *>(it.value())); 1099 if (!resultTypes.back()) 1100 throw py::cast_error(); 1101 } catch (py::cast_error &err) { 1102 throw py::value_error((llvm::Twine("Result ") + 1103 llvm::Twine(it.index()) + " of operation \"" + 1104 name + "\" must be a Type (" + err.what() + ")") 1105 .str()); 1106 } 1107 } 1108 } else { 1109 // Sized result unpacking. 1110 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1111 if (resultSegmentSpec.size() != resultTypeList.size()) { 1112 throw py::value_error((llvm::Twine("Operation \"") + name + 1113 "\" requires " + 1114 llvm::Twine(resultSegmentSpec.size()) + 1115 "result segments but was provided " + 1116 llvm::Twine(resultTypeList.size())) 1117 .str()); 1118 } 1119 resultSegmentLengths.reserve(resultTypeList.size()); 1120 for (auto it : 1121 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1122 int segmentSpec = std::get<1>(it.value()); 1123 if (segmentSpec == 1 || segmentSpec == 0) { 1124 // Unpack unary element. 1125 try { 1126 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1127 if (resultType) { 1128 resultTypes.push_back(resultType); 1129 resultSegmentLengths.push_back(1); 1130 } else if (segmentSpec == 0) { 1131 // Allowed to be optional. 1132 resultSegmentLengths.push_back(0); 1133 } else { 1134 throw py::cast_error("was None and result is not optional"); 1135 } 1136 } catch (py::cast_error &err) { 1137 throw py::value_error((llvm::Twine("Result ") + 1138 llvm::Twine(it.index()) + " of operation \"" + 1139 name + "\" must be a Type (" + err.what() + 1140 ")") 1141 .str()); 1142 } 1143 } else if (segmentSpec == -1) { 1144 // Unpack sequence by appending. 1145 try { 1146 if (std::get<0>(it.value()).is_none()) { 1147 // Treat it as an empty list. 1148 resultSegmentLengths.push_back(0); 1149 } else { 1150 // Unpack the list. 1151 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1152 for (py::object segmentItem : segment) { 1153 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1154 if (!resultTypes.back()) { 1155 throw py::cast_error("contained a None item"); 1156 } 1157 } 1158 resultSegmentLengths.push_back(segment.size()); 1159 } 1160 } catch (std::exception &err) { 1161 // NOTE: Sloppy to be using a catch-all here, but there are at least 1162 // three different unrelated exceptions that can be thrown in the 1163 // above "casts". Just keep the scope above small and catch them all. 1164 throw py::value_error((llvm::Twine("Result ") + 1165 llvm::Twine(it.index()) + " of operation \"" + 1166 name + "\" must be a Sequence of Types (" + 1167 err.what() + ")") 1168 .str()); 1169 } 1170 } else { 1171 throw py::value_error("Unexpected segment spec"); 1172 } 1173 } 1174 } 1175 1176 // Unpack operands. 1177 std::vector<PyValue *> operands; 1178 operands.reserve(operands.size()); 1179 if (operandSegmentSpecObj.is_none()) { 1180 // Non-sized operand unpacking. 1181 for (auto it : llvm::enumerate(operandList)) { 1182 try { 1183 operands.push_back(py::cast<PyValue *>(it.value())); 1184 if (!operands.back()) 1185 throw py::cast_error(); 1186 } catch (py::cast_error &err) { 1187 throw py::value_error((llvm::Twine("Operand ") + 1188 llvm::Twine(it.index()) + " of operation \"" + 1189 name + "\" must be a Value (" + err.what() + ")") 1190 .str()); 1191 } 1192 } 1193 } else { 1194 // Sized operand unpacking. 1195 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1196 if (operandSegmentSpec.size() != operandList.size()) { 1197 throw py::value_error((llvm::Twine("Operation \"") + name + 1198 "\" requires " + 1199 llvm::Twine(operandSegmentSpec.size()) + 1200 "operand segments but was provided " + 1201 llvm::Twine(operandList.size())) 1202 .str()); 1203 } 1204 operandSegmentLengths.reserve(operandList.size()); 1205 for (auto it : 1206 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1207 int segmentSpec = std::get<1>(it.value()); 1208 if (segmentSpec == 1 || segmentSpec == 0) { 1209 // Unpack unary element. 1210 try { 1211 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1212 if (operandValue) { 1213 operands.push_back(operandValue); 1214 operandSegmentLengths.push_back(1); 1215 } else if (segmentSpec == 0) { 1216 // Allowed to be optional. 1217 operandSegmentLengths.push_back(0); 1218 } else { 1219 throw py::cast_error("was None and operand is not optional"); 1220 } 1221 } catch (py::cast_error &err) { 1222 throw py::value_error((llvm::Twine("Operand ") + 1223 llvm::Twine(it.index()) + " of operation \"" + 1224 name + "\" must be a Value (" + err.what() + 1225 ")") 1226 .str()); 1227 } 1228 } else if (segmentSpec == -1) { 1229 // Unpack sequence by appending. 1230 try { 1231 if (std::get<0>(it.value()).is_none()) { 1232 // Treat it as an empty list. 1233 operandSegmentLengths.push_back(0); 1234 } else { 1235 // Unpack the list. 1236 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1237 for (py::object segmentItem : segment) { 1238 operands.push_back(py::cast<PyValue *>(segmentItem)); 1239 if (!operands.back()) { 1240 throw py::cast_error("contained a None item"); 1241 } 1242 } 1243 operandSegmentLengths.push_back(segment.size()); 1244 } 1245 } catch (std::exception &err) { 1246 // NOTE: Sloppy to be using a catch-all here, but there are at least 1247 // three different unrelated exceptions that can be thrown in the 1248 // above "casts". Just keep the scope above small and catch them all. 1249 throw py::value_error((llvm::Twine("Operand ") + 1250 llvm::Twine(it.index()) + " of operation \"" + 1251 name + "\" must be a Sequence of Values (" + 1252 err.what() + ")") 1253 .str()); 1254 } 1255 } else { 1256 throw py::value_error("Unexpected segment spec"); 1257 } 1258 } 1259 } 1260 1261 // Merge operand/result segment lengths into attributes if needed. 1262 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1263 // Dup. 1264 if (attributes) { 1265 attributes = py::dict(*attributes); 1266 } else { 1267 attributes = py::dict(); 1268 } 1269 if (attributes->contains("result_segment_sizes") || 1270 attributes->contains("operand_segment_sizes")) { 1271 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1272 "'operand_segment_sizes' attribute is unsupported. " 1273 "Use Operation.create for such low-level access."); 1274 } 1275 1276 // Add result_segment_sizes attribute. 1277 if (!resultSegmentLengths.empty()) { 1278 int64_t size = resultSegmentLengths.size(); 1279 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1280 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1281 resultSegmentLengths.size(), resultSegmentLengths.data()); 1282 (*attributes)["result_segment_sizes"] = 1283 PyAttribute(context, segmentLengthAttr); 1284 } 1285 1286 // Add operand_segment_sizes attribute. 1287 if (!operandSegmentLengths.empty()) { 1288 int64_t size = operandSegmentLengths.size(); 1289 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1290 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1291 operandSegmentLengths.size(), operandSegmentLengths.data()); 1292 (*attributes)["operand_segment_sizes"] = 1293 PyAttribute(context, segmentLengthAttr); 1294 } 1295 } 1296 1297 // Delegate to create. 1298 return PyOperation::create(std::move(name), 1299 /*results=*/std::move(resultTypes), 1300 /*operands=*/std::move(operands), 1301 /*attributes=*/std::move(attributes), 1302 /*successors=*/std::move(successors), 1303 /*regions=*/*regions, location, maybeIp); 1304 } 1305 1306 PyOpView::PyOpView(py::object operationObject) 1307 // Casting through the PyOperationBase base-class and then back to the 1308 // Operation lets us accept any PyOperationBase subclass. 1309 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1310 operationObject(operation.getRef().getObject()) {} 1311 1312 py::object PyOpView::createRawSubclass(py::object userClass) { 1313 // This is... a little gross. The typical pattern is to have a pure python 1314 // class that extends OpView like: 1315 // class AddFOp(_cext.ir.OpView): 1316 // def __init__(self, loc, lhs, rhs): 1317 // operation = loc.context.create_operation( 1318 // "addf", lhs, rhs, results=[lhs.type]) 1319 // super().__init__(operation) 1320 // 1321 // I.e. The goal of the user facing type is to provide a nice constructor 1322 // that has complete freedom for the op under construction. This is at odds 1323 // with our other desire to sometimes create this object by just passing an 1324 // operation (to initialize the base class). We could do *arg and **kwargs 1325 // munging to try to make it work, but instead, we synthesize a new class 1326 // on the fly which extends this user class (AddFOp in this example) and 1327 // *give it* the base class's __init__ method, thus bypassing the 1328 // intermediate subclass's __init__ method entirely. While slightly, 1329 // underhanded, this is safe/legal because the type hierarchy has not changed 1330 // (we just added a new leaf) and we aren't mucking around with __new__. 1331 // Typically, this new class will be stored on the original as "_Raw" and will 1332 // be used for casts and other things that need a variant of the class that 1333 // is initialized purely from an operation. 1334 py::object parentMetaclass = 1335 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1336 py::dict attributes; 1337 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1338 // now. 1339 // auto opViewType = py::type::of<PyOpView>(); 1340 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1341 attributes["__init__"] = opViewType.attr("__init__"); 1342 py::str origName = userClass.attr("__name__"); 1343 py::str newName = py::str("_") + origName; 1344 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1345 } 1346 1347 //------------------------------------------------------------------------------ 1348 // PyInsertionPoint. 1349 //------------------------------------------------------------------------------ 1350 1351 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1352 1353 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1354 : refOperation(beforeOperationBase.getOperation().getRef()), 1355 block((*refOperation)->getBlock()) {} 1356 1357 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1358 PyOperation &operation = operationBase.getOperation(); 1359 if (operation.isAttached()) 1360 throw SetPyError(PyExc_ValueError, 1361 "Attempt to insert operation that is already attached"); 1362 block.getParentOperation()->checkValid(); 1363 MlirOperation beforeOp = {nullptr}; 1364 if (refOperation) { 1365 // Insert before operation. 1366 (*refOperation)->checkValid(); 1367 beforeOp = (*refOperation)->get(); 1368 } else { 1369 // Insert at end (before null) is only valid if the block does not 1370 // already end in a known terminator (violating this will cause assertion 1371 // failures later). 1372 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1373 throw py::index_error("Cannot insert operation at the end of a block " 1374 "that already has a terminator. Did you mean to " 1375 "use 'InsertionPoint.at_block_terminator(block)' " 1376 "versus 'InsertionPoint(block)'?"); 1377 } 1378 } 1379 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1380 operation.setAttached(); 1381 } 1382 1383 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1384 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1385 if (mlirOperationIsNull(firstOp)) { 1386 // Just insert at end. 1387 return PyInsertionPoint(block); 1388 } 1389 1390 // Insert before first op. 1391 PyOperationRef firstOpRef = PyOperation::forOperation( 1392 block.getParentOperation()->getContext(), firstOp); 1393 return PyInsertionPoint{block, std::move(firstOpRef)}; 1394 } 1395 1396 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1397 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1398 if (mlirOperationIsNull(terminator)) 1399 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1400 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1401 block.getParentOperation()->getContext(), terminator); 1402 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1403 } 1404 1405 py::object PyInsertionPoint::contextEnter() { 1406 return PyThreadContextEntry::pushInsertionPoint(*this); 1407 } 1408 1409 void PyInsertionPoint::contextExit(pybind11::object excType, 1410 pybind11::object excVal, 1411 pybind11::object excTb) { 1412 PyThreadContextEntry::popInsertionPoint(*this); 1413 } 1414 1415 //------------------------------------------------------------------------------ 1416 // PyAttribute. 1417 //------------------------------------------------------------------------------ 1418 1419 bool PyAttribute::operator==(const PyAttribute &other) { 1420 return mlirAttributeEqual(attr, other.attr); 1421 } 1422 1423 py::object PyAttribute::getCapsule() { 1424 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1425 } 1426 1427 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1428 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1429 if (mlirAttributeIsNull(rawAttr)) 1430 throw py::error_already_set(); 1431 return PyAttribute( 1432 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1433 } 1434 1435 //------------------------------------------------------------------------------ 1436 // PyNamedAttribute. 1437 //------------------------------------------------------------------------------ 1438 1439 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1440 : ownedName(new std::string(std::move(ownedName))) { 1441 namedAttr = mlirNamedAttributeGet( 1442 mlirIdentifierGet(mlirAttributeGetContext(attr), 1443 toMlirStringRef(*this->ownedName)), 1444 attr); 1445 } 1446 1447 //------------------------------------------------------------------------------ 1448 // PyType. 1449 //------------------------------------------------------------------------------ 1450 1451 bool PyType::operator==(const PyType &other) { 1452 return mlirTypeEqual(type, other.type); 1453 } 1454 1455 py::object PyType::getCapsule() { 1456 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1457 } 1458 1459 PyType PyType::createFromCapsule(py::object capsule) { 1460 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1461 if (mlirTypeIsNull(rawType)) 1462 throw py::error_already_set(); 1463 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1464 rawType); 1465 } 1466 1467 //------------------------------------------------------------------------------ 1468 // PyValue and subclases. 1469 //------------------------------------------------------------------------------ 1470 1471 pybind11::object PyValue::getCapsule() { 1472 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1473 } 1474 1475 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1476 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1477 if (mlirValueIsNull(value)) 1478 throw py::error_already_set(); 1479 MlirOperation owner; 1480 if (mlirValueIsAOpResult(value)) 1481 owner = mlirOpResultGetOwner(value); 1482 if (mlirValueIsABlockArgument(value)) 1483 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1484 if (mlirOperationIsNull(owner)) 1485 throw py::error_already_set(); 1486 MlirContext ctx = mlirOperationGetContext(owner); 1487 PyOperationRef ownerRef = 1488 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1489 return PyValue(ownerRef, value); 1490 } 1491 1492 namespace { 1493 /// CRTP base class for Python MLIR values that subclass Value and should be 1494 /// castable from it. The value hierarchy is one level deep and is not supposed 1495 /// to accommodate other levels unless core MLIR changes. 1496 template <typename DerivedTy> 1497 class PyConcreteValue : public PyValue { 1498 public: 1499 // Derived classes must define statics for: 1500 // IsAFunctionTy isaFunction 1501 // const char *pyClassName 1502 // and redefine bindDerived. 1503 using ClassTy = py::class_<DerivedTy, PyValue>; 1504 using IsAFunctionTy = bool (*)(MlirValue); 1505 1506 PyConcreteValue() = default; 1507 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1508 : PyValue(operationRef, value) {} 1509 PyConcreteValue(PyValue &orig) 1510 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1511 1512 /// Attempts to cast the original value to the derived type and throws on 1513 /// type mismatches. 1514 static MlirValue castFrom(PyValue &orig) { 1515 if (!DerivedTy::isaFunction(orig.get())) { 1516 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1517 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1518 DerivedTy::pyClassName + 1519 " (from " + origRepr + ")"); 1520 } 1521 return orig.get(); 1522 } 1523 1524 /// Binds the Python module objects to functions of this class. 1525 static void bind(py::module &m) { 1526 auto cls = ClassTy(m, DerivedTy::pyClassName); 1527 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1528 DerivedTy::bindDerived(cls); 1529 } 1530 1531 /// Implemented by derived classes to add methods to the Python subclass. 1532 static void bindDerived(ClassTy &m) {} 1533 }; 1534 1535 /// Python wrapper for MlirBlockArgument. 1536 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1537 public: 1538 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1539 static constexpr const char *pyClassName = "BlockArgument"; 1540 using PyConcreteValue::PyConcreteValue; 1541 1542 static void bindDerived(ClassTy &c) { 1543 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1544 return PyBlock(self.getParentOperation(), 1545 mlirBlockArgumentGetOwner(self.get())); 1546 }); 1547 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1548 return mlirBlockArgumentGetArgNumber(self.get()); 1549 }); 1550 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1551 return mlirBlockArgumentSetType(self.get(), type); 1552 }); 1553 } 1554 }; 1555 1556 /// Python wrapper for MlirOpResult. 1557 class PyOpResult : public PyConcreteValue<PyOpResult> { 1558 public: 1559 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1560 static constexpr const char *pyClassName = "OpResult"; 1561 using PyConcreteValue::PyConcreteValue; 1562 1563 static void bindDerived(ClassTy &c) { 1564 c.def_property_readonly("owner", [](PyOpResult &self) { 1565 assert( 1566 mlirOperationEqual(self.getParentOperation()->get(), 1567 mlirOpResultGetOwner(self.get())) && 1568 "expected the owner of the value in Python to match that in the IR"); 1569 return self.getParentOperation(); 1570 }); 1571 c.def_property_readonly("result_number", [](PyOpResult &self) { 1572 return mlirOpResultGetResultNumber(self.get()); 1573 }); 1574 } 1575 }; 1576 1577 /// A list of block arguments. Internally, these are stored as consecutive 1578 /// elements, random access is cheap. The argument list is associated with the 1579 /// operation that contains the block (detached blocks are not allowed in 1580 /// Python bindings) and extends its lifetime. 1581 class PyBlockArgumentList { 1582 public: 1583 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1584 : operation(std::move(operation)), block(block) {} 1585 1586 /// Returns the length of the block argument list. 1587 intptr_t dunderLen() { 1588 operation->checkValid(); 1589 return mlirBlockGetNumArguments(block); 1590 } 1591 1592 /// Returns `index`-th element of the block argument list. 1593 PyBlockArgument dunderGetItem(intptr_t index) { 1594 if (index < 0 || index >= dunderLen()) { 1595 throw SetPyError(PyExc_IndexError, 1596 "attempt to access out of bounds region"); 1597 } 1598 PyValue value(operation, mlirBlockGetArgument(block, index)); 1599 return PyBlockArgument(value); 1600 } 1601 1602 /// Defines a Python class in the bindings. 1603 static void bind(py::module &m) { 1604 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1605 .def("__len__", &PyBlockArgumentList::dunderLen) 1606 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1607 } 1608 1609 private: 1610 PyOperationRef operation; 1611 MlirBlock block; 1612 }; 1613 1614 /// A list of operation operands. Internally, these are stored as consecutive 1615 /// elements, random access is cheap. The result list is associated with the 1616 /// operation whose results these are, and extends the lifetime of this 1617 /// operation. 1618 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1619 public: 1620 static constexpr const char *pyClassName = "OpOperandList"; 1621 1622 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1623 intptr_t length = -1, intptr_t step = 1) 1624 : Sliceable(startIndex, 1625 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1626 : length, 1627 step), 1628 operation(operation) {} 1629 1630 intptr_t getNumElements() { 1631 operation->checkValid(); 1632 return mlirOperationGetNumOperands(operation->get()); 1633 } 1634 1635 PyValue getElement(intptr_t pos) { 1636 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); 1637 } 1638 1639 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1640 return PyOpOperandList(operation, startIndex, length, step); 1641 } 1642 1643 void dunderSetItem(intptr_t index, PyValue value) { 1644 index = wrapIndex(index); 1645 mlirOperationSetOperand(operation->get(), index, value.get()); 1646 } 1647 1648 static void bindDerived(ClassTy &c) { 1649 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1650 } 1651 1652 private: 1653 PyOperationRef operation; 1654 }; 1655 1656 /// A list of operation results. Internally, these are stored as consecutive 1657 /// elements, random access is cheap. The result list is associated with the 1658 /// operation whose results these are, and extends the lifetime of this 1659 /// operation. 1660 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1661 public: 1662 static constexpr const char *pyClassName = "OpResultList"; 1663 1664 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1665 intptr_t length = -1, intptr_t step = 1) 1666 : Sliceable(startIndex, 1667 length == -1 ? mlirOperationGetNumResults(operation->get()) 1668 : length, 1669 step), 1670 operation(operation) {} 1671 1672 intptr_t getNumElements() { 1673 operation->checkValid(); 1674 return mlirOperationGetNumResults(operation->get()); 1675 } 1676 1677 PyOpResult getElement(intptr_t index) { 1678 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1679 return PyOpResult(value); 1680 } 1681 1682 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1683 return PyOpResultList(operation, startIndex, length, step); 1684 } 1685 1686 private: 1687 PyOperationRef operation; 1688 }; 1689 1690 /// A list of operation attributes. Can be indexed by name, producing 1691 /// attributes, or by index, producing named attributes. 1692 class PyOpAttributeMap { 1693 public: 1694 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1695 1696 PyAttribute dunderGetItemNamed(const std::string &name) { 1697 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1698 toMlirStringRef(name)); 1699 if (mlirAttributeIsNull(attr)) { 1700 throw SetPyError(PyExc_KeyError, 1701 "attempt to access a non-existent attribute"); 1702 } 1703 return PyAttribute(operation->getContext(), attr); 1704 } 1705 1706 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1707 if (index < 0 || index >= dunderLen()) { 1708 throw SetPyError(PyExc_IndexError, 1709 "attempt to access out of bounds attribute"); 1710 } 1711 MlirNamedAttribute namedAttr = 1712 mlirOperationGetAttribute(operation->get(), index); 1713 return PyNamedAttribute( 1714 namedAttr.attribute, 1715 std::string(mlirIdentifierStr(namedAttr.name).data)); 1716 } 1717 1718 void dunderSetItem(const std::string &name, PyAttribute attr) { 1719 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1720 attr); 1721 } 1722 1723 void dunderDelItem(const std::string &name) { 1724 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1725 toMlirStringRef(name)); 1726 if (!removed) 1727 throw SetPyError(PyExc_KeyError, 1728 "attempt to delete a non-existent attribute"); 1729 } 1730 1731 intptr_t dunderLen() { 1732 return mlirOperationGetNumAttributes(operation->get()); 1733 } 1734 1735 bool dunderContains(const std::string &name) { 1736 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1737 operation->get(), toMlirStringRef(name))); 1738 } 1739 1740 static void bind(py::module &m) { 1741 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1742 .def("__contains__", &PyOpAttributeMap::dunderContains) 1743 .def("__len__", &PyOpAttributeMap::dunderLen) 1744 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1745 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1746 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1747 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1748 } 1749 1750 private: 1751 PyOperationRef operation; 1752 }; 1753 1754 } // end namespace 1755 1756 //------------------------------------------------------------------------------ 1757 // Populates the core exports of the 'ir' submodule. 1758 //------------------------------------------------------------------------------ 1759 1760 void mlir::python::populateIRCore(py::module &m) { 1761 //---------------------------------------------------------------------------- 1762 // Mapping of MlirContext. 1763 //---------------------------------------------------------------------------- 1764 py::class_<PyMlirContext>(m, "Context") 1765 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1766 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1767 .def("_get_context_again", 1768 [](PyMlirContext &self) { 1769 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1770 return ref.releaseObject(); 1771 }) 1772 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1773 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1774 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1775 &PyMlirContext::getCapsule) 1776 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1777 .def("__enter__", &PyMlirContext::contextEnter) 1778 .def("__exit__", &PyMlirContext::contextExit) 1779 .def_property_readonly_static( 1780 "current", 1781 [](py::object & /*class*/) { 1782 auto *context = PyThreadContextEntry::getDefaultContext(); 1783 if (!context) 1784 throw SetPyError(PyExc_ValueError, "No current Context"); 1785 return context; 1786 }, 1787 "Gets the Context bound to the current thread or raises ValueError") 1788 .def_property_readonly( 1789 "dialects", 1790 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1791 "Gets a container for accessing dialects by name") 1792 .def_property_readonly( 1793 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1794 "Alias for 'dialect'") 1795 .def( 1796 "get_dialect_descriptor", 1797 [=](PyMlirContext &self, std::string &name) { 1798 MlirDialect dialect = mlirContextGetOrLoadDialect( 1799 self.get(), {name.data(), name.size()}); 1800 if (mlirDialectIsNull(dialect)) { 1801 throw SetPyError(PyExc_ValueError, 1802 Twine("Dialect '") + name + "' not found"); 1803 } 1804 return PyDialectDescriptor(self.getRef(), dialect); 1805 }, 1806 "Gets or loads a dialect by name, returning its descriptor object") 1807 .def_property( 1808 "allow_unregistered_dialects", 1809 [](PyMlirContext &self) -> bool { 1810 return mlirContextGetAllowUnregisteredDialects(self.get()); 1811 }, 1812 [](PyMlirContext &self, bool value) { 1813 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1814 }) 1815 .def("enable_multithreading", 1816 [](PyMlirContext &self, bool enable) { 1817 mlirContextEnableMultithreading(self.get(), enable); 1818 }) 1819 .def("is_registered_operation", 1820 [](PyMlirContext &self, std::string &name) { 1821 return mlirContextIsRegisteredOperation( 1822 self.get(), MlirStringRef{name.data(), name.size()}); 1823 }); 1824 1825 //---------------------------------------------------------------------------- 1826 // Mapping of PyDialectDescriptor 1827 //---------------------------------------------------------------------------- 1828 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1829 .def_property_readonly("namespace", 1830 [](PyDialectDescriptor &self) { 1831 MlirStringRef ns = 1832 mlirDialectGetNamespace(self.get()); 1833 return py::str(ns.data, ns.length); 1834 }) 1835 .def("__repr__", [](PyDialectDescriptor &self) { 1836 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1837 std::string repr("<DialectDescriptor "); 1838 repr.append(ns.data, ns.length); 1839 repr.append(">"); 1840 return repr; 1841 }); 1842 1843 //---------------------------------------------------------------------------- 1844 // Mapping of PyDialects 1845 //---------------------------------------------------------------------------- 1846 py::class_<PyDialects>(m, "Dialects") 1847 .def("__getitem__", 1848 [=](PyDialects &self, std::string keyName) { 1849 MlirDialect dialect = 1850 self.getDialectForKey(keyName, /*attrError=*/false); 1851 py::object descriptor = 1852 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1853 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1854 }) 1855 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1856 MlirDialect dialect = 1857 self.getDialectForKey(attrName, /*attrError=*/true); 1858 py::object descriptor = 1859 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1860 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1861 }); 1862 1863 //---------------------------------------------------------------------------- 1864 // Mapping of PyDialect 1865 //---------------------------------------------------------------------------- 1866 py::class_<PyDialect>(m, "Dialect") 1867 .def(py::init<py::object>(), "descriptor") 1868 .def_property_readonly( 1869 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1870 .def("__repr__", [](py::object self) { 1871 auto clazz = self.attr("__class__"); 1872 return py::str("<Dialect ") + 1873 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1874 clazz.attr("__module__") + py::str(".") + 1875 clazz.attr("__name__") + py::str(")>"); 1876 }); 1877 1878 //---------------------------------------------------------------------------- 1879 // Mapping of Location 1880 //---------------------------------------------------------------------------- 1881 py::class_<PyLocation>(m, "Location") 1882 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1883 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1884 .def("__enter__", &PyLocation::contextEnter) 1885 .def("__exit__", &PyLocation::contextExit) 1886 .def("__eq__", 1887 [](PyLocation &self, PyLocation &other) -> bool { 1888 return mlirLocationEqual(self, other); 1889 }) 1890 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1891 .def_property_readonly_static( 1892 "current", 1893 [](py::object & /*class*/) { 1894 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1895 if (!loc) 1896 throw SetPyError(PyExc_ValueError, "No current Location"); 1897 return loc; 1898 }, 1899 "Gets the Location bound to the current thread or raises ValueError") 1900 .def_static( 1901 "unknown", 1902 [](DefaultingPyMlirContext context) { 1903 return PyLocation(context->getRef(), 1904 mlirLocationUnknownGet(context->get())); 1905 }, 1906 py::arg("context") = py::none(), 1907 "Gets a Location representing an unknown location") 1908 .def_static( 1909 "file", 1910 [](std::string filename, int line, int col, 1911 DefaultingPyMlirContext context) { 1912 return PyLocation( 1913 context->getRef(), 1914 mlirLocationFileLineColGet( 1915 context->get(), toMlirStringRef(filename), line, col)); 1916 }, 1917 py::arg("filename"), py::arg("line"), py::arg("col"), 1918 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1919 .def_property_readonly( 1920 "context", 1921 [](PyLocation &self) { return self.getContext().getObject(); }, 1922 "Context that owns the Location") 1923 .def("__repr__", [](PyLocation &self) { 1924 PyPrintAccumulator printAccum; 1925 mlirLocationPrint(self, printAccum.getCallback(), 1926 printAccum.getUserData()); 1927 return printAccum.join(); 1928 }); 1929 1930 //---------------------------------------------------------------------------- 1931 // Mapping of Module 1932 //---------------------------------------------------------------------------- 1933 py::class_<PyModule>(m, "Module") 1934 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1935 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1936 .def_static( 1937 "parse", 1938 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1939 MlirModule module = mlirModuleCreateParse( 1940 context->get(), toMlirStringRef(moduleAsm)); 1941 // TODO: Rework error reporting once diagnostic engine is exposed 1942 // in C API. 1943 if (mlirModuleIsNull(module)) { 1944 throw SetPyError( 1945 PyExc_ValueError, 1946 "Unable to parse module assembly (see diagnostics)"); 1947 } 1948 return PyModule::forModule(module).releaseObject(); 1949 }, 1950 py::arg("asm"), py::arg("context") = py::none(), 1951 kModuleParseDocstring) 1952 .def_static( 1953 "create", 1954 [](DefaultingPyLocation loc) { 1955 MlirModule module = mlirModuleCreateEmpty(loc); 1956 return PyModule::forModule(module).releaseObject(); 1957 }, 1958 py::arg("loc") = py::none(), "Creates an empty module") 1959 .def_property_readonly( 1960 "context", 1961 [](PyModule &self) { return self.getContext().getObject(); }, 1962 "Context that created the Module") 1963 .def_property_readonly( 1964 "operation", 1965 [](PyModule &self) { 1966 return PyOperation::forOperation(self.getContext(), 1967 mlirModuleGetOperation(self.get()), 1968 self.getRef().releaseObject()) 1969 .releaseObject(); 1970 }, 1971 "Accesses the module as an operation") 1972 .def_property_readonly( 1973 "body", 1974 [](PyModule &self) { 1975 PyOperationRef module_op = PyOperation::forOperation( 1976 self.getContext(), mlirModuleGetOperation(self.get()), 1977 self.getRef().releaseObject()); 1978 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 1979 return returnBlock; 1980 }, 1981 "Return the block for this module") 1982 .def( 1983 "dump", 1984 [](PyModule &self) { 1985 mlirOperationDump(mlirModuleGetOperation(self.get())); 1986 }, 1987 kDumpDocstring) 1988 .def( 1989 "__str__", 1990 [](PyModule &self) { 1991 MlirOperation operation = mlirModuleGetOperation(self.get()); 1992 PyPrintAccumulator printAccum; 1993 mlirOperationPrint(operation, printAccum.getCallback(), 1994 printAccum.getUserData()); 1995 return printAccum.join(); 1996 }, 1997 kOperationStrDunderDocstring); 1998 1999 //---------------------------------------------------------------------------- 2000 // Mapping of Operation. 2001 //---------------------------------------------------------------------------- 2002 py::class_<PyOperationBase>(m, "_OperationBase") 2003 .def("__eq__", 2004 [](PyOperationBase &self, PyOperationBase &other) { 2005 return &self.getOperation() == &other.getOperation(); 2006 }) 2007 .def("__eq__", 2008 [](PyOperationBase &self, py::object other) { return false; }) 2009 .def_property_readonly("attributes", 2010 [](PyOperationBase &self) { 2011 return PyOpAttributeMap( 2012 self.getOperation().getRef()); 2013 }) 2014 .def_property_readonly("operands", 2015 [](PyOperationBase &self) { 2016 return PyOpOperandList( 2017 self.getOperation().getRef()); 2018 }) 2019 .def_property_readonly("regions", 2020 [](PyOperationBase &self) { 2021 return PyRegionList( 2022 self.getOperation().getRef()); 2023 }) 2024 .def_property_readonly( 2025 "results", 2026 [](PyOperationBase &self) { 2027 return PyOpResultList(self.getOperation().getRef()); 2028 }, 2029 "Returns the list of Operation results.") 2030 .def_property_readonly( 2031 "result", 2032 [](PyOperationBase &self) { 2033 auto &operation = self.getOperation(); 2034 auto numResults = mlirOperationGetNumResults(operation); 2035 if (numResults != 1) { 2036 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2037 throw SetPyError( 2038 PyExc_ValueError, 2039 Twine("Cannot call .result on operation ") + 2040 StringRef(name.data, name.length) + " which has " + 2041 Twine(numResults) + 2042 " results (it is only valid for operations with a " 2043 "single result)"); 2044 } 2045 return PyOpResult(operation.getRef(), 2046 mlirOperationGetResult(operation, 0)); 2047 }, 2048 "Shortcut to get an op result if it has only one (throws an error " 2049 "otherwise).") 2050 .def("__iter__", 2051 [](PyOperationBase &self) { 2052 return PyRegionIterator(self.getOperation().getRef()); 2053 }) 2054 .def( 2055 "__str__", 2056 [](PyOperationBase &self) { 2057 return self.getAsm(/*binary=*/false, 2058 /*largeElementsLimit=*/llvm::None, 2059 /*enableDebugInfo=*/false, 2060 /*prettyDebugInfo=*/false, 2061 /*printGenericOpForm=*/false, 2062 /*useLocalScope=*/false); 2063 }, 2064 "Returns the assembly form of the operation.") 2065 .def("print", &PyOperationBase::print, 2066 // Careful: Lots of arguments must match up with print method. 2067 py::arg("file") = py::none(), py::arg("binary") = false, 2068 py::arg("large_elements_limit") = py::none(), 2069 py::arg("enable_debug_info") = false, 2070 py::arg("pretty_debug_info") = false, 2071 py::arg("print_generic_op_form") = false, 2072 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2073 .def("get_asm", &PyOperationBase::getAsm, 2074 // Careful: Lots of arguments must match up with get_asm method. 2075 py::arg("binary") = false, 2076 py::arg("large_elements_limit") = py::none(), 2077 py::arg("enable_debug_info") = false, 2078 py::arg("pretty_debug_info") = false, 2079 py::arg("print_generic_op_form") = false, 2080 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2081 .def( 2082 "verify", 2083 [](PyOperationBase &self) { 2084 return mlirOperationVerify(self.getOperation()); 2085 }, 2086 "Verify the operation and return true if it passes, false if it " 2087 "fails."); 2088 2089 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2090 .def_static("create", &PyOperation::create, py::arg("name"), 2091 py::arg("results") = py::none(), 2092 py::arg("operands") = py::none(), 2093 py::arg("attributes") = py::none(), 2094 py::arg("successors") = py::none(), py::arg("regions") = 0, 2095 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2096 kOperationCreateDocstring) 2097 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2098 &PyOperation::getCapsule) 2099 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2100 .def_property_readonly("name", 2101 [](PyOperation &self) { 2102 MlirOperation operation = self.get(); 2103 MlirStringRef name = mlirIdentifierStr( 2104 mlirOperationGetName(operation)); 2105 return py::str(name.data, name.length); 2106 }) 2107 .def_property_readonly( 2108 "context", 2109 [](PyOperation &self) { return self.getContext().getObject(); }, 2110 "Context that owns the Operation") 2111 .def_property_readonly("opview", &PyOperation::createOpView); 2112 2113 auto opViewClass = 2114 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2115 .def(py::init<py::object>()) 2116 .def_property_readonly("operation", &PyOpView::getOperationObject) 2117 .def_property_readonly( 2118 "context", 2119 [](PyOpView &self) { 2120 return self.getOperation().getContext().getObject(); 2121 }, 2122 "Context that owns the Operation") 2123 .def("__str__", [](PyOpView &self) { 2124 return py::str(self.getOperationObject()); 2125 }); 2126 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2127 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2128 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2129 opViewClass.attr("build_generic") = classmethod( 2130 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2131 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2132 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2133 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2134 "Builds a specific, generated OpView based on class level attributes."); 2135 2136 //---------------------------------------------------------------------------- 2137 // Mapping of PyRegion. 2138 //---------------------------------------------------------------------------- 2139 py::class_<PyRegion>(m, "Region") 2140 .def_property_readonly( 2141 "blocks", 2142 [](PyRegion &self) { 2143 return PyBlockList(self.getParentOperation(), self.get()); 2144 }, 2145 "Returns a forward-optimized sequence of blocks.") 2146 .def( 2147 "__iter__", 2148 [](PyRegion &self) { 2149 self.checkValid(); 2150 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2151 return PyBlockIterator(self.getParentOperation(), firstBlock); 2152 }, 2153 "Iterates over blocks in the region.") 2154 .def("__eq__", 2155 [](PyRegion &self, PyRegion &other) { 2156 return self.get().ptr == other.get().ptr; 2157 }) 2158 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2159 2160 //---------------------------------------------------------------------------- 2161 // Mapping of PyBlock. 2162 //---------------------------------------------------------------------------- 2163 py::class_<PyBlock>(m, "Block") 2164 .def_property_readonly( 2165 "arguments", 2166 [](PyBlock &self) { 2167 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2168 }, 2169 "Returns a list of block arguments.") 2170 .def_property_readonly( 2171 "operations", 2172 [](PyBlock &self) { 2173 return PyOperationList(self.getParentOperation(), self.get()); 2174 }, 2175 "Returns a forward-optimized sequence of operations.") 2176 .def( 2177 "__iter__", 2178 [](PyBlock &self) { 2179 self.checkValid(); 2180 MlirOperation firstOperation = 2181 mlirBlockGetFirstOperation(self.get()); 2182 return PyOperationIterator(self.getParentOperation(), 2183 firstOperation); 2184 }, 2185 "Iterates over operations in the block.") 2186 .def("__eq__", 2187 [](PyBlock &self, PyBlock &other) { 2188 return self.get().ptr == other.get().ptr; 2189 }) 2190 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2191 .def( 2192 "__str__", 2193 [](PyBlock &self) { 2194 self.checkValid(); 2195 PyPrintAccumulator printAccum; 2196 mlirBlockPrint(self.get(), printAccum.getCallback(), 2197 printAccum.getUserData()); 2198 return printAccum.join(); 2199 }, 2200 "Returns the assembly form of the block."); 2201 2202 //---------------------------------------------------------------------------- 2203 // Mapping of PyInsertionPoint. 2204 //---------------------------------------------------------------------------- 2205 2206 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2207 .def(py::init<PyBlock &>(), py::arg("block"), 2208 "Inserts after the last operation but still inside the block.") 2209 .def("__enter__", &PyInsertionPoint::contextEnter) 2210 .def("__exit__", &PyInsertionPoint::contextExit) 2211 .def_property_readonly_static( 2212 "current", 2213 [](py::object & /*class*/) { 2214 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2215 if (!ip) 2216 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2217 return ip; 2218 }, 2219 "Gets the InsertionPoint bound to the current thread or raises " 2220 "ValueError if none has been set") 2221 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2222 "Inserts before a referenced operation.") 2223 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2224 py::arg("block"), "Inserts at the beginning of the block.") 2225 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2226 py::arg("block"), "Inserts before the block terminator.") 2227 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2228 "Inserts an operation."); 2229 2230 //---------------------------------------------------------------------------- 2231 // Mapping of PyAttribute. 2232 //---------------------------------------------------------------------------- 2233 py::class_<PyAttribute>(m, "Attribute") 2234 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2235 &PyAttribute::getCapsule) 2236 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2237 .def_static( 2238 "parse", 2239 [](std::string attrSpec, DefaultingPyMlirContext context) { 2240 MlirAttribute type = mlirAttributeParseGet( 2241 context->get(), toMlirStringRef(attrSpec)); 2242 // TODO: Rework error reporting once diagnostic engine is exposed 2243 // in C API. 2244 if (mlirAttributeIsNull(type)) { 2245 throw SetPyError(PyExc_ValueError, 2246 Twine("Unable to parse attribute: '") + 2247 attrSpec + "'"); 2248 } 2249 return PyAttribute(context->getRef(), type); 2250 }, 2251 py::arg("asm"), py::arg("context") = py::none(), 2252 "Parses an attribute from an assembly form") 2253 .def_property_readonly( 2254 "context", 2255 [](PyAttribute &self) { return self.getContext().getObject(); }, 2256 "Context that owns the Attribute") 2257 .def_property_readonly("type", 2258 [](PyAttribute &self) { 2259 return PyType(self.getContext()->getRef(), 2260 mlirAttributeGetType(self)); 2261 }) 2262 .def( 2263 "get_named", 2264 [](PyAttribute &self, std::string name) { 2265 return PyNamedAttribute(self, std::move(name)); 2266 }, 2267 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2268 .def("__eq__", 2269 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2270 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2271 .def( 2272 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2273 kDumpDocstring) 2274 .def( 2275 "__str__", 2276 [](PyAttribute &self) { 2277 PyPrintAccumulator printAccum; 2278 mlirAttributePrint(self, printAccum.getCallback(), 2279 printAccum.getUserData()); 2280 return printAccum.join(); 2281 }, 2282 "Returns the assembly form of the Attribute.") 2283 .def("__repr__", [](PyAttribute &self) { 2284 // Generally, assembly formats are not printed for __repr__ because 2285 // this can cause exceptionally long debug output and exceptions. 2286 // However, attribute values are generally considered useful and are 2287 // printed. This may need to be re-evaluated if debug dumps end up 2288 // being excessive. 2289 PyPrintAccumulator printAccum; 2290 printAccum.parts.append("Attribute("); 2291 mlirAttributePrint(self, printAccum.getCallback(), 2292 printAccum.getUserData()); 2293 printAccum.parts.append(")"); 2294 return printAccum.join(); 2295 }); 2296 2297 //---------------------------------------------------------------------------- 2298 // Mapping of PyNamedAttribute 2299 //---------------------------------------------------------------------------- 2300 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2301 .def("__repr__", 2302 [](PyNamedAttribute &self) { 2303 PyPrintAccumulator printAccum; 2304 printAccum.parts.append("NamedAttribute("); 2305 printAccum.parts.append( 2306 mlirIdentifierStr(self.namedAttr.name).data); 2307 printAccum.parts.append("="); 2308 mlirAttributePrint(self.namedAttr.attribute, 2309 printAccum.getCallback(), 2310 printAccum.getUserData()); 2311 printAccum.parts.append(")"); 2312 return printAccum.join(); 2313 }) 2314 .def_property_readonly( 2315 "name", 2316 [](PyNamedAttribute &self) { 2317 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2318 mlirIdentifierStr(self.namedAttr.name).length); 2319 }, 2320 "The name of the NamedAttribute binding") 2321 .def_property_readonly( 2322 "attr", 2323 [](PyNamedAttribute &self) { 2324 // TODO: When named attribute is removed/refactored, also remove 2325 // this constructor (it does an inefficient table lookup). 2326 auto contextRef = PyMlirContext::forContext( 2327 mlirAttributeGetContext(self.namedAttr.attribute)); 2328 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2329 }, 2330 py::keep_alive<0, 1>(), 2331 "The underlying generic attribute of the NamedAttribute binding"); 2332 2333 //---------------------------------------------------------------------------- 2334 // Mapping of PyType. 2335 //---------------------------------------------------------------------------- 2336 py::class_<PyType>(m, "Type") 2337 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2338 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2339 .def_static( 2340 "parse", 2341 [](std::string typeSpec, DefaultingPyMlirContext context) { 2342 MlirType type = 2343 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2344 // TODO: Rework error reporting once diagnostic engine is exposed 2345 // in C API. 2346 if (mlirTypeIsNull(type)) { 2347 throw SetPyError(PyExc_ValueError, 2348 Twine("Unable to parse type: '") + typeSpec + 2349 "'"); 2350 } 2351 return PyType(context->getRef(), type); 2352 }, 2353 py::arg("asm"), py::arg("context") = py::none(), 2354 kContextParseTypeDocstring) 2355 .def_property_readonly( 2356 "context", [](PyType &self) { return self.getContext().getObject(); }, 2357 "Context that owns the Type") 2358 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2359 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2360 .def( 2361 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2362 .def( 2363 "__str__", 2364 [](PyType &self) { 2365 PyPrintAccumulator printAccum; 2366 mlirTypePrint(self, printAccum.getCallback(), 2367 printAccum.getUserData()); 2368 return printAccum.join(); 2369 }, 2370 "Returns the assembly form of the type.") 2371 .def("__repr__", [](PyType &self) { 2372 // Generally, assembly formats are not printed for __repr__ because 2373 // this can cause exceptionally long debug output and exceptions. 2374 // However, types are an exception as they typically have compact 2375 // assembly forms and printing them is useful. 2376 PyPrintAccumulator printAccum; 2377 printAccum.parts.append("Type("); 2378 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2379 printAccum.parts.append(")"); 2380 return printAccum.join(); 2381 }); 2382 2383 //---------------------------------------------------------------------------- 2384 // Mapping of Value. 2385 //---------------------------------------------------------------------------- 2386 py::class_<PyValue>(m, "Value") 2387 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2388 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2389 .def_property_readonly( 2390 "context", 2391 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2392 "Context in which the value lives.") 2393 .def( 2394 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2395 kDumpDocstring) 2396 .def("__eq__", 2397 [](PyValue &self, PyValue &other) { 2398 return self.get().ptr == other.get().ptr; 2399 }) 2400 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2401 .def( 2402 "__str__", 2403 [](PyValue &self) { 2404 PyPrintAccumulator printAccum; 2405 printAccum.parts.append("Value("); 2406 mlirValuePrint(self.get(), printAccum.getCallback(), 2407 printAccum.getUserData()); 2408 printAccum.parts.append(")"); 2409 return printAccum.join(); 2410 }, 2411 kValueDunderStrDocstring) 2412 .def_property_readonly("type", [](PyValue &self) { 2413 return PyType(self.getParentOperation()->getContext(), 2414 mlirValueGetType(self.get())); 2415 }); 2416 PyBlockArgument::bind(m); 2417 PyOpResult::bind(m); 2418 2419 // Container bindings. 2420 PyBlockArgumentList::bind(m); 2421 PyBlockIterator::bind(m); 2422 PyBlockList::bind(m); 2423 PyOperationIterator::bind(m); 2424 PyOperationList::bind(m); 2425 PyOpAttributeMap::bind(m); 2426 PyOpOperandList::bind(m); 2427 PyOpResultList::bind(m); 2428 PyRegionIterator::bind(m); 2429 PyRegionList::bind(m); 2430 2431 // Debug bindings. 2432 PyGlobalDebugFlag::bind(m); 2433 } 2434