1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include <pybind11/stl.h> 22 23 namespace py = pybind11; 24 using namespace mlir; 25 using namespace mlir::python; 26 27 using llvm::SmallVector; 28 using llvm::StringRef; 29 using llvm::Twine; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kContextParseTypeDocstring[] = 36 R"(Parses the assembly form of a type. 37 38 Returns a Type object or raises a ValueError if the type cannot be parsed. 39 40 See also: https://mlir.llvm.org/docs/LangRef/#type-system 41 )"; 42 43 static const char kContextGetFileLocationDocstring[] = 44 R"(Gets a Location representing a file, line and column)"; 45 46 static const char kModuleParseDocstring[] = 47 R"(Parses a module's assembly format from a string. 48 49 Returns a new MlirModule or raises a ValueError if the parsing fails. 50 51 See also: https://mlir.llvm.org/docs/LangRef/ 52 )"; 53 54 static const char kOperationCreateDocstring[] = 55 R"(Creates a new operation. 56 57 Args: 58 name: Operation name (e.g. "dialect.operation"). 59 results: Sequence of Type representing op result types. 60 attributes: Dict of str:Attribute. 61 successors: List of Block for the operation's successors. 62 regions: Number of regions to create. 63 location: A Location object (defaults to resolve from context manager). 64 ip: An InsertionPoint (defaults to resolve from context manager or set to 65 False to disable insertion, even with an insertion point set in the 66 context manager). 67 Returns: 68 A new "detached" Operation object. Detached operations can be added 69 to blocks, which causes them to become "attached." 70 )"; 71 72 static const char kOperationPrintDocstring[] = 73 R"(Prints the assembly form of the operation to a file like object. 74 75 Args: 76 file: The file like object to write to. Defaults to sys.stdout. 77 binary: Whether to write bytes (True) or str (False). Defaults to False. 78 large_elements_limit: Whether to elide elements attributes above this 79 number of elements. Defaults to None (no limit). 80 enable_debug_info: Whether to print debug/location information. Defaults 81 to False. 82 pretty_debug_info: Whether to format debug information for easier reading 83 by a human (warning: the result is unparseable). 84 print_generic_op_form: Whether to print the generic assembly forms of all 85 ops. Defaults to False. 86 use_local_Scope: Whether to print in a way that is more optimized for 87 multi-threaded access but may not be consistent with how the overall 88 module prints. 89 )"; 90 91 static const char kOperationGetAsmDocstring[] = 92 R"(Gets the assembly form of the operation with all options available. 93 94 Args: 95 binary: Whether to return a bytes (True) or str (False) object. Defaults to 96 False. 97 ... others ...: See the print() method for common keyword arguments for 98 configuring the printout. 99 Returns: 100 Either a bytes or str object, depending on the setting of the 'binary' 101 argument. 102 )"; 103 104 static const char kOperationStrDunderDocstring[] = 105 R"(Gets the assembly form of the operation with default options. 106 107 If more advanced control over the assembly formatting or I/O options is needed, 108 use the dedicated print or get_asm method, which supports keyword arguments to 109 customize behavior. 110 )"; 111 112 static const char kDumpDocstring[] = 113 R"(Dumps a debug representation of the object to stderr.)"; 114 115 static const char kAppendBlockDocstring[] = 116 R"(Appends a new block, with argument types as positional args. 117 118 Returns: 119 The created block. 120 )"; 121 122 static const char kValueDunderStrDocstring[] = 123 R"(Returns the string form of the value. 124 125 If the value is a block argument, this is the assembly form of its type and the 126 position in the argument list. If the value is an operation result, this is 127 equivalent to printing the operation that produced it. 128 )"; 129 130 //------------------------------------------------------------------------------ 131 // Utilities. 132 //------------------------------------------------------------------------------ 133 134 /// Helper for creating an @classmethod. 135 template <class Func, typename... Args> 136 py::object classmethod(Func f, Args... args) { 137 py::object cf = py::cpp_function(f, args...); 138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 139 } 140 141 static py::object 142 createCustomDialectWrapper(const std::string &dialectNamespace, 143 py::object dialectDescriptor) { 144 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 145 if (!dialectClass) { 146 // Use the base class. 147 return py::cast(PyDialect(std::move(dialectDescriptor))); 148 } 149 150 // Create the custom implementation. 151 return (*dialectClass)(std::move(dialectDescriptor)); 152 } 153 154 static MlirStringRef toMlirStringRef(const std::string &s) { 155 return mlirStringRefCreate(s.data(), s.size()); 156 } 157 158 /// Wrapper for the global LLVM debugging flag. 159 struct PyGlobalDebugFlag { 160 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 161 162 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 163 164 static void bind(py::module &m) { 165 // Debug flags. 166 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") 167 .def_property_static("flag", &PyGlobalDebugFlag::get, 168 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 169 } 170 }; 171 172 //------------------------------------------------------------------------------ 173 // Collections. 174 //------------------------------------------------------------------------------ 175 176 namespace { 177 178 class PyRegionIterator { 179 public: 180 PyRegionIterator(PyOperationRef operation) 181 : operation(std::move(operation)) {} 182 183 PyRegionIterator &dunderIter() { return *this; } 184 185 PyRegion dunderNext() { 186 operation->checkValid(); 187 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 188 throw py::stop_iteration(); 189 } 190 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 191 return PyRegion(operation, region); 192 } 193 194 static void bind(py::module &m) { 195 py::class_<PyRegionIterator>(m, "RegionIterator") 196 .def("__iter__", &PyRegionIterator::dunderIter) 197 .def("__next__", &PyRegionIterator::dunderNext); 198 } 199 200 private: 201 PyOperationRef operation; 202 int nextIndex = 0; 203 }; 204 205 /// Regions of an op are fixed length and indexed numerically so are represented 206 /// with a sequence-like container. 207 class PyRegionList { 208 public: 209 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 210 211 intptr_t dunderLen() { 212 operation->checkValid(); 213 return mlirOperationGetNumRegions(operation->get()); 214 } 215 216 PyRegion dunderGetItem(intptr_t index) { 217 // dunderLen checks validity. 218 if (index < 0 || index >= dunderLen()) { 219 throw SetPyError(PyExc_IndexError, 220 "attempt to access out of bounds region"); 221 } 222 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 223 return PyRegion(operation, region); 224 } 225 226 static void bind(py::module &m) { 227 py::class_<PyRegionList>(m, "RegionSequence") 228 .def("__len__", &PyRegionList::dunderLen) 229 .def("__getitem__", &PyRegionList::dunderGetItem); 230 } 231 232 private: 233 PyOperationRef operation; 234 }; 235 236 class PyBlockIterator { 237 public: 238 PyBlockIterator(PyOperationRef operation, MlirBlock next) 239 : operation(std::move(operation)), next(next) {} 240 241 PyBlockIterator &dunderIter() { return *this; } 242 243 PyBlock dunderNext() { 244 operation->checkValid(); 245 if (mlirBlockIsNull(next)) { 246 throw py::stop_iteration(); 247 } 248 249 PyBlock returnBlock(operation, next); 250 next = mlirBlockGetNextInRegion(next); 251 return returnBlock; 252 } 253 254 static void bind(py::module &m) { 255 py::class_<PyBlockIterator>(m, "BlockIterator") 256 .def("__iter__", &PyBlockIterator::dunderIter) 257 .def("__next__", &PyBlockIterator::dunderNext); 258 } 259 260 private: 261 PyOperationRef operation; 262 MlirBlock next; 263 }; 264 265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 266 /// we present them as a more full-featured list-like container but optimize 267 /// it for forward iteration. Blocks are always owned by a region. 268 class PyBlockList { 269 public: 270 PyBlockList(PyOperationRef operation, MlirRegion region) 271 : operation(std::move(operation)), region(region) {} 272 273 PyBlockIterator dunderIter() { 274 operation->checkValid(); 275 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 276 } 277 278 intptr_t dunderLen() { 279 operation->checkValid(); 280 intptr_t count = 0; 281 MlirBlock block = mlirRegionGetFirstBlock(region); 282 while (!mlirBlockIsNull(block)) { 283 count += 1; 284 block = mlirBlockGetNextInRegion(block); 285 } 286 return count; 287 } 288 289 PyBlock dunderGetItem(intptr_t index) { 290 operation->checkValid(); 291 if (index < 0) { 292 throw SetPyError(PyExc_IndexError, 293 "attempt to access out of bounds block"); 294 } 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 if (index == 0) { 298 return PyBlock(operation, block); 299 } 300 block = mlirBlockGetNextInRegion(block); 301 index -= 1; 302 } 303 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 304 } 305 306 PyBlock appendBlock(py::args pyArgTypes) { 307 operation->checkValid(); 308 llvm::SmallVector<MlirType, 4> argTypes; 309 argTypes.reserve(pyArgTypes.size()); 310 for (auto &pyArg : pyArgTypes) { 311 argTypes.push_back(pyArg.cast<PyType &>()); 312 } 313 314 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 315 mlirRegionAppendOwnedBlock(region, block); 316 return PyBlock(operation, block); 317 } 318 319 static void bind(py::module &m) { 320 py::class_<PyBlockList>(m, "BlockList") 321 .def("__getitem__", &PyBlockList::dunderGetItem) 322 .def("__iter__", &PyBlockList::dunderIter) 323 .def("__len__", &PyBlockList::dunderLen) 324 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 325 } 326 327 private: 328 PyOperationRef operation; 329 MlirRegion region; 330 }; 331 332 class PyOperationIterator { 333 public: 334 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 335 : parentOperation(std::move(parentOperation)), next(next) {} 336 337 PyOperationIterator &dunderIter() { return *this; } 338 339 py::object dunderNext() { 340 parentOperation->checkValid(); 341 if (mlirOperationIsNull(next)) { 342 throw py::stop_iteration(); 343 } 344 345 PyOperationRef returnOperation = 346 PyOperation::forOperation(parentOperation->getContext(), next); 347 next = mlirOperationGetNextInBlock(next); 348 return returnOperation->createOpView(); 349 } 350 351 static void bind(py::module &m) { 352 py::class_<PyOperationIterator>(m, "OperationIterator") 353 .def("__iter__", &PyOperationIterator::dunderIter) 354 .def("__next__", &PyOperationIterator::dunderNext); 355 } 356 357 private: 358 PyOperationRef parentOperation; 359 MlirOperation next; 360 }; 361 362 /// Operations are exposed by the C-API as a forward-only linked list. In 363 /// Python, we present them as a more full-featured list-like container but 364 /// optimize it for forward iteration. Iterable operations are always owned 365 /// by a block. 366 class PyOperationList { 367 public: 368 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 369 : parentOperation(std::move(parentOperation)), block(block) {} 370 371 PyOperationIterator dunderIter() { 372 parentOperation->checkValid(); 373 return PyOperationIterator(parentOperation, 374 mlirBlockGetFirstOperation(block)); 375 } 376 377 intptr_t dunderLen() { 378 parentOperation->checkValid(); 379 intptr_t count = 0; 380 MlirOperation childOp = mlirBlockGetFirstOperation(block); 381 while (!mlirOperationIsNull(childOp)) { 382 count += 1; 383 childOp = mlirOperationGetNextInBlock(childOp); 384 } 385 return count; 386 } 387 388 py::object dunderGetItem(intptr_t index) { 389 parentOperation->checkValid(); 390 if (index < 0) { 391 throw SetPyError(PyExc_IndexError, 392 "attempt to access out of bounds operation"); 393 } 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 if (index == 0) { 397 return PyOperation::forOperation(parentOperation->getContext(), childOp) 398 ->createOpView(); 399 } 400 childOp = mlirOperationGetNextInBlock(childOp); 401 index -= 1; 402 } 403 throw SetPyError(PyExc_IndexError, 404 "attempt to access out of bounds operation"); 405 } 406 407 static void bind(py::module &m) { 408 py::class_<PyOperationList>(m, "OperationList") 409 .def("__getitem__", &PyOperationList::dunderGetItem) 410 .def("__iter__", &PyOperationList::dunderIter) 411 .def("__len__", &PyOperationList::dunderLen); 412 } 413 414 private: 415 PyOperationRef parentOperation; 416 MlirBlock block; 417 }; 418 419 } // namespace 420 421 //------------------------------------------------------------------------------ 422 // PyMlirContext 423 //------------------------------------------------------------------------------ 424 425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 426 py::gil_scoped_acquire acquire; 427 auto &liveContexts = getLiveContexts(); 428 liveContexts[context.ptr] = this; 429 } 430 431 PyMlirContext::~PyMlirContext() { 432 // Note that the only public way to construct an instance is via the 433 // forContext method, which always puts the associated handle into 434 // liveContexts. 435 py::gil_scoped_acquire acquire; 436 getLiveContexts().erase(context.ptr); 437 mlirContextDestroy(context); 438 } 439 440 py::object PyMlirContext::getCapsule() { 441 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 442 } 443 444 py::object PyMlirContext::createFromCapsule(py::object capsule) { 445 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 446 if (mlirContextIsNull(rawContext)) 447 throw py::error_already_set(); 448 return forContext(rawContext).releaseObject(); 449 } 450 451 PyMlirContext *PyMlirContext::createNewContextForInit() { 452 MlirContext context = mlirContextCreate(); 453 mlirRegisterAllDialects(context); 454 return new PyMlirContext(context); 455 } 456 457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 458 py::gil_scoped_acquire acquire; 459 auto &liveContexts = getLiveContexts(); 460 auto it = liveContexts.find(context.ptr); 461 if (it == liveContexts.end()) { 462 // Create. 463 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 464 py::object pyRef = py::cast(unownedContextWrapper); 465 assert(pyRef && "cast to py::object failed"); 466 liveContexts[context.ptr] = unownedContextWrapper; 467 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 468 } 469 // Use existing. 470 py::object pyRef = py::cast(it->second); 471 return PyMlirContextRef(it->second, std::move(pyRef)); 472 } 473 474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 475 static LiveContextMap liveContexts; 476 return liveContexts; 477 } 478 479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 480 481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 482 483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 484 485 pybind11::object PyMlirContext::contextEnter() { 486 return PyThreadContextEntry::pushContext(*this); 487 } 488 489 void PyMlirContext::contextExit(pybind11::object excType, 490 pybind11::object excVal, 491 pybind11::object excTb) { 492 PyThreadContextEntry::popContext(*this); 493 } 494 495 PyMlirContext &DefaultingPyMlirContext::resolve() { 496 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 497 if (!context) { 498 throw SetPyError( 499 PyExc_RuntimeError, 500 "An MLIR function requires a Context but none was provided in the call " 501 "or from the surrounding environment. Either pass to the function with " 502 "a 'context=' argument or establish a default using 'with Context():'"); 503 } 504 return *context; 505 } 506 507 //------------------------------------------------------------------------------ 508 // PyThreadContextEntry management 509 //------------------------------------------------------------------------------ 510 511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 512 static thread_local std::vector<PyThreadContextEntry> stack; 513 return stack; 514 } 515 516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 517 auto &stack = getStack(); 518 if (stack.empty()) 519 return nullptr; 520 return &stack.back(); 521 } 522 523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 524 py::object insertionPoint, 525 py::object location) { 526 auto &stack = getStack(); 527 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 528 std::move(location)); 529 // If the new stack has more than one entry and the context of the new top 530 // entry matches the previous, copy the insertionPoint and location from the 531 // previous entry if missing from the new top entry. 532 if (stack.size() > 1) { 533 auto &prev = *(stack.rbegin() + 1); 534 auto ¤t = stack.back(); 535 if (current.context.is(prev.context)) { 536 // Default non-context objects from the previous entry. 537 if (!current.insertionPoint) 538 current.insertionPoint = prev.insertionPoint; 539 if (!current.location) 540 current.location = prev.location; 541 } 542 } 543 } 544 545 PyMlirContext *PyThreadContextEntry::getContext() { 546 if (!context) 547 return nullptr; 548 return py::cast<PyMlirContext *>(context); 549 } 550 551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 552 if (!insertionPoint) 553 return nullptr; 554 return py::cast<PyInsertionPoint *>(insertionPoint); 555 } 556 557 PyLocation *PyThreadContextEntry::getLocation() { 558 if (!location) 559 return nullptr; 560 return py::cast<PyLocation *>(location); 561 } 562 563 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 564 auto *tos = getTopOfStack(); 565 return tos ? tos->getContext() : nullptr; 566 } 567 568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 569 auto *tos = getTopOfStack(); 570 return tos ? tos->getInsertionPoint() : nullptr; 571 } 572 573 PyLocation *PyThreadContextEntry::getDefaultLocation() { 574 auto *tos = getTopOfStack(); 575 return tos ? tos->getLocation() : nullptr; 576 } 577 578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 579 py::object contextObj = py::cast(context); 580 push(FrameKind::Context, /*context=*/contextObj, 581 /*insertionPoint=*/py::object(), 582 /*location=*/py::object()); 583 return contextObj; 584 } 585 586 void PyThreadContextEntry::popContext(PyMlirContext &context) { 587 auto &stack = getStack(); 588 if (stack.empty()) 589 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 590 auto &tos = stack.back(); 591 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593 stack.pop_back(); 594 } 595 596 py::object 597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 598 py::object contextObj = 599 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 600 py::object insertionPointObj = py::cast(insertionPoint); 601 push(FrameKind::InsertionPoint, 602 /*context=*/contextObj, 603 /*insertionPoint=*/insertionPointObj, 604 /*location=*/py::object()); 605 return insertionPointObj; 606 } 607 608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 609 auto &stack = getStack(); 610 if (stack.empty()) 611 throw SetPyError(PyExc_RuntimeError, 612 "Unbalanced InsertionPoint enter/exit"); 613 auto &tos = stack.back(); 614 if (tos.frameKind != FrameKind::InsertionPoint && 615 tos.getInsertionPoint() != &insertionPoint) 616 throw SetPyError(PyExc_RuntimeError, 617 "Unbalanced InsertionPoint enter/exit"); 618 stack.pop_back(); 619 } 620 621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 622 py::object contextObj = location.getContext().getObject(); 623 py::object locationObj = py::cast(location); 624 push(FrameKind::Location, /*context=*/contextObj, 625 /*insertionPoint=*/py::object(), 626 /*location=*/locationObj); 627 return locationObj; 628 } 629 630 void PyThreadContextEntry::popLocation(PyLocation &location) { 631 auto &stack = getStack(); 632 if (stack.empty()) 633 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 634 auto &tos = stack.back(); 635 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637 stack.pop_back(); 638 } 639 640 //------------------------------------------------------------------------------ 641 // PyDialect, PyDialectDescriptor, PyDialects 642 //------------------------------------------------------------------------------ 643 644 MlirDialect PyDialects::getDialectForKey(const std::string &key, 645 bool attrError) { 646 // If the "std" dialect was asked for, substitute the empty namespace :( 647 static const std::string emptyKey; 648 const std::string *canonKey = key == "std" ? &emptyKey : &key; 649 MlirDialect dialect = mlirContextGetOrLoadDialect( 650 getContext()->get(), {canonKey->data(), canonKey->size()}); 651 if (mlirDialectIsNull(dialect)) { 652 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 653 Twine("Dialect '") + key + "' not found"); 654 } 655 return dialect; 656 } 657 658 //------------------------------------------------------------------------------ 659 // PyLocation 660 //------------------------------------------------------------------------------ 661 662 py::object PyLocation::getCapsule() { 663 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 664 } 665 666 PyLocation PyLocation::createFromCapsule(py::object capsule) { 667 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 668 if (mlirLocationIsNull(rawLoc)) 669 throw py::error_already_set(); 670 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 671 rawLoc); 672 } 673 674 py::object PyLocation::contextEnter() { 675 return PyThreadContextEntry::pushLocation(*this); 676 } 677 678 void PyLocation::contextExit(py::object excType, py::object excVal, 679 py::object excTb) { 680 PyThreadContextEntry::popLocation(*this); 681 } 682 683 PyLocation &DefaultingPyLocation::resolve() { 684 auto *location = PyThreadContextEntry::getDefaultLocation(); 685 if (!location) { 686 throw SetPyError( 687 PyExc_RuntimeError, 688 "An MLIR function requires a Location but none was provided in the " 689 "call or from the surrounding environment. Either pass to the function " 690 "with a 'loc=' argument or establish a default using 'with loc:'"); 691 } 692 return *location; 693 } 694 695 //------------------------------------------------------------------------------ 696 // PyModule 697 //------------------------------------------------------------------------------ 698 699 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 700 : BaseContextObject(std::move(contextRef)), module(module) {} 701 702 PyModule::~PyModule() { 703 py::gil_scoped_acquire acquire; 704 auto &liveModules = getContext()->liveModules; 705 assert(liveModules.count(module.ptr) == 1 && 706 "destroying module not in live map"); 707 liveModules.erase(module.ptr); 708 mlirModuleDestroy(module); 709 } 710 711 PyModuleRef PyModule::forModule(MlirModule module) { 712 MlirContext context = mlirModuleGetContext(module); 713 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 714 715 py::gil_scoped_acquire acquire; 716 auto &liveModules = contextRef->liveModules; 717 auto it = liveModules.find(module.ptr); 718 if (it == liveModules.end()) { 719 // Create. 720 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 721 // Note that the default return value policy on cast is automatic_reference, 722 // which does not take ownership (delete will not be called). 723 // Just be explicit. 724 py::object pyRef = 725 py::cast(unownedModule, py::return_value_policy::take_ownership); 726 unownedModule->handle = pyRef; 727 liveModules[module.ptr] = 728 std::make_pair(unownedModule->handle, unownedModule); 729 return PyModuleRef(unownedModule, std::move(pyRef)); 730 } 731 // Use existing. 732 PyModule *existing = it->second.second; 733 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 734 return PyModuleRef(existing, std::move(pyRef)); 735 } 736 737 py::object PyModule::createFromCapsule(py::object capsule) { 738 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 739 if (mlirModuleIsNull(rawModule)) 740 throw py::error_already_set(); 741 return forModule(rawModule).releaseObject(); 742 } 743 744 py::object PyModule::getCapsule() { 745 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 746 } 747 748 //------------------------------------------------------------------------------ 749 // PyOperation 750 //------------------------------------------------------------------------------ 751 752 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 753 : BaseContextObject(std::move(contextRef)), operation(operation) {} 754 755 PyOperation::~PyOperation() { 756 // If the operation has already been invalidated there is nothing to do. 757 if (!valid) 758 return; 759 auto &liveOperations = getContext()->liveOperations; 760 assert(liveOperations.count(operation.ptr) == 1 && 761 "destroying operation not in live map"); 762 liveOperations.erase(operation.ptr); 763 if (!isAttached()) { 764 mlirOperationDestroy(operation); 765 } 766 } 767 768 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 769 MlirOperation operation, 770 py::object parentKeepAlive) { 771 auto &liveOperations = contextRef->liveOperations; 772 // Create. 773 PyOperation *unownedOperation = 774 new PyOperation(std::move(contextRef), operation); 775 // Note that the default return value policy on cast is automatic_reference, 776 // which does not take ownership (delete will not be called). 777 // Just be explicit. 778 py::object pyRef = 779 py::cast(unownedOperation, py::return_value_policy::take_ownership); 780 unownedOperation->handle = pyRef; 781 if (parentKeepAlive) { 782 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 783 } 784 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 785 return PyOperationRef(unownedOperation, std::move(pyRef)); 786 } 787 788 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 789 MlirOperation operation, 790 py::object parentKeepAlive) { 791 auto &liveOperations = contextRef->liveOperations; 792 auto it = liveOperations.find(operation.ptr); 793 if (it == liveOperations.end()) { 794 // Create. 795 return createInstance(std::move(contextRef), operation, 796 std::move(parentKeepAlive)); 797 } 798 // Use existing. 799 PyOperation *existing = it->second.second; 800 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 801 return PyOperationRef(existing, std::move(pyRef)); 802 } 803 804 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 805 MlirOperation operation, 806 py::object parentKeepAlive) { 807 auto &liveOperations = contextRef->liveOperations; 808 assert(liveOperations.count(operation.ptr) == 0 && 809 "cannot create detached operation that already exists"); 810 (void)liveOperations; 811 812 PyOperationRef created = createInstance(std::move(contextRef), operation, 813 std::move(parentKeepAlive)); 814 created->attached = false; 815 return created; 816 } 817 818 void PyOperation::checkValid() const { 819 if (!valid) { 820 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 821 } 822 } 823 824 void PyOperationBase::print(py::object fileObject, bool binary, 825 llvm::Optional<int64_t> largeElementsLimit, 826 bool enableDebugInfo, bool prettyDebugInfo, 827 bool printGenericOpForm, bool useLocalScope) { 828 PyOperation &operation = getOperation(); 829 operation.checkValid(); 830 if (fileObject.is_none()) 831 fileObject = py::module::import("sys").attr("stdout"); 832 833 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 834 fileObject.attr("write")("// Verification failed, printing generic form\n"); 835 printGenericOpForm = true; 836 } 837 838 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 839 if (largeElementsLimit) 840 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 841 if (enableDebugInfo) 842 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 843 if (printGenericOpForm) 844 mlirOpPrintingFlagsPrintGenericOpForm(flags); 845 846 PyFileAccumulator accum(fileObject, binary); 847 py::gil_scoped_release(); 848 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 849 accum.getUserData()); 850 mlirOpPrintingFlagsDestroy(flags); 851 } 852 853 py::object PyOperationBase::getAsm(bool binary, 854 llvm::Optional<int64_t> largeElementsLimit, 855 bool enableDebugInfo, bool prettyDebugInfo, 856 bool printGenericOpForm, 857 bool useLocalScope) { 858 py::object fileObject; 859 if (binary) { 860 fileObject = py::module::import("io").attr("BytesIO")(); 861 } else { 862 fileObject = py::module::import("io").attr("StringIO")(); 863 } 864 print(fileObject, /*binary=*/binary, 865 /*largeElementsLimit=*/largeElementsLimit, 866 /*enableDebugInfo=*/enableDebugInfo, 867 /*prettyDebugInfo=*/prettyDebugInfo, 868 /*printGenericOpForm=*/printGenericOpForm, 869 /*useLocalScope=*/useLocalScope); 870 871 return fileObject.attr("getvalue")(); 872 } 873 874 PyOperationRef PyOperation::getParentOperation() { 875 checkValid(); 876 if (!isAttached()) 877 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 878 MlirOperation operation = mlirOperationGetParentOperation(get()); 879 if (mlirOperationIsNull(operation)) 880 throw SetPyError(PyExc_ValueError, "Operation has no parent."); 881 return PyOperation::forOperation(getContext(), operation); 882 } 883 884 PyBlock PyOperation::getBlock() { 885 checkValid(); 886 PyOperationRef parentOperation = getParentOperation(); 887 MlirBlock block = mlirOperationGetBlock(get()); 888 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 889 return PyBlock{std::move(parentOperation), block}; 890 } 891 892 py::object PyOperation::getCapsule() { 893 checkValid(); 894 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 895 } 896 897 py::object PyOperation::createFromCapsule(py::object capsule) { 898 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 899 if (mlirOperationIsNull(rawOperation)) 900 throw py::error_already_set(); 901 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 902 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 903 .releaseObject(); 904 } 905 906 py::object PyOperation::create( 907 std::string name, llvm::Optional<std::vector<PyType *>> results, 908 llvm::Optional<std::vector<PyValue *>> operands, 909 llvm::Optional<py::dict> attributes, 910 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 911 DefaultingPyLocation location, py::object maybeIp) { 912 llvm::SmallVector<MlirValue, 4> mlirOperands; 913 llvm::SmallVector<MlirType, 4> mlirResults; 914 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 915 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 916 917 // General parameter validation. 918 if (regions < 0) 919 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 920 921 // Unpack/validate operands. 922 if (operands) { 923 mlirOperands.reserve(operands->size()); 924 for (PyValue *operand : *operands) { 925 if (!operand) 926 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 927 mlirOperands.push_back(operand->get()); 928 } 929 } 930 931 // Unpack/validate results. 932 if (results) { 933 mlirResults.reserve(results->size()); 934 for (PyType *result : *results) { 935 // TODO: Verify result type originate from the same context. 936 if (!result) 937 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 938 mlirResults.push_back(*result); 939 } 940 } 941 // Unpack/validate attributes. 942 if (attributes) { 943 mlirAttributes.reserve(attributes->size()); 944 for (auto &it : *attributes) { 945 std::string key; 946 try { 947 key = it.first.cast<std::string>(); 948 } catch (py::cast_error &err) { 949 std::string msg = "Invalid attribute key (not a string) when " 950 "attempting to create the operation \"" + 951 name + "\" (" + err.what() + ")"; 952 throw py::cast_error(msg); 953 } 954 try { 955 auto &attribute = it.second.cast<PyAttribute &>(); 956 // TODO: Verify attribute originates from the same context. 957 mlirAttributes.emplace_back(std::move(key), attribute); 958 } catch (py::reference_cast_error &) { 959 // This exception seems thrown when the value is "None". 960 std::string msg = 961 "Found an invalid (`None`?) attribute value for the key \"" + key + 962 "\" when attempting to create the operation \"" + name + "\""; 963 throw py::cast_error(msg); 964 } catch (py::cast_error &err) { 965 std::string msg = "Invalid attribute value for the key \"" + key + 966 "\" when attempting to create the operation \"" + 967 name + "\" (" + err.what() + ")"; 968 throw py::cast_error(msg); 969 } 970 } 971 } 972 // Unpack/validate successors. 973 if (successors) { 974 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 975 mlirSuccessors.reserve(successors->size()); 976 for (auto *successor : *successors) { 977 // TODO: Verify successor originate from the same context. 978 if (!successor) 979 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 980 mlirSuccessors.push_back(successor->get()); 981 } 982 } 983 984 // Apply unpacked/validated to the operation state. Beyond this 985 // point, exceptions cannot be thrown or else the state will leak. 986 MlirOperationState state = 987 mlirOperationStateGet(toMlirStringRef(name), location); 988 if (!mlirOperands.empty()) 989 mlirOperationStateAddOperands(&state, mlirOperands.size(), 990 mlirOperands.data()); 991 if (!mlirResults.empty()) 992 mlirOperationStateAddResults(&state, mlirResults.size(), 993 mlirResults.data()); 994 if (!mlirAttributes.empty()) { 995 // Note that the attribute names directly reference bytes in 996 // mlirAttributes, so that vector must not be changed from here 997 // on. 998 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 999 mlirNamedAttributes.reserve(mlirAttributes.size()); 1000 for (auto &it : mlirAttributes) 1001 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1002 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1003 toMlirStringRef(it.first)), 1004 it.second)); 1005 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1006 mlirNamedAttributes.data()); 1007 } 1008 if (!mlirSuccessors.empty()) 1009 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1010 mlirSuccessors.data()); 1011 if (regions) { 1012 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1013 mlirRegions.resize(regions); 1014 for (int i = 0; i < regions; ++i) 1015 mlirRegions[i] = mlirRegionCreate(); 1016 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1017 mlirRegions.data()); 1018 } 1019 1020 // Construct the operation. 1021 MlirOperation operation = mlirOperationCreate(&state); 1022 PyOperationRef created = 1023 PyOperation::createDetached(location->getContext(), operation); 1024 1025 // InsertPoint active? 1026 if (!maybeIp.is(py::cast(false))) { 1027 PyInsertionPoint *ip; 1028 if (maybeIp.is_none()) { 1029 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1030 } else { 1031 ip = py::cast<PyInsertionPoint *>(maybeIp); 1032 } 1033 if (ip) 1034 ip->insert(*created.get()); 1035 } 1036 1037 return created->createOpView(); 1038 } 1039 1040 py::object PyOperation::createOpView() { 1041 checkValid(); 1042 MlirIdentifier ident = mlirOperationGetName(get()); 1043 MlirStringRef identStr = mlirIdentifierStr(ident); 1044 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1045 StringRef(identStr.data, identStr.length)); 1046 if (opViewClass) 1047 return (*opViewClass)(getRef().getObject()); 1048 return py::cast(PyOpView(getRef().getObject())); 1049 } 1050 1051 void PyOperation::erase() { 1052 checkValid(); 1053 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1054 // Python reference to a child operation is live. All children should also 1055 // have their `valid` bit set to false. 1056 auto &liveOperations = getContext()->liveOperations; 1057 if (liveOperations.count(operation.ptr)) 1058 liveOperations.erase(operation.ptr); 1059 mlirOperationDestroy(operation); 1060 valid = false; 1061 } 1062 1063 //------------------------------------------------------------------------------ 1064 // PyOpView 1065 //------------------------------------------------------------------------------ 1066 1067 py::object 1068 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1069 py::list operandList, 1070 llvm::Optional<py::dict> attributes, 1071 llvm::Optional<std::vector<PyBlock *>> successors, 1072 llvm::Optional<int> regions, 1073 DefaultingPyLocation location, py::object maybeIp) { 1074 PyMlirContextRef context = location->getContext(); 1075 // Class level operation construction metadata. 1076 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1077 // Operand and result segment specs are either none, which does no 1078 // variadic unpacking, or a list of ints with segment sizes, where each 1079 // element is either a positive number (typically 1 for a scalar) or -1 to 1080 // indicate that it is derived from the length of the same-indexed operand 1081 // or result (implying that it is a list at that position). 1082 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1083 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1084 1085 std::vector<uint32_t> operandSegmentLengths; 1086 std::vector<uint32_t> resultSegmentLengths; 1087 1088 // Validate/determine region count. 1089 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1090 int opMinRegionCount = std::get<0>(opRegionSpec); 1091 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1092 if (!regions) { 1093 regions = opMinRegionCount; 1094 } 1095 if (*regions < opMinRegionCount) { 1096 throw py::value_error( 1097 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1098 llvm::Twine(opMinRegionCount) + 1099 " regions but was built with regions=" + llvm::Twine(*regions)) 1100 .str()); 1101 } 1102 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1103 throw py::value_error( 1104 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1105 llvm::Twine(opMinRegionCount) + 1106 " regions but was built with regions=" + llvm::Twine(*regions)) 1107 .str()); 1108 } 1109 1110 // Unpack results. 1111 std::vector<PyType *> resultTypes; 1112 resultTypes.reserve(resultTypeList.size()); 1113 if (resultSegmentSpecObj.is_none()) { 1114 // Non-variadic result unpacking. 1115 for (auto it : llvm::enumerate(resultTypeList)) { 1116 try { 1117 resultTypes.push_back(py::cast<PyType *>(it.value())); 1118 if (!resultTypes.back()) 1119 throw py::cast_error(); 1120 } catch (py::cast_error &err) { 1121 throw py::value_error((llvm::Twine("Result ") + 1122 llvm::Twine(it.index()) + " of operation \"" + 1123 name + "\" must be a Type (" + err.what() + ")") 1124 .str()); 1125 } 1126 } 1127 } else { 1128 // Sized result unpacking. 1129 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1130 if (resultSegmentSpec.size() != resultTypeList.size()) { 1131 throw py::value_error((llvm::Twine("Operation \"") + name + 1132 "\" requires " + 1133 llvm::Twine(resultSegmentSpec.size()) + 1134 "result segments but was provided " + 1135 llvm::Twine(resultTypeList.size())) 1136 .str()); 1137 } 1138 resultSegmentLengths.reserve(resultTypeList.size()); 1139 for (auto it : 1140 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1141 int segmentSpec = std::get<1>(it.value()); 1142 if (segmentSpec == 1 || segmentSpec == 0) { 1143 // Unpack unary element. 1144 try { 1145 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1146 if (resultType) { 1147 resultTypes.push_back(resultType); 1148 resultSegmentLengths.push_back(1); 1149 } else if (segmentSpec == 0) { 1150 // Allowed to be optional. 1151 resultSegmentLengths.push_back(0); 1152 } else { 1153 throw py::cast_error("was None and result is not optional"); 1154 } 1155 } catch (py::cast_error &err) { 1156 throw py::value_error((llvm::Twine("Result ") + 1157 llvm::Twine(it.index()) + " of operation \"" + 1158 name + "\" must be a Type (" + err.what() + 1159 ")") 1160 .str()); 1161 } 1162 } else if (segmentSpec == -1) { 1163 // Unpack sequence by appending. 1164 try { 1165 if (std::get<0>(it.value()).is_none()) { 1166 // Treat it as an empty list. 1167 resultSegmentLengths.push_back(0); 1168 } else { 1169 // Unpack the list. 1170 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1171 for (py::object segmentItem : segment) { 1172 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1173 if (!resultTypes.back()) { 1174 throw py::cast_error("contained a None item"); 1175 } 1176 } 1177 resultSegmentLengths.push_back(segment.size()); 1178 } 1179 } catch (std::exception &err) { 1180 // NOTE: Sloppy to be using a catch-all here, but there are at least 1181 // three different unrelated exceptions that can be thrown in the 1182 // above "casts". Just keep the scope above small and catch them all. 1183 throw py::value_error((llvm::Twine("Result ") + 1184 llvm::Twine(it.index()) + " of operation \"" + 1185 name + "\" must be a Sequence of Types (" + 1186 err.what() + ")") 1187 .str()); 1188 } 1189 } else { 1190 throw py::value_error("Unexpected segment spec"); 1191 } 1192 } 1193 } 1194 1195 // Unpack operands. 1196 std::vector<PyValue *> operands; 1197 operands.reserve(operands.size()); 1198 if (operandSegmentSpecObj.is_none()) { 1199 // Non-sized operand unpacking. 1200 for (auto it : llvm::enumerate(operandList)) { 1201 try { 1202 operands.push_back(py::cast<PyValue *>(it.value())); 1203 if (!operands.back()) 1204 throw py::cast_error(); 1205 } catch (py::cast_error &err) { 1206 throw py::value_error((llvm::Twine("Operand ") + 1207 llvm::Twine(it.index()) + " of operation \"" + 1208 name + "\" must be a Value (" + err.what() + ")") 1209 .str()); 1210 } 1211 } 1212 } else { 1213 // Sized operand unpacking. 1214 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1215 if (operandSegmentSpec.size() != operandList.size()) { 1216 throw py::value_error((llvm::Twine("Operation \"") + name + 1217 "\" requires " + 1218 llvm::Twine(operandSegmentSpec.size()) + 1219 "operand segments but was provided " + 1220 llvm::Twine(operandList.size())) 1221 .str()); 1222 } 1223 operandSegmentLengths.reserve(operandList.size()); 1224 for (auto it : 1225 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1226 int segmentSpec = std::get<1>(it.value()); 1227 if (segmentSpec == 1 || segmentSpec == 0) { 1228 // Unpack unary element. 1229 try { 1230 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1231 if (operandValue) { 1232 operands.push_back(operandValue); 1233 operandSegmentLengths.push_back(1); 1234 } else if (segmentSpec == 0) { 1235 // Allowed to be optional. 1236 operandSegmentLengths.push_back(0); 1237 } else { 1238 throw py::cast_error("was None and operand is not optional"); 1239 } 1240 } catch (py::cast_error &err) { 1241 throw py::value_error((llvm::Twine("Operand ") + 1242 llvm::Twine(it.index()) + " of operation \"" + 1243 name + "\" must be a Value (" + err.what() + 1244 ")") 1245 .str()); 1246 } 1247 } else if (segmentSpec == -1) { 1248 // Unpack sequence by appending. 1249 try { 1250 if (std::get<0>(it.value()).is_none()) { 1251 // Treat it as an empty list. 1252 operandSegmentLengths.push_back(0); 1253 } else { 1254 // Unpack the list. 1255 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1256 for (py::object segmentItem : segment) { 1257 operands.push_back(py::cast<PyValue *>(segmentItem)); 1258 if (!operands.back()) { 1259 throw py::cast_error("contained a None item"); 1260 } 1261 } 1262 operandSegmentLengths.push_back(segment.size()); 1263 } 1264 } catch (std::exception &err) { 1265 // NOTE: Sloppy to be using a catch-all here, but there are at least 1266 // three different unrelated exceptions that can be thrown in the 1267 // above "casts". Just keep the scope above small and catch them all. 1268 throw py::value_error((llvm::Twine("Operand ") + 1269 llvm::Twine(it.index()) + " of operation \"" + 1270 name + "\" must be a Sequence of Values (" + 1271 err.what() + ")") 1272 .str()); 1273 } 1274 } else { 1275 throw py::value_error("Unexpected segment spec"); 1276 } 1277 } 1278 } 1279 1280 // Merge operand/result segment lengths into attributes if needed. 1281 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1282 // Dup. 1283 if (attributes) { 1284 attributes = py::dict(*attributes); 1285 } else { 1286 attributes = py::dict(); 1287 } 1288 if (attributes->contains("result_segment_sizes") || 1289 attributes->contains("operand_segment_sizes")) { 1290 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1291 "'operand_segment_sizes' attribute is unsupported. " 1292 "Use Operation.create for such low-level access."); 1293 } 1294 1295 // Add result_segment_sizes attribute. 1296 if (!resultSegmentLengths.empty()) { 1297 int64_t size = resultSegmentLengths.size(); 1298 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1299 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1300 resultSegmentLengths.size(), resultSegmentLengths.data()); 1301 (*attributes)["result_segment_sizes"] = 1302 PyAttribute(context, segmentLengthAttr); 1303 } 1304 1305 // Add operand_segment_sizes attribute. 1306 if (!operandSegmentLengths.empty()) { 1307 int64_t size = operandSegmentLengths.size(); 1308 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1309 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1310 operandSegmentLengths.size(), operandSegmentLengths.data()); 1311 (*attributes)["operand_segment_sizes"] = 1312 PyAttribute(context, segmentLengthAttr); 1313 } 1314 } 1315 1316 // Delegate to create. 1317 return PyOperation::create(std::move(name), 1318 /*results=*/std::move(resultTypes), 1319 /*operands=*/std::move(operands), 1320 /*attributes=*/std::move(attributes), 1321 /*successors=*/std::move(successors), 1322 /*regions=*/*regions, location, maybeIp); 1323 } 1324 1325 PyOpView::PyOpView(py::object operationObject) 1326 // Casting through the PyOperationBase base-class and then back to the 1327 // Operation lets us accept any PyOperationBase subclass. 1328 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1329 operationObject(operation.getRef().getObject()) {} 1330 1331 py::object PyOpView::createRawSubclass(py::object userClass) { 1332 // This is... a little gross. The typical pattern is to have a pure python 1333 // class that extends OpView like: 1334 // class AddFOp(_cext.ir.OpView): 1335 // def __init__(self, loc, lhs, rhs): 1336 // operation = loc.context.create_operation( 1337 // "addf", lhs, rhs, results=[lhs.type]) 1338 // super().__init__(operation) 1339 // 1340 // I.e. The goal of the user facing type is to provide a nice constructor 1341 // that has complete freedom for the op under construction. This is at odds 1342 // with our other desire to sometimes create this object by just passing an 1343 // operation (to initialize the base class). We could do *arg and **kwargs 1344 // munging to try to make it work, but instead, we synthesize a new class 1345 // on the fly which extends this user class (AddFOp in this example) and 1346 // *give it* the base class's __init__ method, thus bypassing the 1347 // intermediate subclass's __init__ method entirely. While slightly, 1348 // underhanded, this is safe/legal because the type hierarchy has not changed 1349 // (we just added a new leaf) and we aren't mucking around with __new__. 1350 // Typically, this new class will be stored on the original as "_Raw" and will 1351 // be used for casts and other things that need a variant of the class that 1352 // is initialized purely from an operation. 1353 py::object parentMetaclass = 1354 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1355 py::dict attributes; 1356 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1357 // now. 1358 // auto opViewType = py::type::of<PyOpView>(); 1359 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1360 attributes["__init__"] = opViewType.attr("__init__"); 1361 py::str origName = userClass.attr("__name__"); 1362 py::str newName = py::str("_") + origName; 1363 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1364 } 1365 1366 //------------------------------------------------------------------------------ 1367 // PyInsertionPoint. 1368 //------------------------------------------------------------------------------ 1369 1370 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1371 1372 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1373 : refOperation(beforeOperationBase.getOperation().getRef()), 1374 block((*refOperation)->getBlock()) {} 1375 1376 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1377 PyOperation &operation = operationBase.getOperation(); 1378 if (operation.isAttached()) 1379 throw SetPyError(PyExc_ValueError, 1380 "Attempt to insert operation that is already attached"); 1381 block.getParentOperation()->checkValid(); 1382 MlirOperation beforeOp = {nullptr}; 1383 if (refOperation) { 1384 // Insert before operation. 1385 (*refOperation)->checkValid(); 1386 beforeOp = (*refOperation)->get(); 1387 } else { 1388 // Insert at end (before null) is only valid if the block does not 1389 // already end in a known terminator (violating this will cause assertion 1390 // failures later). 1391 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1392 throw py::index_error("Cannot insert operation at the end of a block " 1393 "that already has a terminator. Did you mean to " 1394 "use 'InsertionPoint.at_block_terminator(block)' " 1395 "versus 'InsertionPoint(block)'?"); 1396 } 1397 } 1398 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1399 operation.setAttached(); 1400 } 1401 1402 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1403 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1404 if (mlirOperationIsNull(firstOp)) { 1405 // Just insert at end. 1406 return PyInsertionPoint(block); 1407 } 1408 1409 // Insert before first op. 1410 PyOperationRef firstOpRef = PyOperation::forOperation( 1411 block.getParentOperation()->getContext(), firstOp); 1412 return PyInsertionPoint{block, std::move(firstOpRef)}; 1413 } 1414 1415 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1416 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1417 if (mlirOperationIsNull(terminator)) 1418 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1419 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1420 block.getParentOperation()->getContext(), terminator); 1421 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1422 } 1423 1424 py::object PyInsertionPoint::contextEnter() { 1425 return PyThreadContextEntry::pushInsertionPoint(*this); 1426 } 1427 1428 void PyInsertionPoint::contextExit(pybind11::object excType, 1429 pybind11::object excVal, 1430 pybind11::object excTb) { 1431 PyThreadContextEntry::popInsertionPoint(*this); 1432 } 1433 1434 //------------------------------------------------------------------------------ 1435 // PyAttribute. 1436 //------------------------------------------------------------------------------ 1437 1438 bool PyAttribute::operator==(const PyAttribute &other) { 1439 return mlirAttributeEqual(attr, other.attr); 1440 } 1441 1442 py::object PyAttribute::getCapsule() { 1443 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1444 } 1445 1446 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1447 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1448 if (mlirAttributeIsNull(rawAttr)) 1449 throw py::error_already_set(); 1450 return PyAttribute( 1451 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1452 } 1453 1454 //------------------------------------------------------------------------------ 1455 // PyNamedAttribute. 1456 //------------------------------------------------------------------------------ 1457 1458 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1459 : ownedName(new std::string(std::move(ownedName))) { 1460 namedAttr = mlirNamedAttributeGet( 1461 mlirIdentifierGet(mlirAttributeGetContext(attr), 1462 toMlirStringRef(*this->ownedName)), 1463 attr); 1464 } 1465 1466 //------------------------------------------------------------------------------ 1467 // PyType. 1468 //------------------------------------------------------------------------------ 1469 1470 bool PyType::operator==(const PyType &other) { 1471 return mlirTypeEqual(type, other.type); 1472 } 1473 1474 py::object PyType::getCapsule() { 1475 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1476 } 1477 1478 PyType PyType::createFromCapsule(py::object capsule) { 1479 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1480 if (mlirTypeIsNull(rawType)) 1481 throw py::error_already_set(); 1482 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1483 rawType); 1484 } 1485 1486 //------------------------------------------------------------------------------ 1487 // PyValue and subclases. 1488 //------------------------------------------------------------------------------ 1489 1490 pybind11::object PyValue::getCapsule() { 1491 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1492 } 1493 1494 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1495 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1496 if (mlirValueIsNull(value)) 1497 throw py::error_already_set(); 1498 MlirOperation owner; 1499 if (mlirValueIsAOpResult(value)) 1500 owner = mlirOpResultGetOwner(value); 1501 if (mlirValueIsABlockArgument(value)) 1502 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1503 if (mlirOperationIsNull(owner)) 1504 throw py::error_already_set(); 1505 MlirContext ctx = mlirOperationGetContext(owner); 1506 PyOperationRef ownerRef = 1507 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1508 return PyValue(ownerRef, value); 1509 } 1510 1511 namespace { 1512 /// CRTP base class for Python MLIR values that subclass Value and should be 1513 /// castable from it. The value hierarchy is one level deep and is not supposed 1514 /// to accommodate other levels unless core MLIR changes. 1515 template <typename DerivedTy> 1516 class PyConcreteValue : public PyValue { 1517 public: 1518 // Derived classes must define statics for: 1519 // IsAFunctionTy isaFunction 1520 // const char *pyClassName 1521 // and redefine bindDerived. 1522 using ClassTy = py::class_<DerivedTy, PyValue>; 1523 using IsAFunctionTy = bool (*)(MlirValue); 1524 1525 PyConcreteValue() = default; 1526 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1527 : PyValue(operationRef, value) {} 1528 PyConcreteValue(PyValue &orig) 1529 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1530 1531 /// Attempts to cast the original value to the derived type and throws on 1532 /// type mismatches. 1533 static MlirValue castFrom(PyValue &orig) { 1534 if (!DerivedTy::isaFunction(orig.get())) { 1535 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1536 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1537 DerivedTy::pyClassName + 1538 " (from " + origRepr + ")"); 1539 } 1540 return orig.get(); 1541 } 1542 1543 /// Binds the Python module objects to functions of this class. 1544 static void bind(py::module &m) { 1545 auto cls = ClassTy(m, DerivedTy::pyClassName); 1546 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1547 DerivedTy::bindDerived(cls); 1548 } 1549 1550 /// Implemented by derived classes to add methods to the Python subclass. 1551 static void bindDerived(ClassTy &m) {} 1552 }; 1553 1554 /// Python wrapper for MlirBlockArgument. 1555 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1556 public: 1557 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1558 static constexpr const char *pyClassName = "BlockArgument"; 1559 using PyConcreteValue::PyConcreteValue; 1560 1561 static void bindDerived(ClassTy &c) { 1562 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1563 return PyBlock(self.getParentOperation(), 1564 mlirBlockArgumentGetOwner(self.get())); 1565 }); 1566 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1567 return mlirBlockArgumentGetArgNumber(self.get()); 1568 }); 1569 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1570 return mlirBlockArgumentSetType(self.get(), type); 1571 }); 1572 } 1573 }; 1574 1575 /// Python wrapper for MlirOpResult. 1576 class PyOpResult : public PyConcreteValue<PyOpResult> { 1577 public: 1578 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1579 static constexpr const char *pyClassName = "OpResult"; 1580 using PyConcreteValue::PyConcreteValue; 1581 1582 static void bindDerived(ClassTy &c) { 1583 c.def_property_readonly("owner", [](PyOpResult &self) { 1584 assert( 1585 mlirOperationEqual(self.getParentOperation()->get(), 1586 mlirOpResultGetOwner(self.get())) && 1587 "expected the owner of the value in Python to match that in the IR"); 1588 return self.getParentOperation().getObject(); 1589 }); 1590 c.def_property_readonly("result_number", [](PyOpResult &self) { 1591 return mlirOpResultGetResultNumber(self.get()); 1592 }); 1593 } 1594 }; 1595 1596 /// A list of block arguments. Internally, these are stored as consecutive 1597 /// elements, random access is cheap. The argument list is associated with the 1598 /// operation that contains the block (detached blocks are not allowed in 1599 /// Python bindings) and extends its lifetime. 1600 class PyBlockArgumentList { 1601 public: 1602 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1603 : operation(std::move(operation)), block(block) {} 1604 1605 /// Returns the length of the block argument list. 1606 intptr_t dunderLen() { 1607 operation->checkValid(); 1608 return mlirBlockGetNumArguments(block); 1609 } 1610 1611 /// Returns `index`-th element of the block argument list. 1612 PyBlockArgument dunderGetItem(intptr_t index) { 1613 if (index < 0 || index >= dunderLen()) { 1614 throw SetPyError(PyExc_IndexError, 1615 "attempt to access out of bounds region"); 1616 } 1617 PyValue value(operation, mlirBlockGetArgument(block, index)); 1618 return PyBlockArgument(value); 1619 } 1620 1621 /// Defines a Python class in the bindings. 1622 static void bind(py::module &m) { 1623 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1624 .def("__len__", &PyBlockArgumentList::dunderLen) 1625 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1626 } 1627 1628 private: 1629 PyOperationRef operation; 1630 MlirBlock block; 1631 }; 1632 1633 /// A list of operation operands. Internally, these are stored as consecutive 1634 /// elements, random access is cheap. The result list is associated with the 1635 /// operation whose results these are, and extends the lifetime of this 1636 /// operation. 1637 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1638 public: 1639 static constexpr const char *pyClassName = "OpOperandList"; 1640 1641 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1642 intptr_t length = -1, intptr_t step = 1) 1643 : Sliceable(startIndex, 1644 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1645 : length, 1646 step), 1647 operation(operation) {} 1648 1649 intptr_t getNumElements() { 1650 operation->checkValid(); 1651 return mlirOperationGetNumOperands(operation->get()); 1652 } 1653 1654 PyValue getElement(intptr_t pos) { 1655 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); 1656 } 1657 1658 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1659 return PyOpOperandList(operation, startIndex, length, step); 1660 } 1661 1662 void dunderSetItem(intptr_t index, PyValue value) { 1663 index = wrapIndex(index); 1664 mlirOperationSetOperand(operation->get(), index, value.get()); 1665 } 1666 1667 static void bindDerived(ClassTy &c) { 1668 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1669 } 1670 1671 private: 1672 PyOperationRef operation; 1673 }; 1674 1675 /// A list of operation results. Internally, these are stored as consecutive 1676 /// elements, random access is cheap. The result list is associated with the 1677 /// operation whose results these are, and extends the lifetime of this 1678 /// operation. 1679 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1680 public: 1681 static constexpr const char *pyClassName = "OpResultList"; 1682 1683 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1684 intptr_t length = -1, intptr_t step = 1) 1685 : Sliceable(startIndex, 1686 length == -1 ? mlirOperationGetNumResults(operation->get()) 1687 : length, 1688 step), 1689 operation(operation) {} 1690 1691 intptr_t getNumElements() { 1692 operation->checkValid(); 1693 return mlirOperationGetNumResults(operation->get()); 1694 } 1695 1696 PyOpResult getElement(intptr_t index) { 1697 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1698 return PyOpResult(value); 1699 } 1700 1701 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1702 return PyOpResultList(operation, startIndex, length, step); 1703 } 1704 1705 private: 1706 PyOperationRef operation; 1707 }; 1708 1709 /// A list of operation attributes. Can be indexed by name, producing 1710 /// attributes, or by index, producing named attributes. 1711 class PyOpAttributeMap { 1712 public: 1713 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1714 1715 PyAttribute dunderGetItemNamed(const std::string &name) { 1716 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1717 toMlirStringRef(name)); 1718 if (mlirAttributeIsNull(attr)) { 1719 throw SetPyError(PyExc_KeyError, 1720 "attempt to access a non-existent attribute"); 1721 } 1722 return PyAttribute(operation->getContext(), attr); 1723 } 1724 1725 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1726 if (index < 0 || index >= dunderLen()) { 1727 throw SetPyError(PyExc_IndexError, 1728 "attempt to access out of bounds attribute"); 1729 } 1730 MlirNamedAttribute namedAttr = 1731 mlirOperationGetAttribute(operation->get(), index); 1732 return PyNamedAttribute( 1733 namedAttr.attribute, 1734 std::string(mlirIdentifierStr(namedAttr.name).data)); 1735 } 1736 1737 void dunderSetItem(const std::string &name, PyAttribute attr) { 1738 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1739 attr); 1740 } 1741 1742 void dunderDelItem(const std::string &name) { 1743 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1744 toMlirStringRef(name)); 1745 if (!removed) 1746 throw SetPyError(PyExc_KeyError, 1747 "attempt to delete a non-existent attribute"); 1748 } 1749 1750 intptr_t dunderLen() { 1751 return mlirOperationGetNumAttributes(operation->get()); 1752 } 1753 1754 bool dunderContains(const std::string &name) { 1755 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1756 operation->get(), toMlirStringRef(name))); 1757 } 1758 1759 static void bind(py::module &m) { 1760 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1761 .def("__contains__", &PyOpAttributeMap::dunderContains) 1762 .def("__len__", &PyOpAttributeMap::dunderLen) 1763 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1764 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1765 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1766 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1767 } 1768 1769 private: 1770 PyOperationRef operation; 1771 }; 1772 1773 } // end namespace 1774 1775 //------------------------------------------------------------------------------ 1776 // Populates the core exports of the 'ir' submodule. 1777 //------------------------------------------------------------------------------ 1778 1779 void mlir::python::populateIRCore(py::module &m) { 1780 //---------------------------------------------------------------------------- 1781 // Mapping of MlirContext. 1782 //---------------------------------------------------------------------------- 1783 py::class_<PyMlirContext>(m, "Context") 1784 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1785 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1786 .def("_get_context_again", 1787 [](PyMlirContext &self) { 1788 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1789 return ref.releaseObject(); 1790 }) 1791 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1792 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1793 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1794 &PyMlirContext::getCapsule) 1795 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1796 .def("__enter__", &PyMlirContext::contextEnter) 1797 .def("__exit__", &PyMlirContext::contextExit) 1798 .def_property_readonly_static( 1799 "current", 1800 [](py::object & /*class*/) { 1801 auto *context = PyThreadContextEntry::getDefaultContext(); 1802 if (!context) 1803 throw SetPyError(PyExc_ValueError, "No current Context"); 1804 return context; 1805 }, 1806 "Gets the Context bound to the current thread or raises ValueError") 1807 .def_property_readonly( 1808 "dialects", 1809 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1810 "Gets a container for accessing dialects by name") 1811 .def_property_readonly( 1812 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1813 "Alias for 'dialect'") 1814 .def( 1815 "get_dialect_descriptor", 1816 [=](PyMlirContext &self, std::string &name) { 1817 MlirDialect dialect = mlirContextGetOrLoadDialect( 1818 self.get(), {name.data(), name.size()}); 1819 if (mlirDialectIsNull(dialect)) { 1820 throw SetPyError(PyExc_ValueError, 1821 Twine("Dialect '") + name + "' not found"); 1822 } 1823 return PyDialectDescriptor(self.getRef(), dialect); 1824 }, 1825 "Gets or loads a dialect by name, returning its descriptor object") 1826 .def_property( 1827 "allow_unregistered_dialects", 1828 [](PyMlirContext &self) -> bool { 1829 return mlirContextGetAllowUnregisteredDialects(self.get()); 1830 }, 1831 [](PyMlirContext &self, bool value) { 1832 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1833 }) 1834 .def("enable_multithreading", 1835 [](PyMlirContext &self, bool enable) { 1836 mlirContextEnableMultithreading(self.get(), enable); 1837 }) 1838 .def("is_registered_operation", 1839 [](PyMlirContext &self, std::string &name) { 1840 return mlirContextIsRegisteredOperation( 1841 self.get(), MlirStringRef{name.data(), name.size()}); 1842 }); 1843 1844 //---------------------------------------------------------------------------- 1845 // Mapping of PyDialectDescriptor 1846 //---------------------------------------------------------------------------- 1847 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1848 .def_property_readonly("namespace", 1849 [](PyDialectDescriptor &self) { 1850 MlirStringRef ns = 1851 mlirDialectGetNamespace(self.get()); 1852 return py::str(ns.data, ns.length); 1853 }) 1854 .def("__repr__", [](PyDialectDescriptor &self) { 1855 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1856 std::string repr("<DialectDescriptor "); 1857 repr.append(ns.data, ns.length); 1858 repr.append(">"); 1859 return repr; 1860 }); 1861 1862 //---------------------------------------------------------------------------- 1863 // Mapping of PyDialects 1864 //---------------------------------------------------------------------------- 1865 py::class_<PyDialects>(m, "Dialects") 1866 .def("__getitem__", 1867 [=](PyDialects &self, std::string keyName) { 1868 MlirDialect dialect = 1869 self.getDialectForKey(keyName, /*attrError=*/false); 1870 py::object descriptor = 1871 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1872 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1873 }) 1874 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1875 MlirDialect dialect = 1876 self.getDialectForKey(attrName, /*attrError=*/true); 1877 py::object descriptor = 1878 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1879 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1880 }); 1881 1882 //---------------------------------------------------------------------------- 1883 // Mapping of PyDialect 1884 //---------------------------------------------------------------------------- 1885 py::class_<PyDialect>(m, "Dialect") 1886 .def(py::init<py::object>(), "descriptor") 1887 .def_property_readonly( 1888 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1889 .def("__repr__", [](py::object self) { 1890 auto clazz = self.attr("__class__"); 1891 return py::str("<Dialect ") + 1892 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1893 clazz.attr("__module__") + py::str(".") + 1894 clazz.attr("__name__") + py::str(")>"); 1895 }); 1896 1897 //---------------------------------------------------------------------------- 1898 // Mapping of Location 1899 //---------------------------------------------------------------------------- 1900 py::class_<PyLocation>(m, "Location") 1901 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1902 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1903 .def("__enter__", &PyLocation::contextEnter) 1904 .def("__exit__", &PyLocation::contextExit) 1905 .def("__eq__", 1906 [](PyLocation &self, PyLocation &other) -> bool { 1907 return mlirLocationEqual(self, other); 1908 }) 1909 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1910 .def_property_readonly_static( 1911 "current", 1912 [](py::object & /*class*/) { 1913 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1914 if (!loc) 1915 throw SetPyError(PyExc_ValueError, "No current Location"); 1916 return loc; 1917 }, 1918 "Gets the Location bound to the current thread or raises ValueError") 1919 .def_static( 1920 "unknown", 1921 [](DefaultingPyMlirContext context) { 1922 return PyLocation(context->getRef(), 1923 mlirLocationUnknownGet(context->get())); 1924 }, 1925 py::arg("context") = py::none(), 1926 "Gets a Location representing an unknown location") 1927 .def_static( 1928 "file", 1929 [](std::string filename, int line, int col, 1930 DefaultingPyMlirContext context) { 1931 return PyLocation( 1932 context->getRef(), 1933 mlirLocationFileLineColGet( 1934 context->get(), toMlirStringRef(filename), line, col)); 1935 }, 1936 py::arg("filename"), py::arg("line"), py::arg("col"), 1937 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1938 .def_property_readonly( 1939 "context", 1940 [](PyLocation &self) { return self.getContext().getObject(); }, 1941 "Context that owns the Location") 1942 .def("__repr__", [](PyLocation &self) { 1943 PyPrintAccumulator printAccum; 1944 mlirLocationPrint(self, printAccum.getCallback(), 1945 printAccum.getUserData()); 1946 return printAccum.join(); 1947 }); 1948 1949 //---------------------------------------------------------------------------- 1950 // Mapping of Module 1951 //---------------------------------------------------------------------------- 1952 py::class_<PyModule>(m, "Module") 1953 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1954 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1955 .def_static( 1956 "parse", 1957 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1958 MlirModule module = mlirModuleCreateParse( 1959 context->get(), toMlirStringRef(moduleAsm)); 1960 // TODO: Rework error reporting once diagnostic engine is exposed 1961 // in C API. 1962 if (mlirModuleIsNull(module)) { 1963 throw SetPyError( 1964 PyExc_ValueError, 1965 "Unable to parse module assembly (see diagnostics)"); 1966 } 1967 return PyModule::forModule(module).releaseObject(); 1968 }, 1969 py::arg("asm"), py::arg("context") = py::none(), 1970 kModuleParseDocstring) 1971 .def_static( 1972 "create", 1973 [](DefaultingPyLocation loc) { 1974 MlirModule module = mlirModuleCreateEmpty(loc); 1975 return PyModule::forModule(module).releaseObject(); 1976 }, 1977 py::arg("loc") = py::none(), "Creates an empty module") 1978 .def_property_readonly( 1979 "context", 1980 [](PyModule &self) { return self.getContext().getObject(); }, 1981 "Context that created the Module") 1982 .def_property_readonly( 1983 "operation", 1984 [](PyModule &self) { 1985 return PyOperation::forOperation(self.getContext(), 1986 mlirModuleGetOperation(self.get()), 1987 self.getRef().releaseObject()) 1988 .releaseObject(); 1989 }, 1990 "Accesses the module as an operation") 1991 .def_property_readonly( 1992 "body", 1993 [](PyModule &self) { 1994 PyOperationRef module_op = PyOperation::forOperation( 1995 self.getContext(), mlirModuleGetOperation(self.get()), 1996 self.getRef().releaseObject()); 1997 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 1998 return returnBlock; 1999 }, 2000 "Return the block for this module") 2001 .def( 2002 "dump", 2003 [](PyModule &self) { 2004 mlirOperationDump(mlirModuleGetOperation(self.get())); 2005 }, 2006 kDumpDocstring) 2007 .def( 2008 "__str__", 2009 [](PyModule &self) { 2010 MlirOperation operation = mlirModuleGetOperation(self.get()); 2011 PyPrintAccumulator printAccum; 2012 mlirOperationPrint(operation, printAccum.getCallback(), 2013 printAccum.getUserData()); 2014 return printAccum.join(); 2015 }, 2016 kOperationStrDunderDocstring); 2017 2018 //---------------------------------------------------------------------------- 2019 // Mapping of Operation. 2020 //---------------------------------------------------------------------------- 2021 py::class_<PyOperationBase>(m, "_OperationBase") 2022 .def("__eq__", 2023 [](PyOperationBase &self, PyOperationBase &other) { 2024 return &self.getOperation() == &other.getOperation(); 2025 }) 2026 .def("__eq__", 2027 [](PyOperationBase &self, py::object other) { return false; }) 2028 .def_property_readonly("attributes", 2029 [](PyOperationBase &self) { 2030 return PyOpAttributeMap( 2031 self.getOperation().getRef()); 2032 }) 2033 .def_property_readonly("operands", 2034 [](PyOperationBase &self) { 2035 return PyOpOperandList( 2036 self.getOperation().getRef()); 2037 }) 2038 .def_property_readonly("regions", 2039 [](PyOperationBase &self) { 2040 return PyRegionList( 2041 self.getOperation().getRef()); 2042 }) 2043 .def_property_readonly( 2044 "results", 2045 [](PyOperationBase &self) { 2046 return PyOpResultList(self.getOperation().getRef()); 2047 }, 2048 "Returns the list of Operation results.") 2049 .def_property_readonly( 2050 "result", 2051 [](PyOperationBase &self) { 2052 auto &operation = self.getOperation(); 2053 auto numResults = mlirOperationGetNumResults(operation); 2054 if (numResults != 1) { 2055 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2056 throw SetPyError( 2057 PyExc_ValueError, 2058 Twine("Cannot call .result on operation ") + 2059 StringRef(name.data, name.length) + " which has " + 2060 Twine(numResults) + 2061 " results (it is only valid for operations with a " 2062 "single result)"); 2063 } 2064 return PyOpResult(operation.getRef(), 2065 mlirOperationGetResult(operation, 0)); 2066 }, 2067 "Shortcut to get an op result if it has only one (throws an error " 2068 "otherwise).") 2069 .def("__iter__", 2070 [](PyOperationBase &self) { 2071 return PyRegionIterator(self.getOperation().getRef()); 2072 }) 2073 .def( 2074 "__str__", 2075 [](PyOperationBase &self) { 2076 return self.getAsm(/*binary=*/false, 2077 /*largeElementsLimit=*/llvm::None, 2078 /*enableDebugInfo=*/false, 2079 /*prettyDebugInfo=*/false, 2080 /*printGenericOpForm=*/false, 2081 /*useLocalScope=*/false); 2082 }, 2083 "Returns the assembly form of the operation.") 2084 .def("print", &PyOperationBase::print, 2085 // Careful: Lots of arguments must match up with print method. 2086 py::arg("file") = py::none(), py::arg("binary") = false, 2087 py::arg("large_elements_limit") = py::none(), 2088 py::arg("enable_debug_info") = false, 2089 py::arg("pretty_debug_info") = false, 2090 py::arg("print_generic_op_form") = false, 2091 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2092 .def("get_asm", &PyOperationBase::getAsm, 2093 // Careful: Lots of arguments must match up with get_asm method. 2094 py::arg("binary") = false, 2095 py::arg("large_elements_limit") = py::none(), 2096 py::arg("enable_debug_info") = false, 2097 py::arg("pretty_debug_info") = false, 2098 py::arg("print_generic_op_form") = false, 2099 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2100 .def( 2101 "verify", 2102 [](PyOperationBase &self) { 2103 return mlirOperationVerify(self.getOperation()); 2104 }, 2105 "Verify the operation and return true if it passes, false if it " 2106 "fails."); 2107 2108 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2109 .def_static("create", &PyOperation::create, py::arg("name"), 2110 py::arg("results") = py::none(), 2111 py::arg("operands") = py::none(), 2112 py::arg("attributes") = py::none(), 2113 py::arg("successors") = py::none(), py::arg("regions") = 0, 2114 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2115 kOperationCreateDocstring) 2116 .def("erase", &PyOperation::erase) 2117 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2118 &PyOperation::getCapsule) 2119 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2120 .def_property_readonly("name", 2121 [](PyOperation &self) { 2122 self.checkValid(); 2123 MlirOperation operation = self.get(); 2124 MlirStringRef name = mlirIdentifierStr( 2125 mlirOperationGetName(operation)); 2126 return py::str(name.data, name.length); 2127 }) 2128 .def_property_readonly( 2129 "context", 2130 [](PyOperation &self) { 2131 self.checkValid(); 2132 return self.getContext().getObject(); 2133 }, 2134 "Context that owns the Operation") 2135 .def_property_readonly("opview", &PyOperation::createOpView); 2136 2137 auto opViewClass = 2138 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2139 .def(py::init<py::object>()) 2140 .def_property_readonly("operation", &PyOpView::getOperationObject) 2141 .def_property_readonly( 2142 "context", 2143 [](PyOpView &self) { 2144 return self.getOperation().getContext().getObject(); 2145 }, 2146 "Context that owns the Operation") 2147 .def("__str__", [](PyOpView &self) { 2148 return py::str(self.getOperationObject()); 2149 }); 2150 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2151 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2152 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2153 opViewClass.attr("build_generic") = classmethod( 2154 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2155 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2156 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2157 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2158 "Builds a specific, generated OpView based on class level attributes."); 2159 2160 //---------------------------------------------------------------------------- 2161 // Mapping of PyRegion. 2162 //---------------------------------------------------------------------------- 2163 py::class_<PyRegion>(m, "Region") 2164 .def_property_readonly( 2165 "blocks", 2166 [](PyRegion &self) { 2167 return PyBlockList(self.getParentOperation(), self.get()); 2168 }, 2169 "Returns a forward-optimized sequence of blocks.") 2170 .def( 2171 "__iter__", 2172 [](PyRegion &self) { 2173 self.checkValid(); 2174 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2175 return PyBlockIterator(self.getParentOperation(), firstBlock); 2176 }, 2177 "Iterates over blocks in the region.") 2178 .def("__eq__", 2179 [](PyRegion &self, PyRegion &other) { 2180 return self.get().ptr == other.get().ptr; 2181 }) 2182 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2183 2184 //---------------------------------------------------------------------------- 2185 // Mapping of PyBlock. 2186 //---------------------------------------------------------------------------- 2187 py::class_<PyBlock>(m, "Block") 2188 .def_property_readonly( 2189 "arguments", 2190 [](PyBlock &self) { 2191 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2192 }, 2193 "Returns a list of block arguments.") 2194 .def_property_readonly( 2195 "operations", 2196 [](PyBlock &self) { 2197 return PyOperationList(self.getParentOperation(), self.get()); 2198 }, 2199 "Returns a forward-optimized sequence of operations.") 2200 .def( 2201 "__iter__", 2202 [](PyBlock &self) { 2203 self.checkValid(); 2204 MlirOperation firstOperation = 2205 mlirBlockGetFirstOperation(self.get()); 2206 return PyOperationIterator(self.getParentOperation(), 2207 firstOperation); 2208 }, 2209 "Iterates over operations in the block.") 2210 .def("__eq__", 2211 [](PyBlock &self, PyBlock &other) { 2212 return self.get().ptr == other.get().ptr; 2213 }) 2214 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2215 .def( 2216 "__str__", 2217 [](PyBlock &self) { 2218 self.checkValid(); 2219 PyPrintAccumulator printAccum; 2220 mlirBlockPrint(self.get(), printAccum.getCallback(), 2221 printAccum.getUserData()); 2222 return printAccum.join(); 2223 }, 2224 "Returns the assembly form of the block."); 2225 2226 //---------------------------------------------------------------------------- 2227 // Mapping of PyInsertionPoint. 2228 //---------------------------------------------------------------------------- 2229 2230 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2231 .def(py::init<PyBlock &>(), py::arg("block"), 2232 "Inserts after the last operation but still inside the block.") 2233 .def("__enter__", &PyInsertionPoint::contextEnter) 2234 .def("__exit__", &PyInsertionPoint::contextExit) 2235 .def_property_readonly_static( 2236 "current", 2237 [](py::object & /*class*/) { 2238 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2239 if (!ip) 2240 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2241 return ip; 2242 }, 2243 "Gets the InsertionPoint bound to the current thread or raises " 2244 "ValueError if none has been set") 2245 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2246 "Inserts before a referenced operation.") 2247 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2248 py::arg("block"), "Inserts at the beginning of the block.") 2249 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2250 py::arg("block"), "Inserts before the block terminator.") 2251 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2252 "Inserts an operation."); 2253 2254 //---------------------------------------------------------------------------- 2255 // Mapping of PyAttribute. 2256 //---------------------------------------------------------------------------- 2257 py::class_<PyAttribute>(m, "Attribute") 2258 // Delegate to the PyAttribute copy constructor, which will also lifetime 2259 // extend the backing context which owns the MlirAttribute. 2260 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2261 "Casts the passed attribute to the generic Attribute") 2262 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2263 &PyAttribute::getCapsule) 2264 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2265 .def_static( 2266 "parse", 2267 [](std::string attrSpec, DefaultingPyMlirContext context) { 2268 MlirAttribute type = mlirAttributeParseGet( 2269 context->get(), toMlirStringRef(attrSpec)); 2270 // TODO: Rework error reporting once diagnostic engine is exposed 2271 // in C API. 2272 if (mlirAttributeIsNull(type)) { 2273 throw SetPyError(PyExc_ValueError, 2274 Twine("Unable to parse attribute: '") + 2275 attrSpec + "'"); 2276 } 2277 return PyAttribute(context->getRef(), type); 2278 }, 2279 py::arg("asm"), py::arg("context") = py::none(), 2280 "Parses an attribute from an assembly form") 2281 .def_property_readonly( 2282 "context", 2283 [](PyAttribute &self) { return self.getContext().getObject(); }, 2284 "Context that owns the Attribute") 2285 .def_property_readonly("type", 2286 [](PyAttribute &self) { 2287 return PyType(self.getContext()->getRef(), 2288 mlirAttributeGetType(self)); 2289 }) 2290 .def( 2291 "get_named", 2292 [](PyAttribute &self, std::string name) { 2293 return PyNamedAttribute(self, std::move(name)); 2294 }, 2295 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2296 .def("__eq__", 2297 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2298 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2299 .def( 2300 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2301 kDumpDocstring) 2302 .def( 2303 "__str__", 2304 [](PyAttribute &self) { 2305 PyPrintAccumulator printAccum; 2306 mlirAttributePrint(self, printAccum.getCallback(), 2307 printAccum.getUserData()); 2308 return printAccum.join(); 2309 }, 2310 "Returns the assembly form of the Attribute.") 2311 .def("__repr__", [](PyAttribute &self) { 2312 // Generally, assembly formats are not printed for __repr__ because 2313 // this can cause exceptionally long debug output and exceptions. 2314 // However, attribute values are generally considered useful and are 2315 // printed. This may need to be re-evaluated if debug dumps end up 2316 // being excessive. 2317 PyPrintAccumulator printAccum; 2318 printAccum.parts.append("Attribute("); 2319 mlirAttributePrint(self, printAccum.getCallback(), 2320 printAccum.getUserData()); 2321 printAccum.parts.append(")"); 2322 return printAccum.join(); 2323 }); 2324 2325 //---------------------------------------------------------------------------- 2326 // Mapping of PyNamedAttribute 2327 //---------------------------------------------------------------------------- 2328 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2329 .def("__repr__", 2330 [](PyNamedAttribute &self) { 2331 PyPrintAccumulator printAccum; 2332 printAccum.parts.append("NamedAttribute("); 2333 printAccum.parts.append( 2334 mlirIdentifierStr(self.namedAttr.name).data); 2335 printAccum.parts.append("="); 2336 mlirAttributePrint(self.namedAttr.attribute, 2337 printAccum.getCallback(), 2338 printAccum.getUserData()); 2339 printAccum.parts.append(")"); 2340 return printAccum.join(); 2341 }) 2342 .def_property_readonly( 2343 "name", 2344 [](PyNamedAttribute &self) { 2345 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2346 mlirIdentifierStr(self.namedAttr.name).length); 2347 }, 2348 "The name of the NamedAttribute binding") 2349 .def_property_readonly( 2350 "attr", 2351 [](PyNamedAttribute &self) { 2352 // TODO: When named attribute is removed/refactored, also remove 2353 // this constructor (it does an inefficient table lookup). 2354 auto contextRef = PyMlirContext::forContext( 2355 mlirAttributeGetContext(self.namedAttr.attribute)); 2356 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2357 }, 2358 py::keep_alive<0, 1>(), 2359 "The underlying generic attribute of the NamedAttribute binding"); 2360 2361 //---------------------------------------------------------------------------- 2362 // Mapping of PyType. 2363 //---------------------------------------------------------------------------- 2364 py::class_<PyType>(m, "Type") 2365 // Delegate to the PyType copy constructor, which will also lifetime 2366 // extend the backing context which owns the MlirType. 2367 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2368 "Casts the passed type to the generic Type") 2369 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2370 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2371 .def_static( 2372 "parse", 2373 [](std::string typeSpec, DefaultingPyMlirContext context) { 2374 MlirType type = 2375 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2376 // TODO: Rework error reporting once diagnostic engine is exposed 2377 // in C API. 2378 if (mlirTypeIsNull(type)) { 2379 throw SetPyError(PyExc_ValueError, 2380 Twine("Unable to parse type: '") + typeSpec + 2381 "'"); 2382 } 2383 return PyType(context->getRef(), type); 2384 }, 2385 py::arg("asm"), py::arg("context") = py::none(), 2386 kContextParseTypeDocstring) 2387 .def_property_readonly( 2388 "context", [](PyType &self) { return self.getContext().getObject(); }, 2389 "Context that owns the Type") 2390 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2391 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2392 .def( 2393 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2394 .def( 2395 "__str__", 2396 [](PyType &self) { 2397 PyPrintAccumulator printAccum; 2398 mlirTypePrint(self, printAccum.getCallback(), 2399 printAccum.getUserData()); 2400 return printAccum.join(); 2401 }, 2402 "Returns the assembly form of the type.") 2403 .def("__repr__", [](PyType &self) { 2404 // Generally, assembly formats are not printed for __repr__ because 2405 // this can cause exceptionally long debug output and exceptions. 2406 // However, types are an exception as they typically have compact 2407 // assembly forms and printing them is useful. 2408 PyPrintAccumulator printAccum; 2409 printAccum.parts.append("Type("); 2410 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2411 printAccum.parts.append(")"); 2412 return printAccum.join(); 2413 }); 2414 2415 //---------------------------------------------------------------------------- 2416 // Mapping of Value. 2417 //---------------------------------------------------------------------------- 2418 py::class_<PyValue>(m, "Value") 2419 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2420 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2421 .def_property_readonly( 2422 "context", 2423 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2424 "Context in which the value lives.") 2425 .def( 2426 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2427 kDumpDocstring) 2428 .def("__eq__", 2429 [](PyValue &self, PyValue &other) { 2430 return self.get().ptr == other.get().ptr; 2431 }) 2432 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2433 .def( 2434 "__str__", 2435 [](PyValue &self) { 2436 PyPrintAccumulator printAccum; 2437 printAccum.parts.append("Value("); 2438 mlirValuePrint(self.get(), printAccum.getCallback(), 2439 printAccum.getUserData()); 2440 printAccum.parts.append(")"); 2441 return printAccum.join(); 2442 }, 2443 kValueDunderStrDocstring) 2444 .def_property_readonly("type", [](PyValue &self) { 2445 return PyType(self.getParentOperation()->getContext(), 2446 mlirValueGetType(self.get())); 2447 }); 2448 PyBlockArgument::bind(m); 2449 PyOpResult::bind(m); 2450 2451 // Container bindings. 2452 PyBlockArgumentList::bind(m); 2453 PyBlockIterator::bind(m); 2454 PyBlockList::bind(m); 2455 PyOperationIterator::bind(m); 2456 PyOperationList::bind(m); 2457 PyOpAttributeMap::bind(m); 2458 PyOpOperandList::bind(m); 2459 PyOpResultList::bind(m); 2460 PyRegionIterator::bind(m); 2461 PyRegionList::bind(m); 2462 2463 // Debug bindings. 2464 PyGlobalDebugFlag::bind(m); 2465 } 2466