1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 namespace py = pybind11; 25 using namespace mlir; 26 using namespace mlir::python; 27 28 using llvm::SmallVector; 29 using llvm::StringRef; 30 using llvm::Twine; 31 32 //------------------------------------------------------------------------------ 33 // Docstrings (trivial, non-duplicated docstrings are included inline). 34 //------------------------------------------------------------------------------ 35 36 static const char kContextParseTypeDocstring[] = 37 R"(Parses the assembly form of a type. 38 39 Returns a Type object or raises a ValueError if the type cannot be parsed. 40 41 See also: https://mlir.llvm.org/docs/LangRef/#type-system 42 )"; 43 44 static const char kContextGetCallSiteLocationDocstring[] = 45 R"(Gets a Location representing a caller and callsite)"; 46 47 static const char kContextGetFileLocationDocstring[] = 48 R"(Gets a Location representing a file, line and column)"; 49 50 static const char kContextGetNameLocationDocString[] = 51 R"(Gets a Location representing a named location with optional child location)"; 52 53 static const char kModuleParseDocstring[] = 54 R"(Parses a module's assembly format from a string. 55 56 Returns a new MlirModule or raises a ValueError if the parsing fails. 57 58 See also: https://mlir.llvm.org/docs/LangRef/ 59 )"; 60 61 static const char kOperationCreateDocstring[] = 62 R"(Creates a new operation. 63 64 Args: 65 name: Operation name (e.g. "dialect.operation"). 66 results: Sequence of Type representing op result types. 67 attributes: Dict of str:Attribute. 68 successors: List of Block for the operation's successors. 69 regions: Number of regions to create. 70 location: A Location object (defaults to resolve from context manager). 71 ip: An InsertionPoint (defaults to resolve from context manager or set to 72 False to disable insertion, even with an insertion point set in the 73 context manager). 74 Returns: 75 A new "detached" Operation object. Detached operations can be added 76 to blocks, which causes them to become "attached." 77 )"; 78 79 static const char kOperationPrintDocstring[] = 80 R"(Prints the assembly form of the operation to a file like object. 81 82 Args: 83 file: The file like object to write to. Defaults to sys.stdout. 84 binary: Whether to write bytes (True) or str (False). Defaults to False. 85 large_elements_limit: Whether to elide elements attributes above this 86 number of elements. Defaults to None (no limit). 87 enable_debug_info: Whether to print debug/location information. Defaults 88 to False. 89 pretty_debug_info: Whether to format debug information for easier reading 90 by a human (warning: the result is unparseable). 91 print_generic_op_form: Whether to print the generic assembly forms of all 92 ops. Defaults to False. 93 use_local_Scope: Whether to print in a way that is more optimized for 94 multi-threaded access but may not be consistent with how the overall 95 module prints. 96 )"; 97 98 static const char kOperationGetAsmDocstring[] = 99 R"(Gets the assembly form of the operation with all options available. 100 101 Args: 102 binary: Whether to return a bytes (True) or str (False) object. Defaults to 103 False. 104 ... others ...: See the print() method for common keyword arguments for 105 configuring the printout. 106 Returns: 107 Either a bytes or str object, depending on the setting of the 'binary' 108 argument. 109 )"; 110 111 static const char kOperationStrDunderDocstring[] = 112 R"(Gets the assembly form of the operation with default options. 113 114 If more advanced control over the assembly formatting or I/O options is needed, 115 use the dedicated print or get_asm method, which supports keyword arguments to 116 customize behavior. 117 )"; 118 119 static const char kDumpDocstring[] = 120 R"(Dumps a debug representation of the object to stderr.)"; 121 122 static const char kAppendBlockDocstring[] = 123 R"(Appends a new block, with argument types as positional args. 124 125 Returns: 126 The created block. 127 )"; 128 129 static const char kValueDunderStrDocstring[] = 130 R"(Returns the string form of the value. 131 132 If the value is a block argument, this is the assembly form of its type and the 133 position in the argument list. If the value is an operation result, this is 134 equivalent to printing the operation that produced it. 135 )"; 136 137 //------------------------------------------------------------------------------ 138 // Utilities. 139 //------------------------------------------------------------------------------ 140 141 /// Helper for creating an @classmethod. 142 template <class Func, typename... Args> 143 py::object classmethod(Func f, Args... args) { 144 py::object cf = py::cpp_function(f, args...); 145 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 146 } 147 148 static py::object 149 createCustomDialectWrapper(const std::string &dialectNamespace, 150 py::object dialectDescriptor) { 151 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 152 if (!dialectClass) { 153 // Use the base class. 154 return py::cast(PyDialect(std::move(dialectDescriptor))); 155 } 156 157 // Create the custom implementation. 158 return (*dialectClass)(std::move(dialectDescriptor)); 159 } 160 161 static MlirStringRef toMlirStringRef(const std::string &s) { 162 return mlirStringRefCreate(s.data(), s.size()); 163 } 164 165 /// Wrapper for the global LLVM debugging flag. 166 struct PyGlobalDebugFlag { 167 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 168 169 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 170 171 static void bind(py::module &m) { 172 // Debug flags. 173 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 174 .def_property_static("flag", &PyGlobalDebugFlag::get, 175 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 176 } 177 }; 178 179 //------------------------------------------------------------------------------ 180 // Collections. 181 //------------------------------------------------------------------------------ 182 183 namespace { 184 185 class PyRegionIterator { 186 public: 187 PyRegionIterator(PyOperationRef operation) 188 : operation(std::move(operation)) {} 189 190 PyRegionIterator &dunderIter() { return *this; } 191 192 PyRegion dunderNext() { 193 operation->checkValid(); 194 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 195 throw py::stop_iteration(); 196 } 197 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 198 return PyRegion(operation, region); 199 } 200 201 static void bind(py::module &m) { 202 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 203 .def("__iter__", &PyRegionIterator::dunderIter) 204 .def("__next__", &PyRegionIterator::dunderNext); 205 } 206 207 private: 208 PyOperationRef operation; 209 int nextIndex = 0; 210 }; 211 212 /// Regions of an op are fixed length and indexed numerically so are represented 213 /// with a sequence-like container. 214 class PyRegionList { 215 public: 216 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 217 218 intptr_t dunderLen() { 219 operation->checkValid(); 220 return mlirOperationGetNumRegions(operation->get()); 221 } 222 223 PyRegion dunderGetItem(intptr_t index) { 224 // dunderLen checks validity. 225 if (index < 0 || index >= dunderLen()) { 226 throw SetPyError(PyExc_IndexError, 227 "attempt to access out of bounds region"); 228 } 229 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 230 return PyRegion(operation, region); 231 } 232 233 static void bind(py::module &m) { 234 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 235 .def("__len__", &PyRegionList::dunderLen) 236 .def("__getitem__", &PyRegionList::dunderGetItem); 237 } 238 239 private: 240 PyOperationRef operation; 241 }; 242 243 class PyBlockIterator { 244 public: 245 PyBlockIterator(PyOperationRef operation, MlirBlock next) 246 : operation(std::move(operation)), next(next) {} 247 248 PyBlockIterator &dunderIter() { return *this; } 249 250 PyBlock dunderNext() { 251 operation->checkValid(); 252 if (mlirBlockIsNull(next)) { 253 throw py::stop_iteration(); 254 } 255 256 PyBlock returnBlock(operation, next); 257 next = mlirBlockGetNextInRegion(next); 258 return returnBlock; 259 } 260 261 static void bind(py::module &m) { 262 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 263 .def("__iter__", &PyBlockIterator::dunderIter) 264 .def("__next__", &PyBlockIterator::dunderNext); 265 } 266 267 private: 268 PyOperationRef operation; 269 MlirBlock next; 270 }; 271 272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 273 /// we present them as a more full-featured list-like container but optimize 274 /// it for forward iteration. Blocks are always owned by a region. 275 class PyBlockList { 276 public: 277 PyBlockList(PyOperationRef operation, MlirRegion region) 278 : operation(std::move(operation)), region(region) {} 279 280 PyBlockIterator dunderIter() { 281 operation->checkValid(); 282 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 283 } 284 285 intptr_t dunderLen() { 286 operation->checkValid(); 287 intptr_t count = 0; 288 MlirBlock block = mlirRegionGetFirstBlock(region); 289 while (!mlirBlockIsNull(block)) { 290 count += 1; 291 block = mlirBlockGetNextInRegion(block); 292 } 293 return count; 294 } 295 296 PyBlock dunderGetItem(intptr_t index) { 297 operation->checkValid(); 298 if (index < 0) { 299 throw SetPyError(PyExc_IndexError, 300 "attempt to access out of bounds block"); 301 } 302 MlirBlock block = mlirRegionGetFirstBlock(region); 303 while (!mlirBlockIsNull(block)) { 304 if (index == 0) { 305 return PyBlock(operation, block); 306 } 307 block = mlirBlockGetNextInRegion(block); 308 index -= 1; 309 } 310 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 311 } 312 313 PyBlock appendBlock(py::args pyArgTypes) { 314 operation->checkValid(); 315 llvm::SmallVector<MlirType, 4> argTypes; 316 argTypes.reserve(pyArgTypes.size()); 317 for (auto &pyArg : pyArgTypes) { 318 argTypes.push_back(pyArg.cast<PyType &>()); 319 } 320 321 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 322 mlirRegionAppendOwnedBlock(region, block); 323 return PyBlock(operation, block); 324 } 325 326 static void bind(py::module &m) { 327 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 328 .def("__getitem__", &PyBlockList::dunderGetItem) 329 .def("__iter__", &PyBlockList::dunderIter) 330 .def("__len__", &PyBlockList::dunderLen) 331 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 332 } 333 334 private: 335 PyOperationRef operation; 336 MlirRegion region; 337 }; 338 339 class PyOperationIterator { 340 public: 341 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 342 : parentOperation(std::move(parentOperation)), next(next) {} 343 344 PyOperationIterator &dunderIter() { return *this; } 345 346 py::object dunderNext() { 347 parentOperation->checkValid(); 348 if (mlirOperationIsNull(next)) { 349 throw py::stop_iteration(); 350 } 351 352 PyOperationRef returnOperation = 353 PyOperation::forOperation(parentOperation->getContext(), next); 354 next = mlirOperationGetNextInBlock(next); 355 return returnOperation->createOpView(); 356 } 357 358 static void bind(py::module &m) { 359 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 360 .def("__iter__", &PyOperationIterator::dunderIter) 361 .def("__next__", &PyOperationIterator::dunderNext); 362 } 363 364 private: 365 PyOperationRef parentOperation; 366 MlirOperation next; 367 }; 368 369 /// Operations are exposed by the C-API as a forward-only linked list. In 370 /// Python, we present them as a more full-featured list-like container but 371 /// optimize it for forward iteration. Iterable operations are always owned 372 /// by a block. 373 class PyOperationList { 374 public: 375 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 376 : parentOperation(std::move(parentOperation)), block(block) {} 377 378 PyOperationIterator dunderIter() { 379 parentOperation->checkValid(); 380 return PyOperationIterator(parentOperation, 381 mlirBlockGetFirstOperation(block)); 382 } 383 384 intptr_t dunderLen() { 385 parentOperation->checkValid(); 386 intptr_t count = 0; 387 MlirOperation childOp = mlirBlockGetFirstOperation(block); 388 while (!mlirOperationIsNull(childOp)) { 389 count += 1; 390 childOp = mlirOperationGetNextInBlock(childOp); 391 } 392 return count; 393 } 394 395 py::object dunderGetItem(intptr_t index) { 396 parentOperation->checkValid(); 397 if (index < 0) { 398 throw SetPyError(PyExc_IndexError, 399 "attempt to access out of bounds operation"); 400 } 401 MlirOperation childOp = mlirBlockGetFirstOperation(block); 402 while (!mlirOperationIsNull(childOp)) { 403 if (index == 0) { 404 return PyOperation::forOperation(parentOperation->getContext(), childOp) 405 ->createOpView(); 406 } 407 childOp = mlirOperationGetNextInBlock(childOp); 408 index -= 1; 409 } 410 throw SetPyError(PyExc_IndexError, 411 "attempt to access out of bounds operation"); 412 } 413 414 static void bind(py::module &m) { 415 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 416 .def("__getitem__", &PyOperationList::dunderGetItem) 417 .def("__iter__", &PyOperationList::dunderIter) 418 .def("__len__", &PyOperationList::dunderLen); 419 } 420 421 private: 422 PyOperationRef parentOperation; 423 MlirBlock block; 424 }; 425 426 } // namespace 427 428 //------------------------------------------------------------------------------ 429 // PyMlirContext 430 //------------------------------------------------------------------------------ 431 432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 433 py::gil_scoped_acquire acquire; 434 auto &liveContexts = getLiveContexts(); 435 liveContexts[context.ptr] = this; 436 } 437 438 PyMlirContext::~PyMlirContext() { 439 // Note that the only public way to construct an instance is via the 440 // forContext method, which always puts the associated handle into 441 // liveContexts. 442 py::gil_scoped_acquire acquire; 443 getLiveContexts().erase(context.ptr); 444 mlirContextDestroy(context); 445 } 446 447 py::object PyMlirContext::getCapsule() { 448 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 449 } 450 451 py::object PyMlirContext::createFromCapsule(py::object capsule) { 452 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 453 if (mlirContextIsNull(rawContext)) 454 throw py::error_already_set(); 455 return forContext(rawContext).releaseObject(); 456 } 457 458 PyMlirContext *PyMlirContext::createNewContextForInit() { 459 MlirContext context = mlirContextCreate(); 460 mlirRegisterAllDialects(context); 461 return new PyMlirContext(context); 462 } 463 464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 465 py::gil_scoped_acquire acquire; 466 auto &liveContexts = getLiveContexts(); 467 auto it = liveContexts.find(context.ptr); 468 if (it == liveContexts.end()) { 469 // Create. 470 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 471 py::object pyRef = py::cast(unownedContextWrapper); 472 assert(pyRef && "cast to py::object failed"); 473 liveContexts[context.ptr] = unownedContextWrapper; 474 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 475 } 476 // Use existing. 477 py::object pyRef = py::cast(it->second); 478 return PyMlirContextRef(it->second, std::move(pyRef)); 479 } 480 481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 482 static LiveContextMap liveContexts; 483 return liveContexts; 484 } 485 486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 487 488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 489 490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 491 492 pybind11::object PyMlirContext::contextEnter() { 493 return PyThreadContextEntry::pushContext(*this); 494 } 495 496 void PyMlirContext::contextExit(pybind11::object excType, 497 pybind11::object excVal, 498 pybind11::object excTb) { 499 PyThreadContextEntry::popContext(*this); 500 } 501 502 PyMlirContext &DefaultingPyMlirContext::resolve() { 503 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 504 if (!context) { 505 throw SetPyError( 506 PyExc_RuntimeError, 507 "An MLIR function requires a Context but none was provided in the call " 508 "or from the surrounding environment. Either pass to the function with " 509 "a 'context=' argument or establish a default using 'with Context():'"); 510 } 511 return *context; 512 } 513 514 //------------------------------------------------------------------------------ 515 // PyThreadContextEntry management 516 //------------------------------------------------------------------------------ 517 518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 519 static thread_local std::vector<PyThreadContextEntry> stack; 520 return stack; 521 } 522 523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 524 auto &stack = getStack(); 525 if (stack.empty()) 526 return nullptr; 527 return &stack.back(); 528 } 529 530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 531 py::object insertionPoint, 532 py::object location) { 533 auto &stack = getStack(); 534 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 535 std::move(location)); 536 // If the new stack has more than one entry and the context of the new top 537 // entry matches the previous, copy the insertionPoint and location from the 538 // previous entry if missing from the new top entry. 539 if (stack.size() > 1) { 540 auto &prev = *(stack.rbegin() + 1); 541 auto ¤t = stack.back(); 542 if (current.context.is(prev.context)) { 543 // Default non-context objects from the previous entry. 544 if (!current.insertionPoint) 545 current.insertionPoint = prev.insertionPoint; 546 if (!current.location) 547 current.location = prev.location; 548 } 549 } 550 } 551 552 PyMlirContext *PyThreadContextEntry::getContext() { 553 if (!context) 554 return nullptr; 555 return py::cast<PyMlirContext *>(context); 556 } 557 558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 559 if (!insertionPoint) 560 return nullptr; 561 return py::cast<PyInsertionPoint *>(insertionPoint); 562 } 563 564 PyLocation *PyThreadContextEntry::getLocation() { 565 if (!location) 566 return nullptr; 567 return py::cast<PyLocation *>(location); 568 } 569 570 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 571 auto *tos = getTopOfStack(); 572 return tos ? tos->getContext() : nullptr; 573 } 574 575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 576 auto *tos = getTopOfStack(); 577 return tos ? tos->getInsertionPoint() : nullptr; 578 } 579 580 PyLocation *PyThreadContextEntry::getDefaultLocation() { 581 auto *tos = getTopOfStack(); 582 return tos ? tos->getLocation() : nullptr; 583 } 584 585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 586 py::object contextObj = py::cast(context); 587 push(FrameKind::Context, /*context=*/contextObj, 588 /*insertionPoint=*/py::object(), 589 /*location=*/py::object()); 590 return contextObj; 591 } 592 593 void PyThreadContextEntry::popContext(PyMlirContext &context) { 594 auto &stack = getStack(); 595 if (stack.empty()) 596 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 597 auto &tos = stack.back(); 598 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 599 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 600 stack.pop_back(); 601 } 602 603 py::object 604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 605 py::object contextObj = 606 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 607 py::object insertionPointObj = py::cast(insertionPoint); 608 push(FrameKind::InsertionPoint, 609 /*context=*/contextObj, 610 /*insertionPoint=*/insertionPointObj, 611 /*location=*/py::object()); 612 return insertionPointObj; 613 } 614 615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 616 auto &stack = getStack(); 617 if (stack.empty()) 618 throw SetPyError(PyExc_RuntimeError, 619 "Unbalanced InsertionPoint enter/exit"); 620 auto &tos = stack.back(); 621 if (tos.frameKind != FrameKind::InsertionPoint && 622 tos.getInsertionPoint() != &insertionPoint) 623 throw SetPyError(PyExc_RuntimeError, 624 "Unbalanced InsertionPoint enter/exit"); 625 stack.pop_back(); 626 } 627 628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 629 py::object contextObj = location.getContext().getObject(); 630 py::object locationObj = py::cast(location); 631 push(FrameKind::Location, /*context=*/contextObj, 632 /*insertionPoint=*/py::object(), 633 /*location=*/locationObj); 634 return locationObj; 635 } 636 637 void PyThreadContextEntry::popLocation(PyLocation &location) { 638 auto &stack = getStack(); 639 if (stack.empty()) 640 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 641 auto &tos = stack.back(); 642 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 643 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 644 stack.pop_back(); 645 } 646 647 //------------------------------------------------------------------------------ 648 // PyDialect, PyDialectDescriptor, PyDialects 649 //------------------------------------------------------------------------------ 650 651 MlirDialect PyDialects::getDialectForKey(const std::string &key, 652 bool attrError) { 653 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 654 {key.data(), key.size()}); 655 if (mlirDialectIsNull(dialect)) { 656 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 657 Twine("Dialect '") + key + "' not found"); 658 } 659 return dialect; 660 } 661 662 //------------------------------------------------------------------------------ 663 // PyLocation 664 //------------------------------------------------------------------------------ 665 666 py::object PyLocation::getCapsule() { 667 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 668 } 669 670 PyLocation PyLocation::createFromCapsule(py::object capsule) { 671 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 672 if (mlirLocationIsNull(rawLoc)) 673 throw py::error_already_set(); 674 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 675 rawLoc); 676 } 677 678 py::object PyLocation::contextEnter() { 679 return PyThreadContextEntry::pushLocation(*this); 680 } 681 682 void PyLocation::contextExit(py::object excType, py::object excVal, 683 py::object excTb) { 684 PyThreadContextEntry::popLocation(*this); 685 } 686 687 PyLocation &DefaultingPyLocation::resolve() { 688 auto *location = PyThreadContextEntry::getDefaultLocation(); 689 if (!location) { 690 throw SetPyError( 691 PyExc_RuntimeError, 692 "An MLIR function requires a Location but none was provided in the " 693 "call or from the surrounding environment. Either pass to the function " 694 "with a 'loc=' argument or establish a default using 'with loc:'"); 695 } 696 return *location; 697 } 698 699 //------------------------------------------------------------------------------ 700 // PyModule 701 //------------------------------------------------------------------------------ 702 703 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 704 : BaseContextObject(std::move(contextRef)), module(module) {} 705 706 PyModule::~PyModule() { 707 py::gil_scoped_acquire acquire; 708 auto &liveModules = getContext()->liveModules; 709 assert(liveModules.count(module.ptr) == 1 && 710 "destroying module not in live map"); 711 liveModules.erase(module.ptr); 712 mlirModuleDestroy(module); 713 } 714 715 PyModuleRef PyModule::forModule(MlirModule module) { 716 MlirContext context = mlirModuleGetContext(module); 717 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 718 719 py::gil_scoped_acquire acquire; 720 auto &liveModules = contextRef->liveModules; 721 auto it = liveModules.find(module.ptr); 722 if (it == liveModules.end()) { 723 // Create. 724 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 725 // Note that the default return value policy on cast is automatic_reference, 726 // which does not take ownership (delete will not be called). 727 // Just be explicit. 728 py::object pyRef = 729 py::cast(unownedModule, py::return_value_policy::take_ownership); 730 unownedModule->handle = pyRef; 731 liveModules[module.ptr] = 732 std::make_pair(unownedModule->handle, unownedModule); 733 return PyModuleRef(unownedModule, std::move(pyRef)); 734 } 735 // Use existing. 736 PyModule *existing = it->second.second; 737 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 738 return PyModuleRef(existing, std::move(pyRef)); 739 } 740 741 py::object PyModule::createFromCapsule(py::object capsule) { 742 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 743 if (mlirModuleIsNull(rawModule)) 744 throw py::error_already_set(); 745 return forModule(rawModule).releaseObject(); 746 } 747 748 py::object PyModule::getCapsule() { 749 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 750 } 751 752 //------------------------------------------------------------------------------ 753 // PyOperation 754 //------------------------------------------------------------------------------ 755 756 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 757 : BaseContextObject(std::move(contextRef)), operation(operation) {} 758 759 PyOperation::~PyOperation() { 760 // If the operation has already been invalidated there is nothing to do. 761 if (!valid) 762 return; 763 auto &liveOperations = getContext()->liveOperations; 764 assert(liveOperations.count(operation.ptr) == 1 && 765 "destroying operation not in live map"); 766 liveOperations.erase(operation.ptr); 767 if (!isAttached()) { 768 mlirOperationDestroy(operation); 769 } 770 } 771 772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 773 MlirOperation operation, 774 py::object parentKeepAlive) { 775 auto &liveOperations = contextRef->liveOperations; 776 // Create. 777 PyOperation *unownedOperation = 778 new PyOperation(std::move(contextRef), operation); 779 // Note that the default return value policy on cast is automatic_reference, 780 // which does not take ownership (delete will not be called). 781 // Just be explicit. 782 py::object pyRef = 783 py::cast(unownedOperation, py::return_value_policy::take_ownership); 784 unownedOperation->handle = pyRef; 785 if (parentKeepAlive) { 786 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 787 } 788 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 789 return PyOperationRef(unownedOperation, std::move(pyRef)); 790 } 791 792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 793 MlirOperation operation, 794 py::object parentKeepAlive) { 795 auto &liveOperations = contextRef->liveOperations; 796 auto it = liveOperations.find(operation.ptr); 797 if (it == liveOperations.end()) { 798 // Create. 799 return createInstance(std::move(contextRef), operation, 800 std::move(parentKeepAlive)); 801 } 802 // Use existing. 803 PyOperation *existing = it->second.second; 804 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 805 return PyOperationRef(existing, std::move(pyRef)); 806 } 807 808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 809 MlirOperation operation, 810 py::object parentKeepAlive) { 811 auto &liveOperations = contextRef->liveOperations; 812 assert(liveOperations.count(operation.ptr) == 0 && 813 "cannot create detached operation that already exists"); 814 (void)liveOperations; 815 816 PyOperationRef created = createInstance(std::move(contextRef), operation, 817 std::move(parentKeepAlive)); 818 created->attached = false; 819 return created; 820 } 821 822 void PyOperation::checkValid() const { 823 if (!valid) { 824 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 825 } 826 } 827 828 void PyOperationBase::print(py::object fileObject, bool binary, 829 llvm::Optional<int64_t> largeElementsLimit, 830 bool enableDebugInfo, bool prettyDebugInfo, 831 bool printGenericOpForm, bool useLocalScope) { 832 PyOperation &operation = getOperation(); 833 operation.checkValid(); 834 if (fileObject.is_none()) 835 fileObject = py::module::import("sys").attr("stdout"); 836 837 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 838 fileObject.attr("write")("// Verification failed, printing generic form\n"); 839 printGenericOpForm = true; 840 } 841 842 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 843 if (largeElementsLimit) 844 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 845 if (enableDebugInfo) 846 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 847 if (printGenericOpForm) 848 mlirOpPrintingFlagsPrintGenericOpForm(flags); 849 850 PyFileAccumulator accum(fileObject, binary); 851 py::gil_scoped_release(); 852 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 853 accum.getUserData()); 854 mlirOpPrintingFlagsDestroy(flags); 855 } 856 857 py::object PyOperationBase::getAsm(bool binary, 858 llvm::Optional<int64_t> largeElementsLimit, 859 bool enableDebugInfo, bool prettyDebugInfo, 860 bool printGenericOpForm, 861 bool useLocalScope) { 862 py::object fileObject; 863 if (binary) { 864 fileObject = py::module::import("io").attr("BytesIO")(); 865 } else { 866 fileObject = py::module::import("io").attr("StringIO")(); 867 } 868 print(fileObject, /*binary=*/binary, 869 /*largeElementsLimit=*/largeElementsLimit, 870 /*enableDebugInfo=*/enableDebugInfo, 871 /*prettyDebugInfo=*/prettyDebugInfo, 872 /*printGenericOpForm=*/printGenericOpForm, 873 /*useLocalScope=*/useLocalScope); 874 875 return fileObject.attr("getvalue")(); 876 } 877 878 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 879 checkValid(); 880 if (!isAttached()) 881 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 882 MlirOperation operation = mlirOperationGetParentOperation(get()); 883 if (mlirOperationIsNull(operation)) 884 return {}; 885 return PyOperation::forOperation(getContext(), operation); 886 } 887 888 PyBlock PyOperation::getBlock() { 889 checkValid(); 890 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 891 MlirBlock block = mlirOperationGetBlock(get()); 892 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 893 assert(parentOperation && "Operation has no parent"); 894 return PyBlock{std::move(*parentOperation), block}; 895 } 896 897 py::object PyOperation::getCapsule() { 898 checkValid(); 899 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 900 } 901 902 py::object PyOperation::createFromCapsule(py::object capsule) { 903 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 904 if (mlirOperationIsNull(rawOperation)) 905 throw py::error_already_set(); 906 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 907 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 908 .releaseObject(); 909 } 910 911 py::object PyOperation::create( 912 std::string name, llvm::Optional<std::vector<PyType *>> results, 913 llvm::Optional<std::vector<PyValue *>> operands, 914 llvm::Optional<py::dict> attributes, 915 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 916 DefaultingPyLocation location, py::object maybeIp) { 917 llvm::SmallVector<MlirValue, 4> mlirOperands; 918 llvm::SmallVector<MlirType, 4> mlirResults; 919 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 920 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 921 922 // General parameter validation. 923 if (regions < 0) 924 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 925 926 // Unpack/validate operands. 927 if (operands) { 928 mlirOperands.reserve(operands->size()); 929 for (PyValue *operand : *operands) { 930 if (!operand) 931 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 932 mlirOperands.push_back(operand->get()); 933 } 934 } 935 936 // Unpack/validate results. 937 if (results) { 938 mlirResults.reserve(results->size()); 939 for (PyType *result : *results) { 940 // TODO: Verify result type originate from the same context. 941 if (!result) 942 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 943 mlirResults.push_back(*result); 944 } 945 } 946 // Unpack/validate attributes. 947 if (attributes) { 948 mlirAttributes.reserve(attributes->size()); 949 for (auto &it : *attributes) { 950 std::string key; 951 try { 952 key = it.first.cast<std::string>(); 953 } catch (py::cast_error &err) { 954 std::string msg = "Invalid attribute key (not a string) when " 955 "attempting to create the operation \"" + 956 name + "\" (" + err.what() + ")"; 957 throw py::cast_error(msg); 958 } 959 try { 960 auto &attribute = it.second.cast<PyAttribute &>(); 961 // TODO: Verify attribute originates from the same context. 962 mlirAttributes.emplace_back(std::move(key), attribute); 963 } catch (py::reference_cast_error &) { 964 // This exception seems thrown when the value is "None". 965 std::string msg = 966 "Found an invalid (`None`?) attribute value for the key \"" + key + 967 "\" when attempting to create the operation \"" + name + "\""; 968 throw py::cast_error(msg); 969 } catch (py::cast_error &err) { 970 std::string msg = "Invalid attribute value for the key \"" + key + 971 "\" when attempting to create the operation \"" + 972 name + "\" (" + err.what() + ")"; 973 throw py::cast_error(msg); 974 } 975 } 976 } 977 // Unpack/validate successors. 978 if (successors) { 979 mlirSuccessors.reserve(successors->size()); 980 for (auto *successor : *successors) { 981 // TODO: Verify successor originate from the same context. 982 if (!successor) 983 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 984 mlirSuccessors.push_back(successor->get()); 985 } 986 } 987 988 // Apply unpacked/validated to the operation state. Beyond this 989 // point, exceptions cannot be thrown or else the state will leak. 990 MlirOperationState state = 991 mlirOperationStateGet(toMlirStringRef(name), location); 992 if (!mlirOperands.empty()) 993 mlirOperationStateAddOperands(&state, mlirOperands.size(), 994 mlirOperands.data()); 995 if (!mlirResults.empty()) 996 mlirOperationStateAddResults(&state, mlirResults.size(), 997 mlirResults.data()); 998 if (!mlirAttributes.empty()) { 999 // Note that the attribute names directly reference bytes in 1000 // mlirAttributes, so that vector must not be changed from here 1001 // on. 1002 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1003 mlirNamedAttributes.reserve(mlirAttributes.size()); 1004 for (auto &it : mlirAttributes) 1005 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1006 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1007 toMlirStringRef(it.first)), 1008 it.second)); 1009 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1010 mlirNamedAttributes.data()); 1011 } 1012 if (!mlirSuccessors.empty()) 1013 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1014 mlirSuccessors.data()); 1015 if (regions) { 1016 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1017 mlirRegions.resize(regions); 1018 for (int i = 0; i < regions; ++i) 1019 mlirRegions[i] = mlirRegionCreate(); 1020 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1021 mlirRegions.data()); 1022 } 1023 1024 // Construct the operation. 1025 MlirOperation operation = mlirOperationCreate(&state); 1026 PyOperationRef created = 1027 PyOperation::createDetached(location->getContext(), operation); 1028 1029 // InsertPoint active? 1030 if (!maybeIp.is(py::cast(false))) { 1031 PyInsertionPoint *ip; 1032 if (maybeIp.is_none()) { 1033 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1034 } else { 1035 ip = py::cast<PyInsertionPoint *>(maybeIp); 1036 } 1037 if (ip) 1038 ip->insert(*created.get()); 1039 } 1040 1041 return created->createOpView(); 1042 } 1043 1044 py::object PyOperation::createOpView() { 1045 checkValid(); 1046 MlirIdentifier ident = mlirOperationGetName(get()); 1047 MlirStringRef identStr = mlirIdentifierStr(ident); 1048 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1049 StringRef(identStr.data, identStr.length)); 1050 if (opViewClass) 1051 return (*opViewClass)(getRef().getObject()); 1052 return py::cast(PyOpView(getRef().getObject())); 1053 } 1054 1055 void PyOperation::erase() { 1056 checkValid(); 1057 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1058 // Python reference to a child operation is live. All children should also 1059 // have their `valid` bit set to false. 1060 auto &liveOperations = getContext()->liveOperations; 1061 if (liveOperations.count(operation.ptr)) 1062 liveOperations.erase(operation.ptr); 1063 mlirOperationDestroy(operation); 1064 valid = false; 1065 } 1066 1067 //------------------------------------------------------------------------------ 1068 // PyOpView 1069 //------------------------------------------------------------------------------ 1070 1071 py::object 1072 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1073 py::list operandList, 1074 llvm::Optional<py::dict> attributes, 1075 llvm::Optional<std::vector<PyBlock *>> successors, 1076 llvm::Optional<int> regions, 1077 DefaultingPyLocation location, py::object maybeIp) { 1078 PyMlirContextRef context = location->getContext(); 1079 // Class level operation construction metadata. 1080 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1081 // Operand and result segment specs are either none, which does no 1082 // variadic unpacking, or a list of ints with segment sizes, where each 1083 // element is either a positive number (typically 1 for a scalar) or -1 to 1084 // indicate that it is derived from the length of the same-indexed operand 1085 // or result (implying that it is a list at that position). 1086 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1087 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1088 1089 std::vector<uint32_t> operandSegmentLengths; 1090 std::vector<uint32_t> resultSegmentLengths; 1091 1092 // Validate/determine region count. 1093 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1094 int opMinRegionCount = std::get<0>(opRegionSpec); 1095 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1096 if (!regions) { 1097 regions = opMinRegionCount; 1098 } 1099 if (*regions < opMinRegionCount) { 1100 throw py::value_error( 1101 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1102 llvm::Twine(opMinRegionCount) + 1103 " regions but was built with regions=" + llvm::Twine(*regions)) 1104 .str()); 1105 } 1106 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1107 throw py::value_error( 1108 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1109 llvm::Twine(opMinRegionCount) + 1110 " regions but was built with regions=" + llvm::Twine(*regions)) 1111 .str()); 1112 } 1113 1114 // Unpack results. 1115 std::vector<PyType *> resultTypes; 1116 resultTypes.reserve(resultTypeList.size()); 1117 if (resultSegmentSpecObj.is_none()) { 1118 // Non-variadic result unpacking. 1119 for (auto it : llvm::enumerate(resultTypeList)) { 1120 try { 1121 resultTypes.push_back(py::cast<PyType *>(it.value())); 1122 if (!resultTypes.back()) 1123 throw py::cast_error(); 1124 } catch (py::cast_error &err) { 1125 throw py::value_error((llvm::Twine("Result ") + 1126 llvm::Twine(it.index()) + " of operation \"" + 1127 name + "\" must be a Type (" + err.what() + ")") 1128 .str()); 1129 } 1130 } 1131 } else { 1132 // Sized result unpacking. 1133 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1134 if (resultSegmentSpec.size() != resultTypeList.size()) { 1135 throw py::value_error((llvm::Twine("Operation \"") + name + 1136 "\" requires " + 1137 llvm::Twine(resultSegmentSpec.size()) + 1138 "result segments but was provided " + 1139 llvm::Twine(resultTypeList.size())) 1140 .str()); 1141 } 1142 resultSegmentLengths.reserve(resultTypeList.size()); 1143 for (auto it : 1144 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1145 int segmentSpec = std::get<1>(it.value()); 1146 if (segmentSpec == 1 || segmentSpec == 0) { 1147 // Unpack unary element. 1148 try { 1149 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1150 if (resultType) { 1151 resultTypes.push_back(resultType); 1152 resultSegmentLengths.push_back(1); 1153 } else if (segmentSpec == 0) { 1154 // Allowed to be optional. 1155 resultSegmentLengths.push_back(0); 1156 } else { 1157 throw py::cast_error("was None and result is not optional"); 1158 } 1159 } catch (py::cast_error &err) { 1160 throw py::value_error((llvm::Twine("Result ") + 1161 llvm::Twine(it.index()) + " of operation \"" + 1162 name + "\" must be a Type (" + err.what() + 1163 ")") 1164 .str()); 1165 } 1166 } else if (segmentSpec == -1) { 1167 // Unpack sequence by appending. 1168 try { 1169 if (std::get<0>(it.value()).is_none()) { 1170 // Treat it as an empty list. 1171 resultSegmentLengths.push_back(0); 1172 } else { 1173 // Unpack the list. 1174 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1175 for (py::object segmentItem : segment) { 1176 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1177 if (!resultTypes.back()) { 1178 throw py::cast_error("contained a None item"); 1179 } 1180 } 1181 resultSegmentLengths.push_back(segment.size()); 1182 } 1183 } catch (std::exception &err) { 1184 // NOTE: Sloppy to be using a catch-all here, but there are at least 1185 // three different unrelated exceptions that can be thrown in the 1186 // above "casts". Just keep the scope above small and catch them all. 1187 throw py::value_error((llvm::Twine("Result ") + 1188 llvm::Twine(it.index()) + " of operation \"" + 1189 name + "\" must be a Sequence of Types (" + 1190 err.what() + ")") 1191 .str()); 1192 } 1193 } else { 1194 throw py::value_error("Unexpected segment spec"); 1195 } 1196 } 1197 } 1198 1199 // Unpack operands. 1200 std::vector<PyValue *> operands; 1201 operands.reserve(operands.size()); 1202 if (operandSegmentSpecObj.is_none()) { 1203 // Non-sized operand unpacking. 1204 for (auto it : llvm::enumerate(operandList)) { 1205 try { 1206 operands.push_back(py::cast<PyValue *>(it.value())); 1207 if (!operands.back()) 1208 throw py::cast_error(); 1209 } catch (py::cast_error &err) { 1210 throw py::value_error((llvm::Twine("Operand ") + 1211 llvm::Twine(it.index()) + " of operation \"" + 1212 name + "\" must be a Value (" + err.what() + ")") 1213 .str()); 1214 } 1215 } 1216 } else { 1217 // Sized operand unpacking. 1218 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1219 if (operandSegmentSpec.size() != operandList.size()) { 1220 throw py::value_error((llvm::Twine("Operation \"") + name + 1221 "\" requires " + 1222 llvm::Twine(operandSegmentSpec.size()) + 1223 "operand segments but was provided " + 1224 llvm::Twine(operandList.size())) 1225 .str()); 1226 } 1227 operandSegmentLengths.reserve(operandList.size()); 1228 for (auto it : 1229 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1230 int segmentSpec = std::get<1>(it.value()); 1231 if (segmentSpec == 1 || segmentSpec == 0) { 1232 // Unpack unary element. 1233 try { 1234 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1235 if (operandValue) { 1236 operands.push_back(operandValue); 1237 operandSegmentLengths.push_back(1); 1238 } else if (segmentSpec == 0) { 1239 // Allowed to be optional. 1240 operandSegmentLengths.push_back(0); 1241 } else { 1242 throw py::cast_error("was None and operand is not optional"); 1243 } 1244 } catch (py::cast_error &err) { 1245 throw py::value_error((llvm::Twine("Operand ") + 1246 llvm::Twine(it.index()) + " of operation \"" + 1247 name + "\" must be a Value (" + err.what() + 1248 ")") 1249 .str()); 1250 } 1251 } else if (segmentSpec == -1) { 1252 // Unpack sequence by appending. 1253 try { 1254 if (std::get<0>(it.value()).is_none()) { 1255 // Treat it as an empty list. 1256 operandSegmentLengths.push_back(0); 1257 } else { 1258 // Unpack the list. 1259 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1260 for (py::object segmentItem : segment) { 1261 operands.push_back(py::cast<PyValue *>(segmentItem)); 1262 if (!operands.back()) { 1263 throw py::cast_error("contained a None item"); 1264 } 1265 } 1266 operandSegmentLengths.push_back(segment.size()); 1267 } 1268 } catch (std::exception &err) { 1269 // NOTE: Sloppy to be using a catch-all here, but there are at least 1270 // three different unrelated exceptions that can be thrown in the 1271 // above "casts". Just keep the scope above small and catch them all. 1272 throw py::value_error((llvm::Twine("Operand ") + 1273 llvm::Twine(it.index()) + " of operation \"" + 1274 name + "\" must be a Sequence of Values (" + 1275 err.what() + ")") 1276 .str()); 1277 } 1278 } else { 1279 throw py::value_error("Unexpected segment spec"); 1280 } 1281 } 1282 } 1283 1284 // Merge operand/result segment lengths into attributes if needed. 1285 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1286 // Dup. 1287 if (attributes) { 1288 attributes = py::dict(*attributes); 1289 } else { 1290 attributes = py::dict(); 1291 } 1292 if (attributes->contains("result_segment_sizes") || 1293 attributes->contains("operand_segment_sizes")) { 1294 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1295 "'operand_segment_sizes' attribute is unsupported. " 1296 "Use Operation.create for such low-level access."); 1297 } 1298 1299 // Add result_segment_sizes attribute. 1300 if (!resultSegmentLengths.empty()) { 1301 int64_t size = resultSegmentLengths.size(); 1302 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1303 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1304 resultSegmentLengths.size(), resultSegmentLengths.data()); 1305 (*attributes)["result_segment_sizes"] = 1306 PyAttribute(context, segmentLengthAttr); 1307 } 1308 1309 // Add operand_segment_sizes attribute. 1310 if (!operandSegmentLengths.empty()) { 1311 int64_t size = operandSegmentLengths.size(); 1312 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1313 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1314 operandSegmentLengths.size(), operandSegmentLengths.data()); 1315 (*attributes)["operand_segment_sizes"] = 1316 PyAttribute(context, segmentLengthAttr); 1317 } 1318 } 1319 1320 // Delegate to create. 1321 return PyOperation::create(std::move(name), 1322 /*results=*/std::move(resultTypes), 1323 /*operands=*/std::move(operands), 1324 /*attributes=*/std::move(attributes), 1325 /*successors=*/std::move(successors), 1326 /*regions=*/*regions, location, maybeIp); 1327 } 1328 1329 PyOpView::PyOpView(py::object operationObject) 1330 // Casting through the PyOperationBase base-class and then back to the 1331 // Operation lets us accept any PyOperationBase subclass. 1332 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1333 operationObject(operation.getRef().getObject()) {} 1334 1335 py::object PyOpView::createRawSubclass(py::object userClass) { 1336 // This is... a little gross. The typical pattern is to have a pure python 1337 // class that extends OpView like: 1338 // class AddFOp(_cext.ir.OpView): 1339 // def __init__(self, loc, lhs, rhs): 1340 // operation = loc.context.create_operation( 1341 // "addf", lhs, rhs, results=[lhs.type]) 1342 // super().__init__(operation) 1343 // 1344 // I.e. The goal of the user facing type is to provide a nice constructor 1345 // that has complete freedom for the op under construction. This is at odds 1346 // with our other desire to sometimes create this object by just passing an 1347 // operation (to initialize the base class). We could do *arg and **kwargs 1348 // munging to try to make it work, but instead, we synthesize a new class 1349 // on the fly which extends this user class (AddFOp in this example) and 1350 // *give it* the base class's __init__ method, thus bypassing the 1351 // intermediate subclass's __init__ method entirely. While slightly, 1352 // underhanded, this is safe/legal because the type hierarchy has not changed 1353 // (we just added a new leaf) and we aren't mucking around with __new__. 1354 // Typically, this new class will be stored on the original as "_Raw" and will 1355 // be used for casts and other things that need a variant of the class that 1356 // is initialized purely from an operation. 1357 py::object parentMetaclass = 1358 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1359 py::dict attributes; 1360 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1361 // now. 1362 // auto opViewType = py::type::of<PyOpView>(); 1363 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1364 attributes["__init__"] = opViewType.attr("__init__"); 1365 py::str origName = userClass.attr("__name__"); 1366 py::str newName = py::str("_") + origName; 1367 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1368 } 1369 1370 //------------------------------------------------------------------------------ 1371 // PyInsertionPoint. 1372 //------------------------------------------------------------------------------ 1373 1374 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1375 1376 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1377 : refOperation(beforeOperationBase.getOperation().getRef()), 1378 block((*refOperation)->getBlock()) {} 1379 1380 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1381 PyOperation &operation = operationBase.getOperation(); 1382 if (operation.isAttached()) 1383 throw SetPyError(PyExc_ValueError, 1384 "Attempt to insert operation that is already attached"); 1385 block.getParentOperation()->checkValid(); 1386 MlirOperation beforeOp = {nullptr}; 1387 if (refOperation) { 1388 // Insert before operation. 1389 (*refOperation)->checkValid(); 1390 beforeOp = (*refOperation)->get(); 1391 } else { 1392 // Insert at end (before null) is only valid if the block does not 1393 // already end in a known terminator (violating this will cause assertion 1394 // failures later). 1395 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1396 throw py::index_error("Cannot insert operation at the end of a block " 1397 "that already has a terminator. Did you mean to " 1398 "use 'InsertionPoint.at_block_terminator(block)' " 1399 "versus 'InsertionPoint(block)'?"); 1400 } 1401 } 1402 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1403 operation.setAttached(); 1404 } 1405 1406 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1407 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1408 if (mlirOperationIsNull(firstOp)) { 1409 // Just insert at end. 1410 return PyInsertionPoint(block); 1411 } 1412 1413 // Insert before first op. 1414 PyOperationRef firstOpRef = PyOperation::forOperation( 1415 block.getParentOperation()->getContext(), firstOp); 1416 return PyInsertionPoint{block, std::move(firstOpRef)}; 1417 } 1418 1419 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1420 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1421 if (mlirOperationIsNull(terminator)) 1422 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1423 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1424 block.getParentOperation()->getContext(), terminator); 1425 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1426 } 1427 1428 py::object PyInsertionPoint::contextEnter() { 1429 return PyThreadContextEntry::pushInsertionPoint(*this); 1430 } 1431 1432 void PyInsertionPoint::contextExit(pybind11::object excType, 1433 pybind11::object excVal, 1434 pybind11::object excTb) { 1435 PyThreadContextEntry::popInsertionPoint(*this); 1436 } 1437 1438 //------------------------------------------------------------------------------ 1439 // PyAttribute. 1440 //------------------------------------------------------------------------------ 1441 1442 bool PyAttribute::operator==(const PyAttribute &other) { 1443 return mlirAttributeEqual(attr, other.attr); 1444 } 1445 1446 py::object PyAttribute::getCapsule() { 1447 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1448 } 1449 1450 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1451 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1452 if (mlirAttributeIsNull(rawAttr)) 1453 throw py::error_already_set(); 1454 return PyAttribute( 1455 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1456 } 1457 1458 //------------------------------------------------------------------------------ 1459 // PyNamedAttribute. 1460 //------------------------------------------------------------------------------ 1461 1462 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1463 : ownedName(new std::string(std::move(ownedName))) { 1464 namedAttr = mlirNamedAttributeGet( 1465 mlirIdentifierGet(mlirAttributeGetContext(attr), 1466 toMlirStringRef(*this->ownedName)), 1467 attr); 1468 } 1469 1470 //------------------------------------------------------------------------------ 1471 // PyType. 1472 //------------------------------------------------------------------------------ 1473 1474 bool PyType::operator==(const PyType &other) { 1475 return mlirTypeEqual(type, other.type); 1476 } 1477 1478 py::object PyType::getCapsule() { 1479 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1480 } 1481 1482 PyType PyType::createFromCapsule(py::object capsule) { 1483 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1484 if (mlirTypeIsNull(rawType)) 1485 throw py::error_already_set(); 1486 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1487 rawType); 1488 } 1489 1490 //------------------------------------------------------------------------------ 1491 // PyValue and subclases. 1492 //------------------------------------------------------------------------------ 1493 1494 pybind11::object PyValue::getCapsule() { 1495 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1496 } 1497 1498 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1499 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1500 if (mlirValueIsNull(value)) 1501 throw py::error_already_set(); 1502 MlirOperation owner; 1503 if (mlirValueIsAOpResult(value)) 1504 owner = mlirOpResultGetOwner(value); 1505 if (mlirValueIsABlockArgument(value)) 1506 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1507 if (mlirOperationIsNull(owner)) 1508 throw py::error_already_set(); 1509 MlirContext ctx = mlirOperationGetContext(owner); 1510 PyOperationRef ownerRef = 1511 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1512 return PyValue(ownerRef, value); 1513 } 1514 1515 namespace { 1516 /// CRTP base class for Python MLIR values that subclass Value and should be 1517 /// castable from it. The value hierarchy is one level deep and is not supposed 1518 /// to accommodate other levels unless core MLIR changes. 1519 template <typename DerivedTy> 1520 class PyConcreteValue : public PyValue { 1521 public: 1522 // Derived classes must define statics for: 1523 // IsAFunctionTy isaFunction 1524 // const char *pyClassName 1525 // and redefine bindDerived. 1526 using ClassTy = py::class_<DerivedTy, PyValue>; 1527 using IsAFunctionTy = bool (*)(MlirValue); 1528 1529 PyConcreteValue() = default; 1530 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1531 : PyValue(operationRef, value) {} 1532 PyConcreteValue(PyValue &orig) 1533 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1534 1535 /// Attempts to cast the original value to the derived type and throws on 1536 /// type mismatches. 1537 static MlirValue castFrom(PyValue &orig) { 1538 if (!DerivedTy::isaFunction(orig.get())) { 1539 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1540 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1541 DerivedTy::pyClassName + 1542 " (from " + origRepr + ")"); 1543 } 1544 return orig.get(); 1545 } 1546 1547 /// Binds the Python module objects to functions of this class. 1548 static void bind(py::module &m) { 1549 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1550 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1551 cls.def_static("isinstance", [](PyValue &otherValue) -> bool { 1552 return DerivedTy::isaFunction(otherValue); 1553 }); 1554 DerivedTy::bindDerived(cls); 1555 } 1556 1557 /// Implemented by derived classes to add methods to the Python subclass. 1558 static void bindDerived(ClassTy &m) {} 1559 }; 1560 1561 /// Python wrapper for MlirBlockArgument. 1562 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1563 public: 1564 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1565 static constexpr const char *pyClassName = "BlockArgument"; 1566 using PyConcreteValue::PyConcreteValue; 1567 1568 static void bindDerived(ClassTy &c) { 1569 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1570 return PyBlock(self.getParentOperation(), 1571 mlirBlockArgumentGetOwner(self.get())); 1572 }); 1573 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1574 return mlirBlockArgumentGetArgNumber(self.get()); 1575 }); 1576 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1577 return mlirBlockArgumentSetType(self.get(), type); 1578 }); 1579 } 1580 }; 1581 1582 /// Python wrapper for MlirOpResult. 1583 class PyOpResult : public PyConcreteValue<PyOpResult> { 1584 public: 1585 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1586 static constexpr const char *pyClassName = "OpResult"; 1587 using PyConcreteValue::PyConcreteValue; 1588 1589 static void bindDerived(ClassTy &c) { 1590 c.def_property_readonly("owner", [](PyOpResult &self) { 1591 assert( 1592 mlirOperationEqual(self.getParentOperation()->get(), 1593 mlirOpResultGetOwner(self.get())) && 1594 "expected the owner of the value in Python to match that in the IR"); 1595 return self.getParentOperation().getObject(); 1596 }); 1597 c.def_property_readonly("result_number", [](PyOpResult &self) { 1598 return mlirOpResultGetResultNumber(self.get()); 1599 }); 1600 } 1601 }; 1602 1603 /// Returns the list of types of the values held by container. 1604 template <typename Container> 1605 static std::vector<PyType> getValueTypes(Container &container, 1606 PyMlirContextRef &context) { 1607 std::vector<PyType> result; 1608 result.reserve(container.getNumElements()); 1609 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1610 result.push_back( 1611 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1612 } 1613 return result; 1614 } 1615 1616 /// A list of block arguments. Internally, these are stored as consecutive 1617 /// elements, random access is cheap. The argument list is associated with the 1618 /// operation that contains the block (detached blocks are not allowed in 1619 /// Python bindings) and extends its lifetime. 1620 class PyBlockArgumentList 1621 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1622 public: 1623 static constexpr const char *pyClassName = "BlockArgumentList"; 1624 1625 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1626 intptr_t startIndex = 0, intptr_t length = -1, 1627 intptr_t step = 1) 1628 : Sliceable(startIndex, 1629 length == -1 ? mlirBlockGetNumArguments(block) : length, 1630 step), 1631 operation(std::move(operation)), block(block) {} 1632 1633 /// Returns the number of arguments in the list. 1634 intptr_t getNumElements() { 1635 operation->checkValid(); 1636 return mlirBlockGetNumArguments(block); 1637 } 1638 1639 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1640 PyBlockArgument getElement(intptr_t pos) { 1641 MlirValue argument = mlirBlockGetArgument(block, pos); 1642 return PyBlockArgument(operation, argument); 1643 } 1644 1645 /// Returns a sublist of this list. 1646 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1647 intptr_t step) { 1648 return PyBlockArgumentList(operation, block, startIndex, length, step); 1649 } 1650 1651 static void bindDerived(ClassTy &c) { 1652 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1653 return getValueTypes(self, self.operation->getContext()); 1654 }); 1655 } 1656 1657 private: 1658 PyOperationRef operation; 1659 MlirBlock block; 1660 }; 1661 1662 /// A list of operation operands. Internally, these are stored as consecutive 1663 /// elements, random access is cheap. The result list is associated with the 1664 /// operation whose results these are, and extends the lifetime of this 1665 /// operation. 1666 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1667 public: 1668 static constexpr const char *pyClassName = "OpOperandList"; 1669 1670 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1671 intptr_t length = -1, intptr_t step = 1) 1672 : Sliceable(startIndex, 1673 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1674 : length, 1675 step), 1676 operation(operation) {} 1677 1678 intptr_t getNumElements() { 1679 operation->checkValid(); 1680 return mlirOperationGetNumOperands(operation->get()); 1681 } 1682 1683 PyValue getElement(intptr_t pos) { 1684 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1685 MlirOperation owner; 1686 if (mlirValueIsAOpResult(operand)) 1687 owner = mlirOpResultGetOwner(operand); 1688 else if (mlirValueIsABlockArgument(operand)) 1689 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1690 else 1691 assert(false && "Value must be an block arg or op result."); 1692 PyOperationRef pyOwner = 1693 PyOperation::forOperation(operation->getContext(), owner); 1694 return PyValue(pyOwner, operand); 1695 } 1696 1697 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1698 return PyOpOperandList(operation, startIndex, length, step); 1699 } 1700 1701 void dunderSetItem(intptr_t index, PyValue value) { 1702 index = wrapIndex(index); 1703 mlirOperationSetOperand(operation->get(), index, value.get()); 1704 } 1705 1706 static void bindDerived(ClassTy &c) { 1707 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1708 } 1709 1710 private: 1711 PyOperationRef operation; 1712 }; 1713 1714 /// A list of operation results. Internally, these are stored as consecutive 1715 /// elements, random access is cheap. The result list is associated with the 1716 /// operation whose results these are, and extends the lifetime of this 1717 /// operation. 1718 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1719 public: 1720 static constexpr const char *pyClassName = "OpResultList"; 1721 1722 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1723 intptr_t length = -1, intptr_t step = 1) 1724 : Sliceable(startIndex, 1725 length == -1 ? mlirOperationGetNumResults(operation->get()) 1726 : length, 1727 step), 1728 operation(operation) {} 1729 1730 intptr_t getNumElements() { 1731 operation->checkValid(); 1732 return mlirOperationGetNumResults(operation->get()); 1733 } 1734 1735 PyOpResult getElement(intptr_t index) { 1736 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1737 return PyOpResult(value); 1738 } 1739 1740 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1741 return PyOpResultList(operation, startIndex, length, step); 1742 } 1743 1744 static void bindDerived(ClassTy &c) { 1745 c.def_property_readonly("types", [](PyOpResultList &self) { 1746 return getValueTypes(self, self.operation->getContext()); 1747 }); 1748 } 1749 1750 private: 1751 PyOperationRef operation; 1752 }; 1753 1754 /// A list of operation attributes. Can be indexed by name, producing 1755 /// attributes, or by index, producing named attributes. 1756 class PyOpAttributeMap { 1757 public: 1758 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1759 1760 PyAttribute dunderGetItemNamed(const std::string &name) { 1761 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1762 toMlirStringRef(name)); 1763 if (mlirAttributeIsNull(attr)) { 1764 throw SetPyError(PyExc_KeyError, 1765 "attempt to access a non-existent attribute"); 1766 } 1767 return PyAttribute(operation->getContext(), attr); 1768 } 1769 1770 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1771 if (index < 0 || index >= dunderLen()) { 1772 throw SetPyError(PyExc_IndexError, 1773 "attempt to access out of bounds attribute"); 1774 } 1775 MlirNamedAttribute namedAttr = 1776 mlirOperationGetAttribute(operation->get(), index); 1777 return PyNamedAttribute( 1778 namedAttr.attribute, 1779 std::string(mlirIdentifierStr(namedAttr.name).data)); 1780 } 1781 1782 void dunderSetItem(const std::string &name, PyAttribute attr) { 1783 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1784 attr); 1785 } 1786 1787 void dunderDelItem(const std::string &name) { 1788 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1789 toMlirStringRef(name)); 1790 if (!removed) 1791 throw SetPyError(PyExc_KeyError, 1792 "attempt to delete a non-existent attribute"); 1793 } 1794 1795 intptr_t dunderLen() { 1796 return mlirOperationGetNumAttributes(operation->get()); 1797 } 1798 1799 bool dunderContains(const std::string &name) { 1800 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1801 operation->get(), toMlirStringRef(name))); 1802 } 1803 1804 static void bind(py::module &m) { 1805 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1806 .def("__contains__", &PyOpAttributeMap::dunderContains) 1807 .def("__len__", &PyOpAttributeMap::dunderLen) 1808 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1809 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1810 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1811 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1812 } 1813 1814 private: 1815 PyOperationRef operation; 1816 }; 1817 1818 } // end namespace 1819 1820 //------------------------------------------------------------------------------ 1821 // Populates the core exports of the 'ir' submodule. 1822 //------------------------------------------------------------------------------ 1823 1824 void mlir::python::populateIRCore(py::module &m) { 1825 //---------------------------------------------------------------------------- 1826 // Mapping of MlirContext. 1827 //---------------------------------------------------------------------------- 1828 py::class_<PyMlirContext>(m, "Context", py::module_local()) 1829 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1830 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1831 .def("_get_context_again", 1832 [](PyMlirContext &self) { 1833 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1834 return ref.releaseObject(); 1835 }) 1836 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1837 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1838 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1839 &PyMlirContext::getCapsule) 1840 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1841 .def("__enter__", &PyMlirContext::contextEnter) 1842 .def("__exit__", &PyMlirContext::contextExit) 1843 .def_property_readonly_static( 1844 "current", 1845 [](py::object & /*class*/) { 1846 auto *context = PyThreadContextEntry::getDefaultContext(); 1847 if (!context) 1848 throw SetPyError(PyExc_ValueError, "No current Context"); 1849 return context; 1850 }, 1851 "Gets the Context bound to the current thread or raises ValueError") 1852 .def_property_readonly( 1853 "dialects", 1854 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1855 "Gets a container for accessing dialects by name") 1856 .def_property_readonly( 1857 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1858 "Alias for 'dialect'") 1859 .def( 1860 "get_dialect_descriptor", 1861 [=](PyMlirContext &self, std::string &name) { 1862 MlirDialect dialect = mlirContextGetOrLoadDialect( 1863 self.get(), {name.data(), name.size()}); 1864 if (mlirDialectIsNull(dialect)) { 1865 throw SetPyError(PyExc_ValueError, 1866 Twine("Dialect '") + name + "' not found"); 1867 } 1868 return PyDialectDescriptor(self.getRef(), dialect); 1869 }, 1870 "Gets or loads a dialect by name, returning its descriptor object") 1871 .def_property( 1872 "allow_unregistered_dialects", 1873 [](PyMlirContext &self) -> bool { 1874 return mlirContextGetAllowUnregisteredDialects(self.get()); 1875 }, 1876 [](PyMlirContext &self, bool value) { 1877 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1878 }) 1879 .def("enable_multithreading", 1880 [](PyMlirContext &self, bool enable) { 1881 mlirContextEnableMultithreading(self.get(), enable); 1882 }) 1883 .def("is_registered_operation", 1884 [](PyMlirContext &self, std::string &name) { 1885 return mlirContextIsRegisteredOperation( 1886 self.get(), MlirStringRef{name.data(), name.size()}); 1887 }); 1888 1889 //---------------------------------------------------------------------------- 1890 // Mapping of PyDialectDescriptor 1891 //---------------------------------------------------------------------------- 1892 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1893 .def_property_readonly("namespace", 1894 [](PyDialectDescriptor &self) { 1895 MlirStringRef ns = 1896 mlirDialectGetNamespace(self.get()); 1897 return py::str(ns.data, ns.length); 1898 }) 1899 .def("__repr__", [](PyDialectDescriptor &self) { 1900 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1901 std::string repr("<DialectDescriptor "); 1902 repr.append(ns.data, ns.length); 1903 repr.append(">"); 1904 return repr; 1905 }); 1906 1907 //---------------------------------------------------------------------------- 1908 // Mapping of PyDialects 1909 //---------------------------------------------------------------------------- 1910 py::class_<PyDialects>(m, "Dialects", py::module_local()) 1911 .def("__getitem__", 1912 [=](PyDialects &self, std::string keyName) { 1913 MlirDialect dialect = 1914 self.getDialectForKey(keyName, /*attrError=*/false); 1915 py::object descriptor = 1916 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1917 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1918 }) 1919 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1920 MlirDialect dialect = 1921 self.getDialectForKey(attrName, /*attrError=*/true); 1922 py::object descriptor = 1923 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1924 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1925 }); 1926 1927 //---------------------------------------------------------------------------- 1928 // Mapping of PyDialect 1929 //---------------------------------------------------------------------------- 1930 py::class_<PyDialect>(m, "Dialect", py::module_local()) 1931 .def(py::init<py::object>(), "descriptor") 1932 .def_property_readonly( 1933 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1934 .def("__repr__", [](py::object self) { 1935 auto clazz = self.attr("__class__"); 1936 return py::str("<Dialect ") + 1937 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1938 clazz.attr("__module__") + py::str(".") + 1939 clazz.attr("__name__") + py::str(")>"); 1940 }); 1941 1942 //---------------------------------------------------------------------------- 1943 // Mapping of Location 1944 //---------------------------------------------------------------------------- 1945 py::class_<PyLocation>(m, "Location", py::module_local()) 1946 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1947 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1948 .def("__enter__", &PyLocation::contextEnter) 1949 .def("__exit__", &PyLocation::contextExit) 1950 .def("__eq__", 1951 [](PyLocation &self, PyLocation &other) -> bool { 1952 return mlirLocationEqual(self, other); 1953 }) 1954 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1955 .def_property_readonly_static( 1956 "current", 1957 [](py::object & /*class*/) { 1958 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1959 if (!loc) 1960 throw SetPyError(PyExc_ValueError, "No current Location"); 1961 return loc; 1962 }, 1963 "Gets the Location bound to the current thread or raises ValueError") 1964 .def_static( 1965 "unknown", 1966 [](DefaultingPyMlirContext context) { 1967 return PyLocation(context->getRef(), 1968 mlirLocationUnknownGet(context->get())); 1969 }, 1970 py::arg("context") = py::none(), 1971 "Gets a Location representing an unknown location") 1972 .def_static( 1973 "callsite", 1974 [](PyLocation callee, const std::vector<PyLocation> &frames, 1975 DefaultingPyMlirContext context) { 1976 if (frames.empty()) 1977 throw py::value_error("No caller frames provided"); 1978 MlirLocation caller = frames.back().get(); 1979 for (PyLocation frame : 1980 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 1981 caller = mlirLocationCallSiteGet(frame.get(), caller); 1982 return PyLocation(context->getRef(), 1983 mlirLocationCallSiteGet(callee.get(), caller)); 1984 }, 1985 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 1986 kContextGetCallSiteLocationDocstring) 1987 .def_static( 1988 "file", 1989 [](std::string filename, int line, int col, 1990 DefaultingPyMlirContext context) { 1991 return PyLocation( 1992 context->getRef(), 1993 mlirLocationFileLineColGet( 1994 context->get(), toMlirStringRef(filename), line, col)); 1995 }, 1996 py::arg("filename"), py::arg("line"), py::arg("col"), 1997 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1998 .def_static( 1999 "name", 2000 [](std::string name, llvm::Optional<PyLocation> childLoc, 2001 DefaultingPyMlirContext context) { 2002 return PyLocation( 2003 context->getRef(), 2004 mlirLocationNameGet( 2005 context->get(), toMlirStringRef(name), 2006 childLoc ? childLoc->get() 2007 : mlirLocationUnknownGet(context->get()))); 2008 }, 2009 py::arg("name"), py::arg("childLoc") = py::none(), 2010 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2011 .def_property_readonly( 2012 "context", 2013 [](PyLocation &self) { return self.getContext().getObject(); }, 2014 "Context that owns the Location") 2015 .def("__repr__", [](PyLocation &self) { 2016 PyPrintAccumulator printAccum; 2017 mlirLocationPrint(self, printAccum.getCallback(), 2018 printAccum.getUserData()); 2019 return printAccum.join(); 2020 }); 2021 2022 //---------------------------------------------------------------------------- 2023 // Mapping of Module 2024 //---------------------------------------------------------------------------- 2025 py::class_<PyModule>(m, "Module", py::module_local()) 2026 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2027 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2028 .def_static( 2029 "parse", 2030 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2031 MlirModule module = mlirModuleCreateParse( 2032 context->get(), toMlirStringRef(moduleAsm)); 2033 // TODO: Rework error reporting once diagnostic engine is exposed 2034 // in C API. 2035 if (mlirModuleIsNull(module)) { 2036 throw SetPyError( 2037 PyExc_ValueError, 2038 "Unable to parse module assembly (see diagnostics)"); 2039 } 2040 return PyModule::forModule(module).releaseObject(); 2041 }, 2042 py::arg("asm"), py::arg("context") = py::none(), 2043 kModuleParseDocstring) 2044 .def_static( 2045 "create", 2046 [](DefaultingPyLocation loc) { 2047 MlirModule module = mlirModuleCreateEmpty(loc); 2048 return PyModule::forModule(module).releaseObject(); 2049 }, 2050 py::arg("loc") = py::none(), "Creates an empty module") 2051 .def_property_readonly( 2052 "context", 2053 [](PyModule &self) { return self.getContext().getObject(); }, 2054 "Context that created the Module") 2055 .def_property_readonly( 2056 "operation", 2057 [](PyModule &self) { 2058 return PyOperation::forOperation(self.getContext(), 2059 mlirModuleGetOperation(self.get()), 2060 self.getRef().releaseObject()) 2061 .releaseObject(); 2062 }, 2063 "Accesses the module as an operation") 2064 .def_property_readonly( 2065 "body", 2066 [](PyModule &self) { 2067 PyOperationRef module_op = PyOperation::forOperation( 2068 self.getContext(), mlirModuleGetOperation(self.get()), 2069 self.getRef().releaseObject()); 2070 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2071 return returnBlock; 2072 }, 2073 "Return the block for this module") 2074 .def( 2075 "dump", 2076 [](PyModule &self) { 2077 mlirOperationDump(mlirModuleGetOperation(self.get())); 2078 }, 2079 kDumpDocstring) 2080 .def( 2081 "__str__", 2082 [](PyModule &self) { 2083 MlirOperation operation = mlirModuleGetOperation(self.get()); 2084 PyPrintAccumulator printAccum; 2085 mlirOperationPrint(operation, printAccum.getCallback(), 2086 printAccum.getUserData()); 2087 return printAccum.join(); 2088 }, 2089 kOperationStrDunderDocstring); 2090 2091 //---------------------------------------------------------------------------- 2092 // Mapping of Operation. 2093 //---------------------------------------------------------------------------- 2094 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2095 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2096 [](PyOperationBase &self) { 2097 return self.getOperation().getCapsule(); 2098 }) 2099 .def("__eq__", 2100 [](PyOperationBase &self, PyOperationBase &other) { 2101 return &self.getOperation() == &other.getOperation(); 2102 }) 2103 .def("__eq__", 2104 [](PyOperationBase &self, py::object other) { return false; }) 2105 .def_property_readonly("attributes", 2106 [](PyOperationBase &self) { 2107 return PyOpAttributeMap( 2108 self.getOperation().getRef()); 2109 }) 2110 .def_property_readonly("operands", 2111 [](PyOperationBase &self) { 2112 return PyOpOperandList( 2113 self.getOperation().getRef()); 2114 }) 2115 .def_property_readonly("regions", 2116 [](PyOperationBase &self) { 2117 return PyRegionList( 2118 self.getOperation().getRef()); 2119 }) 2120 .def_property_readonly( 2121 "results", 2122 [](PyOperationBase &self) { 2123 return PyOpResultList(self.getOperation().getRef()); 2124 }, 2125 "Returns the list of Operation results.") 2126 .def_property_readonly( 2127 "result", 2128 [](PyOperationBase &self) { 2129 auto &operation = self.getOperation(); 2130 auto numResults = mlirOperationGetNumResults(operation); 2131 if (numResults != 1) { 2132 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2133 throw SetPyError( 2134 PyExc_ValueError, 2135 Twine("Cannot call .result on operation ") + 2136 StringRef(name.data, name.length) + " which has " + 2137 Twine(numResults) + 2138 " results (it is only valid for operations with a " 2139 "single result)"); 2140 } 2141 return PyOpResult(operation.getRef(), 2142 mlirOperationGetResult(operation, 0)); 2143 }, 2144 "Shortcut to get an op result if it has only one (throws an error " 2145 "otherwise).") 2146 .def("__iter__", 2147 [](PyOperationBase &self) { 2148 return PyRegionIterator(self.getOperation().getRef()); 2149 }) 2150 .def( 2151 "__str__", 2152 [](PyOperationBase &self) { 2153 return self.getAsm(/*binary=*/false, 2154 /*largeElementsLimit=*/llvm::None, 2155 /*enableDebugInfo=*/false, 2156 /*prettyDebugInfo=*/false, 2157 /*printGenericOpForm=*/false, 2158 /*useLocalScope=*/false); 2159 }, 2160 "Returns the assembly form of the operation.") 2161 .def("print", &PyOperationBase::print, 2162 // Careful: Lots of arguments must match up with print method. 2163 py::arg("file") = py::none(), py::arg("binary") = false, 2164 py::arg("large_elements_limit") = py::none(), 2165 py::arg("enable_debug_info") = false, 2166 py::arg("pretty_debug_info") = false, 2167 py::arg("print_generic_op_form") = false, 2168 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2169 .def("get_asm", &PyOperationBase::getAsm, 2170 // Careful: Lots of arguments must match up with get_asm method. 2171 py::arg("binary") = false, 2172 py::arg("large_elements_limit") = py::none(), 2173 py::arg("enable_debug_info") = false, 2174 py::arg("pretty_debug_info") = false, 2175 py::arg("print_generic_op_form") = false, 2176 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2177 .def( 2178 "verify", 2179 [](PyOperationBase &self) { 2180 return mlirOperationVerify(self.getOperation()); 2181 }, 2182 "Verify the operation and return true if it passes, false if it " 2183 "fails."); 2184 2185 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2186 .def_static("create", &PyOperation::create, py::arg("name"), 2187 py::arg("results") = py::none(), 2188 py::arg("operands") = py::none(), 2189 py::arg("attributes") = py::none(), 2190 py::arg("successors") = py::none(), py::arg("regions") = 0, 2191 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2192 kOperationCreateDocstring) 2193 .def_property_readonly("parent", 2194 [](PyOperation &self) -> py::object { 2195 auto parent = self.getParentOperation(); 2196 if (parent) 2197 return parent->getObject(); 2198 return py::none(); 2199 }) 2200 .def("erase", &PyOperation::erase) 2201 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2202 &PyOperation::getCapsule) 2203 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2204 .def_property_readonly("name", 2205 [](PyOperation &self) { 2206 self.checkValid(); 2207 MlirOperation operation = self.get(); 2208 MlirStringRef name = mlirIdentifierStr( 2209 mlirOperationGetName(operation)); 2210 return py::str(name.data, name.length); 2211 }) 2212 .def_property_readonly( 2213 "context", 2214 [](PyOperation &self) { 2215 self.checkValid(); 2216 return self.getContext().getObject(); 2217 }, 2218 "Context that owns the Operation") 2219 .def_property_readonly("opview", &PyOperation::createOpView); 2220 2221 auto opViewClass = 2222 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2223 .def(py::init<py::object>()) 2224 .def_property_readonly("operation", &PyOpView::getOperationObject) 2225 .def_property_readonly( 2226 "context", 2227 [](PyOpView &self) { 2228 return self.getOperation().getContext().getObject(); 2229 }, 2230 "Context that owns the Operation") 2231 .def("__str__", [](PyOpView &self) { 2232 return py::str(self.getOperationObject()); 2233 }); 2234 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2235 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2236 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2237 opViewClass.attr("build_generic") = classmethod( 2238 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2239 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2240 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2241 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2242 "Builds a specific, generated OpView based on class level attributes."); 2243 2244 //---------------------------------------------------------------------------- 2245 // Mapping of PyRegion. 2246 //---------------------------------------------------------------------------- 2247 py::class_<PyRegion>(m, "Region", py::module_local()) 2248 .def_property_readonly( 2249 "blocks", 2250 [](PyRegion &self) { 2251 return PyBlockList(self.getParentOperation(), self.get()); 2252 }, 2253 "Returns a forward-optimized sequence of blocks.") 2254 .def_property_readonly( 2255 "owner", 2256 [](PyRegion &self) { 2257 return self.getParentOperation()->createOpView(); 2258 }, 2259 "Returns the operation owning this region.") 2260 .def( 2261 "__iter__", 2262 [](PyRegion &self) { 2263 self.checkValid(); 2264 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2265 return PyBlockIterator(self.getParentOperation(), firstBlock); 2266 }, 2267 "Iterates over blocks in the region.") 2268 .def("__eq__", 2269 [](PyRegion &self, PyRegion &other) { 2270 return self.get().ptr == other.get().ptr; 2271 }) 2272 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2273 2274 //---------------------------------------------------------------------------- 2275 // Mapping of PyBlock. 2276 //---------------------------------------------------------------------------- 2277 py::class_<PyBlock>(m, "Block", py::module_local()) 2278 .def_property_readonly( 2279 "owner", 2280 [](PyBlock &self) { 2281 return self.getParentOperation()->createOpView(); 2282 }, 2283 "Returns the owning operation of this block.") 2284 .def_property_readonly( 2285 "region", 2286 [](PyBlock &self) { 2287 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2288 return PyRegion(self.getParentOperation(), region); 2289 }, 2290 "Returns the owning region of this block.") 2291 .def_property_readonly( 2292 "arguments", 2293 [](PyBlock &self) { 2294 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2295 }, 2296 "Returns a list of block arguments.") 2297 .def_property_readonly( 2298 "operations", 2299 [](PyBlock &self) { 2300 return PyOperationList(self.getParentOperation(), self.get()); 2301 }, 2302 "Returns a forward-optimized sequence of operations.") 2303 .def_static( 2304 "create_at_start", 2305 [](PyRegion &parent, py::list pyArgTypes) { 2306 parent.checkValid(); 2307 llvm::SmallVector<MlirType, 4> argTypes; 2308 argTypes.reserve(pyArgTypes.size()); 2309 for (auto &pyArg : pyArgTypes) { 2310 argTypes.push_back(pyArg.cast<PyType &>()); 2311 } 2312 2313 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2314 mlirRegionInsertOwnedBlock(parent, 0, block); 2315 return PyBlock(parent.getParentOperation(), block); 2316 }, 2317 py::arg("parent"), py::arg("pyArgTypes") = py::list(), 2318 "Creates and returns a new Block at the beginning of the given " 2319 "region (with given argument types).") 2320 .def( 2321 "create_before", 2322 [](PyBlock &self, py::args pyArgTypes) { 2323 self.checkValid(); 2324 llvm::SmallVector<MlirType, 4> argTypes; 2325 argTypes.reserve(pyArgTypes.size()); 2326 for (auto &pyArg : pyArgTypes) { 2327 argTypes.push_back(pyArg.cast<PyType &>()); 2328 } 2329 2330 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2331 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2332 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2333 return PyBlock(self.getParentOperation(), block); 2334 }, 2335 "Creates and returns a new Block before this block " 2336 "(with given argument types).") 2337 .def( 2338 "create_after", 2339 [](PyBlock &self, py::args pyArgTypes) { 2340 self.checkValid(); 2341 llvm::SmallVector<MlirType, 4> argTypes; 2342 argTypes.reserve(pyArgTypes.size()); 2343 for (auto &pyArg : pyArgTypes) { 2344 argTypes.push_back(pyArg.cast<PyType &>()); 2345 } 2346 2347 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2348 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2349 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2350 return PyBlock(self.getParentOperation(), block); 2351 }, 2352 "Creates and returns a new Block after this block " 2353 "(with given argument types).") 2354 .def( 2355 "__iter__", 2356 [](PyBlock &self) { 2357 self.checkValid(); 2358 MlirOperation firstOperation = 2359 mlirBlockGetFirstOperation(self.get()); 2360 return PyOperationIterator(self.getParentOperation(), 2361 firstOperation); 2362 }, 2363 "Iterates over operations in the block.") 2364 .def("__eq__", 2365 [](PyBlock &self, PyBlock &other) { 2366 return self.get().ptr == other.get().ptr; 2367 }) 2368 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2369 .def( 2370 "__str__", 2371 [](PyBlock &self) { 2372 self.checkValid(); 2373 PyPrintAccumulator printAccum; 2374 mlirBlockPrint(self.get(), printAccum.getCallback(), 2375 printAccum.getUserData()); 2376 return printAccum.join(); 2377 }, 2378 "Returns the assembly form of the block."); 2379 2380 //---------------------------------------------------------------------------- 2381 // Mapping of PyInsertionPoint. 2382 //---------------------------------------------------------------------------- 2383 2384 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2385 .def(py::init<PyBlock &>(), py::arg("block"), 2386 "Inserts after the last operation but still inside the block.") 2387 .def("__enter__", &PyInsertionPoint::contextEnter) 2388 .def("__exit__", &PyInsertionPoint::contextExit) 2389 .def_property_readonly_static( 2390 "current", 2391 [](py::object & /*class*/) { 2392 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2393 if (!ip) 2394 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2395 return ip; 2396 }, 2397 "Gets the InsertionPoint bound to the current thread or raises " 2398 "ValueError if none has been set") 2399 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2400 "Inserts before a referenced operation.") 2401 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2402 py::arg("block"), "Inserts at the beginning of the block.") 2403 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2404 py::arg("block"), "Inserts before the block terminator.") 2405 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2406 "Inserts an operation.") 2407 .def_property_readonly( 2408 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2409 "Returns the block that this InsertionPoint points to."); 2410 2411 //---------------------------------------------------------------------------- 2412 // Mapping of PyAttribute. 2413 //---------------------------------------------------------------------------- 2414 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2415 // Delegate to the PyAttribute copy constructor, which will also lifetime 2416 // extend the backing context which owns the MlirAttribute. 2417 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2418 "Casts the passed attribute to the generic Attribute") 2419 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2420 &PyAttribute::getCapsule) 2421 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2422 .def_static( 2423 "parse", 2424 [](std::string attrSpec, DefaultingPyMlirContext context) { 2425 MlirAttribute type = mlirAttributeParseGet( 2426 context->get(), toMlirStringRef(attrSpec)); 2427 // TODO: Rework error reporting once diagnostic engine is exposed 2428 // in C API. 2429 if (mlirAttributeIsNull(type)) { 2430 throw SetPyError(PyExc_ValueError, 2431 Twine("Unable to parse attribute: '") + 2432 attrSpec + "'"); 2433 } 2434 return PyAttribute(context->getRef(), type); 2435 }, 2436 py::arg("asm"), py::arg("context") = py::none(), 2437 "Parses an attribute from an assembly form") 2438 .def_property_readonly( 2439 "context", 2440 [](PyAttribute &self) { return self.getContext().getObject(); }, 2441 "Context that owns the Attribute") 2442 .def_property_readonly("type", 2443 [](PyAttribute &self) { 2444 return PyType(self.getContext()->getRef(), 2445 mlirAttributeGetType(self)); 2446 }) 2447 .def( 2448 "get_named", 2449 [](PyAttribute &self, std::string name) { 2450 return PyNamedAttribute(self, std::move(name)); 2451 }, 2452 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2453 .def("__eq__", 2454 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2455 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2456 .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; }) 2457 .def( 2458 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2459 kDumpDocstring) 2460 .def( 2461 "__str__", 2462 [](PyAttribute &self) { 2463 PyPrintAccumulator printAccum; 2464 mlirAttributePrint(self, printAccum.getCallback(), 2465 printAccum.getUserData()); 2466 return printAccum.join(); 2467 }, 2468 "Returns the assembly form of the Attribute.") 2469 .def("__repr__", [](PyAttribute &self) { 2470 // Generally, assembly formats are not printed for __repr__ because 2471 // this can cause exceptionally long debug output and exceptions. 2472 // However, attribute values are generally considered useful and are 2473 // printed. This may need to be re-evaluated if debug dumps end up 2474 // being excessive. 2475 PyPrintAccumulator printAccum; 2476 printAccum.parts.append("Attribute("); 2477 mlirAttributePrint(self, printAccum.getCallback(), 2478 printAccum.getUserData()); 2479 printAccum.parts.append(")"); 2480 return printAccum.join(); 2481 }); 2482 2483 //---------------------------------------------------------------------------- 2484 // Mapping of PyNamedAttribute 2485 //---------------------------------------------------------------------------- 2486 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2487 .def("__repr__", 2488 [](PyNamedAttribute &self) { 2489 PyPrintAccumulator printAccum; 2490 printAccum.parts.append("NamedAttribute("); 2491 printAccum.parts.append( 2492 mlirIdentifierStr(self.namedAttr.name).data); 2493 printAccum.parts.append("="); 2494 mlirAttributePrint(self.namedAttr.attribute, 2495 printAccum.getCallback(), 2496 printAccum.getUserData()); 2497 printAccum.parts.append(")"); 2498 return printAccum.join(); 2499 }) 2500 .def_property_readonly( 2501 "name", 2502 [](PyNamedAttribute &self) { 2503 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2504 mlirIdentifierStr(self.namedAttr.name).length); 2505 }, 2506 "The name of the NamedAttribute binding") 2507 .def_property_readonly( 2508 "attr", 2509 [](PyNamedAttribute &self) { 2510 // TODO: When named attribute is removed/refactored, also remove 2511 // this constructor (it does an inefficient table lookup). 2512 auto contextRef = PyMlirContext::forContext( 2513 mlirAttributeGetContext(self.namedAttr.attribute)); 2514 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2515 }, 2516 py::keep_alive<0, 1>(), 2517 "The underlying generic attribute of the NamedAttribute binding"); 2518 2519 //---------------------------------------------------------------------------- 2520 // Mapping of PyType. 2521 //---------------------------------------------------------------------------- 2522 py::class_<PyType>(m, "Type", py::module_local()) 2523 // Delegate to the PyType copy constructor, which will also lifetime 2524 // extend the backing context which owns the MlirType. 2525 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2526 "Casts the passed type to the generic Type") 2527 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2528 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2529 .def_static( 2530 "parse", 2531 [](std::string typeSpec, DefaultingPyMlirContext context) { 2532 MlirType type = 2533 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2534 // TODO: Rework error reporting once diagnostic engine is exposed 2535 // in C API. 2536 if (mlirTypeIsNull(type)) { 2537 throw SetPyError(PyExc_ValueError, 2538 Twine("Unable to parse type: '") + typeSpec + 2539 "'"); 2540 } 2541 return PyType(context->getRef(), type); 2542 }, 2543 py::arg("asm"), py::arg("context") = py::none(), 2544 kContextParseTypeDocstring) 2545 .def_property_readonly( 2546 "context", [](PyType &self) { return self.getContext().getObject(); }, 2547 "Context that owns the Type") 2548 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2549 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2550 .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; }) 2551 .def( 2552 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2553 .def( 2554 "__str__", 2555 [](PyType &self) { 2556 PyPrintAccumulator printAccum; 2557 mlirTypePrint(self, printAccum.getCallback(), 2558 printAccum.getUserData()); 2559 return printAccum.join(); 2560 }, 2561 "Returns the assembly form of the type.") 2562 .def("__repr__", [](PyType &self) { 2563 // Generally, assembly formats are not printed for __repr__ because 2564 // this can cause exceptionally long debug output and exceptions. 2565 // However, types are an exception as they typically have compact 2566 // assembly forms and printing them is useful. 2567 PyPrintAccumulator printAccum; 2568 printAccum.parts.append("Type("); 2569 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2570 printAccum.parts.append(")"); 2571 return printAccum.join(); 2572 }); 2573 2574 //---------------------------------------------------------------------------- 2575 // Mapping of Value. 2576 //---------------------------------------------------------------------------- 2577 py::class_<PyValue>(m, "Value", py::module_local()) 2578 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2579 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2580 .def_property_readonly( 2581 "context", 2582 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2583 "Context in which the value lives.") 2584 .def( 2585 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2586 kDumpDocstring) 2587 .def_property_readonly( 2588 "owner", 2589 [](PyValue &self) { 2590 assert(mlirOperationEqual(self.getParentOperation()->get(), 2591 mlirOpResultGetOwner(self.get())) && 2592 "expected the owner of the value in Python to match that in " 2593 "the IR"); 2594 return self.getParentOperation().getObject(); 2595 }) 2596 .def("__eq__", 2597 [](PyValue &self, PyValue &other) { 2598 return self.get().ptr == other.get().ptr; 2599 }) 2600 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2601 .def( 2602 "__str__", 2603 [](PyValue &self) { 2604 PyPrintAccumulator printAccum; 2605 printAccum.parts.append("Value("); 2606 mlirValuePrint(self.get(), printAccum.getCallback(), 2607 printAccum.getUserData()); 2608 printAccum.parts.append(")"); 2609 return printAccum.join(); 2610 }, 2611 kValueDunderStrDocstring) 2612 .def_property_readonly("type", [](PyValue &self) { 2613 return PyType(self.getParentOperation()->getContext(), 2614 mlirValueGetType(self.get())); 2615 }); 2616 PyBlockArgument::bind(m); 2617 PyOpResult::bind(m); 2618 2619 // Container bindings. 2620 PyBlockArgumentList::bind(m); 2621 PyBlockIterator::bind(m); 2622 PyBlockList::bind(m); 2623 PyOperationIterator::bind(m); 2624 PyOperationList::bind(m); 2625 PyOpAttributeMap::bind(m); 2626 PyOpOperandList::bind(m); 2627 PyOpResultList::bind(m); 2628 PyRegionIterator::bind(m); 2629 PyRegionList::bind(m); 2630 2631 // Debug bindings. 2632 PyGlobalDebugFlag::bind(m); 2633 } 2634