1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include <pybind11/stl.h> 22 23 namespace py = pybind11; 24 using namespace mlir; 25 using namespace mlir::python; 26 27 using llvm::SmallVector; 28 using llvm::StringRef; 29 using llvm::Twine; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kContextParseTypeDocstring[] = 36 R"(Parses the assembly form of a type. 37 38 Returns a Type object or raises a ValueError if the type cannot be parsed. 39 40 See also: https://mlir.llvm.org/docs/LangRef/#type-system 41 )"; 42 43 static const char kContextGetFileLocationDocstring[] = 44 R"(Gets a Location representing a file, line and column)"; 45 46 static const char kModuleParseDocstring[] = 47 R"(Parses a module's assembly format from a string. 48 49 Returns a new MlirModule or raises a ValueError if the parsing fails. 50 51 See also: https://mlir.llvm.org/docs/LangRef/ 52 )"; 53 54 static const char kOperationCreateDocstring[] = 55 R"(Creates a new operation. 56 57 Args: 58 name: Operation name (e.g. "dialect.operation"). 59 results: Sequence of Type representing op result types. 60 attributes: Dict of str:Attribute. 61 successors: List of Block for the operation's successors. 62 regions: Number of regions to create. 63 location: A Location object (defaults to resolve from context manager). 64 ip: An InsertionPoint (defaults to resolve from context manager or set to 65 False to disable insertion, even with an insertion point set in the 66 context manager). 67 Returns: 68 A new "detached" Operation object. Detached operations can be added 69 to blocks, which causes them to become "attached." 70 )"; 71 72 static const char kOperationPrintDocstring[] = 73 R"(Prints the assembly form of the operation to a file like object. 74 75 Args: 76 file: The file like object to write to. Defaults to sys.stdout. 77 binary: Whether to write bytes (True) or str (False). Defaults to False. 78 large_elements_limit: Whether to elide elements attributes above this 79 number of elements. Defaults to None (no limit). 80 enable_debug_info: Whether to print debug/location information. Defaults 81 to False. 82 pretty_debug_info: Whether to format debug information for easier reading 83 by a human (warning: the result is unparseable). 84 print_generic_op_form: Whether to print the generic assembly forms of all 85 ops. Defaults to False. 86 use_local_Scope: Whether to print in a way that is more optimized for 87 multi-threaded access but may not be consistent with how the overall 88 module prints. 89 )"; 90 91 static const char kOperationGetAsmDocstring[] = 92 R"(Gets the assembly form of the operation with all options available. 93 94 Args: 95 binary: Whether to return a bytes (True) or str (False) object. Defaults to 96 False. 97 ... others ...: See the print() method for common keyword arguments for 98 configuring the printout. 99 Returns: 100 Either a bytes or str object, depending on the setting of the 'binary' 101 argument. 102 )"; 103 104 static const char kOperationStrDunderDocstring[] = 105 R"(Gets the assembly form of the operation with default options. 106 107 If more advanced control over the assembly formatting or I/O options is needed, 108 use the dedicated print or get_asm method, which supports keyword arguments to 109 customize behavior. 110 )"; 111 112 static const char kDumpDocstring[] = 113 R"(Dumps a debug representation of the object to stderr.)"; 114 115 static const char kAppendBlockDocstring[] = 116 R"(Appends a new block, with argument types as positional args. 117 118 Returns: 119 The created block. 120 )"; 121 122 static const char kValueDunderStrDocstring[] = 123 R"(Returns the string form of the value. 124 125 If the value is a block argument, this is the assembly form of its type and the 126 position in the argument list. If the value is an operation result, this is 127 equivalent to printing the operation that produced it. 128 )"; 129 130 //------------------------------------------------------------------------------ 131 // Utilities. 132 //------------------------------------------------------------------------------ 133 134 /// Helper for creating an @classmethod. 135 template <class Func, typename... Args> 136 py::object classmethod(Func f, Args... args) { 137 py::object cf = py::cpp_function(f, args...); 138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 139 } 140 141 static py::object 142 createCustomDialectWrapper(const std::string &dialectNamespace, 143 py::object dialectDescriptor) { 144 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 145 if (!dialectClass) { 146 // Use the base class. 147 return py::cast(PyDialect(std::move(dialectDescriptor))); 148 } 149 150 // Create the custom implementation. 151 return (*dialectClass)(std::move(dialectDescriptor)); 152 } 153 154 static MlirStringRef toMlirStringRef(const std::string &s) { 155 return mlirStringRefCreate(s.data(), s.size()); 156 } 157 158 /// Wrapper for the global LLVM debugging flag. 159 struct PyGlobalDebugFlag { 160 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 161 162 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 163 164 static void bind(py::module &m) { 165 // Debug flags. 166 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") 167 .def_property_static("flag", &PyGlobalDebugFlag::get, 168 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 169 } 170 }; 171 172 //------------------------------------------------------------------------------ 173 // Collections. 174 //------------------------------------------------------------------------------ 175 176 namespace { 177 178 class PyRegionIterator { 179 public: 180 PyRegionIterator(PyOperationRef operation) 181 : operation(std::move(operation)) {} 182 183 PyRegionIterator &dunderIter() { return *this; } 184 185 PyRegion dunderNext() { 186 operation->checkValid(); 187 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 188 throw py::stop_iteration(); 189 } 190 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 191 return PyRegion(operation, region); 192 } 193 194 static void bind(py::module &m) { 195 py::class_<PyRegionIterator>(m, "RegionIterator") 196 .def("__iter__", &PyRegionIterator::dunderIter) 197 .def("__next__", &PyRegionIterator::dunderNext); 198 } 199 200 private: 201 PyOperationRef operation; 202 int nextIndex = 0; 203 }; 204 205 /// Regions of an op are fixed length and indexed numerically so are represented 206 /// with a sequence-like container. 207 class PyRegionList { 208 public: 209 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 210 211 intptr_t dunderLen() { 212 operation->checkValid(); 213 return mlirOperationGetNumRegions(operation->get()); 214 } 215 216 PyRegion dunderGetItem(intptr_t index) { 217 // dunderLen checks validity. 218 if (index < 0 || index >= dunderLen()) { 219 throw SetPyError(PyExc_IndexError, 220 "attempt to access out of bounds region"); 221 } 222 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 223 return PyRegion(operation, region); 224 } 225 226 static void bind(py::module &m) { 227 py::class_<PyRegionList>(m, "RegionSequence") 228 .def("__len__", &PyRegionList::dunderLen) 229 .def("__getitem__", &PyRegionList::dunderGetItem); 230 } 231 232 private: 233 PyOperationRef operation; 234 }; 235 236 class PyBlockIterator { 237 public: 238 PyBlockIterator(PyOperationRef operation, MlirBlock next) 239 : operation(std::move(operation)), next(next) {} 240 241 PyBlockIterator &dunderIter() { return *this; } 242 243 PyBlock dunderNext() { 244 operation->checkValid(); 245 if (mlirBlockIsNull(next)) { 246 throw py::stop_iteration(); 247 } 248 249 PyBlock returnBlock(operation, next); 250 next = mlirBlockGetNextInRegion(next); 251 return returnBlock; 252 } 253 254 static void bind(py::module &m) { 255 py::class_<PyBlockIterator>(m, "BlockIterator") 256 .def("__iter__", &PyBlockIterator::dunderIter) 257 .def("__next__", &PyBlockIterator::dunderNext); 258 } 259 260 private: 261 PyOperationRef operation; 262 MlirBlock next; 263 }; 264 265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 266 /// we present them as a more full-featured list-like container but optimize 267 /// it for forward iteration. Blocks are always owned by a region. 268 class PyBlockList { 269 public: 270 PyBlockList(PyOperationRef operation, MlirRegion region) 271 : operation(std::move(operation)), region(region) {} 272 273 PyBlockIterator dunderIter() { 274 operation->checkValid(); 275 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 276 } 277 278 intptr_t dunderLen() { 279 operation->checkValid(); 280 intptr_t count = 0; 281 MlirBlock block = mlirRegionGetFirstBlock(region); 282 while (!mlirBlockIsNull(block)) { 283 count += 1; 284 block = mlirBlockGetNextInRegion(block); 285 } 286 return count; 287 } 288 289 PyBlock dunderGetItem(intptr_t index) { 290 operation->checkValid(); 291 if (index < 0) { 292 throw SetPyError(PyExc_IndexError, 293 "attempt to access out of bounds block"); 294 } 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 if (index == 0) { 298 return PyBlock(operation, block); 299 } 300 block = mlirBlockGetNextInRegion(block); 301 index -= 1; 302 } 303 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 304 } 305 306 PyBlock appendBlock(py::args pyArgTypes) { 307 operation->checkValid(); 308 llvm::SmallVector<MlirType, 4> argTypes; 309 argTypes.reserve(pyArgTypes.size()); 310 for (auto &pyArg : pyArgTypes) { 311 argTypes.push_back(pyArg.cast<PyType &>()); 312 } 313 314 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 315 mlirRegionAppendOwnedBlock(region, block); 316 return PyBlock(operation, block); 317 } 318 319 static void bind(py::module &m) { 320 py::class_<PyBlockList>(m, "BlockList") 321 .def("__getitem__", &PyBlockList::dunderGetItem) 322 .def("__iter__", &PyBlockList::dunderIter) 323 .def("__len__", &PyBlockList::dunderLen) 324 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 325 } 326 327 private: 328 PyOperationRef operation; 329 MlirRegion region; 330 }; 331 332 class PyOperationIterator { 333 public: 334 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 335 : parentOperation(std::move(parentOperation)), next(next) {} 336 337 PyOperationIterator &dunderIter() { return *this; } 338 339 py::object dunderNext() { 340 parentOperation->checkValid(); 341 if (mlirOperationIsNull(next)) { 342 throw py::stop_iteration(); 343 } 344 345 PyOperationRef returnOperation = 346 PyOperation::forOperation(parentOperation->getContext(), next); 347 next = mlirOperationGetNextInBlock(next); 348 return returnOperation->createOpView(); 349 } 350 351 static void bind(py::module &m) { 352 py::class_<PyOperationIterator>(m, "OperationIterator") 353 .def("__iter__", &PyOperationIterator::dunderIter) 354 .def("__next__", &PyOperationIterator::dunderNext); 355 } 356 357 private: 358 PyOperationRef parentOperation; 359 MlirOperation next; 360 }; 361 362 /// Operations are exposed by the C-API as a forward-only linked list. In 363 /// Python, we present them as a more full-featured list-like container but 364 /// optimize it for forward iteration. Iterable operations are always owned 365 /// by a block. 366 class PyOperationList { 367 public: 368 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 369 : parentOperation(std::move(parentOperation)), block(block) {} 370 371 PyOperationIterator dunderIter() { 372 parentOperation->checkValid(); 373 return PyOperationIterator(parentOperation, 374 mlirBlockGetFirstOperation(block)); 375 } 376 377 intptr_t dunderLen() { 378 parentOperation->checkValid(); 379 intptr_t count = 0; 380 MlirOperation childOp = mlirBlockGetFirstOperation(block); 381 while (!mlirOperationIsNull(childOp)) { 382 count += 1; 383 childOp = mlirOperationGetNextInBlock(childOp); 384 } 385 return count; 386 } 387 388 py::object dunderGetItem(intptr_t index) { 389 parentOperation->checkValid(); 390 if (index < 0) { 391 throw SetPyError(PyExc_IndexError, 392 "attempt to access out of bounds operation"); 393 } 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 if (index == 0) { 397 return PyOperation::forOperation(parentOperation->getContext(), childOp) 398 ->createOpView(); 399 } 400 childOp = mlirOperationGetNextInBlock(childOp); 401 index -= 1; 402 } 403 throw SetPyError(PyExc_IndexError, 404 "attempt to access out of bounds operation"); 405 } 406 407 static void bind(py::module &m) { 408 py::class_<PyOperationList>(m, "OperationList") 409 .def("__getitem__", &PyOperationList::dunderGetItem) 410 .def("__iter__", &PyOperationList::dunderIter) 411 .def("__len__", &PyOperationList::dunderLen); 412 } 413 414 private: 415 PyOperationRef parentOperation; 416 MlirBlock block; 417 }; 418 419 } // namespace 420 421 //------------------------------------------------------------------------------ 422 // PyMlirContext 423 //------------------------------------------------------------------------------ 424 425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 426 py::gil_scoped_acquire acquire; 427 auto &liveContexts = getLiveContexts(); 428 liveContexts[context.ptr] = this; 429 } 430 431 PyMlirContext::~PyMlirContext() { 432 // Note that the only public way to construct an instance is via the 433 // forContext method, which always puts the associated handle into 434 // liveContexts. 435 py::gil_scoped_acquire acquire; 436 getLiveContexts().erase(context.ptr); 437 mlirContextDestroy(context); 438 } 439 440 py::object PyMlirContext::getCapsule() { 441 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 442 } 443 444 py::object PyMlirContext::createFromCapsule(py::object capsule) { 445 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 446 if (mlirContextIsNull(rawContext)) 447 throw py::error_already_set(); 448 return forContext(rawContext).releaseObject(); 449 } 450 451 PyMlirContext *PyMlirContext::createNewContextForInit() { 452 MlirContext context = mlirContextCreate(); 453 mlirRegisterAllDialects(context); 454 return new PyMlirContext(context); 455 } 456 457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 458 py::gil_scoped_acquire acquire; 459 auto &liveContexts = getLiveContexts(); 460 auto it = liveContexts.find(context.ptr); 461 if (it == liveContexts.end()) { 462 // Create. 463 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 464 py::object pyRef = py::cast(unownedContextWrapper); 465 assert(pyRef && "cast to py::object failed"); 466 liveContexts[context.ptr] = unownedContextWrapper; 467 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 468 } 469 // Use existing. 470 py::object pyRef = py::cast(it->second); 471 return PyMlirContextRef(it->second, std::move(pyRef)); 472 } 473 474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 475 static LiveContextMap liveContexts; 476 return liveContexts; 477 } 478 479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 480 481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 482 483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 484 485 pybind11::object PyMlirContext::contextEnter() { 486 return PyThreadContextEntry::pushContext(*this); 487 } 488 489 void PyMlirContext::contextExit(pybind11::object excType, 490 pybind11::object excVal, 491 pybind11::object excTb) { 492 PyThreadContextEntry::popContext(*this); 493 } 494 495 PyMlirContext &DefaultingPyMlirContext::resolve() { 496 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 497 if (!context) { 498 throw SetPyError( 499 PyExc_RuntimeError, 500 "An MLIR function requires a Context but none was provided in the call " 501 "or from the surrounding environment. Either pass to the function with " 502 "a 'context=' argument or establish a default using 'with Context():'"); 503 } 504 return *context; 505 } 506 507 //------------------------------------------------------------------------------ 508 // PyThreadContextEntry management 509 //------------------------------------------------------------------------------ 510 511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 512 static thread_local std::vector<PyThreadContextEntry> stack; 513 return stack; 514 } 515 516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 517 auto &stack = getStack(); 518 if (stack.empty()) 519 return nullptr; 520 return &stack.back(); 521 } 522 523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 524 py::object insertionPoint, 525 py::object location) { 526 auto &stack = getStack(); 527 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 528 std::move(location)); 529 // If the new stack has more than one entry and the context of the new top 530 // entry matches the previous, copy the insertionPoint and location from the 531 // previous entry if missing from the new top entry. 532 if (stack.size() > 1) { 533 auto &prev = *(stack.rbegin() + 1); 534 auto ¤t = stack.back(); 535 if (current.context.is(prev.context)) { 536 // Default non-context objects from the previous entry. 537 if (!current.insertionPoint) 538 current.insertionPoint = prev.insertionPoint; 539 if (!current.location) 540 current.location = prev.location; 541 } 542 } 543 } 544 545 PyMlirContext *PyThreadContextEntry::getContext() { 546 if (!context) 547 return nullptr; 548 return py::cast<PyMlirContext *>(context); 549 } 550 551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 552 if (!insertionPoint) 553 return nullptr; 554 return py::cast<PyInsertionPoint *>(insertionPoint); 555 } 556 557 PyLocation *PyThreadContextEntry::getLocation() { 558 if (!location) 559 return nullptr; 560 return py::cast<PyLocation *>(location); 561 } 562 563 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 564 auto *tos = getTopOfStack(); 565 return tos ? tos->getContext() : nullptr; 566 } 567 568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 569 auto *tos = getTopOfStack(); 570 return tos ? tos->getInsertionPoint() : nullptr; 571 } 572 573 PyLocation *PyThreadContextEntry::getDefaultLocation() { 574 auto *tos = getTopOfStack(); 575 return tos ? tos->getLocation() : nullptr; 576 } 577 578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 579 py::object contextObj = py::cast(context); 580 push(FrameKind::Context, /*context=*/contextObj, 581 /*insertionPoint=*/py::object(), 582 /*location=*/py::object()); 583 return contextObj; 584 } 585 586 void PyThreadContextEntry::popContext(PyMlirContext &context) { 587 auto &stack = getStack(); 588 if (stack.empty()) 589 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 590 auto &tos = stack.back(); 591 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593 stack.pop_back(); 594 } 595 596 py::object 597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 598 py::object contextObj = 599 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 600 py::object insertionPointObj = py::cast(insertionPoint); 601 push(FrameKind::InsertionPoint, 602 /*context=*/contextObj, 603 /*insertionPoint=*/insertionPointObj, 604 /*location=*/py::object()); 605 return insertionPointObj; 606 } 607 608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 609 auto &stack = getStack(); 610 if (stack.empty()) 611 throw SetPyError(PyExc_RuntimeError, 612 "Unbalanced InsertionPoint enter/exit"); 613 auto &tos = stack.back(); 614 if (tos.frameKind != FrameKind::InsertionPoint && 615 tos.getInsertionPoint() != &insertionPoint) 616 throw SetPyError(PyExc_RuntimeError, 617 "Unbalanced InsertionPoint enter/exit"); 618 stack.pop_back(); 619 } 620 621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 622 py::object contextObj = location.getContext().getObject(); 623 py::object locationObj = py::cast(location); 624 push(FrameKind::Location, /*context=*/contextObj, 625 /*insertionPoint=*/py::object(), 626 /*location=*/locationObj); 627 return locationObj; 628 } 629 630 void PyThreadContextEntry::popLocation(PyLocation &location) { 631 auto &stack = getStack(); 632 if (stack.empty()) 633 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 634 auto &tos = stack.back(); 635 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637 stack.pop_back(); 638 } 639 640 //------------------------------------------------------------------------------ 641 // PyDialect, PyDialectDescriptor, PyDialects 642 //------------------------------------------------------------------------------ 643 644 MlirDialect PyDialects::getDialectForKey(const std::string &key, 645 bool attrError) { 646 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 647 {key.data(), key.size()}); 648 if (mlirDialectIsNull(dialect)) { 649 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 650 Twine("Dialect '") + key + "' not found"); 651 } 652 return dialect; 653 } 654 655 //------------------------------------------------------------------------------ 656 // PyLocation 657 //------------------------------------------------------------------------------ 658 659 py::object PyLocation::getCapsule() { 660 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 661 } 662 663 PyLocation PyLocation::createFromCapsule(py::object capsule) { 664 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 665 if (mlirLocationIsNull(rawLoc)) 666 throw py::error_already_set(); 667 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 668 rawLoc); 669 } 670 671 py::object PyLocation::contextEnter() { 672 return PyThreadContextEntry::pushLocation(*this); 673 } 674 675 void PyLocation::contextExit(py::object excType, py::object excVal, 676 py::object excTb) { 677 PyThreadContextEntry::popLocation(*this); 678 } 679 680 PyLocation &DefaultingPyLocation::resolve() { 681 auto *location = PyThreadContextEntry::getDefaultLocation(); 682 if (!location) { 683 throw SetPyError( 684 PyExc_RuntimeError, 685 "An MLIR function requires a Location but none was provided in the " 686 "call or from the surrounding environment. Either pass to the function " 687 "with a 'loc=' argument or establish a default using 'with loc:'"); 688 } 689 return *location; 690 } 691 692 //------------------------------------------------------------------------------ 693 // PyModule 694 //------------------------------------------------------------------------------ 695 696 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 697 : BaseContextObject(std::move(contextRef)), module(module) {} 698 699 PyModule::~PyModule() { 700 py::gil_scoped_acquire acquire; 701 auto &liveModules = getContext()->liveModules; 702 assert(liveModules.count(module.ptr) == 1 && 703 "destroying module not in live map"); 704 liveModules.erase(module.ptr); 705 mlirModuleDestroy(module); 706 } 707 708 PyModuleRef PyModule::forModule(MlirModule module) { 709 MlirContext context = mlirModuleGetContext(module); 710 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 711 712 py::gil_scoped_acquire acquire; 713 auto &liveModules = contextRef->liveModules; 714 auto it = liveModules.find(module.ptr); 715 if (it == liveModules.end()) { 716 // Create. 717 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 718 // Note that the default return value policy on cast is automatic_reference, 719 // which does not take ownership (delete will not be called). 720 // Just be explicit. 721 py::object pyRef = 722 py::cast(unownedModule, py::return_value_policy::take_ownership); 723 unownedModule->handle = pyRef; 724 liveModules[module.ptr] = 725 std::make_pair(unownedModule->handle, unownedModule); 726 return PyModuleRef(unownedModule, std::move(pyRef)); 727 } 728 // Use existing. 729 PyModule *existing = it->second.second; 730 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 731 return PyModuleRef(existing, std::move(pyRef)); 732 } 733 734 py::object PyModule::createFromCapsule(py::object capsule) { 735 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 736 if (mlirModuleIsNull(rawModule)) 737 throw py::error_already_set(); 738 return forModule(rawModule).releaseObject(); 739 } 740 741 py::object PyModule::getCapsule() { 742 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 743 } 744 745 //------------------------------------------------------------------------------ 746 // PyOperation 747 //------------------------------------------------------------------------------ 748 749 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 750 : BaseContextObject(std::move(contextRef)), operation(operation) {} 751 752 PyOperation::~PyOperation() { 753 // If the operation has already been invalidated there is nothing to do. 754 if (!valid) 755 return; 756 auto &liveOperations = getContext()->liveOperations; 757 assert(liveOperations.count(operation.ptr) == 1 && 758 "destroying operation not in live map"); 759 liveOperations.erase(operation.ptr); 760 if (!isAttached()) { 761 mlirOperationDestroy(operation); 762 } 763 } 764 765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 766 MlirOperation operation, 767 py::object parentKeepAlive) { 768 auto &liveOperations = contextRef->liveOperations; 769 // Create. 770 PyOperation *unownedOperation = 771 new PyOperation(std::move(contextRef), operation); 772 // Note that the default return value policy on cast is automatic_reference, 773 // which does not take ownership (delete will not be called). 774 // Just be explicit. 775 py::object pyRef = 776 py::cast(unownedOperation, py::return_value_policy::take_ownership); 777 unownedOperation->handle = pyRef; 778 if (parentKeepAlive) { 779 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 780 } 781 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 782 return PyOperationRef(unownedOperation, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 auto it = liveOperations.find(operation.ptr); 790 if (it == liveOperations.end()) { 791 // Create. 792 return createInstance(std::move(contextRef), operation, 793 std::move(parentKeepAlive)); 794 } 795 // Use existing. 796 PyOperation *existing = it->second.second; 797 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 798 return PyOperationRef(existing, std::move(pyRef)); 799 } 800 801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 802 MlirOperation operation, 803 py::object parentKeepAlive) { 804 auto &liveOperations = contextRef->liveOperations; 805 assert(liveOperations.count(operation.ptr) == 0 && 806 "cannot create detached operation that already exists"); 807 (void)liveOperations; 808 809 PyOperationRef created = createInstance(std::move(contextRef), operation, 810 std::move(parentKeepAlive)); 811 created->attached = false; 812 return created; 813 } 814 815 void PyOperation::checkValid() const { 816 if (!valid) { 817 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 818 } 819 } 820 821 void PyOperationBase::print(py::object fileObject, bool binary, 822 llvm::Optional<int64_t> largeElementsLimit, 823 bool enableDebugInfo, bool prettyDebugInfo, 824 bool printGenericOpForm, bool useLocalScope) { 825 PyOperation &operation = getOperation(); 826 operation.checkValid(); 827 if (fileObject.is_none()) 828 fileObject = py::module::import("sys").attr("stdout"); 829 830 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 831 fileObject.attr("write")("// Verification failed, printing generic form\n"); 832 printGenericOpForm = true; 833 } 834 835 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 836 if (largeElementsLimit) 837 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 838 if (enableDebugInfo) 839 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 840 if (printGenericOpForm) 841 mlirOpPrintingFlagsPrintGenericOpForm(flags); 842 843 PyFileAccumulator accum(fileObject, binary); 844 py::gil_scoped_release(); 845 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 846 accum.getUserData()); 847 mlirOpPrintingFlagsDestroy(flags); 848 } 849 850 py::object PyOperationBase::getAsm(bool binary, 851 llvm::Optional<int64_t> largeElementsLimit, 852 bool enableDebugInfo, bool prettyDebugInfo, 853 bool printGenericOpForm, 854 bool useLocalScope) { 855 py::object fileObject; 856 if (binary) { 857 fileObject = py::module::import("io").attr("BytesIO")(); 858 } else { 859 fileObject = py::module::import("io").attr("StringIO")(); 860 } 861 print(fileObject, /*binary=*/binary, 862 /*largeElementsLimit=*/largeElementsLimit, 863 /*enableDebugInfo=*/enableDebugInfo, 864 /*prettyDebugInfo=*/prettyDebugInfo, 865 /*printGenericOpForm=*/printGenericOpForm, 866 /*useLocalScope=*/useLocalScope); 867 868 return fileObject.attr("getvalue")(); 869 } 870 871 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 872 checkValid(); 873 if (!isAttached()) 874 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 875 MlirOperation operation = mlirOperationGetParentOperation(get()); 876 if (mlirOperationIsNull(operation)) 877 return {}; 878 return PyOperation::forOperation(getContext(), operation); 879 } 880 881 PyBlock PyOperation::getBlock() { 882 checkValid(); 883 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 884 MlirBlock block = mlirOperationGetBlock(get()); 885 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 886 assert(parentOperation && "Operation has no parent"); 887 return PyBlock{std::move(*parentOperation), block}; 888 } 889 890 py::object PyOperation::getCapsule() { 891 checkValid(); 892 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 893 } 894 895 py::object PyOperation::createFromCapsule(py::object capsule) { 896 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 897 if (mlirOperationIsNull(rawOperation)) 898 throw py::error_already_set(); 899 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 900 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 901 .releaseObject(); 902 } 903 904 py::object PyOperation::create( 905 std::string name, llvm::Optional<std::vector<PyType *>> results, 906 llvm::Optional<std::vector<PyValue *>> operands, 907 llvm::Optional<py::dict> attributes, 908 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 909 DefaultingPyLocation location, py::object maybeIp) { 910 llvm::SmallVector<MlirValue, 4> mlirOperands; 911 llvm::SmallVector<MlirType, 4> mlirResults; 912 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 913 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 914 915 // General parameter validation. 916 if (regions < 0) 917 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 918 919 // Unpack/validate operands. 920 if (operands) { 921 mlirOperands.reserve(operands->size()); 922 for (PyValue *operand : *operands) { 923 if (!operand) 924 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 925 mlirOperands.push_back(operand->get()); 926 } 927 } 928 929 // Unpack/validate results. 930 if (results) { 931 mlirResults.reserve(results->size()); 932 for (PyType *result : *results) { 933 // TODO: Verify result type originate from the same context. 934 if (!result) 935 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 936 mlirResults.push_back(*result); 937 } 938 } 939 // Unpack/validate attributes. 940 if (attributes) { 941 mlirAttributes.reserve(attributes->size()); 942 for (auto &it : *attributes) { 943 std::string key; 944 try { 945 key = it.first.cast<std::string>(); 946 } catch (py::cast_error &err) { 947 std::string msg = "Invalid attribute key (not a string) when " 948 "attempting to create the operation \"" + 949 name + "\" (" + err.what() + ")"; 950 throw py::cast_error(msg); 951 } 952 try { 953 auto &attribute = it.second.cast<PyAttribute &>(); 954 // TODO: Verify attribute originates from the same context. 955 mlirAttributes.emplace_back(std::move(key), attribute); 956 } catch (py::reference_cast_error &) { 957 // This exception seems thrown when the value is "None". 958 std::string msg = 959 "Found an invalid (`None`?) attribute value for the key \"" + key + 960 "\" when attempting to create the operation \"" + name + "\""; 961 throw py::cast_error(msg); 962 } catch (py::cast_error &err) { 963 std::string msg = "Invalid attribute value for the key \"" + key + 964 "\" when attempting to create the operation \"" + 965 name + "\" (" + err.what() + ")"; 966 throw py::cast_error(msg); 967 } 968 } 969 } 970 // Unpack/validate successors. 971 if (successors) { 972 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 973 mlirSuccessors.reserve(successors->size()); 974 for (auto *successor : *successors) { 975 // TODO: Verify successor originate from the same context. 976 if (!successor) 977 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 978 mlirSuccessors.push_back(successor->get()); 979 } 980 } 981 982 // Apply unpacked/validated to the operation state. Beyond this 983 // point, exceptions cannot be thrown or else the state will leak. 984 MlirOperationState state = 985 mlirOperationStateGet(toMlirStringRef(name), location); 986 if (!mlirOperands.empty()) 987 mlirOperationStateAddOperands(&state, mlirOperands.size(), 988 mlirOperands.data()); 989 if (!mlirResults.empty()) 990 mlirOperationStateAddResults(&state, mlirResults.size(), 991 mlirResults.data()); 992 if (!mlirAttributes.empty()) { 993 // Note that the attribute names directly reference bytes in 994 // mlirAttributes, so that vector must not be changed from here 995 // on. 996 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 997 mlirNamedAttributes.reserve(mlirAttributes.size()); 998 for (auto &it : mlirAttributes) 999 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1000 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1001 toMlirStringRef(it.first)), 1002 it.second)); 1003 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1004 mlirNamedAttributes.data()); 1005 } 1006 if (!mlirSuccessors.empty()) 1007 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1008 mlirSuccessors.data()); 1009 if (regions) { 1010 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1011 mlirRegions.resize(regions); 1012 for (int i = 0; i < regions; ++i) 1013 mlirRegions[i] = mlirRegionCreate(); 1014 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1015 mlirRegions.data()); 1016 } 1017 1018 // Construct the operation. 1019 MlirOperation operation = mlirOperationCreate(&state); 1020 PyOperationRef created = 1021 PyOperation::createDetached(location->getContext(), operation); 1022 1023 // InsertPoint active? 1024 if (!maybeIp.is(py::cast(false))) { 1025 PyInsertionPoint *ip; 1026 if (maybeIp.is_none()) { 1027 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1028 } else { 1029 ip = py::cast<PyInsertionPoint *>(maybeIp); 1030 } 1031 if (ip) 1032 ip->insert(*created.get()); 1033 } 1034 1035 return created->createOpView(); 1036 } 1037 1038 py::object PyOperation::createOpView() { 1039 checkValid(); 1040 MlirIdentifier ident = mlirOperationGetName(get()); 1041 MlirStringRef identStr = mlirIdentifierStr(ident); 1042 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1043 StringRef(identStr.data, identStr.length)); 1044 if (opViewClass) 1045 return (*opViewClass)(getRef().getObject()); 1046 return py::cast(PyOpView(getRef().getObject())); 1047 } 1048 1049 void PyOperation::erase() { 1050 checkValid(); 1051 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1052 // Python reference to a child operation is live. All children should also 1053 // have their `valid` bit set to false. 1054 auto &liveOperations = getContext()->liveOperations; 1055 if (liveOperations.count(operation.ptr)) 1056 liveOperations.erase(operation.ptr); 1057 mlirOperationDestroy(operation); 1058 valid = false; 1059 } 1060 1061 //------------------------------------------------------------------------------ 1062 // PyOpView 1063 //------------------------------------------------------------------------------ 1064 1065 py::object 1066 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1067 py::list operandList, 1068 llvm::Optional<py::dict> attributes, 1069 llvm::Optional<std::vector<PyBlock *>> successors, 1070 llvm::Optional<int> regions, 1071 DefaultingPyLocation location, py::object maybeIp) { 1072 PyMlirContextRef context = location->getContext(); 1073 // Class level operation construction metadata. 1074 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1075 // Operand and result segment specs are either none, which does no 1076 // variadic unpacking, or a list of ints with segment sizes, where each 1077 // element is either a positive number (typically 1 for a scalar) or -1 to 1078 // indicate that it is derived from the length of the same-indexed operand 1079 // or result (implying that it is a list at that position). 1080 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1081 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1082 1083 std::vector<uint32_t> operandSegmentLengths; 1084 std::vector<uint32_t> resultSegmentLengths; 1085 1086 // Validate/determine region count. 1087 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1088 int opMinRegionCount = std::get<0>(opRegionSpec); 1089 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1090 if (!regions) { 1091 regions = opMinRegionCount; 1092 } 1093 if (*regions < opMinRegionCount) { 1094 throw py::value_error( 1095 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1096 llvm::Twine(opMinRegionCount) + 1097 " regions but was built with regions=" + llvm::Twine(*regions)) 1098 .str()); 1099 } 1100 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1101 throw py::value_error( 1102 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1103 llvm::Twine(opMinRegionCount) + 1104 " regions but was built with regions=" + llvm::Twine(*regions)) 1105 .str()); 1106 } 1107 1108 // Unpack results. 1109 std::vector<PyType *> resultTypes; 1110 resultTypes.reserve(resultTypeList.size()); 1111 if (resultSegmentSpecObj.is_none()) { 1112 // Non-variadic result unpacking. 1113 for (auto it : llvm::enumerate(resultTypeList)) { 1114 try { 1115 resultTypes.push_back(py::cast<PyType *>(it.value())); 1116 if (!resultTypes.back()) 1117 throw py::cast_error(); 1118 } catch (py::cast_error &err) { 1119 throw py::value_error((llvm::Twine("Result ") + 1120 llvm::Twine(it.index()) + " of operation \"" + 1121 name + "\" must be a Type (" + err.what() + ")") 1122 .str()); 1123 } 1124 } 1125 } else { 1126 // Sized result unpacking. 1127 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1128 if (resultSegmentSpec.size() != resultTypeList.size()) { 1129 throw py::value_error((llvm::Twine("Operation \"") + name + 1130 "\" requires " + 1131 llvm::Twine(resultSegmentSpec.size()) + 1132 "result segments but was provided " + 1133 llvm::Twine(resultTypeList.size())) 1134 .str()); 1135 } 1136 resultSegmentLengths.reserve(resultTypeList.size()); 1137 for (auto it : 1138 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1139 int segmentSpec = std::get<1>(it.value()); 1140 if (segmentSpec == 1 || segmentSpec == 0) { 1141 // Unpack unary element. 1142 try { 1143 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1144 if (resultType) { 1145 resultTypes.push_back(resultType); 1146 resultSegmentLengths.push_back(1); 1147 } else if (segmentSpec == 0) { 1148 // Allowed to be optional. 1149 resultSegmentLengths.push_back(0); 1150 } else { 1151 throw py::cast_error("was None and result is not optional"); 1152 } 1153 } catch (py::cast_error &err) { 1154 throw py::value_error((llvm::Twine("Result ") + 1155 llvm::Twine(it.index()) + " of operation \"" + 1156 name + "\" must be a Type (" + err.what() + 1157 ")") 1158 .str()); 1159 } 1160 } else if (segmentSpec == -1) { 1161 // Unpack sequence by appending. 1162 try { 1163 if (std::get<0>(it.value()).is_none()) { 1164 // Treat it as an empty list. 1165 resultSegmentLengths.push_back(0); 1166 } else { 1167 // Unpack the list. 1168 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1169 for (py::object segmentItem : segment) { 1170 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1171 if (!resultTypes.back()) { 1172 throw py::cast_error("contained a None item"); 1173 } 1174 } 1175 resultSegmentLengths.push_back(segment.size()); 1176 } 1177 } catch (std::exception &err) { 1178 // NOTE: Sloppy to be using a catch-all here, but there are at least 1179 // three different unrelated exceptions that can be thrown in the 1180 // above "casts". Just keep the scope above small and catch them all. 1181 throw py::value_error((llvm::Twine("Result ") + 1182 llvm::Twine(it.index()) + " of operation \"" + 1183 name + "\" must be a Sequence of Types (" + 1184 err.what() + ")") 1185 .str()); 1186 } 1187 } else { 1188 throw py::value_error("Unexpected segment spec"); 1189 } 1190 } 1191 } 1192 1193 // Unpack operands. 1194 std::vector<PyValue *> operands; 1195 operands.reserve(operands.size()); 1196 if (operandSegmentSpecObj.is_none()) { 1197 // Non-sized operand unpacking. 1198 for (auto it : llvm::enumerate(operandList)) { 1199 try { 1200 operands.push_back(py::cast<PyValue *>(it.value())); 1201 if (!operands.back()) 1202 throw py::cast_error(); 1203 } catch (py::cast_error &err) { 1204 throw py::value_error((llvm::Twine("Operand ") + 1205 llvm::Twine(it.index()) + " of operation \"" + 1206 name + "\" must be a Value (" + err.what() + ")") 1207 .str()); 1208 } 1209 } 1210 } else { 1211 // Sized operand unpacking. 1212 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1213 if (operandSegmentSpec.size() != operandList.size()) { 1214 throw py::value_error((llvm::Twine("Operation \"") + name + 1215 "\" requires " + 1216 llvm::Twine(operandSegmentSpec.size()) + 1217 "operand segments but was provided " + 1218 llvm::Twine(operandList.size())) 1219 .str()); 1220 } 1221 operandSegmentLengths.reserve(operandList.size()); 1222 for (auto it : 1223 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1224 int segmentSpec = std::get<1>(it.value()); 1225 if (segmentSpec == 1 || segmentSpec == 0) { 1226 // Unpack unary element. 1227 try { 1228 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1229 if (operandValue) { 1230 operands.push_back(operandValue); 1231 operandSegmentLengths.push_back(1); 1232 } else if (segmentSpec == 0) { 1233 // Allowed to be optional. 1234 operandSegmentLengths.push_back(0); 1235 } else { 1236 throw py::cast_error("was None and operand is not optional"); 1237 } 1238 } catch (py::cast_error &err) { 1239 throw py::value_error((llvm::Twine("Operand ") + 1240 llvm::Twine(it.index()) + " of operation \"" + 1241 name + "\" must be a Value (" + err.what() + 1242 ")") 1243 .str()); 1244 } 1245 } else if (segmentSpec == -1) { 1246 // Unpack sequence by appending. 1247 try { 1248 if (std::get<0>(it.value()).is_none()) { 1249 // Treat it as an empty list. 1250 operandSegmentLengths.push_back(0); 1251 } else { 1252 // Unpack the list. 1253 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1254 for (py::object segmentItem : segment) { 1255 operands.push_back(py::cast<PyValue *>(segmentItem)); 1256 if (!operands.back()) { 1257 throw py::cast_error("contained a None item"); 1258 } 1259 } 1260 operandSegmentLengths.push_back(segment.size()); 1261 } 1262 } catch (std::exception &err) { 1263 // NOTE: Sloppy to be using a catch-all here, but there are at least 1264 // three different unrelated exceptions that can be thrown in the 1265 // above "casts". Just keep the scope above small and catch them all. 1266 throw py::value_error((llvm::Twine("Operand ") + 1267 llvm::Twine(it.index()) + " of operation \"" + 1268 name + "\" must be a Sequence of Values (" + 1269 err.what() + ")") 1270 .str()); 1271 } 1272 } else { 1273 throw py::value_error("Unexpected segment spec"); 1274 } 1275 } 1276 } 1277 1278 // Merge operand/result segment lengths into attributes if needed. 1279 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1280 // Dup. 1281 if (attributes) { 1282 attributes = py::dict(*attributes); 1283 } else { 1284 attributes = py::dict(); 1285 } 1286 if (attributes->contains("result_segment_sizes") || 1287 attributes->contains("operand_segment_sizes")) { 1288 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1289 "'operand_segment_sizes' attribute is unsupported. " 1290 "Use Operation.create for such low-level access."); 1291 } 1292 1293 // Add result_segment_sizes attribute. 1294 if (!resultSegmentLengths.empty()) { 1295 int64_t size = resultSegmentLengths.size(); 1296 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1297 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1298 resultSegmentLengths.size(), resultSegmentLengths.data()); 1299 (*attributes)["result_segment_sizes"] = 1300 PyAttribute(context, segmentLengthAttr); 1301 } 1302 1303 // Add operand_segment_sizes attribute. 1304 if (!operandSegmentLengths.empty()) { 1305 int64_t size = operandSegmentLengths.size(); 1306 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1307 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1308 operandSegmentLengths.size(), operandSegmentLengths.data()); 1309 (*attributes)["operand_segment_sizes"] = 1310 PyAttribute(context, segmentLengthAttr); 1311 } 1312 } 1313 1314 // Delegate to create. 1315 return PyOperation::create(std::move(name), 1316 /*results=*/std::move(resultTypes), 1317 /*operands=*/std::move(operands), 1318 /*attributes=*/std::move(attributes), 1319 /*successors=*/std::move(successors), 1320 /*regions=*/*regions, location, maybeIp); 1321 } 1322 1323 PyOpView::PyOpView(py::object operationObject) 1324 // Casting through the PyOperationBase base-class and then back to the 1325 // Operation lets us accept any PyOperationBase subclass. 1326 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1327 operationObject(operation.getRef().getObject()) {} 1328 1329 py::object PyOpView::createRawSubclass(py::object userClass) { 1330 // This is... a little gross. The typical pattern is to have a pure python 1331 // class that extends OpView like: 1332 // class AddFOp(_cext.ir.OpView): 1333 // def __init__(self, loc, lhs, rhs): 1334 // operation = loc.context.create_operation( 1335 // "addf", lhs, rhs, results=[lhs.type]) 1336 // super().__init__(operation) 1337 // 1338 // I.e. The goal of the user facing type is to provide a nice constructor 1339 // that has complete freedom for the op under construction. This is at odds 1340 // with our other desire to sometimes create this object by just passing an 1341 // operation (to initialize the base class). We could do *arg and **kwargs 1342 // munging to try to make it work, but instead, we synthesize a new class 1343 // on the fly which extends this user class (AddFOp in this example) and 1344 // *give it* the base class's __init__ method, thus bypassing the 1345 // intermediate subclass's __init__ method entirely. While slightly, 1346 // underhanded, this is safe/legal because the type hierarchy has not changed 1347 // (we just added a new leaf) and we aren't mucking around with __new__. 1348 // Typically, this new class will be stored on the original as "_Raw" and will 1349 // be used for casts and other things that need a variant of the class that 1350 // is initialized purely from an operation. 1351 py::object parentMetaclass = 1352 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1353 py::dict attributes; 1354 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1355 // now. 1356 // auto opViewType = py::type::of<PyOpView>(); 1357 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1358 attributes["__init__"] = opViewType.attr("__init__"); 1359 py::str origName = userClass.attr("__name__"); 1360 py::str newName = py::str("_") + origName; 1361 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1362 } 1363 1364 //------------------------------------------------------------------------------ 1365 // PyInsertionPoint. 1366 //------------------------------------------------------------------------------ 1367 1368 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1369 1370 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1371 : refOperation(beforeOperationBase.getOperation().getRef()), 1372 block((*refOperation)->getBlock()) {} 1373 1374 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1375 PyOperation &operation = operationBase.getOperation(); 1376 if (operation.isAttached()) 1377 throw SetPyError(PyExc_ValueError, 1378 "Attempt to insert operation that is already attached"); 1379 block.getParentOperation()->checkValid(); 1380 MlirOperation beforeOp = {nullptr}; 1381 if (refOperation) { 1382 // Insert before operation. 1383 (*refOperation)->checkValid(); 1384 beforeOp = (*refOperation)->get(); 1385 } else { 1386 // Insert at end (before null) is only valid if the block does not 1387 // already end in a known terminator (violating this will cause assertion 1388 // failures later). 1389 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1390 throw py::index_error("Cannot insert operation at the end of a block " 1391 "that already has a terminator. Did you mean to " 1392 "use 'InsertionPoint.at_block_terminator(block)' " 1393 "versus 'InsertionPoint(block)'?"); 1394 } 1395 } 1396 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1397 operation.setAttached(); 1398 } 1399 1400 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1401 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1402 if (mlirOperationIsNull(firstOp)) { 1403 // Just insert at end. 1404 return PyInsertionPoint(block); 1405 } 1406 1407 // Insert before first op. 1408 PyOperationRef firstOpRef = PyOperation::forOperation( 1409 block.getParentOperation()->getContext(), firstOp); 1410 return PyInsertionPoint{block, std::move(firstOpRef)}; 1411 } 1412 1413 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1414 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1415 if (mlirOperationIsNull(terminator)) 1416 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1417 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1418 block.getParentOperation()->getContext(), terminator); 1419 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1420 } 1421 1422 py::object PyInsertionPoint::contextEnter() { 1423 return PyThreadContextEntry::pushInsertionPoint(*this); 1424 } 1425 1426 void PyInsertionPoint::contextExit(pybind11::object excType, 1427 pybind11::object excVal, 1428 pybind11::object excTb) { 1429 PyThreadContextEntry::popInsertionPoint(*this); 1430 } 1431 1432 //------------------------------------------------------------------------------ 1433 // PyAttribute. 1434 //------------------------------------------------------------------------------ 1435 1436 bool PyAttribute::operator==(const PyAttribute &other) { 1437 return mlirAttributeEqual(attr, other.attr); 1438 } 1439 1440 py::object PyAttribute::getCapsule() { 1441 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1442 } 1443 1444 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1445 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1446 if (mlirAttributeIsNull(rawAttr)) 1447 throw py::error_already_set(); 1448 return PyAttribute( 1449 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1450 } 1451 1452 //------------------------------------------------------------------------------ 1453 // PyNamedAttribute. 1454 //------------------------------------------------------------------------------ 1455 1456 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1457 : ownedName(new std::string(std::move(ownedName))) { 1458 namedAttr = mlirNamedAttributeGet( 1459 mlirIdentifierGet(mlirAttributeGetContext(attr), 1460 toMlirStringRef(*this->ownedName)), 1461 attr); 1462 } 1463 1464 //------------------------------------------------------------------------------ 1465 // PyType. 1466 //------------------------------------------------------------------------------ 1467 1468 bool PyType::operator==(const PyType &other) { 1469 return mlirTypeEqual(type, other.type); 1470 } 1471 1472 py::object PyType::getCapsule() { 1473 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1474 } 1475 1476 PyType PyType::createFromCapsule(py::object capsule) { 1477 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1478 if (mlirTypeIsNull(rawType)) 1479 throw py::error_already_set(); 1480 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1481 rawType); 1482 } 1483 1484 //------------------------------------------------------------------------------ 1485 // PyValue and subclases. 1486 //------------------------------------------------------------------------------ 1487 1488 pybind11::object PyValue::getCapsule() { 1489 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1490 } 1491 1492 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1493 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1494 if (mlirValueIsNull(value)) 1495 throw py::error_already_set(); 1496 MlirOperation owner; 1497 if (mlirValueIsAOpResult(value)) 1498 owner = mlirOpResultGetOwner(value); 1499 if (mlirValueIsABlockArgument(value)) 1500 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1501 if (mlirOperationIsNull(owner)) 1502 throw py::error_already_set(); 1503 MlirContext ctx = mlirOperationGetContext(owner); 1504 PyOperationRef ownerRef = 1505 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1506 return PyValue(ownerRef, value); 1507 } 1508 1509 namespace { 1510 /// CRTP base class for Python MLIR values that subclass Value and should be 1511 /// castable from it. The value hierarchy is one level deep and is not supposed 1512 /// to accommodate other levels unless core MLIR changes. 1513 template <typename DerivedTy> 1514 class PyConcreteValue : public PyValue { 1515 public: 1516 // Derived classes must define statics for: 1517 // IsAFunctionTy isaFunction 1518 // const char *pyClassName 1519 // and redefine bindDerived. 1520 using ClassTy = py::class_<DerivedTy, PyValue>; 1521 using IsAFunctionTy = bool (*)(MlirValue); 1522 1523 PyConcreteValue() = default; 1524 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1525 : PyValue(operationRef, value) {} 1526 PyConcreteValue(PyValue &orig) 1527 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1528 1529 /// Attempts to cast the original value to the derived type and throws on 1530 /// type mismatches. 1531 static MlirValue castFrom(PyValue &orig) { 1532 if (!DerivedTy::isaFunction(orig.get())) { 1533 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1534 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1535 DerivedTy::pyClassName + 1536 " (from " + origRepr + ")"); 1537 } 1538 return orig.get(); 1539 } 1540 1541 /// Binds the Python module objects to functions of this class. 1542 static void bind(py::module &m) { 1543 auto cls = ClassTy(m, DerivedTy::pyClassName); 1544 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1545 DerivedTy::bindDerived(cls); 1546 } 1547 1548 /// Implemented by derived classes to add methods to the Python subclass. 1549 static void bindDerived(ClassTy &m) {} 1550 }; 1551 1552 /// Python wrapper for MlirBlockArgument. 1553 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1554 public: 1555 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1556 static constexpr const char *pyClassName = "BlockArgument"; 1557 using PyConcreteValue::PyConcreteValue; 1558 1559 static void bindDerived(ClassTy &c) { 1560 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1561 return PyBlock(self.getParentOperation(), 1562 mlirBlockArgumentGetOwner(self.get())); 1563 }); 1564 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1565 return mlirBlockArgumentGetArgNumber(self.get()); 1566 }); 1567 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1568 return mlirBlockArgumentSetType(self.get(), type); 1569 }); 1570 } 1571 }; 1572 1573 /// Python wrapper for MlirOpResult. 1574 class PyOpResult : public PyConcreteValue<PyOpResult> { 1575 public: 1576 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1577 static constexpr const char *pyClassName = "OpResult"; 1578 using PyConcreteValue::PyConcreteValue; 1579 1580 static void bindDerived(ClassTy &c) { 1581 c.def_property_readonly("owner", [](PyOpResult &self) { 1582 assert( 1583 mlirOperationEqual(self.getParentOperation()->get(), 1584 mlirOpResultGetOwner(self.get())) && 1585 "expected the owner of the value in Python to match that in the IR"); 1586 return self.getParentOperation().getObject(); 1587 }); 1588 c.def_property_readonly("result_number", [](PyOpResult &self) { 1589 return mlirOpResultGetResultNumber(self.get()); 1590 }); 1591 } 1592 }; 1593 1594 /// A list of block arguments. Internally, these are stored as consecutive 1595 /// elements, random access is cheap. The argument list is associated with the 1596 /// operation that contains the block (detached blocks are not allowed in 1597 /// Python bindings) and extends its lifetime. 1598 class PyBlockArgumentList { 1599 public: 1600 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1601 : operation(std::move(operation)), block(block) {} 1602 1603 /// Returns the length of the block argument list. 1604 intptr_t dunderLen() { 1605 operation->checkValid(); 1606 return mlirBlockGetNumArguments(block); 1607 } 1608 1609 /// Returns `index`-th element of the block argument list. 1610 PyBlockArgument dunderGetItem(intptr_t index) { 1611 if (index < 0 || index >= dunderLen()) { 1612 throw SetPyError(PyExc_IndexError, 1613 "attempt to access out of bounds region"); 1614 } 1615 PyValue value(operation, mlirBlockGetArgument(block, index)); 1616 return PyBlockArgument(value); 1617 } 1618 1619 /// Defines a Python class in the bindings. 1620 static void bind(py::module &m) { 1621 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1622 .def("__len__", &PyBlockArgumentList::dunderLen) 1623 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1624 } 1625 1626 private: 1627 PyOperationRef operation; 1628 MlirBlock block; 1629 }; 1630 1631 /// A list of operation operands. Internally, these are stored as consecutive 1632 /// elements, random access is cheap. The result list is associated with the 1633 /// operation whose results these are, and extends the lifetime of this 1634 /// operation. 1635 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1636 public: 1637 static constexpr const char *pyClassName = "OpOperandList"; 1638 1639 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1640 intptr_t length = -1, intptr_t step = 1) 1641 : Sliceable(startIndex, 1642 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1643 : length, 1644 step), 1645 operation(operation) {} 1646 1647 intptr_t getNumElements() { 1648 operation->checkValid(); 1649 return mlirOperationGetNumOperands(operation->get()); 1650 } 1651 1652 PyValue getElement(intptr_t pos) { 1653 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1654 MlirOperation owner; 1655 if (mlirValueIsAOpResult(operand)) 1656 owner = mlirOpResultGetOwner(operand); 1657 else if (mlirValueIsABlockArgument(operand)) 1658 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1659 else 1660 assert(false && "Value must be an block arg or op result."); 1661 PyOperationRef pyOwner = 1662 PyOperation::forOperation(operation->getContext(), owner); 1663 return PyValue(pyOwner, operand); 1664 } 1665 1666 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1667 return PyOpOperandList(operation, startIndex, length, step); 1668 } 1669 1670 void dunderSetItem(intptr_t index, PyValue value) { 1671 index = wrapIndex(index); 1672 mlirOperationSetOperand(operation->get(), index, value.get()); 1673 } 1674 1675 static void bindDerived(ClassTy &c) { 1676 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1677 } 1678 1679 private: 1680 PyOperationRef operation; 1681 }; 1682 1683 /// A list of operation results. Internally, these are stored as consecutive 1684 /// elements, random access is cheap. The result list is associated with the 1685 /// operation whose results these are, and extends the lifetime of this 1686 /// operation. 1687 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1688 public: 1689 static constexpr const char *pyClassName = "OpResultList"; 1690 1691 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1692 intptr_t length = -1, intptr_t step = 1) 1693 : Sliceable(startIndex, 1694 length == -1 ? mlirOperationGetNumResults(operation->get()) 1695 : length, 1696 step), 1697 operation(operation) {} 1698 1699 intptr_t getNumElements() { 1700 operation->checkValid(); 1701 return mlirOperationGetNumResults(operation->get()); 1702 } 1703 1704 PyOpResult getElement(intptr_t index) { 1705 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1706 return PyOpResult(value); 1707 } 1708 1709 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1710 return PyOpResultList(operation, startIndex, length, step); 1711 } 1712 1713 private: 1714 PyOperationRef operation; 1715 }; 1716 1717 /// A list of operation attributes. Can be indexed by name, producing 1718 /// attributes, or by index, producing named attributes. 1719 class PyOpAttributeMap { 1720 public: 1721 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1722 1723 PyAttribute dunderGetItemNamed(const std::string &name) { 1724 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1725 toMlirStringRef(name)); 1726 if (mlirAttributeIsNull(attr)) { 1727 throw SetPyError(PyExc_KeyError, 1728 "attempt to access a non-existent attribute"); 1729 } 1730 return PyAttribute(operation->getContext(), attr); 1731 } 1732 1733 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1734 if (index < 0 || index >= dunderLen()) { 1735 throw SetPyError(PyExc_IndexError, 1736 "attempt to access out of bounds attribute"); 1737 } 1738 MlirNamedAttribute namedAttr = 1739 mlirOperationGetAttribute(operation->get(), index); 1740 return PyNamedAttribute( 1741 namedAttr.attribute, 1742 std::string(mlirIdentifierStr(namedAttr.name).data)); 1743 } 1744 1745 void dunderSetItem(const std::string &name, PyAttribute attr) { 1746 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1747 attr); 1748 } 1749 1750 void dunderDelItem(const std::string &name) { 1751 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1752 toMlirStringRef(name)); 1753 if (!removed) 1754 throw SetPyError(PyExc_KeyError, 1755 "attempt to delete a non-existent attribute"); 1756 } 1757 1758 intptr_t dunderLen() { 1759 return mlirOperationGetNumAttributes(operation->get()); 1760 } 1761 1762 bool dunderContains(const std::string &name) { 1763 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1764 operation->get(), toMlirStringRef(name))); 1765 } 1766 1767 static void bind(py::module &m) { 1768 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1769 .def("__contains__", &PyOpAttributeMap::dunderContains) 1770 .def("__len__", &PyOpAttributeMap::dunderLen) 1771 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1772 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1773 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1774 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1775 } 1776 1777 private: 1778 PyOperationRef operation; 1779 }; 1780 1781 } // end namespace 1782 1783 //------------------------------------------------------------------------------ 1784 // Populates the core exports of the 'ir' submodule. 1785 //------------------------------------------------------------------------------ 1786 1787 void mlir::python::populateIRCore(py::module &m) { 1788 //---------------------------------------------------------------------------- 1789 // Mapping of MlirContext. 1790 //---------------------------------------------------------------------------- 1791 py::class_<PyMlirContext>(m, "Context") 1792 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1793 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1794 .def("_get_context_again", 1795 [](PyMlirContext &self) { 1796 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1797 return ref.releaseObject(); 1798 }) 1799 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1800 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1801 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1802 &PyMlirContext::getCapsule) 1803 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1804 .def("__enter__", &PyMlirContext::contextEnter) 1805 .def("__exit__", &PyMlirContext::contextExit) 1806 .def_property_readonly_static( 1807 "current", 1808 [](py::object & /*class*/) { 1809 auto *context = PyThreadContextEntry::getDefaultContext(); 1810 if (!context) 1811 throw SetPyError(PyExc_ValueError, "No current Context"); 1812 return context; 1813 }, 1814 "Gets the Context bound to the current thread or raises ValueError") 1815 .def_property_readonly( 1816 "dialects", 1817 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1818 "Gets a container for accessing dialects by name") 1819 .def_property_readonly( 1820 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1821 "Alias for 'dialect'") 1822 .def( 1823 "get_dialect_descriptor", 1824 [=](PyMlirContext &self, std::string &name) { 1825 MlirDialect dialect = mlirContextGetOrLoadDialect( 1826 self.get(), {name.data(), name.size()}); 1827 if (mlirDialectIsNull(dialect)) { 1828 throw SetPyError(PyExc_ValueError, 1829 Twine("Dialect '") + name + "' not found"); 1830 } 1831 return PyDialectDescriptor(self.getRef(), dialect); 1832 }, 1833 "Gets or loads a dialect by name, returning its descriptor object") 1834 .def_property( 1835 "allow_unregistered_dialects", 1836 [](PyMlirContext &self) -> bool { 1837 return mlirContextGetAllowUnregisteredDialects(self.get()); 1838 }, 1839 [](PyMlirContext &self, bool value) { 1840 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1841 }) 1842 .def("enable_multithreading", 1843 [](PyMlirContext &self, bool enable) { 1844 mlirContextEnableMultithreading(self.get(), enable); 1845 }) 1846 .def("is_registered_operation", 1847 [](PyMlirContext &self, std::string &name) { 1848 return mlirContextIsRegisteredOperation( 1849 self.get(), MlirStringRef{name.data(), name.size()}); 1850 }); 1851 1852 //---------------------------------------------------------------------------- 1853 // Mapping of PyDialectDescriptor 1854 //---------------------------------------------------------------------------- 1855 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1856 .def_property_readonly("namespace", 1857 [](PyDialectDescriptor &self) { 1858 MlirStringRef ns = 1859 mlirDialectGetNamespace(self.get()); 1860 return py::str(ns.data, ns.length); 1861 }) 1862 .def("__repr__", [](PyDialectDescriptor &self) { 1863 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1864 std::string repr("<DialectDescriptor "); 1865 repr.append(ns.data, ns.length); 1866 repr.append(">"); 1867 return repr; 1868 }); 1869 1870 //---------------------------------------------------------------------------- 1871 // Mapping of PyDialects 1872 //---------------------------------------------------------------------------- 1873 py::class_<PyDialects>(m, "Dialects") 1874 .def("__getitem__", 1875 [=](PyDialects &self, std::string keyName) { 1876 MlirDialect dialect = 1877 self.getDialectForKey(keyName, /*attrError=*/false); 1878 py::object descriptor = 1879 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1880 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1881 }) 1882 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1883 MlirDialect dialect = 1884 self.getDialectForKey(attrName, /*attrError=*/true); 1885 py::object descriptor = 1886 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1887 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1888 }); 1889 1890 //---------------------------------------------------------------------------- 1891 // Mapping of PyDialect 1892 //---------------------------------------------------------------------------- 1893 py::class_<PyDialect>(m, "Dialect") 1894 .def(py::init<py::object>(), "descriptor") 1895 .def_property_readonly( 1896 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1897 .def("__repr__", [](py::object self) { 1898 auto clazz = self.attr("__class__"); 1899 return py::str("<Dialect ") + 1900 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1901 clazz.attr("__module__") + py::str(".") + 1902 clazz.attr("__name__") + py::str(")>"); 1903 }); 1904 1905 //---------------------------------------------------------------------------- 1906 // Mapping of Location 1907 //---------------------------------------------------------------------------- 1908 py::class_<PyLocation>(m, "Location") 1909 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1910 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1911 .def("__enter__", &PyLocation::contextEnter) 1912 .def("__exit__", &PyLocation::contextExit) 1913 .def("__eq__", 1914 [](PyLocation &self, PyLocation &other) -> bool { 1915 return mlirLocationEqual(self, other); 1916 }) 1917 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1918 .def_property_readonly_static( 1919 "current", 1920 [](py::object & /*class*/) { 1921 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1922 if (!loc) 1923 throw SetPyError(PyExc_ValueError, "No current Location"); 1924 return loc; 1925 }, 1926 "Gets the Location bound to the current thread or raises ValueError") 1927 .def_static( 1928 "unknown", 1929 [](DefaultingPyMlirContext context) { 1930 return PyLocation(context->getRef(), 1931 mlirLocationUnknownGet(context->get())); 1932 }, 1933 py::arg("context") = py::none(), 1934 "Gets a Location representing an unknown location") 1935 .def_static( 1936 "file", 1937 [](std::string filename, int line, int col, 1938 DefaultingPyMlirContext context) { 1939 return PyLocation( 1940 context->getRef(), 1941 mlirLocationFileLineColGet( 1942 context->get(), toMlirStringRef(filename), line, col)); 1943 }, 1944 py::arg("filename"), py::arg("line"), py::arg("col"), 1945 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1946 .def_property_readonly( 1947 "context", 1948 [](PyLocation &self) { return self.getContext().getObject(); }, 1949 "Context that owns the Location") 1950 .def("__repr__", [](PyLocation &self) { 1951 PyPrintAccumulator printAccum; 1952 mlirLocationPrint(self, printAccum.getCallback(), 1953 printAccum.getUserData()); 1954 return printAccum.join(); 1955 }); 1956 1957 //---------------------------------------------------------------------------- 1958 // Mapping of Module 1959 //---------------------------------------------------------------------------- 1960 py::class_<PyModule>(m, "Module") 1961 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1962 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1963 .def_static( 1964 "parse", 1965 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1966 MlirModule module = mlirModuleCreateParse( 1967 context->get(), toMlirStringRef(moduleAsm)); 1968 // TODO: Rework error reporting once diagnostic engine is exposed 1969 // in C API. 1970 if (mlirModuleIsNull(module)) { 1971 throw SetPyError( 1972 PyExc_ValueError, 1973 "Unable to parse module assembly (see diagnostics)"); 1974 } 1975 return PyModule::forModule(module).releaseObject(); 1976 }, 1977 py::arg("asm"), py::arg("context") = py::none(), 1978 kModuleParseDocstring) 1979 .def_static( 1980 "create", 1981 [](DefaultingPyLocation loc) { 1982 MlirModule module = mlirModuleCreateEmpty(loc); 1983 return PyModule::forModule(module).releaseObject(); 1984 }, 1985 py::arg("loc") = py::none(), "Creates an empty module") 1986 .def_property_readonly( 1987 "context", 1988 [](PyModule &self) { return self.getContext().getObject(); }, 1989 "Context that created the Module") 1990 .def_property_readonly( 1991 "operation", 1992 [](PyModule &self) { 1993 return PyOperation::forOperation(self.getContext(), 1994 mlirModuleGetOperation(self.get()), 1995 self.getRef().releaseObject()) 1996 .releaseObject(); 1997 }, 1998 "Accesses the module as an operation") 1999 .def_property_readonly( 2000 "body", 2001 [](PyModule &self) { 2002 PyOperationRef module_op = PyOperation::forOperation( 2003 self.getContext(), mlirModuleGetOperation(self.get()), 2004 self.getRef().releaseObject()); 2005 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2006 return returnBlock; 2007 }, 2008 "Return the block for this module") 2009 .def( 2010 "dump", 2011 [](PyModule &self) { 2012 mlirOperationDump(mlirModuleGetOperation(self.get())); 2013 }, 2014 kDumpDocstring) 2015 .def( 2016 "__str__", 2017 [](PyModule &self) { 2018 MlirOperation operation = mlirModuleGetOperation(self.get()); 2019 PyPrintAccumulator printAccum; 2020 mlirOperationPrint(operation, printAccum.getCallback(), 2021 printAccum.getUserData()); 2022 return printAccum.join(); 2023 }, 2024 kOperationStrDunderDocstring); 2025 2026 //---------------------------------------------------------------------------- 2027 // Mapping of Operation. 2028 //---------------------------------------------------------------------------- 2029 py::class_<PyOperationBase>(m, "_OperationBase") 2030 .def("__eq__", 2031 [](PyOperationBase &self, PyOperationBase &other) { 2032 return &self.getOperation() == &other.getOperation(); 2033 }) 2034 .def("__eq__", 2035 [](PyOperationBase &self, py::object other) { return false; }) 2036 .def_property_readonly("attributes", 2037 [](PyOperationBase &self) { 2038 return PyOpAttributeMap( 2039 self.getOperation().getRef()); 2040 }) 2041 .def_property_readonly("operands", 2042 [](PyOperationBase &self) { 2043 return PyOpOperandList( 2044 self.getOperation().getRef()); 2045 }) 2046 .def_property_readonly("regions", 2047 [](PyOperationBase &self) { 2048 return PyRegionList( 2049 self.getOperation().getRef()); 2050 }) 2051 .def_property_readonly( 2052 "results", 2053 [](PyOperationBase &self) { 2054 return PyOpResultList(self.getOperation().getRef()); 2055 }, 2056 "Returns the list of Operation results.") 2057 .def_property_readonly( 2058 "result", 2059 [](PyOperationBase &self) { 2060 auto &operation = self.getOperation(); 2061 auto numResults = mlirOperationGetNumResults(operation); 2062 if (numResults != 1) { 2063 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2064 throw SetPyError( 2065 PyExc_ValueError, 2066 Twine("Cannot call .result on operation ") + 2067 StringRef(name.data, name.length) + " which has " + 2068 Twine(numResults) + 2069 " results (it is only valid for operations with a " 2070 "single result)"); 2071 } 2072 return PyOpResult(operation.getRef(), 2073 mlirOperationGetResult(operation, 0)); 2074 }, 2075 "Shortcut to get an op result if it has only one (throws an error " 2076 "otherwise).") 2077 .def("__iter__", 2078 [](PyOperationBase &self) { 2079 return PyRegionIterator(self.getOperation().getRef()); 2080 }) 2081 .def( 2082 "__str__", 2083 [](PyOperationBase &self) { 2084 return self.getAsm(/*binary=*/false, 2085 /*largeElementsLimit=*/llvm::None, 2086 /*enableDebugInfo=*/false, 2087 /*prettyDebugInfo=*/false, 2088 /*printGenericOpForm=*/false, 2089 /*useLocalScope=*/false); 2090 }, 2091 "Returns the assembly form of the operation.") 2092 .def("print", &PyOperationBase::print, 2093 // Careful: Lots of arguments must match up with print method. 2094 py::arg("file") = py::none(), py::arg("binary") = false, 2095 py::arg("large_elements_limit") = py::none(), 2096 py::arg("enable_debug_info") = false, 2097 py::arg("pretty_debug_info") = false, 2098 py::arg("print_generic_op_form") = false, 2099 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2100 .def("get_asm", &PyOperationBase::getAsm, 2101 // Careful: Lots of arguments must match up with get_asm method. 2102 py::arg("binary") = false, 2103 py::arg("large_elements_limit") = py::none(), 2104 py::arg("enable_debug_info") = false, 2105 py::arg("pretty_debug_info") = false, 2106 py::arg("print_generic_op_form") = false, 2107 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2108 .def( 2109 "verify", 2110 [](PyOperationBase &self) { 2111 return mlirOperationVerify(self.getOperation()); 2112 }, 2113 "Verify the operation and return true if it passes, false if it " 2114 "fails."); 2115 2116 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2117 .def_static("create", &PyOperation::create, py::arg("name"), 2118 py::arg("results") = py::none(), 2119 py::arg("operands") = py::none(), 2120 py::arg("attributes") = py::none(), 2121 py::arg("successors") = py::none(), py::arg("regions") = 0, 2122 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2123 kOperationCreateDocstring) 2124 .def_property_readonly("parent", 2125 [](PyOperation &self) -> py::object { 2126 auto parent = self.getParentOperation(); 2127 if (parent) 2128 return parent->getObject(); 2129 return py::none(); 2130 }) 2131 .def("erase", &PyOperation::erase) 2132 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2133 &PyOperation::getCapsule) 2134 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2135 .def_property_readonly("name", 2136 [](PyOperation &self) { 2137 self.checkValid(); 2138 MlirOperation operation = self.get(); 2139 MlirStringRef name = mlirIdentifierStr( 2140 mlirOperationGetName(operation)); 2141 return py::str(name.data, name.length); 2142 }) 2143 .def_property_readonly( 2144 "context", 2145 [](PyOperation &self) { 2146 self.checkValid(); 2147 return self.getContext().getObject(); 2148 }, 2149 "Context that owns the Operation") 2150 .def_property_readonly("opview", &PyOperation::createOpView); 2151 2152 auto opViewClass = 2153 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2154 .def(py::init<py::object>()) 2155 .def_property_readonly("operation", &PyOpView::getOperationObject) 2156 .def_property_readonly( 2157 "context", 2158 [](PyOpView &self) { 2159 return self.getOperation().getContext().getObject(); 2160 }, 2161 "Context that owns the Operation") 2162 .def("__str__", [](PyOpView &self) { 2163 return py::str(self.getOperationObject()); 2164 }); 2165 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2166 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2167 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2168 opViewClass.attr("build_generic") = classmethod( 2169 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2170 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2171 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2172 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2173 "Builds a specific, generated OpView based on class level attributes."); 2174 2175 //---------------------------------------------------------------------------- 2176 // Mapping of PyRegion. 2177 //---------------------------------------------------------------------------- 2178 py::class_<PyRegion>(m, "Region") 2179 .def_property_readonly( 2180 "blocks", 2181 [](PyRegion &self) { 2182 return PyBlockList(self.getParentOperation(), self.get()); 2183 }, 2184 "Returns a forward-optimized sequence of blocks.") 2185 .def( 2186 "__iter__", 2187 [](PyRegion &self) { 2188 self.checkValid(); 2189 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2190 return PyBlockIterator(self.getParentOperation(), firstBlock); 2191 }, 2192 "Iterates over blocks in the region.") 2193 .def("__eq__", 2194 [](PyRegion &self, PyRegion &other) { 2195 return self.get().ptr == other.get().ptr; 2196 }) 2197 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2198 2199 //---------------------------------------------------------------------------- 2200 // Mapping of PyBlock. 2201 //---------------------------------------------------------------------------- 2202 py::class_<PyBlock>(m, "Block") 2203 .def_property_readonly( 2204 "owner", 2205 [](PyBlock &self) { 2206 return self.getParentOperation()->createOpView(); 2207 }, 2208 "Returns the owning operation of this block.") 2209 .def_property_readonly( 2210 "arguments", 2211 [](PyBlock &self) { 2212 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2213 }, 2214 "Returns a list of block arguments.") 2215 .def_property_readonly( 2216 "operations", 2217 [](PyBlock &self) { 2218 return PyOperationList(self.getParentOperation(), self.get()); 2219 }, 2220 "Returns a forward-optimized sequence of operations.") 2221 .def( 2222 "__iter__", 2223 [](PyBlock &self) { 2224 self.checkValid(); 2225 MlirOperation firstOperation = 2226 mlirBlockGetFirstOperation(self.get()); 2227 return PyOperationIterator(self.getParentOperation(), 2228 firstOperation); 2229 }, 2230 "Iterates over operations in the block.") 2231 .def("__eq__", 2232 [](PyBlock &self, PyBlock &other) { 2233 return self.get().ptr == other.get().ptr; 2234 }) 2235 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2236 .def( 2237 "__str__", 2238 [](PyBlock &self) { 2239 self.checkValid(); 2240 PyPrintAccumulator printAccum; 2241 mlirBlockPrint(self.get(), printAccum.getCallback(), 2242 printAccum.getUserData()); 2243 return printAccum.join(); 2244 }, 2245 "Returns the assembly form of the block."); 2246 2247 //---------------------------------------------------------------------------- 2248 // Mapping of PyInsertionPoint. 2249 //---------------------------------------------------------------------------- 2250 2251 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2252 .def(py::init<PyBlock &>(), py::arg("block"), 2253 "Inserts after the last operation but still inside the block.") 2254 .def("__enter__", &PyInsertionPoint::contextEnter) 2255 .def("__exit__", &PyInsertionPoint::contextExit) 2256 .def_property_readonly_static( 2257 "current", 2258 [](py::object & /*class*/) { 2259 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2260 if (!ip) 2261 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2262 return ip; 2263 }, 2264 "Gets the InsertionPoint bound to the current thread or raises " 2265 "ValueError if none has been set") 2266 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2267 "Inserts before a referenced operation.") 2268 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2269 py::arg("block"), "Inserts at the beginning of the block.") 2270 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2271 py::arg("block"), "Inserts before the block terminator.") 2272 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2273 "Inserts an operation."); 2274 2275 //---------------------------------------------------------------------------- 2276 // Mapping of PyAttribute. 2277 //---------------------------------------------------------------------------- 2278 py::class_<PyAttribute>(m, "Attribute") 2279 // Delegate to the PyAttribute copy constructor, which will also lifetime 2280 // extend the backing context which owns the MlirAttribute. 2281 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2282 "Casts the passed attribute to the generic Attribute") 2283 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2284 &PyAttribute::getCapsule) 2285 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2286 .def_static( 2287 "parse", 2288 [](std::string attrSpec, DefaultingPyMlirContext context) { 2289 MlirAttribute type = mlirAttributeParseGet( 2290 context->get(), toMlirStringRef(attrSpec)); 2291 // TODO: Rework error reporting once diagnostic engine is exposed 2292 // in C API. 2293 if (mlirAttributeIsNull(type)) { 2294 throw SetPyError(PyExc_ValueError, 2295 Twine("Unable to parse attribute: '") + 2296 attrSpec + "'"); 2297 } 2298 return PyAttribute(context->getRef(), type); 2299 }, 2300 py::arg("asm"), py::arg("context") = py::none(), 2301 "Parses an attribute from an assembly form") 2302 .def_property_readonly( 2303 "context", 2304 [](PyAttribute &self) { return self.getContext().getObject(); }, 2305 "Context that owns the Attribute") 2306 .def_property_readonly("type", 2307 [](PyAttribute &self) { 2308 return PyType(self.getContext()->getRef(), 2309 mlirAttributeGetType(self)); 2310 }) 2311 .def( 2312 "get_named", 2313 [](PyAttribute &self, std::string name) { 2314 return PyNamedAttribute(self, std::move(name)); 2315 }, 2316 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2317 .def("__eq__", 2318 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2319 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2320 .def( 2321 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2322 kDumpDocstring) 2323 .def( 2324 "__str__", 2325 [](PyAttribute &self) { 2326 PyPrintAccumulator printAccum; 2327 mlirAttributePrint(self, printAccum.getCallback(), 2328 printAccum.getUserData()); 2329 return printAccum.join(); 2330 }, 2331 "Returns the assembly form of the Attribute.") 2332 .def("__repr__", [](PyAttribute &self) { 2333 // Generally, assembly formats are not printed for __repr__ because 2334 // this can cause exceptionally long debug output and exceptions. 2335 // However, attribute values are generally considered useful and are 2336 // printed. This may need to be re-evaluated if debug dumps end up 2337 // being excessive. 2338 PyPrintAccumulator printAccum; 2339 printAccum.parts.append("Attribute("); 2340 mlirAttributePrint(self, printAccum.getCallback(), 2341 printAccum.getUserData()); 2342 printAccum.parts.append(")"); 2343 return printAccum.join(); 2344 }); 2345 2346 //---------------------------------------------------------------------------- 2347 // Mapping of PyNamedAttribute 2348 //---------------------------------------------------------------------------- 2349 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2350 .def("__repr__", 2351 [](PyNamedAttribute &self) { 2352 PyPrintAccumulator printAccum; 2353 printAccum.parts.append("NamedAttribute("); 2354 printAccum.parts.append( 2355 mlirIdentifierStr(self.namedAttr.name).data); 2356 printAccum.parts.append("="); 2357 mlirAttributePrint(self.namedAttr.attribute, 2358 printAccum.getCallback(), 2359 printAccum.getUserData()); 2360 printAccum.parts.append(")"); 2361 return printAccum.join(); 2362 }) 2363 .def_property_readonly( 2364 "name", 2365 [](PyNamedAttribute &self) { 2366 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2367 mlirIdentifierStr(self.namedAttr.name).length); 2368 }, 2369 "The name of the NamedAttribute binding") 2370 .def_property_readonly( 2371 "attr", 2372 [](PyNamedAttribute &self) { 2373 // TODO: When named attribute is removed/refactored, also remove 2374 // this constructor (it does an inefficient table lookup). 2375 auto contextRef = PyMlirContext::forContext( 2376 mlirAttributeGetContext(self.namedAttr.attribute)); 2377 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2378 }, 2379 py::keep_alive<0, 1>(), 2380 "The underlying generic attribute of the NamedAttribute binding"); 2381 2382 //---------------------------------------------------------------------------- 2383 // Mapping of PyType. 2384 //---------------------------------------------------------------------------- 2385 py::class_<PyType>(m, "Type") 2386 // Delegate to the PyType copy constructor, which will also lifetime 2387 // extend the backing context which owns the MlirType. 2388 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2389 "Casts the passed type to the generic Type") 2390 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2391 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2392 .def_static( 2393 "parse", 2394 [](std::string typeSpec, DefaultingPyMlirContext context) { 2395 MlirType type = 2396 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2397 // TODO: Rework error reporting once diagnostic engine is exposed 2398 // in C API. 2399 if (mlirTypeIsNull(type)) { 2400 throw SetPyError(PyExc_ValueError, 2401 Twine("Unable to parse type: '") + typeSpec + 2402 "'"); 2403 } 2404 return PyType(context->getRef(), type); 2405 }, 2406 py::arg("asm"), py::arg("context") = py::none(), 2407 kContextParseTypeDocstring) 2408 .def_property_readonly( 2409 "context", [](PyType &self) { return self.getContext().getObject(); }, 2410 "Context that owns the Type") 2411 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2412 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2413 .def( 2414 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2415 .def( 2416 "__str__", 2417 [](PyType &self) { 2418 PyPrintAccumulator printAccum; 2419 mlirTypePrint(self, printAccum.getCallback(), 2420 printAccum.getUserData()); 2421 return printAccum.join(); 2422 }, 2423 "Returns the assembly form of the type.") 2424 .def("__repr__", [](PyType &self) { 2425 // Generally, assembly formats are not printed for __repr__ because 2426 // this can cause exceptionally long debug output and exceptions. 2427 // However, types are an exception as they typically have compact 2428 // assembly forms and printing them is useful. 2429 PyPrintAccumulator printAccum; 2430 printAccum.parts.append("Type("); 2431 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2432 printAccum.parts.append(")"); 2433 return printAccum.join(); 2434 }); 2435 2436 //---------------------------------------------------------------------------- 2437 // Mapping of Value. 2438 //---------------------------------------------------------------------------- 2439 py::class_<PyValue>(m, "Value") 2440 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2441 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2442 .def_property_readonly( 2443 "context", 2444 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2445 "Context in which the value lives.") 2446 .def( 2447 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2448 kDumpDocstring) 2449 .def_property_readonly( 2450 "owner", 2451 [](PyValue &self) { 2452 assert(mlirOperationEqual(self.getParentOperation()->get(), 2453 mlirOpResultGetOwner(self.get())) && 2454 "expected the owner of the value in Python to match that in " 2455 "the IR"); 2456 return self.getParentOperation().getObject(); 2457 }) 2458 .def("__eq__", 2459 [](PyValue &self, PyValue &other) { 2460 return self.get().ptr == other.get().ptr; 2461 }) 2462 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2463 .def( 2464 "__str__", 2465 [](PyValue &self) { 2466 PyPrintAccumulator printAccum; 2467 printAccum.parts.append("Value("); 2468 mlirValuePrint(self.get(), printAccum.getCallback(), 2469 printAccum.getUserData()); 2470 printAccum.parts.append(")"); 2471 return printAccum.join(); 2472 }, 2473 kValueDunderStrDocstring) 2474 .def_property_readonly("type", [](PyValue &self) { 2475 return PyType(self.getParentOperation()->getContext(), 2476 mlirValueGetType(self.get())); 2477 }); 2478 PyBlockArgument::bind(m); 2479 PyOpResult::bind(m); 2480 2481 // Container bindings. 2482 PyBlockArgumentList::bind(m); 2483 PyBlockIterator::bind(m); 2484 PyBlockList::bind(m); 2485 PyOperationIterator::bind(m); 2486 PyOperationList::bind(m); 2487 PyOpAttributeMap::bind(m); 2488 PyOpOperandList::bind(m); 2489 PyOpResultList::bind(m); 2490 PyRegionIterator::bind(m); 2491 PyRegionList::bind(m); 2492 2493 // Debug bindings. 2494 PyGlobalDebugFlag::bind(m); 2495 } 2496