1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include <pybind11/stl.h> 22 23 namespace py = pybind11; 24 using namespace mlir; 25 using namespace mlir::python; 26 27 using llvm::SmallVector; 28 using llvm::StringRef; 29 using llvm::Twine; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kContextParseTypeDocstring[] = 36 R"(Parses the assembly form of a type. 37 38 Returns a Type object or raises a ValueError if the type cannot be parsed. 39 40 See also: https://mlir.llvm.org/docs/LangRef/#type-system 41 )"; 42 43 static const char kContextGetFileLocationDocstring[] = 44 R"(Gets a Location representing a file, line and column)"; 45 46 static const char kModuleParseDocstring[] = 47 R"(Parses a module's assembly format from a string. 48 49 Returns a new MlirModule or raises a ValueError if the parsing fails. 50 51 See also: https://mlir.llvm.org/docs/LangRef/ 52 )"; 53 54 static const char kOperationCreateDocstring[] = 55 R"(Creates a new operation. 56 57 Args: 58 name: Operation name (e.g. "dialect.operation"). 59 results: Sequence of Type representing op result types. 60 attributes: Dict of str:Attribute. 61 successors: List of Block for the operation's successors. 62 regions: Number of regions to create. 63 location: A Location object (defaults to resolve from context manager). 64 ip: An InsertionPoint (defaults to resolve from context manager or set to 65 False to disable insertion, even with an insertion point set in the 66 context manager). 67 Returns: 68 A new "detached" Operation object. Detached operations can be added 69 to blocks, which causes them to become "attached." 70 )"; 71 72 static const char kOperationPrintDocstring[] = 73 R"(Prints the assembly form of the operation to a file like object. 74 75 Args: 76 file: The file like object to write to. Defaults to sys.stdout. 77 binary: Whether to write bytes (True) or str (False). Defaults to False. 78 large_elements_limit: Whether to elide elements attributes above this 79 number of elements. Defaults to None (no limit). 80 enable_debug_info: Whether to print debug/location information. Defaults 81 to False. 82 pretty_debug_info: Whether to format debug information for easier reading 83 by a human (warning: the result is unparseable). 84 print_generic_op_form: Whether to print the generic assembly forms of all 85 ops. Defaults to False. 86 use_local_Scope: Whether to print in a way that is more optimized for 87 multi-threaded access but may not be consistent with how the overall 88 module prints. 89 )"; 90 91 static const char kOperationGetAsmDocstring[] = 92 R"(Gets the assembly form of the operation with all options available. 93 94 Args: 95 binary: Whether to return a bytes (True) or str (False) object. Defaults to 96 False. 97 ... others ...: See the print() method for common keyword arguments for 98 configuring the printout. 99 Returns: 100 Either a bytes or str object, depending on the setting of the 'binary' 101 argument. 102 )"; 103 104 static const char kOperationStrDunderDocstring[] = 105 R"(Gets the assembly form of the operation with default options. 106 107 If more advanced control over the assembly formatting or I/O options is needed, 108 use the dedicated print or get_asm method, which supports keyword arguments to 109 customize behavior. 110 )"; 111 112 static const char kDumpDocstring[] = 113 R"(Dumps a debug representation of the object to stderr.)"; 114 115 static const char kAppendBlockDocstring[] = 116 R"(Appends a new block, with argument types as positional args. 117 118 Returns: 119 The created block. 120 )"; 121 122 static const char kValueDunderStrDocstring[] = 123 R"(Returns the string form of the value. 124 125 If the value is a block argument, this is the assembly form of its type and the 126 position in the argument list. If the value is an operation result, this is 127 equivalent to printing the operation that produced it. 128 )"; 129 130 //------------------------------------------------------------------------------ 131 // Utilities. 132 //------------------------------------------------------------------------------ 133 134 /// Helper for creating an @classmethod. 135 template <class Func, typename... Args> 136 py::object classmethod(Func f, Args... args) { 137 py::object cf = py::cpp_function(f, args...); 138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 139 } 140 141 static py::object 142 createCustomDialectWrapper(const std::string &dialectNamespace, 143 py::object dialectDescriptor) { 144 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 145 if (!dialectClass) { 146 // Use the base class. 147 return py::cast(PyDialect(std::move(dialectDescriptor))); 148 } 149 150 // Create the custom implementation. 151 return (*dialectClass)(std::move(dialectDescriptor)); 152 } 153 154 static MlirStringRef toMlirStringRef(const std::string &s) { 155 return mlirStringRefCreate(s.data(), s.size()); 156 } 157 158 /// Wrapper for the global LLVM debugging flag. 159 struct PyGlobalDebugFlag { 160 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 161 162 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 163 164 static void bind(py::module &m) { 165 // Debug flags. 166 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 167 .def_property_static("flag", &PyGlobalDebugFlag::get, 168 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 169 } 170 }; 171 172 //------------------------------------------------------------------------------ 173 // Collections. 174 //------------------------------------------------------------------------------ 175 176 namespace { 177 178 class PyRegionIterator { 179 public: 180 PyRegionIterator(PyOperationRef operation) 181 : operation(std::move(operation)) {} 182 183 PyRegionIterator &dunderIter() { return *this; } 184 185 PyRegion dunderNext() { 186 operation->checkValid(); 187 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 188 throw py::stop_iteration(); 189 } 190 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 191 return PyRegion(operation, region); 192 } 193 194 static void bind(py::module &m) { 195 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 196 .def("__iter__", &PyRegionIterator::dunderIter) 197 .def("__next__", &PyRegionIterator::dunderNext); 198 } 199 200 private: 201 PyOperationRef operation; 202 int nextIndex = 0; 203 }; 204 205 /// Regions of an op are fixed length and indexed numerically so are represented 206 /// with a sequence-like container. 207 class PyRegionList { 208 public: 209 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 210 211 intptr_t dunderLen() { 212 operation->checkValid(); 213 return mlirOperationGetNumRegions(operation->get()); 214 } 215 216 PyRegion dunderGetItem(intptr_t index) { 217 // dunderLen checks validity. 218 if (index < 0 || index >= dunderLen()) { 219 throw SetPyError(PyExc_IndexError, 220 "attempt to access out of bounds region"); 221 } 222 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 223 return PyRegion(operation, region); 224 } 225 226 static void bind(py::module &m) { 227 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 228 .def("__len__", &PyRegionList::dunderLen) 229 .def("__getitem__", &PyRegionList::dunderGetItem); 230 } 231 232 private: 233 PyOperationRef operation; 234 }; 235 236 class PyBlockIterator { 237 public: 238 PyBlockIterator(PyOperationRef operation, MlirBlock next) 239 : operation(std::move(operation)), next(next) {} 240 241 PyBlockIterator &dunderIter() { return *this; } 242 243 PyBlock dunderNext() { 244 operation->checkValid(); 245 if (mlirBlockIsNull(next)) { 246 throw py::stop_iteration(); 247 } 248 249 PyBlock returnBlock(operation, next); 250 next = mlirBlockGetNextInRegion(next); 251 return returnBlock; 252 } 253 254 static void bind(py::module &m) { 255 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 256 .def("__iter__", &PyBlockIterator::dunderIter) 257 .def("__next__", &PyBlockIterator::dunderNext); 258 } 259 260 private: 261 PyOperationRef operation; 262 MlirBlock next; 263 }; 264 265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 266 /// we present them as a more full-featured list-like container but optimize 267 /// it for forward iteration. Blocks are always owned by a region. 268 class PyBlockList { 269 public: 270 PyBlockList(PyOperationRef operation, MlirRegion region) 271 : operation(std::move(operation)), region(region) {} 272 273 PyBlockIterator dunderIter() { 274 operation->checkValid(); 275 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 276 } 277 278 intptr_t dunderLen() { 279 operation->checkValid(); 280 intptr_t count = 0; 281 MlirBlock block = mlirRegionGetFirstBlock(region); 282 while (!mlirBlockIsNull(block)) { 283 count += 1; 284 block = mlirBlockGetNextInRegion(block); 285 } 286 return count; 287 } 288 289 PyBlock dunderGetItem(intptr_t index) { 290 operation->checkValid(); 291 if (index < 0) { 292 throw SetPyError(PyExc_IndexError, 293 "attempt to access out of bounds block"); 294 } 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 if (index == 0) { 298 return PyBlock(operation, block); 299 } 300 block = mlirBlockGetNextInRegion(block); 301 index -= 1; 302 } 303 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 304 } 305 306 PyBlock appendBlock(py::args pyArgTypes) { 307 operation->checkValid(); 308 llvm::SmallVector<MlirType, 4> argTypes; 309 argTypes.reserve(pyArgTypes.size()); 310 for (auto &pyArg : pyArgTypes) { 311 argTypes.push_back(pyArg.cast<PyType &>()); 312 } 313 314 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 315 mlirRegionAppendOwnedBlock(region, block); 316 return PyBlock(operation, block); 317 } 318 319 static void bind(py::module &m) { 320 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 321 .def("__getitem__", &PyBlockList::dunderGetItem) 322 .def("__iter__", &PyBlockList::dunderIter) 323 .def("__len__", &PyBlockList::dunderLen) 324 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 325 } 326 327 private: 328 PyOperationRef operation; 329 MlirRegion region; 330 }; 331 332 class PyOperationIterator { 333 public: 334 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 335 : parentOperation(std::move(parentOperation)), next(next) {} 336 337 PyOperationIterator &dunderIter() { return *this; } 338 339 py::object dunderNext() { 340 parentOperation->checkValid(); 341 if (mlirOperationIsNull(next)) { 342 throw py::stop_iteration(); 343 } 344 345 PyOperationRef returnOperation = 346 PyOperation::forOperation(parentOperation->getContext(), next); 347 next = mlirOperationGetNextInBlock(next); 348 return returnOperation->createOpView(); 349 } 350 351 static void bind(py::module &m) { 352 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 353 .def("__iter__", &PyOperationIterator::dunderIter) 354 .def("__next__", &PyOperationIterator::dunderNext); 355 } 356 357 private: 358 PyOperationRef parentOperation; 359 MlirOperation next; 360 }; 361 362 /// Operations are exposed by the C-API as a forward-only linked list. In 363 /// Python, we present them as a more full-featured list-like container but 364 /// optimize it for forward iteration. Iterable operations are always owned 365 /// by a block. 366 class PyOperationList { 367 public: 368 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 369 : parentOperation(std::move(parentOperation)), block(block) {} 370 371 PyOperationIterator dunderIter() { 372 parentOperation->checkValid(); 373 return PyOperationIterator(parentOperation, 374 mlirBlockGetFirstOperation(block)); 375 } 376 377 intptr_t dunderLen() { 378 parentOperation->checkValid(); 379 intptr_t count = 0; 380 MlirOperation childOp = mlirBlockGetFirstOperation(block); 381 while (!mlirOperationIsNull(childOp)) { 382 count += 1; 383 childOp = mlirOperationGetNextInBlock(childOp); 384 } 385 return count; 386 } 387 388 py::object dunderGetItem(intptr_t index) { 389 parentOperation->checkValid(); 390 if (index < 0) { 391 throw SetPyError(PyExc_IndexError, 392 "attempt to access out of bounds operation"); 393 } 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 if (index == 0) { 397 return PyOperation::forOperation(parentOperation->getContext(), childOp) 398 ->createOpView(); 399 } 400 childOp = mlirOperationGetNextInBlock(childOp); 401 index -= 1; 402 } 403 throw SetPyError(PyExc_IndexError, 404 "attempt to access out of bounds operation"); 405 } 406 407 static void bind(py::module &m) { 408 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 409 .def("__getitem__", &PyOperationList::dunderGetItem) 410 .def("__iter__", &PyOperationList::dunderIter) 411 .def("__len__", &PyOperationList::dunderLen); 412 } 413 414 private: 415 PyOperationRef parentOperation; 416 MlirBlock block; 417 }; 418 419 } // namespace 420 421 //------------------------------------------------------------------------------ 422 // PyMlirContext 423 //------------------------------------------------------------------------------ 424 425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 426 py::gil_scoped_acquire acquire; 427 auto &liveContexts = getLiveContexts(); 428 liveContexts[context.ptr] = this; 429 } 430 431 PyMlirContext::~PyMlirContext() { 432 // Note that the only public way to construct an instance is via the 433 // forContext method, which always puts the associated handle into 434 // liveContexts. 435 py::gil_scoped_acquire acquire; 436 getLiveContexts().erase(context.ptr); 437 mlirContextDestroy(context); 438 } 439 440 py::object PyMlirContext::getCapsule() { 441 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 442 } 443 444 py::object PyMlirContext::createFromCapsule(py::object capsule) { 445 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 446 if (mlirContextIsNull(rawContext)) 447 throw py::error_already_set(); 448 return forContext(rawContext).releaseObject(); 449 } 450 451 PyMlirContext *PyMlirContext::createNewContextForInit() { 452 MlirContext context = mlirContextCreate(); 453 mlirRegisterAllDialects(context); 454 return new PyMlirContext(context); 455 } 456 457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 458 py::gil_scoped_acquire acquire; 459 auto &liveContexts = getLiveContexts(); 460 auto it = liveContexts.find(context.ptr); 461 if (it == liveContexts.end()) { 462 // Create. 463 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 464 py::object pyRef = py::cast(unownedContextWrapper); 465 assert(pyRef && "cast to py::object failed"); 466 liveContexts[context.ptr] = unownedContextWrapper; 467 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 468 } 469 // Use existing. 470 py::object pyRef = py::cast(it->second); 471 return PyMlirContextRef(it->second, std::move(pyRef)); 472 } 473 474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 475 static LiveContextMap liveContexts; 476 return liveContexts; 477 } 478 479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 480 481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 482 483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 484 485 pybind11::object PyMlirContext::contextEnter() { 486 return PyThreadContextEntry::pushContext(*this); 487 } 488 489 void PyMlirContext::contextExit(pybind11::object excType, 490 pybind11::object excVal, 491 pybind11::object excTb) { 492 PyThreadContextEntry::popContext(*this); 493 } 494 495 PyMlirContext &DefaultingPyMlirContext::resolve() { 496 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 497 if (!context) { 498 throw SetPyError( 499 PyExc_RuntimeError, 500 "An MLIR function requires a Context but none was provided in the call " 501 "or from the surrounding environment. Either pass to the function with " 502 "a 'context=' argument or establish a default using 'with Context():'"); 503 } 504 return *context; 505 } 506 507 //------------------------------------------------------------------------------ 508 // PyThreadContextEntry management 509 //------------------------------------------------------------------------------ 510 511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 512 static thread_local std::vector<PyThreadContextEntry> stack; 513 return stack; 514 } 515 516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 517 auto &stack = getStack(); 518 if (stack.empty()) 519 return nullptr; 520 return &stack.back(); 521 } 522 523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 524 py::object insertionPoint, 525 py::object location) { 526 auto &stack = getStack(); 527 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 528 std::move(location)); 529 // If the new stack has more than one entry and the context of the new top 530 // entry matches the previous, copy the insertionPoint and location from the 531 // previous entry if missing from the new top entry. 532 if (stack.size() > 1) { 533 auto &prev = *(stack.rbegin() + 1); 534 auto ¤t = stack.back(); 535 if (current.context.is(prev.context)) { 536 // Default non-context objects from the previous entry. 537 if (!current.insertionPoint) 538 current.insertionPoint = prev.insertionPoint; 539 if (!current.location) 540 current.location = prev.location; 541 } 542 } 543 } 544 545 PyMlirContext *PyThreadContextEntry::getContext() { 546 if (!context) 547 return nullptr; 548 return py::cast<PyMlirContext *>(context); 549 } 550 551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 552 if (!insertionPoint) 553 return nullptr; 554 return py::cast<PyInsertionPoint *>(insertionPoint); 555 } 556 557 PyLocation *PyThreadContextEntry::getLocation() { 558 if (!location) 559 return nullptr; 560 return py::cast<PyLocation *>(location); 561 } 562 563 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 564 auto *tos = getTopOfStack(); 565 return tos ? tos->getContext() : nullptr; 566 } 567 568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 569 auto *tos = getTopOfStack(); 570 return tos ? tos->getInsertionPoint() : nullptr; 571 } 572 573 PyLocation *PyThreadContextEntry::getDefaultLocation() { 574 auto *tos = getTopOfStack(); 575 return tos ? tos->getLocation() : nullptr; 576 } 577 578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 579 py::object contextObj = py::cast(context); 580 push(FrameKind::Context, /*context=*/contextObj, 581 /*insertionPoint=*/py::object(), 582 /*location=*/py::object()); 583 return contextObj; 584 } 585 586 void PyThreadContextEntry::popContext(PyMlirContext &context) { 587 auto &stack = getStack(); 588 if (stack.empty()) 589 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 590 auto &tos = stack.back(); 591 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593 stack.pop_back(); 594 } 595 596 py::object 597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 598 py::object contextObj = 599 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 600 py::object insertionPointObj = py::cast(insertionPoint); 601 push(FrameKind::InsertionPoint, 602 /*context=*/contextObj, 603 /*insertionPoint=*/insertionPointObj, 604 /*location=*/py::object()); 605 return insertionPointObj; 606 } 607 608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 609 auto &stack = getStack(); 610 if (stack.empty()) 611 throw SetPyError(PyExc_RuntimeError, 612 "Unbalanced InsertionPoint enter/exit"); 613 auto &tos = stack.back(); 614 if (tos.frameKind != FrameKind::InsertionPoint && 615 tos.getInsertionPoint() != &insertionPoint) 616 throw SetPyError(PyExc_RuntimeError, 617 "Unbalanced InsertionPoint enter/exit"); 618 stack.pop_back(); 619 } 620 621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 622 py::object contextObj = location.getContext().getObject(); 623 py::object locationObj = py::cast(location); 624 push(FrameKind::Location, /*context=*/contextObj, 625 /*insertionPoint=*/py::object(), 626 /*location=*/locationObj); 627 return locationObj; 628 } 629 630 void PyThreadContextEntry::popLocation(PyLocation &location) { 631 auto &stack = getStack(); 632 if (stack.empty()) 633 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 634 auto &tos = stack.back(); 635 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637 stack.pop_back(); 638 } 639 640 //------------------------------------------------------------------------------ 641 // PyDialect, PyDialectDescriptor, PyDialects 642 //------------------------------------------------------------------------------ 643 644 MlirDialect PyDialects::getDialectForKey(const std::string &key, 645 bool attrError) { 646 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 647 {key.data(), key.size()}); 648 if (mlirDialectIsNull(dialect)) { 649 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 650 Twine("Dialect '") + key + "' not found"); 651 } 652 return dialect; 653 } 654 655 //------------------------------------------------------------------------------ 656 // PyLocation 657 //------------------------------------------------------------------------------ 658 659 py::object PyLocation::getCapsule() { 660 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 661 } 662 663 PyLocation PyLocation::createFromCapsule(py::object capsule) { 664 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 665 if (mlirLocationIsNull(rawLoc)) 666 throw py::error_already_set(); 667 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 668 rawLoc); 669 } 670 671 py::object PyLocation::contextEnter() { 672 return PyThreadContextEntry::pushLocation(*this); 673 } 674 675 void PyLocation::contextExit(py::object excType, py::object excVal, 676 py::object excTb) { 677 PyThreadContextEntry::popLocation(*this); 678 } 679 680 PyLocation &DefaultingPyLocation::resolve() { 681 auto *location = PyThreadContextEntry::getDefaultLocation(); 682 if (!location) { 683 throw SetPyError( 684 PyExc_RuntimeError, 685 "An MLIR function requires a Location but none was provided in the " 686 "call or from the surrounding environment. Either pass to the function " 687 "with a 'loc=' argument or establish a default using 'with loc:'"); 688 } 689 return *location; 690 } 691 692 //------------------------------------------------------------------------------ 693 // PyModule 694 //------------------------------------------------------------------------------ 695 696 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 697 : BaseContextObject(std::move(contextRef)), module(module) {} 698 699 PyModule::~PyModule() { 700 py::gil_scoped_acquire acquire; 701 auto &liveModules = getContext()->liveModules; 702 assert(liveModules.count(module.ptr) == 1 && 703 "destroying module not in live map"); 704 liveModules.erase(module.ptr); 705 mlirModuleDestroy(module); 706 } 707 708 PyModuleRef PyModule::forModule(MlirModule module) { 709 MlirContext context = mlirModuleGetContext(module); 710 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 711 712 py::gil_scoped_acquire acquire; 713 auto &liveModules = contextRef->liveModules; 714 auto it = liveModules.find(module.ptr); 715 if (it == liveModules.end()) { 716 // Create. 717 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 718 // Note that the default return value policy on cast is automatic_reference, 719 // which does not take ownership (delete will not be called). 720 // Just be explicit. 721 py::object pyRef = 722 py::cast(unownedModule, py::return_value_policy::take_ownership); 723 unownedModule->handle = pyRef; 724 liveModules[module.ptr] = 725 std::make_pair(unownedModule->handle, unownedModule); 726 return PyModuleRef(unownedModule, std::move(pyRef)); 727 } 728 // Use existing. 729 PyModule *existing = it->second.second; 730 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 731 return PyModuleRef(existing, std::move(pyRef)); 732 } 733 734 py::object PyModule::createFromCapsule(py::object capsule) { 735 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 736 if (mlirModuleIsNull(rawModule)) 737 throw py::error_already_set(); 738 return forModule(rawModule).releaseObject(); 739 } 740 741 py::object PyModule::getCapsule() { 742 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 743 } 744 745 //------------------------------------------------------------------------------ 746 // PyOperation 747 //------------------------------------------------------------------------------ 748 749 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 750 : BaseContextObject(std::move(contextRef)), operation(operation) {} 751 752 PyOperation::~PyOperation() { 753 // If the operation has already been invalidated there is nothing to do. 754 if (!valid) 755 return; 756 auto &liveOperations = getContext()->liveOperations; 757 assert(liveOperations.count(operation.ptr) == 1 && 758 "destroying operation not in live map"); 759 liveOperations.erase(operation.ptr); 760 if (!isAttached()) { 761 mlirOperationDestroy(operation); 762 } 763 } 764 765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 766 MlirOperation operation, 767 py::object parentKeepAlive) { 768 auto &liveOperations = contextRef->liveOperations; 769 // Create. 770 PyOperation *unownedOperation = 771 new PyOperation(std::move(contextRef), operation); 772 // Note that the default return value policy on cast is automatic_reference, 773 // which does not take ownership (delete will not be called). 774 // Just be explicit. 775 py::object pyRef = 776 py::cast(unownedOperation, py::return_value_policy::take_ownership); 777 unownedOperation->handle = pyRef; 778 if (parentKeepAlive) { 779 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 780 } 781 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 782 return PyOperationRef(unownedOperation, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 auto it = liveOperations.find(operation.ptr); 790 if (it == liveOperations.end()) { 791 // Create. 792 return createInstance(std::move(contextRef), operation, 793 std::move(parentKeepAlive)); 794 } 795 // Use existing. 796 PyOperation *existing = it->second.second; 797 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 798 return PyOperationRef(existing, std::move(pyRef)); 799 } 800 801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 802 MlirOperation operation, 803 py::object parentKeepAlive) { 804 auto &liveOperations = contextRef->liveOperations; 805 assert(liveOperations.count(operation.ptr) == 0 && 806 "cannot create detached operation that already exists"); 807 (void)liveOperations; 808 809 PyOperationRef created = createInstance(std::move(contextRef), operation, 810 std::move(parentKeepAlive)); 811 created->attached = false; 812 return created; 813 } 814 815 void PyOperation::checkValid() const { 816 if (!valid) { 817 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 818 } 819 } 820 821 void PyOperationBase::print(py::object fileObject, bool binary, 822 llvm::Optional<int64_t> largeElementsLimit, 823 bool enableDebugInfo, bool prettyDebugInfo, 824 bool printGenericOpForm, bool useLocalScope) { 825 PyOperation &operation = getOperation(); 826 operation.checkValid(); 827 if (fileObject.is_none()) 828 fileObject = py::module::import("sys").attr("stdout"); 829 830 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 831 fileObject.attr("write")("// Verification failed, printing generic form\n"); 832 printGenericOpForm = true; 833 } 834 835 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 836 if (largeElementsLimit) 837 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 838 if (enableDebugInfo) 839 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 840 if (printGenericOpForm) 841 mlirOpPrintingFlagsPrintGenericOpForm(flags); 842 843 PyFileAccumulator accum(fileObject, binary); 844 py::gil_scoped_release(); 845 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 846 accum.getUserData()); 847 mlirOpPrintingFlagsDestroy(flags); 848 } 849 850 py::object PyOperationBase::getAsm(bool binary, 851 llvm::Optional<int64_t> largeElementsLimit, 852 bool enableDebugInfo, bool prettyDebugInfo, 853 bool printGenericOpForm, 854 bool useLocalScope) { 855 py::object fileObject; 856 if (binary) { 857 fileObject = py::module::import("io").attr("BytesIO")(); 858 } else { 859 fileObject = py::module::import("io").attr("StringIO")(); 860 } 861 print(fileObject, /*binary=*/binary, 862 /*largeElementsLimit=*/largeElementsLimit, 863 /*enableDebugInfo=*/enableDebugInfo, 864 /*prettyDebugInfo=*/prettyDebugInfo, 865 /*printGenericOpForm=*/printGenericOpForm, 866 /*useLocalScope=*/useLocalScope); 867 868 return fileObject.attr("getvalue")(); 869 } 870 871 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 872 checkValid(); 873 if (!isAttached()) 874 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 875 MlirOperation operation = mlirOperationGetParentOperation(get()); 876 if (mlirOperationIsNull(operation)) 877 return {}; 878 return PyOperation::forOperation(getContext(), operation); 879 } 880 881 PyBlock PyOperation::getBlock() { 882 checkValid(); 883 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 884 MlirBlock block = mlirOperationGetBlock(get()); 885 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 886 assert(parentOperation && "Operation has no parent"); 887 return PyBlock{std::move(*parentOperation), block}; 888 } 889 890 py::object PyOperation::getCapsule() { 891 checkValid(); 892 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 893 } 894 895 py::object PyOperation::createFromCapsule(py::object capsule) { 896 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 897 if (mlirOperationIsNull(rawOperation)) 898 throw py::error_already_set(); 899 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 900 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 901 .releaseObject(); 902 } 903 904 py::object PyOperation::create( 905 std::string name, llvm::Optional<std::vector<PyType *>> results, 906 llvm::Optional<std::vector<PyValue *>> operands, 907 llvm::Optional<py::dict> attributes, 908 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 909 DefaultingPyLocation location, py::object maybeIp) { 910 llvm::SmallVector<MlirValue, 4> mlirOperands; 911 llvm::SmallVector<MlirType, 4> mlirResults; 912 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 913 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 914 915 // General parameter validation. 916 if (regions < 0) 917 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 918 919 // Unpack/validate operands. 920 if (operands) { 921 mlirOperands.reserve(operands->size()); 922 for (PyValue *operand : *operands) { 923 if (!operand) 924 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 925 mlirOperands.push_back(operand->get()); 926 } 927 } 928 929 // Unpack/validate results. 930 if (results) { 931 mlirResults.reserve(results->size()); 932 for (PyType *result : *results) { 933 // TODO: Verify result type originate from the same context. 934 if (!result) 935 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 936 mlirResults.push_back(*result); 937 } 938 } 939 // Unpack/validate attributes. 940 if (attributes) { 941 mlirAttributes.reserve(attributes->size()); 942 for (auto &it : *attributes) { 943 std::string key; 944 try { 945 key = it.first.cast<std::string>(); 946 } catch (py::cast_error &err) { 947 std::string msg = "Invalid attribute key (not a string) when " 948 "attempting to create the operation \"" + 949 name + "\" (" + err.what() + ")"; 950 throw py::cast_error(msg); 951 } 952 try { 953 auto &attribute = it.second.cast<PyAttribute &>(); 954 // TODO: Verify attribute originates from the same context. 955 mlirAttributes.emplace_back(std::move(key), attribute); 956 } catch (py::reference_cast_error &) { 957 // This exception seems thrown when the value is "None". 958 std::string msg = 959 "Found an invalid (`None`?) attribute value for the key \"" + key + 960 "\" when attempting to create the operation \"" + name + "\""; 961 throw py::cast_error(msg); 962 } catch (py::cast_error &err) { 963 std::string msg = "Invalid attribute value for the key \"" + key + 964 "\" when attempting to create the operation \"" + 965 name + "\" (" + err.what() + ")"; 966 throw py::cast_error(msg); 967 } 968 } 969 } 970 // Unpack/validate successors. 971 if (successors) { 972 mlirSuccessors.reserve(successors->size()); 973 for (auto *successor : *successors) { 974 // TODO: Verify successor originate from the same context. 975 if (!successor) 976 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 977 mlirSuccessors.push_back(successor->get()); 978 } 979 } 980 981 // Apply unpacked/validated to the operation state. Beyond this 982 // point, exceptions cannot be thrown or else the state will leak. 983 MlirOperationState state = 984 mlirOperationStateGet(toMlirStringRef(name), location); 985 if (!mlirOperands.empty()) 986 mlirOperationStateAddOperands(&state, mlirOperands.size(), 987 mlirOperands.data()); 988 if (!mlirResults.empty()) 989 mlirOperationStateAddResults(&state, mlirResults.size(), 990 mlirResults.data()); 991 if (!mlirAttributes.empty()) { 992 // Note that the attribute names directly reference bytes in 993 // mlirAttributes, so that vector must not be changed from here 994 // on. 995 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 996 mlirNamedAttributes.reserve(mlirAttributes.size()); 997 for (auto &it : mlirAttributes) 998 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 999 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1000 toMlirStringRef(it.first)), 1001 it.second)); 1002 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1003 mlirNamedAttributes.data()); 1004 } 1005 if (!mlirSuccessors.empty()) 1006 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1007 mlirSuccessors.data()); 1008 if (regions) { 1009 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1010 mlirRegions.resize(regions); 1011 for (int i = 0; i < regions; ++i) 1012 mlirRegions[i] = mlirRegionCreate(); 1013 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1014 mlirRegions.data()); 1015 } 1016 1017 // Construct the operation. 1018 MlirOperation operation = mlirOperationCreate(&state); 1019 PyOperationRef created = 1020 PyOperation::createDetached(location->getContext(), operation); 1021 1022 // InsertPoint active? 1023 if (!maybeIp.is(py::cast(false))) { 1024 PyInsertionPoint *ip; 1025 if (maybeIp.is_none()) { 1026 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1027 } else { 1028 ip = py::cast<PyInsertionPoint *>(maybeIp); 1029 } 1030 if (ip) 1031 ip->insert(*created.get()); 1032 } 1033 1034 return created->createOpView(); 1035 } 1036 1037 py::object PyOperation::createOpView() { 1038 checkValid(); 1039 MlirIdentifier ident = mlirOperationGetName(get()); 1040 MlirStringRef identStr = mlirIdentifierStr(ident); 1041 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1042 StringRef(identStr.data, identStr.length)); 1043 if (opViewClass) 1044 return (*opViewClass)(getRef().getObject()); 1045 return py::cast(PyOpView(getRef().getObject())); 1046 } 1047 1048 void PyOperation::erase() { 1049 checkValid(); 1050 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1051 // Python reference to a child operation is live. All children should also 1052 // have their `valid` bit set to false. 1053 auto &liveOperations = getContext()->liveOperations; 1054 if (liveOperations.count(operation.ptr)) 1055 liveOperations.erase(operation.ptr); 1056 mlirOperationDestroy(operation); 1057 valid = false; 1058 } 1059 1060 //------------------------------------------------------------------------------ 1061 // PyOpView 1062 //------------------------------------------------------------------------------ 1063 1064 py::object 1065 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1066 py::list operandList, 1067 llvm::Optional<py::dict> attributes, 1068 llvm::Optional<std::vector<PyBlock *>> successors, 1069 llvm::Optional<int> regions, 1070 DefaultingPyLocation location, py::object maybeIp) { 1071 PyMlirContextRef context = location->getContext(); 1072 // Class level operation construction metadata. 1073 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1074 // Operand and result segment specs are either none, which does no 1075 // variadic unpacking, or a list of ints with segment sizes, where each 1076 // element is either a positive number (typically 1 for a scalar) or -1 to 1077 // indicate that it is derived from the length of the same-indexed operand 1078 // or result (implying that it is a list at that position). 1079 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1080 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1081 1082 std::vector<uint32_t> operandSegmentLengths; 1083 std::vector<uint32_t> resultSegmentLengths; 1084 1085 // Validate/determine region count. 1086 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1087 int opMinRegionCount = std::get<0>(opRegionSpec); 1088 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1089 if (!regions) { 1090 regions = opMinRegionCount; 1091 } 1092 if (*regions < opMinRegionCount) { 1093 throw py::value_error( 1094 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1095 llvm::Twine(opMinRegionCount) + 1096 " regions but was built with regions=" + llvm::Twine(*regions)) 1097 .str()); 1098 } 1099 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1100 throw py::value_error( 1101 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1102 llvm::Twine(opMinRegionCount) + 1103 " regions but was built with regions=" + llvm::Twine(*regions)) 1104 .str()); 1105 } 1106 1107 // Unpack results. 1108 std::vector<PyType *> resultTypes; 1109 resultTypes.reserve(resultTypeList.size()); 1110 if (resultSegmentSpecObj.is_none()) { 1111 // Non-variadic result unpacking. 1112 for (auto it : llvm::enumerate(resultTypeList)) { 1113 try { 1114 resultTypes.push_back(py::cast<PyType *>(it.value())); 1115 if (!resultTypes.back()) 1116 throw py::cast_error(); 1117 } catch (py::cast_error &err) { 1118 throw py::value_error((llvm::Twine("Result ") + 1119 llvm::Twine(it.index()) + " of operation \"" + 1120 name + "\" must be a Type (" + err.what() + ")") 1121 .str()); 1122 } 1123 } 1124 } else { 1125 // Sized result unpacking. 1126 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1127 if (resultSegmentSpec.size() != resultTypeList.size()) { 1128 throw py::value_error((llvm::Twine("Operation \"") + name + 1129 "\" requires " + 1130 llvm::Twine(resultSegmentSpec.size()) + 1131 "result segments but was provided " + 1132 llvm::Twine(resultTypeList.size())) 1133 .str()); 1134 } 1135 resultSegmentLengths.reserve(resultTypeList.size()); 1136 for (auto it : 1137 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1138 int segmentSpec = std::get<1>(it.value()); 1139 if (segmentSpec == 1 || segmentSpec == 0) { 1140 // Unpack unary element. 1141 try { 1142 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1143 if (resultType) { 1144 resultTypes.push_back(resultType); 1145 resultSegmentLengths.push_back(1); 1146 } else if (segmentSpec == 0) { 1147 // Allowed to be optional. 1148 resultSegmentLengths.push_back(0); 1149 } else { 1150 throw py::cast_error("was None and result is not optional"); 1151 } 1152 } catch (py::cast_error &err) { 1153 throw py::value_error((llvm::Twine("Result ") + 1154 llvm::Twine(it.index()) + " of operation \"" + 1155 name + "\" must be a Type (" + err.what() + 1156 ")") 1157 .str()); 1158 } 1159 } else if (segmentSpec == -1) { 1160 // Unpack sequence by appending. 1161 try { 1162 if (std::get<0>(it.value()).is_none()) { 1163 // Treat it as an empty list. 1164 resultSegmentLengths.push_back(0); 1165 } else { 1166 // Unpack the list. 1167 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1168 for (py::object segmentItem : segment) { 1169 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1170 if (!resultTypes.back()) { 1171 throw py::cast_error("contained a None item"); 1172 } 1173 } 1174 resultSegmentLengths.push_back(segment.size()); 1175 } 1176 } catch (std::exception &err) { 1177 // NOTE: Sloppy to be using a catch-all here, but there are at least 1178 // three different unrelated exceptions that can be thrown in the 1179 // above "casts". Just keep the scope above small and catch them all. 1180 throw py::value_error((llvm::Twine("Result ") + 1181 llvm::Twine(it.index()) + " of operation \"" + 1182 name + "\" must be a Sequence of Types (" + 1183 err.what() + ")") 1184 .str()); 1185 } 1186 } else { 1187 throw py::value_error("Unexpected segment spec"); 1188 } 1189 } 1190 } 1191 1192 // Unpack operands. 1193 std::vector<PyValue *> operands; 1194 operands.reserve(operands.size()); 1195 if (operandSegmentSpecObj.is_none()) { 1196 // Non-sized operand unpacking. 1197 for (auto it : llvm::enumerate(operandList)) { 1198 try { 1199 operands.push_back(py::cast<PyValue *>(it.value())); 1200 if (!operands.back()) 1201 throw py::cast_error(); 1202 } catch (py::cast_error &err) { 1203 throw py::value_error((llvm::Twine("Operand ") + 1204 llvm::Twine(it.index()) + " of operation \"" + 1205 name + "\" must be a Value (" + err.what() + ")") 1206 .str()); 1207 } 1208 } 1209 } else { 1210 // Sized operand unpacking. 1211 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1212 if (operandSegmentSpec.size() != operandList.size()) { 1213 throw py::value_error((llvm::Twine("Operation \"") + name + 1214 "\" requires " + 1215 llvm::Twine(operandSegmentSpec.size()) + 1216 "operand segments but was provided " + 1217 llvm::Twine(operandList.size())) 1218 .str()); 1219 } 1220 operandSegmentLengths.reserve(operandList.size()); 1221 for (auto it : 1222 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1223 int segmentSpec = std::get<1>(it.value()); 1224 if (segmentSpec == 1 || segmentSpec == 0) { 1225 // Unpack unary element. 1226 try { 1227 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1228 if (operandValue) { 1229 operands.push_back(operandValue); 1230 operandSegmentLengths.push_back(1); 1231 } else if (segmentSpec == 0) { 1232 // Allowed to be optional. 1233 operandSegmentLengths.push_back(0); 1234 } else { 1235 throw py::cast_error("was None and operand is not optional"); 1236 } 1237 } catch (py::cast_error &err) { 1238 throw py::value_error((llvm::Twine("Operand ") + 1239 llvm::Twine(it.index()) + " of operation \"" + 1240 name + "\" must be a Value (" + err.what() + 1241 ")") 1242 .str()); 1243 } 1244 } else if (segmentSpec == -1) { 1245 // Unpack sequence by appending. 1246 try { 1247 if (std::get<0>(it.value()).is_none()) { 1248 // Treat it as an empty list. 1249 operandSegmentLengths.push_back(0); 1250 } else { 1251 // Unpack the list. 1252 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1253 for (py::object segmentItem : segment) { 1254 operands.push_back(py::cast<PyValue *>(segmentItem)); 1255 if (!operands.back()) { 1256 throw py::cast_error("contained a None item"); 1257 } 1258 } 1259 operandSegmentLengths.push_back(segment.size()); 1260 } 1261 } catch (std::exception &err) { 1262 // NOTE: Sloppy to be using a catch-all here, but there are at least 1263 // three different unrelated exceptions that can be thrown in the 1264 // above "casts". Just keep the scope above small and catch them all. 1265 throw py::value_error((llvm::Twine("Operand ") + 1266 llvm::Twine(it.index()) + " of operation \"" + 1267 name + "\" must be a Sequence of Values (" + 1268 err.what() + ")") 1269 .str()); 1270 } 1271 } else { 1272 throw py::value_error("Unexpected segment spec"); 1273 } 1274 } 1275 } 1276 1277 // Merge operand/result segment lengths into attributes if needed. 1278 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1279 // Dup. 1280 if (attributes) { 1281 attributes = py::dict(*attributes); 1282 } else { 1283 attributes = py::dict(); 1284 } 1285 if (attributes->contains("result_segment_sizes") || 1286 attributes->contains("operand_segment_sizes")) { 1287 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1288 "'operand_segment_sizes' attribute is unsupported. " 1289 "Use Operation.create for such low-level access."); 1290 } 1291 1292 // Add result_segment_sizes attribute. 1293 if (!resultSegmentLengths.empty()) { 1294 int64_t size = resultSegmentLengths.size(); 1295 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1296 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1297 resultSegmentLengths.size(), resultSegmentLengths.data()); 1298 (*attributes)["result_segment_sizes"] = 1299 PyAttribute(context, segmentLengthAttr); 1300 } 1301 1302 // Add operand_segment_sizes attribute. 1303 if (!operandSegmentLengths.empty()) { 1304 int64_t size = operandSegmentLengths.size(); 1305 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1306 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1307 operandSegmentLengths.size(), operandSegmentLengths.data()); 1308 (*attributes)["operand_segment_sizes"] = 1309 PyAttribute(context, segmentLengthAttr); 1310 } 1311 } 1312 1313 // Delegate to create. 1314 return PyOperation::create(std::move(name), 1315 /*results=*/std::move(resultTypes), 1316 /*operands=*/std::move(operands), 1317 /*attributes=*/std::move(attributes), 1318 /*successors=*/std::move(successors), 1319 /*regions=*/*regions, location, maybeIp); 1320 } 1321 1322 PyOpView::PyOpView(py::object operationObject) 1323 // Casting through the PyOperationBase base-class and then back to the 1324 // Operation lets us accept any PyOperationBase subclass. 1325 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1326 operationObject(operation.getRef().getObject()) {} 1327 1328 py::object PyOpView::createRawSubclass(py::object userClass) { 1329 // This is... a little gross. The typical pattern is to have a pure python 1330 // class that extends OpView like: 1331 // class AddFOp(_cext.ir.OpView): 1332 // def __init__(self, loc, lhs, rhs): 1333 // operation = loc.context.create_operation( 1334 // "addf", lhs, rhs, results=[lhs.type]) 1335 // super().__init__(operation) 1336 // 1337 // I.e. The goal of the user facing type is to provide a nice constructor 1338 // that has complete freedom for the op under construction. This is at odds 1339 // with our other desire to sometimes create this object by just passing an 1340 // operation (to initialize the base class). We could do *arg and **kwargs 1341 // munging to try to make it work, but instead, we synthesize a new class 1342 // on the fly which extends this user class (AddFOp in this example) and 1343 // *give it* the base class's __init__ method, thus bypassing the 1344 // intermediate subclass's __init__ method entirely. While slightly, 1345 // underhanded, this is safe/legal because the type hierarchy has not changed 1346 // (we just added a new leaf) and we aren't mucking around with __new__. 1347 // Typically, this new class will be stored on the original as "_Raw" and will 1348 // be used for casts and other things that need a variant of the class that 1349 // is initialized purely from an operation. 1350 py::object parentMetaclass = 1351 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1352 py::dict attributes; 1353 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1354 // now. 1355 // auto opViewType = py::type::of<PyOpView>(); 1356 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1357 attributes["__init__"] = opViewType.attr("__init__"); 1358 py::str origName = userClass.attr("__name__"); 1359 py::str newName = py::str("_") + origName; 1360 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1361 } 1362 1363 //------------------------------------------------------------------------------ 1364 // PyInsertionPoint. 1365 //------------------------------------------------------------------------------ 1366 1367 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1368 1369 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1370 : refOperation(beforeOperationBase.getOperation().getRef()), 1371 block((*refOperation)->getBlock()) {} 1372 1373 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1374 PyOperation &operation = operationBase.getOperation(); 1375 if (operation.isAttached()) 1376 throw SetPyError(PyExc_ValueError, 1377 "Attempt to insert operation that is already attached"); 1378 block.getParentOperation()->checkValid(); 1379 MlirOperation beforeOp = {nullptr}; 1380 if (refOperation) { 1381 // Insert before operation. 1382 (*refOperation)->checkValid(); 1383 beforeOp = (*refOperation)->get(); 1384 } else { 1385 // Insert at end (before null) is only valid if the block does not 1386 // already end in a known terminator (violating this will cause assertion 1387 // failures later). 1388 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1389 throw py::index_error("Cannot insert operation at the end of a block " 1390 "that already has a terminator. Did you mean to " 1391 "use 'InsertionPoint.at_block_terminator(block)' " 1392 "versus 'InsertionPoint(block)'?"); 1393 } 1394 } 1395 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1396 operation.setAttached(); 1397 } 1398 1399 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1400 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1401 if (mlirOperationIsNull(firstOp)) { 1402 // Just insert at end. 1403 return PyInsertionPoint(block); 1404 } 1405 1406 // Insert before first op. 1407 PyOperationRef firstOpRef = PyOperation::forOperation( 1408 block.getParentOperation()->getContext(), firstOp); 1409 return PyInsertionPoint{block, std::move(firstOpRef)}; 1410 } 1411 1412 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1413 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1414 if (mlirOperationIsNull(terminator)) 1415 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1416 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1417 block.getParentOperation()->getContext(), terminator); 1418 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1419 } 1420 1421 py::object PyInsertionPoint::contextEnter() { 1422 return PyThreadContextEntry::pushInsertionPoint(*this); 1423 } 1424 1425 void PyInsertionPoint::contextExit(pybind11::object excType, 1426 pybind11::object excVal, 1427 pybind11::object excTb) { 1428 PyThreadContextEntry::popInsertionPoint(*this); 1429 } 1430 1431 //------------------------------------------------------------------------------ 1432 // PyAttribute. 1433 //------------------------------------------------------------------------------ 1434 1435 bool PyAttribute::operator==(const PyAttribute &other) { 1436 return mlirAttributeEqual(attr, other.attr); 1437 } 1438 1439 py::object PyAttribute::getCapsule() { 1440 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1441 } 1442 1443 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1444 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1445 if (mlirAttributeIsNull(rawAttr)) 1446 throw py::error_already_set(); 1447 return PyAttribute( 1448 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1449 } 1450 1451 //------------------------------------------------------------------------------ 1452 // PyNamedAttribute. 1453 //------------------------------------------------------------------------------ 1454 1455 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1456 : ownedName(new std::string(std::move(ownedName))) { 1457 namedAttr = mlirNamedAttributeGet( 1458 mlirIdentifierGet(mlirAttributeGetContext(attr), 1459 toMlirStringRef(*this->ownedName)), 1460 attr); 1461 } 1462 1463 //------------------------------------------------------------------------------ 1464 // PyType. 1465 //------------------------------------------------------------------------------ 1466 1467 bool PyType::operator==(const PyType &other) { 1468 return mlirTypeEqual(type, other.type); 1469 } 1470 1471 py::object PyType::getCapsule() { 1472 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1473 } 1474 1475 PyType PyType::createFromCapsule(py::object capsule) { 1476 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1477 if (mlirTypeIsNull(rawType)) 1478 throw py::error_already_set(); 1479 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1480 rawType); 1481 } 1482 1483 //------------------------------------------------------------------------------ 1484 // PyValue and subclases. 1485 //------------------------------------------------------------------------------ 1486 1487 pybind11::object PyValue::getCapsule() { 1488 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1489 } 1490 1491 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1492 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1493 if (mlirValueIsNull(value)) 1494 throw py::error_already_set(); 1495 MlirOperation owner; 1496 if (mlirValueIsAOpResult(value)) 1497 owner = mlirOpResultGetOwner(value); 1498 if (mlirValueIsABlockArgument(value)) 1499 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1500 if (mlirOperationIsNull(owner)) 1501 throw py::error_already_set(); 1502 MlirContext ctx = mlirOperationGetContext(owner); 1503 PyOperationRef ownerRef = 1504 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1505 return PyValue(ownerRef, value); 1506 } 1507 1508 namespace { 1509 /// CRTP base class for Python MLIR values that subclass Value and should be 1510 /// castable from it. The value hierarchy is one level deep and is not supposed 1511 /// to accommodate other levels unless core MLIR changes. 1512 template <typename DerivedTy> 1513 class PyConcreteValue : public PyValue { 1514 public: 1515 // Derived classes must define statics for: 1516 // IsAFunctionTy isaFunction 1517 // const char *pyClassName 1518 // and redefine bindDerived. 1519 using ClassTy = py::class_<DerivedTy, PyValue>; 1520 using IsAFunctionTy = bool (*)(MlirValue); 1521 1522 PyConcreteValue() = default; 1523 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1524 : PyValue(operationRef, value) {} 1525 PyConcreteValue(PyValue &orig) 1526 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1527 1528 /// Attempts to cast the original value to the derived type and throws on 1529 /// type mismatches. 1530 static MlirValue castFrom(PyValue &orig) { 1531 if (!DerivedTy::isaFunction(orig.get())) { 1532 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1533 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1534 DerivedTy::pyClassName + 1535 " (from " + origRepr + ")"); 1536 } 1537 return orig.get(); 1538 } 1539 1540 /// Binds the Python module objects to functions of this class. 1541 static void bind(py::module &m) { 1542 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1543 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1544 DerivedTy::bindDerived(cls); 1545 } 1546 1547 /// Implemented by derived classes to add methods to the Python subclass. 1548 static void bindDerived(ClassTy &m) {} 1549 }; 1550 1551 /// Python wrapper for MlirBlockArgument. 1552 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1553 public: 1554 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1555 static constexpr const char *pyClassName = "BlockArgument"; 1556 using PyConcreteValue::PyConcreteValue; 1557 1558 static void bindDerived(ClassTy &c) { 1559 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1560 return PyBlock(self.getParentOperation(), 1561 mlirBlockArgumentGetOwner(self.get())); 1562 }); 1563 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1564 return mlirBlockArgumentGetArgNumber(self.get()); 1565 }); 1566 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1567 return mlirBlockArgumentSetType(self.get(), type); 1568 }); 1569 } 1570 }; 1571 1572 /// Python wrapper for MlirOpResult. 1573 class PyOpResult : public PyConcreteValue<PyOpResult> { 1574 public: 1575 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1576 static constexpr const char *pyClassName = "OpResult"; 1577 using PyConcreteValue::PyConcreteValue; 1578 1579 static void bindDerived(ClassTy &c) { 1580 c.def_property_readonly("owner", [](PyOpResult &self) { 1581 assert( 1582 mlirOperationEqual(self.getParentOperation()->get(), 1583 mlirOpResultGetOwner(self.get())) && 1584 "expected the owner of the value in Python to match that in the IR"); 1585 return self.getParentOperation().getObject(); 1586 }); 1587 c.def_property_readonly("result_number", [](PyOpResult &self) { 1588 return mlirOpResultGetResultNumber(self.get()); 1589 }); 1590 } 1591 }; 1592 1593 /// A list of block arguments. Internally, these are stored as consecutive 1594 /// elements, random access is cheap. The argument list is associated with the 1595 /// operation that contains the block (detached blocks are not allowed in 1596 /// Python bindings) and extends its lifetime. 1597 class PyBlockArgumentList 1598 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1599 public: 1600 static constexpr const char *pyClassName = "BlockArgumentList"; 1601 1602 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1603 intptr_t startIndex = 0, intptr_t length = -1, 1604 intptr_t step = 1) 1605 : Sliceable(startIndex, 1606 length == -1 ? mlirBlockGetNumArguments(block) : length, 1607 step), 1608 operation(std::move(operation)), block(block) {} 1609 1610 /// Returns the number of arguments in the list. 1611 intptr_t getNumElements() { 1612 operation->checkValid(); 1613 return mlirBlockGetNumArguments(block); 1614 } 1615 1616 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1617 PyBlockArgument getElement(intptr_t pos) { 1618 MlirValue argument = mlirBlockGetArgument(block, pos); 1619 return PyBlockArgument(operation, argument); 1620 } 1621 1622 /// Returns a sublist of this list. 1623 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1624 intptr_t step) { 1625 return PyBlockArgumentList(operation, block, startIndex, length, step); 1626 } 1627 1628 private: 1629 PyOperationRef operation; 1630 MlirBlock block; 1631 }; 1632 1633 /// A list of operation operands. Internally, these are stored as consecutive 1634 /// elements, random access is cheap. The result list is associated with the 1635 /// operation whose results these are, and extends the lifetime of this 1636 /// operation. 1637 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1638 public: 1639 static constexpr const char *pyClassName = "OpOperandList"; 1640 1641 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1642 intptr_t length = -1, intptr_t step = 1) 1643 : Sliceable(startIndex, 1644 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1645 : length, 1646 step), 1647 operation(operation) {} 1648 1649 intptr_t getNumElements() { 1650 operation->checkValid(); 1651 return mlirOperationGetNumOperands(operation->get()); 1652 } 1653 1654 PyValue getElement(intptr_t pos) { 1655 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1656 MlirOperation owner; 1657 if (mlirValueIsAOpResult(operand)) 1658 owner = mlirOpResultGetOwner(operand); 1659 else if (mlirValueIsABlockArgument(operand)) 1660 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1661 else 1662 assert(false && "Value must be an block arg or op result."); 1663 PyOperationRef pyOwner = 1664 PyOperation::forOperation(operation->getContext(), owner); 1665 return PyValue(pyOwner, operand); 1666 } 1667 1668 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1669 return PyOpOperandList(operation, startIndex, length, step); 1670 } 1671 1672 void dunderSetItem(intptr_t index, PyValue value) { 1673 index = wrapIndex(index); 1674 mlirOperationSetOperand(operation->get(), index, value.get()); 1675 } 1676 1677 static void bindDerived(ClassTy &c) { 1678 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1679 } 1680 1681 private: 1682 PyOperationRef operation; 1683 }; 1684 1685 /// A list of operation results. Internally, these are stored as consecutive 1686 /// elements, random access is cheap. The result list is associated with the 1687 /// operation whose results these are, and extends the lifetime of this 1688 /// operation. 1689 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1690 public: 1691 static constexpr const char *pyClassName = "OpResultList"; 1692 1693 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1694 intptr_t length = -1, intptr_t step = 1) 1695 : Sliceable(startIndex, 1696 length == -1 ? mlirOperationGetNumResults(operation->get()) 1697 : length, 1698 step), 1699 operation(operation) {} 1700 1701 intptr_t getNumElements() { 1702 operation->checkValid(); 1703 return mlirOperationGetNumResults(operation->get()); 1704 } 1705 1706 PyOpResult getElement(intptr_t index) { 1707 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1708 return PyOpResult(value); 1709 } 1710 1711 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1712 return PyOpResultList(operation, startIndex, length, step); 1713 } 1714 1715 private: 1716 PyOperationRef operation; 1717 }; 1718 1719 /// A list of operation attributes. Can be indexed by name, producing 1720 /// attributes, or by index, producing named attributes. 1721 class PyOpAttributeMap { 1722 public: 1723 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1724 1725 PyAttribute dunderGetItemNamed(const std::string &name) { 1726 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1727 toMlirStringRef(name)); 1728 if (mlirAttributeIsNull(attr)) { 1729 throw SetPyError(PyExc_KeyError, 1730 "attempt to access a non-existent attribute"); 1731 } 1732 return PyAttribute(operation->getContext(), attr); 1733 } 1734 1735 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1736 if (index < 0 || index >= dunderLen()) { 1737 throw SetPyError(PyExc_IndexError, 1738 "attempt to access out of bounds attribute"); 1739 } 1740 MlirNamedAttribute namedAttr = 1741 mlirOperationGetAttribute(operation->get(), index); 1742 return PyNamedAttribute( 1743 namedAttr.attribute, 1744 std::string(mlirIdentifierStr(namedAttr.name).data)); 1745 } 1746 1747 void dunderSetItem(const std::string &name, PyAttribute attr) { 1748 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1749 attr); 1750 } 1751 1752 void dunderDelItem(const std::string &name) { 1753 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1754 toMlirStringRef(name)); 1755 if (!removed) 1756 throw SetPyError(PyExc_KeyError, 1757 "attempt to delete a non-existent attribute"); 1758 } 1759 1760 intptr_t dunderLen() { 1761 return mlirOperationGetNumAttributes(operation->get()); 1762 } 1763 1764 bool dunderContains(const std::string &name) { 1765 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1766 operation->get(), toMlirStringRef(name))); 1767 } 1768 1769 static void bind(py::module &m) { 1770 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1771 .def("__contains__", &PyOpAttributeMap::dunderContains) 1772 .def("__len__", &PyOpAttributeMap::dunderLen) 1773 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1774 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1775 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1776 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1777 } 1778 1779 private: 1780 PyOperationRef operation; 1781 }; 1782 1783 } // end namespace 1784 1785 //------------------------------------------------------------------------------ 1786 // Populates the core exports of the 'ir' submodule. 1787 //------------------------------------------------------------------------------ 1788 1789 void mlir::python::populateIRCore(py::module &m) { 1790 //---------------------------------------------------------------------------- 1791 // Mapping of MlirContext. 1792 //---------------------------------------------------------------------------- 1793 py::class_<PyMlirContext>(m, "Context", py::module_local()) 1794 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1795 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1796 .def("_get_context_again", 1797 [](PyMlirContext &self) { 1798 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1799 return ref.releaseObject(); 1800 }) 1801 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1802 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1803 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1804 &PyMlirContext::getCapsule) 1805 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1806 .def("__enter__", &PyMlirContext::contextEnter) 1807 .def("__exit__", &PyMlirContext::contextExit) 1808 .def_property_readonly_static( 1809 "current", 1810 [](py::object & /*class*/) { 1811 auto *context = PyThreadContextEntry::getDefaultContext(); 1812 if (!context) 1813 throw SetPyError(PyExc_ValueError, "No current Context"); 1814 return context; 1815 }, 1816 "Gets the Context bound to the current thread or raises ValueError") 1817 .def_property_readonly( 1818 "dialects", 1819 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1820 "Gets a container for accessing dialects by name") 1821 .def_property_readonly( 1822 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1823 "Alias for 'dialect'") 1824 .def( 1825 "get_dialect_descriptor", 1826 [=](PyMlirContext &self, std::string &name) { 1827 MlirDialect dialect = mlirContextGetOrLoadDialect( 1828 self.get(), {name.data(), name.size()}); 1829 if (mlirDialectIsNull(dialect)) { 1830 throw SetPyError(PyExc_ValueError, 1831 Twine("Dialect '") + name + "' not found"); 1832 } 1833 return PyDialectDescriptor(self.getRef(), dialect); 1834 }, 1835 "Gets or loads a dialect by name, returning its descriptor object") 1836 .def_property( 1837 "allow_unregistered_dialects", 1838 [](PyMlirContext &self) -> bool { 1839 return mlirContextGetAllowUnregisteredDialects(self.get()); 1840 }, 1841 [](PyMlirContext &self, bool value) { 1842 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1843 }) 1844 .def("enable_multithreading", 1845 [](PyMlirContext &self, bool enable) { 1846 mlirContextEnableMultithreading(self.get(), enable); 1847 }) 1848 .def("is_registered_operation", 1849 [](PyMlirContext &self, std::string &name) { 1850 return mlirContextIsRegisteredOperation( 1851 self.get(), MlirStringRef{name.data(), name.size()}); 1852 }); 1853 1854 //---------------------------------------------------------------------------- 1855 // Mapping of PyDialectDescriptor 1856 //---------------------------------------------------------------------------- 1857 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1858 .def_property_readonly("namespace", 1859 [](PyDialectDescriptor &self) { 1860 MlirStringRef ns = 1861 mlirDialectGetNamespace(self.get()); 1862 return py::str(ns.data, ns.length); 1863 }) 1864 .def("__repr__", [](PyDialectDescriptor &self) { 1865 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1866 std::string repr("<DialectDescriptor "); 1867 repr.append(ns.data, ns.length); 1868 repr.append(">"); 1869 return repr; 1870 }); 1871 1872 //---------------------------------------------------------------------------- 1873 // Mapping of PyDialects 1874 //---------------------------------------------------------------------------- 1875 py::class_<PyDialects>(m, "Dialects", py::module_local()) 1876 .def("__getitem__", 1877 [=](PyDialects &self, std::string keyName) { 1878 MlirDialect dialect = 1879 self.getDialectForKey(keyName, /*attrError=*/false); 1880 py::object descriptor = 1881 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1882 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1883 }) 1884 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1885 MlirDialect dialect = 1886 self.getDialectForKey(attrName, /*attrError=*/true); 1887 py::object descriptor = 1888 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1889 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1890 }); 1891 1892 //---------------------------------------------------------------------------- 1893 // Mapping of PyDialect 1894 //---------------------------------------------------------------------------- 1895 py::class_<PyDialect>(m, "Dialect", py::module_local()) 1896 .def(py::init<py::object>(), "descriptor") 1897 .def_property_readonly( 1898 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1899 .def("__repr__", [](py::object self) { 1900 auto clazz = self.attr("__class__"); 1901 return py::str("<Dialect ") + 1902 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1903 clazz.attr("__module__") + py::str(".") + 1904 clazz.attr("__name__") + py::str(")>"); 1905 }); 1906 1907 //---------------------------------------------------------------------------- 1908 // Mapping of Location 1909 //---------------------------------------------------------------------------- 1910 py::class_<PyLocation>(m, "Location", py::module_local()) 1911 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1912 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1913 .def("__enter__", &PyLocation::contextEnter) 1914 .def("__exit__", &PyLocation::contextExit) 1915 .def("__eq__", 1916 [](PyLocation &self, PyLocation &other) -> bool { 1917 return mlirLocationEqual(self, other); 1918 }) 1919 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1920 .def_property_readonly_static( 1921 "current", 1922 [](py::object & /*class*/) { 1923 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1924 if (!loc) 1925 throw SetPyError(PyExc_ValueError, "No current Location"); 1926 return loc; 1927 }, 1928 "Gets the Location bound to the current thread or raises ValueError") 1929 .def_static( 1930 "unknown", 1931 [](DefaultingPyMlirContext context) { 1932 return PyLocation(context->getRef(), 1933 mlirLocationUnknownGet(context->get())); 1934 }, 1935 py::arg("context") = py::none(), 1936 "Gets a Location representing an unknown location") 1937 .def_static( 1938 "file", 1939 [](std::string filename, int line, int col, 1940 DefaultingPyMlirContext context) { 1941 return PyLocation( 1942 context->getRef(), 1943 mlirLocationFileLineColGet( 1944 context->get(), toMlirStringRef(filename), line, col)); 1945 }, 1946 py::arg("filename"), py::arg("line"), py::arg("col"), 1947 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1948 .def_property_readonly( 1949 "context", 1950 [](PyLocation &self) { return self.getContext().getObject(); }, 1951 "Context that owns the Location") 1952 .def("__repr__", [](PyLocation &self) { 1953 PyPrintAccumulator printAccum; 1954 mlirLocationPrint(self, printAccum.getCallback(), 1955 printAccum.getUserData()); 1956 return printAccum.join(); 1957 }); 1958 1959 //---------------------------------------------------------------------------- 1960 // Mapping of Module 1961 //---------------------------------------------------------------------------- 1962 py::class_<PyModule>(m, "Module", py::module_local()) 1963 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1964 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1965 .def_static( 1966 "parse", 1967 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1968 MlirModule module = mlirModuleCreateParse( 1969 context->get(), toMlirStringRef(moduleAsm)); 1970 // TODO: Rework error reporting once diagnostic engine is exposed 1971 // in C API. 1972 if (mlirModuleIsNull(module)) { 1973 throw SetPyError( 1974 PyExc_ValueError, 1975 "Unable to parse module assembly (see diagnostics)"); 1976 } 1977 return PyModule::forModule(module).releaseObject(); 1978 }, 1979 py::arg("asm"), py::arg("context") = py::none(), 1980 kModuleParseDocstring) 1981 .def_static( 1982 "create", 1983 [](DefaultingPyLocation loc) { 1984 MlirModule module = mlirModuleCreateEmpty(loc); 1985 return PyModule::forModule(module).releaseObject(); 1986 }, 1987 py::arg("loc") = py::none(), "Creates an empty module") 1988 .def_property_readonly( 1989 "context", 1990 [](PyModule &self) { return self.getContext().getObject(); }, 1991 "Context that created the Module") 1992 .def_property_readonly( 1993 "operation", 1994 [](PyModule &self) { 1995 return PyOperation::forOperation(self.getContext(), 1996 mlirModuleGetOperation(self.get()), 1997 self.getRef().releaseObject()) 1998 .releaseObject(); 1999 }, 2000 "Accesses the module as an operation") 2001 .def_property_readonly( 2002 "body", 2003 [](PyModule &self) { 2004 PyOperationRef module_op = PyOperation::forOperation( 2005 self.getContext(), mlirModuleGetOperation(self.get()), 2006 self.getRef().releaseObject()); 2007 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2008 return returnBlock; 2009 }, 2010 "Return the block for this module") 2011 .def( 2012 "dump", 2013 [](PyModule &self) { 2014 mlirOperationDump(mlirModuleGetOperation(self.get())); 2015 }, 2016 kDumpDocstring) 2017 .def( 2018 "__str__", 2019 [](PyModule &self) { 2020 MlirOperation operation = mlirModuleGetOperation(self.get()); 2021 PyPrintAccumulator printAccum; 2022 mlirOperationPrint(operation, printAccum.getCallback(), 2023 printAccum.getUserData()); 2024 return printAccum.join(); 2025 }, 2026 kOperationStrDunderDocstring); 2027 2028 //---------------------------------------------------------------------------- 2029 // Mapping of Operation. 2030 //---------------------------------------------------------------------------- 2031 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2032 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2033 [](PyOperationBase &self) { 2034 return self.getOperation().getCapsule(); 2035 }) 2036 .def("__eq__", 2037 [](PyOperationBase &self, PyOperationBase &other) { 2038 return &self.getOperation() == &other.getOperation(); 2039 }) 2040 .def("__eq__", 2041 [](PyOperationBase &self, py::object other) { return false; }) 2042 .def_property_readonly("attributes", 2043 [](PyOperationBase &self) { 2044 return PyOpAttributeMap( 2045 self.getOperation().getRef()); 2046 }) 2047 .def_property_readonly("operands", 2048 [](PyOperationBase &self) { 2049 return PyOpOperandList( 2050 self.getOperation().getRef()); 2051 }) 2052 .def_property_readonly("regions", 2053 [](PyOperationBase &self) { 2054 return PyRegionList( 2055 self.getOperation().getRef()); 2056 }) 2057 .def_property_readonly( 2058 "results", 2059 [](PyOperationBase &self) { 2060 return PyOpResultList(self.getOperation().getRef()); 2061 }, 2062 "Returns the list of Operation results.") 2063 .def_property_readonly( 2064 "result", 2065 [](PyOperationBase &self) { 2066 auto &operation = self.getOperation(); 2067 auto numResults = mlirOperationGetNumResults(operation); 2068 if (numResults != 1) { 2069 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2070 throw SetPyError( 2071 PyExc_ValueError, 2072 Twine("Cannot call .result on operation ") + 2073 StringRef(name.data, name.length) + " which has " + 2074 Twine(numResults) + 2075 " results (it is only valid for operations with a " 2076 "single result)"); 2077 } 2078 return PyOpResult(operation.getRef(), 2079 mlirOperationGetResult(operation, 0)); 2080 }, 2081 "Shortcut to get an op result if it has only one (throws an error " 2082 "otherwise).") 2083 .def("__iter__", 2084 [](PyOperationBase &self) { 2085 return PyRegionIterator(self.getOperation().getRef()); 2086 }) 2087 .def( 2088 "__str__", 2089 [](PyOperationBase &self) { 2090 return self.getAsm(/*binary=*/false, 2091 /*largeElementsLimit=*/llvm::None, 2092 /*enableDebugInfo=*/false, 2093 /*prettyDebugInfo=*/false, 2094 /*printGenericOpForm=*/false, 2095 /*useLocalScope=*/false); 2096 }, 2097 "Returns the assembly form of the operation.") 2098 .def("print", &PyOperationBase::print, 2099 // Careful: Lots of arguments must match up with print method. 2100 py::arg("file") = py::none(), py::arg("binary") = false, 2101 py::arg("large_elements_limit") = py::none(), 2102 py::arg("enable_debug_info") = false, 2103 py::arg("pretty_debug_info") = false, 2104 py::arg("print_generic_op_form") = false, 2105 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2106 .def("get_asm", &PyOperationBase::getAsm, 2107 // Careful: Lots of arguments must match up with get_asm method. 2108 py::arg("binary") = false, 2109 py::arg("large_elements_limit") = py::none(), 2110 py::arg("enable_debug_info") = false, 2111 py::arg("pretty_debug_info") = false, 2112 py::arg("print_generic_op_form") = false, 2113 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2114 .def( 2115 "verify", 2116 [](PyOperationBase &self) { 2117 return mlirOperationVerify(self.getOperation()); 2118 }, 2119 "Verify the operation and return true if it passes, false if it " 2120 "fails."); 2121 2122 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2123 .def_static("create", &PyOperation::create, py::arg("name"), 2124 py::arg("results") = py::none(), 2125 py::arg("operands") = py::none(), 2126 py::arg("attributes") = py::none(), 2127 py::arg("successors") = py::none(), py::arg("regions") = 0, 2128 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2129 kOperationCreateDocstring) 2130 .def_property_readonly("parent", 2131 [](PyOperation &self) -> py::object { 2132 auto parent = self.getParentOperation(); 2133 if (parent) 2134 return parent->getObject(); 2135 return py::none(); 2136 }) 2137 .def("erase", &PyOperation::erase) 2138 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2139 &PyOperation::getCapsule) 2140 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2141 .def_property_readonly("name", 2142 [](PyOperation &self) { 2143 self.checkValid(); 2144 MlirOperation operation = self.get(); 2145 MlirStringRef name = mlirIdentifierStr( 2146 mlirOperationGetName(operation)); 2147 return py::str(name.data, name.length); 2148 }) 2149 .def_property_readonly( 2150 "context", 2151 [](PyOperation &self) { 2152 self.checkValid(); 2153 return self.getContext().getObject(); 2154 }, 2155 "Context that owns the Operation") 2156 .def_property_readonly("opview", &PyOperation::createOpView); 2157 2158 auto opViewClass = 2159 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2160 .def(py::init<py::object>()) 2161 .def_property_readonly("operation", &PyOpView::getOperationObject) 2162 .def_property_readonly( 2163 "context", 2164 [](PyOpView &self) { 2165 return self.getOperation().getContext().getObject(); 2166 }, 2167 "Context that owns the Operation") 2168 .def("__str__", [](PyOpView &self) { 2169 return py::str(self.getOperationObject()); 2170 }); 2171 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2172 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2173 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2174 opViewClass.attr("build_generic") = classmethod( 2175 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2176 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2177 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2178 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2179 "Builds a specific, generated OpView based on class level attributes."); 2180 2181 //---------------------------------------------------------------------------- 2182 // Mapping of PyRegion. 2183 //---------------------------------------------------------------------------- 2184 py::class_<PyRegion>(m, "Region", py::module_local()) 2185 .def_property_readonly( 2186 "blocks", 2187 [](PyRegion &self) { 2188 return PyBlockList(self.getParentOperation(), self.get()); 2189 }, 2190 "Returns a forward-optimized sequence of blocks.") 2191 .def( 2192 "__iter__", 2193 [](PyRegion &self) { 2194 self.checkValid(); 2195 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2196 return PyBlockIterator(self.getParentOperation(), firstBlock); 2197 }, 2198 "Iterates over blocks in the region.") 2199 .def("__eq__", 2200 [](PyRegion &self, PyRegion &other) { 2201 return self.get().ptr == other.get().ptr; 2202 }) 2203 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2204 2205 //---------------------------------------------------------------------------- 2206 // Mapping of PyBlock. 2207 //---------------------------------------------------------------------------- 2208 py::class_<PyBlock>(m, "Block", py::module_local()) 2209 .def_property_readonly( 2210 "owner", 2211 [](PyBlock &self) { 2212 return self.getParentOperation()->createOpView(); 2213 }, 2214 "Returns the owning operation of this block.") 2215 .def_property_readonly( 2216 "region", 2217 [](PyBlock &self) { 2218 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2219 return PyRegion(self.getParentOperation(), region); 2220 }, 2221 "Returns the owning region of this block.") 2222 .def_property_readonly( 2223 "arguments", 2224 [](PyBlock &self) { 2225 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2226 }, 2227 "Returns a list of block arguments.") 2228 .def_property_readonly( 2229 "operations", 2230 [](PyBlock &self) { 2231 return PyOperationList(self.getParentOperation(), self.get()); 2232 }, 2233 "Returns a forward-optimized sequence of operations.") 2234 .def( 2235 "create_before", 2236 [](PyBlock &self, py::args pyArgTypes) { 2237 self.checkValid(); 2238 llvm::SmallVector<MlirType, 4> argTypes; 2239 argTypes.reserve(pyArgTypes.size()); 2240 for (auto &pyArg : pyArgTypes) { 2241 argTypes.push_back(pyArg.cast<PyType &>()); 2242 } 2243 2244 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2245 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2246 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2247 return PyBlock(self.getParentOperation(), block); 2248 }, 2249 "Creates and returns a new Block before this block " 2250 "(with given argument types).") 2251 .def( 2252 "create_after", 2253 [](PyBlock &self, py::args pyArgTypes) { 2254 self.checkValid(); 2255 llvm::SmallVector<MlirType, 4> argTypes; 2256 argTypes.reserve(pyArgTypes.size()); 2257 for (auto &pyArg : pyArgTypes) { 2258 argTypes.push_back(pyArg.cast<PyType &>()); 2259 } 2260 2261 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2262 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2263 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2264 return PyBlock(self.getParentOperation(), block); 2265 }, 2266 "Creates and returns a new Block after this block " 2267 "(with given argument types).") 2268 .def( 2269 "__iter__", 2270 [](PyBlock &self) { 2271 self.checkValid(); 2272 MlirOperation firstOperation = 2273 mlirBlockGetFirstOperation(self.get()); 2274 return PyOperationIterator(self.getParentOperation(), 2275 firstOperation); 2276 }, 2277 "Iterates over operations in the block.") 2278 .def("__eq__", 2279 [](PyBlock &self, PyBlock &other) { 2280 return self.get().ptr == other.get().ptr; 2281 }) 2282 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2283 .def( 2284 "__str__", 2285 [](PyBlock &self) { 2286 self.checkValid(); 2287 PyPrintAccumulator printAccum; 2288 mlirBlockPrint(self.get(), printAccum.getCallback(), 2289 printAccum.getUserData()); 2290 return printAccum.join(); 2291 }, 2292 "Returns the assembly form of the block."); 2293 2294 //---------------------------------------------------------------------------- 2295 // Mapping of PyInsertionPoint. 2296 //---------------------------------------------------------------------------- 2297 2298 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2299 .def(py::init<PyBlock &>(), py::arg("block"), 2300 "Inserts after the last operation but still inside the block.") 2301 .def("__enter__", &PyInsertionPoint::contextEnter) 2302 .def("__exit__", &PyInsertionPoint::contextExit) 2303 .def_property_readonly_static( 2304 "current", 2305 [](py::object & /*class*/) { 2306 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2307 if (!ip) 2308 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2309 return ip; 2310 }, 2311 "Gets the InsertionPoint bound to the current thread or raises " 2312 "ValueError if none has been set") 2313 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2314 "Inserts before a referenced operation.") 2315 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2316 py::arg("block"), "Inserts at the beginning of the block.") 2317 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2318 py::arg("block"), "Inserts before the block terminator.") 2319 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2320 "Inserts an operation.") 2321 .def_property_readonly( 2322 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2323 "Returns the block that this InsertionPoint points to."); 2324 2325 //---------------------------------------------------------------------------- 2326 // Mapping of PyAttribute. 2327 //---------------------------------------------------------------------------- 2328 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2329 // Delegate to the PyAttribute copy constructor, which will also lifetime 2330 // extend the backing context which owns the MlirAttribute. 2331 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2332 "Casts the passed attribute to the generic Attribute") 2333 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2334 &PyAttribute::getCapsule) 2335 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2336 .def_static( 2337 "parse", 2338 [](std::string attrSpec, DefaultingPyMlirContext context) { 2339 MlirAttribute type = mlirAttributeParseGet( 2340 context->get(), toMlirStringRef(attrSpec)); 2341 // TODO: Rework error reporting once diagnostic engine is exposed 2342 // in C API. 2343 if (mlirAttributeIsNull(type)) { 2344 throw SetPyError(PyExc_ValueError, 2345 Twine("Unable to parse attribute: '") + 2346 attrSpec + "'"); 2347 } 2348 return PyAttribute(context->getRef(), type); 2349 }, 2350 py::arg("asm"), py::arg("context") = py::none(), 2351 "Parses an attribute from an assembly form") 2352 .def_property_readonly( 2353 "context", 2354 [](PyAttribute &self) { return self.getContext().getObject(); }, 2355 "Context that owns the Attribute") 2356 .def_property_readonly("type", 2357 [](PyAttribute &self) { 2358 return PyType(self.getContext()->getRef(), 2359 mlirAttributeGetType(self)); 2360 }) 2361 .def( 2362 "get_named", 2363 [](PyAttribute &self, std::string name) { 2364 return PyNamedAttribute(self, std::move(name)); 2365 }, 2366 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2367 .def("__eq__", 2368 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2369 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2370 .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; }) 2371 .def( 2372 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2373 kDumpDocstring) 2374 .def( 2375 "__str__", 2376 [](PyAttribute &self) { 2377 PyPrintAccumulator printAccum; 2378 mlirAttributePrint(self, printAccum.getCallback(), 2379 printAccum.getUserData()); 2380 return printAccum.join(); 2381 }, 2382 "Returns the assembly form of the Attribute.") 2383 .def("__repr__", [](PyAttribute &self) { 2384 // Generally, assembly formats are not printed for __repr__ because 2385 // this can cause exceptionally long debug output and exceptions. 2386 // However, attribute values are generally considered useful and are 2387 // printed. This may need to be re-evaluated if debug dumps end up 2388 // being excessive. 2389 PyPrintAccumulator printAccum; 2390 printAccum.parts.append("Attribute("); 2391 mlirAttributePrint(self, printAccum.getCallback(), 2392 printAccum.getUserData()); 2393 printAccum.parts.append(")"); 2394 return printAccum.join(); 2395 }); 2396 2397 //---------------------------------------------------------------------------- 2398 // Mapping of PyNamedAttribute 2399 //---------------------------------------------------------------------------- 2400 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2401 .def("__repr__", 2402 [](PyNamedAttribute &self) { 2403 PyPrintAccumulator printAccum; 2404 printAccum.parts.append("NamedAttribute("); 2405 printAccum.parts.append( 2406 mlirIdentifierStr(self.namedAttr.name).data); 2407 printAccum.parts.append("="); 2408 mlirAttributePrint(self.namedAttr.attribute, 2409 printAccum.getCallback(), 2410 printAccum.getUserData()); 2411 printAccum.parts.append(")"); 2412 return printAccum.join(); 2413 }) 2414 .def_property_readonly( 2415 "name", 2416 [](PyNamedAttribute &self) { 2417 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2418 mlirIdentifierStr(self.namedAttr.name).length); 2419 }, 2420 "The name of the NamedAttribute binding") 2421 .def_property_readonly( 2422 "attr", 2423 [](PyNamedAttribute &self) { 2424 // TODO: When named attribute is removed/refactored, also remove 2425 // this constructor (it does an inefficient table lookup). 2426 auto contextRef = PyMlirContext::forContext( 2427 mlirAttributeGetContext(self.namedAttr.attribute)); 2428 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2429 }, 2430 py::keep_alive<0, 1>(), 2431 "The underlying generic attribute of the NamedAttribute binding"); 2432 2433 //---------------------------------------------------------------------------- 2434 // Mapping of PyType. 2435 //---------------------------------------------------------------------------- 2436 py::class_<PyType>(m, "Type", py::module_local()) 2437 // Delegate to the PyType copy constructor, which will also lifetime 2438 // extend the backing context which owns the MlirType. 2439 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2440 "Casts the passed type to the generic Type") 2441 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2442 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2443 .def_static( 2444 "parse", 2445 [](std::string typeSpec, DefaultingPyMlirContext context) { 2446 MlirType type = 2447 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2448 // TODO: Rework error reporting once diagnostic engine is exposed 2449 // in C API. 2450 if (mlirTypeIsNull(type)) { 2451 throw SetPyError(PyExc_ValueError, 2452 Twine("Unable to parse type: '") + typeSpec + 2453 "'"); 2454 } 2455 return PyType(context->getRef(), type); 2456 }, 2457 py::arg("asm"), py::arg("context") = py::none(), 2458 kContextParseTypeDocstring) 2459 .def_property_readonly( 2460 "context", [](PyType &self) { return self.getContext().getObject(); }, 2461 "Context that owns the Type") 2462 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2463 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2464 .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; }) 2465 .def( 2466 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2467 .def( 2468 "__str__", 2469 [](PyType &self) { 2470 PyPrintAccumulator printAccum; 2471 mlirTypePrint(self, printAccum.getCallback(), 2472 printAccum.getUserData()); 2473 return printAccum.join(); 2474 }, 2475 "Returns the assembly form of the type.") 2476 .def("__repr__", [](PyType &self) { 2477 // Generally, assembly formats are not printed for __repr__ because 2478 // this can cause exceptionally long debug output and exceptions. 2479 // However, types are an exception as they typically have compact 2480 // assembly forms and printing them is useful. 2481 PyPrintAccumulator printAccum; 2482 printAccum.parts.append("Type("); 2483 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2484 printAccum.parts.append(")"); 2485 return printAccum.join(); 2486 }); 2487 2488 //---------------------------------------------------------------------------- 2489 // Mapping of Value. 2490 //---------------------------------------------------------------------------- 2491 py::class_<PyValue>(m, "Value", py::module_local()) 2492 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2493 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2494 .def_property_readonly( 2495 "context", 2496 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2497 "Context in which the value lives.") 2498 .def( 2499 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2500 kDumpDocstring) 2501 .def_property_readonly( 2502 "owner", 2503 [](PyValue &self) { 2504 assert(mlirOperationEqual(self.getParentOperation()->get(), 2505 mlirOpResultGetOwner(self.get())) && 2506 "expected the owner of the value in Python to match that in " 2507 "the IR"); 2508 return self.getParentOperation().getObject(); 2509 }) 2510 .def("__eq__", 2511 [](PyValue &self, PyValue &other) { 2512 return self.get().ptr == other.get().ptr; 2513 }) 2514 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2515 .def( 2516 "__str__", 2517 [](PyValue &self) { 2518 PyPrintAccumulator printAccum; 2519 printAccum.parts.append("Value("); 2520 mlirValuePrint(self.get(), printAccum.getCallback(), 2521 printAccum.getUserData()); 2522 printAccum.parts.append(")"); 2523 return printAccum.join(); 2524 }, 2525 kValueDunderStrDocstring) 2526 .def_property_readonly("type", [](PyValue &self) { 2527 return PyType(self.getParentOperation()->getContext(), 2528 mlirValueGetType(self.get())); 2529 }); 2530 PyBlockArgument::bind(m); 2531 PyOpResult::bind(m); 2532 2533 // Container bindings. 2534 PyBlockArgumentList::bind(m); 2535 PyBlockIterator::bind(m); 2536 PyBlockList::bind(m); 2537 PyOperationIterator::bind(m); 2538 PyOperationList::bind(m); 2539 PyOpAttributeMap::bind(m); 2540 PyOpOperandList::bind(m); 2541 PyOpResultList::bind(m); 2542 PyRegionIterator::bind(m); 2543 PyRegionList::bind(m); 2544 2545 // Debug bindings. 2546 PyGlobalDebugFlag::bind(m); 2547 } 2548