1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include <pybind11/stl.h> 22 23 namespace py = pybind11; 24 using namespace mlir; 25 using namespace mlir::python; 26 27 using llvm::SmallVector; 28 using llvm::StringRef; 29 using llvm::Twine; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kContextParseTypeDocstring[] = 36 R"(Parses the assembly form of a type. 37 38 Returns a Type object or raises a ValueError if the type cannot be parsed. 39 40 See also: https://mlir.llvm.org/docs/LangRef/#type-system 41 )"; 42 43 static const char kContextGetFileLocationDocstring[] = 44 R"(Gets a Location representing a file, line and column)"; 45 46 static const char kModuleParseDocstring[] = 47 R"(Parses a module's assembly format from a string. 48 49 Returns a new MlirModule or raises a ValueError if the parsing fails. 50 51 See also: https://mlir.llvm.org/docs/LangRef/ 52 )"; 53 54 static const char kOperationCreateDocstring[] = 55 R"(Creates a new operation. 56 57 Args: 58 name: Operation name (e.g. "dialect.operation"). 59 results: Sequence of Type representing op result types. 60 attributes: Dict of str:Attribute. 61 successors: List of Block for the operation's successors. 62 regions: Number of regions to create. 63 location: A Location object (defaults to resolve from context manager). 64 ip: An InsertionPoint (defaults to resolve from context manager or set to 65 False to disable insertion, even with an insertion point set in the 66 context manager). 67 Returns: 68 A new "detached" Operation object. Detached operations can be added 69 to blocks, which causes them to become "attached." 70 )"; 71 72 static const char kOperationPrintDocstring[] = 73 R"(Prints the assembly form of the operation to a file like object. 74 75 Args: 76 file: The file like object to write to. Defaults to sys.stdout. 77 binary: Whether to write bytes (True) or str (False). Defaults to False. 78 large_elements_limit: Whether to elide elements attributes above this 79 number of elements. Defaults to None (no limit). 80 enable_debug_info: Whether to print debug/location information. Defaults 81 to False. 82 pretty_debug_info: Whether to format debug information for easier reading 83 by a human (warning: the result is unparseable). 84 print_generic_op_form: Whether to print the generic assembly forms of all 85 ops. Defaults to False. 86 use_local_Scope: Whether to print in a way that is more optimized for 87 multi-threaded access but may not be consistent with how the overall 88 module prints. 89 )"; 90 91 static const char kOperationGetAsmDocstring[] = 92 R"(Gets the assembly form of the operation with all options available. 93 94 Args: 95 binary: Whether to return a bytes (True) or str (False) object. Defaults to 96 False. 97 ... others ...: See the print() method for common keyword arguments for 98 configuring the printout. 99 Returns: 100 Either a bytes or str object, depending on the setting of the 'binary' 101 argument. 102 )"; 103 104 static const char kOperationStrDunderDocstring[] = 105 R"(Gets the assembly form of the operation with default options. 106 107 If more advanced control over the assembly formatting or I/O options is needed, 108 use the dedicated print or get_asm method, which supports keyword arguments to 109 customize behavior. 110 )"; 111 112 static const char kDumpDocstring[] = 113 R"(Dumps a debug representation of the object to stderr.)"; 114 115 static const char kAppendBlockDocstring[] = 116 R"(Appends a new block, with argument types as positional args. 117 118 Returns: 119 The created block. 120 )"; 121 122 static const char kValueDunderStrDocstring[] = 123 R"(Returns the string form of the value. 124 125 If the value is a block argument, this is the assembly form of its type and the 126 position in the argument list. If the value is an operation result, this is 127 equivalent to printing the operation that produced it. 128 )"; 129 130 //------------------------------------------------------------------------------ 131 // Utilities. 132 //------------------------------------------------------------------------------ 133 134 /// Helper for creating an @classmethod. 135 template <class Func, typename... Args> 136 py::object classmethod(Func f, Args... args) { 137 py::object cf = py::cpp_function(f, args...); 138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 139 } 140 141 static py::object 142 createCustomDialectWrapper(const std::string &dialectNamespace, 143 py::object dialectDescriptor) { 144 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 145 if (!dialectClass) { 146 // Use the base class. 147 return py::cast(PyDialect(std::move(dialectDescriptor))); 148 } 149 150 // Create the custom implementation. 151 return (*dialectClass)(std::move(dialectDescriptor)); 152 } 153 154 static MlirStringRef toMlirStringRef(const std::string &s) { 155 return mlirStringRefCreate(s.data(), s.size()); 156 } 157 158 /// Wrapper for the global LLVM debugging flag. 159 struct PyGlobalDebugFlag { 160 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 161 162 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 163 164 static void bind(py::module &m) { 165 // Debug flags. 166 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") 167 .def_property_static("flag", &PyGlobalDebugFlag::get, 168 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 169 } 170 }; 171 172 //------------------------------------------------------------------------------ 173 // Collections. 174 //------------------------------------------------------------------------------ 175 176 namespace { 177 178 class PyRegionIterator { 179 public: 180 PyRegionIterator(PyOperationRef operation) 181 : operation(std::move(operation)) {} 182 183 PyRegionIterator &dunderIter() { return *this; } 184 185 PyRegion dunderNext() { 186 operation->checkValid(); 187 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 188 throw py::stop_iteration(); 189 } 190 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 191 return PyRegion(operation, region); 192 } 193 194 static void bind(py::module &m) { 195 py::class_<PyRegionIterator>(m, "RegionIterator") 196 .def("__iter__", &PyRegionIterator::dunderIter) 197 .def("__next__", &PyRegionIterator::dunderNext); 198 } 199 200 private: 201 PyOperationRef operation; 202 int nextIndex = 0; 203 }; 204 205 /// Regions of an op are fixed length and indexed numerically so are represented 206 /// with a sequence-like container. 207 class PyRegionList { 208 public: 209 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 210 211 intptr_t dunderLen() { 212 operation->checkValid(); 213 return mlirOperationGetNumRegions(operation->get()); 214 } 215 216 PyRegion dunderGetItem(intptr_t index) { 217 // dunderLen checks validity. 218 if (index < 0 || index >= dunderLen()) { 219 throw SetPyError(PyExc_IndexError, 220 "attempt to access out of bounds region"); 221 } 222 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 223 return PyRegion(operation, region); 224 } 225 226 static void bind(py::module &m) { 227 py::class_<PyRegionList>(m, "RegionSequence") 228 .def("__len__", &PyRegionList::dunderLen) 229 .def("__getitem__", &PyRegionList::dunderGetItem); 230 } 231 232 private: 233 PyOperationRef operation; 234 }; 235 236 class PyBlockIterator { 237 public: 238 PyBlockIterator(PyOperationRef operation, MlirBlock next) 239 : operation(std::move(operation)), next(next) {} 240 241 PyBlockIterator &dunderIter() { return *this; } 242 243 PyBlock dunderNext() { 244 operation->checkValid(); 245 if (mlirBlockIsNull(next)) { 246 throw py::stop_iteration(); 247 } 248 249 PyBlock returnBlock(operation, next); 250 next = mlirBlockGetNextInRegion(next); 251 return returnBlock; 252 } 253 254 static void bind(py::module &m) { 255 py::class_<PyBlockIterator>(m, "BlockIterator") 256 .def("__iter__", &PyBlockIterator::dunderIter) 257 .def("__next__", &PyBlockIterator::dunderNext); 258 } 259 260 private: 261 PyOperationRef operation; 262 MlirBlock next; 263 }; 264 265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 266 /// we present them as a more full-featured list-like container but optimize 267 /// it for forward iteration. Blocks are always owned by a region. 268 class PyBlockList { 269 public: 270 PyBlockList(PyOperationRef operation, MlirRegion region) 271 : operation(std::move(operation)), region(region) {} 272 273 PyBlockIterator dunderIter() { 274 operation->checkValid(); 275 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 276 } 277 278 intptr_t dunderLen() { 279 operation->checkValid(); 280 intptr_t count = 0; 281 MlirBlock block = mlirRegionGetFirstBlock(region); 282 while (!mlirBlockIsNull(block)) { 283 count += 1; 284 block = mlirBlockGetNextInRegion(block); 285 } 286 return count; 287 } 288 289 PyBlock dunderGetItem(intptr_t index) { 290 operation->checkValid(); 291 if (index < 0) { 292 throw SetPyError(PyExc_IndexError, 293 "attempt to access out of bounds block"); 294 } 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 if (index == 0) { 298 return PyBlock(operation, block); 299 } 300 block = mlirBlockGetNextInRegion(block); 301 index -= 1; 302 } 303 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 304 } 305 306 PyBlock appendBlock(py::args pyArgTypes) { 307 operation->checkValid(); 308 llvm::SmallVector<MlirType, 4> argTypes; 309 argTypes.reserve(pyArgTypes.size()); 310 for (auto &pyArg : pyArgTypes) { 311 argTypes.push_back(pyArg.cast<PyType &>()); 312 } 313 314 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 315 mlirRegionAppendOwnedBlock(region, block); 316 return PyBlock(operation, block); 317 } 318 319 static void bind(py::module &m) { 320 py::class_<PyBlockList>(m, "BlockList") 321 .def("__getitem__", &PyBlockList::dunderGetItem) 322 .def("__iter__", &PyBlockList::dunderIter) 323 .def("__len__", &PyBlockList::dunderLen) 324 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 325 } 326 327 private: 328 PyOperationRef operation; 329 MlirRegion region; 330 }; 331 332 class PyOperationIterator { 333 public: 334 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 335 : parentOperation(std::move(parentOperation)), next(next) {} 336 337 PyOperationIterator &dunderIter() { return *this; } 338 339 py::object dunderNext() { 340 parentOperation->checkValid(); 341 if (mlirOperationIsNull(next)) { 342 throw py::stop_iteration(); 343 } 344 345 PyOperationRef returnOperation = 346 PyOperation::forOperation(parentOperation->getContext(), next); 347 next = mlirOperationGetNextInBlock(next); 348 return returnOperation->createOpView(); 349 } 350 351 static void bind(py::module &m) { 352 py::class_<PyOperationIterator>(m, "OperationIterator") 353 .def("__iter__", &PyOperationIterator::dunderIter) 354 .def("__next__", &PyOperationIterator::dunderNext); 355 } 356 357 private: 358 PyOperationRef parentOperation; 359 MlirOperation next; 360 }; 361 362 /// Operations are exposed by the C-API as a forward-only linked list. In 363 /// Python, we present them as a more full-featured list-like container but 364 /// optimize it for forward iteration. Iterable operations are always owned 365 /// by a block. 366 class PyOperationList { 367 public: 368 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 369 : parentOperation(std::move(parentOperation)), block(block) {} 370 371 PyOperationIterator dunderIter() { 372 parentOperation->checkValid(); 373 return PyOperationIterator(parentOperation, 374 mlirBlockGetFirstOperation(block)); 375 } 376 377 intptr_t dunderLen() { 378 parentOperation->checkValid(); 379 intptr_t count = 0; 380 MlirOperation childOp = mlirBlockGetFirstOperation(block); 381 while (!mlirOperationIsNull(childOp)) { 382 count += 1; 383 childOp = mlirOperationGetNextInBlock(childOp); 384 } 385 return count; 386 } 387 388 py::object dunderGetItem(intptr_t index) { 389 parentOperation->checkValid(); 390 if (index < 0) { 391 throw SetPyError(PyExc_IndexError, 392 "attempt to access out of bounds operation"); 393 } 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 if (index == 0) { 397 return PyOperation::forOperation(parentOperation->getContext(), childOp) 398 ->createOpView(); 399 } 400 childOp = mlirOperationGetNextInBlock(childOp); 401 index -= 1; 402 } 403 throw SetPyError(PyExc_IndexError, 404 "attempt to access out of bounds operation"); 405 } 406 407 static void bind(py::module &m) { 408 py::class_<PyOperationList>(m, "OperationList") 409 .def("__getitem__", &PyOperationList::dunderGetItem) 410 .def("__iter__", &PyOperationList::dunderIter) 411 .def("__len__", &PyOperationList::dunderLen); 412 } 413 414 private: 415 PyOperationRef parentOperation; 416 MlirBlock block; 417 }; 418 419 } // namespace 420 421 //------------------------------------------------------------------------------ 422 // PyMlirContext 423 //------------------------------------------------------------------------------ 424 425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 426 py::gil_scoped_acquire acquire; 427 auto &liveContexts = getLiveContexts(); 428 liveContexts[context.ptr] = this; 429 } 430 431 PyMlirContext::~PyMlirContext() { 432 // Note that the only public way to construct an instance is via the 433 // forContext method, which always puts the associated handle into 434 // liveContexts. 435 py::gil_scoped_acquire acquire; 436 getLiveContexts().erase(context.ptr); 437 mlirContextDestroy(context); 438 } 439 440 py::object PyMlirContext::getCapsule() { 441 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 442 } 443 444 py::object PyMlirContext::createFromCapsule(py::object capsule) { 445 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 446 if (mlirContextIsNull(rawContext)) 447 throw py::error_already_set(); 448 return forContext(rawContext).releaseObject(); 449 } 450 451 PyMlirContext *PyMlirContext::createNewContextForInit() { 452 MlirContext context = mlirContextCreate(); 453 mlirRegisterAllDialects(context); 454 return new PyMlirContext(context); 455 } 456 457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 458 py::gil_scoped_acquire acquire; 459 auto &liveContexts = getLiveContexts(); 460 auto it = liveContexts.find(context.ptr); 461 if (it == liveContexts.end()) { 462 // Create. 463 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 464 py::object pyRef = py::cast(unownedContextWrapper); 465 assert(pyRef && "cast to py::object failed"); 466 liveContexts[context.ptr] = unownedContextWrapper; 467 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 468 } 469 // Use existing. 470 py::object pyRef = py::cast(it->second); 471 return PyMlirContextRef(it->second, std::move(pyRef)); 472 } 473 474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 475 static LiveContextMap liveContexts; 476 return liveContexts; 477 } 478 479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 480 481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 482 483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 484 485 pybind11::object PyMlirContext::contextEnter() { 486 return PyThreadContextEntry::pushContext(*this); 487 } 488 489 void PyMlirContext::contextExit(pybind11::object excType, 490 pybind11::object excVal, 491 pybind11::object excTb) { 492 PyThreadContextEntry::popContext(*this); 493 } 494 495 PyMlirContext &DefaultingPyMlirContext::resolve() { 496 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 497 if (!context) { 498 throw SetPyError( 499 PyExc_RuntimeError, 500 "An MLIR function requires a Context but none was provided in the call " 501 "or from the surrounding environment. Either pass to the function with " 502 "a 'context=' argument or establish a default using 'with Context():'"); 503 } 504 return *context; 505 } 506 507 //------------------------------------------------------------------------------ 508 // PyThreadContextEntry management 509 //------------------------------------------------------------------------------ 510 511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 512 static thread_local std::vector<PyThreadContextEntry> stack; 513 return stack; 514 } 515 516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 517 auto &stack = getStack(); 518 if (stack.empty()) 519 return nullptr; 520 return &stack.back(); 521 } 522 523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 524 py::object insertionPoint, 525 py::object location) { 526 auto &stack = getStack(); 527 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 528 std::move(location)); 529 // If the new stack has more than one entry and the context of the new top 530 // entry matches the previous, copy the insertionPoint and location from the 531 // previous entry if missing from the new top entry. 532 if (stack.size() > 1) { 533 auto &prev = *(stack.rbegin() + 1); 534 auto ¤t = stack.back(); 535 if (current.context.is(prev.context)) { 536 // Default non-context objects from the previous entry. 537 if (!current.insertionPoint) 538 current.insertionPoint = prev.insertionPoint; 539 if (!current.location) 540 current.location = prev.location; 541 } 542 } 543 } 544 545 PyMlirContext *PyThreadContextEntry::getContext() { 546 if (!context) 547 return nullptr; 548 return py::cast<PyMlirContext *>(context); 549 } 550 551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 552 if (!insertionPoint) 553 return nullptr; 554 return py::cast<PyInsertionPoint *>(insertionPoint); 555 } 556 557 PyLocation *PyThreadContextEntry::getLocation() { 558 if (!location) 559 return nullptr; 560 return py::cast<PyLocation *>(location); 561 } 562 563 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 564 auto *tos = getTopOfStack(); 565 return tos ? tos->getContext() : nullptr; 566 } 567 568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 569 auto *tos = getTopOfStack(); 570 return tos ? tos->getInsertionPoint() : nullptr; 571 } 572 573 PyLocation *PyThreadContextEntry::getDefaultLocation() { 574 auto *tos = getTopOfStack(); 575 return tos ? tos->getLocation() : nullptr; 576 } 577 578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 579 py::object contextObj = py::cast(context); 580 push(FrameKind::Context, /*context=*/contextObj, 581 /*insertionPoint=*/py::object(), 582 /*location=*/py::object()); 583 return contextObj; 584 } 585 586 void PyThreadContextEntry::popContext(PyMlirContext &context) { 587 auto &stack = getStack(); 588 if (stack.empty()) 589 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 590 auto &tos = stack.back(); 591 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593 stack.pop_back(); 594 } 595 596 py::object 597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 598 py::object contextObj = 599 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 600 py::object insertionPointObj = py::cast(insertionPoint); 601 push(FrameKind::InsertionPoint, 602 /*context=*/contextObj, 603 /*insertionPoint=*/insertionPointObj, 604 /*location=*/py::object()); 605 return insertionPointObj; 606 } 607 608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 609 auto &stack = getStack(); 610 if (stack.empty()) 611 throw SetPyError(PyExc_RuntimeError, 612 "Unbalanced InsertionPoint enter/exit"); 613 auto &tos = stack.back(); 614 if (tos.frameKind != FrameKind::InsertionPoint && 615 tos.getInsertionPoint() != &insertionPoint) 616 throw SetPyError(PyExc_RuntimeError, 617 "Unbalanced InsertionPoint enter/exit"); 618 stack.pop_back(); 619 } 620 621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 622 py::object contextObj = location.getContext().getObject(); 623 py::object locationObj = py::cast(location); 624 push(FrameKind::Location, /*context=*/contextObj, 625 /*insertionPoint=*/py::object(), 626 /*location=*/locationObj); 627 return locationObj; 628 } 629 630 void PyThreadContextEntry::popLocation(PyLocation &location) { 631 auto &stack = getStack(); 632 if (stack.empty()) 633 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 634 auto &tos = stack.back(); 635 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637 stack.pop_back(); 638 } 639 640 //------------------------------------------------------------------------------ 641 // PyDialect, PyDialectDescriptor, PyDialects 642 //------------------------------------------------------------------------------ 643 644 MlirDialect PyDialects::getDialectForKey(const std::string &key, 645 bool attrError) { 646 // If the "std" dialect was asked for, substitute the empty namespace :( 647 static const std::string emptyKey; 648 const std::string *canonKey = key == "std" ? &emptyKey : &key; 649 MlirDialect dialect = mlirContextGetOrLoadDialect( 650 getContext()->get(), {canonKey->data(), canonKey->size()}); 651 if (mlirDialectIsNull(dialect)) { 652 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 653 Twine("Dialect '") + key + "' not found"); 654 } 655 return dialect; 656 } 657 658 //------------------------------------------------------------------------------ 659 // PyLocation 660 //------------------------------------------------------------------------------ 661 662 py::object PyLocation::getCapsule() { 663 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 664 } 665 666 PyLocation PyLocation::createFromCapsule(py::object capsule) { 667 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 668 if (mlirLocationIsNull(rawLoc)) 669 throw py::error_already_set(); 670 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 671 rawLoc); 672 } 673 674 py::object PyLocation::contextEnter() { 675 return PyThreadContextEntry::pushLocation(*this); 676 } 677 678 void PyLocation::contextExit(py::object excType, py::object excVal, 679 py::object excTb) { 680 PyThreadContextEntry::popLocation(*this); 681 } 682 683 PyLocation &DefaultingPyLocation::resolve() { 684 auto *location = PyThreadContextEntry::getDefaultLocation(); 685 if (!location) { 686 throw SetPyError( 687 PyExc_RuntimeError, 688 "An MLIR function requires a Location but none was provided in the " 689 "call or from the surrounding environment. Either pass to the function " 690 "with a 'loc=' argument or establish a default using 'with loc:'"); 691 } 692 return *location; 693 } 694 695 //------------------------------------------------------------------------------ 696 // PyModule 697 //------------------------------------------------------------------------------ 698 699 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 700 : BaseContextObject(std::move(contextRef)), module(module) {} 701 702 PyModule::~PyModule() { 703 py::gil_scoped_acquire acquire; 704 auto &liveModules = getContext()->liveModules; 705 assert(liveModules.count(module.ptr) == 1 && 706 "destroying module not in live map"); 707 liveModules.erase(module.ptr); 708 mlirModuleDestroy(module); 709 } 710 711 PyModuleRef PyModule::forModule(MlirModule module) { 712 MlirContext context = mlirModuleGetContext(module); 713 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 714 715 py::gil_scoped_acquire acquire; 716 auto &liveModules = contextRef->liveModules; 717 auto it = liveModules.find(module.ptr); 718 if (it == liveModules.end()) { 719 // Create. 720 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 721 // Note that the default return value policy on cast is automatic_reference, 722 // which does not take ownership (delete will not be called). 723 // Just be explicit. 724 py::object pyRef = 725 py::cast(unownedModule, py::return_value_policy::take_ownership); 726 unownedModule->handle = pyRef; 727 liveModules[module.ptr] = 728 std::make_pair(unownedModule->handle, unownedModule); 729 return PyModuleRef(unownedModule, std::move(pyRef)); 730 } 731 // Use existing. 732 PyModule *existing = it->second.second; 733 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 734 return PyModuleRef(existing, std::move(pyRef)); 735 } 736 737 py::object PyModule::createFromCapsule(py::object capsule) { 738 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 739 if (mlirModuleIsNull(rawModule)) 740 throw py::error_already_set(); 741 return forModule(rawModule).releaseObject(); 742 } 743 744 py::object PyModule::getCapsule() { 745 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 746 } 747 748 //------------------------------------------------------------------------------ 749 // PyOperation 750 //------------------------------------------------------------------------------ 751 752 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 753 : BaseContextObject(std::move(contextRef)), operation(operation) {} 754 755 PyOperation::~PyOperation() { 756 auto &liveOperations = getContext()->liveOperations; 757 assert(liveOperations.count(operation.ptr) == 1 && 758 "destroying operation not in live map"); 759 liveOperations.erase(operation.ptr); 760 if (!isAttached()) { 761 mlirOperationDestroy(operation); 762 } 763 } 764 765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 766 MlirOperation operation, 767 py::object parentKeepAlive) { 768 auto &liveOperations = contextRef->liveOperations; 769 // Create. 770 PyOperation *unownedOperation = 771 new PyOperation(std::move(contextRef), operation); 772 // Note that the default return value policy on cast is automatic_reference, 773 // which does not take ownership (delete will not be called). 774 // Just be explicit. 775 py::object pyRef = 776 py::cast(unownedOperation, py::return_value_policy::take_ownership); 777 unownedOperation->handle = pyRef; 778 if (parentKeepAlive) { 779 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 780 } 781 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 782 return PyOperationRef(unownedOperation, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 auto it = liveOperations.find(operation.ptr); 790 if (it == liveOperations.end()) { 791 // Create. 792 return createInstance(std::move(contextRef), operation, 793 std::move(parentKeepAlive)); 794 } 795 // Use existing. 796 PyOperation *existing = it->second.second; 797 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 798 return PyOperationRef(existing, std::move(pyRef)); 799 } 800 801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 802 MlirOperation operation, 803 py::object parentKeepAlive) { 804 auto &liveOperations = contextRef->liveOperations; 805 assert(liveOperations.count(operation.ptr) == 0 && 806 "cannot create detached operation that already exists"); 807 (void)liveOperations; 808 809 PyOperationRef created = createInstance(std::move(contextRef), operation, 810 std::move(parentKeepAlive)); 811 created->attached = false; 812 return created; 813 } 814 815 void PyOperation::checkValid() const { 816 if (!valid) { 817 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 818 } 819 } 820 821 void PyOperationBase::print(py::object fileObject, bool binary, 822 llvm::Optional<int64_t> largeElementsLimit, 823 bool enableDebugInfo, bool prettyDebugInfo, 824 bool printGenericOpForm, bool useLocalScope) { 825 PyOperation &operation = getOperation(); 826 operation.checkValid(); 827 if (fileObject.is_none()) 828 fileObject = py::module::import("sys").attr("stdout"); 829 830 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 831 fileObject.attr("write")("// Verification failed, printing generic form\n"); 832 printGenericOpForm = true; 833 } 834 835 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 836 if (largeElementsLimit) 837 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 838 if (enableDebugInfo) 839 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 840 if (printGenericOpForm) 841 mlirOpPrintingFlagsPrintGenericOpForm(flags); 842 843 PyFileAccumulator accum(fileObject, binary); 844 py::gil_scoped_release(); 845 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 846 accum.getUserData()); 847 mlirOpPrintingFlagsDestroy(flags); 848 } 849 850 py::object PyOperationBase::getAsm(bool binary, 851 llvm::Optional<int64_t> largeElementsLimit, 852 bool enableDebugInfo, bool prettyDebugInfo, 853 bool printGenericOpForm, 854 bool useLocalScope) { 855 py::object fileObject; 856 if (binary) { 857 fileObject = py::module::import("io").attr("BytesIO")(); 858 } else { 859 fileObject = py::module::import("io").attr("StringIO")(); 860 } 861 print(fileObject, /*binary=*/binary, 862 /*largeElementsLimit=*/largeElementsLimit, 863 /*enableDebugInfo=*/enableDebugInfo, 864 /*prettyDebugInfo=*/prettyDebugInfo, 865 /*printGenericOpForm=*/printGenericOpForm, 866 /*useLocalScope=*/useLocalScope); 867 868 return fileObject.attr("getvalue")(); 869 } 870 871 PyOperationRef PyOperation::getParentOperation() { 872 if (!isAttached()) 873 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 874 MlirOperation operation = mlirOperationGetParentOperation(get()); 875 if (mlirOperationIsNull(operation)) 876 throw SetPyError(PyExc_ValueError, "Operation has no parent."); 877 return PyOperation::forOperation(getContext(), operation); 878 } 879 880 PyBlock PyOperation::getBlock() { 881 PyOperationRef parentOperation = getParentOperation(); 882 MlirBlock block = mlirOperationGetBlock(get()); 883 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 884 return PyBlock{std::move(parentOperation), block}; 885 } 886 887 py::object PyOperation::getCapsule() { 888 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 889 } 890 891 py::object PyOperation::createFromCapsule(py::object capsule) { 892 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 893 if (mlirOperationIsNull(rawOperation)) 894 throw py::error_already_set(); 895 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 896 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 897 .releaseObject(); 898 } 899 900 py::object PyOperation::create( 901 std::string name, llvm::Optional<std::vector<PyType *>> results, 902 llvm::Optional<std::vector<PyValue *>> operands, 903 llvm::Optional<py::dict> attributes, 904 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 905 DefaultingPyLocation location, py::object maybeIp) { 906 llvm::SmallVector<MlirValue, 4> mlirOperands; 907 llvm::SmallVector<MlirType, 4> mlirResults; 908 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 909 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 910 911 // General parameter validation. 912 if (regions < 0) 913 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 914 915 // Unpack/validate operands. 916 if (operands) { 917 mlirOperands.reserve(operands->size()); 918 for (PyValue *operand : *operands) { 919 if (!operand) 920 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 921 mlirOperands.push_back(operand->get()); 922 } 923 } 924 925 // Unpack/validate results. 926 if (results) { 927 mlirResults.reserve(results->size()); 928 for (PyType *result : *results) { 929 // TODO: Verify result type originate from the same context. 930 if (!result) 931 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 932 mlirResults.push_back(*result); 933 } 934 } 935 // Unpack/validate attributes. 936 if (attributes) { 937 mlirAttributes.reserve(attributes->size()); 938 for (auto &it : *attributes) { 939 std::string key; 940 try { 941 key = it.first.cast<std::string>(); 942 } catch (py::cast_error &err) { 943 std::string msg = "Invalid attribute key (not a string) when " 944 "attempting to create the operation \"" + 945 name + "\" (" + err.what() + ")"; 946 throw py::cast_error(msg); 947 } 948 try { 949 auto &attribute = it.second.cast<PyAttribute &>(); 950 // TODO: Verify attribute originates from the same context. 951 mlirAttributes.emplace_back(std::move(key), attribute); 952 } catch (py::reference_cast_error &) { 953 // This exception seems thrown when the value is "None". 954 std::string msg = 955 "Found an invalid (`None`?) attribute value for the key \"" + key + 956 "\" when attempting to create the operation \"" + name + "\""; 957 throw py::cast_error(msg); 958 } catch (py::cast_error &err) { 959 std::string msg = "Invalid attribute value for the key \"" + key + 960 "\" when attempting to create the operation \"" + 961 name + "\" (" + err.what() + ")"; 962 throw py::cast_error(msg); 963 } 964 } 965 } 966 // Unpack/validate successors. 967 if (successors) { 968 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 969 mlirSuccessors.reserve(successors->size()); 970 for (auto *successor : *successors) { 971 // TODO: Verify successor originate from the same context. 972 if (!successor) 973 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 974 mlirSuccessors.push_back(successor->get()); 975 } 976 } 977 978 // Apply unpacked/validated to the operation state. Beyond this 979 // point, exceptions cannot be thrown or else the state will leak. 980 MlirOperationState state = 981 mlirOperationStateGet(toMlirStringRef(name), location); 982 if (!mlirOperands.empty()) 983 mlirOperationStateAddOperands(&state, mlirOperands.size(), 984 mlirOperands.data()); 985 if (!mlirResults.empty()) 986 mlirOperationStateAddResults(&state, mlirResults.size(), 987 mlirResults.data()); 988 if (!mlirAttributes.empty()) { 989 // Note that the attribute names directly reference bytes in 990 // mlirAttributes, so that vector must not be changed from here 991 // on. 992 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 993 mlirNamedAttributes.reserve(mlirAttributes.size()); 994 for (auto &it : mlirAttributes) 995 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 996 mlirIdentifierGet(mlirAttributeGetContext(it.second), 997 toMlirStringRef(it.first)), 998 it.second)); 999 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1000 mlirNamedAttributes.data()); 1001 } 1002 if (!mlirSuccessors.empty()) 1003 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1004 mlirSuccessors.data()); 1005 if (regions) { 1006 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1007 mlirRegions.resize(regions); 1008 for (int i = 0; i < regions; ++i) 1009 mlirRegions[i] = mlirRegionCreate(); 1010 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1011 mlirRegions.data()); 1012 } 1013 1014 // Construct the operation. 1015 MlirOperation operation = mlirOperationCreate(&state); 1016 PyOperationRef created = 1017 PyOperation::createDetached(location->getContext(), operation); 1018 1019 // InsertPoint active? 1020 if (!maybeIp.is(py::cast(false))) { 1021 PyInsertionPoint *ip; 1022 if (maybeIp.is_none()) { 1023 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1024 } else { 1025 ip = py::cast<PyInsertionPoint *>(maybeIp); 1026 } 1027 if (ip) 1028 ip->insert(*created.get()); 1029 } 1030 1031 return created->createOpView(); 1032 } 1033 1034 py::object PyOperation::createOpView() { 1035 MlirIdentifier ident = mlirOperationGetName(get()); 1036 MlirStringRef identStr = mlirIdentifierStr(ident); 1037 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1038 StringRef(identStr.data, identStr.length)); 1039 if (opViewClass) 1040 return (*opViewClass)(getRef().getObject()); 1041 return py::cast(PyOpView(getRef().getObject())); 1042 } 1043 1044 //------------------------------------------------------------------------------ 1045 // PyOpView 1046 //------------------------------------------------------------------------------ 1047 1048 py::object 1049 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1050 py::list operandList, 1051 llvm::Optional<py::dict> attributes, 1052 llvm::Optional<std::vector<PyBlock *>> successors, 1053 llvm::Optional<int> regions, 1054 DefaultingPyLocation location, py::object maybeIp) { 1055 PyMlirContextRef context = location->getContext(); 1056 // Class level operation construction metadata. 1057 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1058 // Operand and result segment specs are either none, which does no 1059 // variadic unpacking, or a list of ints with segment sizes, where each 1060 // element is either a positive number (typically 1 for a scalar) or -1 to 1061 // indicate that it is derived from the length of the same-indexed operand 1062 // or result (implying that it is a list at that position). 1063 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1064 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1065 1066 std::vector<uint32_t> operandSegmentLengths; 1067 std::vector<uint32_t> resultSegmentLengths; 1068 1069 // Validate/determine region count. 1070 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1071 int opMinRegionCount = std::get<0>(opRegionSpec); 1072 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1073 if (!regions) { 1074 regions = opMinRegionCount; 1075 } 1076 if (*regions < opMinRegionCount) { 1077 throw py::value_error( 1078 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1079 llvm::Twine(opMinRegionCount) + 1080 " regions but was built with regions=" + llvm::Twine(*regions)) 1081 .str()); 1082 } 1083 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1084 throw py::value_error( 1085 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1086 llvm::Twine(opMinRegionCount) + 1087 " regions but was built with regions=" + llvm::Twine(*regions)) 1088 .str()); 1089 } 1090 1091 // Unpack results. 1092 std::vector<PyType *> resultTypes; 1093 resultTypes.reserve(resultTypeList.size()); 1094 if (resultSegmentSpecObj.is_none()) { 1095 // Non-variadic result unpacking. 1096 for (auto it : llvm::enumerate(resultTypeList)) { 1097 try { 1098 resultTypes.push_back(py::cast<PyType *>(it.value())); 1099 if (!resultTypes.back()) 1100 throw py::cast_error(); 1101 } catch (py::cast_error &err) { 1102 throw py::value_error((llvm::Twine("Result ") + 1103 llvm::Twine(it.index()) + " of operation \"" + 1104 name + "\" must be a Type (" + err.what() + ")") 1105 .str()); 1106 } 1107 } 1108 } else { 1109 // Sized result unpacking. 1110 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1111 if (resultSegmentSpec.size() != resultTypeList.size()) { 1112 throw py::value_error((llvm::Twine("Operation \"") + name + 1113 "\" requires " + 1114 llvm::Twine(resultSegmentSpec.size()) + 1115 "result segments but was provided " + 1116 llvm::Twine(resultTypeList.size())) 1117 .str()); 1118 } 1119 resultSegmentLengths.reserve(resultTypeList.size()); 1120 for (auto it : 1121 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1122 int segmentSpec = std::get<1>(it.value()); 1123 if (segmentSpec == 1 || segmentSpec == 0) { 1124 // Unpack unary element. 1125 try { 1126 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1127 if (resultType) { 1128 resultTypes.push_back(resultType); 1129 resultSegmentLengths.push_back(1); 1130 } else if (segmentSpec == 0) { 1131 // Allowed to be optional. 1132 resultSegmentLengths.push_back(0); 1133 } else { 1134 throw py::cast_error("was None and result is not optional"); 1135 } 1136 } catch (py::cast_error &err) { 1137 throw py::value_error((llvm::Twine("Result ") + 1138 llvm::Twine(it.index()) + " of operation \"" + 1139 name + "\" must be a Type (" + err.what() + 1140 ")") 1141 .str()); 1142 } 1143 } else if (segmentSpec == -1) { 1144 // Unpack sequence by appending. 1145 try { 1146 if (std::get<0>(it.value()).is_none()) { 1147 // Treat it as an empty list. 1148 resultSegmentLengths.push_back(0); 1149 } else { 1150 // Unpack the list. 1151 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1152 for (py::object segmentItem : segment) { 1153 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1154 if (!resultTypes.back()) { 1155 throw py::cast_error("contained a None item"); 1156 } 1157 } 1158 resultSegmentLengths.push_back(segment.size()); 1159 } 1160 } catch (std::exception &err) { 1161 // NOTE: Sloppy to be using a catch-all here, but there are at least 1162 // three different unrelated exceptions that can be thrown in the 1163 // above "casts". Just keep the scope above small and catch them all. 1164 throw py::value_error((llvm::Twine("Result ") + 1165 llvm::Twine(it.index()) + " of operation \"" + 1166 name + "\" must be a Sequence of Types (" + 1167 err.what() + ")") 1168 .str()); 1169 } 1170 } else { 1171 throw py::value_error("Unexpected segment spec"); 1172 } 1173 } 1174 } 1175 1176 // Unpack operands. 1177 std::vector<PyValue *> operands; 1178 operands.reserve(operands.size()); 1179 if (operandSegmentSpecObj.is_none()) { 1180 // Non-sized operand unpacking. 1181 for (auto it : llvm::enumerate(operandList)) { 1182 try { 1183 operands.push_back(py::cast<PyValue *>(it.value())); 1184 if (!operands.back()) 1185 throw py::cast_error(); 1186 } catch (py::cast_error &err) { 1187 throw py::value_error((llvm::Twine("Operand ") + 1188 llvm::Twine(it.index()) + " of operation \"" + 1189 name + "\" must be a Value (" + err.what() + ")") 1190 .str()); 1191 } 1192 } 1193 } else { 1194 // Sized operand unpacking. 1195 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1196 if (operandSegmentSpec.size() != operandList.size()) { 1197 throw py::value_error((llvm::Twine("Operation \"") + name + 1198 "\" requires " + 1199 llvm::Twine(operandSegmentSpec.size()) + 1200 "operand segments but was provided " + 1201 llvm::Twine(operandList.size())) 1202 .str()); 1203 } 1204 operandSegmentLengths.reserve(operandList.size()); 1205 for (auto it : 1206 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1207 int segmentSpec = std::get<1>(it.value()); 1208 if (segmentSpec == 1 || segmentSpec == 0) { 1209 // Unpack unary element. 1210 try { 1211 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1212 if (operandValue) { 1213 operands.push_back(operandValue); 1214 operandSegmentLengths.push_back(1); 1215 } else if (segmentSpec == 0) { 1216 // Allowed to be optional. 1217 operandSegmentLengths.push_back(0); 1218 } else { 1219 throw py::cast_error("was None and operand is not optional"); 1220 } 1221 } catch (py::cast_error &err) { 1222 throw py::value_error((llvm::Twine("Operand ") + 1223 llvm::Twine(it.index()) + " of operation \"" + 1224 name + "\" must be a Value (" + err.what() + 1225 ")") 1226 .str()); 1227 } 1228 } else if (segmentSpec == -1) { 1229 // Unpack sequence by appending. 1230 try { 1231 if (std::get<0>(it.value()).is_none()) { 1232 // Treat it as an empty list. 1233 operandSegmentLengths.push_back(0); 1234 } else { 1235 // Unpack the list. 1236 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1237 for (py::object segmentItem : segment) { 1238 operands.push_back(py::cast<PyValue *>(segmentItem)); 1239 if (!operands.back()) { 1240 throw py::cast_error("contained a None item"); 1241 } 1242 } 1243 operandSegmentLengths.push_back(segment.size()); 1244 } 1245 } catch (std::exception &err) { 1246 // NOTE: Sloppy to be using a catch-all here, but there are at least 1247 // three different unrelated exceptions that can be thrown in the 1248 // above "casts". Just keep the scope above small and catch them all. 1249 throw py::value_error((llvm::Twine("Operand ") + 1250 llvm::Twine(it.index()) + " of operation \"" + 1251 name + "\" must be a Sequence of Values (" + 1252 err.what() + ")") 1253 .str()); 1254 } 1255 } else { 1256 throw py::value_error("Unexpected segment spec"); 1257 } 1258 } 1259 } 1260 1261 // Merge operand/result segment lengths into attributes if needed. 1262 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1263 // Dup. 1264 if (attributes) { 1265 attributes = py::dict(*attributes); 1266 } else { 1267 attributes = py::dict(); 1268 } 1269 if (attributes->contains("result_segment_sizes") || 1270 attributes->contains("operand_segment_sizes")) { 1271 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1272 "'operand_segment_sizes' attribute is unsupported. " 1273 "Use Operation.create for such low-level access."); 1274 } 1275 1276 // Add result_segment_sizes attribute. 1277 if (!resultSegmentLengths.empty()) { 1278 int64_t size = resultSegmentLengths.size(); 1279 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1280 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1281 resultSegmentLengths.size(), resultSegmentLengths.data()); 1282 (*attributes)["result_segment_sizes"] = 1283 PyAttribute(context, segmentLengthAttr); 1284 } 1285 1286 // Add operand_segment_sizes attribute. 1287 if (!operandSegmentLengths.empty()) { 1288 int64_t size = operandSegmentLengths.size(); 1289 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1290 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1291 operandSegmentLengths.size(), operandSegmentLengths.data()); 1292 (*attributes)["operand_segment_sizes"] = 1293 PyAttribute(context, segmentLengthAttr); 1294 } 1295 } 1296 1297 // Delegate to create. 1298 return PyOperation::create(std::move(name), 1299 /*results=*/std::move(resultTypes), 1300 /*operands=*/std::move(operands), 1301 /*attributes=*/std::move(attributes), 1302 /*successors=*/std::move(successors), 1303 /*regions=*/*regions, location, maybeIp); 1304 } 1305 1306 PyOpView::PyOpView(py::object operationObject) 1307 // Casting through the PyOperationBase base-class and then back to the 1308 // Operation lets us accept any PyOperationBase subclass. 1309 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1310 operationObject(operation.getRef().getObject()) {} 1311 1312 py::object PyOpView::createRawSubclass(py::object userClass) { 1313 // This is... a little gross. The typical pattern is to have a pure python 1314 // class that extends OpView like: 1315 // class AddFOp(_cext.ir.OpView): 1316 // def __init__(self, loc, lhs, rhs): 1317 // operation = loc.context.create_operation( 1318 // "addf", lhs, rhs, results=[lhs.type]) 1319 // super().__init__(operation) 1320 // 1321 // I.e. The goal of the user facing type is to provide a nice constructor 1322 // that has complete freedom for the op under construction. This is at odds 1323 // with our other desire to sometimes create this object by just passing an 1324 // operation (to initialize the base class). We could do *arg and **kwargs 1325 // munging to try to make it work, but instead, we synthesize a new class 1326 // on the fly which extends this user class (AddFOp in this example) and 1327 // *give it* the base class's __init__ method, thus bypassing the 1328 // intermediate subclass's __init__ method entirely. While slightly, 1329 // underhanded, this is safe/legal because the type hierarchy has not changed 1330 // (we just added a new leaf) and we aren't mucking around with __new__. 1331 // Typically, this new class will be stored on the original as "_Raw" and will 1332 // be used for casts and other things that need a variant of the class that 1333 // is initialized purely from an operation. 1334 py::object parentMetaclass = 1335 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1336 py::dict attributes; 1337 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1338 // now. 1339 // auto opViewType = py::type::of<PyOpView>(); 1340 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1341 attributes["__init__"] = opViewType.attr("__init__"); 1342 py::str origName = userClass.attr("__name__"); 1343 py::str newName = py::str("_") + origName; 1344 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1345 } 1346 1347 //------------------------------------------------------------------------------ 1348 // PyInsertionPoint. 1349 //------------------------------------------------------------------------------ 1350 1351 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1352 1353 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1354 : refOperation(beforeOperationBase.getOperation().getRef()), 1355 block((*refOperation)->getBlock()) {} 1356 1357 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1358 PyOperation &operation = operationBase.getOperation(); 1359 if (operation.isAttached()) 1360 throw SetPyError(PyExc_ValueError, 1361 "Attempt to insert operation that is already attached"); 1362 block.getParentOperation()->checkValid(); 1363 MlirOperation beforeOp = {nullptr}; 1364 if (refOperation) { 1365 // Insert before operation. 1366 (*refOperation)->checkValid(); 1367 beforeOp = (*refOperation)->get(); 1368 } else { 1369 // Insert at end (before null) is only valid if the block does not 1370 // already end in a known terminator (violating this will cause assertion 1371 // failures later). 1372 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1373 throw py::index_error("Cannot insert operation at the end of a block " 1374 "that already has a terminator. Did you mean to " 1375 "use 'InsertionPoint.at_block_terminator(block)' " 1376 "versus 'InsertionPoint(block)'?"); 1377 } 1378 } 1379 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1380 operation.setAttached(); 1381 } 1382 1383 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1384 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1385 if (mlirOperationIsNull(firstOp)) { 1386 // Just insert at end. 1387 return PyInsertionPoint(block); 1388 } 1389 1390 // Insert before first op. 1391 PyOperationRef firstOpRef = PyOperation::forOperation( 1392 block.getParentOperation()->getContext(), firstOp); 1393 return PyInsertionPoint{block, std::move(firstOpRef)}; 1394 } 1395 1396 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1397 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1398 if (mlirOperationIsNull(terminator)) 1399 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1400 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1401 block.getParentOperation()->getContext(), terminator); 1402 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1403 } 1404 1405 py::object PyInsertionPoint::contextEnter() { 1406 return PyThreadContextEntry::pushInsertionPoint(*this); 1407 } 1408 1409 void PyInsertionPoint::contextExit(pybind11::object excType, 1410 pybind11::object excVal, 1411 pybind11::object excTb) { 1412 PyThreadContextEntry::popInsertionPoint(*this); 1413 } 1414 1415 //------------------------------------------------------------------------------ 1416 // PyAttribute. 1417 //------------------------------------------------------------------------------ 1418 1419 bool PyAttribute::operator==(const PyAttribute &other) { 1420 return mlirAttributeEqual(attr, other.attr); 1421 } 1422 1423 py::object PyAttribute::getCapsule() { 1424 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1425 } 1426 1427 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1428 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1429 if (mlirAttributeIsNull(rawAttr)) 1430 throw py::error_already_set(); 1431 return PyAttribute( 1432 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1433 } 1434 1435 //------------------------------------------------------------------------------ 1436 // PyNamedAttribute. 1437 //------------------------------------------------------------------------------ 1438 1439 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1440 : ownedName(new std::string(std::move(ownedName))) { 1441 namedAttr = mlirNamedAttributeGet( 1442 mlirIdentifierGet(mlirAttributeGetContext(attr), 1443 toMlirStringRef(*this->ownedName)), 1444 attr); 1445 } 1446 1447 //------------------------------------------------------------------------------ 1448 // PyType. 1449 //------------------------------------------------------------------------------ 1450 1451 bool PyType::operator==(const PyType &other) { 1452 return mlirTypeEqual(type, other.type); 1453 } 1454 1455 py::object PyType::getCapsule() { 1456 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1457 } 1458 1459 PyType PyType::createFromCapsule(py::object capsule) { 1460 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1461 if (mlirTypeIsNull(rawType)) 1462 throw py::error_already_set(); 1463 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1464 rawType); 1465 } 1466 1467 //------------------------------------------------------------------------------ 1468 // PyValue and subclases. 1469 //------------------------------------------------------------------------------ 1470 1471 pybind11::object PyValue::getCapsule() { 1472 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1473 } 1474 1475 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1476 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1477 if (mlirValueIsNull(value)) 1478 throw py::error_already_set(); 1479 MlirOperation owner; 1480 if (mlirValueIsAOpResult(value)) 1481 owner = mlirOpResultGetOwner(value); 1482 if (mlirValueIsABlockArgument(value)) 1483 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1484 if (mlirOperationIsNull(owner)) 1485 throw py::error_already_set(); 1486 MlirContext ctx = mlirOperationGetContext(owner); 1487 PyOperationRef ownerRef = 1488 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1489 return PyValue(ownerRef, value); 1490 } 1491 1492 namespace { 1493 /// CRTP base class for Python MLIR values that subclass Value and should be 1494 /// castable from it. The value hierarchy is one level deep and is not supposed 1495 /// to accommodate other levels unless core MLIR changes. 1496 template <typename DerivedTy> 1497 class PyConcreteValue : public PyValue { 1498 public: 1499 // Derived classes must define statics for: 1500 // IsAFunctionTy isaFunction 1501 // const char *pyClassName 1502 // and redefine bindDerived. 1503 using ClassTy = py::class_<DerivedTy, PyValue>; 1504 using IsAFunctionTy = bool (*)(MlirValue); 1505 1506 PyConcreteValue() = default; 1507 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1508 : PyValue(operationRef, value) {} 1509 PyConcreteValue(PyValue &orig) 1510 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1511 1512 /// Attempts to cast the original value to the derived type and throws on 1513 /// type mismatches. 1514 static MlirValue castFrom(PyValue &orig) { 1515 if (!DerivedTy::isaFunction(orig.get())) { 1516 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1517 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1518 DerivedTy::pyClassName + 1519 " (from " + origRepr + ")"); 1520 } 1521 return orig.get(); 1522 } 1523 1524 /// Binds the Python module objects to functions of this class. 1525 static void bind(py::module &m) { 1526 auto cls = ClassTy(m, DerivedTy::pyClassName); 1527 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1528 DerivedTy::bindDerived(cls); 1529 } 1530 1531 /// Implemented by derived classes to add methods to the Python subclass. 1532 static void bindDerived(ClassTy &m) {} 1533 }; 1534 1535 /// Python wrapper for MlirBlockArgument. 1536 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1537 public: 1538 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1539 static constexpr const char *pyClassName = "BlockArgument"; 1540 using PyConcreteValue::PyConcreteValue; 1541 1542 static void bindDerived(ClassTy &c) { 1543 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1544 return PyBlock(self.getParentOperation(), 1545 mlirBlockArgumentGetOwner(self.get())); 1546 }); 1547 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1548 return mlirBlockArgumentGetArgNumber(self.get()); 1549 }); 1550 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1551 return mlirBlockArgumentSetType(self.get(), type); 1552 }); 1553 } 1554 }; 1555 1556 /// Python wrapper for MlirOpResult. 1557 class PyOpResult : public PyConcreteValue<PyOpResult> { 1558 public: 1559 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1560 static constexpr const char *pyClassName = "OpResult"; 1561 using PyConcreteValue::PyConcreteValue; 1562 1563 static void bindDerived(ClassTy &c) { 1564 c.def_property_readonly("owner", [](PyOpResult &self) { 1565 assert( 1566 mlirOperationEqual(self.getParentOperation()->get(), 1567 mlirOpResultGetOwner(self.get())) && 1568 "expected the owner of the value in Python to match that in the IR"); 1569 return self.getParentOperation(); 1570 }); 1571 c.def_property_readonly("result_number", [](PyOpResult &self) { 1572 return mlirOpResultGetResultNumber(self.get()); 1573 }); 1574 } 1575 }; 1576 1577 /// A list of block arguments. Internally, these are stored as consecutive 1578 /// elements, random access is cheap. The argument list is associated with the 1579 /// operation that contains the block (detached blocks are not allowed in 1580 /// Python bindings) and extends its lifetime. 1581 class PyBlockArgumentList { 1582 public: 1583 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1584 : operation(std::move(operation)), block(block) {} 1585 1586 /// Returns the length of the block argument list. 1587 intptr_t dunderLen() { 1588 operation->checkValid(); 1589 return mlirBlockGetNumArguments(block); 1590 } 1591 1592 /// Returns `index`-th element of the block argument list. 1593 PyBlockArgument dunderGetItem(intptr_t index) { 1594 if (index < 0 || index >= dunderLen()) { 1595 throw SetPyError(PyExc_IndexError, 1596 "attempt to access out of bounds region"); 1597 } 1598 PyValue value(operation, mlirBlockGetArgument(block, index)); 1599 return PyBlockArgument(value); 1600 } 1601 1602 /// Defines a Python class in the bindings. 1603 static void bind(py::module &m) { 1604 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1605 .def("__len__", &PyBlockArgumentList::dunderLen) 1606 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1607 } 1608 1609 private: 1610 PyOperationRef operation; 1611 MlirBlock block; 1612 }; 1613 1614 /// A list of operation operands. Internally, these are stored as consecutive 1615 /// elements, random access is cheap. The result list is associated with the 1616 /// operation whose results these are, and extends the lifetime of this 1617 /// operation. 1618 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1619 public: 1620 static constexpr const char *pyClassName = "OpOperandList"; 1621 1622 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1623 intptr_t length = -1, intptr_t step = 1) 1624 : Sliceable(startIndex, 1625 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1626 : length, 1627 step), 1628 operation(operation) {} 1629 1630 intptr_t getNumElements() { 1631 operation->checkValid(); 1632 return mlirOperationGetNumOperands(operation->get()); 1633 } 1634 1635 PyValue getElement(intptr_t pos) { 1636 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); 1637 } 1638 1639 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1640 return PyOpOperandList(operation, startIndex, length, step); 1641 } 1642 1643 private: 1644 PyOperationRef operation; 1645 }; 1646 1647 /// A list of operation results. Internally, these are stored as consecutive 1648 /// elements, random access is cheap. The result list is associated with the 1649 /// operation whose results these are, and extends the lifetime of this 1650 /// operation. 1651 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1652 public: 1653 static constexpr const char *pyClassName = "OpResultList"; 1654 1655 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1656 intptr_t length = -1, intptr_t step = 1) 1657 : Sliceable(startIndex, 1658 length == -1 ? mlirOperationGetNumResults(operation->get()) 1659 : length, 1660 step), 1661 operation(operation) {} 1662 1663 intptr_t getNumElements() { 1664 operation->checkValid(); 1665 return mlirOperationGetNumResults(operation->get()); 1666 } 1667 1668 PyOpResult getElement(intptr_t index) { 1669 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1670 return PyOpResult(value); 1671 } 1672 1673 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1674 return PyOpResultList(operation, startIndex, length, step); 1675 } 1676 1677 private: 1678 PyOperationRef operation; 1679 }; 1680 1681 /// A list of operation attributes. Can be indexed by name, producing 1682 /// attributes, or by index, producing named attributes. 1683 class PyOpAttributeMap { 1684 public: 1685 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1686 1687 PyAttribute dunderGetItemNamed(const std::string &name) { 1688 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1689 toMlirStringRef(name)); 1690 if (mlirAttributeIsNull(attr)) { 1691 throw SetPyError(PyExc_KeyError, 1692 "attempt to access a non-existent attribute"); 1693 } 1694 return PyAttribute(operation->getContext(), attr); 1695 } 1696 1697 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1698 if (index < 0 || index >= dunderLen()) { 1699 throw SetPyError(PyExc_IndexError, 1700 "attempt to access out of bounds attribute"); 1701 } 1702 MlirNamedAttribute namedAttr = 1703 mlirOperationGetAttribute(operation->get(), index); 1704 return PyNamedAttribute( 1705 namedAttr.attribute, 1706 std::string(mlirIdentifierStr(namedAttr.name).data)); 1707 } 1708 1709 void dunderSetItem(const std::string &name, PyAttribute attr) { 1710 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1711 attr); 1712 } 1713 1714 void dunderDelItem(const std::string &name) { 1715 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1716 toMlirStringRef(name)); 1717 if (!removed) 1718 throw SetPyError(PyExc_KeyError, 1719 "attempt to delete a non-existent attribute"); 1720 } 1721 1722 intptr_t dunderLen() { 1723 return mlirOperationGetNumAttributes(operation->get()); 1724 } 1725 1726 bool dunderContains(const std::string &name) { 1727 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1728 operation->get(), toMlirStringRef(name))); 1729 } 1730 1731 static void bind(py::module &m) { 1732 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1733 .def("__contains__", &PyOpAttributeMap::dunderContains) 1734 .def("__len__", &PyOpAttributeMap::dunderLen) 1735 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1736 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1737 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1738 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1739 } 1740 1741 private: 1742 PyOperationRef operation; 1743 }; 1744 1745 } // end namespace 1746 1747 //------------------------------------------------------------------------------ 1748 // Populates the core exports of the 'ir' submodule. 1749 //------------------------------------------------------------------------------ 1750 1751 void mlir::python::populateIRCore(py::module &m) { 1752 //---------------------------------------------------------------------------- 1753 // Mapping of MlirContext. 1754 //---------------------------------------------------------------------------- 1755 py::class_<PyMlirContext>(m, "Context") 1756 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1757 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1758 .def("_get_context_again", 1759 [](PyMlirContext &self) { 1760 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1761 return ref.releaseObject(); 1762 }) 1763 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1764 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1765 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1766 &PyMlirContext::getCapsule) 1767 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1768 .def("__enter__", &PyMlirContext::contextEnter) 1769 .def("__exit__", &PyMlirContext::contextExit) 1770 .def_property_readonly_static( 1771 "current", 1772 [](py::object & /*class*/) { 1773 auto *context = PyThreadContextEntry::getDefaultContext(); 1774 if (!context) 1775 throw SetPyError(PyExc_ValueError, "No current Context"); 1776 return context; 1777 }, 1778 "Gets the Context bound to the current thread or raises ValueError") 1779 .def_property_readonly( 1780 "dialects", 1781 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1782 "Gets a container for accessing dialects by name") 1783 .def_property_readonly( 1784 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1785 "Alias for 'dialect'") 1786 .def( 1787 "get_dialect_descriptor", 1788 [=](PyMlirContext &self, std::string &name) { 1789 MlirDialect dialect = mlirContextGetOrLoadDialect( 1790 self.get(), {name.data(), name.size()}); 1791 if (mlirDialectIsNull(dialect)) { 1792 throw SetPyError(PyExc_ValueError, 1793 Twine("Dialect '") + name + "' not found"); 1794 } 1795 return PyDialectDescriptor(self.getRef(), dialect); 1796 }, 1797 "Gets or loads a dialect by name, returning its descriptor object") 1798 .def_property( 1799 "allow_unregistered_dialects", 1800 [](PyMlirContext &self) -> bool { 1801 return mlirContextGetAllowUnregisteredDialects(self.get()); 1802 }, 1803 [](PyMlirContext &self, bool value) { 1804 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1805 }) 1806 .def("enable_multithreading", 1807 [](PyMlirContext &self, bool enable) { 1808 mlirContextEnableMultithreading(self.get(), enable); 1809 }) 1810 .def("is_registered_operation", 1811 [](PyMlirContext &self, std::string &name) { 1812 return mlirContextIsRegisteredOperation( 1813 self.get(), MlirStringRef{name.data(), name.size()}); 1814 }); 1815 1816 //---------------------------------------------------------------------------- 1817 // Mapping of PyDialectDescriptor 1818 //---------------------------------------------------------------------------- 1819 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1820 .def_property_readonly("namespace", 1821 [](PyDialectDescriptor &self) { 1822 MlirStringRef ns = 1823 mlirDialectGetNamespace(self.get()); 1824 return py::str(ns.data, ns.length); 1825 }) 1826 .def("__repr__", [](PyDialectDescriptor &self) { 1827 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1828 std::string repr("<DialectDescriptor "); 1829 repr.append(ns.data, ns.length); 1830 repr.append(">"); 1831 return repr; 1832 }); 1833 1834 //---------------------------------------------------------------------------- 1835 // Mapping of PyDialects 1836 //---------------------------------------------------------------------------- 1837 py::class_<PyDialects>(m, "Dialects") 1838 .def("__getitem__", 1839 [=](PyDialects &self, std::string keyName) { 1840 MlirDialect dialect = 1841 self.getDialectForKey(keyName, /*attrError=*/false); 1842 py::object descriptor = 1843 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1844 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1845 }) 1846 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1847 MlirDialect dialect = 1848 self.getDialectForKey(attrName, /*attrError=*/true); 1849 py::object descriptor = 1850 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1851 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1852 }); 1853 1854 //---------------------------------------------------------------------------- 1855 // Mapping of PyDialect 1856 //---------------------------------------------------------------------------- 1857 py::class_<PyDialect>(m, "Dialect") 1858 .def(py::init<py::object>(), "descriptor") 1859 .def_property_readonly( 1860 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1861 .def("__repr__", [](py::object self) { 1862 auto clazz = self.attr("__class__"); 1863 return py::str("<Dialect ") + 1864 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1865 clazz.attr("__module__") + py::str(".") + 1866 clazz.attr("__name__") + py::str(")>"); 1867 }); 1868 1869 //---------------------------------------------------------------------------- 1870 // Mapping of Location 1871 //---------------------------------------------------------------------------- 1872 py::class_<PyLocation>(m, "Location") 1873 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1874 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1875 .def("__enter__", &PyLocation::contextEnter) 1876 .def("__exit__", &PyLocation::contextExit) 1877 .def("__eq__", 1878 [](PyLocation &self, PyLocation &other) -> bool { 1879 return mlirLocationEqual(self, other); 1880 }) 1881 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1882 .def_property_readonly_static( 1883 "current", 1884 [](py::object & /*class*/) { 1885 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1886 if (!loc) 1887 throw SetPyError(PyExc_ValueError, "No current Location"); 1888 return loc; 1889 }, 1890 "Gets the Location bound to the current thread or raises ValueError") 1891 .def_static( 1892 "unknown", 1893 [](DefaultingPyMlirContext context) { 1894 return PyLocation(context->getRef(), 1895 mlirLocationUnknownGet(context->get())); 1896 }, 1897 py::arg("context") = py::none(), 1898 "Gets a Location representing an unknown location") 1899 .def_static( 1900 "file", 1901 [](std::string filename, int line, int col, 1902 DefaultingPyMlirContext context) { 1903 return PyLocation( 1904 context->getRef(), 1905 mlirLocationFileLineColGet( 1906 context->get(), toMlirStringRef(filename), line, col)); 1907 }, 1908 py::arg("filename"), py::arg("line"), py::arg("col"), 1909 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1910 .def_property_readonly( 1911 "context", 1912 [](PyLocation &self) { return self.getContext().getObject(); }, 1913 "Context that owns the Location") 1914 .def("__repr__", [](PyLocation &self) { 1915 PyPrintAccumulator printAccum; 1916 mlirLocationPrint(self, printAccum.getCallback(), 1917 printAccum.getUserData()); 1918 return printAccum.join(); 1919 }); 1920 1921 //---------------------------------------------------------------------------- 1922 // Mapping of Module 1923 //---------------------------------------------------------------------------- 1924 py::class_<PyModule>(m, "Module") 1925 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1926 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1927 .def_static( 1928 "parse", 1929 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1930 MlirModule module = mlirModuleCreateParse( 1931 context->get(), toMlirStringRef(moduleAsm)); 1932 // TODO: Rework error reporting once diagnostic engine is exposed 1933 // in C API. 1934 if (mlirModuleIsNull(module)) { 1935 throw SetPyError( 1936 PyExc_ValueError, 1937 "Unable to parse module assembly (see diagnostics)"); 1938 } 1939 return PyModule::forModule(module).releaseObject(); 1940 }, 1941 py::arg("asm"), py::arg("context") = py::none(), 1942 kModuleParseDocstring) 1943 .def_static( 1944 "create", 1945 [](DefaultingPyLocation loc) { 1946 MlirModule module = mlirModuleCreateEmpty(loc); 1947 return PyModule::forModule(module).releaseObject(); 1948 }, 1949 py::arg("loc") = py::none(), "Creates an empty module") 1950 .def_property_readonly( 1951 "context", 1952 [](PyModule &self) { return self.getContext().getObject(); }, 1953 "Context that created the Module") 1954 .def_property_readonly( 1955 "operation", 1956 [](PyModule &self) { 1957 return PyOperation::forOperation(self.getContext(), 1958 mlirModuleGetOperation(self.get()), 1959 self.getRef().releaseObject()) 1960 .releaseObject(); 1961 }, 1962 "Accesses the module as an operation") 1963 .def_property_readonly( 1964 "body", 1965 [](PyModule &self) { 1966 PyOperationRef module_op = PyOperation::forOperation( 1967 self.getContext(), mlirModuleGetOperation(self.get()), 1968 self.getRef().releaseObject()); 1969 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 1970 return returnBlock; 1971 }, 1972 "Return the block for this module") 1973 .def( 1974 "dump", 1975 [](PyModule &self) { 1976 mlirOperationDump(mlirModuleGetOperation(self.get())); 1977 }, 1978 kDumpDocstring) 1979 .def( 1980 "__str__", 1981 [](PyModule &self) { 1982 MlirOperation operation = mlirModuleGetOperation(self.get()); 1983 PyPrintAccumulator printAccum; 1984 mlirOperationPrint(operation, printAccum.getCallback(), 1985 printAccum.getUserData()); 1986 return printAccum.join(); 1987 }, 1988 kOperationStrDunderDocstring); 1989 1990 //---------------------------------------------------------------------------- 1991 // Mapping of Operation. 1992 //---------------------------------------------------------------------------- 1993 py::class_<PyOperationBase>(m, "_OperationBase") 1994 .def("__eq__", 1995 [](PyOperationBase &self, PyOperationBase &other) { 1996 return &self.getOperation() == &other.getOperation(); 1997 }) 1998 .def("__eq__", 1999 [](PyOperationBase &self, py::object other) { return false; }) 2000 .def_property_readonly("attributes", 2001 [](PyOperationBase &self) { 2002 return PyOpAttributeMap( 2003 self.getOperation().getRef()); 2004 }) 2005 .def_property_readonly("operands", 2006 [](PyOperationBase &self) { 2007 return PyOpOperandList( 2008 self.getOperation().getRef()); 2009 }) 2010 .def_property_readonly("regions", 2011 [](PyOperationBase &self) { 2012 return PyRegionList( 2013 self.getOperation().getRef()); 2014 }) 2015 .def_property_readonly( 2016 "results", 2017 [](PyOperationBase &self) { 2018 return PyOpResultList(self.getOperation().getRef()); 2019 }, 2020 "Returns the list of Operation results.") 2021 .def_property_readonly( 2022 "result", 2023 [](PyOperationBase &self) { 2024 auto &operation = self.getOperation(); 2025 auto numResults = mlirOperationGetNumResults(operation); 2026 if (numResults != 1) { 2027 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2028 throw SetPyError( 2029 PyExc_ValueError, 2030 Twine("Cannot call .result on operation ") + 2031 StringRef(name.data, name.length) + " which has " + 2032 Twine(numResults) + 2033 " results (it is only valid for operations with a " 2034 "single result)"); 2035 } 2036 return PyOpResult(operation.getRef(), 2037 mlirOperationGetResult(operation, 0)); 2038 }, 2039 "Shortcut to get an op result if it has only one (throws an error " 2040 "otherwise).") 2041 .def("__iter__", 2042 [](PyOperationBase &self) { 2043 return PyRegionIterator(self.getOperation().getRef()); 2044 }) 2045 .def( 2046 "__str__", 2047 [](PyOperationBase &self) { 2048 return self.getAsm(/*binary=*/false, 2049 /*largeElementsLimit=*/llvm::None, 2050 /*enableDebugInfo=*/false, 2051 /*prettyDebugInfo=*/false, 2052 /*printGenericOpForm=*/false, 2053 /*useLocalScope=*/false); 2054 }, 2055 "Returns the assembly form of the operation.") 2056 .def("print", &PyOperationBase::print, 2057 // Careful: Lots of arguments must match up with print method. 2058 py::arg("file") = py::none(), py::arg("binary") = false, 2059 py::arg("large_elements_limit") = py::none(), 2060 py::arg("enable_debug_info") = false, 2061 py::arg("pretty_debug_info") = false, 2062 py::arg("print_generic_op_form") = false, 2063 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2064 .def("get_asm", &PyOperationBase::getAsm, 2065 // Careful: Lots of arguments must match up with get_asm method. 2066 py::arg("binary") = false, 2067 py::arg("large_elements_limit") = py::none(), 2068 py::arg("enable_debug_info") = false, 2069 py::arg("pretty_debug_info") = false, 2070 py::arg("print_generic_op_form") = false, 2071 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2072 .def( 2073 "verify", 2074 [](PyOperationBase &self) { 2075 return mlirOperationVerify(self.getOperation()); 2076 }, 2077 "Verify the operation and return true if it passes, false if it " 2078 "fails."); 2079 2080 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2081 .def_static("create", &PyOperation::create, py::arg("name"), 2082 py::arg("results") = py::none(), 2083 py::arg("operands") = py::none(), 2084 py::arg("attributes") = py::none(), 2085 py::arg("successors") = py::none(), py::arg("regions") = 0, 2086 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2087 kOperationCreateDocstring) 2088 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2089 &PyOperation::getCapsule) 2090 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2091 .def_property_readonly("name", 2092 [](PyOperation &self) { 2093 MlirOperation operation = self.get(); 2094 MlirStringRef name = mlirIdentifierStr( 2095 mlirOperationGetName(operation)); 2096 return py::str(name.data, name.length); 2097 }) 2098 .def_property_readonly( 2099 "context", 2100 [](PyOperation &self) { return self.getContext().getObject(); }, 2101 "Context that owns the Operation") 2102 .def_property_readonly("opview", &PyOperation::createOpView); 2103 2104 auto opViewClass = 2105 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2106 .def(py::init<py::object>()) 2107 .def_property_readonly("operation", &PyOpView::getOperationObject) 2108 .def_property_readonly( 2109 "context", 2110 [](PyOpView &self) { 2111 return self.getOperation().getContext().getObject(); 2112 }, 2113 "Context that owns the Operation") 2114 .def("__str__", [](PyOpView &self) { 2115 return py::str(self.getOperationObject()); 2116 }); 2117 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2118 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2119 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2120 opViewClass.attr("build_generic") = classmethod( 2121 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2122 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2123 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2124 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2125 "Builds a specific, generated OpView based on class level attributes."); 2126 2127 //---------------------------------------------------------------------------- 2128 // Mapping of PyRegion. 2129 //---------------------------------------------------------------------------- 2130 py::class_<PyRegion>(m, "Region") 2131 .def_property_readonly( 2132 "blocks", 2133 [](PyRegion &self) { 2134 return PyBlockList(self.getParentOperation(), self.get()); 2135 }, 2136 "Returns a forward-optimized sequence of blocks.") 2137 .def( 2138 "__iter__", 2139 [](PyRegion &self) { 2140 self.checkValid(); 2141 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2142 return PyBlockIterator(self.getParentOperation(), firstBlock); 2143 }, 2144 "Iterates over blocks in the region.") 2145 .def("__eq__", 2146 [](PyRegion &self, PyRegion &other) { 2147 return self.get().ptr == other.get().ptr; 2148 }) 2149 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2150 2151 //---------------------------------------------------------------------------- 2152 // Mapping of PyBlock. 2153 //---------------------------------------------------------------------------- 2154 py::class_<PyBlock>(m, "Block") 2155 .def_property_readonly( 2156 "arguments", 2157 [](PyBlock &self) { 2158 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2159 }, 2160 "Returns a list of block arguments.") 2161 .def_property_readonly( 2162 "operations", 2163 [](PyBlock &self) { 2164 return PyOperationList(self.getParentOperation(), self.get()); 2165 }, 2166 "Returns a forward-optimized sequence of operations.") 2167 .def( 2168 "__iter__", 2169 [](PyBlock &self) { 2170 self.checkValid(); 2171 MlirOperation firstOperation = 2172 mlirBlockGetFirstOperation(self.get()); 2173 return PyOperationIterator(self.getParentOperation(), 2174 firstOperation); 2175 }, 2176 "Iterates over operations in the block.") 2177 .def("__eq__", 2178 [](PyBlock &self, PyBlock &other) { 2179 return self.get().ptr == other.get().ptr; 2180 }) 2181 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2182 .def( 2183 "__str__", 2184 [](PyBlock &self) { 2185 self.checkValid(); 2186 PyPrintAccumulator printAccum; 2187 mlirBlockPrint(self.get(), printAccum.getCallback(), 2188 printAccum.getUserData()); 2189 return printAccum.join(); 2190 }, 2191 "Returns the assembly form of the block."); 2192 2193 //---------------------------------------------------------------------------- 2194 // Mapping of PyInsertionPoint. 2195 //---------------------------------------------------------------------------- 2196 2197 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2198 .def(py::init<PyBlock &>(), py::arg("block"), 2199 "Inserts after the last operation but still inside the block.") 2200 .def("__enter__", &PyInsertionPoint::contextEnter) 2201 .def("__exit__", &PyInsertionPoint::contextExit) 2202 .def_property_readonly_static( 2203 "current", 2204 [](py::object & /*class*/) { 2205 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2206 if (!ip) 2207 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2208 return ip; 2209 }, 2210 "Gets the InsertionPoint bound to the current thread or raises " 2211 "ValueError if none has been set") 2212 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2213 "Inserts before a referenced operation.") 2214 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2215 py::arg("block"), "Inserts at the beginning of the block.") 2216 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2217 py::arg("block"), "Inserts before the block terminator.") 2218 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2219 "Inserts an operation."); 2220 2221 //---------------------------------------------------------------------------- 2222 // Mapping of PyAttribute. 2223 //---------------------------------------------------------------------------- 2224 py::class_<PyAttribute>(m, "Attribute") 2225 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2226 &PyAttribute::getCapsule) 2227 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2228 .def_static( 2229 "parse", 2230 [](std::string attrSpec, DefaultingPyMlirContext context) { 2231 MlirAttribute type = mlirAttributeParseGet( 2232 context->get(), toMlirStringRef(attrSpec)); 2233 // TODO: Rework error reporting once diagnostic engine is exposed 2234 // in C API. 2235 if (mlirAttributeIsNull(type)) { 2236 throw SetPyError(PyExc_ValueError, 2237 Twine("Unable to parse attribute: '") + 2238 attrSpec + "'"); 2239 } 2240 return PyAttribute(context->getRef(), type); 2241 }, 2242 py::arg("asm"), py::arg("context") = py::none(), 2243 "Parses an attribute from an assembly form") 2244 .def_property_readonly( 2245 "context", 2246 [](PyAttribute &self) { return self.getContext().getObject(); }, 2247 "Context that owns the Attribute") 2248 .def_property_readonly("type", 2249 [](PyAttribute &self) { 2250 return PyType(self.getContext()->getRef(), 2251 mlirAttributeGetType(self)); 2252 }) 2253 .def( 2254 "get_named", 2255 [](PyAttribute &self, std::string name) { 2256 return PyNamedAttribute(self, std::move(name)); 2257 }, 2258 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2259 .def("__eq__", 2260 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2261 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2262 .def( 2263 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2264 kDumpDocstring) 2265 .def( 2266 "__str__", 2267 [](PyAttribute &self) { 2268 PyPrintAccumulator printAccum; 2269 mlirAttributePrint(self, printAccum.getCallback(), 2270 printAccum.getUserData()); 2271 return printAccum.join(); 2272 }, 2273 "Returns the assembly form of the Attribute.") 2274 .def("__repr__", [](PyAttribute &self) { 2275 // Generally, assembly formats are not printed for __repr__ because 2276 // this can cause exceptionally long debug output and exceptions. 2277 // However, attribute values are generally considered useful and are 2278 // printed. This may need to be re-evaluated if debug dumps end up 2279 // being excessive. 2280 PyPrintAccumulator printAccum; 2281 printAccum.parts.append("Attribute("); 2282 mlirAttributePrint(self, printAccum.getCallback(), 2283 printAccum.getUserData()); 2284 printAccum.parts.append(")"); 2285 return printAccum.join(); 2286 }); 2287 2288 //---------------------------------------------------------------------------- 2289 // Mapping of PyNamedAttribute 2290 //---------------------------------------------------------------------------- 2291 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2292 .def("__repr__", 2293 [](PyNamedAttribute &self) { 2294 PyPrintAccumulator printAccum; 2295 printAccum.parts.append("NamedAttribute("); 2296 printAccum.parts.append( 2297 mlirIdentifierStr(self.namedAttr.name).data); 2298 printAccum.parts.append("="); 2299 mlirAttributePrint(self.namedAttr.attribute, 2300 printAccum.getCallback(), 2301 printAccum.getUserData()); 2302 printAccum.parts.append(")"); 2303 return printAccum.join(); 2304 }) 2305 .def_property_readonly( 2306 "name", 2307 [](PyNamedAttribute &self) { 2308 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2309 mlirIdentifierStr(self.namedAttr.name).length); 2310 }, 2311 "The name of the NamedAttribute binding") 2312 .def_property_readonly( 2313 "attr", 2314 [](PyNamedAttribute &self) { 2315 // TODO: When named attribute is removed/refactored, also remove 2316 // this constructor (it does an inefficient table lookup). 2317 auto contextRef = PyMlirContext::forContext( 2318 mlirAttributeGetContext(self.namedAttr.attribute)); 2319 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2320 }, 2321 py::keep_alive<0, 1>(), 2322 "The underlying generic attribute of the NamedAttribute binding"); 2323 2324 //---------------------------------------------------------------------------- 2325 // Mapping of PyType. 2326 //---------------------------------------------------------------------------- 2327 py::class_<PyType>(m, "Type") 2328 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2329 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2330 .def_static( 2331 "parse", 2332 [](std::string typeSpec, DefaultingPyMlirContext context) { 2333 MlirType type = 2334 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2335 // TODO: Rework error reporting once diagnostic engine is exposed 2336 // in C API. 2337 if (mlirTypeIsNull(type)) { 2338 throw SetPyError(PyExc_ValueError, 2339 Twine("Unable to parse type: '") + typeSpec + 2340 "'"); 2341 } 2342 return PyType(context->getRef(), type); 2343 }, 2344 py::arg("asm"), py::arg("context") = py::none(), 2345 kContextParseTypeDocstring) 2346 .def_property_readonly( 2347 "context", [](PyType &self) { return self.getContext().getObject(); }, 2348 "Context that owns the Type") 2349 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2350 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2351 .def( 2352 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2353 .def( 2354 "__str__", 2355 [](PyType &self) { 2356 PyPrintAccumulator printAccum; 2357 mlirTypePrint(self, printAccum.getCallback(), 2358 printAccum.getUserData()); 2359 return printAccum.join(); 2360 }, 2361 "Returns the assembly form of the type.") 2362 .def("__repr__", [](PyType &self) { 2363 // Generally, assembly formats are not printed for __repr__ because 2364 // this can cause exceptionally long debug output and exceptions. 2365 // However, types are an exception as they typically have compact 2366 // assembly forms and printing them is useful. 2367 PyPrintAccumulator printAccum; 2368 printAccum.parts.append("Type("); 2369 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2370 printAccum.parts.append(")"); 2371 return printAccum.join(); 2372 }); 2373 2374 //---------------------------------------------------------------------------- 2375 // Mapping of Value. 2376 //---------------------------------------------------------------------------- 2377 py::class_<PyValue>(m, "Value") 2378 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2379 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2380 .def_property_readonly( 2381 "context", 2382 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2383 "Context in which the value lives.") 2384 .def( 2385 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2386 kDumpDocstring) 2387 .def("__eq__", 2388 [](PyValue &self, PyValue &other) { 2389 return self.get().ptr == other.get().ptr; 2390 }) 2391 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2392 .def( 2393 "__str__", 2394 [](PyValue &self) { 2395 PyPrintAccumulator printAccum; 2396 printAccum.parts.append("Value("); 2397 mlirValuePrint(self.get(), printAccum.getCallback(), 2398 printAccum.getUserData()); 2399 printAccum.parts.append(")"); 2400 return printAccum.join(); 2401 }, 2402 kValueDunderStrDocstring) 2403 .def_property_readonly("type", [](PyValue &self) { 2404 return PyType(self.getParentOperation()->getContext(), 2405 mlirValueGetType(self.get())); 2406 }); 2407 PyBlockArgument::bind(m); 2408 PyOpResult::bind(m); 2409 2410 // Container bindings. 2411 PyBlockArgumentList::bind(m); 2412 PyBlockIterator::bind(m); 2413 PyBlockList::bind(m); 2414 PyOperationIterator::bind(m); 2415 PyOperationList::bind(m); 2416 PyOpAttributeMap::bind(m); 2417 PyOpOperandList::bind(m); 2418 PyOpResultList::bind(m); 2419 PyRegionIterator::bind(m); 2420 PyRegionList::bind(m); 2421 2422 // Debug bindings. 2423 PyGlobalDebugFlag::bind(m); 2424 } 2425