1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include <pybind11/stl.h> 22 23 namespace py = pybind11; 24 using namespace mlir; 25 using namespace mlir::python; 26 27 using llvm::SmallVector; 28 using llvm::StringRef; 29 using llvm::Twine; 30 31 //------------------------------------------------------------------------------ 32 // Docstrings (trivial, non-duplicated docstrings are included inline). 33 //------------------------------------------------------------------------------ 34 35 static const char kContextParseTypeDocstring[] = 36 R"(Parses the assembly form of a type. 37 38 Returns a Type object or raises a ValueError if the type cannot be parsed. 39 40 See also: https://mlir.llvm.org/docs/LangRef/#type-system 41 )"; 42 43 static const char kContextGetFileLocationDocstring[] = 44 R"(Gets a Location representing a file, line and column)"; 45 46 static const char kContextGetNameLocationDocString[] = 47 R"(Gets a Location representing a named location with optional child location)"; 48 49 static const char kModuleParseDocstring[] = 50 R"(Parses a module's assembly format from a string. 51 52 Returns a new MlirModule or raises a ValueError if the parsing fails. 53 54 See also: https://mlir.llvm.org/docs/LangRef/ 55 )"; 56 57 static const char kOperationCreateDocstring[] = 58 R"(Creates a new operation. 59 60 Args: 61 name: Operation name (e.g. "dialect.operation"). 62 results: Sequence of Type representing op result types. 63 attributes: Dict of str:Attribute. 64 successors: List of Block for the operation's successors. 65 regions: Number of regions to create. 66 location: A Location object (defaults to resolve from context manager). 67 ip: An InsertionPoint (defaults to resolve from context manager or set to 68 False to disable insertion, even with an insertion point set in the 69 context manager). 70 Returns: 71 A new "detached" Operation object. Detached operations can be added 72 to blocks, which causes them to become "attached." 73 )"; 74 75 static const char kOperationPrintDocstring[] = 76 R"(Prints the assembly form of the operation to a file like object. 77 78 Args: 79 file: The file like object to write to. Defaults to sys.stdout. 80 binary: Whether to write bytes (True) or str (False). Defaults to False. 81 large_elements_limit: Whether to elide elements attributes above this 82 number of elements. Defaults to None (no limit). 83 enable_debug_info: Whether to print debug/location information. Defaults 84 to False. 85 pretty_debug_info: Whether to format debug information for easier reading 86 by a human (warning: the result is unparseable). 87 print_generic_op_form: Whether to print the generic assembly forms of all 88 ops. Defaults to False. 89 use_local_Scope: Whether to print in a way that is more optimized for 90 multi-threaded access but may not be consistent with how the overall 91 module prints. 92 )"; 93 94 static const char kOperationGetAsmDocstring[] = 95 R"(Gets the assembly form of the operation with all options available. 96 97 Args: 98 binary: Whether to return a bytes (True) or str (False) object. Defaults to 99 False. 100 ... others ...: See the print() method for common keyword arguments for 101 configuring the printout. 102 Returns: 103 Either a bytes or str object, depending on the setting of the 'binary' 104 argument. 105 )"; 106 107 static const char kOperationStrDunderDocstring[] = 108 R"(Gets the assembly form of the operation with default options. 109 110 If more advanced control over the assembly formatting or I/O options is needed, 111 use the dedicated print or get_asm method, which supports keyword arguments to 112 customize behavior. 113 )"; 114 115 static const char kDumpDocstring[] = 116 R"(Dumps a debug representation of the object to stderr.)"; 117 118 static const char kAppendBlockDocstring[] = 119 R"(Appends a new block, with argument types as positional args. 120 121 Returns: 122 The created block. 123 )"; 124 125 static const char kValueDunderStrDocstring[] = 126 R"(Returns the string form of the value. 127 128 If the value is a block argument, this is the assembly form of its type and the 129 position in the argument list. If the value is an operation result, this is 130 equivalent to printing the operation that produced it. 131 )"; 132 133 //------------------------------------------------------------------------------ 134 // Utilities. 135 //------------------------------------------------------------------------------ 136 137 /// Helper for creating an @classmethod. 138 template <class Func, typename... Args> 139 py::object classmethod(Func f, Args... args) { 140 py::object cf = py::cpp_function(f, args...); 141 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 142 } 143 144 static py::object 145 createCustomDialectWrapper(const std::string &dialectNamespace, 146 py::object dialectDescriptor) { 147 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 148 if (!dialectClass) { 149 // Use the base class. 150 return py::cast(PyDialect(std::move(dialectDescriptor))); 151 } 152 153 // Create the custom implementation. 154 return (*dialectClass)(std::move(dialectDescriptor)); 155 } 156 157 static MlirStringRef toMlirStringRef(const std::string &s) { 158 return mlirStringRefCreate(s.data(), s.size()); 159 } 160 161 /// Wrapper for the global LLVM debugging flag. 162 struct PyGlobalDebugFlag { 163 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 164 165 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 166 167 static void bind(py::module &m) { 168 // Debug flags. 169 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 170 .def_property_static("flag", &PyGlobalDebugFlag::get, 171 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 172 } 173 }; 174 175 //------------------------------------------------------------------------------ 176 // Collections. 177 //------------------------------------------------------------------------------ 178 179 namespace { 180 181 class PyRegionIterator { 182 public: 183 PyRegionIterator(PyOperationRef operation) 184 : operation(std::move(operation)) {} 185 186 PyRegionIterator &dunderIter() { return *this; } 187 188 PyRegion dunderNext() { 189 operation->checkValid(); 190 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 191 throw py::stop_iteration(); 192 } 193 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 194 return PyRegion(operation, region); 195 } 196 197 static void bind(py::module &m) { 198 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 199 .def("__iter__", &PyRegionIterator::dunderIter) 200 .def("__next__", &PyRegionIterator::dunderNext); 201 } 202 203 private: 204 PyOperationRef operation; 205 int nextIndex = 0; 206 }; 207 208 /// Regions of an op are fixed length and indexed numerically so are represented 209 /// with a sequence-like container. 210 class PyRegionList { 211 public: 212 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 213 214 intptr_t dunderLen() { 215 operation->checkValid(); 216 return mlirOperationGetNumRegions(operation->get()); 217 } 218 219 PyRegion dunderGetItem(intptr_t index) { 220 // dunderLen checks validity. 221 if (index < 0 || index >= dunderLen()) { 222 throw SetPyError(PyExc_IndexError, 223 "attempt to access out of bounds region"); 224 } 225 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 226 return PyRegion(operation, region); 227 } 228 229 static void bind(py::module &m) { 230 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 231 .def("__len__", &PyRegionList::dunderLen) 232 .def("__getitem__", &PyRegionList::dunderGetItem); 233 } 234 235 private: 236 PyOperationRef operation; 237 }; 238 239 class PyBlockIterator { 240 public: 241 PyBlockIterator(PyOperationRef operation, MlirBlock next) 242 : operation(std::move(operation)), next(next) {} 243 244 PyBlockIterator &dunderIter() { return *this; } 245 246 PyBlock dunderNext() { 247 operation->checkValid(); 248 if (mlirBlockIsNull(next)) { 249 throw py::stop_iteration(); 250 } 251 252 PyBlock returnBlock(operation, next); 253 next = mlirBlockGetNextInRegion(next); 254 return returnBlock; 255 } 256 257 static void bind(py::module &m) { 258 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 259 .def("__iter__", &PyBlockIterator::dunderIter) 260 .def("__next__", &PyBlockIterator::dunderNext); 261 } 262 263 private: 264 PyOperationRef operation; 265 MlirBlock next; 266 }; 267 268 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 269 /// we present them as a more full-featured list-like container but optimize 270 /// it for forward iteration. Blocks are always owned by a region. 271 class PyBlockList { 272 public: 273 PyBlockList(PyOperationRef operation, MlirRegion region) 274 : operation(std::move(operation)), region(region) {} 275 276 PyBlockIterator dunderIter() { 277 operation->checkValid(); 278 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 279 } 280 281 intptr_t dunderLen() { 282 operation->checkValid(); 283 intptr_t count = 0; 284 MlirBlock block = mlirRegionGetFirstBlock(region); 285 while (!mlirBlockIsNull(block)) { 286 count += 1; 287 block = mlirBlockGetNextInRegion(block); 288 } 289 return count; 290 } 291 292 PyBlock dunderGetItem(intptr_t index) { 293 operation->checkValid(); 294 if (index < 0) { 295 throw SetPyError(PyExc_IndexError, 296 "attempt to access out of bounds block"); 297 } 298 MlirBlock block = mlirRegionGetFirstBlock(region); 299 while (!mlirBlockIsNull(block)) { 300 if (index == 0) { 301 return PyBlock(operation, block); 302 } 303 block = mlirBlockGetNextInRegion(block); 304 index -= 1; 305 } 306 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 307 } 308 309 PyBlock appendBlock(py::args pyArgTypes) { 310 operation->checkValid(); 311 llvm::SmallVector<MlirType, 4> argTypes; 312 argTypes.reserve(pyArgTypes.size()); 313 for (auto &pyArg : pyArgTypes) { 314 argTypes.push_back(pyArg.cast<PyType &>()); 315 } 316 317 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 318 mlirRegionAppendOwnedBlock(region, block); 319 return PyBlock(operation, block); 320 } 321 322 static void bind(py::module &m) { 323 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 324 .def("__getitem__", &PyBlockList::dunderGetItem) 325 .def("__iter__", &PyBlockList::dunderIter) 326 .def("__len__", &PyBlockList::dunderLen) 327 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 328 } 329 330 private: 331 PyOperationRef operation; 332 MlirRegion region; 333 }; 334 335 class PyOperationIterator { 336 public: 337 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 338 : parentOperation(std::move(parentOperation)), next(next) {} 339 340 PyOperationIterator &dunderIter() { return *this; } 341 342 py::object dunderNext() { 343 parentOperation->checkValid(); 344 if (mlirOperationIsNull(next)) { 345 throw py::stop_iteration(); 346 } 347 348 PyOperationRef returnOperation = 349 PyOperation::forOperation(parentOperation->getContext(), next); 350 next = mlirOperationGetNextInBlock(next); 351 return returnOperation->createOpView(); 352 } 353 354 static void bind(py::module &m) { 355 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 356 .def("__iter__", &PyOperationIterator::dunderIter) 357 .def("__next__", &PyOperationIterator::dunderNext); 358 } 359 360 private: 361 PyOperationRef parentOperation; 362 MlirOperation next; 363 }; 364 365 /// Operations are exposed by the C-API as a forward-only linked list. In 366 /// Python, we present them as a more full-featured list-like container but 367 /// optimize it for forward iteration. Iterable operations are always owned 368 /// by a block. 369 class PyOperationList { 370 public: 371 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 372 : parentOperation(std::move(parentOperation)), block(block) {} 373 374 PyOperationIterator dunderIter() { 375 parentOperation->checkValid(); 376 return PyOperationIterator(parentOperation, 377 mlirBlockGetFirstOperation(block)); 378 } 379 380 intptr_t dunderLen() { 381 parentOperation->checkValid(); 382 intptr_t count = 0; 383 MlirOperation childOp = mlirBlockGetFirstOperation(block); 384 while (!mlirOperationIsNull(childOp)) { 385 count += 1; 386 childOp = mlirOperationGetNextInBlock(childOp); 387 } 388 return count; 389 } 390 391 py::object dunderGetItem(intptr_t index) { 392 parentOperation->checkValid(); 393 if (index < 0) { 394 throw SetPyError(PyExc_IndexError, 395 "attempt to access out of bounds operation"); 396 } 397 MlirOperation childOp = mlirBlockGetFirstOperation(block); 398 while (!mlirOperationIsNull(childOp)) { 399 if (index == 0) { 400 return PyOperation::forOperation(parentOperation->getContext(), childOp) 401 ->createOpView(); 402 } 403 childOp = mlirOperationGetNextInBlock(childOp); 404 index -= 1; 405 } 406 throw SetPyError(PyExc_IndexError, 407 "attempt to access out of bounds operation"); 408 } 409 410 static void bind(py::module &m) { 411 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 412 .def("__getitem__", &PyOperationList::dunderGetItem) 413 .def("__iter__", &PyOperationList::dunderIter) 414 .def("__len__", &PyOperationList::dunderLen); 415 } 416 417 private: 418 PyOperationRef parentOperation; 419 MlirBlock block; 420 }; 421 422 } // namespace 423 424 //------------------------------------------------------------------------------ 425 // PyMlirContext 426 //------------------------------------------------------------------------------ 427 428 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 429 py::gil_scoped_acquire acquire; 430 auto &liveContexts = getLiveContexts(); 431 liveContexts[context.ptr] = this; 432 } 433 434 PyMlirContext::~PyMlirContext() { 435 // Note that the only public way to construct an instance is via the 436 // forContext method, which always puts the associated handle into 437 // liveContexts. 438 py::gil_scoped_acquire acquire; 439 getLiveContexts().erase(context.ptr); 440 mlirContextDestroy(context); 441 } 442 443 py::object PyMlirContext::getCapsule() { 444 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 445 } 446 447 py::object PyMlirContext::createFromCapsule(py::object capsule) { 448 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 449 if (mlirContextIsNull(rawContext)) 450 throw py::error_already_set(); 451 return forContext(rawContext).releaseObject(); 452 } 453 454 PyMlirContext *PyMlirContext::createNewContextForInit() { 455 MlirContext context = mlirContextCreate(); 456 mlirRegisterAllDialects(context); 457 return new PyMlirContext(context); 458 } 459 460 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 461 py::gil_scoped_acquire acquire; 462 auto &liveContexts = getLiveContexts(); 463 auto it = liveContexts.find(context.ptr); 464 if (it == liveContexts.end()) { 465 // Create. 466 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 467 py::object pyRef = py::cast(unownedContextWrapper); 468 assert(pyRef && "cast to py::object failed"); 469 liveContexts[context.ptr] = unownedContextWrapper; 470 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 471 } 472 // Use existing. 473 py::object pyRef = py::cast(it->second); 474 return PyMlirContextRef(it->second, std::move(pyRef)); 475 } 476 477 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 478 static LiveContextMap liveContexts; 479 return liveContexts; 480 } 481 482 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 483 484 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 485 486 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 487 488 pybind11::object PyMlirContext::contextEnter() { 489 return PyThreadContextEntry::pushContext(*this); 490 } 491 492 void PyMlirContext::contextExit(pybind11::object excType, 493 pybind11::object excVal, 494 pybind11::object excTb) { 495 PyThreadContextEntry::popContext(*this); 496 } 497 498 PyMlirContext &DefaultingPyMlirContext::resolve() { 499 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 500 if (!context) { 501 throw SetPyError( 502 PyExc_RuntimeError, 503 "An MLIR function requires a Context but none was provided in the call " 504 "or from the surrounding environment. Either pass to the function with " 505 "a 'context=' argument or establish a default using 'with Context():'"); 506 } 507 return *context; 508 } 509 510 //------------------------------------------------------------------------------ 511 // PyThreadContextEntry management 512 //------------------------------------------------------------------------------ 513 514 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 515 static thread_local std::vector<PyThreadContextEntry> stack; 516 return stack; 517 } 518 519 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 520 auto &stack = getStack(); 521 if (stack.empty()) 522 return nullptr; 523 return &stack.back(); 524 } 525 526 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 527 py::object insertionPoint, 528 py::object location) { 529 auto &stack = getStack(); 530 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 531 std::move(location)); 532 // If the new stack has more than one entry and the context of the new top 533 // entry matches the previous, copy the insertionPoint and location from the 534 // previous entry if missing from the new top entry. 535 if (stack.size() > 1) { 536 auto &prev = *(stack.rbegin() + 1); 537 auto ¤t = stack.back(); 538 if (current.context.is(prev.context)) { 539 // Default non-context objects from the previous entry. 540 if (!current.insertionPoint) 541 current.insertionPoint = prev.insertionPoint; 542 if (!current.location) 543 current.location = prev.location; 544 } 545 } 546 } 547 548 PyMlirContext *PyThreadContextEntry::getContext() { 549 if (!context) 550 return nullptr; 551 return py::cast<PyMlirContext *>(context); 552 } 553 554 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 555 if (!insertionPoint) 556 return nullptr; 557 return py::cast<PyInsertionPoint *>(insertionPoint); 558 } 559 560 PyLocation *PyThreadContextEntry::getLocation() { 561 if (!location) 562 return nullptr; 563 return py::cast<PyLocation *>(location); 564 } 565 566 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 567 auto *tos = getTopOfStack(); 568 return tos ? tos->getContext() : nullptr; 569 } 570 571 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 572 auto *tos = getTopOfStack(); 573 return tos ? tos->getInsertionPoint() : nullptr; 574 } 575 576 PyLocation *PyThreadContextEntry::getDefaultLocation() { 577 auto *tos = getTopOfStack(); 578 return tos ? tos->getLocation() : nullptr; 579 } 580 581 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 582 py::object contextObj = py::cast(context); 583 push(FrameKind::Context, /*context=*/contextObj, 584 /*insertionPoint=*/py::object(), 585 /*location=*/py::object()); 586 return contextObj; 587 } 588 589 void PyThreadContextEntry::popContext(PyMlirContext &context) { 590 auto &stack = getStack(); 591 if (stack.empty()) 592 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 593 auto &tos = stack.back(); 594 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 595 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 596 stack.pop_back(); 597 } 598 599 py::object 600 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 601 py::object contextObj = 602 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 603 py::object insertionPointObj = py::cast(insertionPoint); 604 push(FrameKind::InsertionPoint, 605 /*context=*/contextObj, 606 /*insertionPoint=*/insertionPointObj, 607 /*location=*/py::object()); 608 return insertionPointObj; 609 } 610 611 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 612 auto &stack = getStack(); 613 if (stack.empty()) 614 throw SetPyError(PyExc_RuntimeError, 615 "Unbalanced InsertionPoint enter/exit"); 616 auto &tos = stack.back(); 617 if (tos.frameKind != FrameKind::InsertionPoint && 618 tos.getInsertionPoint() != &insertionPoint) 619 throw SetPyError(PyExc_RuntimeError, 620 "Unbalanced InsertionPoint enter/exit"); 621 stack.pop_back(); 622 } 623 624 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 625 py::object contextObj = location.getContext().getObject(); 626 py::object locationObj = py::cast(location); 627 push(FrameKind::Location, /*context=*/contextObj, 628 /*insertionPoint=*/py::object(), 629 /*location=*/locationObj); 630 return locationObj; 631 } 632 633 void PyThreadContextEntry::popLocation(PyLocation &location) { 634 auto &stack = getStack(); 635 if (stack.empty()) 636 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 637 auto &tos = stack.back(); 638 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 639 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 640 stack.pop_back(); 641 } 642 643 //------------------------------------------------------------------------------ 644 // PyDialect, PyDialectDescriptor, PyDialects 645 //------------------------------------------------------------------------------ 646 647 MlirDialect PyDialects::getDialectForKey(const std::string &key, 648 bool attrError) { 649 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 650 {key.data(), key.size()}); 651 if (mlirDialectIsNull(dialect)) { 652 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 653 Twine("Dialect '") + key + "' not found"); 654 } 655 return dialect; 656 } 657 658 //------------------------------------------------------------------------------ 659 // PyLocation 660 //------------------------------------------------------------------------------ 661 662 py::object PyLocation::getCapsule() { 663 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 664 } 665 666 PyLocation PyLocation::createFromCapsule(py::object capsule) { 667 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 668 if (mlirLocationIsNull(rawLoc)) 669 throw py::error_already_set(); 670 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 671 rawLoc); 672 } 673 674 py::object PyLocation::contextEnter() { 675 return PyThreadContextEntry::pushLocation(*this); 676 } 677 678 void PyLocation::contextExit(py::object excType, py::object excVal, 679 py::object excTb) { 680 PyThreadContextEntry::popLocation(*this); 681 } 682 683 PyLocation &DefaultingPyLocation::resolve() { 684 auto *location = PyThreadContextEntry::getDefaultLocation(); 685 if (!location) { 686 throw SetPyError( 687 PyExc_RuntimeError, 688 "An MLIR function requires a Location but none was provided in the " 689 "call or from the surrounding environment. Either pass to the function " 690 "with a 'loc=' argument or establish a default using 'with loc:'"); 691 } 692 return *location; 693 } 694 695 //------------------------------------------------------------------------------ 696 // PyModule 697 //------------------------------------------------------------------------------ 698 699 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 700 : BaseContextObject(std::move(contextRef)), module(module) {} 701 702 PyModule::~PyModule() { 703 py::gil_scoped_acquire acquire; 704 auto &liveModules = getContext()->liveModules; 705 assert(liveModules.count(module.ptr) == 1 && 706 "destroying module not in live map"); 707 liveModules.erase(module.ptr); 708 mlirModuleDestroy(module); 709 } 710 711 PyModuleRef PyModule::forModule(MlirModule module) { 712 MlirContext context = mlirModuleGetContext(module); 713 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 714 715 py::gil_scoped_acquire acquire; 716 auto &liveModules = contextRef->liveModules; 717 auto it = liveModules.find(module.ptr); 718 if (it == liveModules.end()) { 719 // Create. 720 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 721 // Note that the default return value policy on cast is automatic_reference, 722 // which does not take ownership (delete will not be called). 723 // Just be explicit. 724 py::object pyRef = 725 py::cast(unownedModule, py::return_value_policy::take_ownership); 726 unownedModule->handle = pyRef; 727 liveModules[module.ptr] = 728 std::make_pair(unownedModule->handle, unownedModule); 729 return PyModuleRef(unownedModule, std::move(pyRef)); 730 } 731 // Use existing. 732 PyModule *existing = it->second.second; 733 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 734 return PyModuleRef(existing, std::move(pyRef)); 735 } 736 737 py::object PyModule::createFromCapsule(py::object capsule) { 738 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 739 if (mlirModuleIsNull(rawModule)) 740 throw py::error_already_set(); 741 return forModule(rawModule).releaseObject(); 742 } 743 744 py::object PyModule::getCapsule() { 745 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 746 } 747 748 //------------------------------------------------------------------------------ 749 // PyOperation 750 //------------------------------------------------------------------------------ 751 752 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 753 : BaseContextObject(std::move(contextRef)), operation(operation) {} 754 755 PyOperation::~PyOperation() { 756 // If the operation has already been invalidated there is nothing to do. 757 if (!valid) 758 return; 759 auto &liveOperations = getContext()->liveOperations; 760 assert(liveOperations.count(operation.ptr) == 1 && 761 "destroying operation not in live map"); 762 liveOperations.erase(operation.ptr); 763 if (!isAttached()) { 764 mlirOperationDestroy(operation); 765 } 766 } 767 768 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 769 MlirOperation operation, 770 py::object parentKeepAlive) { 771 auto &liveOperations = contextRef->liveOperations; 772 // Create. 773 PyOperation *unownedOperation = 774 new PyOperation(std::move(contextRef), operation); 775 // Note that the default return value policy on cast is automatic_reference, 776 // which does not take ownership (delete will not be called). 777 // Just be explicit. 778 py::object pyRef = 779 py::cast(unownedOperation, py::return_value_policy::take_ownership); 780 unownedOperation->handle = pyRef; 781 if (parentKeepAlive) { 782 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 783 } 784 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 785 return PyOperationRef(unownedOperation, std::move(pyRef)); 786 } 787 788 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 789 MlirOperation operation, 790 py::object parentKeepAlive) { 791 auto &liveOperations = contextRef->liveOperations; 792 auto it = liveOperations.find(operation.ptr); 793 if (it == liveOperations.end()) { 794 // Create. 795 return createInstance(std::move(contextRef), operation, 796 std::move(parentKeepAlive)); 797 } 798 // Use existing. 799 PyOperation *existing = it->second.second; 800 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 801 return PyOperationRef(existing, std::move(pyRef)); 802 } 803 804 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 805 MlirOperation operation, 806 py::object parentKeepAlive) { 807 auto &liveOperations = contextRef->liveOperations; 808 assert(liveOperations.count(operation.ptr) == 0 && 809 "cannot create detached operation that already exists"); 810 (void)liveOperations; 811 812 PyOperationRef created = createInstance(std::move(contextRef), operation, 813 std::move(parentKeepAlive)); 814 created->attached = false; 815 return created; 816 } 817 818 void PyOperation::checkValid() const { 819 if (!valid) { 820 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 821 } 822 } 823 824 void PyOperationBase::print(py::object fileObject, bool binary, 825 llvm::Optional<int64_t> largeElementsLimit, 826 bool enableDebugInfo, bool prettyDebugInfo, 827 bool printGenericOpForm, bool useLocalScope) { 828 PyOperation &operation = getOperation(); 829 operation.checkValid(); 830 if (fileObject.is_none()) 831 fileObject = py::module::import("sys").attr("stdout"); 832 833 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 834 fileObject.attr("write")("// Verification failed, printing generic form\n"); 835 printGenericOpForm = true; 836 } 837 838 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 839 if (largeElementsLimit) 840 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 841 if (enableDebugInfo) 842 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 843 if (printGenericOpForm) 844 mlirOpPrintingFlagsPrintGenericOpForm(flags); 845 846 PyFileAccumulator accum(fileObject, binary); 847 py::gil_scoped_release(); 848 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 849 accum.getUserData()); 850 mlirOpPrintingFlagsDestroy(flags); 851 } 852 853 py::object PyOperationBase::getAsm(bool binary, 854 llvm::Optional<int64_t> largeElementsLimit, 855 bool enableDebugInfo, bool prettyDebugInfo, 856 bool printGenericOpForm, 857 bool useLocalScope) { 858 py::object fileObject; 859 if (binary) { 860 fileObject = py::module::import("io").attr("BytesIO")(); 861 } else { 862 fileObject = py::module::import("io").attr("StringIO")(); 863 } 864 print(fileObject, /*binary=*/binary, 865 /*largeElementsLimit=*/largeElementsLimit, 866 /*enableDebugInfo=*/enableDebugInfo, 867 /*prettyDebugInfo=*/prettyDebugInfo, 868 /*printGenericOpForm=*/printGenericOpForm, 869 /*useLocalScope=*/useLocalScope); 870 871 return fileObject.attr("getvalue")(); 872 } 873 874 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 875 checkValid(); 876 if (!isAttached()) 877 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 878 MlirOperation operation = mlirOperationGetParentOperation(get()); 879 if (mlirOperationIsNull(operation)) 880 return {}; 881 return PyOperation::forOperation(getContext(), operation); 882 } 883 884 PyBlock PyOperation::getBlock() { 885 checkValid(); 886 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 887 MlirBlock block = mlirOperationGetBlock(get()); 888 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 889 assert(parentOperation && "Operation has no parent"); 890 return PyBlock{std::move(*parentOperation), block}; 891 } 892 893 py::object PyOperation::getCapsule() { 894 checkValid(); 895 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 896 } 897 898 py::object PyOperation::createFromCapsule(py::object capsule) { 899 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 900 if (mlirOperationIsNull(rawOperation)) 901 throw py::error_already_set(); 902 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 903 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 904 .releaseObject(); 905 } 906 907 py::object PyOperation::create( 908 std::string name, llvm::Optional<std::vector<PyType *>> results, 909 llvm::Optional<std::vector<PyValue *>> operands, 910 llvm::Optional<py::dict> attributes, 911 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 912 DefaultingPyLocation location, py::object maybeIp) { 913 llvm::SmallVector<MlirValue, 4> mlirOperands; 914 llvm::SmallVector<MlirType, 4> mlirResults; 915 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 916 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 917 918 // General parameter validation. 919 if (regions < 0) 920 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 921 922 // Unpack/validate operands. 923 if (operands) { 924 mlirOperands.reserve(operands->size()); 925 for (PyValue *operand : *operands) { 926 if (!operand) 927 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 928 mlirOperands.push_back(operand->get()); 929 } 930 } 931 932 // Unpack/validate results. 933 if (results) { 934 mlirResults.reserve(results->size()); 935 for (PyType *result : *results) { 936 // TODO: Verify result type originate from the same context. 937 if (!result) 938 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 939 mlirResults.push_back(*result); 940 } 941 } 942 // Unpack/validate attributes. 943 if (attributes) { 944 mlirAttributes.reserve(attributes->size()); 945 for (auto &it : *attributes) { 946 std::string key; 947 try { 948 key = it.first.cast<std::string>(); 949 } catch (py::cast_error &err) { 950 std::string msg = "Invalid attribute key (not a string) when " 951 "attempting to create the operation \"" + 952 name + "\" (" + err.what() + ")"; 953 throw py::cast_error(msg); 954 } 955 try { 956 auto &attribute = it.second.cast<PyAttribute &>(); 957 // TODO: Verify attribute originates from the same context. 958 mlirAttributes.emplace_back(std::move(key), attribute); 959 } catch (py::reference_cast_error &) { 960 // This exception seems thrown when the value is "None". 961 std::string msg = 962 "Found an invalid (`None`?) attribute value for the key \"" + key + 963 "\" when attempting to create the operation \"" + name + "\""; 964 throw py::cast_error(msg); 965 } catch (py::cast_error &err) { 966 std::string msg = "Invalid attribute value for the key \"" + key + 967 "\" when attempting to create the operation \"" + 968 name + "\" (" + err.what() + ")"; 969 throw py::cast_error(msg); 970 } 971 } 972 } 973 // Unpack/validate successors. 974 if (successors) { 975 mlirSuccessors.reserve(successors->size()); 976 for (auto *successor : *successors) { 977 // TODO: Verify successor originate from the same context. 978 if (!successor) 979 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 980 mlirSuccessors.push_back(successor->get()); 981 } 982 } 983 984 // Apply unpacked/validated to the operation state. Beyond this 985 // point, exceptions cannot be thrown or else the state will leak. 986 MlirOperationState state = 987 mlirOperationStateGet(toMlirStringRef(name), location); 988 if (!mlirOperands.empty()) 989 mlirOperationStateAddOperands(&state, mlirOperands.size(), 990 mlirOperands.data()); 991 if (!mlirResults.empty()) 992 mlirOperationStateAddResults(&state, mlirResults.size(), 993 mlirResults.data()); 994 if (!mlirAttributes.empty()) { 995 // Note that the attribute names directly reference bytes in 996 // mlirAttributes, so that vector must not be changed from here 997 // on. 998 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 999 mlirNamedAttributes.reserve(mlirAttributes.size()); 1000 for (auto &it : mlirAttributes) 1001 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1002 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1003 toMlirStringRef(it.first)), 1004 it.second)); 1005 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1006 mlirNamedAttributes.data()); 1007 } 1008 if (!mlirSuccessors.empty()) 1009 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1010 mlirSuccessors.data()); 1011 if (regions) { 1012 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1013 mlirRegions.resize(regions); 1014 for (int i = 0; i < regions; ++i) 1015 mlirRegions[i] = mlirRegionCreate(); 1016 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1017 mlirRegions.data()); 1018 } 1019 1020 // Construct the operation. 1021 MlirOperation operation = mlirOperationCreate(&state); 1022 PyOperationRef created = 1023 PyOperation::createDetached(location->getContext(), operation); 1024 1025 // InsertPoint active? 1026 if (!maybeIp.is(py::cast(false))) { 1027 PyInsertionPoint *ip; 1028 if (maybeIp.is_none()) { 1029 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1030 } else { 1031 ip = py::cast<PyInsertionPoint *>(maybeIp); 1032 } 1033 if (ip) 1034 ip->insert(*created.get()); 1035 } 1036 1037 return created->createOpView(); 1038 } 1039 1040 py::object PyOperation::createOpView() { 1041 checkValid(); 1042 MlirIdentifier ident = mlirOperationGetName(get()); 1043 MlirStringRef identStr = mlirIdentifierStr(ident); 1044 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1045 StringRef(identStr.data, identStr.length)); 1046 if (opViewClass) 1047 return (*opViewClass)(getRef().getObject()); 1048 return py::cast(PyOpView(getRef().getObject())); 1049 } 1050 1051 void PyOperation::erase() { 1052 checkValid(); 1053 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1054 // Python reference to a child operation is live. All children should also 1055 // have their `valid` bit set to false. 1056 auto &liveOperations = getContext()->liveOperations; 1057 if (liveOperations.count(operation.ptr)) 1058 liveOperations.erase(operation.ptr); 1059 mlirOperationDestroy(operation); 1060 valid = false; 1061 } 1062 1063 //------------------------------------------------------------------------------ 1064 // PyOpView 1065 //------------------------------------------------------------------------------ 1066 1067 py::object 1068 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1069 py::list operandList, 1070 llvm::Optional<py::dict> attributes, 1071 llvm::Optional<std::vector<PyBlock *>> successors, 1072 llvm::Optional<int> regions, 1073 DefaultingPyLocation location, py::object maybeIp) { 1074 PyMlirContextRef context = location->getContext(); 1075 // Class level operation construction metadata. 1076 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1077 // Operand and result segment specs are either none, which does no 1078 // variadic unpacking, or a list of ints with segment sizes, where each 1079 // element is either a positive number (typically 1 for a scalar) or -1 to 1080 // indicate that it is derived from the length of the same-indexed operand 1081 // or result (implying that it is a list at that position). 1082 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1083 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1084 1085 std::vector<uint32_t> operandSegmentLengths; 1086 std::vector<uint32_t> resultSegmentLengths; 1087 1088 // Validate/determine region count. 1089 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1090 int opMinRegionCount = std::get<0>(opRegionSpec); 1091 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1092 if (!regions) { 1093 regions = opMinRegionCount; 1094 } 1095 if (*regions < opMinRegionCount) { 1096 throw py::value_error( 1097 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1098 llvm::Twine(opMinRegionCount) + 1099 " regions but was built with regions=" + llvm::Twine(*regions)) 1100 .str()); 1101 } 1102 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1103 throw py::value_error( 1104 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1105 llvm::Twine(opMinRegionCount) + 1106 " regions but was built with regions=" + llvm::Twine(*regions)) 1107 .str()); 1108 } 1109 1110 // Unpack results. 1111 std::vector<PyType *> resultTypes; 1112 resultTypes.reserve(resultTypeList.size()); 1113 if (resultSegmentSpecObj.is_none()) { 1114 // Non-variadic result unpacking. 1115 for (auto it : llvm::enumerate(resultTypeList)) { 1116 try { 1117 resultTypes.push_back(py::cast<PyType *>(it.value())); 1118 if (!resultTypes.back()) 1119 throw py::cast_error(); 1120 } catch (py::cast_error &err) { 1121 throw py::value_error((llvm::Twine("Result ") + 1122 llvm::Twine(it.index()) + " of operation \"" + 1123 name + "\" must be a Type (" + err.what() + ")") 1124 .str()); 1125 } 1126 } 1127 } else { 1128 // Sized result unpacking. 1129 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1130 if (resultSegmentSpec.size() != resultTypeList.size()) { 1131 throw py::value_error((llvm::Twine("Operation \"") + name + 1132 "\" requires " + 1133 llvm::Twine(resultSegmentSpec.size()) + 1134 "result segments but was provided " + 1135 llvm::Twine(resultTypeList.size())) 1136 .str()); 1137 } 1138 resultSegmentLengths.reserve(resultTypeList.size()); 1139 for (auto it : 1140 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1141 int segmentSpec = std::get<1>(it.value()); 1142 if (segmentSpec == 1 || segmentSpec == 0) { 1143 // Unpack unary element. 1144 try { 1145 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1146 if (resultType) { 1147 resultTypes.push_back(resultType); 1148 resultSegmentLengths.push_back(1); 1149 } else if (segmentSpec == 0) { 1150 // Allowed to be optional. 1151 resultSegmentLengths.push_back(0); 1152 } else { 1153 throw py::cast_error("was None and result is not optional"); 1154 } 1155 } catch (py::cast_error &err) { 1156 throw py::value_error((llvm::Twine("Result ") + 1157 llvm::Twine(it.index()) + " of operation \"" + 1158 name + "\" must be a Type (" + err.what() + 1159 ")") 1160 .str()); 1161 } 1162 } else if (segmentSpec == -1) { 1163 // Unpack sequence by appending. 1164 try { 1165 if (std::get<0>(it.value()).is_none()) { 1166 // Treat it as an empty list. 1167 resultSegmentLengths.push_back(0); 1168 } else { 1169 // Unpack the list. 1170 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1171 for (py::object segmentItem : segment) { 1172 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1173 if (!resultTypes.back()) { 1174 throw py::cast_error("contained a None item"); 1175 } 1176 } 1177 resultSegmentLengths.push_back(segment.size()); 1178 } 1179 } catch (std::exception &err) { 1180 // NOTE: Sloppy to be using a catch-all here, but there are at least 1181 // three different unrelated exceptions that can be thrown in the 1182 // above "casts". Just keep the scope above small and catch them all. 1183 throw py::value_error((llvm::Twine("Result ") + 1184 llvm::Twine(it.index()) + " of operation \"" + 1185 name + "\" must be a Sequence of Types (" + 1186 err.what() + ")") 1187 .str()); 1188 } 1189 } else { 1190 throw py::value_error("Unexpected segment spec"); 1191 } 1192 } 1193 } 1194 1195 // Unpack operands. 1196 std::vector<PyValue *> operands; 1197 operands.reserve(operands.size()); 1198 if (operandSegmentSpecObj.is_none()) { 1199 // Non-sized operand unpacking. 1200 for (auto it : llvm::enumerate(operandList)) { 1201 try { 1202 operands.push_back(py::cast<PyValue *>(it.value())); 1203 if (!operands.back()) 1204 throw py::cast_error(); 1205 } catch (py::cast_error &err) { 1206 throw py::value_error((llvm::Twine("Operand ") + 1207 llvm::Twine(it.index()) + " of operation \"" + 1208 name + "\" must be a Value (" + err.what() + ")") 1209 .str()); 1210 } 1211 } 1212 } else { 1213 // Sized operand unpacking. 1214 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1215 if (operandSegmentSpec.size() != operandList.size()) { 1216 throw py::value_error((llvm::Twine("Operation \"") + name + 1217 "\" requires " + 1218 llvm::Twine(operandSegmentSpec.size()) + 1219 "operand segments but was provided " + 1220 llvm::Twine(operandList.size())) 1221 .str()); 1222 } 1223 operandSegmentLengths.reserve(operandList.size()); 1224 for (auto it : 1225 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1226 int segmentSpec = std::get<1>(it.value()); 1227 if (segmentSpec == 1 || segmentSpec == 0) { 1228 // Unpack unary element. 1229 try { 1230 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1231 if (operandValue) { 1232 operands.push_back(operandValue); 1233 operandSegmentLengths.push_back(1); 1234 } else if (segmentSpec == 0) { 1235 // Allowed to be optional. 1236 operandSegmentLengths.push_back(0); 1237 } else { 1238 throw py::cast_error("was None and operand is not optional"); 1239 } 1240 } catch (py::cast_error &err) { 1241 throw py::value_error((llvm::Twine("Operand ") + 1242 llvm::Twine(it.index()) + " of operation \"" + 1243 name + "\" must be a Value (" + err.what() + 1244 ")") 1245 .str()); 1246 } 1247 } else if (segmentSpec == -1) { 1248 // Unpack sequence by appending. 1249 try { 1250 if (std::get<0>(it.value()).is_none()) { 1251 // Treat it as an empty list. 1252 operandSegmentLengths.push_back(0); 1253 } else { 1254 // Unpack the list. 1255 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1256 for (py::object segmentItem : segment) { 1257 operands.push_back(py::cast<PyValue *>(segmentItem)); 1258 if (!operands.back()) { 1259 throw py::cast_error("contained a None item"); 1260 } 1261 } 1262 operandSegmentLengths.push_back(segment.size()); 1263 } 1264 } catch (std::exception &err) { 1265 // NOTE: Sloppy to be using a catch-all here, but there are at least 1266 // three different unrelated exceptions that can be thrown in the 1267 // above "casts". Just keep the scope above small and catch them all. 1268 throw py::value_error((llvm::Twine("Operand ") + 1269 llvm::Twine(it.index()) + " of operation \"" + 1270 name + "\" must be a Sequence of Values (" + 1271 err.what() + ")") 1272 .str()); 1273 } 1274 } else { 1275 throw py::value_error("Unexpected segment spec"); 1276 } 1277 } 1278 } 1279 1280 // Merge operand/result segment lengths into attributes if needed. 1281 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1282 // Dup. 1283 if (attributes) { 1284 attributes = py::dict(*attributes); 1285 } else { 1286 attributes = py::dict(); 1287 } 1288 if (attributes->contains("result_segment_sizes") || 1289 attributes->contains("operand_segment_sizes")) { 1290 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1291 "'operand_segment_sizes' attribute is unsupported. " 1292 "Use Operation.create for such low-level access."); 1293 } 1294 1295 // Add result_segment_sizes attribute. 1296 if (!resultSegmentLengths.empty()) { 1297 int64_t size = resultSegmentLengths.size(); 1298 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1299 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1300 resultSegmentLengths.size(), resultSegmentLengths.data()); 1301 (*attributes)["result_segment_sizes"] = 1302 PyAttribute(context, segmentLengthAttr); 1303 } 1304 1305 // Add operand_segment_sizes attribute. 1306 if (!operandSegmentLengths.empty()) { 1307 int64_t size = operandSegmentLengths.size(); 1308 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1309 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1310 operandSegmentLengths.size(), operandSegmentLengths.data()); 1311 (*attributes)["operand_segment_sizes"] = 1312 PyAttribute(context, segmentLengthAttr); 1313 } 1314 } 1315 1316 // Delegate to create. 1317 return PyOperation::create(std::move(name), 1318 /*results=*/std::move(resultTypes), 1319 /*operands=*/std::move(operands), 1320 /*attributes=*/std::move(attributes), 1321 /*successors=*/std::move(successors), 1322 /*regions=*/*regions, location, maybeIp); 1323 } 1324 1325 PyOpView::PyOpView(py::object operationObject) 1326 // Casting through the PyOperationBase base-class and then back to the 1327 // Operation lets us accept any PyOperationBase subclass. 1328 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1329 operationObject(operation.getRef().getObject()) {} 1330 1331 py::object PyOpView::createRawSubclass(py::object userClass) { 1332 // This is... a little gross. The typical pattern is to have a pure python 1333 // class that extends OpView like: 1334 // class AddFOp(_cext.ir.OpView): 1335 // def __init__(self, loc, lhs, rhs): 1336 // operation = loc.context.create_operation( 1337 // "addf", lhs, rhs, results=[lhs.type]) 1338 // super().__init__(operation) 1339 // 1340 // I.e. The goal of the user facing type is to provide a nice constructor 1341 // that has complete freedom for the op under construction. This is at odds 1342 // with our other desire to sometimes create this object by just passing an 1343 // operation (to initialize the base class). We could do *arg and **kwargs 1344 // munging to try to make it work, but instead, we synthesize a new class 1345 // on the fly which extends this user class (AddFOp in this example) and 1346 // *give it* the base class's __init__ method, thus bypassing the 1347 // intermediate subclass's __init__ method entirely. While slightly, 1348 // underhanded, this is safe/legal because the type hierarchy has not changed 1349 // (we just added a new leaf) and we aren't mucking around with __new__. 1350 // Typically, this new class will be stored on the original as "_Raw" and will 1351 // be used for casts and other things that need a variant of the class that 1352 // is initialized purely from an operation. 1353 py::object parentMetaclass = 1354 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1355 py::dict attributes; 1356 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1357 // now. 1358 // auto opViewType = py::type::of<PyOpView>(); 1359 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1360 attributes["__init__"] = opViewType.attr("__init__"); 1361 py::str origName = userClass.attr("__name__"); 1362 py::str newName = py::str("_") + origName; 1363 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1364 } 1365 1366 //------------------------------------------------------------------------------ 1367 // PyInsertionPoint. 1368 //------------------------------------------------------------------------------ 1369 1370 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1371 1372 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1373 : refOperation(beforeOperationBase.getOperation().getRef()), 1374 block((*refOperation)->getBlock()) {} 1375 1376 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1377 PyOperation &operation = operationBase.getOperation(); 1378 if (operation.isAttached()) 1379 throw SetPyError(PyExc_ValueError, 1380 "Attempt to insert operation that is already attached"); 1381 block.getParentOperation()->checkValid(); 1382 MlirOperation beforeOp = {nullptr}; 1383 if (refOperation) { 1384 // Insert before operation. 1385 (*refOperation)->checkValid(); 1386 beforeOp = (*refOperation)->get(); 1387 } else { 1388 // Insert at end (before null) is only valid if the block does not 1389 // already end in a known terminator (violating this will cause assertion 1390 // failures later). 1391 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1392 throw py::index_error("Cannot insert operation at the end of a block " 1393 "that already has a terminator. Did you mean to " 1394 "use 'InsertionPoint.at_block_terminator(block)' " 1395 "versus 'InsertionPoint(block)'?"); 1396 } 1397 } 1398 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1399 operation.setAttached(); 1400 } 1401 1402 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1403 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1404 if (mlirOperationIsNull(firstOp)) { 1405 // Just insert at end. 1406 return PyInsertionPoint(block); 1407 } 1408 1409 // Insert before first op. 1410 PyOperationRef firstOpRef = PyOperation::forOperation( 1411 block.getParentOperation()->getContext(), firstOp); 1412 return PyInsertionPoint{block, std::move(firstOpRef)}; 1413 } 1414 1415 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1416 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1417 if (mlirOperationIsNull(terminator)) 1418 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1419 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1420 block.getParentOperation()->getContext(), terminator); 1421 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1422 } 1423 1424 py::object PyInsertionPoint::contextEnter() { 1425 return PyThreadContextEntry::pushInsertionPoint(*this); 1426 } 1427 1428 void PyInsertionPoint::contextExit(pybind11::object excType, 1429 pybind11::object excVal, 1430 pybind11::object excTb) { 1431 PyThreadContextEntry::popInsertionPoint(*this); 1432 } 1433 1434 //------------------------------------------------------------------------------ 1435 // PyAttribute. 1436 //------------------------------------------------------------------------------ 1437 1438 bool PyAttribute::operator==(const PyAttribute &other) { 1439 return mlirAttributeEqual(attr, other.attr); 1440 } 1441 1442 py::object PyAttribute::getCapsule() { 1443 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1444 } 1445 1446 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1447 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1448 if (mlirAttributeIsNull(rawAttr)) 1449 throw py::error_already_set(); 1450 return PyAttribute( 1451 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1452 } 1453 1454 //------------------------------------------------------------------------------ 1455 // PyNamedAttribute. 1456 //------------------------------------------------------------------------------ 1457 1458 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1459 : ownedName(new std::string(std::move(ownedName))) { 1460 namedAttr = mlirNamedAttributeGet( 1461 mlirIdentifierGet(mlirAttributeGetContext(attr), 1462 toMlirStringRef(*this->ownedName)), 1463 attr); 1464 } 1465 1466 //------------------------------------------------------------------------------ 1467 // PyType. 1468 //------------------------------------------------------------------------------ 1469 1470 bool PyType::operator==(const PyType &other) { 1471 return mlirTypeEqual(type, other.type); 1472 } 1473 1474 py::object PyType::getCapsule() { 1475 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1476 } 1477 1478 PyType PyType::createFromCapsule(py::object capsule) { 1479 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1480 if (mlirTypeIsNull(rawType)) 1481 throw py::error_already_set(); 1482 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1483 rawType); 1484 } 1485 1486 //------------------------------------------------------------------------------ 1487 // PyValue and subclases. 1488 //------------------------------------------------------------------------------ 1489 1490 pybind11::object PyValue::getCapsule() { 1491 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1492 } 1493 1494 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1495 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1496 if (mlirValueIsNull(value)) 1497 throw py::error_already_set(); 1498 MlirOperation owner; 1499 if (mlirValueIsAOpResult(value)) 1500 owner = mlirOpResultGetOwner(value); 1501 if (mlirValueIsABlockArgument(value)) 1502 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1503 if (mlirOperationIsNull(owner)) 1504 throw py::error_already_set(); 1505 MlirContext ctx = mlirOperationGetContext(owner); 1506 PyOperationRef ownerRef = 1507 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1508 return PyValue(ownerRef, value); 1509 } 1510 1511 namespace { 1512 /// CRTP base class for Python MLIR values that subclass Value and should be 1513 /// castable from it. The value hierarchy is one level deep and is not supposed 1514 /// to accommodate other levels unless core MLIR changes. 1515 template <typename DerivedTy> 1516 class PyConcreteValue : public PyValue { 1517 public: 1518 // Derived classes must define statics for: 1519 // IsAFunctionTy isaFunction 1520 // const char *pyClassName 1521 // and redefine bindDerived. 1522 using ClassTy = py::class_<DerivedTy, PyValue>; 1523 using IsAFunctionTy = bool (*)(MlirValue); 1524 1525 PyConcreteValue() = default; 1526 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1527 : PyValue(operationRef, value) {} 1528 PyConcreteValue(PyValue &orig) 1529 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1530 1531 /// Attempts to cast the original value to the derived type and throws on 1532 /// type mismatches. 1533 static MlirValue castFrom(PyValue &orig) { 1534 if (!DerivedTy::isaFunction(orig.get())) { 1535 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1536 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1537 DerivedTy::pyClassName + 1538 " (from " + origRepr + ")"); 1539 } 1540 return orig.get(); 1541 } 1542 1543 /// Binds the Python module objects to functions of this class. 1544 static void bind(py::module &m) { 1545 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1546 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1547 DerivedTy::bindDerived(cls); 1548 } 1549 1550 /// Implemented by derived classes to add methods to the Python subclass. 1551 static void bindDerived(ClassTy &m) {} 1552 }; 1553 1554 /// Python wrapper for MlirBlockArgument. 1555 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1556 public: 1557 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1558 static constexpr const char *pyClassName = "BlockArgument"; 1559 using PyConcreteValue::PyConcreteValue; 1560 1561 static void bindDerived(ClassTy &c) { 1562 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1563 return PyBlock(self.getParentOperation(), 1564 mlirBlockArgumentGetOwner(self.get())); 1565 }); 1566 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1567 return mlirBlockArgumentGetArgNumber(self.get()); 1568 }); 1569 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1570 return mlirBlockArgumentSetType(self.get(), type); 1571 }); 1572 } 1573 }; 1574 1575 /// Python wrapper for MlirOpResult. 1576 class PyOpResult : public PyConcreteValue<PyOpResult> { 1577 public: 1578 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1579 static constexpr const char *pyClassName = "OpResult"; 1580 using PyConcreteValue::PyConcreteValue; 1581 1582 static void bindDerived(ClassTy &c) { 1583 c.def_property_readonly("owner", [](PyOpResult &self) { 1584 assert( 1585 mlirOperationEqual(self.getParentOperation()->get(), 1586 mlirOpResultGetOwner(self.get())) && 1587 "expected the owner of the value in Python to match that in the IR"); 1588 return self.getParentOperation().getObject(); 1589 }); 1590 c.def_property_readonly("result_number", [](PyOpResult &self) { 1591 return mlirOpResultGetResultNumber(self.get()); 1592 }); 1593 } 1594 }; 1595 1596 /// Returns the list of types of the values held by container. 1597 template <typename Container> 1598 static std::vector<PyType> getValueTypes(Container &container, 1599 PyMlirContextRef &context) { 1600 std::vector<PyType> result; 1601 result.reserve(container.getNumElements()); 1602 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1603 result.push_back( 1604 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1605 } 1606 return result; 1607 } 1608 1609 /// A list of block arguments. Internally, these are stored as consecutive 1610 /// elements, random access is cheap. The argument list is associated with the 1611 /// operation that contains the block (detached blocks are not allowed in 1612 /// Python bindings) and extends its lifetime. 1613 class PyBlockArgumentList 1614 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1615 public: 1616 static constexpr const char *pyClassName = "BlockArgumentList"; 1617 1618 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1619 intptr_t startIndex = 0, intptr_t length = -1, 1620 intptr_t step = 1) 1621 : Sliceable(startIndex, 1622 length == -1 ? mlirBlockGetNumArguments(block) : length, 1623 step), 1624 operation(std::move(operation)), block(block) {} 1625 1626 /// Returns the number of arguments in the list. 1627 intptr_t getNumElements() { 1628 operation->checkValid(); 1629 return mlirBlockGetNumArguments(block); 1630 } 1631 1632 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1633 PyBlockArgument getElement(intptr_t pos) { 1634 MlirValue argument = mlirBlockGetArgument(block, pos); 1635 return PyBlockArgument(operation, argument); 1636 } 1637 1638 /// Returns a sublist of this list. 1639 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1640 intptr_t step) { 1641 return PyBlockArgumentList(operation, block, startIndex, length, step); 1642 } 1643 1644 static void bindDerived(ClassTy &c) { 1645 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1646 return getValueTypes(self, self.operation->getContext()); 1647 }); 1648 } 1649 1650 private: 1651 PyOperationRef operation; 1652 MlirBlock block; 1653 }; 1654 1655 /// A list of operation operands. Internally, these are stored as consecutive 1656 /// elements, random access is cheap. The result list is associated with the 1657 /// operation whose results these are, and extends the lifetime of this 1658 /// operation. 1659 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1660 public: 1661 static constexpr const char *pyClassName = "OpOperandList"; 1662 1663 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1664 intptr_t length = -1, intptr_t step = 1) 1665 : Sliceable(startIndex, 1666 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1667 : length, 1668 step), 1669 operation(operation) {} 1670 1671 intptr_t getNumElements() { 1672 operation->checkValid(); 1673 return mlirOperationGetNumOperands(operation->get()); 1674 } 1675 1676 PyValue getElement(intptr_t pos) { 1677 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1678 MlirOperation owner; 1679 if (mlirValueIsAOpResult(operand)) 1680 owner = mlirOpResultGetOwner(operand); 1681 else if (mlirValueIsABlockArgument(operand)) 1682 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1683 else 1684 assert(false && "Value must be an block arg or op result."); 1685 PyOperationRef pyOwner = 1686 PyOperation::forOperation(operation->getContext(), owner); 1687 return PyValue(pyOwner, operand); 1688 } 1689 1690 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1691 return PyOpOperandList(operation, startIndex, length, step); 1692 } 1693 1694 void dunderSetItem(intptr_t index, PyValue value) { 1695 index = wrapIndex(index); 1696 mlirOperationSetOperand(operation->get(), index, value.get()); 1697 } 1698 1699 static void bindDerived(ClassTy &c) { 1700 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1701 } 1702 1703 private: 1704 PyOperationRef operation; 1705 }; 1706 1707 /// A list of operation results. Internally, these are stored as consecutive 1708 /// elements, random access is cheap. The result list is associated with the 1709 /// operation whose results these are, and extends the lifetime of this 1710 /// operation. 1711 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1712 public: 1713 static constexpr const char *pyClassName = "OpResultList"; 1714 1715 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1716 intptr_t length = -1, intptr_t step = 1) 1717 : Sliceable(startIndex, 1718 length == -1 ? mlirOperationGetNumResults(operation->get()) 1719 : length, 1720 step), 1721 operation(operation) {} 1722 1723 intptr_t getNumElements() { 1724 operation->checkValid(); 1725 return mlirOperationGetNumResults(operation->get()); 1726 } 1727 1728 PyOpResult getElement(intptr_t index) { 1729 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1730 return PyOpResult(value); 1731 } 1732 1733 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1734 return PyOpResultList(operation, startIndex, length, step); 1735 } 1736 1737 static void bindDerived(ClassTy &c) { 1738 c.def_property_readonly("types", [](PyOpResultList &self) { 1739 return getValueTypes(self, self.operation->getContext()); 1740 }); 1741 } 1742 1743 private: 1744 PyOperationRef operation; 1745 }; 1746 1747 /// A list of operation attributes. Can be indexed by name, producing 1748 /// attributes, or by index, producing named attributes. 1749 class PyOpAttributeMap { 1750 public: 1751 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1752 1753 PyAttribute dunderGetItemNamed(const std::string &name) { 1754 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1755 toMlirStringRef(name)); 1756 if (mlirAttributeIsNull(attr)) { 1757 throw SetPyError(PyExc_KeyError, 1758 "attempt to access a non-existent attribute"); 1759 } 1760 return PyAttribute(operation->getContext(), attr); 1761 } 1762 1763 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1764 if (index < 0 || index >= dunderLen()) { 1765 throw SetPyError(PyExc_IndexError, 1766 "attempt to access out of bounds attribute"); 1767 } 1768 MlirNamedAttribute namedAttr = 1769 mlirOperationGetAttribute(operation->get(), index); 1770 return PyNamedAttribute( 1771 namedAttr.attribute, 1772 std::string(mlirIdentifierStr(namedAttr.name).data)); 1773 } 1774 1775 void dunderSetItem(const std::string &name, PyAttribute attr) { 1776 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1777 attr); 1778 } 1779 1780 void dunderDelItem(const std::string &name) { 1781 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1782 toMlirStringRef(name)); 1783 if (!removed) 1784 throw SetPyError(PyExc_KeyError, 1785 "attempt to delete a non-existent attribute"); 1786 } 1787 1788 intptr_t dunderLen() { 1789 return mlirOperationGetNumAttributes(operation->get()); 1790 } 1791 1792 bool dunderContains(const std::string &name) { 1793 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1794 operation->get(), toMlirStringRef(name))); 1795 } 1796 1797 static void bind(py::module &m) { 1798 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1799 .def("__contains__", &PyOpAttributeMap::dunderContains) 1800 .def("__len__", &PyOpAttributeMap::dunderLen) 1801 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1802 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1803 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1804 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1805 } 1806 1807 private: 1808 PyOperationRef operation; 1809 }; 1810 1811 } // end namespace 1812 1813 //------------------------------------------------------------------------------ 1814 // Populates the core exports of the 'ir' submodule. 1815 //------------------------------------------------------------------------------ 1816 1817 void mlir::python::populateIRCore(py::module &m) { 1818 //---------------------------------------------------------------------------- 1819 // Mapping of MlirContext. 1820 //---------------------------------------------------------------------------- 1821 py::class_<PyMlirContext>(m, "Context", py::module_local()) 1822 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1823 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1824 .def("_get_context_again", 1825 [](PyMlirContext &self) { 1826 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1827 return ref.releaseObject(); 1828 }) 1829 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1830 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1831 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1832 &PyMlirContext::getCapsule) 1833 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1834 .def("__enter__", &PyMlirContext::contextEnter) 1835 .def("__exit__", &PyMlirContext::contextExit) 1836 .def_property_readonly_static( 1837 "current", 1838 [](py::object & /*class*/) { 1839 auto *context = PyThreadContextEntry::getDefaultContext(); 1840 if (!context) 1841 throw SetPyError(PyExc_ValueError, "No current Context"); 1842 return context; 1843 }, 1844 "Gets the Context bound to the current thread or raises ValueError") 1845 .def_property_readonly( 1846 "dialects", 1847 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1848 "Gets a container for accessing dialects by name") 1849 .def_property_readonly( 1850 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1851 "Alias for 'dialect'") 1852 .def( 1853 "get_dialect_descriptor", 1854 [=](PyMlirContext &self, std::string &name) { 1855 MlirDialect dialect = mlirContextGetOrLoadDialect( 1856 self.get(), {name.data(), name.size()}); 1857 if (mlirDialectIsNull(dialect)) { 1858 throw SetPyError(PyExc_ValueError, 1859 Twine("Dialect '") + name + "' not found"); 1860 } 1861 return PyDialectDescriptor(self.getRef(), dialect); 1862 }, 1863 "Gets or loads a dialect by name, returning its descriptor object") 1864 .def_property( 1865 "allow_unregistered_dialects", 1866 [](PyMlirContext &self) -> bool { 1867 return mlirContextGetAllowUnregisteredDialects(self.get()); 1868 }, 1869 [](PyMlirContext &self, bool value) { 1870 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1871 }) 1872 .def("enable_multithreading", 1873 [](PyMlirContext &self, bool enable) { 1874 mlirContextEnableMultithreading(self.get(), enable); 1875 }) 1876 .def("is_registered_operation", 1877 [](PyMlirContext &self, std::string &name) { 1878 return mlirContextIsRegisteredOperation( 1879 self.get(), MlirStringRef{name.data(), name.size()}); 1880 }); 1881 1882 //---------------------------------------------------------------------------- 1883 // Mapping of PyDialectDescriptor 1884 //---------------------------------------------------------------------------- 1885 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1886 .def_property_readonly("namespace", 1887 [](PyDialectDescriptor &self) { 1888 MlirStringRef ns = 1889 mlirDialectGetNamespace(self.get()); 1890 return py::str(ns.data, ns.length); 1891 }) 1892 .def("__repr__", [](PyDialectDescriptor &self) { 1893 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1894 std::string repr("<DialectDescriptor "); 1895 repr.append(ns.data, ns.length); 1896 repr.append(">"); 1897 return repr; 1898 }); 1899 1900 //---------------------------------------------------------------------------- 1901 // Mapping of PyDialects 1902 //---------------------------------------------------------------------------- 1903 py::class_<PyDialects>(m, "Dialects", py::module_local()) 1904 .def("__getitem__", 1905 [=](PyDialects &self, std::string keyName) { 1906 MlirDialect dialect = 1907 self.getDialectForKey(keyName, /*attrError=*/false); 1908 py::object descriptor = 1909 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1910 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1911 }) 1912 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1913 MlirDialect dialect = 1914 self.getDialectForKey(attrName, /*attrError=*/true); 1915 py::object descriptor = 1916 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1917 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1918 }); 1919 1920 //---------------------------------------------------------------------------- 1921 // Mapping of PyDialect 1922 //---------------------------------------------------------------------------- 1923 py::class_<PyDialect>(m, "Dialect", py::module_local()) 1924 .def(py::init<py::object>(), "descriptor") 1925 .def_property_readonly( 1926 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1927 .def("__repr__", [](py::object self) { 1928 auto clazz = self.attr("__class__"); 1929 return py::str("<Dialect ") + 1930 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1931 clazz.attr("__module__") + py::str(".") + 1932 clazz.attr("__name__") + py::str(")>"); 1933 }); 1934 1935 //---------------------------------------------------------------------------- 1936 // Mapping of Location 1937 //---------------------------------------------------------------------------- 1938 py::class_<PyLocation>(m, "Location", py::module_local()) 1939 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1940 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1941 .def("__enter__", &PyLocation::contextEnter) 1942 .def("__exit__", &PyLocation::contextExit) 1943 .def("__eq__", 1944 [](PyLocation &self, PyLocation &other) -> bool { 1945 return mlirLocationEqual(self, other); 1946 }) 1947 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1948 .def_property_readonly_static( 1949 "current", 1950 [](py::object & /*class*/) { 1951 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1952 if (!loc) 1953 throw SetPyError(PyExc_ValueError, "No current Location"); 1954 return loc; 1955 }, 1956 "Gets the Location bound to the current thread or raises ValueError") 1957 .def_static( 1958 "unknown", 1959 [](DefaultingPyMlirContext context) { 1960 return PyLocation(context->getRef(), 1961 mlirLocationUnknownGet(context->get())); 1962 }, 1963 py::arg("context") = py::none(), 1964 "Gets a Location representing an unknown location") 1965 .def_static( 1966 "file", 1967 [](std::string filename, int line, int col, 1968 DefaultingPyMlirContext context) { 1969 return PyLocation( 1970 context->getRef(), 1971 mlirLocationFileLineColGet( 1972 context->get(), toMlirStringRef(filename), line, col)); 1973 }, 1974 py::arg("filename"), py::arg("line"), py::arg("col"), 1975 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1976 .def_static( 1977 "name", 1978 [](std::string name, llvm::Optional<PyLocation> childLoc, 1979 DefaultingPyMlirContext context) { 1980 return PyLocation( 1981 context->getRef(), 1982 mlirLocationNameGet( 1983 context->get(), toMlirStringRef(name), 1984 childLoc ? childLoc->get() 1985 : mlirLocationUnknownGet(context->get()))); 1986 }, 1987 py::arg("name"), py::arg("childLoc") = py::none(), 1988 py::arg("context") = py::none(), kContextGetNameLocationDocString) 1989 .def_property_readonly( 1990 "context", 1991 [](PyLocation &self) { return self.getContext().getObject(); }, 1992 "Context that owns the Location") 1993 .def("__repr__", [](PyLocation &self) { 1994 PyPrintAccumulator printAccum; 1995 mlirLocationPrint(self, printAccum.getCallback(), 1996 printAccum.getUserData()); 1997 return printAccum.join(); 1998 }); 1999 2000 //---------------------------------------------------------------------------- 2001 // Mapping of Module 2002 //---------------------------------------------------------------------------- 2003 py::class_<PyModule>(m, "Module", py::module_local()) 2004 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2005 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2006 .def_static( 2007 "parse", 2008 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2009 MlirModule module = mlirModuleCreateParse( 2010 context->get(), toMlirStringRef(moduleAsm)); 2011 // TODO: Rework error reporting once diagnostic engine is exposed 2012 // in C API. 2013 if (mlirModuleIsNull(module)) { 2014 throw SetPyError( 2015 PyExc_ValueError, 2016 "Unable to parse module assembly (see diagnostics)"); 2017 } 2018 return PyModule::forModule(module).releaseObject(); 2019 }, 2020 py::arg("asm"), py::arg("context") = py::none(), 2021 kModuleParseDocstring) 2022 .def_static( 2023 "create", 2024 [](DefaultingPyLocation loc) { 2025 MlirModule module = mlirModuleCreateEmpty(loc); 2026 return PyModule::forModule(module).releaseObject(); 2027 }, 2028 py::arg("loc") = py::none(), "Creates an empty module") 2029 .def_property_readonly( 2030 "context", 2031 [](PyModule &self) { return self.getContext().getObject(); }, 2032 "Context that created the Module") 2033 .def_property_readonly( 2034 "operation", 2035 [](PyModule &self) { 2036 return PyOperation::forOperation(self.getContext(), 2037 mlirModuleGetOperation(self.get()), 2038 self.getRef().releaseObject()) 2039 .releaseObject(); 2040 }, 2041 "Accesses the module as an operation") 2042 .def_property_readonly( 2043 "body", 2044 [](PyModule &self) { 2045 PyOperationRef module_op = PyOperation::forOperation( 2046 self.getContext(), mlirModuleGetOperation(self.get()), 2047 self.getRef().releaseObject()); 2048 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2049 return returnBlock; 2050 }, 2051 "Return the block for this module") 2052 .def( 2053 "dump", 2054 [](PyModule &self) { 2055 mlirOperationDump(mlirModuleGetOperation(self.get())); 2056 }, 2057 kDumpDocstring) 2058 .def( 2059 "__str__", 2060 [](PyModule &self) { 2061 MlirOperation operation = mlirModuleGetOperation(self.get()); 2062 PyPrintAccumulator printAccum; 2063 mlirOperationPrint(operation, printAccum.getCallback(), 2064 printAccum.getUserData()); 2065 return printAccum.join(); 2066 }, 2067 kOperationStrDunderDocstring); 2068 2069 //---------------------------------------------------------------------------- 2070 // Mapping of Operation. 2071 //---------------------------------------------------------------------------- 2072 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2073 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2074 [](PyOperationBase &self) { 2075 return self.getOperation().getCapsule(); 2076 }) 2077 .def("__eq__", 2078 [](PyOperationBase &self, PyOperationBase &other) { 2079 return &self.getOperation() == &other.getOperation(); 2080 }) 2081 .def("__eq__", 2082 [](PyOperationBase &self, py::object other) { return false; }) 2083 .def_property_readonly("attributes", 2084 [](PyOperationBase &self) { 2085 return PyOpAttributeMap( 2086 self.getOperation().getRef()); 2087 }) 2088 .def_property_readonly("operands", 2089 [](PyOperationBase &self) { 2090 return PyOpOperandList( 2091 self.getOperation().getRef()); 2092 }) 2093 .def_property_readonly("regions", 2094 [](PyOperationBase &self) { 2095 return PyRegionList( 2096 self.getOperation().getRef()); 2097 }) 2098 .def_property_readonly( 2099 "results", 2100 [](PyOperationBase &self) { 2101 return PyOpResultList(self.getOperation().getRef()); 2102 }, 2103 "Returns the list of Operation results.") 2104 .def_property_readonly( 2105 "result", 2106 [](PyOperationBase &self) { 2107 auto &operation = self.getOperation(); 2108 auto numResults = mlirOperationGetNumResults(operation); 2109 if (numResults != 1) { 2110 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2111 throw SetPyError( 2112 PyExc_ValueError, 2113 Twine("Cannot call .result on operation ") + 2114 StringRef(name.data, name.length) + " which has " + 2115 Twine(numResults) + 2116 " results (it is only valid for operations with a " 2117 "single result)"); 2118 } 2119 return PyOpResult(operation.getRef(), 2120 mlirOperationGetResult(operation, 0)); 2121 }, 2122 "Shortcut to get an op result if it has only one (throws an error " 2123 "otherwise).") 2124 .def("__iter__", 2125 [](PyOperationBase &self) { 2126 return PyRegionIterator(self.getOperation().getRef()); 2127 }) 2128 .def( 2129 "__str__", 2130 [](PyOperationBase &self) { 2131 return self.getAsm(/*binary=*/false, 2132 /*largeElementsLimit=*/llvm::None, 2133 /*enableDebugInfo=*/false, 2134 /*prettyDebugInfo=*/false, 2135 /*printGenericOpForm=*/false, 2136 /*useLocalScope=*/false); 2137 }, 2138 "Returns the assembly form of the operation.") 2139 .def("print", &PyOperationBase::print, 2140 // Careful: Lots of arguments must match up with print method. 2141 py::arg("file") = py::none(), py::arg("binary") = false, 2142 py::arg("large_elements_limit") = py::none(), 2143 py::arg("enable_debug_info") = false, 2144 py::arg("pretty_debug_info") = false, 2145 py::arg("print_generic_op_form") = false, 2146 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2147 .def("get_asm", &PyOperationBase::getAsm, 2148 // Careful: Lots of arguments must match up with get_asm method. 2149 py::arg("binary") = false, 2150 py::arg("large_elements_limit") = py::none(), 2151 py::arg("enable_debug_info") = false, 2152 py::arg("pretty_debug_info") = false, 2153 py::arg("print_generic_op_form") = false, 2154 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2155 .def( 2156 "verify", 2157 [](PyOperationBase &self) { 2158 return mlirOperationVerify(self.getOperation()); 2159 }, 2160 "Verify the operation and return true if it passes, false if it " 2161 "fails."); 2162 2163 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2164 .def_static("create", &PyOperation::create, py::arg("name"), 2165 py::arg("results") = py::none(), 2166 py::arg("operands") = py::none(), 2167 py::arg("attributes") = py::none(), 2168 py::arg("successors") = py::none(), py::arg("regions") = 0, 2169 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2170 kOperationCreateDocstring) 2171 .def_property_readonly("parent", 2172 [](PyOperation &self) -> py::object { 2173 auto parent = self.getParentOperation(); 2174 if (parent) 2175 return parent->getObject(); 2176 return py::none(); 2177 }) 2178 .def("erase", &PyOperation::erase) 2179 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2180 &PyOperation::getCapsule) 2181 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2182 .def_property_readonly("name", 2183 [](PyOperation &self) { 2184 self.checkValid(); 2185 MlirOperation operation = self.get(); 2186 MlirStringRef name = mlirIdentifierStr( 2187 mlirOperationGetName(operation)); 2188 return py::str(name.data, name.length); 2189 }) 2190 .def_property_readonly( 2191 "context", 2192 [](PyOperation &self) { 2193 self.checkValid(); 2194 return self.getContext().getObject(); 2195 }, 2196 "Context that owns the Operation") 2197 .def_property_readonly("opview", &PyOperation::createOpView); 2198 2199 auto opViewClass = 2200 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2201 .def(py::init<py::object>()) 2202 .def_property_readonly("operation", &PyOpView::getOperationObject) 2203 .def_property_readonly( 2204 "context", 2205 [](PyOpView &self) { 2206 return self.getOperation().getContext().getObject(); 2207 }, 2208 "Context that owns the Operation") 2209 .def("__str__", [](PyOpView &self) { 2210 return py::str(self.getOperationObject()); 2211 }); 2212 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2213 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2214 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2215 opViewClass.attr("build_generic") = classmethod( 2216 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2217 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2218 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2219 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2220 "Builds a specific, generated OpView based on class level attributes."); 2221 2222 //---------------------------------------------------------------------------- 2223 // Mapping of PyRegion. 2224 //---------------------------------------------------------------------------- 2225 py::class_<PyRegion>(m, "Region", py::module_local()) 2226 .def_property_readonly( 2227 "blocks", 2228 [](PyRegion &self) { 2229 return PyBlockList(self.getParentOperation(), self.get()); 2230 }, 2231 "Returns a forward-optimized sequence of blocks.") 2232 .def( 2233 "__iter__", 2234 [](PyRegion &self) { 2235 self.checkValid(); 2236 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2237 return PyBlockIterator(self.getParentOperation(), firstBlock); 2238 }, 2239 "Iterates over blocks in the region.") 2240 .def("__eq__", 2241 [](PyRegion &self, PyRegion &other) { 2242 return self.get().ptr == other.get().ptr; 2243 }) 2244 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2245 2246 //---------------------------------------------------------------------------- 2247 // Mapping of PyBlock. 2248 //---------------------------------------------------------------------------- 2249 py::class_<PyBlock>(m, "Block", py::module_local()) 2250 .def_property_readonly( 2251 "owner", 2252 [](PyBlock &self) { 2253 return self.getParentOperation()->createOpView(); 2254 }, 2255 "Returns the owning operation of this block.") 2256 .def_property_readonly( 2257 "region", 2258 [](PyBlock &self) { 2259 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2260 return PyRegion(self.getParentOperation(), region); 2261 }, 2262 "Returns the owning region of this block.") 2263 .def_property_readonly( 2264 "arguments", 2265 [](PyBlock &self) { 2266 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2267 }, 2268 "Returns a list of block arguments.") 2269 .def_property_readonly( 2270 "operations", 2271 [](PyBlock &self) { 2272 return PyOperationList(self.getParentOperation(), self.get()); 2273 }, 2274 "Returns a forward-optimized sequence of operations.") 2275 .def( 2276 "create_before", 2277 [](PyBlock &self, py::args pyArgTypes) { 2278 self.checkValid(); 2279 llvm::SmallVector<MlirType, 4> argTypes; 2280 argTypes.reserve(pyArgTypes.size()); 2281 for (auto &pyArg : pyArgTypes) { 2282 argTypes.push_back(pyArg.cast<PyType &>()); 2283 } 2284 2285 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2286 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2287 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2288 return PyBlock(self.getParentOperation(), block); 2289 }, 2290 "Creates and returns a new Block before this block " 2291 "(with given argument types).") 2292 .def( 2293 "create_after", 2294 [](PyBlock &self, py::args pyArgTypes) { 2295 self.checkValid(); 2296 llvm::SmallVector<MlirType, 4> argTypes; 2297 argTypes.reserve(pyArgTypes.size()); 2298 for (auto &pyArg : pyArgTypes) { 2299 argTypes.push_back(pyArg.cast<PyType &>()); 2300 } 2301 2302 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2303 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2304 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2305 return PyBlock(self.getParentOperation(), block); 2306 }, 2307 "Creates and returns a new Block after this block " 2308 "(with given argument types).") 2309 .def( 2310 "__iter__", 2311 [](PyBlock &self) { 2312 self.checkValid(); 2313 MlirOperation firstOperation = 2314 mlirBlockGetFirstOperation(self.get()); 2315 return PyOperationIterator(self.getParentOperation(), 2316 firstOperation); 2317 }, 2318 "Iterates over operations in the block.") 2319 .def("__eq__", 2320 [](PyBlock &self, PyBlock &other) { 2321 return self.get().ptr == other.get().ptr; 2322 }) 2323 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2324 .def( 2325 "__str__", 2326 [](PyBlock &self) { 2327 self.checkValid(); 2328 PyPrintAccumulator printAccum; 2329 mlirBlockPrint(self.get(), printAccum.getCallback(), 2330 printAccum.getUserData()); 2331 return printAccum.join(); 2332 }, 2333 "Returns the assembly form of the block."); 2334 2335 //---------------------------------------------------------------------------- 2336 // Mapping of PyInsertionPoint. 2337 //---------------------------------------------------------------------------- 2338 2339 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2340 .def(py::init<PyBlock &>(), py::arg("block"), 2341 "Inserts after the last operation but still inside the block.") 2342 .def("__enter__", &PyInsertionPoint::contextEnter) 2343 .def("__exit__", &PyInsertionPoint::contextExit) 2344 .def_property_readonly_static( 2345 "current", 2346 [](py::object & /*class*/) { 2347 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2348 if (!ip) 2349 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2350 return ip; 2351 }, 2352 "Gets the InsertionPoint bound to the current thread or raises " 2353 "ValueError if none has been set") 2354 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2355 "Inserts before a referenced operation.") 2356 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2357 py::arg("block"), "Inserts at the beginning of the block.") 2358 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2359 py::arg("block"), "Inserts before the block terminator.") 2360 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2361 "Inserts an operation.") 2362 .def_property_readonly( 2363 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2364 "Returns the block that this InsertionPoint points to."); 2365 2366 //---------------------------------------------------------------------------- 2367 // Mapping of PyAttribute. 2368 //---------------------------------------------------------------------------- 2369 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2370 // Delegate to the PyAttribute copy constructor, which will also lifetime 2371 // extend the backing context which owns the MlirAttribute. 2372 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2373 "Casts the passed attribute to the generic Attribute") 2374 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2375 &PyAttribute::getCapsule) 2376 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2377 .def_static( 2378 "parse", 2379 [](std::string attrSpec, DefaultingPyMlirContext context) { 2380 MlirAttribute type = mlirAttributeParseGet( 2381 context->get(), toMlirStringRef(attrSpec)); 2382 // TODO: Rework error reporting once diagnostic engine is exposed 2383 // in C API. 2384 if (mlirAttributeIsNull(type)) { 2385 throw SetPyError(PyExc_ValueError, 2386 Twine("Unable to parse attribute: '") + 2387 attrSpec + "'"); 2388 } 2389 return PyAttribute(context->getRef(), type); 2390 }, 2391 py::arg("asm"), py::arg("context") = py::none(), 2392 "Parses an attribute from an assembly form") 2393 .def_property_readonly( 2394 "context", 2395 [](PyAttribute &self) { return self.getContext().getObject(); }, 2396 "Context that owns the Attribute") 2397 .def_property_readonly("type", 2398 [](PyAttribute &self) { 2399 return PyType(self.getContext()->getRef(), 2400 mlirAttributeGetType(self)); 2401 }) 2402 .def( 2403 "get_named", 2404 [](PyAttribute &self, std::string name) { 2405 return PyNamedAttribute(self, std::move(name)); 2406 }, 2407 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2408 .def("__eq__", 2409 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2410 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2411 .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; }) 2412 .def( 2413 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2414 kDumpDocstring) 2415 .def( 2416 "__str__", 2417 [](PyAttribute &self) { 2418 PyPrintAccumulator printAccum; 2419 mlirAttributePrint(self, printAccum.getCallback(), 2420 printAccum.getUserData()); 2421 return printAccum.join(); 2422 }, 2423 "Returns the assembly form of the Attribute.") 2424 .def("__repr__", [](PyAttribute &self) { 2425 // Generally, assembly formats are not printed for __repr__ because 2426 // this can cause exceptionally long debug output and exceptions. 2427 // However, attribute values are generally considered useful and are 2428 // printed. This may need to be re-evaluated if debug dumps end up 2429 // being excessive. 2430 PyPrintAccumulator printAccum; 2431 printAccum.parts.append("Attribute("); 2432 mlirAttributePrint(self, printAccum.getCallback(), 2433 printAccum.getUserData()); 2434 printAccum.parts.append(")"); 2435 return printAccum.join(); 2436 }); 2437 2438 //---------------------------------------------------------------------------- 2439 // Mapping of PyNamedAttribute 2440 //---------------------------------------------------------------------------- 2441 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2442 .def("__repr__", 2443 [](PyNamedAttribute &self) { 2444 PyPrintAccumulator printAccum; 2445 printAccum.parts.append("NamedAttribute("); 2446 printAccum.parts.append( 2447 mlirIdentifierStr(self.namedAttr.name).data); 2448 printAccum.parts.append("="); 2449 mlirAttributePrint(self.namedAttr.attribute, 2450 printAccum.getCallback(), 2451 printAccum.getUserData()); 2452 printAccum.parts.append(")"); 2453 return printAccum.join(); 2454 }) 2455 .def_property_readonly( 2456 "name", 2457 [](PyNamedAttribute &self) { 2458 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2459 mlirIdentifierStr(self.namedAttr.name).length); 2460 }, 2461 "The name of the NamedAttribute binding") 2462 .def_property_readonly( 2463 "attr", 2464 [](PyNamedAttribute &self) { 2465 // TODO: When named attribute is removed/refactored, also remove 2466 // this constructor (it does an inefficient table lookup). 2467 auto contextRef = PyMlirContext::forContext( 2468 mlirAttributeGetContext(self.namedAttr.attribute)); 2469 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2470 }, 2471 py::keep_alive<0, 1>(), 2472 "The underlying generic attribute of the NamedAttribute binding"); 2473 2474 //---------------------------------------------------------------------------- 2475 // Mapping of PyType. 2476 //---------------------------------------------------------------------------- 2477 py::class_<PyType>(m, "Type", py::module_local()) 2478 // Delegate to the PyType copy constructor, which will also lifetime 2479 // extend the backing context which owns the MlirType. 2480 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2481 "Casts the passed type to the generic Type") 2482 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2483 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2484 .def_static( 2485 "parse", 2486 [](std::string typeSpec, DefaultingPyMlirContext context) { 2487 MlirType type = 2488 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2489 // TODO: Rework error reporting once diagnostic engine is exposed 2490 // in C API. 2491 if (mlirTypeIsNull(type)) { 2492 throw SetPyError(PyExc_ValueError, 2493 Twine("Unable to parse type: '") + typeSpec + 2494 "'"); 2495 } 2496 return PyType(context->getRef(), type); 2497 }, 2498 py::arg("asm"), py::arg("context") = py::none(), 2499 kContextParseTypeDocstring) 2500 .def_property_readonly( 2501 "context", [](PyType &self) { return self.getContext().getObject(); }, 2502 "Context that owns the Type") 2503 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2504 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2505 .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; }) 2506 .def( 2507 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2508 .def( 2509 "__str__", 2510 [](PyType &self) { 2511 PyPrintAccumulator printAccum; 2512 mlirTypePrint(self, printAccum.getCallback(), 2513 printAccum.getUserData()); 2514 return printAccum.join(); 2515 }, 2516 "Returns the assembly form of the type.") 2517 .def("__repr__", [](PyType &self) { 2518 // Generally, assembly formats are not printed for __repr__ because 2519 // this can cause exceptionally long debug output and exceptions. 2520 // However, types are an exception as they typically have compact 2521 // assembly forms and printing them is useful. 2522 PyPrintAccumulator printAccum; 2523 printAccum.parts.append("Type("); 2524 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2525 printAccum.parts.append(")"); 2526 return printAccum.join(); 2527 }); 2528 2529 //---------------------------------------------------------------------------- 2530 // Mapping of Value. 2531 //---------------------------------------------------------------------------- 2532 py::class_<PyValue>(m, "Value", py::module_local()) 2533 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2534 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2535 .def_property_readonly( 2536 "context", 2537 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2538 "Context in which the value lives.") 2539 .def( 2540 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2541 kDumpDocstring) 2542 .def_property_readonly( 2543 "owner", 2544 [](PyValue &self) { 2545 assert(mlirOperationEqual(self.getParentOperation()->get(), 2546 mlirOpResultGetOwner(self.get())) && 2547 "expected the owner of the value in Python to match that in " 2548 "the IR"); 2549 return self.getParentOperation().getObject(); 2550 }) 2551 .def("__eq__", 2552 [](PyValue &self, PyValue &other) { 2553 return self.get().ptr == other.get().ptr; 2554 }) 2555 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2556 .def( 2557 "__str__", 2558 [](PyValue &self) { 2559 PyPrintAccumulator printAccum; 2560 printAccum.parts.append("Value("); 2561 mlirValuePrint(self.get(), printAccum.getCallback(), 2562 printAccum.getUserData()); 2563 printAccum.parts.append(")"); 2564 return printAccum.join(); 2565 }, 2566 kValueDunderStrDocstring) 2567 .def_property_readonly("type", [](PyValue &self) { 2568 return PyType(self.getParentOperation()->getContext(), 2569 mlirValueGetType(self.get())); 2570 }); 2571 PyBlockArgument::bind(m); 2572 PyOpResult::bind(m); 2573 2574 // Container bindings. 2575 PyBlockArgumentList::bind(m); 2576 PyBlockIterator::bind(m); 2577 PyBlockList::bind(m); 2578 PyOperationIterator::bind(m); 2579 PyOperationList::bind(m); 2580 PyOpAttributeMap::bind(m); 2581 PyOpOperandList::bind(m); 2582 PyOpResultList::bind(m); 2583 PyRegionIterator::bind(m); 2584 PyRegionList::bind(m); 2585 2586 // Debug bindings. 2587 PyGlobalDebugFlag::bind(m); 2588 } 2589