1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 namespace py = pybind11; 25 using namespace mlir; 26 using namespace mlir::python; 27 28 using llvm::SmallVector; 29 using llvm::StringRef; 30 using llvm::Twine; 31 32 //------------------------------------------------------------------------------ 33 // Docstrings (trivial, non-duplicated docstrings are included inline). 34 //------------------------------------------------------------------------------ 35 36 static const char kContextParseTypeDocstring[] = 37 R"(Parses the assembly form of a type. 38 39 Returns a Type object or raises a ValueError if the type cannot be parsed. 40 41 See also: https://mlir.llvm.org/docs/LangRef/#type-system 42 )"; 43 44 static const char kContextGetCallSiteLocationDocstring[] = 45 R"(Gets a Location representing a caller and callsite)"; 46 47 static const char kContextGetFileLocationDocstring[] = 48 R"(Gets a Location representing a file, line and column)"; 49 50 static const char kContextGetNameLocationDocString[] = 51 R"(Gets a Location representing a named location with optional child location)"; 52 53 static const char kModuleParseDocstring[] = 54 R"(Parses a module's assembly format from a string. 55 56 Returns a new MlirModule or raises a ValueError if the parsing fails. 57 58 See also: https://mlir.llvm.org/docs/LangRef/ 59 )"; 60 61 static const char kOperationCreateDocstring[] = 62 R"(Creates a new operation. 63 64 Args: 65 name: Operation name (e.g. "dialect.operation"). 66 results: Sequence of Type representing op result types. 67 attributes: Dict of str:Attribute. 68 successors: List of Block for the operation's successors. 69 regions: Number of regions to create. 70 location: A Location object (defaults to resolve from context manager). 71 ip: An InsertionPoint (defaults to resolve from context manager or set to 72 False to disable insertion, even with an insertion point set in the 73 context manager). 74 Returns: 75 A new "detached" Operation object. Detached operations can be added 76 to blocks, which causes them to become "attached." 77 )"; 78 79 static const char kOperationPrintDocstring[] = 80 R"(Prints the assembly form of the operation to a file like object. 81 82 Args: 83 file: The file like object to write to. Defaults to sys.stdout. 84 binary: Whether to write bytes (True) or str (False). Defaults to False. 85 large_elements_limit: Whether to elide elements attributes above this 86 number of elements. Defaults to None (no limit). 87 enable_debug_info: Whether to print debug/location information. Defaults 88 to False. 89 pretty_debug_info: Whether to format debug information for easier reading 90 by a human (warning: the result is unparseable). 91 print_generic_op_form: Whether to print the generic assembly forms of all 92 ops. Defaults to False. 93 use_local_Scope: Whether to print in a way that is more optimized for 94 multi-threaded access but may not be consistent with how the overall 95 module prints. 96 )"; 97 98 static const char kOperationGetAsmDocstring[] = 99 R"(Gets the assembly form of the operation with all options available. 100 101 Args: 102 binary: Whether to return a bytes (True) or str (False) object. Defaults to 103 False. 104 ... others ...: See the print() method for common keyword arguments for 105 configuring the printout. 106 Returns: 107 Either a bytes or str object, depending on the setting of the 'binary' 108 argument. 109 )"; 110 111 static const char kOperationStrDunderDocstring[] = 112 R"(Gets the assembly form of the operation with default options. 113 114 If more advanced control over the assembly formatting or I/O options is needed, 115 use the dedicated print or get_asm method, which supports keyword arguments to 116 customize behavior. 117 )"; 118 119 static const char kDumpDocstring[] = 120 R"(Dumps a debug representation of the object to stderr.)"; 121 122 static const char kAppendBlockDocstring[] = 123 R"(Appends a new block, with argument types as positional args. 124 125 Returns: 126 The created block. 127 )"; 128 129 static const char kValueDunderStrDocstring[] = 130 R"(Returns the string form of the value. 131 132 If the value is a block argument, this is the assembly form of its type and the 133 position in the argument list. If the value is an operation result, this is 134 equivalent to printing the operation that produced it. 135 )"; 136 137 //------------------------------------------------------------------------------ 138 // Utilities. 139 //------------------------------------------------------------------------------ 140 141 /// Helper for creating an @classmethod. 142 template <class Func, typename... Args> 143 py::object classmethod(Func f, Args... args) { 144 py::object cf = py::cpp_function(f, args...); 145 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 146 } 147 148 static py::object 149 createCustomDialectWrapper(const std::string &dialectNamespace, 150 py::object dialectDescriptor) { 151 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 152 if (!dialectClass) { 153 // Use the base class. 154 return py::cast(PyDialect(std::move(dialectDescriptor))); 155 } 156 157 // Create the custom implementation. 158 return (*dialectClass)(std::move(dialectDescriptor)); 159 } 160 161 static MlirStringRef toMlirStringRef(const std::string &s) { 162 return mlirStringRefCreate(s.data(), s.size()); 163 } 164 165 /// Wrapper for the global LLVM debugging flag. 166 struct PyGlobalDebugFlag { 167 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 168 169 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 170 171 static void bind(py::module &m) { 172 // Debug flags. 173 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 174 .def_property_static("flag", &PyGlobalDebugFlag::get, 175 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 176 } 177 }; 178 179 //------------------------------------------------------------------------------ 180 // Collections. 181 //------------------------------------------------------------------------------ 182 183 namespace { 184 185 class PyRegionIterator { 186 public: 187 PyRegionIterator(PyOperationRef operation) 188 : operation(std::move(operation)) {} 189 190 PyRegionIterator &dunderIter() { return *this; } 191 192 PyRegion dunderNext() { 193 operation->checkValid(); 194 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 195 throw py::stop_iteration(); 196 } 197 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 198 return PyRegion(operation, region); 199 } 200 201 static void bind(py::module &m) { 202 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 203 .def("__iter__", &PyRegionIterator::dunderIter) 204 .def("__next__", &PyRegionIterator::dunderNext); 205 } 206 207 private: 208 PyOperationRef operation; 209 int nextIndex = 0; 210 }; 211 212 /// Regions of an op are fixed length and indexed numerically so are represented 213 /// with a sequence-like container. 214 class PyRegionList { 215 public: 216 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 217 218 intptr_t dunderLen() { 219 operation->checkValid(); 220 return mlirOperationGetNumRegions(operation->get()); 221 } 222 223 PyRegion dunderGetItem(intptr_t index) { 224 // dunderLen checks validity. 225 if (index < 0 || index >= dunderLen()) { 226 throw SetPyError(PyExc_IndexError, 227 "attempt to access out of bounds region"); 228 } 229 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 230 return PyRegion(operation, region); 231 } 232 233 static void bind(py::module &m) { 234 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 235 .def("__len__", &PyRegionList::dunderLen) 236 .def("__getitem__", &PyRegionList::dunderGetItem); 237 } 238 239 private: 240 PyOperationRef operation; 241 }; 242 243 class PyBlockIterator { 244 public: 245 PyBlockIterator(PyOperationRef operation, MlirBlock next) 246 : operation(std::move(operation)), next(next) {} 247 248 PyBlockIterator &dunderIter() { return *this; } 249 250 PyBlock dunderNext() { 251 operation->checkValid(); 252 if (mlirBlockIsNull(next)) { 253 throw py::stop_iteration(); 254 } 255 256 PyBlock returnBlock(operation, next); 257 next = mlirBlockGetNextInRegion(next); 258 return returnBlock; 259 } 260 261 static void bind(py::module &m) { 262 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 263 .def("__iter__", &PyBlockIterator::dunderIter) 264 .def("__next__", &PyBlockIterator::dunderNext); 265 } 266 267 private: 268 PyOperationRef operation; 269 MlirBlock next; 270 }; 271 272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 273 /// we present them as a more full-featured list-like container but optimize 274 /// it for forward iteration. Blocks are always owned by a region. 275 class PyBlockList { 276 public: 277 PyBlockList(PyOperationRef operation, MlirRegion region) 278 : operation(std::move(operation)), region(region) {} 279 280 PyBlockIterator dunderIter() { 281 operation->checkValid(); 282 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 283 } 284 285 intptr_t dunderLen() { 286 operation->checkValid(); 287 intptr_t count = 0; 288 MlirBlock block = mlirRegionGetFirstBlock(region); 289 while (!mlirBlockIsNull(block)) { 290 count += 1; 291 block = mlirBlockGetNextInRegion(block); 292 } 293 return count; 294 } 295 296 PyBlock dunderGetItem(intptr_t index) { 297 operation->checkValid(); 298 if (index < 0) { 299 throw SetPyError(PyExc_IndexError, 300 "attempt to access out of bounds block"); 301 } 302 MlirBlock block = mlirRegionGetFirstBlock(region); 303 while (!mlirBlockIsNull(block)) { 304 if (index == 0) { 305 return PyBlock(operation, block); 306 } 307 block = mlirBlockGetNextInRegion(block); 308 index -= 1; 309 } 310 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 311 } 312 313 PyBlock appendBlock(py::args pyArgTypes) { 314 operation->checkValid(); 315 llvm::SmallVector<MlirType, 4> argTypes; 316 argTypes.reserve(pyArgTypes.size()); 317 for (auto &pyArg : pyArgTypes) { 318 argTypes.push_back(pyArg.cast<PyType &>()); 319 } 320 321 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 322 mlirRegionAppendOwnedBlock(region, block); 323 return PyBlock(operation, block); 324 } 325 326 static void bind(py::module &m) { 327 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 328 .def("__getitem__", &PyBlockList::dunderGetItem) 329 .def("__iter__", &PyBlockList::dunderIter) 330 .def("__len__", &PyBlockList::dunderLen) 331 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 332 } 333 334 private: 335 PyOperationRef operation; 336 MlirRegion region; 337 }; 338 339 class PyOperationIterator { 340 public: 341 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 342 : parentOperation(std::move(parentOperation)), next(next) {} 343 344 PyOperationIterator &dunderIter() { return *this; } 345 346 py::object dunderNext() { 347 parentOperation->checkValid(); 348 if (mlirOperationIsNull(next)) { 349 throw py::stop_iteration(); 350 } 351 352 PyOperationRef returnOperation = 353 PyOperation::forOperation(parentOperation->getContext(), next); 354 next = mlirOperationGetNextInBlock(next); 355 return returnOperation->createOpView(); 356 } 357 358 static void bind(py::module &m) { 359 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 360 .def("__iter__", &PyOperationIterator::dunderIter) 361 .def("__next__", &PyOperationIterator::dunderNext); 362 } 363 364 private: 365 PyOperationRef parentOperation; 366 MlirOperation next; 367 }; 368 369 /// Operations are exposed by the C-API as a forward-only linked list. In 370 /// Python, we present them as a more full-featured list-like container but 371 /// optimize it for forward iteration. Iterable operations are always owned 372 /// by a block. 373 class PyOperationList { 374 public: 375 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 376 : parentOperation(std::move(parentOperation)), block(block) {} 377 378 PyOperationIterator dunderIter() { 379 parentOperation->checkValid(); 380 return PyOperationIterator(parentOperation, 381 mlirBlockGetFirstOperation(block)); 382 } 383 384 intptr_t dunderLen() { 385 parentOperation->checkValid(); 386 intptr_t count = 0; 387 MlirOperation childOp = mlirBlockGetFirstOperation(block); 388 while (!mlirOperationIsNull(childOp)) { 389 count += 1; 390 childOp = mlirOperationGetNextInBlock(childOp); 391 } 392 return count; 393 } 394 395 py::object dunderGetItem(intptr_t index) { 396 parentOperation->checkValid(); 397 if (index < 0) { 398 throw SetPyError(PyExc_IndexError, 399 "attempt to access out of bounds operation"); 400 } 401 MlirOperation childOp = mlirBlockGetFirstOperation(block); 402 while (!mlirOperationIsNull(childOp)) { 403 if (index == 0) { 404 return PyOperation::forOperation(parentOperation->getContext(), childOp) 405 ->createOpView(); 406 } 407 childOp = mlirOperationGetNextInBlock(childOp); 408 index -= 1; 409 } 410 throw SetPyError(PyExc_IndexError, 411 "attempt to access out of bounds operation"); 412 } 413 414 static void bind(py::module &m) { 415 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 416 .def("__getitem__", &PyOperationList::dunderGetItem) 417 .def("__iter__", &PyOperationList::dunderIter) 418 .def("__len__", &PyOperationList::dunderLen); 419 } 420 421 private: 422 PyOperationRef parentOperation; 423 MlirBlock block; 424 }; 425 426 } // namespace 427 428 //------------------------------------------------------------------------------ 429 // PyMlirContext 430 //------------------------------------------------------------------------------ 431 432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 433 py::gil_scoped_acquire acquire; 434 auto &liveContexts = getLiveContexts(); 435 liveContexts[context.ptr] = this; 436 } 437 438 PyMlirContext::~PyMlirContext() { 439 // Note that the only public way to construct an instance is via the 440 // forContext method, which always puts the associated handle into 441 // liveContexts. 442 py::gil_scoped_acquire acquire; 443 getLiveContexts().erase(context.ptr); 444 mlirContextDestroy(context); 445 } 446 447 py::object PyMlirContext::getCapsule() { 448 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 449 } 450 451 py::object PyMlirContext::createFromCapsule(py::object capsule) { 452 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 453 if (mlirContextIsNull(rawContext)) 454 throw py::error_already_set(); 455 return forContext(rawContext).releaseObject(); 456 } 457 458 PyMlirContext *PyMlirContext::createNewContextForInit() { 459 MlirContext context = mlirContextCreate(); 460 mlirRegisterAllDialects(context); 461 return new PyMlirContext(context); 462 } 463 464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 465 py::gil_scoped_acquire acquire; 466 auto &liveContexts = getLiveContexts(); 467 auto it = liveContexts.find(context.ptr); 468 if (it == liveContexts.end()) { 469 // Create. 470 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 471 py::object pyRef = py::cast(unownedContextWrapper); 472 assert(pyRef && "cast to py::object failed"); 473 liveContexts[context.ptr] = unownedContextWrapper; 474 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 475 } 476 // Use existing. 477 py::object pyRef = py::cast(it->second); 478 return PyMlirContextRef(it->second, std::move(pyRef)); 479 } 480 481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 482 static LiveContextMap liveContexts; 483 return liveContexts; 484 } 485 486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 487 488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 489 490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 491 492 pybind11::object PyMlirContext::contextEnter() { 493 return PyThreadContextEntry::pushContext(*this); 494 } 495 496 void PyMlirContext::contextExit(pybind11::object excType, 497 pybind11::object excVal, 498 pybind11::object excTb) { 499 PyThreadContextEntry::popContext(*this); 500 } 501 502 PyMlirContext &DefaultingPyMlirContext::resolve() { 503 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 504 if (!context) { 505 throw SetPyError( 506 PyExc_RuntimeError, 507 "An MLIR function requires a Context but none was provided in the call " 508 "or from the surrounding environment. Either pass to the function with " 509 "a 'context=' argument or establish a default using 'with Context():'"); 510 } 511 return *context; 512 } 513 514 //------------------------------------------------------------------------------ 515 // PyThreadContextEntry management 516 //------------------------------------------------------------------------------ 517 518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 519 static thread_local std::vector<PyThreadContextEntry> stack; 520 return stack; 521 } 522 523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 524 auto &stack = getStack(); 525 if (stack.empty()) 526 return nullptr; 527 return &stack.back(); 528 } 529 530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 531 py::object insertionPoint, 532 py::object location) { 533 auto &stack = getStack(); 534 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 535 std::move(location)); 536 // If the new stack has more than one entry and the context of the new top 537 // entry matches the previous, copy the insertionPoint and location from the 538 // previous entry if missing from the new top entry. 539 if (stack.size() > 1) { 540 auto &prev = *(stack.rbegin() + 1); 541 auto ¤t = stack.back(); 542 if (current.context.is(prev.context)) { 543 // Default non-context objects from the previous entry. 544 if (!current.insertionPoint) 545 current.insertionPoint = prev.insertionPoint; 546 if (!current.location) 547 current.location = prev.location; 548 } 549 } 550 } 551 552 PyMlirContext *PyThreadContextEntry::getContext() { 553 if (!context) 554 return nullptr; 555 return py::cast<PyMlirContext *>(context); 556 } 557 558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 559 if (!insertionPoint) 560 return nullptr; 561 return py::cast<PyInsertionPoint *>(insertionPoint); 562 } 563 564 PyLocation *PyThreadContextEntry::getLocation() { 565 if (!location) 566 return nullptr; 567 return py::cast<PyLocation *>(location); 568 } 569 570 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 571 auto *tos = getTopOfStack(); 572 return tos ? tos->getContext() : nullptr; 573 } 574 575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 576 auto *tos = getTopOfStack(); 577 return tos ? tos->getInsertionPoint() : nullptr; 578 } 579 580 PyLocation *PyThreadContextEntry::getDefaultLocation() { 581 auto *tos = getTopOfStack(); 582 return tos ? tos->getLocation() : nullptr; 583 } 584 585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 586 py::object contextObj = py::cast(context); 587 push(FrameKind::Context, /*context=*/contextObj, 588 /*insertionPoint=*/py::object(), 589 /*location=*/py::object()); 590 return contextObj; 591 } 592 593 void PyThreadContextEntry::popContext(PyMlirContext &context) { 594 auto &stack = getStack(); 595 if (stack.empty()) 596 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 597 auto &tos = stack.back(); 598 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 599 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 600 stack.pop_back(); 601 } 602 603 py::object 604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 605 py::object contextObj = 606 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 607 py::object insertionPointObj = py::cast(insertionPoint); 608 push(FrameKind::InsertionPoint, 609 /*context=*/contextObj, 610 /*insertionPoint=*/insertionPointObj, 611 /*location=*/py::object()); 612 return insertionPointObj; 613 } 614 615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 616 auto &stack = getStack(); 617 if (stack.empty()) 618 throw SetPyError(PyExc_RuntimeError, 619 "Unbalanced InsertionPoint enter/exit"); 620 auto &tos = stack.back(); 621 if (tos.frameKind != FrameKind::InsertionPoint && 622 tos.getInsertionPoint() != &insertionPoint) 623 throw SetPyError(PyExc_RuntimeError, 624 "Unbalanced InsertionPoint enter/exit"); 625 stack.pop_back(); 626 } 627 628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 629 py::object contextObj = location.getContext().getObject(); 630 py::object locationObj = py::cast(location); 631 push(FrameKind::Location, /*context=*/contextObj, 632 /*insertionPoint=*/py::object(), 633 /*location=*/locationObj); 634 return locationObj; 635 } 636 637 void PyThreadContextEntry::popLocation(PyLocation &location) { 638 auto &stack = getStack(); 639 if (stack.empty()) 640 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 641 auto &tos = stack.back(); 642 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 643 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 644 stack.pop_back(); 645 } 646 647 //------------------------------------------------------------------------------ 648 // PyDialect, PyDialectDescriptor, PyDialects 649 //------------------------------------------------------------------------------ 650 651 MlirDialect PyDialects::getDialectForKey(const std::string &key, 652 bool attrError) { 653 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 654 {key.data(), key.size()}); 655 if (mlirDialectIsNull(dialect)) { 656 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 657 Twine("Dialect '") + key + "' not found"); 658 } 659 return dialect; 660 } 661 662 //------------------------------------------------------------------------------ 663 // PyLocation 664 //------------------------------------------------------------------------------ 665 666 py::object PyLocation::getCapsule() { 667 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 668 } 669 670 PyLocation PyLocation::createFromCapsule(py::object capsule) { 671 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 672 if (mlirLocationIsNull(rawLoc)) 673 throw py::error_already_set(); 674 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 675 rawLoc); 676 } 677 678 py::object PyLocation::contextEnter() { 679 return PyThreadContextEntry::pushLocation(*this); 680 } 681 682 void PyLocation::contextExit(py::object excType, py::object excVal, 683 py::object excTb) { 684 PyThreadContextEntry::popLocation(*this); 685 } 686 687 PyLocation &DefaultingPyLocation::resolve() { 688 auto *location = PyThreadContextEntry::getDefaultLocation(); 689 if (!location) { 690 throw SetPyError( 691 PyExc_RuntimeError, 692 "An MLIR function requires a Location but none was provided in the " 693 "call or from the surrounding environment. Either pass to the function " 694 "with a 'loc=' argument or establish a default using 'with loc:'"); 695 } 696 return *location; 697 } 698 699 //------------------------------------------------------------------------------ 700 // PyModule 701 //------------------------------------------------------------------------------ 702 703 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 704 : BaseContextObject(std::move(contextRef)), module(module) {} 705 706 PyModule::~PyModule() { 707 py::gil_scoped_acquire acquire; 708 auto &liveModules = getContext()->liveModules; 709 assert(liveModules.count(module.ptr) == 1 && 710 "destroying module not in live map"); 711 liveModules.erase(module.ptr); 712 mlirModuleDestroy(module); 713 } 714 715 PyModuleRef PyModule::forModule(MlirModule module) { 716 MlirContext context = mlirModuleGetContext(module); 717 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 718 719 py::gil_scoped_acquire acquire; 720 auto &liveModules = contextRef->liveModules; 721 auto it = liveModules.find(module.ptr); 722 if (it == liveModules.end()) { 723 // Create. 724 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 725 // Note that the default return value policy on cast is automatic_reference, 726 // which does not take ownership (delete will not be called). 727 // Just be explicit. 728 py::object pyRef = 729 py::cast(unownedModule, py::return_value_policy::take_ownership); 730 unownedModule->handle = pyRef; 731 liveModules[module.ptr] = 732 std::make_pair(unownedModule->handle, unownedModule); 733 return PyModuleRef(unownedModule, std::move(pyRef)); 734 } 735 // Use existing. 736 PyModule *existing = it->second.second; 737 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 738 return PyModuleRef(existing, std::move(pyRef)); 739 } 740 741 py::object PyModule::createFromCapsule(py::object capsule) { 742 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 743 if (mlirModuleIsNull(rawModule)) 744 throw py::error_already_set(); 745 return forModule(rawModule).releaseObject(); 746 } 747 748 py::object PyModule::getCapsule() { 749 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 750 } 751 752 //------------------------------------------------------------------------------ 753 // PyOperation 754 //------------------------------------------------------------------------------ 755 756 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 757 : BaseContextObject(std::move(contextRef)), operation(operation) {} 758 759 PyOperation::~PyOperation() { 760 // If the operation has already been invalidated there is nothing to do. 761 if (!valid) 762 return; 763 auto &liveOperations = getContext()->liveOperations; 764 assert(liveOperations.count(operation.ptr) == 1 && 765 "destroying operation not in live map"); 766 liveOperations.erase(operation.ptr); 767 if (!isAttached()) { 768 mlirOperationDestroy(operation); 769 } 770 } 771 772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 773 MlirOperation operation, 774 py::object parentKeepAlive) { 775 auto &liveOperations = contextRef->liveOperations; 776 // Create. 777 PyOperation *unownedOperation = 778 new PyOperation(std::move(contextRef), operation); 779 // Note that the default return value policy on cast is automatic_reference, 780 // which does not take ownership (delete will not be called). 781 // Just be explicit. 782 py::object pyRef = 783 py::cast(unownedOperation, py::return_value_policy::take_ownership); 784 unownedOperation->handle = pyRef; 785 if (parentKeepAlive) { 786 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 787 } 788 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 789 return PyOperationRef(unownedOperation, std::move(pyRef)); 790 } 791 792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 793 MlirOperation operation, 794 py::object parentKeepAlive) { 795 auto &liveOperations = contextRef->liveOperations; 796 auto it = liveOperations.find(operation.ptr); 797 if (it == liveOperations.end()) { 798 // Create. 799 return createInstance(std::move(contextRef), operation, 800 std::move(parentKeepAlive)); 801 } 802 // Use existing. 803 PyOperation *existing = it->second.second; 804 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 805 return PyOperationRef(existing, std::move(pyRef)); 806 } 807 808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 809 MlirOperation operation, 810 py::object parentKeepAlive) { 811 auto &liveOperations = contextRef->liveOperations; 812 assert(liveOperations.count(operation.ptr) == 0 && 813 "cannot create detached operation that already exists"); 814 (void)liveOperations; 815 816 PyOperationRef created = createInstance(std::move(contextRef), operation, 817 std::move(parentKeepAlive)); 818 created->attached = false; 819 return created; 820 } 821 822 void PyOperation::checkValid() const { 823 if (!valid) { 824 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 825 } 826 } 827 828 void PyOperationBase::print(py::object fileObject, bool binary, 829 llvm::Optional<int64_t> largeElementsLimit, 830 bool enableDebugInfo, bool prettyDebugInfo, 831 bool printGenericOpForm, bool useLocalScope) { 832 PyOperation &operation = getOperation(); 833 operation.checkValid(); 834 if (fileObject.is_none()) 835 fileObject = py::module::import("sys").attr("stdout"); 836 837 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 838 fileObject.attr("write")("// Verification failed, printing generic form\n"); 839 printGenericOpForm = true; 840 } 841 842 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 843 if (largeElementsLimit) 844 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 845 if (enableDebugInfo) 846 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 847 if (printGenericOpForm) 848 mlirOpPrintingFlagsPrintGenericOpForm(flags); 849 850 PyFileAccumulator accum(fileObject, binary); 851 py::gil_scoped_release(); 852 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 853 accum.getUserData()); 854 mlirOpPrintingFlagsDestroy(flags); 855 } 856 857 py::object PyOperationBase::getAsm(bool binary, 858 llvm::Optional<int64_t> largeElementsLimit, 859 bool enableDebugInfo, bool prettyDebugInfo, 860 bool printGenericOpForm, 861 bool useLocalScope) { 862 py::object fileObject; 863 if (binary) { 864 fileObject = py::module::import("io").attr("BytesIO")(); 865 } else { 866 fileObject = py::module::import("io").attr("StringIO")(); 867 } 868 print(fileObject, /*binary=*/binary, 869 /*largeElementsLimit=*/largeElementsLimit, 870 /*enableDebugInfo=*/enableDebugInfo, 871 /*prettyDebugInfo=*/prettyDebugInfo, 872 /*printGenericOpForm=*/printGenericOpForm, 873 /*useLocalScope=*/useLocalScope); 874 875 return fileObject.attr("getvalue")(); 876 } 877 878 void PyOperationBase::moveAfter(PyOperationBase &other) { 879 PyOperation &operation = getOperation(); 880 PyOperation &otherOp = other.getOperation(); 881 operation.checkValid(); 882 otherOp.checkValid(); 883 mlirOperationMoveAfter(operation, otherOp); 884 operation.parentKeepAlive = otherOp.parentKeepAlive; 885 } 886 887 void PyOperationBase::moveBefore(PyOperationBase &other) { 888 PyOperation &operation = getOperation(); 889 PyOperation &otherOp = other.getOperation(); 890 operation.checkValid(); 891 otherOp.checkValid(); 892 mlirOperationMoveBefore(operation, otherOp); 893 operation.parentKeepAlive = otherOp.parentKeepAlive; 894 } 895 896 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 897 checkValid(); 898 if (!isAttached()) 899 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 900 MlirOperation operation = mlirOperationGetParentOperation(get()); 901 if (mlirOperationIsNull(operation)) 902 return {}; 903 return PyOperation::forOperation(getContext(), operation); 904 } 905 906 PyBlock PyOperation::getBlock() { 907 checkValid(); 908 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 909 MlirBlock block = mlirOperationGetBlock(get()); 910 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 911 assert(parentOperation && "Operation has no parent"); 912 return PyBlock{std::move(*parentOperation), block}; 913 } 914 915 py::object PyOperation::getCapsule() { 916 checkValid(); 917 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 918 } 919 920 py::object PyOperation::createFromCapsule(py::object capsule) { 921 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 922 if (mlirOperationIsNull(rawOperation)) 923 throw py::error_already_set(); 924 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 925 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 926 .releaseObject(); 927 } 928 929 py::object PyOperation::create( 930 std::string name, llvm::Optional<std::vector<PyType *>> results, 931 llvm::Optional<std::vector<PyValue *>> operands, 932 llvm::Optional<py::dict> attributes, 933 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 934 DefaultingPyLocation location, py::object maybeIp) { 935 llvm::SmallVector<MlirValue, 4> mlirOperands; 936 llvm::SmallVector<MlirType, 4> mlirResults; 937 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 938 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 939 940 // General parameter validation. 941 if (regions < 0) 942 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 943 944 // Unpack/validate operands. 945 if (operands) { 946 mlirOperands.reserve(operands->size()); 947 for (PyValue *operand : *operands) { 948 if (!operand) 949 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 950 mlirOperands.push_back(operand->get()); 951 } 952 } 953 954 // Unpack/validate results. 955 if (results) { 956 mlirResults.reserve(results->size()); 957 for (PyType *result : *results) { 958 // TODO: Verify result type originate from the same context. 959 if (!result) 960 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 961 mlirResults.push_back(*result); 962 } 963 } 964 // Unpack/validate attributes. 965 if (attributes) { 966 mlirAttributes.reserve(attributes->size()); 967 for (auto &it : *attributes) { 968 std::string key; 969 try { 970 key = it.first.cast<std::string>(); 971 } catch (py::cast_error &err) { 972 std::string msg = "Invalid attribute key (not a string) when " 973 "attempting to create the operation \"" + 974 name + "\" (" + err.what() + ")"; 975 throw py::cast_error(msg); 976 } 977 try { 978 auto &attribute = it.second.cast<PyAttribute &>(); 979 // TODO: Verify attribute originates from the same context. 980 mlirAttributes.emplace_back(std::move(key), attribute); 981 } catch (py::reference_cast_error &) { 982 // This exception seems thrown when the value is "None". 983 std::string msg = 984 "Found an invalid (`None`?) attribute value for the key \"" + key + 985 "\" when attempting to create the operation \"" + name + "\""; 986 throw py::cast_error(msg); 987 } catch (py::cast_error &err) { 988 std::string msg = "Invalid attribute value for the key \"" + key + 989 "\" when attempting to create the operation \"" + 990 name + "\" (" + err.what() + ")"; 991 throw py::cast_error(msg); 992 } 993 } 994 } 995 // Unpack/validate successors. 996 if (successors) { 997 mlirSuccessors.reserve(successors->size()); 998 for (auto *successor : *successors) { 999 // TODO: Verify successor originate from the same context. 1000 if (!successor) 1001 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1002 mlirSuccessors.push_back(successor->get()); 1003 } 1004 } 1005 1006 // Apply unpacked/validated to the operation state. Beyond this 1007 // point, exceptions cannot be thrown or else the state will leak. 1008 MlirOperationState state = 1009 mlirOperationStateGet(toMlirStringRef(name), location); 1010 if (!mlirOperands.empty()) 1011 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1012 mlirOperands.data()); 1013 if (!mlirResults.empty()) 1014 mlirOperationStateAddResults(&state, mlirResults.size(), 1015 mlirResults.data()); 1016 if (!mlirAttributes.empty()) { 1017 // Note that the attribute names directly reference bytes in 1018 // mlirAttributes, so that vector must not be changed from here 1019 // on. 1020 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1021 mlirNamedAttributes.reserve(mlirAttributes.size()); 1022 for (auto &it : mlirAttributes) 1023 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1024 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1025 toMlirStringRef(it.first)), 1026 it.second)); 1027 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1028 mlirNamedAttributes.data()); 1029 } 1030 if (!mlirSuccessors.empty()) 1031 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1032 mlirSuccessors.data()); 1033 if (regions) { 1034 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1035 mlirRegions.resize(regions); 1036 for (int i = 0; i < regions; ++i) 1037 mlirRegions[i] = mlirRegionCreate(); 1038 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1039 mlirRegions.data()); 1040 } 1041 1042 // Construct the operation. 1043 MlirOperation operation = mlirOperationCreate(&state); 1044 PyOperationRef created = 1045 PyOperation::createDetached(location->getContext(), operation); 1046 1047 // InsertPoint active? 1048 if (!maybeIp.is(py::cast(false))) { 1049 PyInsertionPoint *ip; 1050 if (maybeIp.is_none()) { 1051 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1052 } else { 1053 ip = py::cast<PyInsertionPoint *>(maybeIp); 1054 } 1055 if (ip) 1056 ip->insert(*created.get()); 1057 } 1058 1059 return created->createOpView(); 1060 } 1061 1062 py::object PyOperation::createOpView() { 1063 checkValid(); 1064 MlirIdentifier ident = mlirOperationGetName(get()); 1065 MlirStringRef identStr = mlirIdentifierStr(ident); 1066 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1067 StringRef(identStr.data, identStr.length)); 1068 if (opViewClass) 1069 return (*opViewClass)(getRef().getObject()); 1070 return py::cast(PyOpView(getRef().getObject())); 1071 } 1072 1073 void PyOperation::erase() { 1074 checkValid(); 1075 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1076 // Python reference to a child operation is live. All children should also 1077 // have their `valid` bit set to false. 1078 auto &liveOperations = getContext()->liveOperations; 1079 if (liveOperations.count(operation.ptr)) 1080 liveOperations.erase(operation.ptr); 1081 mlirOperationDestroy(operation); 1082 valid = false; 1083 } 1084 1085 //------------------------------------------------------------------------------ 1086 // PyOpView 1087 //------------------------------------------------------------------------------ 1088 1089 py::object 1090 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1091 py::list operandList, 1092 llvm::Optional<py::dict> attributes, 1093 llvm::Optional<std::vector<PyBlock *>> successors, 1094 llvm::Optional<int> regions, 1095 DefaultingPyLocation location, py::object maybeIp) { 1096 PyMlirContextRef context = location->getContext(); 1097 // Class level operation construction metadata. 1098 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1099 // Operand and result segment specs are either none, which does no 1100 // variadic unpacking, or a list of ints with segment sizes, where each 1101 // element is either a positive number (typically 1 for a scalar) or -1 to 1102 // indicate that it is derived from the length of the same-indexed operand 1103 // or result (implying that it is a list at that position). 1104 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1105 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1106 1107 std::vector<uint32_t> operandSegmentLengths; 1108 std::vector<uint32_t> resultSegmentLengths; 1109 1110 // Validate/determine region count. 1111 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1112 int opMinRegionCount = std::get<0>(opRegionSpec); 1113 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1114 if (!regions) { 1115 regions = opMinRegionCount; 1116 } 1117 if (*regions < opMinRegionCount) { 1118 throw py::value_error( 1119 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1120 llvm::Twine(opMinRegionCount) + 1121 " regions but was built with regions=" + llvm::Twine(*regions)) 1122 .str()); 1123 } 1124 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1125 throw py::value_error( 1126 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1127 llvm::Twine(opMinRegionCount) + 1128 " regions but was built with regions=" + llvm::Twine(*regions)) 1129 .str()); 1130 } 1131 1132 // Unpack results. 1133 std::vector<PyType *> resultTypes; 1134 resultTypes.reserve(resultTypeList.size()); 1135 if (resultSegmentSpecObj.is_none()) { 1136 // Non-variadic result unpacking. 1137 for (auto it : llvm::enumerate(resultTypeList)) { 1138 try { 1139 resultTypes.push_back(py::cast<PyType *>(it.value())); 1140 if (!resultTypes.back()) 1141 throw py::cast_error(); 1142 } catch (py::cast_error &err) { 1143 throw py::value_error((llvm::Twine("Result ") + 1144 llvm::Twine(it.index()) + " of operation \"" + 1145 name + "\" must be a Type (" + err.what() + ")") 1146 .str()); 1147 } 1148 } 1149 } else { 1150 // Sized result unpacking. 1151 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1152 if (resultSegmentSpec.size() != resultTypeList.size()) { 1153 throw py::value_error((llvm::Twine("Operation \"") + name + 1154 "\" requires " + 1155 llvm::Twine(resultSegmentSpec.size()) + 1156 "result segments but was provided " + 1157 llvm::Twine(resultTypeList.size())) 1158 .str()); 1159 } 1160 resultSegmentLengths.reserve(resultTypeList.size()); 1161 for (auto it : 1162 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1163 int segmentSpec = std::get<1>(it.value()); 1164 if (segmentSpec == 1 || segmentSpec == 0) { 1165 // Unpack unary element. 1166 try { 1167 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1168 if (resultType) { 1169 resultTypes.push_back(resultType); 1170 resultSegmentLengths.push_back(1); 1171 } else if (segmentSpec == 0) { 1172 // Allowed to be optional. 1173 resultSegmentLengths.push_back(0); 1174 } else { 1175 throw py::cast_error("was None and result is not optional"); 1176 } 1177 } catch (py::cast_error &err) { 1178 throw py::value_error((llvm::Twine("Result ") + 1179 llvm::Twine(it.index()) + " of operation \"" + 1180 name + "\" must be a Type (" + err.what() + 1181 ")") 1182 .str()); 1183 } 1184 } else if (segmentSpec == -1) { 1185 // Unpack sequence by appending. 1186 try { 1187 if (std::get<0>(it.value()).is_none()) { 1188 // Treat it as an empty list. 1189 resultSegmentLengths.push_back(0); 1190 } else { 1191 // Unpack the list. 1192 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1193 for (py::object segmentItem : segment) { 1194 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1195 if (!resultTypes.back()) { 1196 throw py::cast_error("contained a None item"); 1197 } 1198 } 1199 resultSegmentLengths.push_back(segment.size()); 1200 } 1201 } catch (std::exception &err) { 1202 // NOTE: Sloppy to be using a catch-all here, but there are at least 1203 // three different unrelated exceptions that can be thrown in the 1204 // above "casts". Just keep the scope above small and catch them all. 1205 throw py::value_error((llvm::Twine("Result ") + 1206 llvm::Twine(it.index()) + " of operation \"" + 1207 name + "\" must be a Sequence of Types (" + 1208 err.what() + ")") 1209 .str()); 1210 } 1211 } else { 1212 throw py::value_error("Unexpected segment spec"); 1213 } 1214 } 1215 } 1216 1217 // Unpack operands. 1218 std::vector<PyValue *> operands; 1219 operands.reserve(operands.size()); 1220 if (operandSegmentSpecObj.is_none()) { 1221 // Non-sized operand unpacking. 1222 for (auto it : llvm::enumerate(operandList)) { 1223 try { 1224 operands.push_back(py::cast<PyValue *>(it.value())); 1225 if (!operands.back()) 1226 throw py::cast_error(); 1227 } catch (py::cast_error &err) { 1228 throw py::value_error((llvm::Twine("Operand ") + 1229 llvm::Twine(it.index()) + " of operation \"" + 1230 name + "\" must be a Value (" + err.what() + ")") 1231 .str()); 1232 } 1233 } 1234 } else { 1235 // Sized operand unpacking. 1236 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1237 if (operandSegmentSpec.size() != operandList.size()) { 1238 throw py::value_error((llvm::Twine("Operation \"") + name + 1239 "\" requires " + 1240 llvm::Twine(operandSegmentSpec.size()) + 1241 "operand segments but was provided " + 1242 llvm::Twine(operandList.size())) 1243 .str()); 1244 } 1245 operandSegmentLengths.reserve(operandList.size()); 1246 for (auto it : 1247 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1248 int segmentSpec = std::get<1>(it.value()); 1249 if (segmentSpec == 1 || segmentSpec == 0) { 1250 // Unpack unary element. 1251 try { 1252 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1253 if (operandValue) { 1254 operands.push_back(operandValue); 1255 operandSegmentLengths.push_back(1); 1256 } else if (segmentSpec == 0) { 1257 // Allowed to be optional. 1258 operandSegmentLengths.push_back(0); 1259 } else { 1260 throw py::cast_error("was None and operand is not optional"); 1261 } 1262 } catch (py::cast_error &err) { 1263 throw py::value_error((llvm::Twine("Operand ") + 1264 llvm::Twine(it.index()) + " of operation \"" + 1265 name + "\" must be a Value (" + err.what() + 1266 ")") 1267 .str()); 1268 } 1269 } else if (segmentSpec == -1) { 1270 // Unpack sequence by appending. 1271 try { 1272 if (std::get<0>(it.value()).is_none()) { 1273 // Treat it as an empty list. 1274 operandSegmentLengths.push_back(0); 1275 } else { 1276 // Unpack the list. 1277 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1278 for (py::object segmentItem : segment) { 1279 operands.push_back(py::cast<PyValue *>(segmentItem)); 1280 if (!operands.back()) { 1281 throw py::cast_error("contained a None item"); 1282 } 1283 } 1284 operandSegmentLengths.push_back(segment.size()); 1285 } 1286 } catch (std::exception &err) { 1287 // NOTE: Sloppy to be using a catch-all here, but there are at least 1288 // three different unrelated exceptions that can be thrown in the 1289 // above "casts". Just keep the scope above small and catch them all. 1290 throw py::value_error((llvm::Twine("Operand ") + 1291 llvm::Twine(it.index()) + " of operation \"" + 1292 name + "\" must be a Sequence of Values (" + 1293 err.what() + ")") 1294 .str()); 1295 } 1296 } else { 1297 throw py::value_error("Unexpected segment spec"); 1298 } 1299 } 1300 } 1301 1302 // Merge operand/result segment lengths into attributes if needed. 1303 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1304 // Dup. 1305 if (attributes) { 1306 attributes = py::dict(*attributes); 1307 } else { 1308 attributes = py::dict(); 1309 } 1310 if (attributes->contains("result_segment_sizes") || 1311 attributes->contains("operand_segment_sizes")) { 1312 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1313 "'operand_segment_sizes' attribute is unsupported. " 1314 "Use Operation.create for such low-level access."); 1315 } 1316 1317 // Add result_segment_sizes attribute. 1318 if (!resultSegmentLengths.empty()) { 1319 int64_t size = resultSegmentLengths.size(); 1320 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1321 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1322 resultSegmentLengths.size(), resultSegmentLengths.data()); 1323 (*attributes)["result_segment_sizes"] = 1324 PyAttribute(context, segmentLengthAttr); 1325 } 1326 1327 // Add operand_segment_sizes attribute. 1328 if (!operandSegmentLengths.empty()) { 1329 int64_t size = operandSegmentLengths.size(); 1330 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1331 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1332 operandSegmentLengths.size(), operandSegmentLengths.data()); 1333 (*attributes)["operand_segment_sizes"] = 1334 PyAttribute(context, segmentLengthAttr); 1335 } 1336 } 1337 1338 // Delegate to create. 1339 return PyOperation::create(std::move(name), 1340 /*results=*/std::move(resultTypes), 1341 /*operands=*/std::move(operands), 1342 /*attributes=*/std::move(attributes), 1343 /*successors=*/std::move(successors), 1344 /*regions=*/*regions, location, maybeIp); 1345 } 1346 1347 PyOpView::PyOpView(py::object operationObject) 1348 // Casting through the PyOperationBase base-class and then back to the 1349 // Operation lets us accept any PyOperationBase subclass. 1350 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1351 operationObject(operation.getRef().getObject()) {} 1352 1353 py::object PyOpView::createRawSubclass(py::object userClass) { 1354 // This is... a little gross. The typical pattern is to have a pure python 1355 // class that extends OpView like: 1356 // class AddFOp(_cext.ir.OpView): 1357 // def __init__(self, loc, lhs, rhs): 1358 // operation = loc.context.create_operation( 1359 // "addf", lhs, rhs, results=[lhs.type]) 1360 // super().__init__(operation) 1361 // 1362 // I.e. The goal of the user facing type is to provide a nice constructor 1363 // that has complete freedom for the op under construction. This is at odds 1364 // with our other desire to sometimes create this object by just passing an 1365 // operation (to initialize the base class). We could do *arg and **kwargs 1366 // munging to try to make it work, but instead, we synthesize a new class 1367 // on the fly which extends this user class (AddFOp in this example) and 1368 // *give it* the base class's __init__ method, thus bypassing the 1369 // intermediate subclass's __init__ method entirely. While slightly, 1370 // underhanded, this is safe/legal because the type hierarchy has not changed 1371 // (we just added a new leaf) and we aren't mucking around with __new__. 1372 // Typically, this new class will be stored on the original as "_Raw" and will 1373 // be used for casts and other things that need a variant of the class that 1374 // is initialized purely from an operation. 1375 py::object parentMetaclass = 1376 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1377 py::dict attributes; 1378 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1379 // now. 1380 // auto opViewType = py::type::of<PyOpView>(); 1381 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1382 attributes["__init__"] = opViewType.attr("__init__"); 1383 py::str origName = userClass.attr("__name__"); 1384 py::str newName = py::str("_") + origName; 1385 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1386 } 1387 1388 //------------------------------------------------------------------------------ 1389 // PyInsertionPoint. 1390 //------------------------------------------------------------------------------ 1391 1392 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1393 1394 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1395 : refOperation(beforeOperationBase.getOperation().getRef()), 1396 block((*refOperation)->getBlock()) {} 1397 1398 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1399 PyOperation &operation = operationBase.getOperation(); 1400 if (operation.isAttached()) 1401 throw SetPyError(PyExc_ValueError, 1402 "Attempt to insert operation that is already attached"); 1403 block.getParentOperation()->checkValid(); 1404 MlirOperation beforeOp = {nullptr}; 1405 if (refOperation) { 1406 // Insert before operation. 1407 (*refOperation)->checkValid(); 1408 beforeOp = (*refOperation)->get(); 1409 } else { 1410 // Insert at end (before null) is only valid if the block does not 1411 // already end in a known terminator (violating this will cause assertion 1412 // failures later). 1413 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1414 throw py::index_error("Cannot insert operation at the end of a block " 1415 "that already has a terminator. Did you mean to " 1416 "use 'InsertionPoint.at_block_terminator(block)' " 1417 "versus 'InsertionPoint(block)'?"); 1418 } 1419 } 1420 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1421 operation.setAttached(); 1422 } 1423 1424 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1425 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1426 if (mlirOperationIsNull(firstOp)) { 1427 // Just insert at end. 1428 return PyInsertionPoint(block); 1429 } 1430 1431 // Insert before first op. 1432 PyOperationRef firstOpRef = PyOperation::forOperation( 1433 block.getParentOperation()->getContext(), firstOp); 1434 return PyInsertionPoint{block, std::move(firstOpRef)}; 1435 } 1436 1437 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1438 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1439 if (mlirOperationIsNull(terminator)) 1440 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1441 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1442 block.getParentOperation()->getContext(), terminator); 1443 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1444 } 1445 1446 py::object PyInsertionPoint::contextEnter() { 1447 return PyThreadContextEntry::pushInsertionPoint(*this); 1448 } 1449 1450 void PyInsertionPoint::contextExit(pybind11::object excType, 1451 pybind11::object excVal, 1452 pybind11::object excTb) { 1453 PyThreadContextEntry::popInsertionPoint(*this); 1454 } 1455 1456 //------------------------------------------------------------------------------ 1457 // PyAttribute. 1458 //------------------------------------------------------------------------------ 1459 1460 bool PyAttribute::operator==(const PyAttribute &other) { 1461 return mlirAttributeEqual(attr, other.attr); 1462 } 1463 1464 py::object PyAttribute::getCapsule() { 1465 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1466 } 1467 1468 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1469 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1470 if (mlirAttributeIsNull(rawAttr)) 1471 throw py::error_already_set(); 1472 return PyAttribute( 1473 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1474 } 1475 1476 //------------------------------------------------------------------------------ 1477 // PyNamedAttribute. 1478 //------------------------------------------------------------------------------ 1479 1480 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1481 : ownedName(new std::string(std::move(ownedName))) { 1482 namedAttr = mlirNamedAttributeGet( 1483 mlirIdentifierGet(mlirAttributeGetContext(attr), 1484 toMlirStringRef(*this->ownedName)), 1485 attr); 1486 } 1487 1488 //------------------------------------------------------------------------------ 1489 // PyType. 1490 //------------------------------------------------------------------------------ 1491 1492 bool PyType::operator==(const PyType &other) { 1493 return mlirTypeEqual(type, other.type); 1494 } 1495 1496 py::object PyType::getCapsule() { 1497 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1498 } 1499 1500 PyType PyType::createFromCapsule(py::object capsule) { 1501 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1502 if (mlirTypeIsNull(rawType)) 1503 throw py::error_already_set(); 1504 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1505 rawType); 1506 } 1507 1508 //------------------------------------------------------------------------------ 1509 // PyValue and subclases. 1510 //------------------------------------------------------------------------------ 1511 1512 pybind11::object PyValue::getCapsule() { 1513 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1514 } 1515 1516 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1517 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1518 if (mlirValueIsNull(value)) 1519 throw py::error_already_set(); 1520 MlirOperation owner; 1521 if (mlirValueIsAOpResult(value)) 1522 owner = mlirOpResultGetOwner(value); 1523 if (mlirValueIsABlockArgument(value)) 1524 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1525 if (mlirOperationIsNull(owner)) 1526 throw py::error_already_set(); 1527 MlirContext ctx = mlirOperationGetContext(owner); 1528 PyOperationRef ownerRef = 1529 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1530 return PyValue(ownerRef, value); 1531 } 1532 1533 namespace { 1534 /// CRTP base class for Python MLIR values that subclass Value and should be 1535 /// castable from it. The value hierarchy is one level deep and is not supposed 1536 /// to accommodate other levels unless core MLIR changes. 1537 template <typename DerivedTy> 1538 class PyConcreteValue : public PyValue { 1539 public: 1540 // Derived classes must define statics for: 1541 // IsAFunctionTy isaFunction 1542 // const char *pyClassName 1543 // and redefine bindDerived. 1544 using ClassTy = py::class_<DerivedTy, PyValue>; 1545 using IsAFunctionTy = bool (*)(MlirValue); 1546 1547 PyConcreteValue() = default; 1548 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1549 : PyValue(operationRef, value) {} 1550 PyConcreteValue(PyValue &orig) 1551 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1552 1553 /// Attempts to cast the original value to the derived type and throws on 1554 /// type mismatches. 1555 static MlirValue castFrom(PyValue &orig) { 1556 if (!DerivedTy::isaFunction(orig.get())) { 1557 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1558 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1559 DerivedTy::pyClassName + 1560 " (from " + origRepr + ")"); 1561 } 1562 return orig.get(); 1563 } 1564 1565 /// Binds the Python module objects to functions of this class. 1566 static void bind(py::module &m) { 1567 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1568 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1569 cls.def_static("isinstance", [](PyValue &otherValue) -> bool { 1570 return DerivedTy::isaFunction(otherValue); 1571 }); 1572 DerivedTy::bindDerived(cls); 1573 } 1574 1575 /// Implemented by derived classes to add methods to the Python subclass. 1576 static void bindDerived(ClassTy &m) {} 1577 }; 1578 1579 /// Python wrapper for MlirBlockArgument. 1580 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1581 public: 1582 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1583 static constexpr const char *pyClassName = "BlockArgument"; 1584 using PyConcreteValue::PyConcreteValue; 1585 1586 static void bindDerived(ClassTy &c) { 1587 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1588 return PyBlock(self.getParentOperation(), 1589 mlirBlockArgumentGetOwner(self.get())); 1590 }); 1591 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1592 return mlirBlockArgumentGetArgNumber(self.get()); 1593 }); 1594 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1595 return mlirBlockArgumentSetType(self.get(), type); 1596 }); 1597 } 1598 }; 1599 1600 /// Python wrapper for MlirOpResult. 1601 class PyOpResult : public PyConcreteValue<PyOpResult> { 1602 public: 1603 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1604 static constexpr const char *pyClassName = "OpResult"; 1605 using PyConcreteValue::PyConcreteValue; 1606 1607 static void bindDerived(ClassTy &c) { 1608 c.def_property_readonly("owner", [](PyOpResult &self) { 1609 assert( 1610 mlirOperationEqual(self.getParentOperation()->get(), 1611 mlirOpResultGetOwner(self.get())) && 1612 "expected the owner of the value in Python to match that in the IR"); 1613 return self.getParentOperation().getObject(); 1614 }); 1615 c.def_property_readonly("result_number", [](PyOpResult &self) { 1616 return mlirOpResultGetResultNumber(self.get()); 1617 }); 1618 } 1619 }; 1620 1621 /// Returns the list of types of the values held by container. 1622 template <typename Container> 1623 static std::vector<PyType> getValueTypes(Container &container, 1624 PyMlirContextRef &context) { 1625 std::vector<PyType> result; 1626 result.reserve(container.getNumElements()); 1627 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1628 result.push_back( 1629 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1630 } 1631 return result; 1632 } 1633 1634 /// A list of block arguments. Internally, these are stored as consecutive 1635 /// elements, random access is cheap. The argument list is associated with the 1636 /// operation that contains the block (detached blocks are not allowed in 1637 /// Python bindings) and extends its lifetime. 1638 class PyBlockArgumentList 1639 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1640 public: 1641 static constexpr const char *pyClassName = "BlockArgumentList"; 1642 1643 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1644 intptr_t startIndex = 0, intptr_t length = -1, 1645 intptr_t step = 1) 1646 : Sliceable(startIndex, 1647 length == -1 ? mlirBlockGetNumArguments(block) : length, 1648 step), 1649 operation(std::move(operation)), block(block) {} 1650 1651 /// Returns the number of arguments in the list. 1652 intptr_t getNumElements() { 1653 operation->checkValid(); 1654 return mlirBlockGetNumArguments(block); 1655 } 1656 1657 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1658 PyBlockArgument getElement(intptr_t pos) { 1659 MlirValue argument = mlirBlockGetArgument(block, pos); 1660 return PyBlockArgument(operation, argument); 1661 } 1662 1663 /// Returns a sublist of this list. 1664 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1665 intptr_t step) { 1666 return PyBlockArgumentList(operation, block, startIndex, length, step); 1667 } 1668 1669 static void bindDerived(ClassTy &c) { 1670 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1671 return getValueTypes(self, self.operation->getContext()); 1672 }); 1673 } 1674 1675 private: 1676 PyOperationRef operation; 1677 MlirBlock block; 1678 }; 1679 1680 /// A list of operation operands. Internally, these are stored as consecutive 1681 /// elements, random access is cheap. The result list is associated with the 1682 /// operation whose results these are, and extends the lifetime of this 1683 /// operation. 1684 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1685 public: 1686 static constexpr const char *pyClassName = "OpOperandList"; 1687 1688 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1689 intptr_t length = -1, intptr_t step = 1) 1690 : Sliceable(startIndex, 1691 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1692 : length, 1693 step), 1694 operation(operation) {} 1695 1696 intptr_t getNumElements() { 1697 operation->checkValid(); 1698 return mlirOperationGetNumOperands(operation->get()); 1699 } 1700 1701 PyValue getElement(intptr_t pos) { 1702 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1703 MlirOperation owner; 1704 if (mlirValueIsAOpResult(operand)) 1705 owner = mlirOpResultGetOwner(operand); 1706 else if (mlirValueIsABlockArgument(operand)) 1707 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1708 else 1709 assert(false && "Value must be an block arg or op result."); 1710 PyOperationRef pyOwner = 1711 PyOperation::forOperation(operation->getContext(), owner); 1712 return PyValue(pyOwner, operand); 1713 } 1714 1715 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1716 return PyOpOperandList(operation, startIndex, length, step); 1717 } 1718 1719 void dunderSetItem(intptr_t index, PyValue value) { 1720 index = wrapIndex(index); 1721 mlirOperationSetOperand(operation->get(), index, value.get()); 1722 } 1723 1724 static void bindDerived(ClassTy &c) { 1725 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1726 } 1727 1728 private: 1729 PyOperationRef operation; 1730 }; 1731 1732 /// A list of operation results. Internally, these are stored as consecutive 1733 /// elements, random access is cheap. The result list is associated with the 1734 /// operation whose results these are, and extends the lifetime of this 1735 /// operation. 1736 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1737 public: 1738 static constexpr const char *pyClassName = "OpResultList"; 1739 1740 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1741 intptr_t length = -1, intptr_t step = 1) 1742 : Sliceable(startIndex, 1743 length == -1 ? mlirOperationGetNumResults(operation->get()) 1744 : length, 1745 step), 1746 operation(operation) {} 1747 1748 intptr_t getNumElements() { 1749 operation->checkValid(); 1750 return mlirOperationGetNumResults(operation->get()); 1751 } 1752 1753 PyOpResult getElement(intptr_t index) { 1754 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1755 return PyOpResult(value); 1756 } 1757 1758 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1759 return PyOpResultList(operation, startIndex, length, step); 1760 } 1761 1762 static void bindDerived(ClassTy &c) { 1763 c.def_property_readonly("types", [](PyOpResultList &self) { 1764 return getValueTypes(self, self.operation->getContext()); 1765 }); 1766 } 1767 1768 private: 1769 PyOperationRef operation; 1770 }; 1771 1772 /// A list of operation attributes. Can be indexed by name, producing 1773 /// attributes, or by index, producing named attributes. 1774 class PyOpAttributeMap { 1775 public: 1776 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1777 1778 PyAttribute dunderGetItemNamed(const std::string &name) { 1779 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1780 toMlirStringRef(name)); 1781 if (mlirAttributeIsNull(attr)) { 1782 throw SetPyError(PyExc_KeyError, 1783 "attempt to access a non-existent attribute"); 1784 } 1785 return PyAttribute(operation->getContext(), attr); 1786 } 1787 1788 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1789 if (index < 0 || index >= dunderLen()) { 1790 throw SetPyError(PyExc_IndexError, 1791 "attempt to access out of bounds attribute"); 1792 } 1793 MlirNamedAttribute namedAttr = 1794 mlirOperationGetAttribute(operation->get(), index); 1795 return PyNamedAttribute( 1796 namedAttr.attribute, 1797 std::string(mlirIdentifierStr(namedAttr.name).data)); 1798 } 1799 1800 void dunderSetItem(const std::string &name, PyAttribute attr) { 1801 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1802 attr); 1803 } 1804 1805 void dunderDelItem(const std::string &name) { 1806 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1807 toMlirStringRef(name)); 1808 if (!removed) 1809 throw SetPyError(PyExc_KeyError, 1810 "attempt to delete a non-existent attribute"); 1811 } 1812 1813 intptr_t dunderLen() { 1814 return mlirOperationGetNumAttributes(operation->get()); 1815 } 1816 1817 bool dunderContains(const std::string &name) { 1818 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1819 operation->get(), toMlirStringRef(name))); 1820 } 1821 1822 static void bind(py::module &m) { 1823 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1824 .def("__contains__", &PyOpAttributeMap::dunderContains) 1825 .def("__len__", &PyOpAttributeMap::dunderLen) 1826 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1827 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1828 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1829 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1830 } 1831 1832 private: 1833 PyOperationRef operation; 1834 }; 1835 1836 } // end namespace 1837 1838 //------------------------------------------------------------------------------ 1839 // Populates the core exports of the 'ir' submodule. 1840 //------------------------------------------------------------------------------ 1841 1842 void mlir::python::populateIRCore(py::module &m) { 1843 //---------------------------------------------------------------------------- 1844 // Mapping of MlirContext. 1845 //---------------------------------------------------------------------------- 1846 py::class_<PyMlirContext>(m, "Context", py::module_local()) 1847 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1848 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1849 .def("_get_context_again", 1850 [](PyMlirContext &self) { 1851 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1852 return ref.releaseObject(); 1853 }) 1854 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1855 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1856 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1857 &PyMlirContext::getCapsule) 1858 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1859 .def("__enter__", &PyMlirContext::contextEnter) 1860 .def("__exit__", &PyMlirContext::contextExit) 1861 .def_property_readonly_static( 1862 "current", 1863 [](py::object & /*class*/) { 1864 auto *context = PyThreadContextEntry::getDefaultContext(); 1865 if (!context) 1866 throw SetPyError(PyExc_ValueError, "No current Context"); 1867 return context; 1868 }, 1869 "Gets the Context bound to the current thread or raises ValueError") 1870 .def_property_readonly( 1871 "dialects", 1872 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1873 "Gets a container for accessing dialects by name") 1874 .def_property_readonly( 1875 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1876 "Alias for 'dialect'") 1877 .def( 1878 "get_dialect_descriptor", 1879 [=](PyMlirContext &self, std::string &name) { 1880 MlirDialect dialect = mlirContextGetOrLoadDialect( 1881 self.get(), {name.data(), name.size()}); 1882 if (mlirDialectIsNull(dialect)) { 1883 throw SetPyError(PyExc_ValueError, 1884 Twine("Dialect '") + name + "' not found"); 1885 } 1886 return PyDialectDescriptor(self.getRef(), dialect); 1887 }, 1888 "Gets or loads a dialect by name, returning its descriptor object") 1889 .def_property( 1890 "allow_unregistered_dialects", 1891 [](PyMlirContext &self) -> bool { 1892 return mlirContextGetAllowUnregisteredDialects(self.get()); 1893 }, 1894 [](PyMlirContext &self, bool value) { 1895 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1896 }) 1897 .def("enable_multithreading", 1898 [](PyMlirContext &self, bool enable) { 1899 mlirContextEnableMultithreading(self.get(), enable); 1900 }) 1901 .def("is_registered_operation", 1902 [](PyMlirContext &self, std::string &name) { 1903 return mlirContextIsRegisteredOperation( 1904 self.get(), MlirStringRef{name.data(), name.size()}); 1905 }); 1906 1907 //---------------------------------------------------------------------------- 1908 // Mapping of PyDialectDescriptor 1909 //---------------------------------------------------------------------------- 1910 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1911 .def_property_readonly("namespace", 1912 [](PyDialectDescriptor &self) { 1913 MlirStringRef ns = 1914 mlirDialectGetNamespace(self.get()); 1915 return py::str(ns.data, ns.length); 1916 }) 1917 .def("__repr__", [](PyDialectDescriptor &self) { 1918 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1919 std::string repr("<DialectDescriptor "); 1920 repr.append(ns.data, ns.length); 1921 repr.append(">"); 1922 return repr; 1923 }); 1924 1925 //---------------------------------------------------------------------------- 1926 // Mapping of PyDialects 1927 //---------------------------------------------------------------------------- 1928 py::class_<PyDialects>(m, "Dialects", py::module_local()) 1929 .def("__getitem__", 1930 [=](PyDialects &self, std::string keyName) { 1931 MlirDialect dialect = 1932 self.getDialectForKey(keyName, /*attrError=*/false); 1933 py::object descriptor = 1934 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1935 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1936 }) 1937 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1938 MlirDialect dialect = 1939 self.getDialectForKey(attrName, /*attrError=*/true); 1940 py::object descriptor = 1941 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1942 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1943 }); 1944 1945 //---------------------------------------------------------------------------- 1946 // Mapping of PyDialect 1947 //---------------------------------------------------------------------------- 1948 py::class_<PyDialect>(m, "Dialect", py::module_local()) 1949 .def(py::init<py::object>(), "descriptor") 1950 .def_property_readonly( 1951 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1952 .def("__repr__", [](py::object self) { 1953 auto clazz = self.attr("__class__"); 1954 return py::str("<Dialect ") + 1955 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1956 clazz.attr("__module__") + py::str(".") + 1957 clazz.attr("__name__") + py::str(")>"); 1958 }); 1959 1960 //---------------------------------------------------------------------------- 1961 // Mapping of Location 1962 //---------------------------------------------------------------------------- 1963 py::class_<PyLocation>(m, "Location", py::module_local()) 1964 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1965 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1966 .def("__enter__", &PyLocation::contextEnter) 1967 .def("__exit__", &PyLocation::contextExit) 1968 .def("__eq__", 1969 [](PyLocation &self, PyLocation &other) -> bool { 1970 return mlirLocationEqual(self, other); 1971 }) 1972 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1973 .def_property_readonly_static( 1974 "current", 1975 [](py::object & /*class*/) { 1976 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1977 if (!loc) 1978 throw SetPyError(PyExc_ValueError, "No current Location"); 1979 return loc; 1980 }, 1981 "Gets the Location bound to the current thread or raises ValueError") 1982 .def_static( 1983 "unknown", 1984 [](DefaultingPyMlirContext context) { 1985 return PyLocation(context->getRef(), 1986 mlirLocationUnknownGet(context->get())); 1987 }, 1988 py::arg("context") = py::none(), 1989 "Gets a Location representing an unknown location") 1990 .def_static( 1991 "callsite", 1992 [](PyLocation callee, const std::vector<PyLocation> &frames, 1993 DefaultingPyMlirContext context) { 1994 if (frames.empty()) 1995 throw py::value_error("No caller frames provided"); 1996 MlirLocation caller = frames.back().get(); 1997 for (const PyLocation &frame : 1998 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 1999 caller = mlirLocationCallSiteGet(frame.get(), caller); 2000 return PyLocation(context->getRef(), 2001 mlirLocationCallSiteGet(callee.get(), caller)); 2002 }, 2003 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2004 kContextGetCallSiteLocationDocstring) 2005 .def_static( 2006 "file", 2007 [](std::string filename, int line, int col, 2008 DefaultingPyMlirContext context) { 2009 return PyLocation( 2010 context->getRef(), 2011 mlirLocationFileLineColGet( 2012 context->get(), toMlirStringRef(filename), line, col)); 2013 }, 2014 py::arg("filename"), py::arg("line"), py::arg("col"), 2015 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2016 .def_static( 2017 "name", 2018 [](std::string name, llvm::Optional<PyLocation> childLoc, 2019 DefaultingPyMlirContext context) { 2020 return PyLocation( 2021 context->getRef(), 2022 mlirLocationNameGet( 2023 context->get(), toMlirStringRef(name), 2024 childLoc ? childLoc->get() 2025 : mlirLocationUnknownGet(context->get()))); 2026 }, 2027 py::arg("name"), py::arg("childLoc") = py::none(), 2028 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2029 .def_property_readonly( 2030 "context", 2031 [](PyLocation &self) { return self.getContext().getObject(); }, 2032 "Context that owns the Location") 2033 .def("__repr__", [](PyLocation &self) { 2034 PyPrintAccumulator printAccum; 2035 mlirLocationPrint(self, printAccum.getCallback(), 2036 printAccum.getUserData()); 2037 return printAccum.join(); 2038 }); 2039 2040 //---------------------------------------------------------------------------- 2041 // Mapping of Module 2042 //---------------------------------------------------------------------------- 2043 py::class_<PyModule>(m, "Module", py::module_local()) 2044 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2045 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2046 .def_static( 2047 "parse", 2048 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2049 MlirModule module = mlirModuleCreateParse( 2050 context->get(), toMlirStringRef(moduleAsm)); 2051 // TODO: Rework error reporting once diagnostic engine is exposed 2052 // in C API. 2053 if (mlirModuleIsNull(module)) { 2054 throw SetPyError( 2055 PyExc_ValueError, 2056 "Unable to parse module assembly (see diagnostics)"); 2057 } 2058 return PyModule::forModule(module).releaseObject(); 2059 }, 2060 py::arg("asm"), py::arg("context") = py::none(), 2061 kModuleParseDocstring) 2062 .def_static( 2063 "create", 2064 [](DefaultingPyLocation loc) { 2065 MlirModule module = mlirModuleCreateEmpty(loc); 2066 return PyModule::forModule(module).releaseObject(); 2067 }, 2068 py::arg("loc") = py::none(), "Creates an empty module") 2069 .def_property_readonly( 2070 "context", 2071 [](PyModule &self) { return self.getContext().getObject(); }, 2072 "Context that created the Module") 2073 .def_property_readonly( 2074 "operation", 2075 [](PyModule &self) { 2076 return PyOperation::forOperation(self.getContext(), 2077 mlirModuleGetOperation(self.get()), 2078 self.getRef().releaseObject()) 2079 .releaseObject(); 2080 }, 2081 "Accesses the module as an operation") 2082 .def_property_readonly( 2083 "body", 2084 [](PyModule &self) { 2085 PyOperationRef module_op = PyOperation::forOperation( 2086 self.getContext(), mlirModuleGetOperation(self.get()), 2087 self.getRef().releaseObject()); 2088 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2089 return returnBlock; 2090 }, 2091 "Return the block for this module") 2092 .def( 2093 "dump", 2094 [](PyModule &self) { 2095 mlirOperationDump(mlirModuleGetOperation(self.get())); 2096 }, 2097 kDumpDocstring) 2098 .def( 2099 "__str__", 2100 [](PyModule &self) { 2101 MlirOperation operation = mlirModuleGetOperation(self.get()); 2102 PyPrintAccumulator printAccum; 2103 mlirOperationPrint(operation, printAccum.getCallback(), 2104 printAccum.getUserData()); 2105 return printAccum.join(); 2106 }, 2107 kOperationStrDunderDocstring); 2108 2109 //---------------------------------------------------------------------------- 2110 // Mapping of Operation. 2111 //---------------------------------------------------------------------------- 2112 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2113 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2114 [](PyOperationBase &self) { 2115 return self.getOperation().getCapsule(); 2116 }) 2117 .def("__eq__", 2118 [](PyOperationBase &self, PyOperationBase &other) { 2119 return &self.getOperation() == &other.getOperation(); 2120 }) 2121 .def("__eq__", 2122 [](PyOperationBase &self, py::object other) { return false; }) 2123 .def_property_readonly("attributes", 2124 [](PyOperationBase &self) { 2125 return PyOpAttributeMap( 2126 self.getOperation().getRef()); 2127 }) 2128 .def_property_readonly("operands", 2129 [](PyOperationBase &self) { 2130 return PyOpOperandList( 2131 self.getOperation().getRef()); 2132 }) 2133 .def_property_readonly("regions", 2134 [](PyOperationBase &self) { 2135 return PyRegionList( 2136 self.getOperation().getRef()); 2137 }) 2138 .def_property_readonly( 2139 "results", 2140 [](PyOperationBase &self) { 2141 return PyOpResultList(self.getOperation().getRef()); 2142 }, 2143 "Returns the list of Operation results.") 2144 .def_property_readonly( 2145 "result", 2146 [](PyOperationBase &self) { 2147 auto &operation = self.getOperation(); 2148 auto numResults = mlirOperationGetNumResults(operation); 2149 if (numResults != 1) { 2150 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2151 throw SetPyError( 2152 PyExc_ValueError, 2153 Twine("Cannot call .result on operation ") + 2154 StringRef(name.data, name.length) + " which has " + 2155 Twine(numResults) + 2156 " results (it is only valid for operations with a " 2157 "single result)"); 2158 } 2159 return PyOpResult(operation.getRef(), 2160 mlirOperationGetResult(operation, 0)); 2161 }, 2162 "Shortcut to get an op result if it has only one (throws an error " 2163 "otherwise).") 2164 .def_property_readonly( 2165 "location", 2166 [](PyOperationBase &self) { 2167 PyOperation &operation = self.getOperation(); 2168 return PyLocation(operation.getContext(), 2169 mlirOperationGetLocation(operation.get())); 2170 }, 2171 "Returns the source location the operation was defined or derived " 2172 "from.") 2173 .def( 2174 "__str__", 2175 [](PyOperationBase &self) { 2176 return self.getAsm(/*binary=*/false, 2177 /*largeElementsLimit=*/llvm::None, 2178 /*enableDebugInfo=*/false, 2179 /*prettyDebugInfo=*/false, 2180 /*printGenericOpForm=*/false, 2181 /*useLocalScope=*/false); 2182 }, 2183 "Returns the assembly form of the operation.") 2184 .def("print", &PyOperationBase::print, 2185 // Careful: Lots of arguments must match up with print method. 2186 py::arg("file") = py::none(), py::arg("binary") = false, 2187 py::arg("large_elements_limit") = py::none(), 2188 py::arg("enable_debug_info") = false, 2189 py::arg("pretty_debug_info") = false, 2190 py::arg("print_generic_op_form") = false, 2191 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2192 .def("get_asm", &PyOperationBase::getAsm, 2193 // Careful: Lots of arguments must match up with get_asm method. 2194 py::arg("binary") = false, 2195 py::arg("large_elements_limit") = py::none(), 2196 py::arg("enable_debug_info") = false, 2197 py::arg("pretty_debug_info") = false, 2198 py::arg("print_generic_op_form") = false, 2199 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2200 .def( 2201 "verify", 2202 [](PyOperationBase &self) { 2203 return mlirOperationVerify(self.getOperation()); 2204 }, 2205 "Verify the operation and return true if it passes, false if it " 2206 "fails.") 2207 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2208 "Puts self immediately after the other operation in its parent " 2209 "block.") 2210 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2211 "Puts self immediately before the other operation in its parent " 2212 "block.") 2213 .def( 2214 "detach_from_parent", 2215 [](PyOperationBase &self) { 2216 PyOperation &operation = self.getOperation(); 2217 operation.checkValid(); 2218 if (!operation.isAttached()) 2219 throw py::value_error("Detached operation has no parent."); 2220 2221 operation.detachFromParent(); 2222 return operation.createOpView(); 2223 }, 2224 "Detaches the operation from its parent block."); 2225 2226 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2227 .def_static("create", &PyOperation::create, py::arg("name"), 2228 py::arg("results") = py::none(), 2229 py::arg("operands") = py::none(), 2230 py::arg("attributes") = py::none(), 2231 py::arg("successors") = py::none(), py::arg("regions") = 0, 2232 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2233 kOperationCreateDocstring) 2234 .def_property_readonly("parent", 2235 [](PyOperation &self) -> py::object { 2236 auto parent = self.getParentOperation(); 2237 if (parent) 2238 return parent->getObject(); 2239 return py::none(); 2240 }) 2241 .def("erase", &PyOperation::erase) 2242 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2243 &PyOperation::getCapsule) 2244 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2245 .def_property_readonly("name", 2246 [](PyOperation &self) { 2247 self.checkValid(); 2248 MlirOperation operation = self.get(); 2249 MlirStringRef name = mlirIdentifierStr( 2250 mlirOperationGetName(operation)); 2251 return py::str(name.data, name.length); 2252 }) 2253 .def_property_readonly( 2254 "context", 2255 [](PyOperation &self) { 2256 self.checkValid(); 2257 return self.getContext().getObject(); 2258 }, 2259 "Context that owns the Operation") 2260 .def_property_readonly("opview", &PyOperation::createOpView); 2261 2262 auto opViewClass = 2263 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2264 .def(py::init<py::object>()) 2265 .def_property_readonly("operation", &PyOpView::getOperationObject) 2266 .def_property_readonly( 2267 "context", 2268 [](PyOpView &self) { 2269 return self.getOperation().getContext().getObject(); 2270 }, 2271 "Context that owns the Operation") 2272 .def("__str__", [](PyOpView &self) { 2273 return py::str(self.getOperationObject()); 2274 }); 2275 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2276 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2277 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2278 opViewClass.attr("build_generic") = classmethod( 2279 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2280 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2281 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2282 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2283 "Builds a specific, generated OpView based on class level attributes."); 2284 2285 //---------------------------------------------------------------------------- 2286 // Mapping of PyRegion. 2287 //---------------------------------------------------------------------------- 2288 py::class_<PyRegion>(m, "Region", py::module_local()) 2289 .def_property_readonly( 2290 "blocks", 2291 [](PyRegion &self) { 2292 return PyBlockList(self.getParentOperation(), self.get()); 2293 }, 2294 "Returns a forward-optimized sequence of blocks.") 2295 .def_property_readonly( 2296 "owner", 2297 [](PyRegion &self) { 2298 return self.getParentOperation()->createOpView(); 2299 }, 2300 "Returns the operation owning this region.") 2301 .def( 2302 "__iter__", 2303 [](PyRegion &self) { 2304 self.checkValid(); 2305 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2306 return PyBlockIterator(self.getParentOperation(), firstBlock); 2307 }, 2308 "Iterates over blocks in the region.") 2309 .def("__eq__", 2310 [](PyRegion &self, PyRegion &other) { 2311 return self.get().ptr == other.get().ptr; 2312 }) 2313 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2314 2315 //---------------------------------------------------------------------------- 2316 // Mapping of PyBlock. 2317 //---------------------------------------------------------------------------- 2318 py::class_<PyBlock>(m, "Block", py::module_local()) 2319 .def_property_readonly( 2320 "owner", 2321 [](PyBlock &self) { 2322 return self.getParentOperation()->createOpView(); 2323 }, 2324 "Returns the owning operation of this block.") 2325 .def_property_readonly( 2326 "region", 2327 [](PyBlock &self) { 2328 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2329 return PyRegion(self.getParentOperation(), region); 2330 }, 2331 "Returns the owning region of this block.") 2332 .def_property_readonly( 2333 "arguments", 2334 [](PyBlock &self) { 2335 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2336 }, 2337 "Returns a list of block arguments.") 2338 .def_property_readonly( 2339 "operations", 2340 [](PyBlock &self) { 2341 return PyOperationList(self.getParentOperation(), self.get()); 2342 }, 2343 "Returns a forward-optimized sequence of operations.") 2344 .def_static( 2345 "create_at_start", 2346 [](PyRegion &parent, py::list pyArgTypes) { 2347 parent.checkValid(); 2348 llvm::SmallVector<MlirType, 4> argTypes; 2349 argTypes.reserve(pyArgTypes.size()); 2350 for (auto &pyArg : pyArgTypes) { 2351 argTypes.push_back(pyArg.cast<PyType &>()); 2352 } 2353 2354 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2355 mlirRegionInsertOwnedBlock(parent, 0, block); 2356 return PyBlock(parent.getParentOperation(), block); 2357 }, 2358 py::arg("parent"), py::arg("pyArgTypes") = py::list(), 2359 "Creates and returns a new Block at the beginning of the given " 2360 "region (with given argument types).") 2361 .def( 2362 "create_before", 2363 [](PyBlock &self, py::args pyArgTypes) { 2364 self.checkValid(); 2365 llvm::SmallVector<MlirType, 4> argTypes; 2366 argTypes.reserve(pyArgTypes.size()); 2367 for (auto &pyArg : pyArgTypes) { 2368 argTypes.push_back(pyArg.cast<PyType &>()); 2369 } 2370 2371 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2372 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2373 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2374 return PyBlock(self.getParentOperation(), block); 2375 }, 2376 "Creates and returns a new Block before this block " 2377 "(with given argument types).") 2378 .def( 2379 "create_after", 2380 [](PyBlock &self, py::args pyArgTypes) { 2381 self.checkValid(); 2382 llvm::SmallVector<MlirType, 4> argTypes; 2383 argTypes.reserve(pyArgTypes.size()); 2384 for (auto &pyArg : pyArgTypes) { 2385 argTypes.push_back(pyArg.cast<PyType &>()); 2386 } 2387 2388 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2389 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2390 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2391 return PyBlock(self.getParentOperation(), block); 2392 }, 2393 "Creates and returns a new Block after this block " 2394 "(with given argument types).") 2395 .def( 2396 "__iter__", 2397 [](PyBlock &self) { 2398 self.checkValid(); 2399 MlirOperation firstOperation = 2400 mlirBlockGetFirstOperation(self.get()); 2401 return PyOperationIterator(self.getParentOperation(), 2402 firstOperation); 2403 }, 2404 "Iterates over operations in the block.") 2405 .def("__eq__", 2406 [](PyBlock &self, PyBlock &other) { 2407 return self.get().ptr == other.get().ptr; 2408 }) 2409 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2410 .def( 2411 "__str__", 2412 [](PyBlock &self) { 2413 self.checkValid(); 2414 PyPrintAccumulator printAccum; 2415 mlirBlockPrint(self.get(), printAccum.getCallback(), 2416 printAccum.getUserData()); 2417 return printAccum.join(); 2418 }, 2419 "Returns the assembly form of the block.") 2420 .def( 2421 "append", 2422 [](PyBlock &self, PyOperationBase &operation) { 2423 if (operation.getOperation().isAttached()) 2424 operation.getOperation().detachFromParent(); 2425 2426 MlirOperation mlirOperation = operation.getOperation().get(); 2427 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2428 operation.getOperation().setAttached( 2429 self.getParentOperation().getObject()); 2430 }, 2431 "Appends an operation to this block. If the operation is currently " 2432 "in another block, it will be moved."); 2433 2434 //---------------------------------------------------------------------------- 2435 // Mapping of PyInsertionPoint. 2436 //---------------------------------------------------------------------------- 2437 2438 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2439 .def(py::init<PyBlock &>(), py::arg("block"), 2440 "Inserts after the last operation but still inside the block.") 2441 .def("__enter__", &PyInsertionPoint::contextEnter) 2442 .def("__exit__", &PyInsertionPoint::contextExit) 2443 .def_property_readonly_static( 2444 "current", 2445 [](py::object & /*class*/) { 2446 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2447 if (!ip) 2448 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2449 return ip; 2450 }, 2451 "Gets the InsertionPoint bound to the current thread or raises " 2452 "ValueError if none has been set") 2453 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2454 "Inserts before a referenced operation.") 2455 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2456 py::arg("block"), "Inserts at the beginning of the block.") 2457 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2458 py::arg("block"), "Inserts before the block terminator.") 2459 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2460 "Inserts an operation.") 2461 .def_property_readonly( 2462 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2463 "Returns the block that this InsertionPoint points to."); 2464 2465 //---------------------------------------------------------------------------- 2466 // Mapping of PyAttribute. 2467 //---------------------------------------------------------------------------- 2468 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2469 // Delegate to the PyAttribute copy constructor, which will also lifetime 2470 // extend the backing context which owns the MlirAttribute. 2471 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2472 "Casts the passed attribute to the generic Attribute") 2473 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2474 &PyAttribute::getCapsule) 2475 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2476 .def_static( 2477 "parse", 2478 [](std::string attrSpec, DefaultingPyMlirContext context) { 2479 MlirAttribute type = mlirAttributeParseGet( 2480 context->get(), toMlirStringRef(attrSpec)); 2481 // TODO: Rework error reporting once diagnostic engine is exposed 2482 // in C API. 2483 if (mlirAttributeIsNull(type)) { 2484 throw SetPyError(PyExc_ValueError, 2485 Twine("Unable to parse attribute: '") + 2486 attrSpec + "'"); 2487 } 2488 return PyAttribute(context->getRef(), type); 2489 }, 2490 py::arg("asm"), py::arg("context") = py::none(), 2491 "Parses an attribute from an assembly form") 2492 .def_property_readonly( 2493 "context", 2494 [](PyAttribute &self) { return self.getContext().getObject(); }, 2495 "Context that owns the Attribute") 2496 .def_property_readonly("type", 2497 [](PyAttribute &self) { 2498 return PyType(self.getContext()->getRef(), 2499 mlirAttributeGetType(self)); 2500 }) 2501 .def( 2502 "get_named", 2503 [](PyAttribute &self, std::string name) { 2504 return PyNamedAttribute(self, std::move(name)); 2505 }, 2506 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2507 .def("__eq__", 2508 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2509 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2510 .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; }) 2511 .def( 2512 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2513 kDumpDocstring) 2514 .def( 2515 "__str__", 2516 [](PyAttribute &self) { 2517 PyPrintAccumulator printAccum; 2518 mlirAttributePrint(self, printAccum.getCallback(), 2519 printAccum.getUserData()); 2520 return printAccum.join(); 2521 }, 2522 "Returns the assembly form of the Attribute.") 2523 .def("__repr__", [](PyAttribute &self) { 2524 // Generally, assembly formats are not printed for __repr__ because 2525 // this can cause exceptionally long debug output and exceptions. 2526 // However, attribute values are generally considered useful and are 2527 // printed. This may need to be re-evaluated if debug dumps end up 2528 // being excessive. 2529 PyPrintAccumulator printAccum; 2530 printAccum.parts.append("Attribute("); 2531 mlirAttributePrint(self, printAccum.getCallback(), 2532 printAccum.getUserData()); 2533 printAccum.parts.append(")"); 2534 return printAccum.join(); 2535 }); 2536 2537 //---------------------------------------------------------------------------- 2538 // Mapping of PyNamedAttribute 2539 //---------------------------------------------------------------------------- 2540 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2541 .def("__repr__", 2542 [](PyNamedAttribute &self) { 2543 PyPrintAccumulator printAccum; 2544 printAccum.parts.append("NamedAttribute("); 2545 printAccum.parts.append( 2546 mlirIdentifierStr(self.namedAttr.name).data); 2547 printAccum.parts.append("="); 2548 mlirAttributePrint(self.namedAttr.attribute, 2549 printAccum.getCallback(), 2550 printAccum.getUserData()); 2551 printAccum.parts.append(")"); 2552 return printAccum.join(); 2553 }) 2554 .def_property_readonly( 2555 "name", 2556 [](PyNamedAttribute &self) { 2557 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2558 mlirIdentifierStr(self.namedAttr.name).length); 2559 }, 2560 "The name of the NamedAttribute binding") 2561 .def_property_readonly( 2562 "attr", 2563 [](PyNamedAttribute &self) { 2564 // TODO: When named attribute is removed/refactored, also remove 2565 // this constructor (it does an inefficient table lookup). 2566 auto contextRef = PyMlirContext::forContext( 2567 mlirAttributeGetContext(self.namedAttr.attribute)); 2568 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2569 }, 2570 py::keep_alive<0, 1>(), 2571 "The underlying generic attribute of the NamedAttribute binding"); 2572 2573 //---------------------------------------------------------------------------- 2574 // Mapping of PyType. 2575 //---------------------------------------------------------------------------- 2576 py::class_<PyType>(m, "Type", py::module_local()) 2577 // Delegate to the PyType copy constructor, which will also lifetime 2578 // extend the backing context which owns the MlirType. 2579 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2580 "Casts the passed type to the generic Type") 2581 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2582 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2583 .def_static( 2584 "parse", 2585 [](std::string typeSpec, DefaultingPyMlirContext context) { 2586 MlirType type = 2587 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2588 // TODO: Rework error reporting once diagnostic engine is exposed 2589 // in C API. 2590 if (mlirTypeIsNull(type)) { 2591 throw SetPyError(PyExc_ValueError, 2592 Twine("Unable to parse type: '") + typeSpec + 2593 "'"); 2594 } 2595 return PyType(context->getRef(), type); 2596 }, 2597 py::arg("asm"), py::arg("context") = py::none(), 2598 kContextParseTypeDocstring) 2599 .def_property_readonly( 2600 "context", [](PyType &self) { return self.getContext().getObject(); }, 2601 "Context that owns the Type") 2602 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2603 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2604 .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; }) 2605 .def( 2606 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2607 .def( 2608 "__str__", 2609 [](PyType &self) { 2610 PyPrintAccumulator printAccum; 2611 mlirTypePrint(self, printAccum.getCallback(), 2612 printAccum.getUserData()); 2613 return printAccum.join(); 2614 }, 2615 "Returns the assembly form of the type.") 2616 .def("__repr__", [](PyType &self) { 2617 // Generally, assembly formats are not printed for __repr__ because 2618 // this can cause exceptionally long debug output and exceptions. 2619 // However, types are an exception as they typically have compact 2620 // assembly forms and printing them is useful. 2621 PyPrintAccumulator printAccum; 2622 printAccum.parts.append("Type("); 2623 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2624 printAccum.parts.append(")"); 2625 return printAccum.join(); 2626 }); 2627 2628 //---------------------------------------------------------------------------- 2629 // Mapping of Value. 2630 //---------------------------------------------------------------------------- 2631 py::class_<PyValue>(m, "Value", py::module_local()) 2632 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2633 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2634 .def_property_readonly( 2635 "context", 2636 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2637 "Context in which the value lives.") 2638 .def( 2639 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2640 kDumpDocstring) 2641 .def_property_readonly( 2642 "owner", 2643 [](PyValue &self) { 2644 assert(mlirOperationEqual(self.getParentOperation()->get(), 2645 mlirOpResultGetOwner(self.get())) && 2646 "expected the owner of the value in Python to match that in " 2647 "the IR"); 2648 return self.getParentOperation().getObject(); 2649 }) 2650 .def("__eq__", 2651 [](PyValue &self, PyValue &other) { 2652 return self.get().ptr == other.get().ptr; 2653 }) 2654 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2655 .def( 2656 "__str__", 2657 [](PyValue &self) { 2658 PyPrintAccumulator printAccum; 2659 printAccum.parts.append("Value("); 2660 mlirValuePrint(self.get(), printAccum.getCallback(), 2661 printAccum.getUserData()); 2662 printAccum.parts.append(")"); 2663 return printAccum.join(); 2664 }, 2665 kValueDunderStrDocstring) 2666 .def_property_readonly("type", [](PyValue &self) { 2667 return PyType(self.getParentOperation()->getContext(), 2668 mlirValueGetType(self.get())); 2669 }); 2670 PyBlockArgument::bind(m); 2671 PyOpResult::bind(m); 2672 2673 // Container bindings. 2674 PyBlockArgumentList::bind(m); 2675 PyBlockIterator::bind(m); 2676 PyBlockList::bind(m); 2677 PyOperationIterator::bind(m); 2678 PyOperationList::bind(m); 2679 PyOpAttributeMap::bind(m); 2680 PyOpOperandList::bind(m); 2681 PyOpResultList::bind(m); 2682 PyRegionIterator::bind(m); 2683 PyRegionList::bind(m); 2684 2685 // Debug bindings. 2686 PyGlobalDebugFlag::bind(m); 2687 } 2688