1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 namespace py = pybind11; 25 using namespace mlir; 26 using namespace mlir::python; 27 28 using llvm::SmallVector; 29 using llvm::StringRef; 30 using llvm::Twine; 31 32 //------------------------------------------------------------------------------ 33 // Docstrings (trivial, non-duplicated docstrings are included inline). 34 //------------------------------------------------------------------------------ 35 36 static const char kContextParseTypeDocstring[] = 37 R"(Parses the assembly form of a type. 38 39 Returns a Type object or raises a ValueError if the type cannot be parsed. 40 41 See also: https://mlir.llvm.org/docs/LangRef/#type-system 42 )"; 43 44 static const char kContextGetCallSiteLocationDocstring[] = 45 R"(Gets a Location representing a caller and callsite)"; 46 47 static const char kContextGetFileLocationDocstring[] = 48 R"(Gets a Location representing a file, line and column)"; 49 50 static const char kContextGetNameLocationDocString[] = 51 R"(Gets a Location representing a named location with optional child location)"; 52 53 static const char kModuleParseDocstring[] = 54 R"(Parses a module's assembly format from a string. 55 56 Returns a new MlirModule or raises a ValueError if the parsing fails. 57 58 See also: https://mlir.llvm.org/docs/LangRef/ 59 )"; 60 61 static const char kOperationCreateDocstring[] = 62 R"(Creates a new operation. 63 64 Args: 65 name: Operation name (e.g. "dialect.operation"). 66 results: Sequence of Type representing op result types. 67 attributes: Dict of str:Attribute. 68 successors: List of Block for the operation's successors. 69 regions: Number of regions to create. 70 location: A Location object (defaults to resolve from context manager). 71 ip: An InsertionPoint (defaults to resolve from context manager or set to 72 False to disable insertion, even with an insertion point set in the 73 context manager). 74 Returns: 75 A new "detached" Operation object. Detached operations can be added 76 to blocks, which causes them to become "attached." 77 )"; 78 79 static const char kOperationPrintDocstring[] = 80 R"(Prints the assembly form of the operation to a file like object. 81 82 Args: 83 file: The file like object to write to. Defaults to sys.stdout. 84 binary: Whether to write bytes (True) or str (False). Defaults to False. 85 large_elements_limit: Whether to elide elements attributes above this 86 number of elements. Defaults to None (no limit). 87 enable_debug_info: Whether to print debug/location information. Defaults 88 to False. 89 pretty_debug_info: Whether to format debug information for easier reading 90 by a human (warning: the result is unparseable). 91 print_generic_op_form: Whether to print the generic assembly forms of all 92 ops. Defaults to False. 93 use_local_Scope: Whether to print in a way that is more optimized for 94 multi-threaded access but may not be consistent with how the overall 95 module prints. 96 )"; 97 98 static const char kOperationGetAsmDocstring[] = 99 R"(Gets the assembly form of the operation with all options available. 100 101 Args: 102 binary: Whether to return a bytes (True) or str (False) object. Defaults to 103 False. 104 ... others ...: See the print() method for common keyword arguments for 105 configuring the printout. 106 Returns: 107 Either a bytes or str object, depending on the setting of the 'binary' 108 argument. 109 )"; 110 111 static const char kOperationStrDunderDocstring[] = 112 R"(Gets the assembly form of the operation with default options. 113 114 If more advanced control over the assembly formatting or I/O options is needed, 115 use the dedicated print or get_asm method, which supports keyword arguments to 116 customize behavior. 117 )"; 118 119 static const char kDumpDocstring[] = 120 R"(Dumps a debug representation of the object to stderr.)"; 121 122 static const char kAppendBlockDocstring[] = 123 R"(Appends a new block, with argument types as positional args. 124 125 Returns: 126 The created block. 127 )"; 128 129 static const char kValueDunderStrDocstring[] = 130 R"(Returns the string form of the value. 131 132 If the value is a block argument, this is the assembly form of its type and the 133 position in the argument list. If the value is an operation result, this is 134 equivalent to printing the operation that produced it. 135 )"; 136 137 //------------------------------------------------------------------------------ 138 // Utilities. 139 //------------------------------------------------------------------------------ 140 141 /// Helper for creating an @classmethod. 142 template <class Func, typename... Args> 143 py::object classmethod(Func f, Args... args) { 144 py::object cf = py::cpp_function(f, args...); 145 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 146 } 147 148 static py::object 149 createCustomDialectWrapper(const std::string &dialectNamespace, 150 py::object dialectDescriptor) { 151 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 152 if (!dialectClass) { 153 // Use the base class. 154 return py::cast(PyDialect(std::move(dialectDescriptor))); 155 } 156 157 // Create the custom implementation. 158 return (*dialectClass)(std::move(dialectDescriptor)); 159 } 160 161 static MlirStringRef toMlirStringRef(const std::string &s) { 162 return mlirStringRefCreate(s.data(), s.size()); 163 } 164 165 /// Wrapper for the global LLVM debugging flag. 166 struct PyGlobalDebugFlag { 167 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 168 169 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 170 171 static void bind(py::module &m) { 172 // Debug flags. 173 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 174 .def_property_static("flag", &PyGlobalDebugFlag::get, 175 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 176 } 177 }; 178 179 //------------------------------------------------------------------------------ 180 // Collections. 181 //------------------------------------------------------------------------------ 182 183 namespace { 184 185 class PyRegionIterator { 186 public: 187 PyRegionIterator(PyOperationRef operation) 188 : operation(std::move(operation)) {} 189 190 PyRegionIterator &dunderIter() { return *this; } 191 192 PyRegion dunderNext() { 193 operation->checkValid(); 194 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 195 throw py::stop_iteration(); 196 } 197 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 198 return PyRegion(operation, region); 199 } 200 201 static void bind(py::module &m) { 202 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 203 .def("__iter__", &PyRegionIterator::dunderIter) 204 .def("__next__", &PyRegionIterator::dunderNext); 205 } 206 207 private: 208 PyOperationRef operation; 209 int nextIndex = 0; 210 }; 211 212 /// Regions of an op are fixed length and indexed numerically so are represented 213 /// with a sequence-like container. 214 class PyRegionList { 215 public: 216 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 217 218 intptr_t dunderLen() { 219 operation->checkValid(); 220 return mlirOperationGetNumRegions(operation->get()); 221 } 222 223 PyRegion dunderGetItem(intptr_t index) { 224 // dunderLen checks validity. 225 if (index < 0 || index >= dunderLen()) { 226 throw SetPyError(PyExc_IndexError, 227 "attempt to access out of bounds region"); 228 } 229 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 230 return PyRegion(operation, region); 231 } 232 233 static void bind(py::module &m) { 234 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 235 .def("__len__", &PyRegionList::dunderLen) 236 .def("__getitem__", &PyRegionList::dunderGetItem); 237 } 238 239 private: 240 PyOperationRef operation; 241 }; 242 243 class PyBlockIterator { 244 public: 245 PyBlockIterator(PyOperationRef operation, MlirBlock next) 246 : operation(std::move(operation)), next(next) {} 247 248 PyBlockIterator &dunderIter() { return *this; } 249 250 PyBlock dunderNext() { 251 operation->checkValid(); 252 if (mlirBlockIsNull(next)) { 253 throw py::stop_iteration(); 254 } 255 256 PyBlock returnBlock(operation, next); 257 next = mlirBlockGetNextInRegion(next); 258 return returnBlock; 259 } 260 261 static void bind(py::module &m) { 262 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 263 .def("__iter__", &PyBlockIterator::dunderIter) 264 .def("__next__", &PyBlockIterator::dunderNext); 265 } 266 267 private: 268 PyOperationRef operation; 269 MlirBlock next; 270 }; 271 272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 273 /// we present them as a more full-featured list-like container but optimize 274 /// it for forward iteration. Blocks are always owned by a region. 275 class PyBlockList { 276 public: 277 PyBlockList(PyOperationRef operation, MlirRegion region) 278 : operation(std::move(operation)), region(region) {} 279 280 PyBlockIterator dunderIter() { 281 operation->checkValid(); 282 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 283 } 284 285 intptr_t dunderLen() { 286 operation->checkValid(); 287 intptr_t count = 0; 288 MlirBlock block = mlirRegionGetFirstBlock(region); 289 while (!mlirBlockIsNull(block)) { 290 count += 1; 291 block = mlirBlockGetNextInRegion(block); 292 } 293 return count; 294 } 295 296 PyBlock dunderGetItem(intptr_t index) { 297 operation->checkValid(); 298 if (index < 0) { 299 throw SetPyError(PyExc_IndexError, 300 "attempt to access out of bounds block"); 301 } 302 MlirBlock block = mlirRegionGetFirstBlock(region); 303 while (!mlirBlockIsNull(block)) { 304 if (index == 0) { 305 return PyBlock(operation, block); 306 } 307 block = mlirBlockGetNextInRegion(block); 308 index -= 1; 309 } 310 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 311 } 312 313 PyBlock appendBlock(py::args pyArgTypes) { 314 operation->checkValid(); 315 llvm::SmallVector<MlirType, 4> argTypes; 316 argTypes.reserve(pyArgTypes.size()); 317 for (auto &pyArg : pyArgTypes) { 318 argTypes.push_back(pyArg.cast<PyType &>()); 319 } 320 321 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 322 mlirRegionAppendOwnedBlock(region, block); 323 return PyBlock(operation, block); 324 } 325 326 static void bind(py::module &m) { 327 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 328 .def("__getitem__", &PyBlockList::dunderGetItem) 329 .def("__iter__", &PyBlockList::dunderIter) 330 .def("__len__", &PyBlockList::dunderLen) 331 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 332 } 333 334 private: 335 PyOperationRef operation; 336 MlirRegion region; 337 }; 338 339 class PyOperationIterator { 340 public: 341 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 342 : parentOperation(std::move(parentOperation)), next(next) {} 343 344 PyOperationIterator &dunderIter() { return *this; } 345 346 py::object dunderNext() { 347 parentOperation->checkValid(); 348 if (mlirOperationIsNull(next)) { 349 throw py::stop_iteration(); 350 } 351 352 PyOperationRef returnOperation = 353 PyOperation::forOperation(parentOperation->getContext(), next); 354 next = mlirOperationGetNextInBlock(next); 355 return returnOperation->createOpView(); 356 } 357 358 static void bind(py::module &m) { 359 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 360 .def("__iter__", &PyOperationIterator::dunderIter) 361 .def("__next__", &PyOperationIterator::dunderNext); 362 } 363 364 private: 365 PyOperationRef parentOperation; 366 MlirOperation next; 367 }; 368 369 /// Operations are exposed by the C-API as a forward-only linked list. In 370 /// Python, we present them as a more full-featured list-like container but 371 /// optimize it for forward iteration. Iterable operations are always owned 372 /// by a block. 373 class PyOperationList { 374 public: 375 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 376 : parentOperation(std::move(parentOperation)), block(block) {} 377 378 PyOperationIterator dunderIter() { 379 parentOperation->checkValid(); 380 return PyOperationIterator(parentOperation, 381 mlirBlockGetFirstOperation(block)); 382 } 383 384 intptr_t dunderLen() { 385 parentOperation->checkValid(); 386 intptr_t count = 0; 387 MlirOperation childOp = mlirBlockGetFirstOperation(block); 388 while (!mlirOperationIsNull(childOp)) { 389 count += 1; 390 childOp = mlirOperationGetNextInBlock(childOp); 391 } 392 return count; 393 } 394 395 py::object dunderGetItem(intptr_t index) { 396 parentOperation->checkValid(); 397 if (index < 0) { 398 throw SetPyError(PyExc_IndexError, 399 "attempt to access out of bounds operation"); 400 } 401 MlirOperation childOp = mlirBlockGetFirstOperation(block); 402 while (!mlirOperationIsNull(childOp)) { 403 if (index == 0) { 404 return PyOperation::forOperation(parentOperation->getContext(), childOp) 405 ->createOpView(); 406 } 407 childOp = mlirOperationGetNextInBlock(childOp); 408 index -= 1; 409 } 410 throw SetPyError(PyExc_IndexError, 411 "attempt to access out of bounds operation"); 412 } 413 414 static void bind(py::module &m) { 415 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 416 .def("__getitem__", &PyOperationList::dunderGetItem) 417 .def("__iter__", &PyOperationList::dunderIter) 418 .def("__len__", &PyOperationList::dunderLen); 419 } 420 421 private: 422 PyOperationRef parentOperation; 423 MlirBlock block; 424 }; 425 426 } // namespace 427 428 //------------------------------------------------------------------------------ 429 // PyMlirContext 430 //------------------------------------------------------------------------------ 431 432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 433 py::gil_scoped_acquire acquire; 434 auto &liveContexts = getLiveContexts(); 435 liveContexts[context.ptr] = this; 436 } 437 438 PyMlirContext::~PyMlirContext() { 439 // Note that the only public way to construct an instance is via the 440 // forContext method, which always puts the associated handle into 441 // liveContexts. 442 py::gil_scoped_acquire acquire; 443 getLiveContexts().erase(context.ptr); 444 mlirContextDestroy(context); 445 } 446 447 py::object PyMlirContext::getCapsule() { 448 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 449 } 450 451 py::object PyMlirContext::createFromCapsule(py::object capsule) { 452 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 453 if (mlirContextIsNull(rawContext)) 454 throw py::error_already_set(); 455 return forContext(rawContext).releaseObject(); 456 } 457 458 PyMlirContext *PyMlirContext::createNewContextForInit() { 459 MlirContext context = mlirContextCreate(); 460 mlirRegisterAllDialects(context); 461 return new PyMlirContext(context); 462 } 463 464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 465 py::gil_scoped_acquire acquire; 466 auto &liveContexts = getLiveContexts(); 467 auto it = liveContexts.find(context.ptr); 468 if (it == liveContexts.end()) { 469 // Create. 470 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 471 py::object pyRef = py::cast(unownedContextWrapper); 472 assert(pyRef && "cast to py::object failed"); 473 liveContexts[context.ptr] = unownedContextWrapper; 474 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 475 } 476 // Use existing. 477 py::object pyRef = py::cast(it->second); 478 return PyMlirContextRef(it->second, std::move(pyRef)); 479 } 480 481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 482 static LiveContextMap liveContexts; 483 return liveContexts; 484 } 485 486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 487 488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 489 490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 491 492 pybind11::object PyMlirContext::contextEnter() { 493 return PyThreadContextEntry::pushContext(*this); 494 } 495 496 void PyMlirContext::contextExit(pybind11::object excType, 497 pybind11::object excVal, 498 pybind11::object excTb) { 499 PyThreadContextEntry::popContext(*this); 500 } 501 502 PyMlirContext &DefaultingPyMlirContext::resolve() { 503 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 504 if (!context) { 505 throw SetPyError( 506 PyExc_RuntimeError, 507 "An MLIR function requires a Context but none was provided in the call " 508 "or from the surrounding environment. Either pass to the function with " 509 "a 'context=' argument or establish a default using 'with Context():'"); 510 } 511 return *context; 512 } 513 514 //------------------------------------------------------------------------------ 515 // PyThreadContextEntry management 516 //------------------------------------------------------------------------------ 517 518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 519 static thread_local std::vector<PyThreadContextEntry> stack; 520 return stack; 521 } 522 523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 524 auto &stack = getStack(); 525 if (stack.empty()) 526 return nullptr; 527 return &stack.back(); 528 } 529 530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 531 py::object insertionPoint, 532 py::object location) { 533 auto &stack = getStack(); 534 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 535 std::move(location)); 536 // If the new stack has more than one entry and the context of the new top 537 // entry matches the previous, copy the insertionPoint and location from the 538 // previous entry if missing from the new top entry. 539 if (stack.size() > 1) { 540 auto &prev = *(stack.rbegin() + 1); 541 auto ¤t = stack.back(); 542 if (current.context.is(prev.context)) { 543 // Default non-context objects from the previous entry. 544 if (!current.insertionPoint) 545 current.insertionPoint = prev.insertionPoint; 546 if (!current.location) 547 current.location = prev.location; 548 } 549 } 550 } 551 552 PyMlirContext *PyThreadContextEntry::getContext() { 553 if (!context) 554 return nullptr; 555 return py::cast<PyMlirContext *>(context); 556 } 557 558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 559 if (!insertionPoint) 560 return nullptr; 561 return py::cast<PyInsertionPoint *>(insertionPoint); 562 } 563 564 PyLocation *PyThreadContextEntry::getLocation() { 565 if (!location) 566 return nullptr; 567 return py::cast<PyLocation *>(location); 568 } 569 570 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 571 auto *tos = getTopOfStack(); 572 return tos ? tos->getContext() : nullptr; 573 } 574 575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 576 auto *tos = getTopOfStack(); 577 return tos ? tos->getInsertionPoint() : nullptr; 578 } 579 580 PyLocation *PyThreadContextEntry::getDefaultLocation() { 581 auto *tos = getTopOfStack(); 582 return tos ? tos->getLocation() : nullptr; 583 } 584 585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 586 py::object contextObj = py::cast(context); 587 push(FrameKind::Context, /*context=*/contextObj, 588 /*insertionPoint=*/py::object(), 589 /*location=*/py::object()); 590 return contextObj; 591 } 592 593 void PyThreadContextEntry::popContext(PyMlirContext &context) { 594 auto &stack = getStack(); 595 if (stack.empty()) 596 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 597 auto &tos = stack.back(); 598 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 599 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 600 stack.pop_back(); 601 } 602 603 py::object 604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 605 py::object contextObj = 606 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 607 py::object insertionPointObj = py::cast(insertionPoint); 608 push(FrameKind::InsertionPoint, 609 /*context=*/contextObj, 610 /*insertionPoint=*/insertionPointObj, 611 /*location=*/py::object()); 612 return insertionPointObj; 613 } 614 615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 616 auto &stack = getStack(); 617 if (stack.empty()) 618 throw SetPyError(PyExc_RuntimeError, 619 "Unbalanced InsertionPoint enter/exit"); 620 auto &tos = stack.back(); 621 if (tos.frameKind != FrameKind::InsertionPoint && 622 tos.getInsertionPoint() != &insertionPoint) 623 throw SetPyError(PyExc_RuntimeError, 624 "Unbalanced InsertionPoint enter/exit"); 625 stack.pop_back(); 626 } 627 628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 629 py::object contextObj = location.getContext().getObject(); 630 py::object locationObj = py::cast(location); 631 push(FrameKind::Location, /*context=*/contextObj, 632 /*insertionPoint=*/py::object(), 633 /*location=*/locationObj); 634 return locationObj; 635 } 636 637 void PyThreadContextEntry::popLocation(PyLocation &location) { 638 auto &stack = getStack(); 639 if (stack.empty()) 640 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 641 auto &tos = stack.back(); 642 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 643 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 644 stack.pop_back(); 645 } 646 647 //------------------------------------------------------------------------------ 648 // PyDialect, PyDialectDescriptor, PyDialects 649 //------------------------------------------------------------------------------ 650 651 MlirDialect PyDialects::getDialectForKey(const std::string &key, 652 bool attrError) { 653 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 654 {key.data(), key.size()}); 655 if (mlirDialectIsNull(dialect)) { 656 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 657 Twine("Dialect '") + key + "' not found"); 658 } 659 return dialect; 660 } 661 662 //------------------------------------------------------------------------------ 663 // PyLocation 664 //------------------------------------------------------------------------------ 665 666 py::object PyLocation::getCapsule() { 667 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 668 } 669 670 PyLocation PyLocation::createFromCapsule(py::object capsule) { 671 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 672 if (mlirLocationIsNull(rawLoc)) 673 throw py::error_already_set(); 674 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 675 rawLoc); 676 } 677 678 py::object PyLocation::contextEnter() { 679 return PyThreadContextEntry::pushLocation(*this); 680 } 681 682 void PyLocation::contextExit(py::object excType, py::object excVal, 683 py::object excTb) { 684 PyThreadContextEntry::popLocation(*this); 685 } 686 687 PyLocation &DefaultingPyLocation::resolve() { 688 auto *location = PyThreadContextEntry::getDefaultLocation(); 689 if (!location) { 690 throw SetPyError( 691 PyExc_RuntimeError, 692 "An MLIR function requires a Location but none was provided in the " 693 "call or from the surrounding environment. Either pass to the function " 694 "with a 'loc=' argument or establish a default using 'with loc:'"); 695 } 696 return *location; 697 } 698 699 //------------------------------------------------------------------------------ 700 // PyModule 701 //------------------------------------------------------------------------------ 702 703 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 704 : BaseContextObject(std::move(contextRef)), module(module) {} 705 706 PyModule::~PyModule() { 707 py::gil_scoped_acquire acquire; 708 auto &liveModules = getContext()->liveModules; 709 assert(liveModules.count(module.ptr) == 1 && 710 "destroying module not in live map"); 711 liveModules.erase(module.ptr); 712 mlirModuleDestroy(module); 713 } 714 715 PyModuleRef PyModule::forModule(MlirModule module) { 716 MlirContext context = mlirModuleGetContext(module); 717 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 718 719 py::gil_scoped_acquire acquire; 720 auto &liveModules = contextRef->liveModules; 721 auto it = liveModules.find(module.ptr); 722 if (it == liveModules.end()) { 723 // Create. 724 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 725 // Note that the default return value policy on cast is automatic_reference, 726 // which does not take ownership (delete will not be called). 727 // Just be explicit. 728 py::object pyRef = 729 py::cast(unownedModule, py::return_value_policy::take_ownership); 730 unownedModule->handle = pyRef; 731 liveModules[module.ptr] = 732 std::make_pair(unownedModule->handle, unownedModule); 733 return PyModuleRef(unownedModule, std::move(pyRef)); 734 } 735 // Use existing. 736 PyModule *existing = it->second.second; 737 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 738 return PyModuleRef(existing, std::move(pyRef)); 739 } 740 741 py::object PyModule::createFromCapsule(py::object capsule) { 742 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 743 if (mlirModuleIsNull(rawModule)) 744 throw py::error_already_set(); 745 return forModule(rawModule).releaseObject(); 746 } 747 748 py::object PyModule::getCapsule() { 749 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 750 } 751 752 //------------------------------------------------------------------------------ 753 // PyOperation 754 //------------------------------------------------------------------------------ 755 756 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 757 : BaseContextObject(std::move(contextRef)), operation(operation) {} 758 759 PyOperation::~PyOperation() { 760 // If the operation has already been invalidated there is nothing to do. 761 if (!valid) 762 return; 763 auto &liveOperations = getContext()->liveOperations; 764 assert(liveOperations.count(operation.ptr) == 1 && 765 "destroying operation not in live map"); 766 liveOperations.erase(operation.ptr); 767 if (!isAttached()) { 768 mlirOperationDestroy(operation); 769 } 770 } 771 772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 773 MlirOperation operation, 774 py::object parentKeepAlive) { 775 auto &liveOperations = contextRef->liveOperations; 776 // Create. 777 PyOperation *unownedOperation = 778 new PyOperation(std::move(contextRef), operation); 779 // Note that the default return value policy on cast is automatic_reference, 780 // which does not take ownership (delete will not be called). 781 // Just be explicit. 782 py::object pyRef = 783 py::cast(unownedOperation, py::return_value_policy::take_ownership); 784 unownedOperation->handle = pyRef; 785 if (parentKeepAlive) { 786 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 787 } 788 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 789 return PyOperationRef(unownedOperation, std::move(pyRef)); 790 } 791 792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 793 MlirOperation operation, 794 py::object parentKeepAlive) { 795 auto &liveOperations = contextRef->liveOperations; 796 auto it = liveOperations.find(operation.ptr); 797 if (it == liveOperations.end()) { 798 // Create. 799 return createInstance(std::move(contextRef), operation, 800 std::move(parentKeepAlive)); 801 } 802 // Use existing. 803 PyOperation *existing = it->second.second; 804 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 805 return PyOperationRef(existing, std::move(pyRef)); 806 } 807 808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 809 MlirOperation operation, 810 py::object parentKeepAlive) { 811 auto &liveOperations = contextRef->liveOperations; 812 assert(liveOperations.count(operation.ptr) == 0 && 813 "cannot create detached operation that already exists"); 814 (void)liveOperations; 815 816 PyOperationRef created = createInstance(std::move(contextRef), operation, 817 std::move(parentKeepAlive)); 818 created->attached = false; 819 return created; 820 } 821 822 void PyOperation::checkValid() const { 823 if (!valid) { 824 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 825 } 826 } 827 828 void PyOperationBase::print(py::object fileObject, bool binary, 829 llvm::Optional<int64_t> largeElementsLimit, 830 bool enableDebugInfo, bool prettyDebugInfo, 831 bool printGenericOpForm, bool useLocalScope) { 832 PyOperation &operation = getOperation(); 833 operation.checkValid(); 834 if (fileObject.is_none()) 835 fileObject = py::module::import("sys").attr("stdout"); 836 837 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 838 fileObject.attr("write")("// Verification failed, printing generic form\n"); 839 printGenericOpForm = true; 840 } 841 842 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 843 if (largeElementsLimit) 844 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 845 if (enableDebugInfo) 846 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 847 if (printGenericOpForm) 848 mlirOpPrintingFlagsPrintGenericOpForm(flags); 849 850 PyFileAccumulator accum(fileObject, binary); 851 py::gil_scoped_release(); 852 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 853 accum.getUserData()); 854 mlirOpPrintingFlagsDestroy(flags); 855 } 856 857 py::object PyOperationBase::getAsm(bool binary, 858 llvm::Optional<int64_t> largeElementsLimit, 859 bool enableDebugInfo, bool prettyDebugInfo, 860 bool printGenericOpForm, 861 bool useLocalScope) { 862 py::object fileObject; 863 if (binary) { 864 fileObject = py::module::import("io").attr("BytesIO")(); 865 } else { 866 fileObject = py::module::import("io").attr("StringIO")(); 867 } 868 print(fileObject, /*binary=*/binary, 869 /*largeElementsLimit=*/largeElementsLimit, 870 /*enableDebugInfo=*/enableDebugInfo, 871 /*prettyDebugInfo=*/prettyDebugInfo, 872 /*printGenericOpForm=*/printGenericOpForm, 873 /*useLocalScope=*/useLocalScope); 874 875 return fileObject.attr("getvalue")(); 876 } 877 878 void PyOperationBase::moveAfter(PyOperationBase &other) { 879 PyOperation &operation = getOperation(); 880 PyOperation &otherOp = other.getOperation(); 881 operation.checkValid(); 882 otherOp.checkValid(); 883 mlirOperationMoveAfter(operation, otherOp); 884 operation.parentKeepAlive = otherOp.parentKeepAlive; 885 } 886 887 void PyOperationBase::moveBefore(PyOperationBase &other) { 888 PyOperation &operation = getOperation(); 889 PyOperation &otherOp = other.getOperation(); 890 operation.checkValid(); 891 otherOp.checkValid(); 892 mlirOperationMoveBefore(operation, otherOp); 893 operation.parentKeepAlive = otherOp.parentKeepAlive; 894 } 895 896 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 897 checkValid(); 898 if (!isAttached()) 899 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 900 MlirOperation operation = mlirOperationGetParentOperation(get()); 901 if (mlirOperationIsNull(operation)) 902 return {}; 903 return PyOperation::forOperation(getContext(), operation); 904 } 905 906 PyBlock PyOperation::getBlock() { 907 checkValid(); 908 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 909 MlirBlock block = mlirOperationGetBlock(get()); 910 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 911 assert(parentOperation && "Operation has no parent"); 912 return PyBlock{std::move(*parentOperation), block}; 913 } 914 915 py::object PyOperation::getCapsule() { 916 checkValid(); 917 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 918 } 919 920 py::object PyOperation::createFromCapsule(py::object capsule) { 921 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 922 if (mlirOperationIsNull(rawOperation)) 923 throw py::error_already_set(); 924 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 925 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 926 .releaseObject(); 927 } 928 929 py::object PyOperation::create( 930 std::string name, llvm::Optional<std::vector<PyType *>> results, 931 llvm::Optional<std::vector<PyValue *>> operands, 932 llvm::Optional<py::dict> attributes, 933 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 934 DefaultingPyLocation location, py::object maybeIp) { 935 llvm::SmallVector<MlirValue, 4> mlirOperands; 936 llvm::SmallVector<MlirType, 4> mlirResults; 937 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 938 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 939 940 // General parameter validation. 941 if (regions < 0) 942 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 943 944 // Unpack/validate operands. 945 if (operands) { 946 mlirOperands.reserve(operands->size()); 947 for (PyValue *operand : *operands) { 948 if (!operand) 949 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 950 mlirOperands.push_back(operand->get()); 951 } 952 } 953 954 // Unpack/validate results. 955 if (results) { 956 mlirResults.reserve(results->size()); 957 for (PyType *result : *results) { 958 // TODO: Verify result type originate from the same context. 959 if (!result) 960 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 961 mlirResults.push_back(*result); 962 } 963 } 964 // Unpack/validate attributes. 965 if (attributes) { 966 mlirAttributes.reserve(attributes->size()); 967 for (auto &it : *attributes) { 968 std::string key; 969 try { 970 key = it.first.cast<std::string>(); 971 } catch (py::cast_error &err) { 972 std::string msg = "Invalid attribute key (not a string) when " 973 "attempting to create the operation \"" + 974 name + "\" (" + err.what() + ")"; 975 throw py::cast_error(msg); 976 } 977 try { 978 auto &attribute = it.second.cast<PyAttribute &>(); 979 // TODO: Verify attribute originates from the same context. 980 mlirAttributes.emplace_back(std::move(key), attribute); 981 } catch (py::reference_cast_error &) { 982 // This exception seems thrown when the value is "None". 983 std::string msg = 984 "Found an invalid (`None`?) attribute value for the key \"" + key + 985 "\" when attempting to create the operation \"" + name + "\""; 986 throw py::cast_error(msg); 987 } catch (py::cast_error &err) { 988 std::string msg = "Invalid attribute value for the key \"" + key + 989 "\" when attempting to create the operation \"" + 990 name + "\" (" + err.what() + ")"; 991 throw py::cast_error(msg); 992 } 993 } 994 } 995 // Unpack/validate successors. 996 if (successors) { 997 mlirSuccessors.reserve(successors->size()); 998 for (auto *successor : *successors) { 999 // TODO: Verify successor originate from the same context. 1000 if (!successor) 1001 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1002 mlirSuccessors.push_back(successor->get()); 1003 } 1004 } 1005 1006 // Apply unpacked/validated to the operation state. Beyond this 1007 // point, exceptions cannot be thrown or else the state will leak. 1008 MlirOperationState state = 1009 mlirOperationStateGet(toMlirStringRef(name), location); 1010 if (!mlirOperands.empty()) 1011 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1012 mlirOperands.data()); 1013 if (!mlirResults.empty()) 1014 mlirOperationStateAddResults(&state, mlirResults.size(), 1015 mlirResults.data()); 1016 if (!mlirAttributes.empty()) { 1017 // Note that the attribute names directly reference bytes in 1018 // mlirAttributes, so that vector must not be changed from here 1019 // on. 1020 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1021 mlirNamedAttributes.reserve(mlirAttributes.size()); 1022 for (auto &it : mlirAttributes) 1023 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1024 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1025 toMlirStringRef(it.first)), 1026 it.second)); 1027 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1028 mlirNamedAttributes.data()); 1029 } 1030 if (!mlirSuccessors.empty()) 1031 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1032 mlirSuccessors.data()); 1033 if (regions) { 1034 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1035 mlirRegions.resize(regions); 1036 for (int i = 0; i < regions; ++i) 1037 mlirRegions[i] = mlirRegionCreate(); 1038 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1039 mlirRegions.data()); 1040 } 1041 1042 // Construct the operation. 1043 MlirOperation operation = mlirOperationCreate(&state); 1044 PyOperationRef created = 1045 PyOperation::createDetached(location->getContext(), operation); 1046 1047 // InsertPoint active? 1048 if (!maybeIp.is(py::cast(false))) { 1049 PyInsertionPoint *ip; 1050 if (maybeIp.is_none()) { 1051 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1052 } else { 1053 ip = py::cast<PyInsertionPoint *>(maybeIp); 1054 } 1055 if (ip) 1056 ip->insert(*created.get()); 1057 } 1058 1059 return created->createOpView(); 1060 } 1061 1062 py::object PyOperation::createOpView() { 1063 checkValid(); 1064 MlirIdentifier ident = mlirOperationGetName(get()); 1065 MlirStringRef identStr = mlirIdentifierStr(ident); 1066 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1067 StringRef(identStr.data, identStr.length)); 1068 if (opViewClass) 1069 return (*opViewClass)(getRef().getObject()); 1070 return py::cast(PyOpView(getRef().getObject())); 1071 } 1072 1073 void PyOperation::erase() { 1074 checkValid(); 1075 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1076 // Python reference to a child operation is live. All children should also 1077 // have their `valid` bit set to false. 1078 auto &liveOperations = getContext()->liveOperations; 1079 if (liveOperations.count(operation.ptr)) 1080 liveOperations.erase(operation.ptr); 1081 mlirOperationDestroy(operation); 1082 valid = false; 1083 } 1084 1085 //------------------------------------------------------------------------------ 1086 // PyOpView 1087 //------------------------------------------------------------------------------ 1088 1089 py::object 1090 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1091 py::list operandList, 1092 llvm::Optional<py::dict> attributes, 1093 llvm::Optional<std::vector<PyBlock *>> successors, 1094 llvm::Optional<int> regions, 1095 DefaultingPyLocation location, py::object maybeIp) { 1096 PyMlirContextRef context = location->getContext(); 1097 // Class level operation construction metadata. 1098 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1099 // Operand and result segment specs are either none, which does no 1100 // variadic unpacking, or a list of ints with segment sizes, where each 1101 // element is either a positive number (typically 1 for a scalar) or -1 to 1102 // indicate that it is derived from the length of the same-indexed operand 1103 // or result (implying that it is a list at that position). 1104 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1105 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1106 1107 std::vector<uint32_t> operandSegmentLengths; 1108 std::vector<uint32_t> resultSegmentLengths; 1109 1110 // Validate/determine region count. 1111 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1112 int opMinRegionCount = std::get<0>(opRegionSpec); 1113 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1114 if (!regions) { 1115 regions = opMinRegionCount; 1116 } 1117 if (*regions < opMinRegionCount) { 1118 throw py::value_error( 1119 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1120 llvm::Twine(opMinRegionCount) + 1121 " regions but was built with regions=" + llvm::Twine(*regions)) 1122 .str()); 1123 } 1124 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1125 throw py::value_error( 1126 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1127 llvm::Twine(opMinRegionCount) + 1128 " regions but was built with regions=" + llvm::Twine(*regions)) 1129 .str()); 1130 } 1131 1132 // Unpack results. 1133 std::vector<PyType *> resultTypes; 1134 resultTypes.reserve(resultTypeList.size()); 1135 if (resultSegmentSpecObj.is_none()) { 1136 // Non-variadic result unpacking. 1137 for (auto it : llvm::enumerate(resultTypeList)) { 1138 try { 1139 resultTypes.push_back(py::cast<PyType *>(it.value())); 1140 if (!resultTypes.back()) 1141 throw py::cast_error(); 1142 } catch (py::cast_error &err) { 1143 throw py::value_error((llvm::Twine("Result ") + 1144 llvm::Twine(it.index()) + " of operation \"" + 1145 name + "\" must be a Type (" + err.what() + ")") 1146 .str()); 1147 } 1148 } 1149 } else { 1150 // Sized result unpacking. 1151 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1152 if (resultSegmentSpec.size() != resultTypeList.size()) { 1153 throw py::value_error((llvm::Twine("Operation \"") + name + 1154 "\" requires " + 1155 llvm::Twine(resultSegmentSpec.size()) + 1156 " result segments but was provided " + 1157 llvm::Twine(resultTypeList.size())) 1158 .str()); 1159 } 1160 resultSegmentLengths.reserve(resultTypeList.size()); 1161 for (auto it : 1162 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1163 int segmentSpec = std::get<1>(it.value()); 1164 if (segmentSpec == 1 || segmentSpec == 0) { 1165 // Unpack unary element. 1166 try { 1167 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1168 if (resultType) { 1169 resultTypes.push_back(resultType); 1170 resultSegmentLengths.push_back(1); 1171 } else if (segmentSpec == 0) { 1172 // Allowed to be optional. 1173 resultSegmentLengths.push_back(0); 1174 } else { 1175 throw py::cast_error("was None and result is not optional"); 1176 } 1177 } catch (py::cast_error &err) { 1178 throw py::value_error((llvm::Twine("Result ") + 1179 llvm::Twine(it.index()) + " of operation \"" + 1180 name + "\" must be a Type (" + err.what() + 1181 ")") 1182 .str()); 1183 } 1184 } else if (segmentSpec == -1) { 1185 // Unpack sequence by appending. 1186 try { 1187 if (std::get<0>(it.value()).is_none()) { 1188 // Treat it as an empty list. 1189 resultSegmentLengths.push_back(0); 1190 } else { 1191 // Unpack the list. 1192 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1193 for (py::object segmentItem : segment) { 1194 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1195 if (!resultTypes.back()) { 1196 throw py::cast_error("contained a None item"); 1197 } 1198 } 1199 resultSegmentLengths.push_back(segment.size()); 1200 } 1201 } catch (std::exception &err) { 1202 // NOTE: Sloppy to be using a catch-all here, but there are at least 1203 // three different unrelated exceptions that can be thrown in the 1204 // above "casts". Just keep the scope above small and catch them all. 1205 throw py::value_error((llvm::Twine("Result ") + 1206 llvm::Twine(it.index()) + " of operation \"" + 1207 name + "\" must be a Sequence of Types (" + 1208 err.what() + ")") 1209 .str()); 1210 } 1211 } else { 1212 throw py::value_error("Unexpected segment spec"); 1213 } 1214 } 1215 } 1216 1217 // Unpack operands. 1218 std::vector<PyValue *> operands; 1219 operands.reserve(operands.size()); 1220 if (operandSegmentSpecObj.is_none()) { 1221 // Non-sized operand unpacking. 1222 for (auto it : llvm::enumerate(operandList)) { 1223 try { 1224 operands.push_back(py::cast<PyValue *>(it.value())); 1225 if (!operands.back()) 1226 throw py::cast_error(); 1227 } catch (py::cast_error &err) { 1228 throw py::value_error((llvm::Twine("Operand ") + 1229 llvm::Twine(it.index()) + " of operation \"" + 1230 name + "\" must be a Value (" + err.what() + ")") 1231 .str()); 1232 } 1233 } 1234 } else { 1235 // Sized operand unpacking. 1236 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1237 if (operandSegmentSpec.size() != operandList.size()) { 1238 throw py::value_error((llvm::Twine("Operation \"") + name + 1239 "\" requires " + 1240 llvm::Twine(operandSegmentSpec.size()) + 1241 "operand segments but was provided " + 1242 llvm::Twine(operandList.size())) 1243 .str()); 1244 } 1245 operandSegmentLengths.reserve(operandList.size()); 1246 for (auto it : 1247 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1248 int segmentSpec = std::get<1>(it.value()); 1249 if (segmentSpec == 1 || segmentSpec == 0) { 1250 // Unpack unary element. 1251 try { 1252 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1253 if (operandValue) { 1254 operands.push_back(operandValue); 1255 operandSegmentLengths.push_back(1); 1256 } else if (segmentSpec == 0) { 1257 // Allowed to be optional. 1258 operandSegmentLengths.push_back(0); 1259 } else { 1260 throw py::cast_error("was None and operand is not optional"); 1261 } 1262 } catch (py::cast_error &err) { 1263 throw py::value_error((llvm::Twine("Operand ") + 1264 llvm::Twine(it.index()) + " of operation \"" + 1265 name + "\" must be a Value (" + err.what() + 1266 ")") 1267 .str()); 1268 } 1269 } else if (segmentSpec == -1) { 1270 // Unpack sequence by appending. 1271 try { 1272 if (std::get<0>(it.value()).is_none()) { 1273 // Treat it as an empty list. 1274 operandSegmentLengths.push_back(0); 1275 } else { 1276 // Unpack the list. 1277 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1278 for (py::object segmentItem : segment) { 1279 operands.push_back(py::cast<PyValue *>(segmentItem)); 1280 if (!operands.back()) { 1281 throw py::cast_error("contained a None item"); 1282 } 1283 } 1284 operandSegmentLengths.push_back(segment.size()); 1285 } 1286 } catch (std::exception &err) { 1287 // NOTE: Sloppy to be using a catch-all here, but there are at least 1288 // three different unrelated exceptions that can be thrown in the 1289 // above "casts". Just keep the scope above small and catch them all. 1290 throw py::value_error((llvm::Twine("Operand ") + 1291 llvm::Twine(it.index()) + " of operation \"" + 1292 name + "\" must be a Sequence of Values (" + 1293 err.what() + ")") 1294 .str()); 1295 } 1296 } else { 1297 throw py::value_error("Unexpected segment spec"); 1298 } 1299 } 1300 } 1301 1302 // Merge operand/result segment lengths into attributes if needed. 1303 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1304 // Dup. 1305 if (attributes) { 1306 attributes = py::dict(*attributes); 1307 } else { 1308 attributes = py::dict(); 1309 } 1310 if (attributes->contains("result_segment_sizes") || 1311 attributes->contains("operand_segment_sizes")) { 1312 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1313 "'operand_segment_sizes' attribute is unsupported. " 1314 "Use Operation.create for such low-level access."); 1315 } 1316 1317 // Add result_segment_sizes attribute. 1318 if (!resultSegmentLengths.empty()) { 1319 int64_t size = resultSegmentLengths.size(); 1320 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1321 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1322 resultSegmentLengths.size(), resultSegmentLengths.data()); 1323 (*attributes)["result_segment_sizes"] = 1324 PyAttribute(context, segmentLengthAttr); 1325 } 1326 1327 // Add operand_segment_sizes attribute. 1328 if (!operandSegmentLengths.empty()) { 1329 int64_t size = operandSegmentLengths.size(); 1330 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1331 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1332 operandSegmentLengths.size(), operandSegmentLengths.data()); 1333 (*attributes)["operand_segment_sizes"] = 1334 PyAttribute(context, segmentLengthAttr); 1335 } 1336 } 1337 1338 // Delegate to create. 1339 return PyOperation::create(std::move(name), 1340 /*results=*/std::move(resultTypes), 1341 /*operands=*/std::move(operands), 1342 /*attributes=*/std::move(attributes), 1343 /*successors=*/std::move(successors), 1344 /*regions=*/*regions, location, maybeIp); 1345 } 1346 1347 PyOpView::PyOpView(py::object operationObject) 1348 // Casting through the PyOperationBase base-class and then back to the 1349 // Operation lets us accept any PyOperationBase subclass. 1350 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1351 operationObject(operation.getRef().getObject()) {} 1352 1353 py::object PyOpView::createRawSubclass(py::object userClass) { 1354 // This is... a little gross. The typical pattern is to have a pure python 1355 // class that extends OpView like: 1356 // class AddFOp(_cext.ir.OpView): 1357 // def __init__(self, loc, lhs, rhs): 1358 // operation = loc.context.create_operation( 1359 // "addf", lhs, rhs, results=[lhs.type]) 1360 // super().__init__(operation) 1361 // 1362 // I.e. The goal of the user facing type is to provide a nice constructor 1363 // that has complete freedom for the op under construction. This is at odds 1364 // with our other desire to sometimes create this object by just passing an 1365 // operation (to initialize the base class). We could do *arg and **kwargs 1366 // munging to try to make it work, but instead, we synthesize a new class 1367 // on the fly which extends this user class (AddFOp in this example) and 1368 // *give it* the base class's __init__ method, thus bypassing the 1369 // intermediate subclass's __init__ method entirely. While slightly, 1370 // underhanded, this is safe/legal because the type hierarchy has not changed 1371 // (we just added a new leaf) and we aren't mucking around with __new__. 1372 // Typically, this new class will be stored on the original as "_Raw" and will 1373 // be used for casts and other things that need a variant of the class that 1374 // is initialized purely from an operation. 1375 py::object parentMetaclass = 1376 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1377 py::dict attributes; 1378 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1379 // now. 1380 // auto opViewType = py::type::of<PyOpView>(); 1381 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1382 attributes["__init__"] = opViewType.attr("__init__"); 1383 py::str origName = userClass.attr("__name__"); 1384 py::str newName = py::str("_") + origName; 1385 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1386 } 1387 1388 //------------------------------------------------------------------------------ 1389 // PyInsertionPoint. 1390 //------------------------------------------------------------------------------ 1391 1392 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1393 1394 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1395 : refOperation(beforeOperationBase.getOperation().getRef()), 1396 block((*refOperation)->getBlock()) {} 1397 1398 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1399 PyOperation &operation = operationBase.getOperation(); 1400 if (operation.isAttached()) 1401 throw SetPyError(PyExc_ValueError, 1402 "Attempt to insert operation that is already attached"); 1403 block.getParentOperation()->checkValid(); 1404 MlirOperation beforeOp = {nullptr}; 1405 if (refOperation) { 1406 // Insert before operation. 1407 (*refOperation)->checkValid(); 1408 beforeOp = (*refOperation)->get(); 1409 } else { 1410 // Insert at end (before null) is only valid if the block does not 1411 // already end in a known terminator (violating this will cause assertion 1412 // failures later). 1413 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1414 throw py::index_error("Cannot insert operation at the end of a block " 1415 "that already has a terminator. Did you mean to " 1416 "use 'InsertionPoint.at_block_terminator(block)' " 1417 "versus 'InsertionPoint(block)'?"); 1418 } 1419 } 1420 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1421 operation.setAttached(); 1422 } 1423 1424 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1425 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1426 if (mlirOperationIsNull(firstOp)) { 1427 // Just insert at end. 1428 return PyInsertionPoint(block); 1429 } 1430 1431 // Insert before first op. 1432 PyOperationRef firstOpRef = PyOperation::forOperation( 1433 block.getParentOperation()->getContext(), firstOp); 1434 return PyInsertionPoint{block, std::move(firstOpRef)}; 1435 } 1436 1437 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1438 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1439 if (mlirOperationIsNull(terminator)) 1440 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1441 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1442 block.getParentOperation()->getContext(), terminator); 1443 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1444 } 1445 1446 py::object PyInsertionPoint::contextEnter() { 1447 return PyThreadContextEntry::pushInsertionPoint(*this); 1448 } 1449 1450 void PyInsertionPoint::contextExit(pybind11::object excType, 1451 pybind11::object excVal, 1452 pybind11::object excTb) { 1453 PyThreadContextEntry::popInsertionPoint(*this); 1454 } 1455 1456 //------------------------------------------------------------------------------ 1457 // PyAttribute. 1458 //------------------------------------------------------------------------------ 1459 1460 bool PyAttribute::operator==(const PyAttribute &other) { 1461 return mlirAttributeEqual(attr, other.attr); 1462 } 1463 1464 py::object PyAttribute::getCapsule() { 1465 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1466 } 1467 1468 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1469 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1470 if (mlirAttributeIsNull(rawAttr)) 1471 throw py::error_already_set(); 1472 return PyAttribute( 1473 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1474 } 1475 1476 //------------------------------------------------------------------------------ 1477 // PyNamedAttribute. 1478 //------------------------------------------------------------------------------ 1479 1480 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1481 : ownedName(new std::string(std::move(ownedName))) { 1482 namedAttr = mlirNamedAttributeGet( 1483 mlirIdentifierGet(mlirAttributeGetContext(attr), 1484 toMlirStringRef(*this->ownedName)), 1485 attr); 1486 } 1487 1488 //------------------------------------------------------------------------------ 1489 // PyType. 1490 //------------------------------------------------------------------------------ 1491 1492 bool PyType::operator==(const PyType &other) { 1493 return mlirTypeEqual(type, other.type); 1494 } 1495 1496 py::object PyType::getCapsule() { 1497 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1498 } 1499 1500 PyType PyType::createFromCapsule(py::object capsule) { 1501 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1502 if (mlirTypeIsNull(rawType)) 1503 throw py::error_already_set(); 1504 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1505 rawType); 1506 } 1507 1508 //------------------------------------------------------------------------------ 1509 // PyValue and subclases. 1510 //------------------------------------------------------------------------------ 1511 1512 pybind11::object PyValue::getCapsule() { 1513 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1514 } 1515 1516 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1517 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1518 if (mlirValueIsNull(value)) 1519 throw py::error_already_set(); 1520 MlirOperation owner; 1521 if (mlirValueIsAOpResult(value)) 1522 owner = mlirOpResultGetOwner(value); 1523 if (mlirValueIsABlockArgument(value)) 1524 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1525 if (mlirOperationIsNull(owner)) 1526 throw py::error_already_set(); 1527 MlirContext ctx = mlirOperationGetContext(owner); 1528 PyOperationRef ownerRef = 1529 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1530 return PyValue(ownerRef, value); 1531 } 1532 1533 //------------------------------------------------------------------------------ 1534 // PySymbolTable. 1535 //------------------------------------------------------------------------------ 1536 1537 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1538 : operation(operation.getOperation().getRef()) { 1539 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1540 if (mlirSymbolTableIsNull(symbolTable)) { 1541 throw py::cast_error("Operation is not a Symbol Table."); 1542 } 1543 } 1544 1545 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1546 operation->checkValid(); 1547 MlirOperation symbol = mlirSymbolTableLookup( 1548 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1549 if (mlirOperationIsNull(symbol)) 1550 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1551 1552 return PyOperation::forOperation(operation->getContext(), symbol, 1553 operation.getObject()) 1554 ->createOpView(); 1555 } 1556 1557 void PySymbolTable::erase(PyOperationBase &symbol) { 1558 operation->checkValid(); 1559 symbol.getOperation().checkValid(); 1560 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1561 // The operation is also erased, so we must invalidate it. There may be Python 1562 // references to this operation so we don't want to delete it from the list of 1563 // live operations here. 1564 symbol.getOperation().valid = false; 1565 } 1566 1567 void PySymbolTable::dunderDel(const std::string &name) { 1568 py::object operation = dunderGetItem(name); 1569 erase(py::cast<PyOperationBase &>(operation)); 1570 } 1571 1572 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1573 operation->checkValid(); 1574 symbol.getOperation().checkValid(); 1575 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1576 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1577 if (mlirAttributeIsNull(symbolAttr)) 1578 throw py::value_error("Expected operation to have a symbol name."); 1579 return PyAttribute( 1580 symbol.getOperation().getContext(), 1581 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1582 } 1583 1584 namespace { 1585 /// CRTP base class for Python MLIR values that subclass Value and should be 1586 /// castable from it. The value hierarchy is one level deep and is not supposed 1587 /// to accommodate other levels unless core MLIR changes. 1588 template <typename DerivedTy> 1589 class PyConcreteValue : public PyValue { 1590 public: 1591 // Derived classes must define statics for: 1592 // IsAFunctionTy isaFunction 1593 // const char *pyClassName 1594 // and redefine bindDerived. 1595 using ClassTy = py::class_<DerivedTy, PyValue>; 1596 using IsAFunctionTy = bool (*)(MlirValue); 1597 1598 PyConcreteValue() = default; 1599 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1600 : PyValue(operationRef, value) {} 1601 PyConcreteValue(PyValue &orig) 1602 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1603 1604 /// Attempts to cast the original value to the derived type and throws on 1605 /// type mismatches. 1606 static MlirValue castFrom(PyValue &orig) { 1607 if (!DerivedTy::isaFunction(orig.get())) { 1608 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1609 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1610 DerivedTy::pyClassName + 1611 " (from " + origRepr + ")"); 1612 } 1613 return orig.get(); 1614 } 1615 1616 /// Binds the Python module objects to functions of this class. 1617 static void bind(py::module &m) { 1618 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1619 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1620 cls.def_static("isinstance", [](PyValue &otherValue) -> bool { 1621 return DerivedTy::isaFunction(otherValue); 1622 }); 1623 DerivedTy::bindDerived(cls); 1624 } 1625 1626 /// Implemented by derived classes to add methods to the Python subclass. 1627 static void bindDerived(ClassTy &m) {} 1628 }; 1629 1630 /// Python wrapper for MlirBlockArgument. 1631 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1632 public: 1633 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1634 static constexpr const char *pyClassName = "BlockArgument"; 1635 using PyConcreteValue::PyConcreteValue; 1636 1637 static void bindDerived(ClassTy &c) { 1638 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1639 return PyBlock(self.getParentOperation(), 1640 mlirBlockArgumentGetOwner(self.get())); 1641 }); 1642 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1643 return mlirBlockArgumentGetArgNumber(self.get()); 1644 }); 1645 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1646 return mlirBlockArgumentSetType(self.get(), type); 1647 }); 1648 } 1649 }; 1650 1651 /// Python wrapper for MlirOpResult. 1652 class PyOpResult : public PyConcreteValue<PyOpResult> { 1653 public: 1654 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1655 static constexpr const char *pyClassName = "OpResult"; 1656 using PyConcreteValue::PyConcreteValue; 1657 1658 static void bindDerived(ClassTy &c) { 1659 c.def_property_readonly("owner", [](PyOpResult &self) { 1660 assert( 1661 mlirOperationEqual(self.getParentOperation()->get(), 1662 mlirOpResultGetOwner(self.get())) && 1663 "expected the owner of the value in Python to match that in the IR"); 1664 return self.getParentOperation().getObject(); 1665 }); 1666 c.def_property_readonly("result_number", [](PyOpResult &self) { 1667 return mlirOpResultGetResultNumber(self.get()); 1668 }); 1669 } 1670 }; 1671 1672 /// Returns the list of types of the values held by container. 1673 template <typename Container> 1674 static std::vector<PyType> getValueTypes(Container &container, 1675 PyMlirContextRef &context) { 1676 std::vector<PyType> result; 1677 result.reserve(container.getNumElements()); 1678 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1679 result.push_back( 1680 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1681 } 1682 return result; 1683 } 1684 1685 /// A list of block arguments. Internally, these are stored as consecutive 1686 /// elements, random access is cheap. The argument list is associated with the 1687 /// operation that contains the block (detached blocks are not allowed in 1688 /// Python bindings) and extends its lifetime. 1689 class PyBlockArgumentList 1690 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1691 public: 1692 static constexpr const char *pyClassName = "BlockArgumentList"; 1693 1694 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1695 intptr_t startIndex = 0, intptr_t length = -1, 1696 intptr_t step = 1) 1697 : Sliceable(startIndex, 1698 length == -1 ? mlirBlockGetNumArguments(block) : length, 1699 step), 1700 operation(std::move(operation)), block(block) {} 1701 1702 /// Returns the number of arguments in the list. 1703 intptr_t getNumElements() { 1704 operation->checkValid(); 1705 return mlirBlockGetNumArguments(block); 1706 } 1707 1708 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1709 PyBlockArgument getElement(intptr_t pos) { 1710 MlirValue argument = mlirBlockGetArgument(block, pos); 1711 return PyBlockArgument(operation, argument); 1712 } 1713 1714 /// Returns a sublist of this list. 1715 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1716 intptr_t step) { 1717 return PyBlockArgumentList(operation, block, startIndex, length, step); 1718 } 1719 1720 static void bindDerived(ClassTy &c) { 1721 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1722 return getValueTypes(self, self.operation->getContext()); 1723 }); 1724 } 1725 1726 private: 1727 PyOperationRef operation; 1728 MlirBlock block; 1729 }; 1730 1731 /// A list of operation operands. Internally, these are stored as consecutive 1732 /// elements, random access is cheap. The result list is associated with the 1733 /// operation whose results these are, and extends the lifetime of this 1734 /// operation. 1735 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1736 public: 1737 static constexpr const char *pyClassName = "OpOperandList"; 1738 1739 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1740 intptr_t length = -1, intptr_t step = 1) 1741 : Sliceable(startIndex, 1742 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1743 : length, 1744 step), 1745 operation(operation) {} 1746 1747 intptr_t getNumElements() { 1748 operation->checkValid(); 1749 return mlirOperationGetNumOperands(operation->get()); 1750 } 1751 1752 PyValue getElement(intptr_t pos) { 1753 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1754 MlirOperation owner; 1755 if (mlirValueIsAOpResult(operand)) 1756 owner = mlirOpResultGetOwner(operand); 1757 else if (mlirValueIsABlockArgument(operand)) 1758 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1759 else 1760 assert(false && "Value must be an block arg or op result."); 1761 PyOperationRef pyOwner = 1762 PyOperation::forOperation(operation->getContext(), owner); 1763 return PyValue(pyOwner, operand); 1764 } 1765 1766 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1767 return PyOpOperandList(operation, startIndex, length, step); 1768 } 1769 1770 void dunderSetItem(intptr_t index, PyValue value) { 1771 index = wrapIndex(index); 1772 mlirOperationSetOperand(operation->get(), index, value.get()); 1773 } 1774 1775 static void bindDerived(ClassTy &c) { 1776 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1777 } 1778 1779 private: 1780 PyOperationRef operation; 1781 }; 1782 1783 /// A list of operation results. Internally, these are stored as consecutive 1784 /// elements, random access is cheap. The result list is associated with the 1785 /// operation whose results these are, and extends the lifetime of this 1786 /// operation. 1787 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1788 public: 1789 static constexpr const char *pyClassName = "OpResultList"; 1790 1791 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1792 intptr_t length = -1, intptr_t step = 1) 1793 : Sliceable(startIndex, 1794 length == -1 ? mlirOperationGetNumResults(operation->get()) 1795 : length, 1796 step), 1797 operation(operation) {} 1798 1799 intptr_t getNumElements() { 1800 operation->checkValid(); 1801 return mlirOperationGetNumResults(operation->get()); 1802 } 1803 1804 PyOpResult getElement(intptr_t index) { 1805 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1806 return PyOpResult(value); 1807 } 1808 1809 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1810 return PyOpResultList(operation, startIndex, length, step); 1811 } 1812 1813 static void bindDerived(ClassTy &c) { 1814 c.def_property_readonly("types", [](PyOpResultList &self) { 1815 return getValueTypes(self, self.operation->getContext()); 1816 }); 1817 } 1818 1819 private: 1820 PyOperationRef operation; 1821 }; 1822 1823 /// A list of operation attributes. Can be indexed by name, producing 1824 /// attributes, or by index, producing named attributes. 1825 class PyOpAttributeMap { 1826 public: 1827 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1828 1829 PyAttribute dunderGetItemNamed(const std::string &name) { 1830 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1831 toMlirStringRef(name)); 1832 if (mlirAttributeIsNull(attr)) { 1833 throw SetPyError(PyExc_KeyError, 1834 "attempt to access a non-existent attribute"); 1835 } 1836 return PyAttribute(operation->getContext(), attr); 1837 } 1838 1839 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1840 if (index < 0 || index >= dunderLen()) { 1841 throw SetPyError(PyExc_IndexError, 1842 "attempt to access out of bounds attribute"); 1843 } 1844 MlirNamedAttribute namedAttr = 1845 mlirOperationGetAttribute(operation->get(), index); 1846 return PyNamedAttribute( 1847 namedAttr.attribute, 1848 std::string(mlirIdentifierStr(namedAttr.name).data)); 1849 } 1850 1851 void dunderSetItem(const std::string &name, PyAttribute attr) { 1852 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1853 attr); 1854 } 1855 1856 void dunderDelItem(const std::string &name) { 1857 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1858 toMlirStringRef(name)); 1859 if (!removed) 1860 throw SetPyError(PyExc_KeyError, 1861 "attempt to delete a non-existent attribute"); 1862 } 1863 1864 intptr_t dunderLen() { 1865 return mlirOperationGetNumAttributes(operation->get()); 1866 } 1867 1868 bool dunderContains(const std::string &name) { 1869 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1870 operation->get(), toMlirStringRef(name))); 1871 } 1872 1873 static void bind(py::module &m) { 1874 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1875 .def("__contains__", &PyOpAttributeMap::dunderContains) 1876 .def("__len__", &PyOpAttributeMap::dunderLen) 1877 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1878 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1879 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1880 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1881 } 1882 1883 private: 1884 PyOperationRef operation; 1885 }; 1886 1887 } // end namespace 1888 1889 //------------------------------------------------------------------------------ 1890 // Populates the core exports of the 'ir' submodule. 1891 //------------------------------------------------------------------------------ 1892 1893 void mlir::python::populateIRCore(py::module &m) { 1894 //---------------------------------------------------------------------------- 1895 // Mapping of MlirContext. 1896 //---------------------------------------------------------------------------- 1897 py::class_<PyMlirContext>(m, "Context", py::module_local()) 1898 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1899 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1900 .def("_get_context_again", 1901 [](PyMlirContext &self) { 1902 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1903 return ref.releaseObject(); 1904 }) 1905 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1906 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1907 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1908 &PyMlirContext::getCapsule) 1909 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1910 .def("__enter__", &PyMlirContext::contextEnter) 1911 .def("__exit__", &PyMlirContext::contextExit) 1912 .def_property_readonly_static( 1913 "current", 1914 [](py::object & /*class*/) { 1915 auto *context = PyThreadContextEntry::getDefaultContext(); 1916 if (!context) 1917 throw SetPyError(PyExc_ValueError, "No current Context"); 1918 return context; 1919 }, 1920 "Gets the Context bound to the current thread or raises ValueError") 1921 .def_property_readonly( 1922 "dialects", 1923 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1924 "Gets a container for accessing dialects by name") 1925 .def_property_readonly( 1926 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1927 "Alias for 'dialect'") 1928 .def( 1929 "get_dialect_descriptor", 1930 [=](PyMlirContext &self, std::string &name) { 1931 MlirDialect dialect = mlirContextGetOrLoadDialect( 1932 self.get(), {name.data(), name.size()}); 1933 if (mlirDialectIsNull(dialect)) { 1934 throw SetPyError(PyExc_ValueError, 1935 Twine("Dialect '") + name + "' not found"); 1936 } 1937 return PyDialectDescriptor(self.getRef(), dialect); 1938 }, 1939 "Gets or loads a dialect by name, returning its descriptor object") 1940 .def_property( 1941 "allow_unregistered_dialects", 1942 [](PyMlirContext &self) -> bool { 1943 return mlirContextGetAllowUnregisteredDialects(self.get()); 1944 }, 1945 [](PyMlirContext &self, bool value) { 1946 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1947 }) 1948 .def("enable_multithreading", 1949 [](PyMlirContext &self, bool enable) { 1950 mlirContextEnableMultithreading(self.get(), enable); 1951 }) 1952 .def("is_registered_operation", 1953 [](PyMlirContext &self, std::string &name) { 1954 return mlirContextIsRegisteredOperation( 1955 self.get(), MlirStringRef{name.data(), name.size()}); 1956 }); 1957 1958 //---------------------------------------------------------------------------- 1959 // Mapping of PyDialectDescriptor 1960 //---------------------------------------------------------------------------- 1961 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1962 .def_property_readonly("namespace", 1963 [](PyDialectDescriptor &self) { 1964 MlirStringRef ns = 1965 mlirDialectGetNamespace(self.get()); 1966 return py::str(ns.data, ns.length); 1967 }) 1968 .def("__repr__", [](PyDialectDescriptor &self) { 1969 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1970 std::string repr("<DialectDescriptor "); 1971 repr.append(ns.data, ns.length); 1972 repr.append(">"); 1973 return repr; 1974 }); 1975 1976 //---------------------------------------------------------------------------- 1977 // Mapping of PyDialects 1978 //---------------------------------------------------------------------------- 1979 py::class_<PyDialects>(m, "Dialects", py::module_local()) 1980 .def("__getitem__", 1981 [=](PyDialects &self, std::string keyName) { 1982 MlirDialect dialect = 1983 self.getDialectForKey(keyName, /*attrError=*/false); 1984 py::object descriptor = 1985 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1986 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1987 }) 1988 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1989 MlirDialect dialect = 1990 self.getDialectForKey(attrName, /*attrError=*/true); 1991 py::object descriptor = 1992 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1993 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1994 }); 1995 1996 //---------------------------------------------------------------------------- 1997 // Mapping of PyDialect 1998 //---------------------------------------------------------------------------- 1999 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2000 .def(py::init<py::object>(), "descriptor") 2001 .def_property_readonly( 2002 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2003 .def("__repr__", [](py::object self) { 2004 auto clazz = self.attr("__class__"); 2005 return py::str("<Dialect ") + 2006 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2007 clazz.attr("__module__") + py::str(".") + 2008 clazz.attr("__name__") + py::str(")>"); 2009 }); 2010 2011 //---------------------------------------------------------------------------- 2012 // Mapping of Location 2013 //---------------------------------------------------------------------------- 2014 py::class_<PyLocation>(m, "Location", py::module_local()) 2015 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2016 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2017 .def("__enter__", &PyLocation::contextEnter) 2018 .def("__exit__", &PyLocation::contextExit) 2019 .def("__eq__", 2020 [](PyLocation &self, PyLocation &other) -> bool { 2021 return mlirLocationEqual(self, other); 2022 }) 2023 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2024 .def_property_readonly_static( 2025 "current", 2026 [](py::object & /*class*/) { 2027 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2028 if (!loc) 2029 throw SetPyError(PyExc_ValueError, "No current Location"); 2030 return loc; 2031 }, 2032 "Gets the Location bound to the current thread or raises ValueError") 2033 .def_static( 2034 "unknown", 2035 [](DefaultingPyMlirContext context) { 2036 return PyLocation(context->getRef(), 2037 mlirLocationUnknownGet(context->get())); 2038 }, 2039 py::arg("context") = py::none(), 2040 "Gets a Location representing an unknown location") 2041 .def_static( 2042 "callsite", 2043 [](PyLocation callee, const std::vector<PyLocation> &frames, 2044 DefaultingPyMlirContext context) { 2045 if (frames.empty()) 2046 throw py::value_error("No caller frames provided"); 2047 MlirLocation caller = frames.back().get(); 2048 for (const PyLocation &frame : 2049 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2050 caller = mlirLocationCallSiteGet(frame.get(), caller); 2051 return PyLocation(context->getRef(), 2052 mlirLocationCallSiteGet(callee.get(), caller)); 2053 }, 2054 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2055 kContextGetCallSiteLocationDocstring) 2056 .def_static( 2057 "file", 2058 [](std::string filename, int line, int col, 2059 DefaultingPyMlirContext context) { 2060 return PyLocation( 2061 context->getRef(), 2062 mlirLocationFileLineColGet( 2063 context->get(), toMlirStringRef(filename), line, col)); 2064 }, 2065 py::arg("filename"), py::arg("line"), py::arg("col"), 2066 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2067 .def_static( 2068 "name", 2069 [](std::string name, llvm::Optional<PyLocation> childLoc, 2070 DefaultingPyMlirContext context) { 2071 return PyLocation( 2072 context->getRef(), 2073 mlirLocationNameGet( 2074 context->get(), toMlirStringRef(name), 2075 childLoc ? childLoc->get() 2076 : mlirLocationUnknownGet(context->get()))); 2077 }, 2078 py::arg("name"), py::arg("childLoc") = py::none(), 2079 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2080 .def_property_readonly( 2081 "context", 2082 [](PyLocation &self) { return self.getContext().getObject(); }, 2083 "Context that owns the Location") 2084 .def("__repr__", [](PyLocation &self) { 2085 PyPrintAccumulator printAccum; 2086 mlirLocationPrint(self, printAccum.getCallback(), 2087 printAccum.getUserData()); 2088 return printAccum.join(); 2089 }); 2090 2091 //---------------------------------------------------------------------------- 2092 // Mapping of Module 2093 //---------------------------------------------------------------------------- 2094 py::class_<PyModule>(m, "Module", py::module_local()) 2095 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2096 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2097 .def_static( 2098 "parse", 2099 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2100 MlirModule module = mlirModuleCreateParse( 2101 context->get(), toMlirStringRef(moduleAsm)); 2102 // TODO: Rework error reporting once diagnostic engine is exposed 2103 // in C API. 2104 if (mlirModuleIsNull(module)) { 2105 throw SetPyError( 2106 PyExc_ValueError, 2107 "Unable to parse module assembly (see diagnostics)"); 2108 } 2109 return PyModule::forModule(module).releaseObject(); 2110 }, 2111 py::arg("asm"), py::arg("context") = py::none(), 2112 kModuleParseDocstring) 2113 .def_static( 2114 "create", 2115 [](DefaultingPyLocation loc) { 2116 MlirModule module = mlirModuleCreateEmpty(loc); 2117 return PyModule::forModule(module).releaseObject(); 2118 }, 2119 py::arg("loc") = py::none(), "Creates an empty module") 2120 .def_property_readonly( 2121 "context", 2122 [](PyModule &self) { return self.getContext().getObject(); }, 2123 "Context that created the Module") 2124 .def_property_readonly( 2125 "operation", 2126 [](PyModule &self) { 2127 return PyOperation::forOperation(self.getContext(), 2128 mlirModuleGetOperation(self.get()), 2129 self.getRef().releaseObject()) 2130 .releaseObject(); 2131 }, 2132 "Accesses the module as an operation") 2133 .def_property_readonly( 2134 "body", 2135 [](PyModule &self) { 2136 PyOperationRef module_op = PyOperation::forOperation( 2137 self.getContext(), mlirModuleGetOperation(self.get()), 2138 self.getRef().releaseObject()); 2139 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2140 return returnBlock; 2141 }, 2142 "Return the block for this module") 2143 .def( 2144 "dump", 2145 [](PyModule &self) { 2146 mlirOperationDump(mlirModuleGetOperation(self.get())); 2147 }, 2148 kDumpDocstring) 2149 .def( 2150 "__str__", 2151 [](PyModule &self) { 2152 MlirOperation operation = mlirModuleGetOperation(self.get()); 2153 PyPrintAccumulator printAccum; 2154 mlirOperationPrint(operation, printAccum.getCallback(), 2155 printAccum.getUserData()); 2156 return printAccum.join(); 2157 }, 2158 kOperationStrDunderDocstring); 2159 2160 //---------------------------------------------------------------------------- 2161 // Mapping of Operation. 2162 //---------------------------------------------------------------------------- 2163 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2164 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2165 [](PyOperationBase &self) { 2166 return self.getOperation().getCapsule(); 2167 }) 2168 .def("__eq__", 2169 [](PyOperationBase &self, PyOperationBase &other) { 2170 return &self.getOperation() == &other.getOperation(); 2171 }) 2172 .def("__eq__", 2173 [](PyOperationBase &self, py::object other) { return false; }) 2174 .def("__hash__", 2175 [](PyOperationBase &self) { 2176 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2177 }) 2178 .def_property_readonly("attributes", 2179 [](PyOperationBase &self) { 2180 return PyOpAttributeMap( 2181 self.getOperation().getRef()); 2182 }) 2183 .def_property_readonly("operands", 2184 [](PyOperationBase &self) { 2185 return PyOpOperandList( 2186 self.getOperation().getRef()); 2187 }) 2188 .def_property_readonly("regions", 2189 [](PyOperationBase &self) { 2190 return PyRegionList( 2191 self.getOperation().getRef()); 2192 }) 2193 .def_property_readonly( 2194 "results", 2195 [](PyOperationBase &self) { 2196 return PyOpResultList(self.getOperation().getRef()); 2197 }, 2198 "Returns the list of Operation results.") 2199 .def_property_readonly( 2200 "result", 2201 [](PyOperationBase &self) { 2202 auto &operation = self.getOperation(); 2203 auto numResults = mlirOperationGetNumResults(operation); 2204 if (numResults != 1) { 2205 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2206 throw SetPyError( 2207 PyExc_ValueError, 2208 Twine("Cannot call .result on operation ") + 2209 StringRef(name.data, name.length) + " which has " + 2210 Twine(numResults) + 2211 " results (it is only valid for operations with a " 2212 "single result)"); 2213 } 2214 return PyOpResult(operation.getRef(), 2215 mlirOperationGetResult(operation, 0)); 2216 }, 2217 "Shortcut to get an op result if it has only one (throws an error " 2218 "otherwise).") 2219 .def_property_readonly( 2220 "location", 2221 [](PyOperationBase &self) { 2222 PyOperation &operation = self.getOperation(); 2223 return PyLocation(operation.getContext(), 2224 mlirOperationGetLocation(operation.get())); 2225 }, 2226 "Returns the source location the operation was defined or derived " 2227 "from.") 2228 .def( 2229 "__str__", 2230 [](PyOperationBase &self) { 2231 return self.getAsm(/*binary=*/false, 2232 /*largeElementsLimit=*/llvm::None, 2233 /*enableDebugInfo=*/false, 2234 /*prettyDebugInfo=*/false, 2235 /*printGenericOpForm=*/false, 2236 /*useLocalScope=*/false); 2237 }, 2238 "Returns the assembly form of the operation.") 2239 .def("print", &PyOperationBase::print, 2240 // Careful: Lots of arguments must match up with print method. 2241 py::arg("file") = py::none(), py::arg("binary") = false, 2242 py::arg("large_elements_limit") = py::none(), 2243 py::arg("enable_debug_info") = false, 2244 py::arg("pretty_debug_info") = false, 2245 py::arg("print_generic_op_form") = false, 2246 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2247 .def("get_asm", &PyOperationBase::getAsm, 2248 // Careful: Lots of arguments must match up with get_asm method. 2249 py::arg("binary") = false, 2250 py::arg("large_elements_limit") = py::none(), 2251 py::arg("enable_debug_info") = false, 2252 py::arg("pretty_debug_info") = false, 2253 py::arg("print_generic_op_form") = false, 2254 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2255 .def( 2256 "verify", 2257 [](PyOperationBase &self) { 2258 return mlirOperationVerify(self.getOperation()); 2259 }, 2260 "Verify the operation and return true if it passes, false if it " 2261 "fails.") 2262 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2263 "Puts self immediately after the other operation in its parent " 2264 "block.") 2265 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2266 "Puts self immediately before the other operation in its parent " 2267 "block.") 2268 .def( 2269 "detach_from_parent", 2270 [](PyOperationBase &self) { 2271 PyOperation &operation = self.getOperation(); 2272 operation.checkValid(); 2273 if (!operation.isAttached()) 2274 throw py::value_error("Detached operation has no parent."); 2275 2276 operation.detachFromParent(); 2277 return operation.createOpView(); 2278 }, 2279 "Detaches the operation from its parent block."); 2280 2281 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2282 .def_static("create", &PyOperation::create, py::arg("name"), 2283 py::arg("results") = py::none(), 2284 py::arg("operands") = py::none(), 2285 py::arg("attributes") = py::none(), 2286 py::arg("successors") = py::none(), py::arg("regions") = 0, 2287 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2288 kOperationCreateDocstring) 2289 .def_property_readonly("parent", 2290 [](PyOperation &self) -> py::object { 2291 auto parent = self.getParentOperation(); 2292 if (parent) 2293 return parent->getObject(); 2294 return py::none(); 2295 }) 2296 .def("erase", &PyOperation::erase) 2297 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2298 &PyOperation::getCapsule) 2299 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2300 .def_property_readonly("name", 2301 [](PyOperation &self) { 2302 self.checkValid(); 2303 MlirOperation operation = self.get(); 2304 MlirStringRef name = mlirIdentifierStr( 2305 mlirOperationGetName(operation)); 2306 return py::str(name.data, name.length); 2307 }) 2308 .def_property_readonly( 2309 "context", 2310 [](PyOperation &self) { 2311 self.checkValid(); 2312 return self.getContext().getObject(); 2313 }, 2314 "Context that owns the Operation") 2315 .def_property_readonly("opview", &PyOperation::createOpView); 2316 2317 auto opViewClass = 2318 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2319 .def(py::init<py::object>()) 2320 .def_property_readonly("operation", &PyOpView::getOperationObject) 2321 .def_property_readonly( 2322 "context", 2323 [](PyOpView &self) { 2324 return self.getOperation().getContext().getObject(); 2325 }, 2326 "Context that owns the Operation") 2327 .def("__str__", [](PyOpView &self) { 2328 return py::str(self.getOperationObject()); 2329 }); 2330 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2331 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2332 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2333 opViewClass.attr("build_generic") = classmethod( 2334 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2335 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2336 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2337 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2338 "Builds a specific, generated OpView based on class level attributes."); 2339 2340 //---------------------------------------------------------------------------- 2341 // Mapping of PyRegion. 2342 //---------------------------------------------------------------------------- 2343 py::class_<PyRegion>(m, "Region", py::module_local()) 2344 .def_property_readonly( 2345 "blocks", 2346 [](PyRegion &self) { 2347 return PyBlockList(self.getParentOperation(), self.get()); 2348 }, 2349 "Returns a forward-optimized sequence of blocks.") 2350 .def_property_readonly( 2351 "owner", 2352 [](PyRegion &self) { 2353 return self.getParentOperation()->createOpView(); 2354 }, 2355 "Returns the operation owning this region.") 2356 .def( 2357 "__iter__", 2358 [](PyRegion &self) { 2359 self.checkValid(); 2360 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2361 return PyBlockIterator(self.getParentOperation(), firstBlock); 2362 }, 2363 "Iterates over blocks in the region.") 2364 .def("__eq__", 2365 [](PyRegion &self, PyRegion &other) { 2366 return self.get().ptr == other.get().ptr; 2367 }) 2368 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2369 2370 //---------------------------------------------------------------------------- 2371 // Mapping of PyBlock. 2372 //---------------------------------------------------------------------------- 2373 py::class_<PyBlock>(m, "Block", py::module_local()) 2374 .def_property_readonly( 2375 "owner", 2376 [](PyBlock &self) { 2377 return self.getParentOperation()->createOpView(); 2378 }, 2379 "Returns the owning operation of this block.") 2380 .def_property_readonly( 2381 "region", 2382 [](PyBlock &self) { 2383 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2384 return PyRegion(self.getParentOperation(), region); 2385 }, 2386 "Returns the owning region of this block.") 2387 .def_property_readonly( 2388 "arguments", 2389 [](PyBlock &self) { 2390 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2391 }, 2392 "Returns a list of block arguments.") 2393 .def_property_readonly( 2394 "operations", 2395 [](PyBlock &self) { 2396 return PyOperationList(self.getParentOperation(), self.get()); 2397 }, 2398 "Returns a forward-optimized sequence of operations.") 2399 .def_static( 2400 "create_at_start", 2401 [](PyRegion &parent, py::list pyArgTypes) { 2402 parent.checkValid(); 2403 llvm::SmallVector<MlirType, 4> argTypes; 2404 argTypes.reserve(pyArgTypes.size()); 2405 for (auto &pyArg : pyArgTypes) { 2406 argTypes.push_back(pyArg.cast<PyType &>()); 2407 } 2408 2409 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2410 mlirRegionInsertOwnedBlock(parent, 0, block); 2411 return PyBlock(parent.getParentOperation(), block); 2412 }, 2413 py::arg("parent"), py::arg("pyArgTypes") = py::list(), 2414 "Creates and returns a new Block at the beginning of the given " 2415 "region (with given argument types).") 2416 .def( 2417 "create_before", 2418 [](PyBlock &self, py::args pyArgTypes) { 2419 self.checkValid(); 2420 llvm::SmallVector<MlirType, 4> argTypes; 2421 argTypes.reserve(pyArgTypes.size()); 2422 for (auto &pyArg : pyArgTypes) { 2423 argTypes.push_back(pyArg.cast<PyType &>()); 2424 } 2425 2426 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2427 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2428 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2429 return PyBlock(self.getParentOperation(), block); 2430 }, 2431 "Creates and returns a new Block before this block " 2432 "(with given argument types).") 2433 .def( 2434 "create_after", 2435 [](PyBlock &self, py::args pyArgTypes) { 2436 self.checkValid(); 2437 llvm::SmallVector<MlirType, 4> argTypes; 2438 argTypes.reserve(pyArgTypes.size()); 2439 for (auto &pyArg : pyArgTypes) { 2440 argTypes.push_back(pyArg.cast<PyType &>()); 2441 } 2442 2443 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2444 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2445 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2446 return PyBlock(self.getParentOperation(), block); 2447 }, 2448 "Creates and returns a new Block after this block " 2449 "(with given argument types).") 2450 .def( 2451 "__iter__", 2452 [](PyBlock &self) { 2453 self.checkValid(); 2454 MlirOperation firstOperation = 2455 mlirBlockGetFirstOperation(self.get()); 2456 return PyOperationIterator(self.getParentOperation(), 2457 firstOperation); 2458 }, 2459 "Iterates over operations in the block.") 2460 .def("__eq__", 2461 [](PyBlock &self, PyBlock &other) { 2462 return self.get().ptr == other.get().ptr; 2463 }) 2464 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2465 .def( 2466 "__str__", 2467 [](PyBlock &self) { 2468 self.checkValid(); 2469 PyPrintAccumulator printAccum; 2470 mlirBlockPrint(self.get(), printAccum.getCallback(), 2471 printAccum.getUserData()); 2472 return printAccum.join(); 2473 }, 2474 "Returns the assembly form of the block.") 2475 .def( 2476 "append", 2477 [](PyBlock &self, PyOperationBase &operation) { 2478 if (operation.getOperation().isAttached()) 2479 operation.getOperation().detachFromParent(); 2480 2481 MlirOperation mlirOperation = operation.getOperation().get(); 2482 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2483 operation.getOperation().setAttached( 2484 self.getParentOperation().getObject()); 2485 }, 2486 "Appends an operation to this block. If the operation is currently " 2487 "in another block, it will be moved."); 2488 2489 //---------------------------------------------------------------------------- 2490 // Mapping of PyInsertionPoint. 2491 //---------------------------------------------------------------------------- 2492 2493 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2494 .def(py::init<PyBlock &>(), py::arg("block"), 2495 "Inserts after the last operation but still inside the block.") 2496 .def("__enter__", &PyInsertionPoint::contextEnter) 2497 .def("__exit__", &PyInsertionPoint::contextExit) 2498 .def_property_readonly_static( 2499 "current", 2500 [](py::object & /*class*/) { 2501 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2502 if (!ip) 2503 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2504 return ip; 2505 }, 2506 "Gets the InsertionPoint bound to the current thread or raises " 2507 "ValueError if none has been set") 2508 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2509 "Inserts before a referenced operation.") 2510 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2511 py::arg("block"), "Inserts at the beginning of the block.") 2512 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2513 py::arg("block"), "Inserts before the block terminator.") 2514 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2515 "Inserts an operation.") 2516 .def_property_readonly( 2517 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2518 "Returns the block that this InsertionPoint points to."); 2519 2520 //---------------------------------------------------------------------------- 2521 // Mapping of PyAttribute. 2522 //---------------------------------------------------------------------------- 2523 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2524 // Delegate to the PyAttribute copy constructor, which will also lifetime 2525 // extend the backing context which owns the MlirAttribute. 2526 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2527 "Casts the passed attribute to the generic Attribute") 2528 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2529 &PyAttribute::getCapsule) 2530 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2531 .def_static( 2532 "parse", 2533 [](std::string attrSpec, DefaultingPyMlirContext context) { 2534 MlirAttribute type = mlirAttributeParseGet( 2535 context->get(), toMlirStringRef(attrSpec)); 2536 // TODO: Rework error reporting once diagnostic engine is exposed 2537 // in C API. 2538 if (mlirAttributeIsNull(type)) { 2539 throw SetPyError(PyExc_ValueError, 2540 Twine("Unable to parse attribute: '") + 2541 attrSpec + "'"); 2542 } 2543 return PyAttribute(context->getRef(), type); 2544 }, 2545 py::arg("asm"), py::arg("context") = py::none(), 2546 "Parses an attribute from an assembly form") 2547 .def_property_readonly( 2548 "context", 2549 [](PyAttribute &self) { return self.getContext().getObject(); }, 2550 "Context that owns the Attribute") 2551 .def_property_readonly("type", 2552 [](PyAttribute &self) { 2553 return PyType(self.getContext()->getRef(), 2554 mlirAttributeGetType(self)); 2555 }) 2556 .def( 2557 "get_named", 2558 [](PyAttribute &self, std::string name) { 2559 return PyNamedAttribute(self, std::move(name)); 2560 }, 2561 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2562 .def("__eq__", 2563 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2564 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2565 .def("__hash__", 2566 [](PyAttribute &self) { 2567 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2568 }) 2569 .def( 2570 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2571 kDumpDocstring) 2572 .def( 2573 "__str__", 2574 [](PyAttribute &self) { 2575 PyPrintAccumulator printAccum; 2576 mlirAttributePrint(self, printAccum.getCallback(), 2577 printAccum.getUserData()); 2578 return printAccum.join(); 2579 }, 2580 "Returns the assembly form of the Attribute.") 2581 .def("__repr__", [](PyAttribute &self) { 2582 // Generally, assembly formats are not printed for __repr__ because 2583 // this can cause exceptionally long debug output and exceptions. 2584 // However, attribute values are generally considered useful and are 2585 // printed. This may need to be re-evaluated if debug dumps end up 2586 // being excessive. 2587 PyPrintAccumulator printAccum; 2588 printAccum.parts.append("Attribute("); 2589 mlirAttributePrint(self, printAccum.getCallback(), 2590 printAccum.getUserData()); 2591 printAccum.parts.append(")"); 2592 return printAccum.join(); 2593 }); 2594 2595 //---------------------------------------------------------------------------- 2596 // Mapping of PyNamedAttribute 2597 //---------------------------------------------------------------------------- 2598 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2599 .def("__repr__", 2600 [](PyNamedAttribute &self) { 2601 PyPrintAccumulator printAccum; 2602 printAccum.parts.append("NamedAttribute("); 2603 printAccum.parts.append( 2604 mlirIdentifierStr(self.namedAttr.name).data); 2605 printAccum.parts.append("="); 2606 mlirAttributePrint(self.namedAttr.attribute, 2607 printAccum.getCallback(), 2608 printAccum.getUserData()); 2609 printAccum.parts.append(")"); 2610 return printAccum.join(); 2611 }) 2612 .def_property_readonly( 2613 "name", 2614 [](PyNamedAttribute &self) { 2615 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2616 mlirIdentifierStr(self.namedAttr.name).length); 2617 }, 2618 "The name of the NamedAttribute binding") 2619 .def_property_readonly( 2620 "attr", 2621 [](PyNamedAttribute &self) { 2622 // TODO: When named attribute is removed/refactored, also remove 2623 // this constructor (it does an inefficient table lookup). 2624 auto contextRef = PyMlirContext::forContext( 2625 mlirAttributeGetContext(self.namedAttr.attribute)); 2626 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2627 }, 2628 py::keep_alive<0, 1>(), 2629 "The underlying generic attribute of the NamedAttribute binding"); 2630 2631 //---------------------------------------------------------------------------- 2632 // Mapping of PyType. 2633 //---------------------------------------------------------------------------- 2634 py::class_<PyType>(m, "Type", py::module_local()) 2635 // Delegate to the PyType copy constructor, which will also lifetime 2636 // extend the backing context which owns the MlirType. 2637 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2638 "Casts the passed type to the generic Type") 2639 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2640 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2641 .def_static( 2642 "parse", 2643 [](std::string typeSpec, DefaultingPyMlirContext context) { 2644 MlirType type = 2645 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2646 // TODO: Rework error reporting once diagnostic engine is exposed 2647 // in C API. 2648 if (mlirTypeIsNull(type)) { 2649 throw SetPyError(PyExc_ValueError, 2650 Twine("Unable to parse type: '") + typeSpec + 2651 "'"); 2652 } 2653 return PyType(context->getRef(), type); 2654 }, 2655 py::arg("asm"), py::arg("context") = py::none(), 2656 kContextParseTypeDocstring) 2657 .def_property_readonly( 2658 "context", [](PyType &self) { return self.getContext().getObject(); }, 2659 "Context that owns the Type") 2660 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2661 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2662 .def("__hash__", 2663 [](PyType &self) { 2664 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2665 }) 2666 .def( 2667 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2668 .def( 2669 "__str__", 2670 [](PyType &self) { 2671 PyPrintAccumulator printAccum; 2672 mlirTypePrint(self, printAccum.getCallback(), 2673 printAccum.getUserData()); 2674 return printAccum.join(); 2675 }, 2676 "Returns the assembly form of the type.") 2677 .def("__repr__", [](PyType &self) { 2678 // Generally, assembly formats are not printed for __repr__ because 2679 // this can cause exceptionally long debug output and exceptions. 2680 // However, types are an exception as they typically have compact 2681 // assembly forms and printing them is useful. 2682 PyPrintAccumulator printAccum; 2683 printAccum.parts.append("Type("); 2684 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2685 printAccum.parts.append(")"); 2686 return printAccum.join(); 2687 }); 2688 2689 //---------------------------------------------------------------------------- 2690 // Mapping of Value. 2691 //---------------------------------------------------------------------------- 2692 py::class_<PyValue>(m, "Value", py::module_local()) 2693 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2694 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2695 .def_property_readonly( 2696 "context", 2697 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2698 "Context in which the value lives.") 2699 .def( 2700 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2701 kDumpDocstring) 2702 .def_property_readonly( 2703 "owner", 2704 [](PyValue &self) { 2705 assert(mlirOperationEqual(self.getParentOperation()->get(), 2706 mlirOpResultGetOwner(self.get())) && 2707 "expected the owner of the value in Python to match that in " 2708 "the IR"); 2709 return self.getParentOperation().getObject(); 2710 }) 2711 .def("__eq__", 2712 [](PyValue &self, PyValue &other) { 2713 return self.get().ptr == other.get().ptr; 2714 }) 2715 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2716 .def("__hash__", 2717 [](PyValue &self) { 2718 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2719 }) 2720 .def( 2721 "__str__", 2722 [](PyValue &self) { 2723 PyPrintAccumulator printAccum; 2724 printAccum.parts.append("Value("); 2725 mlirValuePrint(self.get(), printAccum.getCallback(), 2726 printAccum.getUserData()); 2727 printAccum.parts.append(")"); 2728 return printAccum.join(); 2729 }, 2730 kValueDunderStrDocstring) 2731 .def_property_readonly("type", [](PyValue &self) { 2732 return PyType(self.getParentOperation()->getContext(), 2733 mlirValueGetType(self.get())); 2734 }); 2735 PyBlockArgument::bind(m); 2736 PyOpResult::bind(m); 2737 2738 //---------------------------------------------------------------------------- 2739 // Mapping of SymbolTable. 2740 //---------------------------------------------------------------------------- 2741 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 2742 .def(py::init<PyOperationBase &>()) 2743 .def("__getitem__", &PySymbolTable::dunderGetItem) 2744 .def("insert", &PySymbolTable::insert) 2745 .def("erase", &PySymbolTable::erase) 2746 .def("__delitem__", &PySymbolTable::dunderDel) 2747 .def("__contains__", [](PySymbolTable &table, const std::string &name) { 2748 return !mlirOperationIsNull(mlirSymbolTableLookup( 2749 table, mlirStringRefCreate(name.data(), name.length()))); 2750 }); 2751 2752 // Container bindings. 2753 PyBlockArgumentList::bind(m); 2754 PyBlockIterator::bind(m); 2755 PyBlockList::bind(m); 2756 PyOperationIterator::bind(m); 2757 PyOperationList::bind(m); 2758 PyOpAttributeMap::bind(m); 2759 PyOpOperandList::bind(m); 2760 PyOpResultList::bind(m); 2761 PyRegionIterator::bind(m); 2762 PyRegionList::bind(m); 2763 2764 // Debug bindings. 2765 PyGlobalDebugFlag::bind(m); 2766 } 2767