1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 namespace py = pybind11; 25 using namespace mlir; 26 using namespace mlir::python; 27 28 using llvm::SmallVector; 29 using llvm::StringRef; 30 using llvm::Twine; 31 32 //------------------------------------------------------------------------------ 33 // Docstrings (trivial, non-duplicated docstrings are included inline). 34 //------------------------------------------------------------------------------ 35 36 static const char kContextParseTypeDocstring[] = 37 R"(Parses the assembly form of a type. 38 39 Returns a Type object or raises a ValueError if the type cannot be parsed. 40 41 See also: https://mlir.llvm.org/docs/LangRef/#type-system 42 )"; 43 44 static const char kContextGetCallSiteLocationDocstring[] = 45 R"(Gets a Location representing a caller and callsite)"; 46 47 static const char kContextGetFileLocationDocstring[] = 48 R"(Gets a Location representing a file, line and column)"; 49 50 static const char kContextGetNameLocationDocString[] = 51 R"(Gets a Location representing a named location with optional child location)"; 52 53 static const char kModuleParseDocstring[] = 54 R"(Parses a module's assembly format from a string. 55 56 Returns a new MlirModule or raises a ValueError if the parsing fails. 57 58 See also: https://mlir.llvm.org/docs/LangRef/ 59 )"; 60 61 static const char kOperationCreateDocstring[] = 62 R"(Creates a new operation. 63 64 Args: 65 name: Operation name (e.g. "dialect.operation"). 66 results: Sequence of Type representing op result types. 67 attributes: Dict of str:Attribute. 68 successors: List of Block for the operation's successors. 69 regions: Number of regions to create. 70 location: A Location object (defaults to resolve from context manager). 71 ip: An InsertionPoint (defaults to resolve from context manager or set to 72 False to disable insertion, even with an insertion point set in the 73 context manager). 74 Returns: 75 A new "detached" Operation object. Detached operations can be added 76 to blocks, which causes them to become "attached." 77 )"; 78 79 static const char kOperationPrintDocstring[] = 80 R"(Prints the assembly form of the operation to a file like object. 81 82 Args: 83 file: The file like object to write to. Defaults to sys.stdout. 84 binary: Whether to write bytes (True) or str (False). Defaults to False. 85 large_elements_limit: Whether to elide elements attributes above this 86 number of elements. Defaults to None (no limit). 87 enable_debug_info: Whether to print debug/location information. Defaults 88 to False. 89 pretty_debug_info: Whether to format debug information for easier reading 90 by a human (warning: the result is unparseable). 91 print_generic_op_form: Whether to print the generic assembly forms of all 92 ops. Defaults to False. 93 use_local_Scope: Whether to print in a way that is more optimized for 94 multi-threaded access but may not be consistent with how the overall 95 module prints. 96 assume_verified: By default, if not printing generic form, the verifier 97 will be run and if it fails, generic form will be printed with a comment 98 about failed verification. While a reasonable default for interactive use, 99 for systematic use, it is often better for the caller to verify explicitly 100 and report failures in a more robust fashion. Set this to True if doing this 101 in order to avoid running a redundant verification. If the IR is actually 102 invalid, behavior is undefined. 103 )"; 104 105 static const char kOperationGetAsmDocstring[] = 106 R"(Gets the assembly form of the operation with all options available. 107 108 Args: 109 binary: Whether to return a bytes (True) or str (False) object. Defaults to 110 False. 111 ... others ...: See the print() method for common keyword arguments for 112 configuring the printout. 113 Returns: 114 Either a bytes or str object, depending on the setting of the 'binary' 115 argument. 116 )"; 117 118 static const char kOperationStrDunderDocstring[] = 119 R"(Gets the assembly form of the operation with default options. 120 121 If more advanced control over the assembly formatting or I/O options is needed, 122 use the dedicated print or get_asm method, which supports keyword arguments to 123 customize behavior. 124 )"; 125 126 static const char kDumpDocstring[] = 127 R"(Dumps a debug representation of the object to stderr.)"; 128 129 static const char kAppendBlockDocstring[] = 130 R"(Appends a new block, with argument types as positional args. 131 132 Returns: 133 The created block. 134 )"; 135 136 static const char kValueDunderStrDocstring[] = 137 R"(Returns the string form of the value. 138 139 If the value is a block argument, this is the assembly form of its type and the 140 position in the argument list. If the value is an operation result, this is 141 equivalent to printing the operation that produced it. 142 )"; 143 144 //------------------------------------------------------------------------------ 145 // Utilities. 146 //------------------------------------------------------------------------------ 147 148 /// Helper for creating an @classmethod. 149 template <class Func, typename... Args> 150 py::object classmethod(Func f, Args... args) { 151 py::object cf = py::cpp_function(f, args...); 152 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 153 } 154 155 static py::object 156 createCustomDialectWrapper(const std::string &dialectNamespace, 157 py::object dialectDescriptor) { 158 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 159 if (!dialectClass) { 160 // Use the base class. 161 return py::cast(PyDialect(std::move(dialectDescriptor))); 162 } 163 164 // Create the custom implementation. 165 return (*dialectClass)(std::move(dialectDescriptor)); 166 } 167 168 static MlirStringRef toMlirStringRef(const std::string &s) { 169 return mlirStringRefCreate(s.data(), s.size()); 170 } 171 172 /// Wrapper for the global LLVM debugging flag. 173 struct PyGlobalDebugFlag { 174 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 175 176 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 177 178 static void bind(py::module &m) { 179 // Debug flags. 180 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 181 .def_property_static("flag", &PyGlobalDebugFlag::get, 182 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 183 } 184 }; 185 186 //------------------------------------------------------------------------------ 187 // Collections. 188 //------------------------------------------------------------------------------ 189 190 namespace { 191 192 class PyRegionIterator { 193 public: 194 PyRegionIterator(PyOperationRef operation) 195 : operation(std::move(operation)) {} 196 197 PyRegionIterator &dunderIter() { return *this; } 198 199 PyRegion dunderNext() { 200 operation->checkValid(); 201 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 202 throw py::stop_iteration(); 203 } 204 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 205 return PyRegion(operation, region); 206 } 207 208 static void bind(py::module &m) { 209 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 210 .def("__iter__", &PyRegionIterator::dunderIter) 211 .def("__next__", &PyRegionIterator::dunderNext); 212 } 213 214 private: 215 PyOperationRef operation; 216 int nextIndex = 0; 217 }; 218 219 /// Regions of an op are fixed length and indexed numerically so are represented 220 /// with a sequence-like container. 221 class PyRegionList { 222 public: 223 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 224 225 intptr_t dunderLen() { 226 operation->checkValid(); 227 return mlirOperationGetNumRegions(operation->get()); 228 } 229 230 PyRegion dunderGetItem(intptr_t index) { 231 // dunderLen checks validity. 232 if (index < 0 || index >= dunderLen()) { 233 throw SetPyError(PyExc_IndexError, 234 "attempt to access out of bounds region"); 235 } 236 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 237 return PyRegion(operation, region); 238 } 239 240 static void bind(py::module &m) { 241 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 242 .def("__len__", &PyRegionList::dunderLen) 243 .def("__getitem__", &PyRegionList::dunderGetItem); 244 } 245 246 private: 247 PyOperationRef operation; 248 }; 249 250 class PyBlockIterator { 251 public: 252 PyBlockIterator(PyOperationRef operation, MlirBlock next) 253 : operation(std::move(operation)), next(next) {} 254 255 PyBlockIterator &dunderIter() { return *this; } 256 257 PyBlock dunderNext() { 258 operation->checkValid(); 259 if (mlirBlockIsNull(next)) { 260 throw py::stop_iteration(); 261 } 262 263 PyBlock returnBlock(operation, next); 264 next = mlirBlockGetNextInRegion(next); 265 return returnBlock; 266 } 267 268 static void bind(py::module &m) { 269 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 270 .def("__iter__", &PyBlockIterator::dunderIter) 271 .def("__next__", &PyBlockIterator::dunderNext); 272 } 273 274 private: 275 PyOperationRef operation; 276 MlirBlock next; 277 }; 278 279 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 280 /// we present them as a more full-featured list-like container but optimize 281 /// it for forward iteration. Blocks are always owned by a region. 282 class PyBlockList { 283 public: 284 PyBlockList(PyOperationRef operation, MlirRegion region) 285 : operation(std::move(operation)), region(region) {} 286 287 PyBlockIterator dunderIter() { 288 operation->checkValid(); 289 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 290 } 291 292 intptr_t dunderLen() { 293 operation->checkValid(); 294 intptr_t count = 0; 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 count += 1; 298 block = mlirBlockGetNextInRegion(block); 299 } 300 return count; 301 } 302 303 PyBlock dunderGetItem(intptr_t index) { 304 operation->checkValid(); 305 if (index < 0) { 306 throw SetPyError(PyExc_IndexError, 307 "attempt to access out of bounds block"); 308 } 309 MlirBlock block = mlirRegionGetFirstBlock(region); 310 while (!mlirBlockIsNull(block)) { 311 if (index == 0) { 312 return PyBlock(operation, block); 313 } 314 block = mlirBlockGetNextInRegion(block); 315 index -= 1; 316 } 317 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 318 } 319 320 PyBlock appendBlock(py::args pyArgTypes) { 321 operation->checkValid(); 322 llvm::SmallVector<MlirType, 4> argTypes; 323 argTypes.reserve(pyArgTypes.size()); 324 for (auto &pyArg : pyArgTypes) { 325 argTypes.push_back(pyArg.cast<PyType &>()); 326 } 327 328 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 329 mlirRegionAppendOwnedBlock(region, block); 330 return PyBlock(operation, block); 331 } 332 333 static void bind(py::module &m) { 334 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 335 .def("__getitem__", &PyBlockList::dunderGetItem) 336 .def("__iter__", &PyBlockList::dunderIter) 337 .def("__len__", &PyBlockList::dunderLen) 338 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 339 } 340 341 private: 342 PyOperationRef operation; 343 MlirRegion region; 344 }; 345 346 class PyOperationIterator { 347 public: 348 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 349 : parentOperation(std::move(parentOperation)), next(next) {} 350 351 PyOperationIterator &dunderIter() { return *this; } 352 353 py::object dunderNext() { 354 parentOperation->checkValid(); 355 if (mlirOperationIsNull(next)) { 356 throw py::stop_iteration(); 357 } 358 359 PyOperationRef returnOperation = 360 PyOperation::forOperation(parentOperation->getContext(), next); 361 next = mlirOperationGetNextInBlock(next); 362 return returnOperation->createOpView(); 363 } 364 365 static void bind(py::module &m) { 366 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 367 .def("__iter__", &PyOperationIterator::dunderIter) 368 .def("__next__", &PyOperationIterator::dunderNext); 369 } 370 371 private: 372 PyOperationRef parentOperation; 373 MlirOperation next; 374 }; 375 376 /// Operations are exposed by the C-API as a forward-only linked list. In 377 /// Python, we present them as a more full-featured list-like container but 378 /// optimize it for forward iteration. Iterable operations are always owned 379 /// by a block. 380 class PyOperationList { 381 public: 382 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 383 : parentOperation(std::move(parentOperation)), block(block) {} 384 385 PyOperationIterator dunderIter() { 386 parentOperation->checkValid(); 387 return PyOperationIterator(parentOperation, 388 mlirBlockGetFirstOperation(block)); 389 } 390 391 intptr_t dunderLen() { 392 parentOperation->checkValid(); 393 intptr_t count = 0; 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 count += 1; 397 childOp = mlirOperationGetNextInBlock(childOp); 398 } 399 return count; 400 } 401 402 py::object dunderGetItem(intptr_t index) { 403 parentOperation->checkValid(); 404 if (index < 0) { 405 throw SetPyError(PyExc_IndexError, 406 "attempt to access out of bounds operation"); 407 } 408 MlirOperation childOp = mlirBlockGetFirstOperation(block); 409 while (!mlirOperationIsNull(childOp)) { 410 if (index == 0) { 411 return PyOperation::forOperation(parentOperation->getContext(), childOp) 412 ->createOpView(); 413 } 414 childOp = mlirOperationGetNextInBlock(childOp); 415 index -= 1; 416 } 417 throw SetPyError(PyExc_IndexError, 418 "attempt to access out of bounds operation"); 419 } 420 421 static void bind(py::module &m) { 422 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 423 .def("__getitem__", &PyOperationList::dunderGetItem) 424 .def("__iter__", &PyOperationList::dunderIter) 425 .def("__len__", &PyOperationList::dunderLen); 426 } 427 428 private: 429 PyOperationRef parentOperation; 430 MlirBlock block; 431 }; 432 433 } // namespace 434 435 //------------------------------------------------------------------------------ 436 // PyMlirContext 437 //------------------------------------------------------------------------------ 438 439 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 440 py::gil_scoped_acquire acquire; 441 auto &liveContexts = getLiveContexts(); 442 liveContexts[context.ptr] = this; 443 } 444 445 PyMlirContext::~PyMlirContext() { 446 // Note that the only public way to construct an instance is via the 447 // forContext method, which always puts the associated handle into 448 // liveContexts. 449 py::gil_scoped_acquire acquire; 450 getLiveContexts().erase(context.ptr); 451 mlirContextDestroy(context); 452 } 453 454 py::object PyMlirContext::getCapsule() { 455 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 456 } 457 458 py::object PyMlirContext::createFromCapsule(py::object capsule) { 459 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 460 if (mlirContextIsNull(rawContext)) 461 throw py::error_already_set(); 462 return forContext(rawContext).releaseObject(); 463 } 464 465 PyMlirContext *PyMlirContext::createNewContextForInit() { 466 MlirContext context = mlirContextCreate(); 467 mlirRegisterAllDialects(context); 468 return new PyMlirContext(context); 469 } 470 471 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 472 py::gil_scoped_acquire acquire; 473 auto &liveContexts = getLiveContexts(); 474 auto it = liveContexts.find(context.ptr); 475 if (it == liveContexts.end()) { 476 // Create. 477 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 478 py::object pyRef = py::cast(unownedContextWrapper); 479 assert(pyRef && "cast to py::object failed"); 480 liveContexts[context.ptr] = unownedContextWrapper; 481 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 482 } 483 // Use existing. 484 py::object pyRef = py::cast(it->second); 485 return PyMlirContextRef(it->second, std::move(pyRef)); 486 } 487 488 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 489 static LiveContextMap liveContexts; 490 return liveContexts; 491 } 492 493 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 494 495 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 496 497 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 498 499 pybind11::object PyMlirContext::contextEnter() { 500 return PyThreadContextEntry::pushContext(*this); 501 } 502 503 void PyMlirContext::contextExit(pybind11::object excType, 504 pybind11::object excVal, 505 pybind11::object excTb) { 506 PyThreadContextEntry::popContext(*this); 507 } 508 509 PyMlirContext &DefaultingPyMlirContext::resolve() { 510 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 511 if (!context) { 512 throw SetPyError( 513 PyExc_RuntimeError, 514 "An MLIR function requires a Context but none was provided in the call " 515 "or from the surrounding environment. Either pass to the function with " 516 "a 'context=' argument or establish a default using 'with Context():'"); 517 } 518 return *context; 519 } 520 521 //------------------------------------------------------------------------------ 522 // PyThreadContextEntry management 523 //------------------------------------------------------------------------------ 524 525 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 526 static thread_local std::vector<PyThreadContextEntry> stack; 527 return stack; 528 } 529 530 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 531 auto &stack = getStack(); 532 if (stack.empty()) 533 return nullptr; 534 return &stack.back(); 535 } 536 537 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 538 py::object insertionPoint, 539 py::object location) { 540 auto &stack = getStack(); 541 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 542 std::move(location)); 543 // If the new stack has more than one entry and the context of the new top 544 // entry matches the previous, copy the insertionPoint and location from the 545 // previous entry if missing from the new top entry. 546 if (stack.size() > 1) { 547 auto &prev = *(stack.rbegin() + 1); 548 auto ¤t = stack.back(); 549 if (current.context.is(prev.context)) { 550 // Default non-context objects from the previous entry. 551 if (!current.insertionPoint) 552 current.insertionPoint = prev.insertionPoint; 553 if (!current.location) 554 current.location = prev.location; 555 } 556 } 557 } 558 559 PyMlirContext *PyThreadContextEntry::getContext() { 560 if (!context) 561 return nullptr; 562 return py::cast<PyMlirContext *>(context); 563 } 564 565 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 566 if (!insertionPoint) 567 return nullptr; 568 return py::cast<PyInsertionPoint *>(insertionPoint); 569 } 570 571 PyLocation *PyThreadContextEntry::getLocation() { 572 if (!location) 573 return nullptr; 574 return py::cast<PyLocation *>(location); 575 } 576 577 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 578 auto *tos = getTopOfStack(); 579 return tos ? tos->getContext() : nullptr; 580 } 581 582 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 583 auto *tos = getTopOfStack(); 584 return tos ? tos->getInsertionPoint() : nullptr; 585 } 586 587 PyLocation *PyThreadContextEntry::getDefaultLocation() { 588 auto *tos = getTopOfStack(); 589 return tos ? tos->getLocation() : nullptr; 590 } 591 592 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 593 py::object contextObj = py::cast(context); 594 push(FrameKind::Context, /*context=*/contextObj, 595 /*insertionPoint=*/py::object(), 596 /*location=*/py::object()); 597 return contextObj; 598 } 599 600 void PyThreadContextEntry::popContext(PyMlirContext &context) { 601 auto &stack = getStack(); 602 if (stack.empty()) 603 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 604 auto &tos = stack.back(); 605 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 606 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 607 stack.pop_back(); 608 } 609 610 py::object 611 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 612 py::object contextObj = 613 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 614 py::object insertionPointObj = py::cast(insertionPoint); 615 push(FrameKind::InsertionPoint, 616 /*context=*/contextObj, 617 /*insertionPoint=*/insertionPointObj, 618 /*location=*/py::object()); 619 return insertionPointObj; 620 } 621 622 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 623 auto &stack = getStack(); 624 if (stack.empty()) 625 throw SetPyError(PyExc_RuntimeError, 626 "Unbalanced InsertionPoint enter/exit"); 627 auto &tos = stack.back(); 628 if (tos.frameKind != FrameKind::InsertionPoint && 629 tos.getInsertionPoint() != &insertionPoint) 630 throw SetPyError(PyExc_RuntimeError, 631 "Unbalanced InsertionPoint enter/exit"); 632 stack.pop_back(); 633 } 634 635 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 636 py::object contextObj = location.getContext().getObject(); 637 py::object locationObj = py::cast(location); 638 push(FrameKind::Location, /*context=*/contextObj, 639 /*insertionPoint=*/py::object(), 640 /*location=*/locationObj); 641 return locationObj; 642 } 643 644 void PyThreadContextEntry::popLocation(PyLocation &location) { 645 auto &stack = getStack(); 646 if (stack.empty()) 647 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 648 auto &tos = stack.back(); 649 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 650 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 651 stack.pop_back(); 652 } 653 654 //------------------------------------------------------------------------------ 655 // PyDialect, PyDialectDescriptor, PyDialects 656 //------------------------------------------------------------------------------ 657 658 MlirDialect PyDialects::getDialectForKey(const std::string &key, 659 bool attrError) { 660 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 661 {key.data(), key.size()}); 662 if (mlirDialectIsNull(dialect)) { 663 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 664 Twine("Dialect '") + key + "' not found"); 665 } 666 return dialect; 667 } 668 669 //------------------------------------------------------------------------------ 670 // PyLocation 671 //------------------------------------------------------------------------------ 672 673 py::object PyLocation::getCapsule() { 674 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 675 } 676 677 PyLocation PyLocation::createFromCapsule(py::object capsule) { 678 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 679 if (mlirLocationIsNull(rawLoc)) 680 throw py::error_already_set(); 681 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 682 rawLoc); 683 } 684 685 py::object PyLocation::contextEnter() { 686 return PyThreadContextEntry::pushLocation(*this); 687 } 688 689 void PyLocation::contextExit(py::object excType, py::object excVal, 690 py::object excTb) { 691 PyThreadContextEntry::popLocation(*this); 692 } 693 694 PyLocation &DefaultingPyLocation::resolve() { 695 auto *location = PyThreadContextEntry::getDefaultLocation(); 696 if (!location) { 697 throw SetPyError( 698 PyExc_RuntimeError, 699 "An MLIR function requires a Location but none was provided in the " 700 "call or from the surrounding environment. Either pass to the function " 701 "with a 'loc=' argument or establish a default using 'with loc:'"); 702 } 703 return *location; 704 } 705 706 //------------------------------------------------------------------------------ 707 // PyModule 708 //------------------------------------------------------------------------------ 709 710 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 711 : BaseContextObject(std::move(contextRef)), module(module) {} 712 713 PyModule::~PyModule() { 714 py::gil_scoped_acquire acquire; 715 auto &liveModules = getContext()->liveModules; 716 assert(liveModules.count(module.ptr) == 1 && 717 "destroying module not in live map"); 718 liveModules.erase(module.ptr); 719 mlirModuleDestroy(module); 720 } 721 722 PyModuleRef PyModule::forModule(MlirModule module) { 723 MlirContext context = mlirModuleGetContext(module); 724 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 725 726 py::gil_scoped_acquire acquire; 727 auto &liveModules = contextRef->liveModules; 728 auto it = liveModules.find(module.ptr); 729 if (it == liveModules.end()) { 730 // Create. 731 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 732 // Note that the default return value policy on cast is automatic_reference, 733 // which does not take ownership (delete will not be called). 734 // Just be explicit. 735 py::object pyRef = 736 py::cast(unownedModule, py::return_value_policy::take_ownership); 737 unownedModule->handle = pyRef; 738 liveModules[module.ptr] = 739 std::make_pair(unownedModule->handle, unownedModule); 740 return PyModuleRef(unownedModule, std::move(pyRef)); 741 } 742 // Use existing. 743 PyModule *existing = it->second.second; 744 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 745 return PyModuleRef(existing, std::move(pyRef)); 746 } 747 748 py::object PyModule::createFromCapsule(py::object capsule) { 749 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 750 if (mlirModuleIsNull(rawModule)) 751 throw py::error_already_set(); 752 return forModule(rawModule).releaseObject(); 753 } 754 755 py::object PyModule::getCapsule() { 756 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 757 } 758 759 //------------------------------------------------------------------------------ 760 // PyOperation 761 //------------------------------------------------------------------------------ 762 763 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 764 : BaseContextObject(std::move(contextRef)), operation(operation) {} 765 766 PyOperation::~PyOperation() { 767 // If the operation has already been invalidated there is nothing to do. 768 if (!valid) 769 return; 770 auto &liveOperations = getContext()->liveOperations; 771 assert(liveOperations.count(operation.ptr) == 1 && 772 "destroying operation not in live map"); 773 liveOperations.erase(operation.ptr); 774 if (!isAttached()) { 775 mlirOperationDestroy(operation); 776 } 777 } 778 779 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 780 MlirOperation operation, 781 py::object parentKeepAlive) { 782 auto &liveOperations = contextRef->liveOperations; 783 // Create. 784 PyOperation *unownedOperation = 785 new PyOperation(std::move(contextRef), operation); 786 // Note that the default return value policy on cast is automatic_reference, 787 // which does not take ownership (delete will not be called). 788 // Just be explicit. 789 py::object pyRef = 790 py::cast(unownedOperation, py::return_value_policy::take_ownership); 791 unownedOperation->handle = pyRef; 792 if (parentKeepAlive) { 793 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 794 } 795 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 796 return PyOperationRef(unownedOperation, std::move(pyRef)); 797 } 798 799 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 800 MlirOperation operation, 801 py::object parentKeepAlive) { 802 auto &liveOperations = contextRef->liveOperations; 803 auto it = liveOperations.find(operation.ptr); 804 if (it == liveOperations.end()) { 805 // Create. 806 return createInstance(std::move(contextRef), operation, 807 std::move(parentKeepAlive)); 808 } 809 // Use existing. 810 PyOperation *existing = it->second.second; 811 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 812 return PyOperationRef(existing, std::move(pyRef)); 813 } 814 815 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 816 MlirOperation operation, 817 py::object parentKeepAlive) { 818 auto &liveOperations = contextRef->liveOperations; 819 assert(liveOperations.count(operation.ptr) == 0 && 820 "cannot create detached operation that already exists"); 821 (void)liveOperations; 822 823 PyOperationRef created = createInstance(std::move(contextRef), operation, 824 std::move(parentKeepAlive)); 825 created->attached = false; 826 return created; 827 } 828 829 void PyOperation::checkValid() const { 830 if (!valid) { 831 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 832 } 833 } 834 835 void PyOperationBase::print(py::object fileObject, bool binary, 836 llvm::Optional<int64_t> largeElementsLimit, 837 bool enableDebugInfo, bool prettyDebugInfo, 838 bool printGenericOpForm, bool useLocalScope, 839 bool assumeVerified) { 840 PyOperation &operation = getOperation(); 841 operation.checkValid(); 842 if (fileObject.is_none()) 843 fileObject = py::module::import("sys").attr("stdout"); 844 845 if (!assumeVerified && !printGenericOpForm && 846 !mlirOperationVerify(operation)) { 847 std::string message("// Verification failed, printing generic form\n"); 848 if (binary) { 849 fileObject.attr("write")(py::bytes(message)); 850 } else { 851 fileObject.attr("write")(py::str(message)); 852 } 853 printGenericOpForm = true; 854 } 855 856 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 857 if (largeElementsLimit) 858 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 859 if (enableDebugInfo) 860 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 861 if (printGenericOpForm) 862 mlirOpPrintingFlagsPrintGenericOpForm(flags); 863 864 PyFileAccumulator accum(fileObject, binary); 865 py::gil_scoped_release(); 866 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 867 accum.getUserData()); 868 mlirOpPrintingFlagsDestroy(flags); 869 } 870 871 py::object PyOperationBase::getAsm(bool binary, 872 llvm::Optional<int64_t> largeElementsLimit, 873 bool enableDebugInfo, bool prettyDebugInfo, 874 bool printGenericOpForm, bool useLocalScope, 875 bool assumeVerified) { 876 py::object fileObject; 877 if (binary) { 878 fileObject = py::module::import("io").attr("BytesIO")(); 879 } else { 880 fileObject = py::module::import("io").attr("StringIO")(); 881 } 882 print(fileObject, /*binary=*/binary, 883 /*largeElementsLimit=*/largeElementsLimit, 884 /*enableDebugInfo=*/enableDebugInfo, 885 /*prettyDebugInfo=*/prettyDebugInfo, 886 /*printGenericOpForm=*/printGenericOpForm, 887 /*useLocalScope=*/useLocalScope, 888 /*assumeVerified=*/assumeVerified); 889 890 return fileObject.attr("getvalue")(); 891 } 892 893 void PyOperationBase::moveAfter(PyOperationBase &other) { 894 PyOperation &operation = getOperation(); 895 PyOperation &otherOp = other.getOperation(); 896 operation.checkValid(); 897 otherOp.checkValid(); 898 mlirOperationMoveAfter(operation, otherOp); 899 operation.parentKeepAlive = otherOp.parentKeepAlive; 900 } 901 902 void PyOperationBase::moveBefore(PyOperationBase &other) { 903 PyOperation &operation = getOperation(); 904 PyOperation &otherOp = other.getOperation(); 905 operation.checkValid(); 906 otherOp.checkValid(); 907 mlirOperationMoveBefore(operation, otherOp); 908 operation.parentKeepAlive = otherOp.parentKeepAlive; 909 } 910 911 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 912 checkValid(); 913 if (!isAttached()) 914 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 915 MlirOperation operation = mlirOperationGetParentOperation(get()); 916 if (mlirOperationIsNull(operation)) 917 return {}; 918 return PyOperation::forOperation(getContext(), operation); 919 } 920 921 PyBlock PyOperation::getBlock() { 922 checkValid(); 923 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 924 MlirBlock block = mlirOperationGetBlock(get()); 925 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 926 assert(parentOperation && "Operation has no parent"); 927 return PyBlock{std::move(*parentOperation), block}; 928 } 929 930 py::object PyOperation::getCapsule() { 931 checkValid(); 932 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 933 } 934 935 py::object PyOperation::createFromCapsule(py::object capsule) { 936 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 937 if (mlirOperationIsNull(rawOperation)) 938 throw py::error_already_set(); 939 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 940 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 941 .releaseObject(); 942 } 943 944 py::object PyOperation::create( 945 std::string name, llvm::Optional<std::vector<PyType *>> results, 946 llvm::Optional<std::vector<PyValue *>> operands, 947 llvm::Optional<py::dict> attributes, 948 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 949 DefaultingPyLocation location, py::object maybeIp) { 950 llvm::SmallVector<MlirValue, 4> mlirOperands; 951 llvm::SmallVector<MlirType, 4> mlirResults; 952 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 953 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 954 955 // General parameter validation. 956 if (regions < 0) 957 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 958 959 // Unpack/validate operands. 960 if (operands) { 961 mlirOperands.reserve(operands->size()); 962 for (PyValue *operand : *operands) { 963 if (!operand) 964 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 965 mlirOperands.push_back(operand->get()); 966 } 967 } 968 969 // Unpack/validate results. 970 if (results) { 971 mlirResults.reserve(results->size()); 972 for (PyType *result : *results) { 973 // TODO: Verify result type originate from the same context. 974 if (!result) 975 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 976 mlirResults.push_back(*result); 977 } 978 } 979 // Unpack/validate attributes. 980 if (attributes) { 981 mlirAttributes.reserve(attributes->size()); 982 for (auto &it : *attributes) { 983 std::string key; 984 try { 985 key = it.first.cast<std::string>(); 986 } catch (py::cast_error &err) { 987 std::string msg = "Invalid attribute key (not a string) when " 988 "attempting to create the operation \"" + 989 name + "\" (" + err.what() + ")"; 990 throw py::cast_error(msg); 991 } 992 try { 993 auto &attribute = it.second.cast<PyAttribute &>(); 994 // TODO: Verify attribute originates from the same context. 995 mlirAttributes.emplace_back(std::move(key), attribute); 996 } catch (py::reference_cast_error &) { 997 // This exception seems thrown when the value is "None". 998 std::string msg = 999 "Found an invalid (`None`?) attribute value for the key \"" + key + 1000 "\" when attempting to create the operation \"" + name + "\""; 1001 throw py::cast_error(msg); 1002 } catch (py::cast_error &err) { 1003 std::string msg = "Invalid attribute value for the key \"" + key + 1004 "\" when attempting to create the operation \"" + 1005 name + "\" (" + err.what() + ")"; 1006 throw py::cast_error(msg); 1007 } 1008 } 1009 } 1010 // Unpack/validate successors. 1011 if (successors) { 1012 mlirSuccessors.reserve(successors->size()); 1013 for (auto *successor : *successors) { 1014 // TODO: Verify successor originate from the same context. 1015 if (!successor) 1016 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1017 mlirSuccessors.push_back(successor->get()); 1018 } 1019 } 1020 1021 // Apply unpacked/validated to the operation state. Beyond this 1022 // point, exceptions cannot be thrown or else the state will leak. 1023 MlirOperationState state = 1024 mlirOperationStateGet(toMlirStringRef(name), location); 1025 if (!mlirOperands.empty()) 1026 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1027 mlirOperands.data()); 1028 if (!mlirResults.empty()) 1029 mlirOperationStateAddResults(&state, mlirResults.size(), 1030 mlirResults.data()); 1031 if (!mlirAttributes.empty()) { 1032 // Note that the attribute names directly reference bytes in 1033 // mlirAttributes, so that vector must not be changed from here 1034 // on. 1035 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1036 mlirNamedAttributes.reserve(mlirAttributes.size()); 1037 for (auto &it : mlirAttributes) 1038 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1039 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1040 toMlirStringRef(it.first)), 1041 it.second)); 1042 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1043 mlirNamedAttributes.data()); 1044 } 1045 if (!mlirSuccessors.empty()) 1046 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1047 mlirSuccessors.data()); 1048 if (regions) { 1049 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1050 mlirRegions.resize(regions); 1051 for (int i = 0; i < regions; ++i) 1052 mlirRegions[i] = mlirRegionCreate(); 1053 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1054 mlirRegions.data()); 1055 } 1056 1057 // Construct the operation. 1058 MlirOperation operation = mlirOperationCreate(&state); 1059 PyOperationRef created = 1060 PyOperation::createDetached(location->getContext(), operation); 1061 1062 // InsertPoint active? 1063 if (!maybeIp.is(py::cast(false))) { 1064 PyInsertionPoint *ip; 1065 if (maybeIp.is_none()) { 1066 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1067 } else { 1068 ip = py::cast<PyInsertionPoint *>(maybeIp); 1069 } 1070 if (ip) 1071 ip->insert(*created.get()); 1072 } 1073 1074 return created->createOpView(); 1075 } 1076 1077 py::object PyOperation::createOpView() { 1078 checkValid(); 1079 MlirIdentifier ident = mlirOperationGetName(get()); 1080 MlirStringRef identStr = mlirIdentifierStr(ident); 1081 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1082 StringRef(identStr.data, identStr.length)); 1083 if (opViewClass) 1084 return (*opViewClass)(getRef().getObject()); 1085 return py::cast(PyOpView(getRef().getObject())); 1086 } 1087 1088 void PyOperation::erase() { 1089 checkValid(); 1090 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1091 // Python reference to a child operation is live. All children should also 1092 // have their `valid` bit set to false. 1093 auto &liveOperations = getContext()->liveOperations; 1094 if (liveOperations.count(operation.ptr)) 1095 liveOperations.erase(operation.ptr); 1096 mlirOperationDestroy(operation); 1097 valid = false; 1098 } 1099 1100 //------------------------------------------------------------------------------ 1101 // PyOpView 1102 //------------------------------------------------------------------------------ 1103 1104 py::object 1105 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1106 py::list operandList, 1107 llvm::Optional<py::dict> attributes, 1108 llvm::Optional<std::vector<PyBlock *>> successors, 1109 llvm::Optional<int> regions, 1110 DefaultingPyLocation location, py::object maybeIp) { 1111 PyMlirContextRef context = location->getContext(); 1112 // Class level operation construction metadata. 1113 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1114 // Operand and result segment specs are either none, which does no 1115 // variadic unpacking, or a list of ints with segment sizes, where each 1116 // element is either a positive number (typically 1 for a scalar) or -1 to 1117 // indicate that it is derived from the length of the same-indexed operand 1118 // or result (implying that it is a list at that position). 1119 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1120 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1121 1122 std::vector<uint32_t> operandSegmentLengths; 1123 std::vector<uint32_t> resultSegmentLengths; 1124 1125 // Validate/determine region count. 1126 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1127 int opMinRegionCount = std::get<0>(opRegionSpec); 1128 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1129 if (!regions) { 1130 regions = opMinRegionCount; 1131 } 1132 if (*regions < opMinRegionCount) { 1133 throw py::value_error( 1134 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1135 llvm::Twine(opMinRegionCount) + 1136 " regions but was built with regions=" + llvm::Twine(*regions)) 1137 .str()); 1138 } 1139 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1140 throw py::value_error( 1141 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1142 llvm::Twine(opMinRegionCount) + 1143 " regions but was built with regions=" + llvm::Twine(*regions)) 1144 .str()); 1145 } 1146 1147 // Unpack results. 1148 std::vector<PyType *> resultTypes; 1149 resultTypes.reserve(resultTypeList.size()); 1150 if (resultSegmentSpecObj.is_none()) { 1151 // Non-variadic result unpacking. 1152 for (auto it : llvm::enumerate(resultTypeList)) { 1153 try { 1154 resultTypes.push_back(py::cast<PyType *>(it.value())); 1155 if (!resultTypes.back()) 1156 throw py::cast_error(); 1157 } catch (py::cast_error &err) { 1158 throw py::value_error((llvm::Twine("Result ") + 1159 llvm::Twine(it.index()) + " of operation \"" + 1160 name + "\" must be a Type (" + err.what() + ")") 1161 .str()); 1162 } 1163 } 1164 } else { 1165 // Sized result unpacking. 1166 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1167 if (resultSegmentSpec.size() != resultTypeList.size()) { 1168 throw py::value_error((llvm::Twine("Operation \"") + name + 1169 "\" requires " + 1170 llvm::Twine(resultSegmentSpec.size()) + 1171 " result segments but was provided " + 1172 llvm::Twine(resultTypeList.size())) 1173 .str()); 1174 } 1175 resultSegmentLengths.reserve(resultTypeList.size()); 1176 for (auto it : 1177 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1178 int segmentSpec = std::get<1>(it.value()); 1179 if (segmentSpec == 1 || segmentSpec == 0) { 1180 // Unpack unary element. 1181 try { 1182 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1183 if (resultType) { 1184 resultTypes.push_back(resultType); 1185 resultSegmentLengths.push_back(1); 1186 } else if (segmentSpec == 0) { 1187 // Allowed to be optional. 1188 resultSegmentLengths.push_back(0); 1189 } else { 1190 throw py::cast_error("was None and result is not optional"); 1191 } 1192 } catch (py::cast_error &err) { 1193 throw py::value_error((llvm::Twine("Result ") + 1194 llvm::Twine(it.index()) + " of operation \"" + 1195 name + "\" must be a Type (" + err.what() + 1196 ")") 1197 .str()); 1198 } 1199 } else if (segmentSpec == -1) { 1200 // Unpack sequence by appending. 1201 try { 1202 if (std::get<0>(it.value()).is_none()) { 1203 // Treat it as an empty list. 1204 resultSegmentLengths.push_back(0); 1205 } else { 1206 // Unpack the list. 1207 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1208 for (py::object segmentItem : segment) { 1209 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1210 if (!resultTypes.back()) { 1211 throw py::cast_error("contained a None item"); 1212 } 1213 } 1214 resultSegmentLengths.push_back(segment.size()); 1215 } 1216 } catch (std::exception &err) { 1217 // NOTE: Sloppy to be using a catch-all here, but there are at least 1218 // three different unrelated exceptions that can be thrown in the 1219 // above "casts". Just keep the scope above small and catch them all. 1220 throw py::value_error((llvm::Twine("Result ") + 1221 llvm::Twine(it.index()) + " of operation \"" + 1222 name + "\" must be a Sequence of Types (" + 1223 err.what() + ")") 1224 .str()); 1225 } 1226 } else { 1227 throw py::value_error("Unexpected segment spec"); 1228 } 1229 } 1230 } 1231 1232 // Unpack operands. 1233 std::vector<PyValue *> operands; 1234 operands.reserve(operands.size()); 1235 if (operandSegmentSpecObj.is_none()) { 1236 // Non-sized operand unpacking. 1237 for (auto it : llvm::enumerate(operandList)) { 1238 try { 1239 operands.push_back(py::cast<PyValue *>(it.value())); 1240 if (!operands.back()) 1241 throw py::cast_error(); 1242 } catch (py::cast_error &err) { 1243 throw py::value_error((llvm::Twine("Operand ") + 1244 llvm::Twine(it.index()) + " of operation \"" + 1245 name + "\" must be a Value (" + err.what() + ")") 1246 .str()); 1247 } 1248 } 1249 } else { 1250 // Sized operand unpacking. 1251 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1252 if (operandSegmentSpec.size() != operandList.size()) { 1253 throw py::value_error((llvm::Twine("Operation \"") + name + 1254 "\" requires " + 1255 llvm::Twine(operandSegmentSpec.size()) + 1256 "operand segments but was provided " + 1257 llvm::Twine(operandList.size())) 1258 .str()); 1259 } 1260 operandSegmentLengths.reserve(operandList.size()); 1261 for (auto it : 1262 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1263 int segmentSpec = std::get<1>(it.value()); 1264 if (segmentSpec == 1 || segmentSpec == 0) { 1265 // Unpack unary element. 1266 try { 1267 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1268 if (operandValue) { 1269 operands.push_back(operandValue); 1270 operandSegmentLengths.push_back(1); 1271 } else if (segmentSpec == 0) { 1272 // Allowed to be optional. 1273 operandSegmentLengths.push_back(0); 1274 } else { 1275 throw py::cast_error("was None and operand is not optional"); 1276 } 1277 } catch (py::cast_error &err) { 1278 throw py::value_error((llvm::Twine("Operand ") + 1279 llvm::Twine(it.index()) + " of operation \"" + 1280 name + "\" must be a Value (" + err.what() + 1281 ")") 1282 .str()); 1283 } 1284 } else if (segmentSpec == -1) { 1285 // Unpack sequence by appending. 1286 try { 1287 if (std::get<0>(it.value()).is_none()) { 1288 // Treat it as an empty list. 1289 operandSegmentLengths.push_back(0); 1290 } else { 1291 // Unpack the list. 1292 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1293 for (py::object segmentItem : segment) { 1294 operands.push_back(py::cast<PyValue *>(segmentItem)); 1295 if (!operands.back()) { 1296 throw py::cast_error("contained a None item"); 1297 } 1298 } 1299 operandSegmentLengths.push_back(segment.size()); 1300 } 1301 } catch (std::exception &err) { 1302 // NOTE: Sloppy to be using a catch-all here, but there are at least 1303 // three different unrelated exceptions that can be thrown in the 1304 // above "casts". Just keep the scope above small and catch them all. 1305 throw py::value_error((llvm::Twine("Operand ") + 1306 llvm::Twine(it.index()) + " of operation \"" + 1307 name + "\" must be a Sequence of Values (" + 1308 err.what() + ")") 1309 .str()); 1310 } 1311 } else { 1312 throw py::value_error("Unexpected segment spec"); 1313 } 1314 } 1315 } 1316 1317 // Merge operand/result segment lengths into attributes if needed. 1318 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1319 // Dup. 1320 if (attributes) { 1321 attributes = py::dict(*attributes); 1322 } else { 1323 attributes = py::dict(); 1324 } 1325 if (attributes->contains("result_segment_sizes") || 1326 attributes->contains("operand_segment_sizes")) { 1327 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1328 "'operand_segment_sizes' attribute is unsupported. " 1329 "Use Operation.create for such low-level access."); 1330 } 1331 1332 // Add result_segment_sizes attribute. 1333 if (!resultSegmentLengths.empty()) { 1334 int64_t size = resultSegmentLengths.size(); 1335 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1336 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1337 resultSegmentLengths.size(), resultSegmentLengths.data()); 1338 (*attributes)["result_segment_sizes"] = 1339 PyAttribute(context, segmentLengthAttr); 1340 } 1341 1342 // Add operand_segment_sizes attribute. 1343 if (!operandSegmentLengths.empty()) { 1344 int64_t size = operandSegmentLengths.size(); 1345 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1346 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1347 operandSegmentLengths.size(), operandSegmentLengths.data()); 1348 (*attributes)["operand_segment_sizes"] = 1349 PyAttribute(context, segmentLengthAttr); 1350 } 1351 } 1352 1353 // Delegate to create. 1354 return PyOperation::create(std::move(name), 1355 /*results=*/std::move(resultTypes), 1356 /*operands=*/std::move(operands), 1357 /*attributes=*/std::move(attributes), 1358 /*successors=*/std::move(successors), 1359 /*regions=*/*regions, location, maybeIp); 1360 } 1361 1362 PyOpView::PyOpView(py::object operationObject) 1363 // Casting through the PyOperationBase base-class and then back to the 1364 // Operation lets us accept any PyOperationBase subclass. 1365 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1366 operationObject(operation.getRef().getObject()) {} 1367 1368 py::object PyOpView::createRawSubclass(py::object userClass) { 1369 // This is... a little gross. The typical pattern is to have a pure python 1370 // class that extends OpView like: 1371 // class AddFOp(_cext.ir.OpView): 1372 // def __init__(self, loc, lhs, rhs): 1373 // operation = loc.context.create_operation( 1374 // "addf", lhs, rhs, results=[lhs.type]) 1375 // super().__init__(operation) 1376 // 1377 // I.e. The goal of the user facing type is to provide a nice constructor 1378 // that has complete freedom for the op under construction. This is at odds 1379 // with our other desire to sometimes create this object by just passing an 1380 // operation (to initialize the base class). We could do *arg and **kwargs 1381 // munging to try to make it work, but instead, we synthesize a new class 1382 // on the fly which extends this user class (AddFOp in this example) and 1383 // *give it* the base class's __init__ method, thus bypassing the 1384 // intermediate subclass's __init__ method entirely. While slightly, 1385 // underhanded, this is safe/legal because the type hierarchy has not changed 1386 // (we just added a new leaf) and we aren't mucking around with __new__. 1387 // Typically, this new class will be stored on the original as "_Raw" and will 1388 // be used for casts and other things that need a variant of the class that 1389 // is initialized purely from an operation. 1390 py::object parentMetaclass = 1391 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1392 py::dict attributes; 1393 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1394 // now. 1395 // auto opViewType = py::type::of<PyOpView>(); 1396 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1397 attributes["__init__"] = opViewType.attr("__init__"); 1398 py::str origName = userClass.attr("__name__"); 1399 py::str newName = py::str("_") + origName; 1400 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1401 } 1402 1403 //------------------------------------------------------------------------------ 1404 // PyInsertionPoint. 1405 //------------------------------------------------------------------------------ 1406 1407 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1408 1409 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1410 : refOperation(beforeOperationBase.getOperation().getRef()), 1411 block((*refOperation)->getBlock()) {} 1412 1413 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1414 PyOperation &operation = operationBase.getOperation(); 1415 if (operation.isAttached()) 1416 throw SetPyError(PyExc_ValueError, 1417 "Attempt to insert operation that is already attached"); 1418 block.getParentOperation()->checkValid(); 1419 MlirOperation beforeOp = {nullptr}; 1420 if (refOperation) { 1421 // Insert before operation. 1422 (*refOperation)->checkValid(); 1423 beforeOp = (*refOperation)->get(); 1424 } else { 1425 // Insert at end (before null) is only valid if the block does not 1426 // already end in a known terminator (violating this will cause assertion 1427 // failures later). 1428 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1429 throw py::index_error("Cannot insert operation at the end of a block " 1430 "that already has a terminator. Did you mean to " 1431 "use 'InsertionPoint.at_block_terminator(block)' " 1432 "versus 'InsertionPoint(block)'?"); 1433 } 1434 } 1435 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1436 operation.setAttached(); 1437 } 1438 1439 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1440 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1441 if (mlirOperationIsNull(firstOp)) { 1442 // Just insert at end. 1443 return PyInsertionPoint(block); 1444 } 1445 1446 // Insert before first op. 1447 PyOperationRef firstOpRef = PyOperation::forOperation( 1448 block.getParentOperation()->getContext(), firstOp); 1449 return PyInsertionPoint{block, std::move(firstOpRef)}; 1450 } 1451 1452 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1453 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1454 if (mlirOperationIsNull(terminator)) 1455 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1456 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1457 block.getParentOperation()->getContext(), terminator); 1458 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1459 } 1460 1461 py::object PyInsertionPoint::contextEnter() { 1462 return PyThreadContextEntry::pushInsertionPoint(*this); 1463 } 1464 1465 void PyInsertionPoint::contextExit(pybind11::object excType, 1466 pybind11::object excVal, 1467 pybind11::object excTb) { 1468 PyThreadContextEntry::popInsertionPoint(*this); 1469 } 1470 1471 //------------------------------------------------------------------------------ 1472 // PyAttribute. 1473 //------------------------------------------------------------------------------ 1474 1475 bool PyAttribute::operator==(const PyAttribute &other) { 1476 return mlirAttributeEqual(attr, other.attr); 1477 } 1478 1479 py::object PyAttribute::getCapsule() { 1480 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1481 } 1482 1483 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1484 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1485 if (mlirAttributeIsNull(rawAttr)) 1486 throw py::error_already_set(); 1487 return PyAttribute( 1488 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1489 } 1490 1491 //------------------------------------------------------------------------------ 1492 // PyNamedAttribute. 1493 //------------------------------------------------------------------------------ 1494 1495 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1496 : ownedName(new std::string(std::move(ownedName))) { 1497 namedAttr = mlirNamedAttributeGet( 1498 mlirIdentifierGet(mlirAttributeGetContext(attr), 1499 toMlirStringRef(*this->ownedName)), 1500 attr); 1501 } 1502 1503 //------------------------------------------------------------------------------ 1504 // PyType. 1505 //------------------------------------------------------------------------------ 1506 1507 bool PyType::operator==(const PyType &other) { 1508 return mlirTypeEqual(type, other.type); 1509 } 1510 1511 py::object PyType::getCapsule() { 1512 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1513 } 1514 1515 PyType PyType::createFromCapsule(py::object capsule) { 1516 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1517 if (mlirTypeIsNull(rawType)) 1518 throw py::error_already_set(); 1519 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1520 rawType); 1521 } 1522 1523 //------------------------------------------------------------------------------ 1524 // PyValue and subclases. 1525 //------------------------------------------------------------------------------ 1526 1527 pybind11::object PyValue::getCapsule() { 1528 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1529 } 1530 1531 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1532 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1533 if (mlirValueIsNull(value)) 1534 throw py::error_already_set(); 1535 MlirOperation owner; 1536 if (mlirValueIsAOpResult(value)) 1537 owner = mlirOpResultGetOwner(value); 1538 if (mlirValueIsABlockArgument(value)) 1539 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1540 if (mlirOperationIsNull(owner)) 1541 throw py::error_already_set(); 1542 MlirContext ctx = mlirOperationGetContext(owner); 1543 PyOperationRef ownerRef = 1544 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1545 return PyValue(ownerRef, value); 1546 } 1547 1548 //------------------------------------------------------------------------------ 1549 // PySymbolTable. 1550 //------------------------------------------------------------------------------ 1551 1552 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1553 : operation(operation.getOperation().getRef()) { 1554 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1555 if (mlirSymbolTableIsNull(symbolTable)) { 1556 throw py::cast_error("Operation is not a Symbol Table."); 1557 } 1558 } 1559 1560 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1561 operation->checkValid(); 1562 MlirOperation symbol = mlirSymbolTableLookup( 1563 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1564 if (mlirOperationIsNull(symbol)) 1565 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1566 1567 return PyOperation::forOperation(operation->getContext(), symbol, 1568 operation.getObject()) 1569 ->createOpView(); 1570 } 1571 1572 void PySymbolTable::erase(PyOperationBase &symbol) { 1573 operation->checkValid(); 1574 symbol.getOperation().checkValid(); 1575 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1576 // The operation is also erased, so we must invalidate it. There may be Python 1577 // references to this operation so we don't want to delete it from the list of 1578 // live operations here. 1579 symbol.getOperation().valid = false; 1580 } 1581 1582 void PySymbolTable::dunderDel(const std::string &name) { 1583 py::object operation = dunderGetItem(name); 1584 erase(py::cast<PyOperationBase &>(operation)); 1585 } 1586 1587 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1588 operation->checkValid(); 1589 symbol.getOperation().checkValid(); 1590 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1591 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1592 if (mlirAttributeIsNull(symbolAttr)) 1593 throw py::value_error("Expected operation to have a symbol name."); 1594 return PyAttribute( 1595 symbol.getOperation().getContext(), 1596 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1597 } 1598 1599 namespace { 1600 /// CRTP base class for Python MLIR values that subclass Value and should be 1601 /// castable from it. The value hierarchy is one level deep and is not supposed 1602 /// to accommodate other levels unless core MLIR changes. 1603 template <typename DerivedTy> 1604 class PyConcreteValue : public PyValue { 1605 public: 1606 // Derived classes must define statics for: 1607 // IsAFunctionTy isaFunction 1608 // const char *pyClassName 1609 // and redefine bindDerived. 1610 using ClassTy = py::class_<DerivedTy, PyValue>; 1611 using IsAFunctionTy = bool (*)(MlirValue); 1612 1613 PyConcreteValue() = default; 1614 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1615 : PyValue(operationRef, value) {} 1616 PyConcreteValue(PyValue &orig) 1617 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1618 1619 /// Attempts to cast the original value to the derived type and throws on 1620 /// type mismatches. 1621 static MlirValue castFrom(PyValue &orig) { 1622 if (!DerivedTy::isaFunction(orig.get())) { 1623 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1624 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1625 DerivedTy::pyClassName + 1626 " (from " + origRepr + ")"); 1627 } 1628 return orig.get(); 1629 } 1630 1631 /// Binds the Python module objects to functions of this class. 1632 static void bind(py::module &m) { 1633 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1634 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1635 cls.def_static( 1636 "isinstance", 1637 [](PyValue &otherValue) -> bool { 1638 return DerivedTy::isaFunction(otherValue); 1639 }, 1640 py::arg("other_value")); 1641 DerivedTy::bindDerived(cls); 1642 } 1643 1644 /// Implemented by derived classes to add methods to the Python subclass. 1645 static void bindDerived(ClassTy &m) {} 1646 }; 1647 1648 /// Python wrapper for MlirBlockArgument. 1649 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1650 public: 1651 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1652 static constexpr const char *pyClassName = "BlockArgument"; 1653 using PyConcreteValue::PyConcreteValue; 1654 1655 static void bindDerived(ClassTy &c) { 1656 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1657 return PyBlock(self.getParentOperation(), 1658 mlirBlockArgumentGetOwner(self.get())); 1659 }); 1660 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1661 return mlirBlockArgumentGetArgNumber(self.get()); 1662 }); 1663 c.def( 1664 "set_type", 1665 [](PyBlockArgument &self, PyType type) { 1666 return mlirBlockArgumentSetType(self.get(), type); 1667 }, 1668 py::arg("type")); 1669 } 1670 }; 1671 1672 /// Python wrapper for MlirOpResult. 1673 class PyOpResult : public PyConcreteValue<PyOpResult> { 1674 public: 1675 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1676 static constexpr const char *pyClassName = "OpResult"; 1677 using PyConcreteValue::PyConcreteValue; 1678 1679 static void bindDerived(ClassTy &c) { 1680 c.def_property_readonly("owner", [](PyOpResult &self) { 1681 assert( 1682 mlirOperationEqual(self.getParentOperation()->get(), 1683 mlirOpResultGetOwner(self.get())) && 1684 "expected the owner of the value in Python to match that in the IR"); 1685 return self.getParentOperation().getObject(); 1686 }); 1687 c.def_property_readonly("result_number", [](PyOpResult &self) { 1688 return mlirOpResultGetResultNumber(self.get()); 1689 }); 1690 } 1691 }; 1692 1693 /// Returns the list of types of the values held by container. 1694 template <typename Container> 1695 static std::vector<PyType> getValueTypes(Container &container, 1696 PyMlirContextRef &context) { 1697 std::vector<PyType> result; 1698 result.reserve(container.getNumElements()); 1699 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1700 result.push_back( 1701 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1702 } 1703 return result; 1704 } 1705 1706 /// A list of block arguments. Internally, these are stored as consecutive 1707 /// elements, random access is cheap. The argument list is associated with the 1708 /// operation that contains the block (detached blocks are not allowed in 1709 /// Python bindings) and extends its lifetime. 1710 class PyBlockArgumentList 1711 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1712 public: 1713 static constexpr const char *pyClassName = "BlockArgumentList"; 1714 1715 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1716 intptr_t startIndex = 0, intptr_t length = -1, 1717 intptr_t step = 1) 1718 : Sliceable(startIndex, 1719 length == -1 ? mlirBlockGetNumArguments(block) : length, 1720 step), 1721 operation(std::move(operation)), block(block) {} 1722 1723 /// Returns the number of arguments in the list. 1724 intptr_t getNumElements() { 1725 operation->checkValid(); 1726 return mlirBlockGetNumArguments(block); 1727 } 1728 1729 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1730 PyBlockArgument getElement(intptr_t pos) { 1731 MlirValue argument = mlirBlockGetArgument(block, pos); 1732 return PyBlockArgument(operation, argument); 1733 } 1734 1735 /// Returns a sublist of this list. 1736 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1737 intptr_t step) { 1738 return PyBlockArgumentList(operation, block, startIndex, length, step); 1739 } 1740 1741 static void bindDerived(ClassTy &c) { 1742 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1743 return getValueTypes(self, self.operation->getContext()); 1744 }); 1745 } 1746 1747 private: 1748 PyOperationRef operation; 1749 MlirBlock block; 1750 }; 1751 1752 /// A list of operation operands. Internally, these are stored as consecutive 1753 /// elements, random access is cheap. The result list is associated with the 1754 /// operation whose results these are, and extends the lifetime of this 1755 /// operation. 1756 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1757 public: 1758 static constexpr const char *pyClassName = "OpOperandList"; 1759 1760 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1761 intptr_t length = -1, intptr_t step = 1) 1762 : Sliceable(startIndex, 1763 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1764 : length, 1765 step), 1766 operation(operation) {} 1767 1768 intptr_t getNumElements() { 1769 operation->checkValid(); 1770 return mlirOperationGetNumOperands(operation->get()); 1771 } 1772 1773 PyValue getElement(intptr_t pos) { 1774 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1775 MlirOperation owner; 1776 if (mlirValueIsAOpResult(operand)) 1777 owner = mlirOpResultGetOwner(operand); 1778 else if (mlirValueIsABlockArgument(operand)) 1779 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1780 else 1781 assert(false && "Value must be an block arg or op result."); 1782 PyOperationRef pyOwner = 1783 PyOperation::forOperation(operation->getContext(), owner); 1784 return PyValue(pyOwner, operand); 1785 } 1786 1787 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1788 return PyOpOperandList(operation, startIndex, length, step); 1789 } 1790 1791 void dunderSetItem(intptr_t index, PyValue value) { 1792 index = wrapIndex(index); 1793 mlirOperationSetOperand(operation->get(), index, value.get()); 1794 } 1795 1796 static void bindDerived(ClassTy &c) { 1797 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1798 } 1799 1800 private: 1801 PyOperationRef operation; 1802 }; 1803 1804 /// A list of operation results. Internally, these are stored as consecutive 1805 /// elements, random access is cheap. The result list is associated with the 1806 /// operation whose results these are, and extends the lifetime of this 1807 /// operation. 1808 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1809 public: 1810 static constexpr const char *pyClassName = "OpResultList"; 1811 1812 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1813 intptr_t length = -1, intptr_t step = 1) 1814 : Sliceable(startIndex, 1815 length == -1 ? mlirOperationGetNumResults(operation->get()) 1816 : length, 1817 step), 1818 operation(operation) {} 1819 1820 intptr_t getNumElements() { 1821 operation->checkValid(); 1822 return mlirOperationGetNumResults(operation->get()); 1823 } 1824 1825 PyOpResult getElement(intptr_t index) { 1826 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1827 return PyOpResult(value); 1828 } 1829 1830 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1831 return PyOpResultList(operation, startIndex, length, step); 1832 } 1833 1834 static void bindDerived(ClassTy &c) { 1835 c.def_property_readonly("types", [](PyOpResultList &self) { 1836 return getValueTypes(self, self.operation->getContext()); 1837 }); 1838 } 1839 1840 private: 1841 PyOperationRef operation; 1842 }; 1843 1844 /// A list of operation attributes. Can be indexed by name, producing 1845 /// attributes, or by index, producing named attributes. 1846 class PyOpAttributeMap { 1847 public: 1848 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1849 1850 PyAttribute dunderGetItemNamed(const std::string &name) { 1851 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1852 toMlirStringRef(name)); 1853 if (mlirAttributeIsNull(attr)) { 1854 throw SetPyError(PyExc_KeyError, 1855 "attempt to access a non-existent attribute"); 1856 } 1857 return PyAttribute(operation->getContext(), attr); 1858 } 1859 1860 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1861 if (index < 0 || index >= dunderLen()) { 1862 throw SetPyError(PyExc_IndexError, 1863 "attempt to access out of bounds attribute"); 1864 } 1865 MlirNamedAttribute namedAttr = 1866 mlirOperationGetAttribute(operation->get(), index); 1867 return PyNamedAttribute( 1868 namedAttr.attribute, 1869 std::string(mlirIdentifierStr(namedAttr.name).data, 1870 mlirIdentifierStr(namedAttr.name).length)); 1871 } 1872 1873 void dunderSetItem(const std::string &name, PyAttribute attr) { 1874 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1875 attr); 1876 } 1877 1878 void dunderDelItem(const std::string &name) { 1879 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1880 toMlirStringRef(name)); 1881 if (!removed) 1882 throw SetPyError(PyExc_KeyError, 1883 "attempt to delete a non-existent attribute"); 1884 } 1885 1886 intptr_t dunderLen() { 1887 return mlirOperationGetNumAttributes(operation->get()); 1888 } 1889 1890 bool dunderContains(const std::string &name) { 1891 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1892 operation->get(), toMlirStringRef(name))); 1893 } 1894 1895 static void bind(py::module &m) { 1896 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1897 .def("__contains__", &PyOpAttributeMap::dunderContains) 1898 .def("__len__", &PyOpAttributeMap::dunderLen) 1899 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1900 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1901 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1902 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1903 } 1904 1905 private: 1906 PyOperationRef operation; 1907 }; 1908 1909 } // end namespace 1910 1911 //------------------------------------------------------------------------------ 1912 // Populates the core exports of the 'ir' submodule. 1913 //------------------------------------------------------------------------------ 1914 1915 void mlir::python::populateIRCore(py::module &m) { 1916 //---------------------------------------------------------------------------- 1917 // Mapping of MlirContext. 1918 //---------------------------------------------------------------------------- 1919 py::class_<PyMlirContext>(m, "Context", py::module_local()) 1920 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1921 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1922 .def("_get_context_again", 1923 [](PyMlirContext &self) { 1924 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1925 return ref.releaseObject(); 1926 }) 1927 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1928 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1929 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1930 &PyMlirContext::getCapsule) 1931 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1932 .def("__enter__", &PyMlirContext::contextEnter) 1933 .def("__exit__", &PyMlirContext::contextExit) 1934 .def_property_readonly_static( 1935 "current", 1936 [](py::object & /*class*/) { 1937 auto *context = PyThreadContextEntry::getDefaultContext(); 1938 if (!context) 1939 throw SetPyError(PyExc_ValueError, "No current Context"); 1940 return context; 1941 }, 1942 "Gets the Context bound to the current thread or raises ValueError") 1943 .def_property_readonly( 1944 "dialects", 1945 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1946 "Gets a container for accessing dialects by name") 1947 .def_property_readonly( 1948 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1949 "Alias for 'dialect'") 1950 .def( 1951 "get_dialect_descriptor", 1952 [=](PyMlirContext &self, std::string &name) { 1953 MlirDialect dialect = mlirContextGetOrLoadDialect( 1954 self.get(), {name.data(), name.size()}); 1955 if (mlirDialectIsNull(dialect)) { 1956 throw SetPyError(PyExc_ValueError, 1957 Twine("Dialect '") + name + "' not found"); 1958 } 1959 return PyDialectDescriptor(self.getRef(), dialect); 1960 }, 1961 py::arg("dialect_name"), 1962 "Gets or loads a dialect by name, returning its descriptor object") 1963 .def_property( 1964 "allow_unregistered_dialects", 1965 [](PyMlirContext &self) -> bool { 1966 return mlirContextGetAllowUnregisteredDialects(self.get()); 1967 }, 1968 [](PyMlirContext &self, bool value) { 1969 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1970 }) 1971 .def( 1972 "enable_multithreading", 1973 [](PyMlirContext &self, bool enable) { 1974 mlirContextEnableMultithreading(self.get(), enable); 1975 }, 1976 py::arg("enable")) 1977 .def( 1978 "is_registered_operation", 1979 [](PyMlirContext &self, std::string &name) { 1980 return mlirContextIsRegisteredOperation( 1981 self.get(), MlirStringRef{name.data(), name.size()}); 1982 }, 1983 py::arg("operation_name")); 1984 1985 //---------------------------------------------------------------------------- 1986 // Mapping of PyDialectDescriptor 1987 //---------------------------------------------------------------------------- 1988 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1989 .def_property_readonly("namespace", 1990 [](PyDialectDescriptor &self) { 1991 MlirStringRef ns = 1992 mlirDialectGetNamespace(self.get()); 1993 return py::str(ns.data, ns.length); 1994 }) 1995 .def("__repr__", [](PyDialectDescriptor &self) { 1996 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1997 std::string repr("<DialectDescriptor "); 1998 repr.append(ns.data, ns.length); 1999 repr.append(">"); 2000 return repr; 2001 }); 2002 2003 //---------------------------------------------------------------------------- 2004 // Mapping of PyDialects 2005 //---------------------------------------------------------------------------- 2006 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2007 .def("__getitem__", 2008 [=](PyDialects &self, std::string keyName) { 2009 MlirDialect dialect = 2010 self.getDialectForKey(keyName, /*attrError=*/false); 2011 py::object descriptor = 2012 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2013 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2014 }) 2015 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2016 MlirDialect dialect = 2017 self.getDialectForKey(attrName, /*attrError=*/true); 2018 py::object descriptor = 2019 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2020 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2021 }); 2022 2023 //---------------------------------------------------------------------------- 2024 // Mapping of PyDialect 2025 //---------------------------------------------------------------------------- 2026 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2027 .def(py::init<py::object>(), py::arg("descriptor")) 2028 .def_property_readonly( 2029 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2030 .def("__repr__", [](py::object self) { 2031 auto clazz = self.attr("__class__"); 2032 return py::str("<Dialect ") + 2033 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2034 clazz.attr("__module__") + py::str(".") + 2035 clazz.attr("__name__") + py::str(")>"); 2036 }); 2037 2038 //---------------------------------------------------------------------------- 2039 // Mapping of Location 2040 //---------------------------------------------------------------------------- 2041 py::class_<PyLocation>(m, "Location", py::module_local()) 2042 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2043 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2044 .def("__enter__", &PyLocation::contextEnter) 2045 .def("__exit__", &PyLocation::contextExit) 2046 .def("__eq__", 2047 [](PyLocation &self, PyLocation &other) -> bool { 2048 return mlirLocationEqual(self, other); 2049 }) 2050 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2051 .def_property_readonly_static( 2052 "current", 2053 [](py::object & /*class*/) { 2054 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2055 if (!loc) 2056 throw SetPyError(PyExc_ValueError, "No current Location"); 2057 return loc; 2058 }, 2059 "Gets the Location bound to the current thread or raises ValueError") 2060 .def_static( 2061 "unknown", 2062 [](DefaultingPyMlirContext context) { 2063 return PyLocation(context->getRef(), 2064 mlirLocationUnknownGet(context->get())); 2065 }, 2066 py::arg("context") = py::none(), 2067 "Gets a Location representing an unknown location") 2068 .def_static( 2069 "callsite", 2070 [](PyLocation callee, const std::vector<PyLocation> &frames, 2071 DefaultingPyMlirContext context) { 2072 if (frames.empty()) 2073 throw py::value_error("No caller frames provided"); 2074 MlirLocation caller = frames.back().get(); 2075 for (const PyLocation &frame : 2076 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2077 caller = mlirLocationCallSiteGet(frame.get(), caller); 2078 return PyLocation(context->getRef(), 2079 mlirLocationCallSiteGet(callee.get(), caller)); 2080 }, 2081 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2082 kContextGetCallSiteLocationDocstring) 2083 .def_static( 2084 "file", 2085 [](std::string filename, int line, int col, 2086 DefaultingPyMlirContext context) { 2087 return PyLocation( 2088 context->getRef(), 2089 mlirLocationFileLineColGet( 2090 context->get(), toMlirStringRef(filename), line, col)); 2091 }, 2092 py::arg("filename"), py::arg("line"), py::arg("col"), 2093 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2094 .def_static( 2095 "name", 2096 [](std::string name, llvm::Optional<PyLocation> childLoc, 2097 DefaultingPyMlirContext context) { 2098 return PyLocation( 2099 context->getRef(), 2100 mlirLocationNameGet( 2101 context->get(), toMlirStringRef(name), 2102 childLoc ? childLoc->get() 2103 : mlirLocationUnknownGet(context->get()))); 2104 }, 2105 py::arg("name"), py::arg("childLoc") = py::none(), 2106 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2107 .def_property_readonly( 2108 "context", 2109 [](PyLocation &self) { return self.getContext().getObject(); }, 2110 "Context that owns the Location") 2111 .def("__repr__", [](PyLocation &self) { 2112 PyPrintAccumulator printAccum; 2113 mlirLocationPrint(self, printAccum.getCallback(), 2114 printAccum.getUserData()); 2115 return printAccum.join(); 2116 }); 2117 2118 //---------------------------------------------------------------------------- 2119 // Mapping of Module 2120 //---------------------------------------------------------------------------- 2121 py::class_<PyModule>(m, "Module", py::module_local()) 2122 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2123 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2124 .def_static( 2125 "parse", 2126 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2127 MlirModule module = mlirModuleCreateParse( 2128 context->get(), toMlirStringRef(moduleAsm)); 2129 // TODO: Rework error reporting once diagnostic engine is exposed 2130 // in C API. 2131 if (mlirModuleIsNull(module)) { 2132 throw SetPyError( 2133 PyExc_ValueError, 2134 "Unable to parse module assembly (see diagnostics)"); 2135 } 2136 return PyModule::forModule(module).releaseObject(); 2137 }, 2138 py::arg("asm"), py::arg("context") = py::none(), 2139 kModuleParseDocstring) 2140 .def_static( 2141 "create", 2142 [](DefaultingPyLocation loc) { 2143 MlirModule module = mlirModuleCreateEmpty(loc); 2144 return PyModule::forModule(module).releaseObject(); 2145 }, 2146 py::arg("loc") = py::none(), "Creates an empty module") 2147 .def_property_readonly( 2148 "context", 2149 [](PyModule &self) { return self.getContext().getObject(); }, 2150 "Context that created the Module") 2151 .def_property_readonly( 2152 "operation", 2153 [](PyModule &self) { 2154 return PyOperation::forOperation(self.getContext(), 2155 mlirModuleGetOperation(self.get()), 2156 self.getRef().releaseObject()) 2157 .releaseObject(); 2158 }, 2159 "Accesses the module as an operation") 2160 .def_property_readonly( 2161 "body", 2162 [](PyModule &self) { 2163 PyOperationRef module_op = PyOperation::forOperation( 2164 self.getContext(), mlirModuleGetOperation(self.get()), 2165 self.getRef().releaseObject()); 2166 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2167 return returnBlock; 2168 }, 2169 "Return the block for this module") 2170 .def( 2171 "dump", 2172 [](PyModule &self) { 2173 mlirOperationDump(mlirModuleGetOperation(self.get())); 2174 }, 2175 kDumpDocstring) 2176 .def( 2177 "__str__", 2178 [](py::object self) { 2179 // Defer to the operation's __str__. 2180 return self.attr("operation").attr("__str__")(); 2181 }, 2182 kOperationStrDunderDocstring); 2183 2184 //---------------------------------------------------------------------------- 2185 // Mapping of Operation. 2186 //---------------------------------------------------------------------------- 2187 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2188 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2189 [](PyOperationBase &self) { 2190 return self.getOperation().getCapsule(); 2191 }) 2192 .def("__eq__", 2193 [](PyOperationBase &self, PyOperationBase &other) { 2194 return &self.getOperation() == &other.getOperation(); 2195 }) 2196 .def("__eq__", 2197 [](PyOperationBase &self, py::object other) { return false; }) 2198 .def("__hash__", 2199 [](PyOperationBase &self) { 2200 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2201 }) 2202 .def_property_readonly("attributes", 2203 [](PyOperationBase &self) { 2204 return PyOpAttributeMap( 2205 self.getOperation().getRef()); 2206 }) 2207 .def_property_readonly("operands", 2208 [](PyOperationBase &self) { 2209 return PyOpOperandList( 2210 self.getOperation().getRef()); 2211 }) 2212 .def_property_readonly("regions", 2213 [](PyOperationBase &self) { 2214 return PyRegionList( 2215 self.getOperation().getRef()); 2216 }) 2217 .def_property_readonly( 2218 "results", 2219 [](PyOperationBase &self) { 2220 return PyOpResultList(self.getOperation().getRef()); 2221 }, 2222 "Returns the list of Operation results.") 2223 .def_property_readonly( 2224 "result", 2225 [](PyOperationBase &self) { 2226 auto &operation = self.getOperation(); 2227 auto numResults = mlirOperationGetNumResults(operation); 2228 if (numResults != 1) { 2229 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2230 throw SetPyError( 2231 PyExc_ValueError, 2232 Twine("Cannot call .result on operation ") + 2233 StringRef(name.data, name.length) + " which has " + 2234 Twine(numResults) + 2235 " results (it is only valid for operations with a " 2236 "single result)"); 2237 } 2238 return PyOpResult(operation.getRef(), 2239 mlirOperationGetResult(operation, 0)); 2240 }, 2241 "Shortcut to get an op result if it has only one (throws an error " 2242 "otherwise).") 2243 .def_property_readonly( 2244 "location", 2245 [](PyOperationBase &self) { 2246 PyOperation &operation = self.getOperation(); 2247 return PyLocation(operation.getContext(), 2248 mlirOperationGetLocation(operation.get())); 2249 }, 2250 "Returns the source location the operation was defined or derived " 2251 "from.") 2252 .def( 2253 "__str__", 2254 [](PyOperationBase &self) { 2255 return self.getAsm(/*binary=*/false, 2256 /*largeElementsLimit=*/llvm::None, 2257 /*enableDebugInfo=*/false, 2258 /*prettyDebugInfo=*/false, 2259 /*printGenericOpForm=*/false, 2260 /*useLocalScope=*/false, 2261 /*assumeVerified=*/false); 2262 }, 2263 "Returns the assembly form of the operation.") 2264 .def("print", &PyOperationBase::print, 2265 // Careful: Lots of arguments must match up with print method. 2266 py::arg("file") = py::none(), py::arg("binary") = false, 2267 py::arg("large_elements_limit") = py::none(), 2268 py::arg("enable_debug_info") = false, 2269 py::arg("pretty_debug_info") = false, 2270 py::arg("print_generic_op_form") = false, 2271 py::arg("use_local_scope") = false, 2272 py::arg("assume_verified") = false, kOperationPrintDocstring) 2273 .def("get_asm", &PyOperationBase::getAsm, 2274 // Careful: Lots of arguments must match up with get_asm method. 2275 py::arg("binary") = false, 2276 py::arg("large_elements_limit") = py::none(), 2277 py::arg("enable_debug_info") = false, 2278 py::arg("pretty_debug_info") = false, 2279 py::arg("print_generic_op_form") = false, 2280 py::arg("use_local_scope") = false, 2281 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2282 .def( 2283 "verify", 2284 [](PyOperationBase &self) { 2285 return mlirOperationVerify(self.getOperation()); 2286 }, 2287 "Verify the operation and return true if it passes, false if it " 2288 "fails.") 2289 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2290 "Puts self immediately after the other operation in its parent " 2291 "block.") 2292 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2293 "Puts self immediately before the other operation in its parent " 2294 "block.") 2295 .def( 2296 "detach_from_parent", 2297 [](PyOperationBase &self) { 2298 PyOperation &operation = self.getOperation(); 2299 operation.checkValid(); 2300 if (!operation.isAttached()) 2301 throw py::value_error("Detached operation has no parent."); 2302 2303 operation.detachFromParent(); 2304 return operation.createOpView(); 2305 }, 2306 "Detaches the operation from its parent block."); 2307 2308 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2309 .def_static("create", &PyOperation::create, py::arg("name"), 2310 py::arg("results") = py::none(), 2311 py::arg("operands") = py::none(), 2312 py::arg("attributes") = py::none(), 2313 py::arg("successors") = py::none(), py::arg("regions") = 0, 2314 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2315 kOperationCreateDocstring) 2316 .def_property_readonly("parent", 2317 [](PyOperation &self) -> py::object { 2318 auto parent = self.getParentOperation(); 2319 if (parent) 2320 return parent->getObject(); 2321 return py::none(); 2322 }) 2323 .def("erase", &PyOperation::erase) 2324 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2325 &PyOperation::getCapsule) 2326 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2327 .def_property_readonly("name", 2328 [](PyOperation &self) { 2329 self.checkValid(); 2330 MlirOperation operation = self.get(); 2331 MlirStringRef name = mlirIdentifierStr( 2332 mlirOperationGetName(operation)); 2333 return py::str(name.data, name.length); 2334 }) 2335 .def_property_readonly( 2336 "context", 2337 [](PyOperation &self) { 2338 self.checkValid(); 2339 return self.getContext().getObject(); 2340 }, 2341 "Context that owns the Operation") 2342 .def_property_readonly("opview", &PyOperation::createOpView); 2343 2344 auto opViewClass = 2345 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2346 .def(py::init<py::object>(), py::arg("operation")) 2347 .def_property_readonly("operation", &PyOpView::getOperationObject) 2348 .def_property_readonly( 2349 "context", 2350 [](PyOpView &self) { 2351 return self.getOperation().getContext().getObject(); 2352 }, 2353 "Context that owns the Operation") 2354 .def("__str__", [](PyOpView &self) { 2355 return py::str(self.getOperationObject()); 2356 }); 2357 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2358 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2359 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2360 opViewClass.attr("build_generic") = classmethod( 2361 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2362 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2363 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2364 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2365 "Builds a specific, generated OpView based on class level attributes."); 2366 2367 //---------------------------------------------------------------------------- 2368 // Mapping of PyRegion. 2369 //---------------------------------------------------------------------------- 2370 py::class_<PyRegion>(m, "Region", py::module_local()) 2371 .def_property_readonly( 2372 "blocks", 2373 [](PyRegion &self) { 2374 return PyBlockList(self.getParentOperation(), self.get()); 2375 }, 2376 "Returns a forward-optimized sequence of blocks.") 2377 .def_property_readonly( 2378 "owner", 2379 [](PyRegion &self) { 2380 return self.getParentOperation()->createOpView(); 2381 }, 2382 "Returns the operation owning this region.") 2383 .def( 2384 "__iter__", 2385 [](PyRegion &self) { 2386 self.checkValid(); 2387 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2388 return PyBlockIterator(self.getParentOperation(), firstBlock); 2389 }, 2390 "Iterates over blocks in the region.") 2391 .def("__eq__", 2392 [](PyRegion &self, PyRegion &other) { 2393 return self.get().ptr == other.get().ptr; 2394 }) 2395 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2396 2397 //---------------------------------------------------------------------------- 2398 // Mapping of PyBlock. 2399 //---------------------------------------------------------------------------- 2400 py::class_<PyBlock>(m, "Block", py::module_local()) 2401 .def_property_readonly( 2402 "owner", 2403 [](PyBlock &self) { 2404 return self.getParentOperation()->createOpView(); 2405 }, 2406 "Returns the owning operation of this block.") 2407 .def_property_readonly( 2408 "region", 2409 [](PyBlock &self) { 2410 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2411 return PyRegion(self.getParentOperation(), region); 2412 }, 2413 "Returns the owning region of this block.") 2414 .def_property_readonly( 2415 "arguments", 2416 [](PyBlock &self) { 2417 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2418 }, 2419 "Returns a list of block arguments.") 2420 .def_property_readonly( 2421 "operations", 2422 [](PyBlock &self) { 2423 return PyOperationList(self.getParentOperation(), self.get()); 2424 }, 2425 "Returns a forward-optimized sequence of operations.") 2426 .def_static( 2427 "create_at_start", 2428 [](PyRegion &parent, py::list pyArgTypes) { 2429 parent.checkValid(); 2430 llvm::SmallVector<MlirType, 4> argTypes; 2431 argTypes.reserve(pyArgTypes.size()); 2432 for (auto &pyArg : pyArgTypes) { 2433 argTypes.push_back(pyArg.cast<PyType &>()); 2434 } 2435 2436 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2437 mlirRegionInsertOwnedBlock(parent, 0, block); 2438 return PyBlock(parent.getParentOperation(), block); 2439 }, 2440 py::arg("parent"), py::arg("arg_types") = py::list(), 2441 "Creates and returns a new Block at the beginning of the given " 2442 "region (with given argument types).") 2443 .def( 2444 "create_before", 2445 [](PyBlock &self, py::args pyArgTypes) { 2446 self.checkValid(); 2447 llvm::SmallVector<MlirType, 4> argTypes; 2448 argTypes.reserve(pyArgTypes.size()); 2449 for (auto &pyArg : pyArgTypes) { 2450 argTypes.push_back(pyArg.cast<PyType &>()); 2451 } 2452 2453 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2454 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2455 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2456 return PyBlock(self.getParentOperation(), block); 2457 }, 2458 "Creates and returns a new Block before this block " 2459 "(with given argument types).") 2460 .def( 2461 "create_after", 2462 [](PyBlock &self, py::args pyArgTypes) { 2463 self.checkValid(); 2464 llvm::SmallVector<MlirType, 4> argTypes; 2465 argTypes.reserve(pyArgTypes.size()); 2466 for (auto &pyArg : pyArgTypes) { 2467 argTypes.push_back(pyArg.cast<PyType &>()); 2468 } 2469 2470 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2471 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2472 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2473 return PyBlock(self.getParentOperation(), block); 2474 }, 2475 "Creates and returns a new Block after this block " 2476 "(with given argument types).") 2477 .def( 2478 "__iter__", 2479 [](PyBlock &self) { 2480 self.checkValid(); 2481 MlirOperation firstOperation = 2482 mlirBlockGetFirstOperation(self.get()); 2483 return PyOperationIterator(self.getParentOperation(), 2484 firstOperation); 2485 }, 2486 "Iterates over operations in the block.") 2487 .def("__eq__", 2488 [](PyBlock &self, PyBlock &other) { 2489 return self.get().ptr == other.get().ptr; 2490 }) 2491 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2492 .def( 2493 "__str__", 2494 [](PyBlock &self) { 2495 self.checkValid(); 2496 PyPrintAccumulator printAccum; 2497 mlirBlockPrint(self.get(), printAccum.getCallback(), 2498 printAccum.getUserData()); 2499 return printAccum.join(); 2500 }, 2501 "Returns the assembly form of the block.") 2502 .def( 2503 "append", 2504 [](PyBlock &self, PyOperationBase &operation) { 2505 if (operation.getOperation().isAttached()) 2506 operation.getOperation().detachFromParent(); 2507 2508 MlirOperation mlirOperation = operation.getOperation().get(); 2509 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2510 operation.getOperation().setAttached( 2511 self.getParentOperation().getObject()); 2512 }, 2513 py::arg("operation"), 2514 "Appends an operation to this block. If the operation is currently " 2515 "in another block, it will be moved."); 2516 2517 //---------------------------------------------------------------------------- 2518 // Mapping of PyInsertionPoint. 2519 //---------------------------------------------------------------------------- 2520 2521 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2522 .def(py::init<PyBlock &>(), py::arg("block"), 2523 "Inserts after the last operation but still inside the block.") 2524 .def("__enter__", &PyInsertionPoint::contextEnter) 2525 .def("__exit__", &PyInsertionPoint::contextExit) 2526 .def_property_readonly_static( 2527 "current", 2528 [](py::object & /*class*/) { 2529 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2530 if (!ip) 2531 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2532 return ip; 2533 }, 2534 "Gets the InsertionPoint bound to the current thread or raises " 2535 "ValueError if none has been set") 2536 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2537 "Inserts before a referenced operation.") 2538 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2539 py::arg("block"), "Inserts at the beginning of the block.") 2540 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2541 py::arg("block"), "Inserts before the block terminator.") 2542 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2543 "Inserts an operation.") 2544 .def_property_readonly( 2545 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2546 "Returns the block that this InsertionPoint points to."); 2547 2548 //---------------------------------------------------------------------------- 2549 // Mapping of PyAttribute. 2550 //---------------------------------------------------------------------------- 2551 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2552 // Delegate to the PyAttribute copy constructor, which will also lifetime 2553 // extend the backing context which owns the MlirAttribute. 2554 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2555 "Casts the passed attribute to the generic Attribute") 2556 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2557 &PyAttribute::getCapsule) 2558 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2559 .def_static( 2560 "parse", 2561 [](std::string attrSpec, DefaultingPyMlirContext context) { 2562 MlirAttribute type = mlirAttributeParseGet( 2563 context->get(), toMlirStringRef(attrSpec)); 2564 // TODO: Rework error reporting once diagnostic engine is exposed 2565 // in C API. 2566 if (mlirAttributeIsNull(type)) { 2567 throw SetPyError(PyExc_ValueError, 2568 Twine("Unable to parse attribute: '") + 2569 attrSpec + "'"); 2570 } 2571 return PyAttribute(context->getRef(), type); 2572 }, 2573 py::arg("asm"), py::arg("context") = py::none(), 2574 "Parses an attribute from an assembly form") 2575 .def_property_readonly( 2576 "context", 2577 [](PyAttribute &self) { return self.getContext().getObject(); }, 2578 "Context that owns the Attribute") 2579 .def_property_readonly("type", 2580 [](PyAttribute &self) { 2581 return PyType(self.getContext()->getRef(), 2582 mlirAttributeGetType(self)); 2583 }) 2584 .def( 2585 "get_named", 2586 [](PyAttribute &self, std::string name) { 2587 return PyNamedAttribute(self, std::move(name)); 2588 }, 2589 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2590 .def("__eq__", 2591 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2592 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2593 .def("__hash__", 2594 [](PyAttribute &self) { 2595 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2596 }) 2597 .def( 2598 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2599 kDumpDocstring) 2600 .def( 2601 "__str__", 2602 [](PyAttribute &self) { 2603 PyPrintAccumulator printAccum; 2604 mlirAttributePrint(self, printAccum.getCallback(), 2605 printAccum.getUserData()); 2606 return printAccum.join(); 2607 }, 2608 "Returns the assembly form of the Attribute.") 2609 .def("__repr__", [](PyAttribute &self) { 2610 // Generally, assembly formats are not printed for __repr__ because 2611 // this can cause exceptionally long debug output and exceptions. 2612 // However, attribute values are generally considered useful and are 2613 // printed. This may need to be re-evaluated if debug dumps end up 2614 // being excessive. 2615 PyPrintAccumulator printAccum; 2616 printAccum.parts.append("Attribute("); 2617 mlirAttributePrint(self, printAccum.getCallback(), 2618 printAccum.getUserData()); 2619 printAccum.parts.append(")"); 2620 return printAccum.join(); 2621 }); 2622 2623 //---------------------------------------------------------------------------- 2624 // Mapping of PyNamedAttribute 2625 //---------------------------------------------------------------------------- 2626 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2627 .def("__repr__", 2628 [](PyNamedAttribute &self) { 2629 PyPrintAccumulator printAccum; 2630 printAccum.parts.append("NamedAttribute("); 2631 printAccum.parts.append( 2632 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2633 mlirIdentifierStr(self.namedAttr.name).length)); 2634 printAccum.parts.append("="); 2635 mlirAttributePrint(self.namedAttr.attribute, 2636 printAccum.getCallback(), 2637 printAccum.getUserData()); 2638 printAccum.parts.append(")"); 2639 return printAccum.join(); 2640 }) 2641 .def_property_readonly( 2642 "name", 2643 [](PyNamedAttribute &self) { 2644 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2645 mlirIdentifierStr(self.namedAttr.name).length); 2646 }, 2647 "The name of the NamedAttribute binding") 2648 .def_property_readonly( 2649 "attr", 2650 [](PyNamedAttribute &self) { 2651 // TODO: When named attribute is removed/refactored, also remove 2652 // this constructor (it does an inefficient table lookup). 2653 auto contextRef = PyMlirContext::forContext( 2654 mlirAttributeGetContext(self.namedAttr.attribute)); 2655 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2656 }, 2657 py::keep_alive<0, 1>(), 2658 "The underlying generic attribute of the NamedAttribute binding"); 2659 2660 //---------------------------------------------------------------------------- 2661 // Mapping of PyType. 2662 //---------------------------------------------------------------------------- 2663 py::class_<PyType>(m, "Type", py::module_local()) 2664 // Delegate to the PyType copy constructor, which will also lifetime 2665 // extend the backing context which owns the MlirType. 2666 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2667 "Casts the passed type to the generic Type") 2668 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2669 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2670 .def_static( 2671 "parse", 2672 [](std::string typeSpec, DefaultingPyMlirContext context) { 2673 MlirType type = 2674 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2675 // TODO: Rework error reporting once diagnostic engine is exposed 2676 // in C API. 2677 if (mlirTypeIsNull(type)) { 2678 throw SetPyError(PyExc_ValueError, 2679 Twine("Unable to parse type: '") + typeSpec + 2680 "'"); 2681 } 2682 return PyType(context->getRef(), type); 2683 }, 2684 py::arg("asm"), py::arg("context") = py::none(), 2685 kContextParseTypeDocstring) 2686 .def_property_readonly( 2687 "context", [](PyType &self) { return self.getContext().getObject(); }, 2688 "Context that owns the Type") 2689 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2690 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2691 .def("__hash__", 2692 [](PyType &self) { 2693 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2694 }) 2695 .def( 2696 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2697 .def( 2698 "__str__", 2699 [](PyType &self) { 2700 PyPrintAccumulator printAccum; 2701 mlirTypePrint(self, printAccum.getCallback(), 2702 printAccum.getUserData()); 2703 return printAccum.join(); 2704 }, 2705 "Returns the assembly form of the type.") 2706 .def("__repr__", [](PyType &self) { 2707 // Generally, assembly formats are not printed for __repr__ because 2708 // this can cause exceptionally long debug output and exceptions. 2709 // However, types are an exception as they typically have compact 2710 // assembly forms and printing them is useful. 2711 PyPrintAccumulator printAccum; 2712 printAccum.parts.append("Type("); 2713 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2714 printAccum.parts.append(")"); 2715 return printAccum.join(); 2716 }); 2717 2718 //---------------------------------------------------------------------------- 2719 // Mapping of Value. 2720 //---------------------------------------------------------------------------- 2721 py::class_<PyValue>(m, "Value", py::module_local()) 2722 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2723 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2724 .def_property_readonly( 2725 "context", 2726 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2727 "Context in which the value lives.") 2728 .def( 2729 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2730 kDumpDocstring) 2731 .def_property_readonly( 2732 "owner", 2733 [](PyValue &self) { 2734 assert(mlirOperationEqual(self.getParentOperation()->get(), 2735 mlirOpResultGetOwner(self.get())) && 2736 "expected the owner of the value in Python to match that in " 2737 "the IR"); 2738 return self.getParentOperation().getObject(); 2739 }) 2740 .def("__eq__", 2741 [](PyValue &self, PyValue &other) { 2742 return self.get().ptr == other.get().ptr; 2743 }) 2744 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2745 .def("__hash__", 2746 [](PyValue &self) { 2747 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2748 }) 2749 .def( 2750 "__str__", 2751 [](PyValue &self) { 2752 PyPrintAccumulator printAccum; 2753 printAccum.parts.append("Value("); 2754 mlirValuePrint(self.get(), printAccum.getCallback(), 2755 printAccum.getUserData()); 2756 printAccum.parts.append(")"); 2757 return printAccum.join(); 2758 }, 2759 kValueDunderStrDocstring) 2760 .def_property_readonly("type", [](PyValue &self) { 2761 return PyType(self.getParentOperation()->getContext(), 2762 mlirValueGetType(self.get())); 2763 }); 2764 PyBlockArgument::bind(m); 2765 PyOpResult::bind(m); 2766 2767 //---------------------------------------------------------------------------- 2768 // Mapping of SymbolTable. 2769 //---------------------------------------------------------------------------- 2770 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 2771 .def(py::init<PyOperationBase &>()) 2772 .def("__getitem__", &PySymbolTable::dunderGetItem) 2773 .def("insert", &PySymbolTable::insert, py::arg("operation")) 2774 .def("erase", &PySymbolTable::erase, py::arg("operation")) 2775 .def("__delitem__", &PySymbolTable::dunderDel) 2776 .def("__contains__", [](PySymbolTable &table, const std::string &name) { 2777 return !mlirOperationIsNull(mlirSymbolTableLookup( 2778 table, mlirStringRefCreate(name.data(), name.length()))); 2779 }); 2780 2781 // Container bindings. 2782 PyBlockArgumentList::bind(m); 2783 PyBlockIterator::bind(m); 2784 PyBlockList::bind(m); 2785 PyOperationIterator::bind(m); 2786 PyOperationList::bind(m); 2787 PyOpAttributeMap::bind(m); 2788 PyOpOperandList::bind(m); 2789 PyOpResultList::bind(m); 2790 PyRegionIterator::bind(m); 2791 PyRegionList::bind(m); 2792 2793 // Debug bindings. 2794 PyGlobalDebugFlag::bind(m); 2795 } 2796