1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 namespace py = pybind11; 25 using namespace mlir; 26 using namespace mlir::python; 27 28 using llvm::SmallVector; 29 using llvm::StringRef; 30 using llvm::Twine; 31 32 //------------------------------------------------------------------------------ 33 // Docstrings (trivial, non-duplicated docstrings are included inline). 34 //------------------------------------------------------------------------------ 35 36 static const char kContextParseTypeDocstring[] = 37 R"(Parses the assembly form of a type. 38 39 Returns a Type object or raises a ValueError if the type cannot be parsed. 40 41 See also: https://mlir.llvm.org/docs/LangRef/#type-system 42 )"; 43 44 static const char kContextGetCallSiteLocationDocstring[] = 45 R"(Gets a Location representing a caller and callsite)"; 46 47 static const char kContextGetFileLocationDocstring[] = 48 R"(Gets a Location representing a file, line and column)"; 49 50 static const char kContextGetNameLocationDocString[] = 51 R"(Gets a Location representing a named location with optional child location)"; 52 53 static const char kModuleParseDocstring[] = 54 R"(Parses a module's assembly format from a string. 55 56 Returns a new MlirModule or raises a ValueError if the parsing fails. 57 58 See also: https://mlir.llvm.org/docs/LangRef/ 59 )"; 60 61 static const char kOperationCreateDocstring[] = 62 R"(Creates a new operation. 63 64 Args: 65 name: Operation name (e.g. "dialect.operation"). 66 results: Sequence of Type representing op result types. 67 attributes: Dict of str:Attribute. 68 successors: List of Block for the operation's successors. 69 regions: Number of regions to create. 70 location: A Location object (defaults to resolve from context manager). 71 ip: An InsertionPoint (defaults to resolve from context manager or set to 72 False to disable insertion, even with an insertion point set in the 73 context manager). 74 Returns: 75 A new "detached" Operation object. Detached operations can be added 76 to blocks, which causes them to become "attached." 77 )"; 78 79 static const char kOperationPrintDocstring[] = 80 R"(Prints the assembly form of the operation to a file like object. 81 82 Args: 83 file: The file like object to write to. Defaults to sys.stdout. 84 binary: Whether to write bytes (True) or str (False). Defaults to False. 85 large_elements_limit: Whether to elide elements attributes above this 86 number of elements. Defaults to None (no limit). 87 enable_debug_info: Whether to print debug/location information. Defaults 88 to False. 89 pretty_debug_info: Whether to format debug information for easier reading 90 by a human (warning: the result is unparseable). 91 print_generic_op_form: Whether to print the generic assembly forms of all 92 ops. Defaults to False. 93 use_local_Scope: Whether to print in a way that is more optimized for 94 multi-threaded access but may not be consistent with how the overall 95 module prints. 96 assume_verified: By default, if not printing generic form, the verifier 97 will be run and if it fails, generic form will be printed with a comment 98 about failed verification. While a reasonable default for interactive use, 99 for systematic use, it is often better for the caller to verify explicitly 100 and report failures in a more robust fashion. Set this to True if doing this 101 in order to avoid running a redundant verification. If the IR is actually 102 invalid, behavior is undefined. 103 )"; 104 105 static const char kOperationGetAsmDocstring[] = 106 R"(Gets the assembly form of the operation with all options available. 107 108 Args: 109 binary: Whether to return a bytes (True) or str (False) object. Defaults to 110 False. 111 ... others ...: See the print() method for common keyword arguments for 112 configuring the printout. 113 Returns: 114 Either a bytes or str object, depending on the setting of the 'binary' 115 argument. 116 )"; 117 118 static const char kOperationStrDunderDocstring[] = 119 R"(Gets the assembly form of the operation with default options. 120 121 If more advanced control over the assembly formatting or I/O options is needed, 122 use the dedicated print or get_asm method, which supports keyword arguments to 123 customize behavior. 124 )"; 125 126 static const char kDumpDocstring[] = 127 R"(Dumps a debug representation of the object to stderr.)"; 128 129 static const char kAppendBlockDocstring[] = 130 R"(Appends a new block, with argument types as positional args. 131 132 Returns: 133 The created block. 134 )"; 135 136 static const char kValueDunderStrDocstring[] = 137 R"(Returns the string form of the value. 138 139 If the value is a block argument, this is the assembly form of its type and the 140 position in the argument list. If the value is an operation result, this is 141 equivalent to printing the operation that produced it. 142 )"; 143 144 //------------------------------------------------------------------------------ 145 // Utilities. 146 //------------------------------------------------------------------------------ 147 148 /// Helper for creating an @classmethod. 149 template <class Func, typename... Args> 150 py::object classmethod(Func f, Args... args) { 151 py::object cf = py::cpp_function(f, args...); 152 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 153 } 154 155 static py::object 156 createCustomDialectWrapper(const std::string &dialectNamespace, 157 py::object dialectDescriptor) { 158 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 159 if (!dialectClass) { 160 // Use the base class. 161 return py::cast(PyDialect(std::move(dialectDescriptor))); 162 } 163 164 // Create the custom implementation. 165 return (*dialectClass)(std::move(dialectDescriptor)); 166 } 167 168 static MlirStringRef toMlirStringRef(const std::string &s) { 169 return mlirStringRefCreate(s.data(), s.size()); 170 } 171 172 /// Wrapper for the global LLVM debugging flag. 173 struct PyGlobalDebugFlag { 174 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 175 176 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 177 178 static void bind(py::module &m) { 179 // Debug flags. 180 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 181 .def_property_static("flag", &PyGlobalDebugFlag::get, 182 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 183 } 184 }; 185 186 //------------------------------------------------------------------------------ 187 // Collections. 188 //------------------------------------------------------------------------------ 189 190 namespace { 191 192 class PyRegionIterator { 193 public: 194 PyRegionIterator(PyOperationRef operation) 195 : operation(std::move(operation)) {} 196 197 PyRegionIterator &dunderIter() { return *this; } 198 199 PyRegion dunderNext() { 200 operation->checkValid(); 201 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 202 throw py::stop_iteration(); 203 } 204 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 205 return PyRegion(operation, region); 206 } 207 208 static void bind(py::module &m) { 209 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 210 .def("__iter__", &PyRegionIterator::dunderIter) 211 .def("__next__", &PyRegionIterator::dunderNext); 212 } 213 214 private: 215 PyOperationRef operation; 216 int nextIndex = 0; 217 }; 218 219 /// Regions of an op are fixed length and indexed numerically so are represented 220 /// with a sequence-like container. 221 class PyRegionList { 222 public: 223 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 224 225 intptr_t dunderLen() { 226 operation->checkValid(); 227 return mlirOperationGetNumRegions(operation->get()); 228 } 229 230 PyRegion dunderGetItem(intptr_t index) { 231 // dunderLen checks validity. 232 if (index < 0 || index >= dunderLen()) { 233 throw SetPyError(PyExc_IndexError, 234 "attempt to access out of bounds region"); 235 } 236 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 237 return PyRegion(operation, region); 238 } 239 240 static void bind(py::module &m) { 241 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 242 .def("__len__", &PyRegionList::dunderLen) 243 .def("__getitem__", &PyRegionList::dunderGetItem); 244 } 245 246 private: 247 PyOperationRef operation; 248 }; 249 250 class PyBlockIterator { 251 public: 252 PyBlockIterator(PyOperationRef operation, MlirBlock next) 253 : operation(std::move(operation)), next(next) {} 254 255 PyBlockIterator &dunderIter() { return *this; } 256 257 PyBlock dunderNext() { 258 operation->checkValid(); 259 if (mlirBlockIsNull(next)) { 260 throw py::stop_iteration(); 261 } 262 263 PyBlock returnBlock(operation, next); 264 next = mlirBlockGetNextInRegion(next); 265 return returnBlock; 266 } 267 268 static void bind(py::module &m) { 269 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 270 .def("__iter__", &PyBlockIterator::dunderIter) 271 .def("__next__", &PyBlockIterator::dunderNext); 272 } 273 274 private: 275 PyOperationRef operation; 276 MlirBlock next; 277 }; 278 279 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 280 /// we present them as a more full-featured list-like container but optimize 281 /// it for forward iteration. Blocks are always owned by a region. 282 class PyBlockList { 283 public: 284 PyBlockList(PyOperationRef operation, MlirRegion region) 285 : operation(std::move(operation)), region(region) {} 286 287 PyBlockIterator dunderIter() { 288 operation->checkValid(); 289 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 290 } 291 292 intptr_t dunderLen() { 293 operation->checkValid(); 294 intptr_t count = 0; 295 MlirBlock block = mlirRegionGetFirstBlock(region); 296 while (!mlirBlockIsNull(block)) { 297 count += 1; 298 block = mlirBlockGetNextInRegion(block); 299 } 300 return count; 301 } 302 303 PyBlock dunderGetItem(intptr_t index) { 304 operation->checkValid(); 305 if (index < 0) { 306 throw SetPyError(PyExc_IndexError, 307 "attempt to access out of bounds block"); 308 } 309 MlirBlock block = mlirRegionGetFirstBlock(region); 310 while (!mlirBlockIsNull(block)) { 311 if (index == 0) { 312 return PyBlock(operation, block); 313 } 314 block = mlirBlockGetNextInRegion(block); 315 index -= 1; 316 } 317 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 318 } 319 320 PyBlock appendBlock(py::args pyArgTypes) { 321 operation->checkValid(); 322 llvm::SmallVector<MlirType, 4> argTypes; 323 argTypes.reserve(pyArgTypes.size()); 324 for (auto &pyArg : pyArgTypes) { 325 argTypes.push_back(pyArg.cast<PyType &>()); 326 } 327 328 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 329 mlirRegionAppendOwnedBlock(region, block); 330 return PyBlock(operation, block); 331 } 332 333 static void bind(py::module &m) { 334 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 335 .def("__getitem__", &PyBlockList::dunderGetItem) 336 .def("__iter__", &PyBlockList::dunderIter) 337 .def("__len__", &PyBlockList::dunderLen) 338 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 339 } 340 341 private: 342 PyOperationRef operation; 343 MlirRegion region; 344 }; 345 346 class PyOperationIterator { 347 public: 348 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 349 : parentOperation(std::move(parentOperation)), next(next) {} 350 351 PyOperationIterator &dunderIter() { return *this; } 352 353 py::object dunderNext() { 354 parentOperation->checkValid(); 355 if (mlirOperationIsNull(next)) { 356 throw py::stop_iteration(); 357 } 358 359 PyOperationRef returnOperation = 360 PyOperation::forOperation(parentOperation->getContext(), next); 361 next = mlirOperationGetNextInBlock(next); 362 return returnOperation->createOpView(); 363 } 364 365 static void bind(py::module &m) { 366 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 367 .def("__iter__", &PyOperationIterator::dunderIter) 368 .def("__next__", &PyOperationIterator::dunderNext); 369 } 370 371 private: 372 PyOperationRef parentOperation; 373 MlirOperation next; 374 }; 375 376 /// Operations are exposed by the C-API as a forward-only linked list. In 377 /// Python, we present them as a more full-featured list-like container but 378 /// optimize it for forward iteration. Iterable operations are always owned 379 /// by a block. 380 class PyOperationList { 381 public: 382 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 383 : parentOperation(std::move(parentOperation)), block(block) {} 384 385 PyOperationIterator dunderIter() { 386 parentOperation->checkValid(); 387 return PyOperationIterator(parentOperation, 388 mlirBlockGetFirstOperation(block)); 389 } 390 391 intptr_t dunderLen() { 392 parentOperation->checkValid(); 393 intptr_t count = 0; 394 MlirOperation childOp = mlirBlockGetFirstOperation(block); 395 while (!mlirOperationIsNull(childOp)) { 396 count += 1; 397 childOp = mlirOperationGetNextInBlock(childOp); 398 } 399 return count; 400 } 401 402 py::object dunderGetItem(intptr_t index) { 403 parentOperation->checkValid(); 404 if (index < 0) { 405 throw SetPyError(PyExc_IndexError, 406 "attempt to access out of bounds operation"); 407 } 408 MlirOperation childOp = mlirBlockGetFirstOperation(block); 409 while (!mlirOperationIsNull(childOp)) { 410 if (index == 0) { 411 return PyOperation::forOperation(parentOperation->getContext(), childOp) 412 ->createOpView(); 413 } 414 childOp = mlirOperationGetNextInBlock(childOp); 415 index -= 1; 416 } 417 throw SetPyError(PyExc_IndexError, 418 "attempt to access out of bounds operation"); 419 } 420 421 static void bind(py::module &m) { 422 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 423 .def("__getitem__", &PyOperationList::dunderGetItem) 424 .def("__iter__", &PyOperationList::dunderIter) 425 .def("__len__", &PyOperationList::dunderLen); 426 } 427 428 private: 429 PyOperationRef parentOperation; 430 MlirBlock block; 431 }; 432 433 } // namespace 434 435 //------------------------------------------------------------------------------ 436 // PyMlirContext 437 //------------------------------------------------------------------------------ 438 439 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 440 py::gil_scoped_acquire acquire; 441 auto &liveContexts = getLiveContexts(); 442 liveContexts[context.ptr] = this; 443 } 444 445 PyMlirContext::~PyMlirContext() { 446 // Note that the only public way to construct an instance is via the 447 // forContext method, which always puts the associated handle into 448 // liveContexts. 449 py::gil_scoped_acquire acquire; 450 getLiveContexts().erase(context.ptr); 451 mlirContextDestroy(context); 452 } 453 454 py::object PyMlirContext::getCapsule() { 455 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 456 } 457 458 py::object PyMlirContext::createFromCapsule(py::object capsule) { 459 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 460 if (mlirContextIsNull(rawContext)) 461 throw py::error_already_set(); 462 return forContext(rawContext).releaseObject(); 463 } 464 465 PyMlirContext *PyMlirContext::createNewContextForInit() { 466 MlirContext context = mlirContextCreate(); 467 mlirRegisterAllDialects(context); 468 return new PyMlirContext(context); 469 } 470 471 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 472 py::gil_scoped_acquire acquire; 473 auto &liveContexts = getLiveContexts(); 474 auto it = liveContexts.find(context.ptr); 475 if (it == liveContexts.end()) { 476 // Create. 477 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 478 py::object pyRef = py::cast(unownedContextWrapper); 479 assert(pyRef && "cast to py::object failed"); 480 liveContexts[context.ptr] = unownedContextWrapper; 481 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 482 } 483 // Use existing. 484 py::object pyRef = py::cast(it->second); 485 return PyMlirContextRef(it->second, std::move(pyRef)); 486 } 487 488 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 489 static LiveContextMap liveContexts; 490 return liveContexts; 491 } 492 493 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 494 495 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 496 497 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 498 499 pybind11::object PyMlirContext::contextEnter() { 500 return PyThreadContextEntry::pushContext(*this); 501 } 502 503 void PyMlirContext::contextExit(pybind11::object excType, 504 pybind11::object excVal, 505 pybind11::object excTb) { 506 PyThreadContextEntry::popContext(*this); 507 } 508 509 PyMlirContext &DefaultingPyMlirContext::resolve() { 510 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 511 if (!context) { 512 throw SetPyError( 513 PyExc_RuntimeError, 514 "An MLIR function requires a Context but none was provided in the call " 515 "or from the surrounding environment. Either pass to the function with " 516 "a 'context=' argument or establish a default using 'with Context():'"); 517 } 518 return *context; 519 } 520 521 //------------------------------------------------------------------------------ 522 // PyThreadContextEntry management 523 //------------------------------------------------------------------------------ 524 525 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 526 static thread_local std::vector<PyThreadContextEntry> stack; 527 return stack; 528 } 529 530 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 531 auto &stack = getStack(); 532 if (stack.empty()) 533 return nullptr; 534 return &stack.back(); 535 } 536 537 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 538 py::object insertionPoint, 539 py::object location) { 540 auto &stack = getStack(); 541 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 542 std::move(location)); 543 // If the new stack has more than one entry and the context of the new top 544 // entry matches the previous, copy the insertionPoint and location from the 545 // previous entry if missing from the new top entry. 546 if (stack.size() > 1) { 547 auto &prev = *(stack.rbegin() + 1); 548 auto ¤t = stack.back(); 549 if (current.context.is(prev.context)) { 550 // Default non-context objects from the previous entry. 551 if (!current.insertionPoint) 552 current.insertionPoint = prev.insertionPoint; 553 if (!current.location) 554 current.location = prev.location; 555 } 556 } 557 } 558 559 PyMlirContext *PyThreadContextEntry::getContext() { 560 if (!context) 561 return nullptr; 562 return py::cast<PyMlirContext *>(context); 563 } 564 565 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 566 if (!insertionPoint) 567 return nullptr; 568 return py::cast<PyInsertionPoint *>(insertionPoint); 569 } 570 571 PyLocation *PyThreadContextEntry::getLocation() { 572 if (!location) 573 return nullptr; 574 return py::cast<PyLocation *>(location); 575 } 576 577 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 578 auto *tos = getTopOfStack(); 579 return tos ? tos->getContext() : nullptr; 580 } 581 582 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 583 auto *tos = getTopOfStack(); 584 return tos ? tos->getInsertionPoint() : nullptr; 585 } 586 587 PyLocation *PyThreadContextEntry::getDefaultLocation() { 588 auto *tos = getTopOfStack(); 589 return tos ? tos->getLocation() : nullptr; 590 } 591 592 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 593 py::object contextObj = py::cast(context); 594 push(FrameKind::Context, /*context=*/contextObj, 595 /*insertionPoint=*/py::object(), 596 /*location=*/py::object()); 597 return contextObj; 598 } 599 600 void PyThreadContextEntry::popContext(PyMlirContext &context) { 601 auto &stack = getStack(); 602 if (stack.empty()) 603 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 604 auto &tos = stack.back(); 605 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 606 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 607 stack.pop_back(); 608 } 609 610 py::object 611 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 612 py::object contextObj = 613 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 614 py::object insertionPointObj = py::cast(insertionPoint); 615 push(FrameKind::InsertionPoint, 616 /*context=*/contextObj, 617 /*insertionPoint=*/insertionPointObj, 618 /*location=*/py::object()); 619 return insertionPointObj; 620 } 621 622 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 623 auto &stack = getStack(); 624 if (stack.empty()) 625 throw SetPyError(PyExc_RuntimeError, 626 "Unbalanced InsertionPoint enter/exit"); 627 auto &tos = stack.back(); 628 if (tos.frameKind != FrameKind::InsertionPoint && 629 tos.getInsertionPoint() != &insertionPoint) 630 throw SetPyError(PyExc_RuntimeError, 631 "Unbalanced InsertionPoint enter/exit"); 632 stack.pop_back(); 633 } 634 635 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 636 py::object contextObj = location.getContext().getObject(); 637 py::object locationObj = py::cast(location); 638 push(FrameKind::Location, /*context=*/contextObj, 639 /*insertionPoint=*/py::object(), 640 /*location=*/locationObj); 641 return locationObj; 642 } 643 644 void PyThreadContextEntry::popLocation(PyLocation &location) { 645 auto &stack = getStack(); 646 if (stack.empty()) 647 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 648 auto &tos = stack.back(); 649 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 650 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 651 stack.pop_back(); 652 } 653 654 //------------------------------------------------------------------------------ 655 // PyDialect, PyDialectDescriptor, PyDialects 656 //------------------------------------------------------------------------------ 657 658 MlirDialect PyDialects::getDialectForKey(const std::string &key, 659 bool attrError) { 660 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 661 {key.data(), key.size()}); 662 if (mlirDialectIsNull(dialect)) { 663 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 664 Twine("Dialect '") + key + "' not found"); 665 } 666 return dialect; 667 } 668 669 //------------------------------------------------------------------------------ 670 // PyLocation 671 //------------------------------------------------------------------------------ 672 673 py::object PyLocation::getCapsule() { 674 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 675 } 676 677 PyLocation PyLocation::createFromCapsule(py::object capsule) { 678 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 679 if (mlirLocationIsNull(rawLoc)) 680 throw py::error_already_set(); 681 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 682 rawLoc); 683 } 684 685 py::object PyLocation::contextEnter() { 686 return PyThreadContextEntry::pushLocation(*this); 687 } 688 689 void PyLocation::contextExit(py::object excType, py::object excVal, 690 py::object excTb) { 691 PyThreadContextEntry::popLocation(*this); 692 } 693 694 PyLocation &DefaultingPyLocation::resolve() { 695 auto *location = PyThreadContextEntry::getDefaultLocation(); 696 if (!location) { 697 throw SetPyError( 698 PyExc_RuntimeError, 699 "An MLIR function requires a Location but none was provided in the " 700 "call or from the surrounding environment. Either pass to the function " 701 "with a 'loc=' argument or establish a default using 'with loc:'"); 702 } 703 return *location; 704 } 705 706 //------------------------------------------------------------------------------ 707 // PyModule 708 //------------------------------------------------------------------------------ 709 710 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 711 : BaseContextObject(std::move(contextRef)), module(module) {} 712 713 PyModule::~PyModule() { 714 py::gil_scoped_acquire acquire; 715 auto &liveModules = getContext()->liveModules; 716 assert(liveModules.count(module.ptr) == 1 && 717 "destroying module not in live map"); 718 liveModules.erase(module.ptr); 719 mlirModuleDestroy(module); 720 } 721 722 PyModuleRef PyModule::forModule(MlirModule module) { 723 MlirContext context = mlirModuleGetContext(module); 724 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 725 726 py::gil_scoped_acquire acquire; 727 auto &liveModules = contextRef->liveModules; 728 auto it = liveModules.find(module.ptr); 729 if (it == liveModules.end()) { 730 // Create. 731 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 732 // Note that the default return value policy on cast is automatic_reference, 733 // which does not take ownership (delete will not be called). 734 // Just be explicit. 735 py::object pyRef = 736 py::cast(unownedModule, py::return_value_policy::take_ownership); 737 unownedModule->handle = pyRef; 738 liveModules[module.ptr] = 739 std::make_pair(unownedModule->handle, unownedModule); 740 return PyModuleRef(unownedModule, std::move(pyRef)); 741 } 742 // Use existing. 743 PyModule *existing = it->second.second; 744 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 745 return PyModuleRef(existing, std::move(pyRef)); 746 } 747 748 py::object PyModule::createFromCapsule(py::object capsule) { 749 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 750 if (mlirModuleIsNull(rawModule)) 751 throw py::error_already_set(); 752 return forModule(rawModule).releaseObject(); 753 } 754 755 py::object PyModule::getCapsule() { 756 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 757 } 758 759 //------------------------------------------------------------------------------ 760 // PyOperation 761 //------------------------------------------------------------------------------ 762 763 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 764 : BaseContextObject(std::move(contextRef)), operation(operation) {} 765 766 PyOperation::~PyOperation() { 767 // If the operation has already been invalidated there is nothing to do. 768 if (!valid) 769 return; 770 auto &liveOperations = getContext()->liveOperations; 771 assert(liveOperations.count(operation.ptr) == 1 && 772 "destroying operation not in live map"); 773 liveOperations.erase(operation.ptr); 774 if (!isAttached()) { 775 mlirOperationDestroy(operation); 776 } 777 } 778 779 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 780 MlirOperation operation, 781 py::object parentKeepAlive) { 782 auto &liveOperations = contextRef->liveOperations; 783 // Create. 784 PyOperation *unownedOperation = 785 new PyOperation(std::move(contextRef), operation); 786 // Note that the default return value policy on cast is automatic_reference, 787 // which does not take ownership (delete will not be called). 788 // Just be explicit. 789 py::object pyRef = 790 py::cast(unownedOperation, py::return_value_policy::take_ownership); 791 unownedOperation->handle = pyRef; 792 if (parentKeepAlive) { 793 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 794 } 795 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 796 return PyOperationRef(unownedOperation, std::move(pyRef)); 797 } 798 799 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 800 MlirOperation operation, 801 py::object parentKeepAlive) { 802 auto &liveOperations = contextRef->liveOperations; 803 auto it = liveOperations.find(operation.ptr); 804 if (it == liveOperations.end()) { 805 // Create. 806 return createInstance(std::move(contextRef), operation, 807 std::move(parentKeepAlive)); 808 } 809 // Use existing. 810 PyOperation *existing = it->second.second; 811 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 812 return PyOperationRef(existing, std::move(pyRef)); 813 } 814 815 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 816 MlirOperation operation, 817 py::object parentKeepAlive) { 818 auto &liveOperations = contextRef->liveOperations; 819 assert(liveOperations.count(operation.ptr) == 0 && 820 "cannot create detached operation that already exists"); 821 (void)liveOperations; 822 823 PyOperationRef created = createInstance(std::move(contextRef), operation, 824 std::move(parentKeepAlive)); 825 created->attached = false; 826 return created; 827 } 828 829 void PyOperation::checkValid() const { 830 if (!valid) { 831 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 832 } 833 } 834 835 void PyOperationBase::print(py::object fileObject, bool binary, 836 llvm::Optional<int64_t> largeElementsLimit, 837 bool enableDebugInfo, bool prettyDebugInfo, 838 bool printGenericOpForm, bool useLocalScope, 839 bool assumeVerified) { 840 PyOperation &operation = getOperation(); 841 operation.checkValid(); 842 if (fileObject.is_none()) 843 fileObject = py::module::import("sys").attr("stdout"); 844 845 if (!assumeVerified && !printGenericOpForm && 846 !mlirOperationVerify(operation)) { 847 std::string message("// Verification failed, printing generic form\n"); 848 if (binary) { 849 fileObject.attr("write")(py::bytes(message)); 850 } else { 851 fileObject.attr("write")(py::str(message)); 852 } 853 printGenericOpForm = true; 854 } 855 856 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 857 if (largeElementsLimit) 858 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 859 if (enableDebugInfo) 860 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 861 if (printGenericOpForm) 862 mlirOpPrintingFlagsPrintGenericOpForm(flags); 863 864 PyFileAccumulator accum(fileObject, binary); 865 py::gil_scoped_release(); 866 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 867 accum.getUserData()); 868 mlirOpPrintingFlagsDestroy(flags); 869 } 870 871 py::object PyOperationBase::getAsm(bool binary, 872 llvm::Optional<int64_t> largeElementsLimit, 873 bool enableDebugInfo, bool prettyDebugInfo, 874 bool printGenericOpForm, bool useLocalScope, 875 bool assumeVerified) { 876 py::object fileObject; 877 if (binary) { 878 fileObject = py::module::import("io").attr("BytesIO")(); 879 } else { 880 fileObject = py::module::import("io").attr("StringIO")(); 881 } 882 print(fileObject, /*binary=*/binary, 883 /*largeElementsLimit=*/largeElementsLimit, 884 /*enableDebugInfo=*/enableDebugInfo, 885 /*prettyDebugInfo=*/prettyDebugInfo, 886 /*printGenericOpForm=*/printGenericOpForm, 887 /*useLocalScope=*/useLocalScope, 888 /*assumeVerified=*/assumeVerified); 889 890 return fileObject.attr("getvalue")(); 891 } 892 893 void PyOperationBase::moveAfter(PyOperationBase &other) { 894 PyOperation &operation = getOperation(); 895 PyOperation &otherOp = other.getOperation(); 896 operation.checkValid(); 897 otherOp.checkValid(); 898 mlirOperationMoveAfter(operation, otherOp); 899 operation.parentKeepAlive = otherOp.parentKeepAlive; 900 } 901 902 void PyOperationBase::moveBefore(PyOperationBase &other) { 903 PyOperation &operation = getOperation(); 904 PyOperation &otherOp = other.getOperation(); 905 operation.checkValid(); 906 otherOp.checkValid(); 907 mlirOperationMoveBefore(operation, otherOp); 908 operation.parentKeepAlive = otherOp.parentKeepAlive; 909 } 910 911 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 912 checkValid(); 913 if (!isAttached()) 914 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 915 MlirOperation operation = mlirOperationGetParentOperation(get()); 916 if (mlirOperationIsNull(operation)) 917 return {}; 918 return PyOperation::forOperation(getContext(), operation); 919 } 920 921 PyBlock PyOperation::getBlock() { 922 checkValid(); 923 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 924 MlirBlock block = mlirOperationGetBlock(get()); 925 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 926 assert(parentOperation && "Operation has no parent"); 927 return PyBlock{std::move(*parentOperation), block}; 928 } 929 930 py::object PyOperation::getCapsule() { 931 checkValid(); 932 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 933 } 934 935 py::object PyOperation::createFromCapsule(py::object capsule) { 936 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 937 if (mlirOperationIsNull(rawOperation)) 938 throw py::error_already_set(); 939 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 940 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 941 .releaseObject(); 942 } 943 944 py::object PyOperation::create( 945 std::string name, llvm::Optional<std::vector<PyType *>> results, 946 llvm::Optional<std::vector<PyValue *>> operands, 947 llvm::Optional<py::dict> attributes, 948 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 949 DefaultingPyLocation location, py::object maybeIp) { 950 llvm::SmallVector<MlirValue, 4> mlirOperands; 951 llvm::SmallVector<MlirType, 4> mlirResults; 952 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 953 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 954 955 // General parameter validation. 956 if (regions < 0) 957 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 958 959 // Unpack/validate operands. 960 if (operands) { 961 mlirOperands.reserve(operands->size()); 962 for (PyValue *operand : *operands) { 963 if (!operand) 964 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 965 mlirOperands.push_back(operand->get()); 966 } 967 } 968 969 // Unpack/validate results. 970 if (results) { 971 mlirResults.reserve(results->size()); 972 for (PyType *result : *results) { 973 // TODO: Verify result type originate from the same context. 974 if (!result) 975 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 976 mlirResults.push_back(*result); 977 } 978 } 979 // Unpack/validate attributes. 980 if (attributes) { 981 mlirAttributes.reserve(attributes->size()); 982 for (auto &it : *attributes) { 983 std::string key; 984 try { 985 key = it.first.cast<std::string>(); 986 } catch (py::cast_error &err) { 987 std::string msg = "Invalid attribute key (not a string) when " 988 "attempting to create the operation \"" + 989 name + "\" (" + err.what() + ")"; 990 throw py::cast_error(msg); 991 } 992 try { 993 auto &attribute = it.second.cast<PyAttribute &>(); 994 // TODO: Verify attribute originates from the same context. 995 mlirAttributes.emplace_back(std::move(key), attribute); 996 } catch (py::reference_cast_error &) { 997 // This exception seems thrown when the value is "None". 998 std::string msg = 999 "Found an invalid (`None`?) attribute value for the key \"" + key + 1000 "\" when attempting to create the operation \"" + name + "\""; 1001 throw py::cast_error(msg); 1002 } catch (py::cast_error &err) { 1003 std::string msg = "Invalid attribute value for the key \"" + key + 1004 "\" when attempting to create the operation \"" + 1005 name + "\" (" + err.what() + ")"; 1006 throw py::cast_error(msg); 1007 } 1008 } 1009 } 1010 // Unpack/validate successors. 1011 if (successors) { 1012 mlirSuccessors.reserve(successors->size()); 1013 for (auto *successor : *successors) { 1014 // TODO: Verify successor originate from the same context. 1015 if (!successor) 1016 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1017 mlirSuccessors.push_back(successor->get()); 1018 } 1019 } 1020 1021 // Apply unpacked/validated to the operation state. Beyond this 1022 // point, exceptions cannot be thrown or else the state will leak. 1023 MlirOperationState state = 1024 mlirOperationStateGet(toMlirStringRef(name), location); 1025 if (!mlirOperands.empty()) 1026 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1027 mlirOperands.data()); 1028 if (!mlirResults.empty()) 1029 mlirOperationStateAddResults(&state, mlirResults.size(), 1030 mlirResults.data()); 1031 if (!mlirAttributes.empty()) { 1032 // Note that the attribute names directly reference bytes in 1033 // mlirAttributes, so that vector must not be changed from here 1034 // on. 1035 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1036 mlirNamedAttributes.reserve(mlirAttributes.size()); 1037 for (auto &it : mlirAttributes) 1038 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1039 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1040 toMlirStringRef(it.first)), 1041 it.second)); 1042 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1043 mlirNamedAttributes.data()); 1044 } 1045 if (!mlirSuccessors.empty()) 1046 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1047 mlirSuccessors.data()); 1048 if (regions) { 1049 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1050 mlirRegions.resize(regions); 1051 for (int i = 0; i < regions; ++i) 1052 mlirRegions[i] = mlirRegionCreate(); 1053 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1054 mlirRegions.data()); 1055 } 1056 1057 // Construct the operation. 1058 MlirOperation operation = mlirOperationCreate(&state); 1059 PyOperationRef created = 1060 PyOperation::createDetached(location->getContext(), operation); 1061 1062 // InsertPoint active? 1063 if (!maybeIp.is(py::cast(false))) { 1064 PyInsertionPoint *ip; 1065 if (maybeIp.is_none()) { 1066 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1067 } else { 1068 ip = py::cast<PyInsertionPoint *>(maybeIp); 1069 } 1070 if (ip) 1071 ip->insert(*created.get()); 1072 } 1073 1074 return created->createOpView(); 1075 } 1076 1077 py::object PyOperation::createOpView() { 1078 checkValid(); 1079 MlirIdentifier ident = mlirOperationGetName(get()); 1080 MlirStringRef identStr = mlirIdentifierStr(ident); 1081 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1082 StringRef(identStr.data, identStr.length)); 1083 if (opViewClass) 1084 return (*opViewClass)(getRef().getObject()); 1085 return py::cast(PyOpView(getRef().getObject())); 1086 } 1087 1088 void PyOperation::erase() { 1089 checkValid(); 1090 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1091 // Python reference to a child operation is live. All children should also 1092 // have their `valid` bit set to false. 1093 auto &liveOperations = getContext()->liveOperations; 1094 if (liveOperations.count(operation.ptr)) 1095 liveOperations.erase(operation.ptr); 1096 mlirOperationDestroy(operation); 1097 valid = false; 1098 } 1099 1100 //------------------------------------------------------------------------------ 1101 // PyOpView 1102 //------------------------------------------------------------------------------ 1103 1104 py::object 1105 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1106 py::list operandList, 1107 llvm::Optional<py::dict> attributes, 1108 llvm::Optional<std::vector<PyBlock *>> successors, 1109 llvm::Optional<int> regions, 1110 DefaultingPyLocation location, py::object maybeIp) { 1111 PyMlirContextRef context = location->getContext(); 1112 // Class level operation construction metadata. 1113 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1114 // Operand and result segment specs are either none, which does no 1115 // variadic unpacking, or a list of ints with segment sizes, where each 1116 // element is either a positive number (typically 1 for a scalar) or -1 to 1117 // indicate that it is derived from the length of the same-indexed operand 1118 // or result (implying that it is a list at that position). 1119 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1120 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1121 1122 std::vector<uint32_t> operandSegmentLengths; 1123 std::vector<uint32_t> resultSegmentLengths; 1124 1125 // Validate/determine region count. 1126 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1127 int opMinRegionCount = std::get<0>(opRegionSpec); 1128 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1129 if (!regions) { 1130 regions = opMinRegionCount; 1131 } 1132 if (*regions < opMinRegionCount) { 1133 throw py::value_error( 1134 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1135 llvm::Twine(opMinRegionCount) + 1136 " regions but was built with regions=" + llvm::Twine(*regions)) 1137 .str()); 1138 } 1139 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1140 throw py::value_error( 1141 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1142 llvm::Twine(opMinRegionCount) + 1143 " regions but was built with regions=" + llvm::Twine(*regions)) 1144 .str()); 1145 } 1146 1147 // Unpack results. 1148 std::vector<PyType *> resultTypes; 1149 resultTypes.reserve(resultTypeList.size()); 1150 if (resultSegmentSpecObj.is_none()) { 1151 // Non-variadic result unpacking. 1152 for (auto it : llvm::enumerate(resultTypeList)) { 1153 try { 1154 resultTypes.push_back(py::cast<PyType *>(it.value())); 1155 if (!resultTypes.back()) 1156 throw py::cast_error(); 1157 } catch (py::cast_error &err) { 1158 throw py::value_error((llvm::Twine("Result ") + 1159 llvm::Twine(it.index()) + " of operation \"" + 1160 name + "\" must be a Type (" + err.what() + ")") 1161 .str()); 1162 } 1163 } 1164 } else { 1165 // Sized result unpacking. 1166 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1167 if (resultSegmentSpec.size() != resultTypeList.size()) { 1168 throw py::value_error((llvm::Twine("Operation \"") + name + 1169 "\" requires " + 1170 llvm::Twine(resultSegmentSpec.size()) + 1171 " result segments but was provided " + 1172 llvm::Twine(resultTypeList.size())) 1173 .str()); 1174 } 1175 resultSegmentLengths.reserve(resultTypeList.size()); 1176 for (auto it : 1177 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1178 int segmentSpec = std::get<1>(it.value()); 1179 if (segmentSpec == 1 || segmentSpec == 0) { 1180 // Unpack unary element. 1181 try { 1182 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1183 if (resultType) { 1184 resultTypes.push_back(resultType); 1185 resultSegmentLengths.push_back(1); 1186 } else if (segmentSpec == 0) { 1187 // Allowed to be optional. 1188 resultSegmentLengths.push_back(0); 1189 } else { 1190 throw py::cast_error("was None and result is not optional"); 1191 } 1192 } catch (py::cast_error &err) { 1193 throw py::value_error((llvm::Twine("Result ") + 1194 llvm::Twine(it.index()) + " of operation \"" + 1195 name + "\" must be a Type (" + err.what() + 1196 ")") 1197 .str()); 1198 } 1199 } else if (segmentSpec == -1) { 1200 // Unpack sequence by appending. 1201 try { 1202 if (std::get<0>(it.value()).is_none()) { 1203 // Treat it as an empty list. 1204 resultSegmentLengths.push_back(0); 1205 } else { 1206 // Unpack the list. 1207 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1208 for (py::object segmentItem : segment) { 1209 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1210 if (!resultTypes.back()) { 1211 throw py::cast_error("contained a None item"); 1212 } 1213 } 1214 resultSegmentLengths.push_back(segment.size()); 1215 } 1216 } catch (std::exception &err) { 1217 // NOTE: Sloppy to be using a catch-all here, but there are at least 1218 // three different unrelated exceptions that can be thrown in the 1219 // above "casts". Just keep the scope above small and catch them all. 1220 throw py::value_error((llvm::Twine("Result ") + 1221 llvm::Twine(it.index()) + " of operation \"" + 1222 name + "\" must be a Sequence of Types (" + 1223 err.what() + ")") 1224 .str()); 1225 } 1226 } else { 1227 throw py::value_error("Unexpected segment spec"); 1228 } 1229 } 1230 } 1231 1232 // Unpack operands. 1233 std::vector<PyValue *> operands; 1234 operands.reserve(operands.size()); 1235 if (operandSegmentSpecObj.is_none()) { 1236 // Non-sized operand unpacking. 1237 for (auto it : llvm::enumerate(operandList)) { 1238 try { 1239 operands.push_back(py::cast<PyValue *>(it.value())); 1240 if (!operands.back()) 1241 throw py::cast_error(); 1242 } catch (py::cast_error &err) { 1243 throw py::value_error((llvm::Twine("Operand ") + 1244 llvm::Twine(it.index()) + " of operation \"" + 1245 name + "\" must be a Value (" + err.what() + ")") 1246 .str()); 1247 } 1248 } 1249 } else { 1250 // Sized operand unpacking. 1251 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1252 if (operandSegmentSpec.size() != operandList.size()) { 1253 throw py::value_error((llvm::Twine("Operation \"") + name + 1254 "\" requires " + 1255 llvm::Twine(operandSegmentSpec.size()) + 1256 "operand segments but was provided " + 1257 llvm::Twine(operandList.size())) 1258 .str()); 1259 } 1260 operandSegmentLengths.reserve(operandList.size()); 1261 for (auto it : 1262 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1263 int segmentSpec = std::get<1>(it.value()); 1264 if (segmentSpec == 1 || segmentSpec == 0) { 1265 // Unpack unary element. 1266 try { 1267 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1268 if (operandValue) { 1269 operands.push_back(operandValue); 1270 operandSegmentLengths.push_back(1); 1271 } else if (segmentSpec == 0) { 1272 // Allowed to be optional. 1273 operandSegmentLengths.push_back(0); 1274 } else { 1275 throw py::cast_error("was None and operand is not optional"); 1276 } 1277 } catch (py::cast_error &err) { 1278 throw py::value_error((llvm::Twine("Operand ") + 1279 llvm::Twine(it.index()) + " of operation \"" + 1280 name + "\" must be a Value (" + err.what() + 1281 ")") 1282 .str()); 1283 } 1284 } else if (segmentSpec == -1) { 1285 // Unpack sequence by appending. 1286 try { 1287 if (std::get<0>(it.value()).is_none()) { 1288 // Treat it as an empty list. 1289 operandSegmentLengths.push_back(0); 1290 } else { 1291 // Unpack the list. 1292 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1293 for (py::object segmentItem : segment) { 1294 operands.push_back(py::cast<PyValue *>(segmentItem)); 1295 if (!operands.back()) { 1296 throw py::cast_error("contained a None item"); 1297 } 1298 } 1299 operandSegmentLengths.push_back(segment.size()); 1300 } 1301 } catch (std::exception &err) { 1302 // NOTE: Sloppy to be using a catch-all here, but there are at least 1303 // three different unrelated exceptions that can be thrown in the 1304 // above "casts". Just keep the scope above small and catch them all. 1305 throw py::value_error((llvm::Twine("Operand ") + 1306 llvm::Twine(it.index()) + " of operation \"" + 1307 name + "\" must be a Sequence of Values (" + 1308 err.what() + ")") 1309 .str()); 1310 } 1311 } else { 1312 throw py::value_error("Unexpected segment spec"); 1313 } 1314 } 1315 } 1316 1317 // Merge operand/result segment lengths into attributes if needed. 1318 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1319 // Dup. 1320 if (attributes) { 1321 attributes = py::dict(*attributes); 1322 } else { 1323 attributes = py::dict(); 1324 } 1325 if (attributes->contains("result_segment_sizes") || 1326 attributes->contains("operand_segment_sizes")) { 1327 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1328 "'operand_segment_sizes' attribute is unsupported. " 1329 "Use Operation.create for such low-level access."); 1330 } 1331 1332 // Add result_segment_sizes attribute. 1333 if (!resultSegmentLengths.empty()) { 1334 int64_t size = resultSegmentLengths.size(); 1335 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1336 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1337 resultSegmentLengths.size(), resultSegmentLengths.data()); 1338 (*attributes)["result_segment_sizes"] = 1339 PyAttribute(context, segmentLengthAttr); 1340 } 1341 1342 // Add operand_segment_sizes attribute. 1343 if (!operandSegmentLengths.empty()) { 1344 int64_t size = operandSegmentLengths.size(); 1345 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1346 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1347 operandSegmentLengths.size(), operandSegmentLengths.data()); 1348 (*attributes)["operand_segment_sizes"] = 1349 PyAttribute(context, segmentLengthAttr); 1350 } 1351 } 1352 1353 // Delegate to create. 1354 return PyOperation::create(std::move(name), 1355 /*results=*/std::move(resultTypes), 1356 /*operands=*/std::move(operands), 1357 /*attributes=*/std::move(attributes), 1358 /*successors=*/std::move(successors), 1359 /*regions=*/*regions, location, maybeIp); 1360 } 1361 1362 PyOpView::PyOpView(py::object operationObject) 1363 // Casting through the PyOperationBase base-class and then back to the 1364 // Operation lets us accept any PyOperationBase subclass. 1365 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1366 operationObject(operation.getRef().getObject()) {} 1367 1368 py::object PyOpView::createRawSubclass(py::object userClass) { 1369 // This is... a little gross. The typical pattern is to have a pure python 1370 // class that extends OpView like: 1371 // class AddFOp(_cext.ir.OpView): 1372 // def __init__(self, loc, lhs, rhs): 1373 // operation = loc.context.create_operation( 1374 // "addf", lhs, rhs, results=[lhs.type]) 1375 // super().__init__(operation) 1376 // 1377 // I.e. The goal of the user facing type is to provide a nice constructor 1378 // that has complete freedom for the op under construction. This is at odds 1379 // with our other desire to sometimes create this object by just passing an 1380 // operation (to initialize the base class). We could do *arg and **kwargs 1381 // munging to try to make it work, but instead, we synthesize a new class 1382 // on the fly which extends this user class (AddFOp in this example) and 1383 // *give it* the base class's __init__ method, thus bypassing the 1384 // intermediate subclass's __init__ method entirely. While slightly, 1385 // underhanded, this is safe/legal because the type hierarchy has not changed 1386 // (we just added a new leaf) and we aren't mucking around with __new__. 1387 // Typically, this new class will be stored on the original as "_Raw" and will 1388 // be used for casts and other things that need a variant of the class that 1389 // is initialized purely from an operation. 1390 py::object parentMetaclass = 1391 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1392 py::dict attributes; 1393 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1394 // now. 1395 // auto opViewType = py::type::of<PyOpView>(); 1396 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1397 attributes["__init__"] = opViewType.attr("__init__"); 1398 py::str origName = userClass.attr("__name__"); 1399 py::str newName = py::str("_") + origName; 1400 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1401 } 1402 1403 //------------------------------------------------------------------------------ 1404 // PyInsertionPoint. 1405 //------------------------------------------------------------------------------ 1406 1407 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1408 1409 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1410 : refOperation(beforeOperationBase.getOperation().getRef()), 1411 block((*refOperation)->getBlock()) {} 1412 1413 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1414 PyOperation &operation = operationBase.getOperation(); 1415 if (operation.isAttached()) 1416 throw SetPyError(PyExc_ValueError, 1417 "Attempt to insert operation that is already attached"); 1418 block.getParentOperation()->checkValid(); 1419 MlirOperation beforeOp = {nullptr}; 1420 if (refOperation) { 1421 // Insert before operation. 1422 (*refOperation)->checkValid(); 1423 beforeOp = (*refOperation)->get(); 1424 } else { 1425 // Insert at end (before null) is only valid if the block does not 1426 // already end in a known terminator (violating this will cause assertion 1427 // failures later). 1428 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1429 throw py::index_error("Cannot insert operation at the end of a block " 1430 "that already has a terminator. Did you mean to " 1431 "use 'InsertionPoint.at_block_terminator(block)' " 1432 "versus 'InsertionPoint(block)'?"); 1433 } 1434 } 1435 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1436 operation.setAttached(); 1437 } 1438 1439 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1440 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1441 if (mlirOperationIsNull(firstOp)) { 1442 // Just insert at end. 1443 return PyInsertionPoint(block); 1444 } 1445 1446 // Insert before first op. 1447 PyOperationRef firstOpRef = PyOperation::forOperation( 1448 block.getParentOperation()->getContext(), firstOp); 1449 return PyInsertionPoint{block, std::move(firstOpRef)}; 1450 } 1451 1452 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1453 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1454 if (mlirOperationIsNull(terminator)) 1455 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1456 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1457 block.getParentOperation()->getContext(), terminator); 1458 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1459 } 1460 1461 py::object PyInsertionPoint::contextEnter() { 1462 return PyThreadContextEntry::pushInsertionPoint(*this); 1463 } 1464 1465 void PyInsertionPoint::contextExit(pybind11::object excType, 1466 pybind11::object excVal, 1467 pybind11::object excTb) { 1468 PyThreadContextEntry::popInsertionPoint(*this); 1469 } 1470 1471 //------------------------------------------------------------------------------ 1472 // PyAttribute. 1473 //------------------------------------------------------------------------------ 1474 1475 bool PyAttribute::operator==(const PyAttribute &other) { 1476 return mlirAttributeEqual(attr, other.attr); 1477 } 1478 1479 py::object PyAttribute::getCapsule() { 1480 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1481 } 1482 1483 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1484 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1485 if (mlirAttributeIsNull(rawAttr)) 1486 throw py::error_already_set(); 1487 return PyAttribute( 1488 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1489 } 1490 1491 //------------------------------------------------------------------------------ 1492 // PyNamedAttribute. 1493 //------------------------------------------------------------------------------ 1494 1495 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1496 : ownedName(new std::string(std::move(ownedName))) { 1497 namedAttr = mlirNamedAttributeGet( 1498 mlirIdentifierGet(mlirAttributeGetContext(attr), 1499 toMlirStringRef(*this->ownedName)), 1500 attr); 1501 } 1502 1503 //------------------------------------------------------------------------------ 1504 // PyType. 1505 //------------------------------------------------------------------------------ 1506 1507 bool PyType::operator==(const PyType &other) { 1508 return mlirTypeEqual(type, other.type); 1509 } 1510 1511 py::object PyType::getCapsule() { 1512 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1513 } 1514 1515 PyType PyType::createFromCapsule(py::object capsule) { 1516 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1517 if (mlirTypeIsNull(rawType)) 1518 throw py::error_already_set(); 1519 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1520 rawType); 1521 } 1522 1523 //------------------------------------------------------------------------------ 1524 // PyValue and subclases. 1525 //------------------------------------------------------------------------------ 1526 1527 pybind11::object PyValue::getCapsule() { 1528 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1529 } 1530 1531 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1532 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1533 if (mlirValueIsNull(value)) 1534 throw py::error_already_set(); 1535 MlirOperation owner; 1536 if (mlirValueIsAOpResult(value)) 1537 owner = mlirOpResultGetOwner(value); 1538 if (mlirValueIsABlockArgument(value)) 1539 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1540 if (mlirOperationIsNull(owner)) 1541 throw py::error_already_set(); 1542 MlirContext ctx = mlirOperationGetContext(owner); 1543 PyOperationRef ownerRef = 1544 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1545 return PyValue(ownerRef, value); 1546 } 1547 1548 //------------------------------------------------------------------------------ 1549 // PySymbolTable. 1550 //------------------------------------------------------------------------------ 1551 1552 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1553 : operation(operation.getOperation().getRef()) { 1554 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1555 if (mlirSymbolTableIsNull(symbolTable)) { 1556 throw py::cast_error("Operation is not a Symbol Table."); 1557 } 1558 } 1559 1560 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1561 operation->checkValid(); 1562 MlirOperation symbol = mlirSymbolTableLookup( 1563 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1564 if (mlirOperationIsNull(symbol)) 1565 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1566 1567 return PyOperation::forOperation(operation->getContext(), symbol, 1568 operation.getObject()) 1569 ->createOpView(); 1570 } 1571 1572 void PySymbolTable::erase(PyOperationBase &symbol) { 1573 operation->checkValid(); 1574 symbol.getOperation().checkValid(); 1575 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1576 // The operation is also erased, so we must invalidate it. There may be Python 1577 // references to this operation so we don't want to delete it from the list of 1578 // live operations here. 1579 symbol.getOperation().valid = false; 1580 } 1581 1582 void PySymbolTable::dunderDel(const std::string &name) { 1583 py::object operation = dunderGetItem(name); 1584 erase(py::cast<PyOperationBase &>(operation)); 1585 } 1586 1587 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1588 operation->checkValid(); 1589 symbol.getOperation().checkValid(); 1590 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1591 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1592 if (mlirAttributeIsNull(symbolAttr)) 1593 throw py::value_error("Expected operation to have a symbol name."); 1594 return PyAttribute( 1595 symbol.getOperation().getContext(), 1596 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1597 } 1598 1599 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1600 // Op must already be a symbol. 1601 PyOperation &operation = symbol.getOperation(); 1602 operation.checkValid(); 1603 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1604 MlirAttribute existingNameAttr = 1605 mlirOperationGetAttributeByName(operation.get(), attrName); 1606 if (mlirAttributeIsNull(existingNameAttr)) 1607 throw py::value_error("Expected operation to have a symbol name."); 1608 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1609 } 1610 1611 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1612 const std::string &name) { 1613 // Op must already be a symbol. 1614 PyOperation &operation = symbol.getOperation(); 1615 operation.checkValid(); 1616 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1617 MlirAttribute existingNameAttr = 1618 mlirOperationGetAttributeByName(operation.get(), attrName); 1619 if (mlirAttributeIsNull(existingNameAttr)) 1620 throw py::value_error("Expected operation to have a symbol name."); 1621 MlirAttribute newNameAttr = 1622 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1623 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1624 } 1625 1626 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1627 PyOperation &operation = symbol.getOperation(); 1628 operation.checkValid(); 1629 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1630 MlirAttribute existingVisAttr = 1631 mlirOperationGetAttributeByName(operation.get(), attrName); 1632 if (mlirAttributeIsNull(existingVisAttr)) 1633 throw py::value_error("Expected operation to have a symbol visibility."); 1634 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1635 } 1636 1637 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1638 const std::string &visibility) { 1639 if (visibility != "public" && visibility != "private" && 1640 visibility != "nested") 1641 throw py::value_error( 1642 "Expected visibility to be 'public', 'private' or 'nested'"); 1643 PyOperation &operation = symbol.getOperation(); 1644 operation.checkValid(); 1645 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1646 MlirAttribute existingVisAttr = 1647 mlirOperationGetAttributeByName(operation.get(), attrName); 1648 if (mlirAttributeIsNull(existingVisAttr)) 1649 throw py::value_error("Expected operation to have a symbol visibility."); 1650 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1651 toMlirStringRef(visibility)); 1652 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1653 } 1654 1655 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1656 const std::string &newSymbol, 1657 PyOperationBase &from) { 1658 PyOperation &fromOperation = from.getOperation(); 1659 fromOperation.checkValid(); 1660 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1661 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1662 from.getOperation()))) 1663 1664 throw py::value_error("Symbol rename failed"); 1665 } 1666 1667 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1668 bool allSymUsesVisible, 1669 py::object callback) { 1670 PyOperation &fromOperation = from.getOperation(); 1671 fromOperation.checkValid(); 1672 struct UserData { 1673 PyMlirContextRef context; 1674 py::object callback; 1675 bool gotException; 1676 std::string exceptionWhat; 1677 py::object exceptionType; 1678 }; 1679 UserData userData{ 1680 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1681 mlirSymbolTableWalkSymbolTables( 1682 fromOperation.get(), allSymUsesVisible, 1683 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1684 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1685 auto pyFoundOp = 1686 PyOperation::forOperation(calleeUserData->context, foundOp); 1687 if (calleeUserData->gotException) 1688 return; 1689 try { 1690 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1691 } catch (py::error_already_set &e) { 1692 calleeUserData->gotException = true; 1693 calleeUserData->exceptionWhat = e.what(); 1694 calleeUserData->exceptionType = e.type(); 1695 } 1696 }, 1697 static_cast<void *>(&userData)); 1698 if (userData.gotException) { 1699 std::string message("Exception raised in callback: "); 1700 message.append(userData.exceptionWhat); 1701 throw std::runtime_error(std::move(message)); 1702 } 1703 } 1704 1705 namespace { 1706 /// CRTP base class for Python MLIR values that subclass Value and should be 1707 /// castable from it. The value hierarchy is one level deep and is not supposed 1708 /// to accommodate other levels unless core MLIR changes. 1709 template <typename DerivedTy> 1710 class PyConcreteValue : public PyValue { 1711 public: 1712 // Derived classes must define statics for: 1713 // IsAFunctionTy isaFunction 1714 // const char *pyClassName 1715 // and redefine bindDerived. 1716 using ClassTy = py::class_<DerivedTy, PyValue>; 1717 using IsAFunctionTy = bool (*)(MlirValue); 1718 1719 PyConcreteValue() = default; 1720 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1721 : PyValue(operationRef, value) {} 1722 PyConcreteValue(PyValue &orig) 1723 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1724 1725 /// Attempts to cast the original value to the derived type and throws on 1726 /// type mismatches. 1727 static MlirValue castFrom(PyValue &orig) { 1728 if (!DerivedTy::isaFunction(orig.get())) { 1729 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1730 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1731 DerivedTy::pyClassName + 1732 " (from " + origRepr + ")"); 1733 } 1734 return orig.get(); 1735 } 1736 1737 /// Binds the Python module objects to functions of this class. 1738 static void bind(py::module &m) { 1739 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1740 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1741 cls.def_static( 1742 "isinstance", 1743 [](PyValue &otherValue) -> bool { 1744 return DerivedTy::isaFunction(otherValue); 1745 }, 1746 py::arg("other_value")); 1747 DerivedTy::bindDerived(cls); 1748 } 1749 1750 /// Implemented by derived classes to add methods to the Python subclass. 1751 static void bindDerived(ClassTy &m) {} 1752 }; 1753 1754 /// Python wrapper for MlirBlockArgument. 1755 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1756 public: 1757 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1758 static constexpr const char *pyClassName = "BlockArgument"; 1759 using PyConcreteValue::PyConcreteValue; 1760 1761 static void bindDerived(ClassTy &c) { 1762 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1763 return PyBlock(self.getParentOperation(), 1764 mlirBlockArgumentGetOwner(self.get())); 1765 }); 1766 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1767 return mlirBlockArgumentGetArgNumber(self.get()); 1768 }); 1769 c.def( 1770 "set_type", 1771 [](PyBlockArgument &self, PyType type) { 1772 return mlirBlockArgumentSetType(self.get(), type); 1773 }, 1774 py::arg("type")); 1775 } 1776 }; 1777 1778 /// Python wrapper for MlirOpResult. 1779 class PyOpResult : public PyConcreteValue<PyOpResult> { 1780 public: 1781 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1782 static constexpr const char *pyClassName = "OpResult"; 1783 using PyConcreteValue::PyConcreteValue; 1784 1785 static void bindDerived(ClassTy &c) { 1786 c.def_property_readonly("owner", [](PyOpResult &self) { 1787 assert( 1788 mlirOperationEqual(self.getParentOperation()->get(), 1789 mlirOpResultGetOwner(self.get())) && 1790 "expected the owner of the value in Python to match that in the IR"); 1791 return self.getParentOperation().getObject(); 1792 }); 1793 c.def_property_readonly("result_number", [](PyOpResult &self) { 1794 return mlirOpResultGetResultNumber(self.get()); 1795 }); 1796 } 1797 }; 1798 1799 /// Returns the list of types of the values held by container. 1800 template <typename Container> 1801 static std::vector<PyType> getValueTypes(Container &container, 1802 PyMlirContextRef &context) { 1803 std::vector<PyType> result; 1804 result.reserve(container.getNumElements()); 1805 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1806 result.push_back( 1807 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1808 } 1809 return result; 1810 } 1811 1812 /// A list of block arguments. Internally, these are stored as consecutive 1813 /// elements, random access is cheap. The argument list is associated with the 1814 /// operation that contains the block (detached blocks are not allowed in 1815 /// Python bindings) and extends its lifetime. 1816 class PyBlockArgumentList 1817 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1818 public: 1819 static constexpr const char *pyClassName = "BlockArgumentList"; 1820 1821 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1822 intptr_t startIndex = 0, intptr_t length = -1, 1823 intptr_t step = 1) 1824 : Sliceable(startIndex, 1825 length == -1 ? mlirBlockGetNumArguments(block) : length, 1826 step), 1827 operation(std::move(operation)), block(block) {} 1828 1829 /// Returns the number of arguments in the list. 1830 intptr_t getNumElements() { 1831 operation->checkValid(); 1832 return mlirBlockGetNumArguments(block); 1833 } 1834 1835 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1836 PyBlockArgument getElement(intptr_t pos) { 1837 MlirValue argument = mlirBlockGetArgument(block, pos); 1838 return PyBlockArgument(operation, argument); 1839 } 1840 1841 /// Returns a sublist of this list. 1842 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1843 intptr_t step) { 1844 return PyBlockArgumentList(operation, block, startIndex, length, step); 1845 } 1846 1847 static void bindDerived(ClassTy &c) { 1848 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1849 return getValueTypes(self, self.operation->getContext()); 1850 }); 1851 } 1852 1853 private: 1854 PyOperationRef operation; 1855 MlirBlock block; 1856 }; 1857 1858 /// A list of operation operands. Internally, these are stored as consecutive 1859 /// elements, random access is cheap. The result list is associated with the 1860 /// operation whose results these are, and extends the lifetime of this 1861 /// operation. 1862 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1863 public: 1864 static constexpr const char *pyClassName = "OpOperandList"; 1865 1866 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1867 intptr_t length = -1, intptr_t step = 1) 1868 : Sliceable(startIndex, 1869 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1870 : length, 1871 step), 1872 operation(operation) {} 1873 1874 intptr_t getNumElements() { 1875 operation->checkValid(); 1876 return mlirOperationGetNumOperands(operation->get()); 1877 } 1878 1879 PyValue getElement(intptr_t pos) { 1880 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1881 MlirOperation owner; 1882 if (mlirValueIsAOpResult(operand)) 1883 owner = mlirOpResultGetOwner(operand); 1884 else if (mlirValueIsABlockArgument(operand)) 1885 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1886 else 1887 assert(false && "Value must be an block arg or op result."); 1888 PyOperationRef pyOwner = 1889 PyOperation::forOperation(operation->getContext(), owner); 1890 return PyValue(pyOwner, operand); 1891 } 1892 1893 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1894 return PyOpOperandList(operation, startIndex, length, step); 1895 } 1896 1897 void dunderSetItem(intptr_t index, PyValue value) { 1898 index = wrapIndex(index); 1899 mlirOperationSetOperand(operation->get(), index, value.get()); 1900 } 1901 1902 static void bindDerived(ClassTy &c) { 1903 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1904 } 1905 1906 private: 1907 PyOperationRef operation; 1908 }; 1909 1910 /// A list of operation results. Internally, these are stored as consecutive 1911 /// elements, random access is cheap. The result list is associated with the 1912 /// operation whose results these are, and extends the lifetime of this 1913 /// operation. 1914 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1915 public: 1916 static constexpr const char *pyClassName = "OpResultList"; 1917 1918 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1919 intptr_t length = -1, intptr_t step = 1) 1920 : Sliceable(startIndex, 1921 length == -1 ? mlirOperationGetNumResults(operation->get()) 1922 : length, 1923 step), 1924 operation(operation) {} 1925 1926 intptr_t getNumElements() { 1927 operation->checkValid(); 1928 return mlirOperationGetNumResults(operation->get()); 1929 } 1930 1931 PyOpResult getElement(intptr_t index) { 1932 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1933 return PyOpResult(value); 1934 } 1935 1936 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1937 return PyOpResultList(operation, startIndex, length, step); 1938 } 1939 1940 static void bindDerived(ClassTy &c) { 1941 c.def_property_readonly("types", [](PyOpResultList &self) { 1942 return getValueTypes(self, self.operation->getContext()); 1943 }); 1944 } 1945 1946 private: 1947 PyOperationRef operation; 1948 }; 1949 1950 /// A list of operation attributes. Can be indexed by name, producing 1951 /// attributes, or by index, producing named attributes. 1952 class PyOpAttributeMap { 1953 public: 1954 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1955 1956 PyAttribute dunderGetItemNamed(const std::string &name) { 1957 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1958 toMlirStringRef(name)); 1959 if (mlirAttributeIsNull(attr)) { 1960 throw SetPyError(PyExc_KeyError, 1961 "attempt to access a non-existent attribute"); 1962 } 1963 return PyAttribute(operation->getContext(), attr); 1964 } 1965 1966 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1967 if (index < 0 || index >= dunderLen()) { 1968 throw SetPyError(PyExc_IndexError, 1969 "attempt to access out of bounds attribute"); 1970 } 1971 MlirNamedAttribute namedAttr = 1972 mlirOperationGetAttribute(operation->get(), index); 1973 return PyNamedAttribute( 1974 namedAttr.attribute, 1975 std::string(mlirIdentifierStr(namedAttr.name).data, 1976 mlirIdentifierStr(namedAttr.name).length)); 1977 } 1978 1979 void dunderSetItem(const std::string &name, PyAttribute attr) { 1980 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1981 attr); 1982 } 1983 1984 void dunderDelItem(const std::string &name) { 1985 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1986 toMlirStringRef(name)); 1987 if (!removed) 1988 throw SetPyError(PyExc_KeyError, 1989 "attempt to delete a non-existent attribute"); 1990 } 1991 1992 intptr_t dunderLen() { 1993 return mlirOperationGetNumAttributes(operation->get()); 1994 } 1995 1996 bool dunderContains(const std::string &name) { 1997 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1998 operation->get(), toMlirStringRef(name))); 1999 } 2000 2001 static void bind(py::module &m) { 2002 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2003 .def("__contains__", &PyOpAttributeMap::dunderContains) 2004 .def("__len__", &PyOpAttributeMap::dunderLen) 2005 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2006 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2007 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2008 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2009 } 2010 2011 private: 2012 PyOperationRef operation; 2013 }; 2014 2015 } // end namespace 2016 2017 //------------------------------------------------------------------------------ 2018 // Populates the core exports of the 'ir' submodule. 2019 //------------------------------------------------------------------------------ 2020 2021 void mlir::python::populateIRCore(py::module &m) { 2022 //---------------------------------------------------------------------------- 2023 // Mapping of MlirContext. 2024 //---------------------------------------------------------------------------- 2025 py::class_<PyMlirContext>(m, "Context", py::module_local()) 2026 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2027 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2028 .def("_get_context_again", 2029 [](PyMlirContext &self) { 2030 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2031 return ref.releaseObject(); 2032 }) 2033 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2034 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2035 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2036 &PyMlirContext::getCapsule) 2037 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2038 .def("__enter__", &PyMlirContext::contextEnter) 2039 .def("__exit__", &PyMlirContext::contextExit) 2040 .def_property_readonly_static( 2041 "current", 2042 [](py::object & /*class*/) { 2043 auto *context = PyThreadContextEntry::getDefaultContext(); 2044 if (!context) 2045 throw SetPyError(PyExc_ValueError, "No current Context"); 2046 return context; 2047 }, 2048 "Gets the Context bound to the current thread or raises ValueError") 2049 .def_property_readonly( 2050 "dialects", 2051 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2052 "Gets a container for accessing dialects by name") 2053 .def_property_readonly( 2054 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2055 "Alias for 'dialect'") 2056 .def( 2057 "get_dialect_descriptor", 2058 [=](PyMlirContext &self, std::string &name) { 2059 MlirDialect dialect = mlirContextGetOrLoadDialect( 2060 self.get(), {name.data(), name.size()}); 2061 if (mlirDialectIsNull(dialect)) { 2062 throw SetPyError(PyExc_ValueError, 2063 Twine("Dialect '") + name + "' not found"); 2064 } 2065 return PyDialectDescriptor(self.getRef(), dialect); 2066 }, 2067 py::arg("dialect_name"), 2068 "Gets or loads a dialect by name, returning its descriptor object") 2069 .def_property( 2070 "allow_unregistered_dialects", 2071 [](PyMlirContext &self) -> bool { 2072 return mlirContextGetAllowUnregisteredDialects(self.get()); 2073 }, 2074 [](PyMlirContext &self, bool value) { 2075 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2076 }) 2077 .def( 2078 "enable_multithreading", 2079 [](PyMlirContext &self, bool enable) { 2080 mlirContextEnableMultithreading(self.get(), enable); 2081 }, 2082 py::arg("enable")) 2083 .def( 2084 "is_registered_operation", 2085 [](PyMlirContext &self, std::string &name) { 2086 return mlirContextIsRegisteredOperation( 2087 self.get(), MlirStringRef{name.data(), name.size()}); 2088 }, 2089 py::arg("operation_name")); 2090 2091 //---------------------------------------------------------------------------- 2092 // Mapping of PyDialectDescriptor 2093 //---------------------------------------------------------------------------- 2094 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2095 .def_property_readonly("namespace", 2096 [](PyDialectDescriptor &self) { 2097 MlirStringRef ns = 2098 mlirDialectGetNamespace(self.get()); 2099 return py::str(ns.data, ns.length); 2100 }) 2101 .def("__repr__", [](PyDialectDescriptor &self) { 2102 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2103 std::string repr("<DialectDescriptor "); 2104 repr.append(ns.data, ns.length); 2105 repr.append(">"); 2106 return repr; 2107 }); 2108 2109 //---------------------------------------------------------------------------- 2110 // Mapping of PyDialects 2111 //---------------------------------------------------------------------------- 2112 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2113 .def("__getitem__", 2114 [=](PyDialects &self, std::string keyName) { 2115 MlirDialect dialect = 2116 self.getDialectForKey(keyName, /*attrError=*/false); 2117 py::object descriptor = 2118 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2119 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2120 }) 2121 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2122 MlirDialect dialect = 2123 self.getDialectForKey(attrName, /*attrError=*/true); 2124 py::object descriptor = 2125 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2126 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2127 }); 2128 2129 //---------------------------------------------------------------------------- 2130 // Mapping of PyDialect 2131 //---------------------------------------------------------------------------- 2132 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2133 .def(py::init<py::object>(), py::arg("descriptor")) 2134 .def_property_readonly( 2135 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2136 .def("__repr__", [](py::object self) { 2137 auto clazz = self.attr("__class__"); 2138 return py::str("<Dialect ") + 2139 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2140 clazz.attr("__module__") + py::str(".") + 2141 clazz.attr("__name__") + py::str(")>"); 2142 }); 2143 2144 //---------------------------------------------------------------------------- 2145 // Mapping of Location 2146 //---------------------------------------------------------------------------- 2147 py::class_<PyLocation>(m, "Location", py::module_local()) 2148 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2149 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2150 .def("__enter__", &PyLocation::contextEnter) 2151 .def("__exit__", &PyLocation::contextExit) 2152 .def("__eq__", 2153 [](PyLocation &self, PyLocation &other) -> bool { 2154 return mlirLocationEqual(self, other); 2155 }) 2156 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2157 .def_property_readonly_static( 2158 "current", 2159 [](py::object & /*class*/) { 2160 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2161 if (!loc) 2162 throw SetPyError(PyExc_ValueError, "No current Location"); 2163 return loc; 2164 }, 2165 "Gets the Location bound to the current thread or raises ValueError") 2166 .def_static( 2167 "unknown", 2168 [](DefaultingPyMlirContext context) { 2169 return PyLocation(context->getRef(), 2170 mlirLocationUnknownGet(context->get())); 2171 }, 2172 py::arg("context") = py::none(), 2173 "Gets a Location representing an unknown location") 2174 .def_static( 2175 "callsite", 2176 [](PyLocation callee, const std::vector<PyLocation> &frames, 2177 DefaultingPyMlirContext context) { 2178 if (frames.empty()) 2179 throw py::value_error("No caller frames provided"); 2180 MlirLocation caller = frames.back().get(); 2181 for (const PyLocation &frame : 2182 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2183 caller = mlirLocationCallSiteGet(frame.get(), caller); 2184 return PyLocation(context->getRef(), 2185 mlirLocationCallSiteGet(callee.get(), caller)); 2186 }, 2187 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2188 kContextGetCallSiteLocationDocstring) 2189 .def_static( 2190 "file", 2191 [](std::string filename, int line, int col, 2192 DefaultingPyMlirContext context) { 2193 return PyLocation( 2194 context->getRef(), 2195 mlirLocationFileLineColGet( 2196 context->get(), toMlirStringRef(filename), line, col)); 2197 }, 2198 py::arg("filename"), py::arg("line"), py::arg("col"), 2199 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2200 .def_static( 2201 "name", 2202 [](std::string name, llvm::Optional<PyLocation> childLoc, 2203 DefaultingPyMlirContext context) { 2204 return PyLocation( 2205 context->getRef(), 2206 mlirLocationNameGet( 2207 context->get(), toMlirStringRef(name), 2208 childLoc ? childLoc->get() 2209 : mlirLocationUnknownGet(context->get()))); 2210 }, 2211 py::arg("name"), py::arg("childLoc") = py::none(), 2212 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2213 .def_property_readonly( 2214 "context", 2215 [](PyLocation &self) { return self.getContext().getObject(); }, 2216 "Context that owns the Location") 2217 .def("__repr__", [](PyLocation &self) { 2218 PyPrintAccumulator printAccum; 2219 mlirLocationPrint(self, printAccum.getCallback(), 2220 printAccum.getUserData()); 2221 return printAccum.join(); 2222 }); 2223 2224 //---------------------------------------------------------------------------- 2225 // Mapping of Module 2226 //---------------------------------------------------------------------------- 2227 py::class_<PyModule>(m, "Module", py::module_local()) 2228 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2229 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2230 .def_static( 2231 "parse", 2232 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2233 MlirModule module = mlirModuleCreateParse( 2234 context->get(), toMlirStringRef(moduleAsm)); 2235 // TODO: Rework error reporting once diagnostic engine is exposed 2236 // in C API. 2237 if (mlirModuleIsNull(module)) { 2238 throw SetPyError( 2239 PyExc_ValueError, 2240 "Unable to parse module assembly (see diagnostics)"); 2241 } 2242 return PyModule::forModule(module).releaseObject(); 2243 }, 2244 py::arg("asm"), py::arg("context") = py::none(), 2245 kModuleParseDocstring) 2246 .def_static( 2247 "create", 2248 [](DefaultingPyLocation loc) { 2249 MlirModule module = mlirModuleCreateEmpty(loc); 2250 return PyModule::forModule(module).releaseObject(); 2251 }, 2252 py::arg("loc") = py::none(), "Creates an empty module") 2253 .def_property_readonly( 2254 "context", 2255 [](PyModule &self) { return self.getContext().getObject(); }, 2256 "Context that created the Module") 2257 .def_property_readonly( 2258 "operation", 2259 [](PyModule &self) { 2260 return PyOperation::forOperation(self.getContext(), 2261 mlirModuleGetOperation(self.get()), 2262 self.getRef().releaseObject()) 2263 .releaseObject(); 2264 }, 2265 "Accesses the module as an operation") 2266 .def_property_readonly( 2267 "body", 2268 [](PyModule &self) { 2269 PyOperationRef module_op = PyOperation::forOperation( 2270 self.getContext(), mlirModuleGetOperation(self.get()), 2271 self.getRef().releaseObject()); 2272 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2273 return returnBlock; 2274 }, 2275 "Return the block for this module") 2276 .def( 2277 "dump", 2278 [](PyModule &self) { 2279 mlirOperationDump(mlirModuleGetOperation(self.get())); 2280 }, 2281 kDumpDocstring) 2282 .def( 2283 "__str__", 2284 [](py::object self) { 2285 // Defer to the operation's __str__. 2286 return self.attr("operation").attr("__str__")(); 2287 }, 2288 kOperationStrDunderDocstring); 2289 2290 //---------------------------------------------------------------------------- 2291 // Mapping of Operation. 2292 //---------------------------------------------------------------------------- 2293 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2294 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2295 [](PyOperationBase &self) { 2296 return self.getOperation().getCapsule(); 2297 }) 2298 .def("__eq__", 2299 [](PyOperationBase &self, PyOperationBase &other) { 2300 return &self.getOperation() == &other.getOperation(); 2301 }) 2302 .def("__eq__", 2303 [](PyOperationBase &self, py::object other) { return false; }) 2304 .def("__hash__", 2305 [](PyOperationBase &self) { 2306 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2307 }) 2308 .def_property_readonly("attributes", 2309 [](PyOperationBase &self) { 2310 return PyOpAttributeMap( 2311 self.getOperation().getRef()); 2312 }) 2313 .def_property_readonly("operands", 2314 [](PyOperationBase &self) { 2315 return PyOpOperandList( 2316 self.getOperation().getRef()); 2317 }) 2318 .def_property_readonly("regions", 2319 [](PyOperationBase &self) { 2320 return PyRegionList( 2321 self.getOperation().getRef()); 2322 }) 2323 .def_property_readonly( 2324 "results", 2325 [](PyOperationBase &self) { 2326 return PyOpResultList(self.getOperation().getRef()); 2327 }, 2328 "Returns the list of Operation results.") 2329 .def_property_readonly( 2330 "result", 2331 [](PyOperationBase &self) { 2332 auto &operation = self.getOperation(); 2333 auto numResults = mlirOperationGetNumResults(operation); 2334 if (numResults != 1) { 2335 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2336 throw SetPyError( 2337 PyExc_ValueError, 2338 Twine("Cannot call .result on operation ") + 2339 StringRef(name.data, name.length) + " which has " + 2340 Twine(numResults) + 2341 " results (it is only valid for operations with a " 2342 "single result)"); 2343 } 2344 return PyOpResult(operation.getRef(), 2345 mlirOperationGetResult(operation, 0)); 2346 }, 2347 "Shortcut to get an op result if it has only one (throws an error " 2348 "otherwise).") 2349 .def_property_readonly( 2350 "location", 2351 [](PyOperationBase &self) { 2352 PyOperation &operation = self.getOperation(); 2353 return PyLocation(operation.getContext(), 2354 mlirOperationGetLocation(operation.get())); 2355 }, 2356 "Returns the source location the operation was defined or derived " 2357 "from.") 2358 .def( 2359 "__str__", 2360 [](PyOperationBase &self) { 2361 return self.getAsm(/*binary=*/false, 2362 /*largeElementsLimit=*/llvm::None, 2363 /*enableDebugInfo=*/false, 2364 /*prettyDebugInfo=*/false, 2365 /*printGenericOpForm=*/false, 2366 /*useLocalScope=*/false, 2367 /*assumeVerified=*/false); 2368 }, 2369 "Returns the assembly form of the operation.") 2370 .def("print", &PyOperationBase::print, 2371 // Careful: Lots of arguments must match up with print method. 2372 py::arg("file") = py::none(), py::arg("binary") = false, 2373 py::arg("large_elements_limit") = py::none(), 2374 py::arg("enable_debug_info") = false, 2375 py::arg("pretty_debug_info") = false, 2376 py::arg("print_generic_op_form") = false, 2377 py::arg("use_local_scope") = false, 2378 py::arg("assume_verified") = false, kOperationPrintDocstring) 2379 .def("get_asm", &PyOperationBase::getAsm, 2380 // Careful: Lots of arguments must match up with get_asm method. 2381 py::arg("binary") = false, 2382 py::arg("large_elements_limit") = py::none(), 2383 py::arg("enable_debug_info") = false, 2384 py::arg("pretty_debug_info") = false, 2385 py::arg("print_generic_op_form") = false, 2386 py::arg("use_local_scope") = false, 2387 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2388 .def( 2389 "verify", 2390 [](PyOperationBase &self) { 2391 return mlirOperationVerify(self.getOperation()); 2392 }, 2393 "Verify the operation and return true if it passes, false if it " 2394 "fails.") 2395 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2396 "Puts self immediately after the other operation in its parent " 2397 "block.") 2398 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2399 "Puts self immediately before the other operation in its parent " 2400 "block.") 2401 .def( 2402 "detach_from_parent", 2403 [](PyOperationBase &self) { 2404 PyOperation &operation = self.getOperation(); 2405 operation.checkValid(); 2406 if (!operation.isAttached()) 2407 throw py::value_error("Detached operation has no parent."); 2408 2409 operation.detachFromParent(); 2410 return operation.createOpView(); 2411 }, 2412 "Detaches the operation from its parent block."); 2413 2414 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2415 .def_static("create", &PyOperation::create, py::arg("name"), 2416 py::arg("results") = py::none(), 2417 py::arg("operands") = py::none(), 2418 py::arg("attributes") = py::none(), 2419 py::arg("successors") = py::none(), py::arg("regions") = 0, 2420 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2421 kOperationCreateDocstring) 2422 .def_property_readonly("parent", 2423 [](PyOperation &self) -> py::object { 2424 auto parent = self.getParentOperation(); 2425 if (parent) 2426 return parent->getObject(); 2427 return py::none(); 2428 }) 2429 .def("erase", &PyOperation::erase) 2430 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2431 &PyOperation::getCapsule) 2432 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2433 .def_property_readonly("name", 2434 [](PyOperation &self) { 2435 self.checkValid(); 2436 MlirOperation operation = self.get(); 2437 MlirStringRef name = mlirIdentifierStr( 2438 mlirOperationGetName(operation)); 2439 return py::str(name.data, name.length); 2440 }) 2441 .def_property_readonly( 2442 "context", 2443 [](PyOperation &self) { 2444 self.checkValid(); 2445 return self.getContext().getObject(); 2446 }, 2447 "Context that owns the Operation") 2448 .def_property_readonly("opview", &PyOperation::createOpView); 2449 2450 auto opViewClass = 2451 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2452 .def(py::init<py::object>(), py::arg("operation")) 2453 .def_property_readonly("operation", &PyOpView::getOperationObject) 2454 .def_property_readonly( 2455 "context", 2456 [](PyOpView &self) { 2457 return self.getOperation().getContext().getObject(); 2458 }, 2459 "Context that owns the Operation") 2460 .def("__str__", [](PyOpView &self) { 2461 return py::str(self.getOperationObject()); 2462 }); 2463 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2464 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2465 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2466 opViewClass.attr("build_generic") = classmethod( 2467 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2468 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2469 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2470 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2471 "Builds a specific, generated OpView based on class level attributes."); 2472 2473 //---------------------------------------------------------------------------- 2474 // Mapping of PyRegion. 2475 //---------------------------------------------------------------------------- 2476 py::class_<PyRegion>(m, "Region", py::module_local()) 2477 .def_property_readonly( 2478 "blocks", 2479 [](PyRegion &self) { 2480 return PyBlockList(self.getParentOperation(), self.get()); 2481 }, 2482 "Returns a forward-optimized sequence of blocks.") 2483 .def_property_readonly( 2484 "owner", 2485 [](PyRegion &self) { 2486 return self.getParentOperation()->createOpView(); 2487 }, 2488 "Returns the operation owning this region.") 2489 .def( 2490 "__iter__", 2491 [](PyRegion &self) { 2492 self.checkValid(); 2493 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2494 return PyBlockIterator(self.getParentOperation(), firstBlock); 2495 }, 2496 "Iterates over blocks in the region.") 2497 .def("__eq__", 2498 [](PyRegion &self, PyRegion &other) { 2499 return self.get().ptr == other.get().ptr; 2500 }) 2501 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2502 2503 //---------------------------------------------------------------------------- 2504 // Mapping of PyBlock. 2505 //---------------------------------------------------------------------------- 2506 py::class_<PyBlock>(m, "Block", py::module_local()) 2507 .def_property_readonly( 2508 "owner", 2509 [](PyBlock &self) { 2510 return self.getParentOperation()->createOpView(); 2511 }, 2512 "Returns the owning operation of this block.") 2513 .def_property_readonly( 2514 "region", 2515 [](PyBlock &self) { 2516 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2517 return PyRegion(self.getParentOperation(), region); 2518 }, 2519 "Returns the owning region of this block.") 2520 .def_property_readonly( 2521 "arguments", 2522 [](PyBlock &self) { 2523 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2524 }, 2525 "Returns a list of block arguments.") 2526 .def_property_readonly( 2527 "operations", 2528 [](PyBlock &self) { 2529 return PyOperationList(self.getParentOperation(), self.get()); 2530 }, 2531 "Returns a forward-optimized sequence of operations.") 2532 .def_static( 2533 "create_at_start", 2534 [](PyRegion &parent, py::list pyArgTypes) { 2535 parent.checkValid(); 2536 llvm::SmallVector<MlirType, 4> argTypes; 2537 argTypes.reserve(pyArgTypes.size()); 2538 for (auto &pyArg : pyArgTypes) { 2539 argTypes.push_back(pyArg.cast<PyType &>()); 2540 } 2541 2542 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2543 mlirRegionInsertOwnedBlock(parent, 0, block); 2544 return PyBlock(parent.getParentOperation(), block); 2545 }, 2546 py::arg("parent"), py::arg("arg_types") = py::list(), 2547 "Creates and returns a new Block at the beginning of the given " 2548 "region (with given argument types).") 2549 .def( 2550 "create_before", 2551 [](PyBlock &self, py::args pyArgTypes) { 2552 self.checkValid(); 2553 llvm::SmallVector<MlirType, 4> argTypes; 2554 argTypes.reserve(pyArgTypes.size()); 2555 for (auto &pyArg : pyArgTypes) { 2556 argTypes.push_back(pyArg.cast<PyType &>()); 2557 } 2558 2559 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2560 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2561 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2562 return PyBlock(self.getParentOperation(), block); 2563 }, 2564 "Creates and returns a new Block before this block " 2565 "(with given argument types).") 2566 .def( 2567 "create_after", 2568 [](PyBlock &self, py::args pyArgTypes) { 2569 self.checkValid(); 2570 llvm::SmallVector<MlirType, 4> argTypes; 2571 argTypes.reserve(pyArgTypes.size()); 2572 for (auto &pyArg : pyArgTypes) { 2573 argTypes.push_back(pyArg.cast<PyType &>()); 2574 } 2575 2576 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2577 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2578 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2579 return PyBlock(self.getParentOperation(), block); 2580 }, 2581 "Creates and returns a new Block after this block " 2582 "(with given argument types).") 2583 .def( 2584 "__iter__", 2585 [](PyBlock &self) { 2586 self.checkValid(); 2587 MlirOperation firstOperation = 2588 mlirBlockGetFirstOperation(self.get()); 2589 return PyOperationIterator(self.getParentOperation(), 2590 firstOperation); 2591 }, 2592 "Iterates over operations in the block.") 2593 .def("__eq__", 2594 [](PyBlock &self, PyBlock &other) { 2595 return self.get().ptr == other.get().ptr; 2596 }) 2597 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2598 .def( 2599 "__str__", 2600 [](PyBlock &self) { 2601 self.checkValid(); 2602 PyPrintAccumulator printAccum; 2603 mlirBlockPrint(self.get(), printAccum.getCallback(), 2604 printAccum.getUserData()); 2605 return printAccum.join(); 2606 }, 2607 "Returns the assembly form of the block.") 2608 .def( 2609 "append", 2610 [](PyBlock &self, PyOperationBase &operation) { 2611 if (operation.getOperation().isAttached()) 2612 operation.getOperation().detachFromParent(); 2613 2614 MlirOperation mlirOperation = operation.getOperation().get(); 2615 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2616 operation.getOperation().setAttached( 2617 self.getParentOperation().getObject()); 2618 }, 2619 py::arg("operation"), 2620 "Appends an operation to this block. If the operation is currently " 2621 "in another block, it will be moved."); 2622 2623 //---------------------------------------------------------------------------- 2624 // Mapping of PyInsertionPoint. 2625 //---------------------------------------------------------------------------- 2626 2627 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2628 .def(py::init<PyBlock &>(), py::arg("block"), 2629 "Inserts after the last operation but still inside the block.") 2630 .def("__enter__", &PyInsertionPoint::contextEnter) 2631 .def("__exit__", &PyInsertionPoint::contextExit) 2632 .def_property_readonly_static( 2633 "current", 2634 [](py::object & /*class*/) { 2635 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2636 if (!ip) 2637 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2638 return ip; 2639 }, 2640 "Gets the InsertionPoint bound to the current thread or raises " 2641 "ValueError if none has been set") 2642 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2643 "Inserts before a referenced operation.") 2644 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2645 py::arg("block"), "Inserts at the beginning of the block.") 2646 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2647 py::arg("block"), "Inserts before the block terminator.") 2648 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2649 "Inserts an operation.") 2650 .def_property_readonly( 2651 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2652 "Returns the block that this InsertionPoint points to."); 2653 2654 //---------------------------------------------------------------------------- 2655 // Mapping of PyAttribute. 2656 //---------------------------------------------------------------------------- 2657 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2658 // Delegate to the PyAttribute copy constructor, which will also lifetime 2659 // extend the backing context which owns the MlirAttribute. 2660 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2661 "Casts the passed attribute to the generic Attribute") 2662 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2663 &PyAttribute::getCapsule) 2664 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2665 .def_static( 2666 "parse", 2667 [](std::string attrSpec, DefaultingPyMlirContext context) { 2668 MlirAttribute type = mlirAttributeParseGet( 2669 context->get(), toMlirStringRef(attrSpec)); 2670 // TODO: Rework error reporting once diagnostic engine is exposed 2671 // in C API. 2672 if (mlirAttributeIsNull(type)) { 2673 throw SetPyError(PyExc_ValueError, 2674 Twine("Unable to parse attribute: '") + 2675 attrSpec + "'"); 2676 } 2677 return PyAttribute(context->getRef(), type); 2678 }, 2679 py::arg("asm"), py::arg("context") = py::none(), 2680 "Parses an attribute from an assembly form") 2681 .def_property_readonly( 2682 "context", 2683 [](PyAttribute &self) { return self.getContext().getObject(); }, 2684 "Context that owns the Attribute") 2685 .def_property_readonly("type", 2686 [](PyAttribute &self) { 2687 return PyType(self.getContext()->getRef(), 2688 mlirAttributeGetType(self)); 2689 }) 2690 .def( 2691 "get_named", 2692 [](PyAttribute &self, std::string name) { 2693 return PyNamedAttribute(self, std::move(name)); 2694 }, 2695 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2696 .def("__eq__", 2697 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2698 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2699 .def("__hash__", 2700 [](PyAttribute &self) { 2701 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2702 }) 2703 .def( 2704 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2705 kDumpDocstring) 2706 .def( 2707 "__str__", 2708 [](PyAttribute &self) { 2709 PyPrintAccumulator printAccum; 2710 mlirAttributePrint(self, printAccum.getCallback(), 2711 printAccum.getUserData()); 2712 return printAccum.join(); 2713 }, 2714 "Returns the assembly form of the Attribute.") 2715 .def("__repr__", [](PyAttribute &self) { 2716 // Generally, assembly formats are not printed for __repr__ because 2717 // this can cause exceptionally long debug output and exceptions. 2718 // However, attribute values are generally considered useful and are 2719 // printed. This may need to be re-evaluated if debug dumps end up 2720 // being excessive. 2721 PyPrintAccumulator printAccum; 2722 printAccum.parts.append("Attribute("); 2723 mlirAttributePrint(self, printAccum.getCallback(), 2724 printAccum.getUserData()); 2725 printAccum.parts.append(")"); 2726 return printAccum.join(); 2727 }); 2728 2729 //---------------------------------------------------------------------------- 2730 // Mapping of PyNamedAttribute 2731 //---------------------------------------------------------------------------- 2732 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2733 .def("__repr__", 2734 [](PyNamedAttribute &self) { 2735 PyPrintAccumulator printAccum; 2736 printAccum.parts.append("NamedAttribute("); 2737 printAccum.parts.append( 2738 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2739 mlirIdentifierStr(self.namedAttr.name).length)); 2740 printAccum.parts.append("="); 2741 mlirAttributePrint(self.namedAttr.attribute, 2742 printAccum.getCallback(), 2743 printAccum.getUserData()); 2744 printAccum.parts.append(")"); 2745 return printAccum.join(); 2746 }) 2747 .def_property_readonly( 2748 "name", 2749 [](PyNamedAttribute &self) { 2750 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2751 mlirIdentifierStr(self.namedAttr.name).length); 2752 }, 2753 "The name of the NamedAttribute binding") 2754 .def_property_readonly( 2755 "attr", 2756 [](PyNamedAttribute &self) { 2757 // TODO: When named attribute is removed/refactored, also remove 2758 // this constructor (it does an inefficient table lookup). 2759 auto contextRef = PyMlirContext::forContext( 2760 mlirAttributeGetContext(self.namedAttr.attribute)); 2761 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2762 }, 2763 py::keep_alive<0, 1>(), 2764 "The underlying generic attribute of the NamedAttribute binding"); 2765 2766 //---------------------------------------------------------------------------- 2767 // Mapping of PyType. 2768 //---------------------------------------------------------------------------- 2769 py::class_<PyType>(m, "Type", py::module_local()) 2770 // Delegate to the PyType copy constructor, which will also lifetime 2771 // extend the backing context which owns the MlirType. 2772 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2773 "Casts the passed type to the generic Type") 2774 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2775 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2776 .def_static( 2777 "parse", 2778 [](std::string typeSpec, DefaultingPyMlirContext context) { 2779 MlirType type = 2780 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2781 // TODO: Rework error reporting once diagnostic engine is exposed 2782 // in C API. 2783 if (mlirTypeIsNull(type)) { 2784 throw SetPyError(PyExc_ValueError, 2785 Twine("Unable to parse type: '") + typeSpec + 2786 "'"); 2787 } 2788 return PyType(context->getRef(), type); 2789 }, 2790 py::arg("asm"), py::arg("context") = py::none(), 2791 kContextParseTypeDocstring) 2792 .def_property_readonly( 2793 "context", [](PyType &self) { return self.getContext().getObject(); }, 2794 "Context that owns the Type") 2795 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2796 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2797 .def("__hash__", 2798 [](PyType &self) { 2799 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2800 }) 2801 .def( 2802 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2803 .def( 2804 "__str__", 2805 [](PyType &self) { 2806 PyPrintAccumulator printAccum; 2807 mlirTypePrint(self, printAccum.getCallback(), 2808 printAccum.getUserData()); 2809 return printAccum.join(); 2810 }, 2811 "Returns the assembly form of the type.") 2812 .def("__repr__", [](PyType &self) { 2813 // Generally, assembly formats are not printed for __repr__ because 2814 // this can cause exceptionally long debug output and exceptions. 2815 // However, types are an exception as they typically have compact 2816 // assembly forms and printing them is useful. 2817 PyPrintAccumulator printAccum; 2818 printAccum.parts.append("Type("); 2819 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2820 printAccum.parts.append(")"); 2821 return printAccum.join(); 2822 }); 2823 2824 //---------------------------------------------------------------------------- 2825 // Mapping of Value. 2826 //---------------------------------------------------------------------------- 2827 py::class_<PyValue>(m, "Value", py::module_local()) 2828 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2829 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2830 .def_property_readonly( 2831 "context", 2832 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2833 "Context in which the value lives.") 2834 .def( 2835 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2836 kDumpDocstring) 2837 .def_property_readonly( 2838 "owner", 2839 [](PyValue &self) { 2840 assert(mlirOperationEqual(self.getParentOperation()->get(), 2841 mlirOpResultGetOwner(self.get())) && 2842 "expected the owner of the value in Python to match that in " 2843 "the IR"); 2844 return self.getParentOperation().getObject(); 2845 }) 2846 .def("__eq__", 2847 [](PyValue &self, PyValue &other) { 2848 return self.get().ptr == other.get().ptr; 2849 }) 2850 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2851 .def("__hash__", 2852 [](PyValue &self) { 2853 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2854 }) 2855 .def( 2856 "__str__", 2857 [](PyValue &self) { 2858 PyPrintAccumulator printAccum; 2859 printAccum.parts.append("Value("); 2860 mlirValuePrint(self.get(), printAccum.getCallback(), 2861 printAccum.getUserData()); 2862 printAccum.parts.append(")"); 2863 return printAccum.join(); 2864 }, 2865 kValueDunderStrDocstring) 2866 .def_property_readonly("type", [](PyValue &self) { 2867 return PyType(self.getParentOperation()->getContext(), 2868 mlirValueGetType(self.get())); 2869 }); 2870 PyBlockArgument::bind(m); 2871 PyOpResult::bind(m); 2872 2873 //---------------------------------------------------------------------------- 2874 // Mapping of SymbolTable. 2875 //---------------------------------------------------------------------------- 2876 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 2877 .def(py::init<PyOperationBase &>()) 2878 .def("__getitem__", &PySymbolTable::dunderGetItem) 2879 .def("insert", &PySymbolTable::insert, py::arg("operation")) 2880 .def("erase", &PySymbolTable::erase, py::arg("operation")) 2881 .def("__delitem__", &PySymbolTable::dunderDel) 2882 .def("__contains__", 2883 [](PySymbolTable &table, const std::string &name) { 2884 return !mlirOperationIsNull(mlirSymbolTableLookup( 2885 table, mlirStringRefCreate(name.data(), name.length()))); 2886 }) 2887 // Static helpers. 2888 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 2889 py::arg("symbol"), py::arg("name")) 2890 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 2891 py::arg("symbol")) 2892 .def_static("get_visibility", &PySymbolTable::getVisibility, 2893 py::arg("symbol")) 2894 .def_static("set_visibility", &PySymbolTable::setVisibility, 2895 py::arg("symbol"), py::arg("visibility")) 2896 .def_static("replace_all_symbol_uses", 2897 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 2898 py::arg("new_symbol"), py::arg("from_op")) 2899 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 2900 py::arg("from_op"), py::arg("all_sym_uses_visible"), 2901 py::arg("callback")); 2902 2903 // Container bindings. 2904 PyBlockArgumentList::bind(m); 2905 PyBlockIterator::bind(m); 2906 PyBlockList::bind(m); 2907 PyOperationIterator::bind(m); 2908 PyOperationList::bind(m); 2909 PyOpAttributeMap::bind(m); 2910 PyOpOperandList::bind(m); 2911 PyOpResultList::bind(m); 2912 PyRegionIterator::bind(m); 2913 PyRegionList::bind(m); 2914 2915 // Debug bindings. 2916 PyGlobalDebugFlag::bind(m); 2917 } 2918