1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 namespace py = pybind11; 25 using namespace mlir; 26 using namespace mlir::python; 27 28 using llvm::SmallVector; 29 using llvm::StringRef; 30 using llvm::Twine; 31 32 //------------------------------------------------------------------------------ 33 // Docstrings (trivial, non-duplicated docstrings are included inline). 34 //------------------------------------------------------------------------------ 35 36 static const char kContextParseTypeDocstring[] = 37 R"(Parses the assembly form of a type. 38 39 Returns a Type object or raises a ValueError if the type cannot be parsed. 40 41 See also: https://mlir.llvm.org/docs/LangRef/#type-system 42 )"; 43 44 static const char kContextGetCallSiteLocationDocstring[] = 45 R"(Gets a Location representing a caller and callsite)"; 46 47 static const char kContextGetFileLocationDocstring[] = 48 R"(Gets a Location representing a file, line and column)"; 49 50 static const char kContextGetNameLocationDocString[] = 51 R"(Gets a Location representing a named location with optional child location)"; 52 53 static const char kModuleParseDocstring[] = 54 R"(Parses a module's assembly format from a string. 55 56 Returns a new MlirModule or raises a ValueError if the parsing fails. 57 58 See also: https://mlir.llvm.org/docs/LangRef/ 59 )"; 60 61 static const char kOperationCreateDocstring[] = 62 R"(Creates a new operation. 63 64 Args: 65 name: Operation name (e.g. "dialect.operation"). 66 results: Sequence of Type representing op result types. 67 attributes: Dict of str:Attribute. 68 successors: List of Block for the operation's successors. 69 regions: Number of regions to create. 70 location: A Location object (defaults to resolve from context manager). 71 ip: An InsertionPoint (defaults to resolve from context manager or set to 72 False to disable insertion, even with an insertion point set in the 73 context manager). 74 Returns: 75 A new "detached" Operation object. Detached operations can be added 76 to blocks, which causes them to become "attached." 77 )"; 78 79 static const char kOperationPrintDocstring[] = 80 R"(Prints the assembly form of the operation to a file like object. 81 82 Args: 83 file: The file like object to write to. Defaults to sys.stdout. 84 binary: Whether to write bytes (True) or str (False). Defaults to False. 85 large_elements_limit: Whether to elide elements attributes above this 86 number of elements. Defaults to None (no limit). 87 enable_debug_info: Whether to print debug/location information. Defaults 88 to False. 89 pretty_debug_info: Whether to format debug information for easier reading 90 by a human (warning: the result is unparseable). 91 print_generic_op_form: Whether to print the generic assembly forms of all 92 ops. Defaults to False. 93 use_local_Scope: Whether to print in a way that is more optimized for 94 multi-threaded access but may not be consistent with how the overall 95 module prints. 96 )"; 97 98 static const char kOperationGetAsmDocstring[] = 99 R"(Gets the assembly form of the operation with all options available. 100 101 Args: 102 binary: Whether to return a bytes (True) or str (False) object. Defaults to 103 False. 104 ... others ...: See the print() method for common keyword arguments for 105 configuring the printout. 106 Returns: 107 Either a bytes or str object, depending on the setting of the 'binary' 108 argument. 109 )"; 110 111 static const char kOperationStrDunderDocstring[] = 112 R"(Gets the assembly form of the operation with default options. 113 114 If more advanced control over the assembly formatting or I/O options is needed, 115 use the dedicated print or get_asm method, which supports keyword arguments to 116 customize behavior. 117 )"; 118 119 static const char kDumpDocstring[] = 120 R"(Dumps a debug representation of the object to stderr.)"; 121 122 static const char kAppendBlockDocstring[] = 123 R"(Appends a new block, with argument types as positional args. 124 125 Returns: 126 The created block. 127 )"; 128 129 static const char kValueDunderStrDocstring[] = 130 R"(Returns the string form of the value. 131 132 If the value is a block argument, this is the assembly form of its type and the 133 position in the argument list. If the value is an operation result, this is 134 equivalent to printing the operation that produced it. 135 )"; 136 137 //------------------------------------------------------------------------------ 138 // Utilities. 139 //------------------------------------------------------------------------------ 140 141 /// Helper for creating an @classmethod. 142 template <class Func, typename... Args> 143 py::object classmethod(Func f, Args... args) { 144 py::object cf = py::cpp_function(f, args...); 145 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 146 } 147 148 static py::object 149 createCustomDialectWrapper(const std::string &dialectNamespace, 150 py::object dialectDescriptor) { 151 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 152 if (!dialectClass) { 153 // Use the base class. 154 return py::cast(PyDialect(std::move(dialectDescriptor))); 155 } 156 157 // Create the custom implementation. 158 return (*dialectClass)(std::move(dialectDescriptor)); 159 } 160 161 static MlirStringRef toMlirStringRef(const std::string &s) { 162 return mlirStringRefCreate(s.data(), s.size()); 163 } 164 165 /// Wrapper for the global LLVM debugging flag. 166 struct PyGlobalDebugFlag { 167 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 168 169 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 170 171 static void bind(py::module &m) { 172 // Debug flags. 173 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 174 .def_property_static("flag", &PyGlobalDebugFlag::get, 175 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 176 } 177 }; 178 179 //------------------------------------------------------------------------------ 180 // Collections. 181 //------------------------------------------------------------------------------ 182 183 namespace { 184 185 class PyRegionIterator { 186 public: 187 PyRegionIterator(PyOperationRef operation) 188 : operation(std::move(operation)) {} 189 190 PyRegionIterator &dunderIter() { return *this; } 191 192 PyRegion dunderNext() { 193 operation->checkValid(); 194 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 195 throw py::stop_iteration(); 196 } 197 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 198 return PyRegion(operation, region); 199 } 200 201 static void bind(py::module &m) { 202 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 203 .def("__iter__", &PyRegionIterator::dunderIter) 204 .def("__next__", &PyRegionIterator::dunderNext); 205 } 206 207 private: 208 PyOperationRef operation; 209 int nextIndex = 0; 210 }; 211 212 /// Regions of an op are fixed length and indexed numerically so are represented 213 /// with a sequence-like container. 214 class PyRegionList { 215 public: 216 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 217 218 intptr_t dunderLen() { 219 operation->checkValid(); 220 return mlirOperationGetNumRegions(operation->get()); 221 } 222 223 PyRegion dunderGetItem(intptr_t index) { 224 // dunderLen checks validity. 225 if (index < 0 || index >= dunderLen()) { 226 throw SetPyError(PyExc_IndexError, 227 "attempt to access out of bounds region"); 228 } 229 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 230 return PyRegion(operation, region); 231 } 232 233 static void bind(py::module &m) { 234 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 235 .def("__len__", &PyRegionList::dunderLen) 236 .def("__getitem__", &PyRegionList::dunderGetItem); 237 } 238 239 private: 240 PyOperationRef operation; 241 }; 242 243 class PyBlockIterator { 244 public: 245 PyBlockIterator(PyOperationRef operation, MlirBlock next) 246 : operation(std::move(operation)), next(next) {} 247 248 PyBlockIterator &dunderIter() { return *this; } 249 250 PyBlock dunderNext() { 251 operation->checkValid(); 252 if (mlirBlockIsNull(next)) { 253 throw py::stop_iteration(); 254 } 255 256 PyBlock returnBlock(operation, next); 257 next = mlirBlockGetNextInRegion(next); 258 return returnBlock; 259 } 260 261 static void bind(py::module &m) { 262 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 263 .def("__iter__", &PyBlockIterator::dunderIter) 264 .def("__next__", &PyBlockIterator::dunderNext); 265 } 266 267 private: 268 PyOperationRef operation; 269 MlirBlock next; 270 }; 271 272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 273 /// we present them as a more full-featured list-like container but optimize 274 /// it for forward iteration. Blocks are always owned by a region. 275 class PyBlockList { 276 public: 277 PyBlockList(PyOperationRef operation, MlirRegion region) 278 : operation(std::move(operation)), region(region) {} 279 280 PyBlockIterator dunderIter() { 281 operation->checkValid(); 282 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 283 } 284 285 intptr_t dunderLen() { 286 operation->checkValid(); 287 intptr_t count = 0; 288 MlirBlock block = mlirRegionGetFirstBlock(region); 289 while (!mlirBlockIsNull(block)) { 290 count += 1; 291 block = mlirBlockGetNextInRegion(block); 292 } 293 return count; 294 } 295 296 PyBlock dunderGetItem(intptr_t index) { 297 operation->checkValid(); 298 if (index < 0) { 299 throw SetPyError(PyExc_IndexError, 300 "attempt to access out of bounds block"); 301 } 302 MlirBlock block = mlirRegionGetFirstBlock(region); 303 while (!mlirBlockIsNull(block)) { 304 if (index == 0) { 305 return PyBlock(operation, block); 306 } 307 block = mlirBlockGetNextInRegion(block); 308 index -= 1; 309 } 310 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 311 } 312 313 PyBlock appendBlock(py::args pyArgTypes) { 314 operation->checkValid(); 315 llvm::SmallVector<MlirType, 4> argTypes; 316 argTypes.reserve(pyArgTypes.size()); 317 for (auto &pyArg : pyArgTypes) { 318 argTypes.push_back(pyArg.cast<PyType &>()); 319 } 320 321 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 322 mlirRegionAppendOwnedBlock(region, block); 323 return PyBlock(operation, block); 324 } 325 326 static void bind(py::module &m) { 327 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 328 .def("__getitem__", &PyBlockList::dunderGetItem) 329 .def("__iter__", &PyBlockList::dunderIter) 330 .def("__len__", &PyBlockList::dunderLen) 331 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 332 } 333 334 private: 335 PyOperationRef operation; 336 MlirRegion region; 337 }; 338 339 class PyOperationIterator { 340 public: 341 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 342 : parentOperation(std::move(parentOperation)), next(next) {} 343 344 PyOperationIterator &dunderIter() { return *this; } 345 346 py::object dunderNext() { 347 parentOperation->checkValid(); 348 if (mlirOperationIsNull(next)) { 349 throw py::stop_iteration(); 350 } 351 352 PyOperationRef returnOperation = 353 PyOperation::forOperation(parentOperation->getContext(), next); 354 next = mlirOperationGetNextInBlock(next); 355 return returnOperation->createOpView(); 356 } 357 358 static void bind(py::module &m) { 359 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 360 .def("__iter__", &PyOperationIterator::dunderIter) 361 .def("__next__", &PyOperationIterator::dunderNext); 362 } 363 364 private: 365 PyOperationRef parentOperation; 366 MlirOperation next; 367 }; 368 369 /// Operations are exposed by the C-API as a forward-only linked list. In 370 /// Python, we present them as a more full-featured list-like container but 371 /// optimize it for forward iteration. Iterable operations are always owned 372 /// by a block. 373 class PyOperationList { 374 public: 375 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 376 : parentOperation(std::move(parentOperation)), block(block) {} 377 378 PyOperationIterator dunderIter() { 379 parentOperation->checkValid(); 380 return PyOperationIterator(parentOperation, 381 mlirBlockGetFirstOperation(block)); 382 } 383 384 intptr_t dunderLen() { 385 parentOperation->checkValid(); 386 intptr_t count = 0; 387 MlirOperation childOp = mlirBlockGetFirstOperation(block); 388 while (!mlirOperationIsNull(childOp)) { 389 count += 1; 390 childOp = mlirOperationGetNextInBlock(childOp); 391 } 392 return count; 393 } 394 395 py::object dunderGetItem(intptr_t index) { 396 parentOperation->checkValid(); 397 if (index < 0) { 398 throw SetPyError(PyExc_IndexError, 399 "attempt to access out of bounds operation"); 400 } 401 MlirOperation childOp = mlirBlockGetFirstOperation(block); 402 while (!mlirOperationIsNull(childOp)) { 403 if (index == 0) { 404 return PyOperation::forOperation(parentOperation->getContext(), childOp) 405 ->createOpView(); 406 } 407 childOp = mlirOperationGetNextInBlock(childOp); 408 index -= 1; 409 } 410 throw SetPyError(PyExc_IndexError, 411 "attempt to access out of bounds operation"); 412 } 413 414 static void bind(py::module &m) { 415 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 416 .def("__getitem__", &PyOperationList::dunderGetItem) 417 .def("__iter__", &PyOperationList::dunderIter) 418 .def("__len__", &PyOperationList::dunderLen); 419 } 420 421 private: 422 PyOperationRef parentOperation; 423 MlirBlock block; 424 }; 425 426 } // namespace 427 428 //------------------------------------------------------------------------------ 429 // PyMlirContext 430 //------------------------------------------------------------------------------ 431 432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 433 py::gil_scoped_acquire acquire; 434 auto &liveContexts = getLiveContexts(); 435 liveContexts[context.ptr] = this; 436 } 437 438 PyMlirContext::~PyMlirContext() { 439 // Note that the only public way to construct an instance is via the 440 // forContext method, which always puts the associated handle into 441 // liveContexts. 442 py::gil_scoped_acquire acquire; 443 getLiveContexts().erase(context.ptr); 444 mlirContextDestroy(context); 445 } 446 447 py::object PyMlirContext::getCapsule() { 448 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 449 } 450 451 py::object PyMlirContext::createFromCapsule(py::object capsule) { 452 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 453 if (mlirContextIsNull(rawContext)) 454 throw py::error_already_set(); 455 return forContext(rawContext).releaseObject(); 456 } 457 458 PyMlirContext *PyMlirContext::createNewContextForInit() { 459 MlirContext context = mlirContextCreate(); 460 mlirRegisterAllDialects(context); 461 return new PyMlirContext(context); 462 } 463 464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 465 py::gil_scoped_acquire acquire; 466 auto &liveContexts = getLiveContexts(); 467 auto it = liveContexts.find(context.ptr); 468 if (it == liveContexts.end()) { 469 // Create. 470 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 471 py::object pyRef = py::cast(unownedContextWrapper); 472 assert(pyRef && "cast to py::object failed"); 473 liveContexts[context.ptr] = unownedContextWrapper; 474 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 475 } 476 // Use existing. 477 py::object pyRef = py::cast(it->second); 478 return PyMlirContextRef(it->second, std::move(pyRef)); 479 } 480 481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 482 static LiveContextMap liveContexts; 483 return liveContexts; 484 } 485 486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 487 488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 489 490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 491 492 pybind11::object PyMlirContext::contextEnter() { 493 return PyThreadContextEntry::pushContext(*this); 494 } 495 496 void PyMlirContext::contextExit(pybind11::object excType, 497 pybind11::object excVal, 498 pybind11::object excTb) { 499 PyThreadContextEntry::popContext(*this); 500 } 501 502 PyMlirContext &DefaultingPyMlirContext::resolve() { 503 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 504 if (!context) { 505 throw SetPyError( 506 PyExc_RuntimeError, 507 "An MLIR function requires a Context but none was provided in the call " 508 "or from the surrounding environment. Either pass to the function with " 509 "a 'context=' argument or establish a default using 'with Context():'"); 510 } 511 return *context; 512 } 513 514 //------------------------------------------------------------------------------ 515 // PyThreadContextEntry management 516 //------------------------------------------------------------------------------ 517 518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 519 static thread_local std::vector<PyThreadContextEntry> stack; 520 return stack; 521 } 522 523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 524 auto &stack = getStack(); 525 if (stack.empty()) 526 return nullptr; 527 return &stack.back(); 528 } 529 530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 531 py::object insertionPoint, 532 py::object location) { 533 auto &stack = getStack(); 534 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 535 std::move(location)); 536 // If the new stack has more than one entry and the context of the new top 537 // entry matches the previous, copy the insertionPoint and location from the 538 // previous entry if missing from the new top entry. 539 if (stack.size() > 1) { 540 auto &prev = *(stack.rbegin() + 1); 541 auto ¤t = stack.back(); 542 if (current.context.is(prev.context)) { 543 // Default non-context objects from the previous entry. 544 if (!current.insertionPoint) 545 current.insertionPoint = prev.insertionPoint; 546 if (!current.location) 547 current.location = prev.location; 548 } 549 } 550 } 551 552 PyMlirContext *PyThreadContextEntry::getContext() { 553 if (!context) 554 return nullptr; 555 return py::cast<PyMlirContext *>(context); 556 } 557 558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 559 if (!insertionPoint) 560 return nullptr; 561 return py::cast<PyInsertionPoint *>(insertionPoint); 562 } 563 564 PyLocation *PyThreadContextEntry::getLocation() { 565 if (!location) 566 return nullptr; 567 return py::cast<PyLocation *>(location); 568 } 569 570 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 571 auto *tos = getTopOfStack(); 572 return tos ? tos->getContext() : nullptr; 573 } 574 575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 576 auto *tos = getTopOfStack(); 577 return tos ? tos->getInsertionPoint() : nullptr; 578 } 579 580 PyLocation *PyThreadContextEntry::getDefaultLocation() { 581 auto *tos = getTopOfStack(); 582 return tos ? tos->getLocation() : nullptr; 583 } 584 585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 586 py::object contextObj = py::cast(context); 587 push(FrameKind::Context, /*context=*/contextObj, 588 /*insertionPoint=*/py::object(), 589 /*location=*/py::object()); 590 return contextObj; 591 } 592 593 void PyThreadContextEntry::popContext(PyMlirContext &context) { 594 auto &stack = getStack(); 595 if (stack.empty()) 596 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 597 auto &tos = stack.back(); 598 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 599 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 600 stack.pop_back(); 601 } 602 603 py::object 604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 605 py::object contextObj = 606 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 607 py::object insertionPointObj = py::cast(insertionPoint); 608 push(FrameKind::InsertionPoint, 609 /*context=*/contextObj, 610 /*insertionPoint=*/insertionPointObj, 611 /*location=*/py::object()); 612 return insertionPointObj; 613 } 614 615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 616 auto &stack = getStack(); 617 if (stack.empty()) 618 throw SetPyError(PyExc_RuntimeError, 619 "Unbalanced InsertionPoint enter/exit"); 620 auto &tos = stack.back(); 621 if (tos.frameKind != FrameKind::InsertionPoint && 622 tos.getInsertionPoint() != &insertionPoint) 623 throw SetPyError(PyExc_RuntimeError, 624 "Unbalanced InsertionPoint enter/exit"); 625 stack.pop_back(); 626 } 627 628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 629 py::object contextObj = location.getContext().getObject(); 630 py::object locationObj = py::cast(location); 631 push(FrameKind::Location, /*context=*/contextObj, 632 /*insertionPoint=*/py::object(), 633 /*location=*/locationObj); 634 return locationObj; 635 } 636 637 void PyThreadContextEntry::popLocation(PyLocation &location) { 638 auto &stack = getStack(); 639 if (stack.empty()) 640 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 641 auto &tos = stack.back(); 642 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 643 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 644 stack.pop_back(); 645 } 646 647 //------------------------------------------------------------------------------ 648 // PyDialect, PyDialectDescriptor, PyDialects 649 //------------------------------------------------------------------------------ 650 651 MlirDialect PyDialects::getDialectForKey(const std::string &key, 652 bool attrError) { 653 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 654 {key.data(), key.size()}); 655 if (mlirDialectIsNull(dialect)) { 656 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 657 Twine("Dialect '") + key + "' not found"); 658 } 659 return dialect; 660 } 661 662 //------------------------------------------------------------------------------ 663 // PyLocation 664 //------------------------------------------------------------------------------ 665 666 py::object PyLocation::getCapsule() { 667 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 668 } 669 670 PyLocation PyLocation::createFromCapsule(py::object capsule) { 671 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 672 if (mlirLocationIsNull(rawLoc)) 673 throw py::error_already_set(); 674 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 675 rawLoc); 676 } 677 678 py::object PyLocation::contextEnter() { 679 return PyThreadContextEntry::pushLocation(*this); 680 } 681 682 void PyLocation::contextExit(py::object excType, py::object excVal, 683 py::object excTb) { 684 PyThreadContextEntry::popLocation(*this); 685 } 686 687 PyLocation &DefaultingPyLocation::resolve() { 688 auto *location = PyThreadContextEntry::getDefaultLocation(); 689 if (!location) { 690 throw SetPyError( 691 PyExc_RuntimeError, 692 "An MLIR function requires a Location but none was provided in the " 693 "call or from the surrounding environment. Either pass to the function " 694 "with a 'loc=' argument or establish a default using 'with loc:'"); 695 } 696 return *location; 697 } 698 699 //------------------------------------------------------------------------------ 700 // PyModule 701 //------------------------------------------------------------------------------ 702 703 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 704 : BaseContextObject(std::move(contextRef)), module(module) {} 705 706 PyModule::~PyModule() { 707 py::gil_scoped_acquire acquire; 708 auto &liveModules = getContext()->liveModules; 709 assert(liveModules.count(module.ptr) == 1 && 710 "destroying module not in live map"); 711 liveModules.erase(module.ptr); 712 mlirModuleDestroy(module); 713 } 714 715 PyModuleRef PyModule::forModule(MlirModule module) { 716 MlirContext context = mlirModuleGetContext(module); 717 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 718 719 py::gil_scoped_acquire acquire; 720 auto &liveModules = contextRef->liveModules; 721 auto it = liveModules.find(module.ptr); 722 if (it == liveModules.end()) { 723 // Create. 724 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 725 // Note that the default return value policy on cast is automatic_reference, 726 // which does not take ownership (delete will not be called). 727 // Just be explicit. 728 py::object pyRef = 729 py::cast(unownedModule, py::return_value_policy::take_ownership); 730 unownedModule->handle = pyRef; 731 liveModules[module.ptr] = 732 std::make_pair(unownedModule->handle, unownedModule); 733 return PyModuleRef(unownedModule, std::move(pyRef)); 734 } 735 // Use existing. 736 PyModule *existing = it->second.second; 737 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 738 return PyModuleRef(existing, std::move(pyRef)); 739 } 740 741 py::object PyModule::createFromCapsule(py::object capsule) { 742 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 743 if (mlirModuleIsNull(rawModule)) 744 throw py::error_already_set(); 745 return forModule(rawModule).releaseObject(); 746 } 747 748 py::object PyModule::getCapsule() { 749 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 750 } 751 752 //------------------------------------------------------------------------------ 753 // PyOperation 754 //------------------------------------------------------------------------------ 755 756 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 757 : BaseContextObject(std::move(contextRef)), operation(operation) {} 758 759 PyOperation::~PyOperation() { 760 // If the operation has already been invalidated there is nothing to do. 761 if (!valid) 762 return; 763 auto &liveOperations = getContext()->liveOperations; 764 assert(liveOperations.count(operation.ptr) == 1 && 765 "destroying operation not in live map"); 766 liveOperations.erase(operation.ptr); 767 if (!isAttached()) { 768 mlirOperationDestroy(operation); 769 } 770 } 771 772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 773 MlirOperation operation, 774 py::object parentKeepAlive) { 775 auto &liveOperations = contextRef->liveOperations; 776 // Create. 777 PyOperation *unownedOperation = 778 new PyOperation(std::move(contextRef), operation); 779 // Note that the default return value policy on cast is automatic_reference, 780 // which does not take ownership (delete will not be called). 781 // Just be explicit. 782 py::object pyRef = 783 py::cast(unownedOperation, py::return_value_policy::take_ownership); 784 unownedOperation->handle = pyRef; 785 if (parentKeepAlive) { 786 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 787 } 788 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 789 return PyOperationRef(unownedOperation, std::move(pyRef)); 790 } 791 792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 793 MlirOperation operation, 794 py::object parentKeepAlive) { 795 auto &liveOperations = contextRef->liveOperations; 796 auto it = liveOperations.find(operation.ptr); 797 if (it == liveOperations.end()) { 798 // Create. 799 return createInstance(std::move(contextRef), operation, 800 std::move(parentKeepAlive)); 801 } 802 // Use existing. 803 PyOperation *existing = it->second.second; 804 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 805 return PyOperationRef(existing, std::move(pyRef)); 806 } 807 808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 809 MlirOperation operation, 810 py::object parentKeepAlive) { 811 auto &liveOperations = contextRef->liveOperations; 812 assert(liveOperations.count(operation.ptr) == 0 && 813 "cannot create detached operation that already exists"); 814 (void)liveOperations; 815 816 PyOperationRef created = createInstance(std::move(contextRef), operation, 817 std::move(parentKeepAlive)); 818 created->attached = false; 819 return created; 820 } 821 822 void PyOperation::checkValid() const { 823 if (!valid) { 824 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 825 } 826 } 827 828 void PyOperationBase::print(py::object fileObject, bool binary, 829 llvm::Optional<int64_t> largeElementsLimit, 830 bool enableDebugInfo, bool prettyDebugInfo, 831 bool printGenericOpForm, bool useLocalScope) { 832 PyOperation &operation = getOperation(); 833 operation.checkValid(); 834 if (fileObject.is_none()) 835 fileObject = py::module::import("sys").attr("stdout"); 836 837 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 838 fileObject.attr("write")("// Verification failed, printing generic form\n"); 839 printGenericOpForm = true; 840 } 841 842 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 843 if (largeElementsLimit) 844 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 845 if (enableDebugInfo) 846 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 847 if (printGenericOpForm) 848 mlirOpPrintingFlagsPrintGenericOpForm(flags); 849 850 PyFileAccumulator accum(fileObject, binary); 851 py::gil_scoped_release(); 852 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 853 accum.getUserData()); 854 mlirOpPrintingFlagsDestroy(flags); 855 } 856 857 py::object PyOperationBase::getAsm(bool binary, 858 llvm::Optional<int64_t> largeElementsLimit, 859 bool enableDebugInfo, bool prettyDebugInfo, 860 bool printGenericOpForm, 861 bool useLocalScope) { 862 py::object fileObject; 863 if (binary) { 864 fileObject = py::module::import("io").attr("BytesIO")(); 865 } else { 866 fileObject = py::module::import("io").attr("StringIO")(); 867 } 868 print(fileObject, /*binary=*/binary, 869 /*largeElementsLimit=*/largeElementsLimit, 870 /*enableDebugInfo=*/enableDebugInfo, 871 /*prettyDebugInfo=*/prettyDebugInfo, 872 /*printGenericOpForm=*/printGenericOpForm, 873 /*useLocalScope=*/useLocalScope); 874 875 return fileObject.attr("getvalue")(); 876 } 877 878 void PyOperationBase::moveAfter(PyOperationBase &other) { 879 PyOperation &operation = getOperation(); 880 PyOperation &otherOp = other.getOperation(); 881 operation.checkValid(); 882 otherOp.checkValid(); 883 mlirOperationMoveAfter(operation, otherOp); 884 operation.parentKeepAlive = otherOp.parentKeepAlive; 885 } 886 887 void PyOperationBase::moveBefore(PyOperationBase &other) { 888 PyOperation &operation = getOperation(); 889 PyOperation &otherOp = other.getOperation(); 890 operation.checkValid(); 891 otherOp.checkValid(); 892 mlirOperationMoveBefore(operation, otherOp); 893 operation.parentKeepAlive = otherOp.parentKeepAlive; 894 } 895 896 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 897 checkValid(); 898 if (!isAttached()) 899 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 900 MlirOperation operation = mlirOperationGetParentOperation(get()); 901 if (mlirOperationIsNull(operation)) 902 return {}; 903 return PyOperation::forOperation(getContext(), operation); 904 } 905 906 PyBlock PyOperation::getBlock() { 907 checkValid(); 908 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 909 MlirBlock block = mlirOperationGetBlock(get()); 910 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 911 assert(parentOperation && "Operation has no parent"); 912 return PyBlock{std::move(*parentOperation), block}; 913 } 914 915 py::object PyOperation::getCapsule() { 916 checkValid(); 917 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 918 } 919 920 py::object PyOperation::createFromCapsule(py::object capsule) { 921 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 922 if (mlirOperationIsNull(rawOperation)) 923 throw py::error_already_set(); 924 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 925 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 926 .releaseObject(); 927 } 928 929 py::object PyOperation::create( 930 std::string name, llvm::Optional<std::vector<PyType *>> results, 931 llvm::Optional<std::vector<PyValue *>> operands, 932 llvm::Optional<py::dict> attributes, 933 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 934 DefaultingPyLocation location, py::object maybeIp) { 935 llvm::SmallVector<MlirValue, 4> mlirOperands; 936 llvm::SmallVector<MlirType, 4> mlirResults; 937 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 938 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 939 940 // General parameter validation. 941 if (regions < 0) 942 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 943 944 // Unpack/validate operands. 945 if (operands) { 946 mlirOperands.reserve(operands->size()); 947 for (PyValue *operand : *operands) { 948 if (!operand) 949 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 950 mlirOperands.push_back(operand->get()); 951 } 952 } 953 954 // Unpack/validate results. 955 if (results) { 956 mlirResults.reserve(results->size()); 957 for (PyType *result : *results) { 958 // TODO: Verify result type originate from the same context. 959 if (!result) 960 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 961 mlirResults.push_back(*result); 962 } 963 } 964 // Unpack/validate attributes. 965 if (attributes) { 966 mlirAttributes.reserve(attributes->size()); 967 for (auto &it : *attributes) { 968 std::string key; 969 try { 970 key = it.first.cast<std::string>(); 971 } catch (py::cast_error &err) { 972 std::string msg = "Invalid attribute key (not a string) when " 973 "attempting to create the operation \"" + 974 name + "\" (" + err.what() + ")"; 975 throw py::cast_error(msg); 976 } 977 try { 978 auto &attribute = it.second.cast<PyAttribute &>(); 979 // TODO: Verify attribute originates from the same context. 980 mlirAttributes.emplace_back(std::move(key), attribute); 981 } catch (py::reference_cast_error &) { 982 // This exception seems thrown when the value is "None". 983 std::string msg = 984 "Found an invalid (`None`?) attribute value for the key \"" + key + 985 "\" when attempting to create the operation \"" + name + "\""; 986 throw py::cast_error(msg); 987 } catch (py::cast_error &err) { 988 std::string msg = "Invalid attribute value for the key \"" + key + 989 "\" when attempting to create the operation \"" + 990 name + "\" (" + err.what() + ")"; 991 throw py::cast_error(msg); 992 } 993 } 994 } 995 // Unpack/validate successors. 996 if (successors) { 997 mlirSuccessors.reserve(successors->size()); 998 for (auto *successor : *successors) { 999 // TODO: Verify successor originate from the same context. 1000 if (!successor) 1001 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1002 mlirSuccessors.push_back(successor->get()); 1003 } 1004 } 1005 1006 // Apply unpacked/validated to the operation state. Beyond this 1007 // point, exceptions cannot be thrown or else the state will leak. 1008 MlirOperationState state = 1009 mlirOperationStateGet(toMlirStringRef(name), location); 1010 if (!mlirOperands.empty()) 1011 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1012 mlirOperands.data()); 1013 if (!mlirResults.empty()) 1014 mlirOperationStateAddResults(&state, mlirResults.size(), 1015 mlirResults.data()); 1016 if (!mlirAttributes.empty()) { 1017 // Note that the attribute names directly reference bytes in 1018 // mlirAttributes, so that vector must not be changed from here 1019 // on. 1020 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1021 mlirNamedAttributes.reserve(mlirAttributes.size()); 1022 for (auto &it : mlirAttributes) 1023 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1024 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1025 toMlirStringRef(it.first)), 1026 it.second)); 1027 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1028 mlirNamedAttributes.data()); 1029 } 1030 if (!mlirSuccessors.empty()) 1031 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1032 mlirSuccessors.data()); 1033 if (regions) { 1034 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1035 mlirRegions.resize(regions); 1036 for (int i = 0; i < regions; ++i) 1037 mlirRegions[i] = mlirRegionCreate(); 1038 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1039 mlirRegions.data()); 1040 } 1041 1042 // Construct the operation. 1043 MlirOperation operation = mlirOperationCreate(&state); 1044 PyOperationRef created = 1045 PyOperation::createDetached(location->getContext(), operation); 1046 1047 // InsertPoint active? 1048 if (!maybeIp.is(py::cast(false))) { 1049 PyInsertionPoint *ip; 1050 if (maybeIp.is_none()) { 1051 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1052 } else { 1053 ip = py::cast<PyInsertionPoint *>(maybeIp); 1054 } 1055 if (ip) 1056 ip->insert(*created.get()); 1057 } 1058 1059 return created->createOpView(); 1060 } 1061 1062 py::object PyOperation::createOpView() { 1063 checkValid(); 1064 MlirIdentifier ident = mlirOperationGetName(get()); 1065 MlirStringRef identStr = mlirIdentifierStr(ident); 1066 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1067 StringRef(identStr.data, identStr.length)); 1068 if (opViewClass) 1069 return (*opViewClass)(getRef().getObject()); 1070 return py::cast(PyOpView(getRef().getObject())); 1071 } 1072 1073 void PyOperation::erase() { 1074 checkValid(); 1075 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1076 // Python reference to a child operation is live. All children should also 1077 // have their `valid` bit set to false. 1078 auto &liveOperations = getContext()->liveOperations; 1079 if (liveOperations.count(operation.ptr)) 1080 liveOperations.erase(operation.ptr); 1081 mlirOperationDestroy(operation); 1082 valid = false; 1083 } 1084 1085 //------------------------------------------------------------------------------ 1086 // PyOpView 1087 //------------------------------------------------------------------------------ 1088 1089 py::object 1090 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1091 py::list operandList, 1092 llvm::Optional<py::dict> attributes, 1093 llvm::Optional<std::vector<PyBlock *>> successors, 1094 llvm::Optional<int> regions, 1095 DefaultingPyLocation location, py::object maybeIp) { 1096 PyMlirContextRef context = location->getContext(); 1097 // Class level operation construction metadata. 1098 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1099 // Operand and result segment specs are either none, which does no 1100 // variadic unpacking, or a list of ints with segment sizes, where each 1101 // element is either a positive number (typically 1 for a scalar) or -1 to 1102 // indicate that it is derived from the length of the same-indexed operand 1103 // or result (implying that it is a list at that position). 1104 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1105 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1106 1107 std::vector<uint32_t> operandSegmentLengths; 1108 std::vector<uint32_t> resultSegmentLengths; 1109 1110 // Validate/determine region count. 1111 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1112 int opMinRegionCount = std::get<0>(opRegionSpec); 1113 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1114 if (!regions) { 1115 regions = opMinRegionCount; 1116 } 1117 if (*regions < opMinRegionCount) { 1118 throw py::value_error( 1119 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1120 llvm::Twine(opMinRegionCount) + 1121 " regions but was built with regions=" + llvm::Twine(*regions)) 1122 .str()); 1123 } 1124 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1125 throw py::value_error( 1126 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1127 llvm::Twine(opMinRegionCount) + 1128 " regions but was built with regions=" + llvm::Twine(*regions)) 1129 .str()); 1130 } 1131 1132 // Unpack results. 1133 std::vector<PyType *> resultTypes; 1134 resultTypes.reserve(resultTypeList.size()); 1135 if (resultSegmentSpecObj.is_none()) { 1136 // Non-variadic result unpacking. 1137 for (auto it : llvm::enumerate(resultTypeList)) { 1138 try { 1139 resultTypes.push_back(py::cast<PyType *>(it.value())); 1140 if (!resultTypes.back()) 1141 throw py::cast_error(); 1142 } catch (py::cast_error &err) { 1143 throw py::value_error((llvm::Twine("Result ") + 1144 llvm::Twine(it.index()) + " of operation \"" + 1145 name + "\" must be a Type (" + err.what() + ")") 1146 .str()); 1147 } 1148 } 1149 } else { 1150 // Sized result unpacking. 1151 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1152 if (resultSegmentSpec.size() != resultTypeList.size()) { 1153 throw py::value_error((llvm::Twine("Operation \"") + name + 1154 "\" requires " + 1155 llvm::Twine(resultSegmentSpec.size()) + 1156 " result segments but was provided " + 1157 llvm::Twine(resultTypeList.size())) 1158 .str()); 1159 } 1160 resultSegmentLengths.reserve(resultTypeList.size()); 1161 for (auto it : 1162 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1163 int segmentSpec = std::get<1>(it.value()); 1164 if (segmentSpec == 1 || segmentSpec == 0) { 1165 // Unpack unary element. 1166 try { 1167 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1168 if (resultType) { 1169 resultTypes.push_back(resultType); 1170 resultSegmentLengths.push_back(1); 1171 } else if (segmentSpec == 0) { 1172 // Allowed to be optional. 1173 resultSegmentLengths.push_back(0); 1174 } else { 1175 throw py::cast_error("was None and result is not optional"); 1176 } 1177 } catch (py::cast_error &err) { 1178 throw py::value_error((llvm::Twine("Result ") + 1179 llvm::Twine(it.index()) + " of operation \"" + 1180 name + "\" must be a Type (" + err.what() + 1181 ")") 1182 .str()); 1183 } 1184 } else if (segmentSpec == -1) { 1185 // Unpack sequence by appending. 1186 try { 1187 if (std::get<0>(it.value()).is_none()) { 1188 // Treat it as an empty list. 1189 resultSegmentLengths.push_back(0); 1190 } else { 1191 // Unpack the list. 1192 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1193 for (py::object segmentItem : segment) { 1194 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1195 if (!resultTypes.back()) { 1196 throw py::cast_error("contained a None item"); 1197 } 1198 } 1199 resultSegmentLengths.push_back(segment.size()); 1200 } 1201 } catch (std::exception &err) { 1202 // NOTE: Sloppy to be using a catch-all here, but there are at least 1203 // three different unrelated exceptions that can be thrown in the 1204 // above "casts". Just keep the scope above small and catch them all. 1205 throw py::value_error((llvm::Twine("Result ") + 1206 llvm::Twine(it.index()) + " of operation \"" + 1207 name + "\" must be a Sequence of Types (" + 1208 err.what() + ")") 1209 .str()); 1210 } 1211 } else { 1212 throw py::value_error("Unexpected segment spec"); 1213 } 1214 } 1215 } 1216 1217 // Unpack operands. 1218 std::vector<PyValue *> operands; 1219 operands.reserve(operands.size()); 1220 if (operandSegmentSpecObj.is_none()) { 1221 // Non-sized operand unpacking. 1222 for (auto it : llvm::enumerate(operandList)) { 1223 try { 1224 operands.push_back(py::cast<PyValue *>(it.value())); 1225 if (!operands.back()) 1226 throw py::cast_error(); 1227 } catch (py::cast_error &err) { 1228 throw py::value_error((llvm::Twine("Operand ") + 1229 llvm::Twine(it.index()) + " of operation \"" + 1230 name + "\" must be a Value (" + err.what() + ")") 1231 .str()); 1232 } 1233 } 1234 } else { 1235 // Sized operand unpacking. 1236 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1237 if (operandSegmentSpec.size() != operandList.size()) { 1238 throw py::value_error((llvm::Twine("Operation \"") + name + 1239 "\" requires " + 1240 llvm::Twine(operandSegmentSpec.size()) + 1241 "operand segments but was provided " + 1242 llvm::Twine(operandList.size())) 1243 .str()); 1244 } 1245 operandSegmentLengths.reserve(operandList.size()); 1246 for (auto it : 1247 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1248 int segmentSpec = std::get<1>(it.value()); 1249 if (segmentSpec == 1 || segmentSpec == 0) { 1250 // Unpack unary element. 1251 try { 1252 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1253 if (operandValue) { 1254 operands.push_back(operandValue); 1255 operandSegmentLengths.push_back(1); 1256 } else if (segmentSpec == 0) { 1257 // Allowed to be optional. 1258 operandSegmentLengths.push_back(0); 1259 } else { 1260 throw py::cast_error("was None and operand is not optional"); 1261 } 1262 } catch (py::cast_error &err) { 1263 throw py::value_error((llvm::Twine("Operand ") + 1264 llvm::Twine(it.index()) + " of operation \"" + 1265 name + "\" must be a Value (" + err.what() + 1266 ")") 1267 .str()); 1268 } 1269 } else if (segmentSpec == -1) { 1270 // Unpack sequence by appending. 1271 try { 1272 if (std::get<0>(it.value()).is_none()) { 1273 // Treat it as an empty list. 1274 operandSegmentLengths.push_back(0); 1275 } else { 1276 // Unpack the list. 1277 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1278 for (py::object segmentItem : segment) { 1279 operands.push_back(py::cast<PyValue *>(segmentItem)); 1280 if (!operands.back()) { 1281 throw py::cast_error("contained a None item"); 1282 } 1283 } 1284 operandSegmentLengths.push_back(segment.size()); 1285 } 1286 } catch (std::exception &err) { 1287 // NOTE: Sloppy to be using a catch-all here, but there are at least 1288 // three different unrelated exceptions that can be thrown in the 1289 // above "casts". Just keep the scope above small and catch them all. 1290 throw py::value_error((llvm::Twine("Operand ") + 1291 llvm::Twine(it.index()) + " of operation \"" + 1292 name + "\" must be a Sequence of Values (" + 1293 err.what() + ")") 1294 .str()); 1295 } 1296 } else { 1297 throw py::value_error("Unexpected segment spec"); 1298 } 1299 } 1300 } 1301 1302 // Merge operand/result segment lengths into attributes if needed. 1303 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1304 // Dup. 1305 if (attributes) { 1306 attributes = py::dict(*attributes); 1307 } else { 1308 attributes = py::dict(); 1309 } 1310 if (attributes->contains("result_segment_sizes") || 1311 attributes->contains("operand_segment_sizes")) { 1312 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1313 "'operand_segment_sizes' attribute is unsupported. " 1314 "Use Operation.create for such low-level access."); 1315 } 1316 1317 // Add result_segment_sizes attribute. 1318 if (!resultSegmentLengths.empty()) { 1319 int64_t size = resultSegmentLengths.size(); 1320 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1321 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1322 resultSegmentLengths.size(), resultSegmentLengths.data()); 1323 (*attributes)["result_segment_sizes"] = 1324 PyAttribute(context, segmentLengthAttr); 1325 } 1326 1327 // Add operand_segment_sizes attribute. 1328 if (!operandSegmentLengths.empty()) { 1329 int64_t size = operandSegmentLengths.size(); 1330 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1331 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1332 operandSegmentLengths.size(), operandSegmentLengths.data()); 1333 (*attributes)["operand_segment_sizes"] = 1334 PyAttribute(context, segmentLengthAttr); 1335 } 1336 } 1337 1338 // Delegate to create. 1339 return PyOperation::create(std::move(name), 1340 /*results=*/std::move(resultTypes), 1341 /*operands=*/std::move(operands), 1342 /*attributes=*/std::move(attributes), 1343 /*successors=*/std::move(successors), 1344 /*regions=*/*regions, location, maybeIp); 1345 } 1346 1347 PyOpView::PyOpView(py::object operationObject) 1348 // Casting through the PyOperationBase base-class and then back to the 1349 // Operation lets us accept any PyOperationBase subclass. 1350 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1351 operationObject(operation.getRef().getObject()) {} 1352 1353 py::object PyOpView::createRawSubclass(py::object userClass) { 1354 // This is... a little gross. The typical pattern is to have a pure python 1355 // class that extends OpView like: 1356 // class AddFOp(_cext.ir.OpView): 1357 // def __init__(self, loc, lhs, rhs): 1358 // operation = loc.context.create_operation( 1359 // "addf", lhs, rhs, results=[lhs.type]) 1360 // super().__init__(operation) 1361 // 1362 // I.e. The goal of the user facing type is to provide a nice constructor 1363 // that has complete freedom for the op under construction. This is at odds 1364 // with our other desire to sometimes create this object by just passing an 1365 // operation (to initialize the base class). We could do *arg and **kwargs 1366 // munging to try to make it work, but instead, we synthesize a new class 1367 // on the fly which extends this user class (AddFOp in this example) and 1368 // *give it* the base class's __init__ method, thus bypassing the 1369 // intermediate subclass's __init__ method entirely. While slightly, 1370 // underhanded, this is safe/legal because the type hierarchy has not changed 1371 // (we just added a new leaf) and we aren't mucking around with __new__. 1372 // Typically, this new class will be stored on the original as "_Raw" and will 1373 // be used for casts and other things that need a variant of the class that 1374 // is initialized purely from an operation. 1375 py::object parentMetaclass = 1376 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1377 py::dict attributes; 1378 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1379 // now. 1380 // auto opViewType = py::type::of<PyOpView>(); 1381 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1382 attributes["__init__"] = opViewType.attr("__init__"); 1383 py::str origName = userClass.attr("__name__"); 1384 py::str newName = py::str("_") + origName; 1385 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1386 } 1387 1388 //------------------------------------------------------------------------------ 1389 // PyInsertionPoint. 1390 //------------------------------------------------------------------------------ 1391 1392 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1393 1394 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1395 : refOperation(beforeOperationBase.getOperation().getRef()), 1396 block((*refOperation)->getBlock()) {} 1397 1398 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1399 PyOperation &operation = operationBase.getOperation(); 1400 if (operation.isAttached()) 1401 throw SetPyError(PyExc_ValueError, 1402 "Attempt to insert operation that is already attached"); 1403 block.getParentOperation()->checkValid(); 1404 MlirOperation beforeOp = {nullptr}; 1405 if (refOperation) { 1406 // Insert before operation. 1407 (*refOperation)->checkValid(); 1408 beforeOp = (*refOperation)->get(); 1409 } else { 1410 // Insert at end (before null) is only valid if the block does not 1411 // already end in a known terminator (violating this will cause assertion 1412 // failures later). 1413 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1414 throw py::index_error("Cannot insert operation at the end of a block " 1415 "that already has a terminator. Did you mean to " 1416 "use 'InsertionPoint.at_block_terminator(block)' " 1417 "versus 'InsertionPoint(block)'?"); 1418 } 1419 } 1420 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1421 operation.setAttached(); 1422 } 1423 1424 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1425 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1426 if (mlirOperationIsNull(firstOp)) { 1427 // Just insert at end. 1428 return PyInsertionPoint(block); 1429 } 1430 1431 // Insert before first op. 1432 PyOperationRef firstOpRef = PyOperation::forOperation( 1433 block.getParentOperation()->getContext(), firstOp); 1434 return PyInsertionPoint{block, std::move(firstOpRef)}; 1435 } 1436 1437 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1438 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1439 if (mlirOperationIsNull(terminator)) 1440 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1441 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1442 block.getParentOperation()->getContext(), terminator); 1443 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1444 } 1445 1446 py::object PyInsertionPoint::contextEnter() { 1447 return PyThreadContextEntry::pushInsertionPoint(*this); 1448 } 1449 1450 void PyInsertionPoint::contextExit(pybind11::object excType, 1451 pybind11::object excVal, 1452 pybind11::object excTb) { 1453 PyThreadContextEntry::popInsertionPoint(*this); 1454 } 1455 1456 //------------------------------------------------------------------------------ 1457 // PyAttribute. 1458 //------------------------------------------------------------------------------ 1459 1460 bool PyAttribute::operator==(const PyAttribute &other) { 1461 return mlirAttributeEqual(attr, other.attr); 1462 } 1463 1464 py::object PyAttribute::getCapsule() { 1465 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1466 } 1467 1468 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1469 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1470 if (mlirAttributeIsNull(rawAttr)) 1471 throw py::error_already_set(); 1472 return PyAttribute( 1473 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1474 } 1475 1476 //------------------------------------------------------------------------------ 1477 // PyNamedAttribute. 1478 //------------------------------------------------------------------------------ 1479 1480 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1481 : ownedName(new std::string(std::move(ownedName))) { 1482 namedAttr = mlirNamedAttributeGet( 1483 mlirIdentifierGet(mlirAttributeGetContext(attr), 1484 toMlirStringRef(*this->ownedName)), 1485 attr); 1486 } 1487 1488 //------------------------------------------------------------------------------ 1489 // PyType. 1490 //------------------------------------------------------------------------------ 1491 1492 bool PyType::operator==(const PyType &other) { 1493 return mlirTypeEqual(type, other.type); 1494 } 1495 1496 py::object PyType::getCapsule() { 1497 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1498 } 1499 1500 PyType PyType::createFromCapsule(py::object capsule) { 1501 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1502 if (mlirTypeIsNull(rawType)) 1503 throw py::error_already_set(); 1504 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1505 rawType); 1506 } 1507 1508 //------------------------------------------------------------------------------ 1509 // PyValue and subclases. 1510 //------------------------------------------------------------------------------ 1511 1512 pybind11::object PyValue::getCapsule() { 1513 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1514 } 1515 1516 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1517 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1518 if (mlirValueIsNull(value)) 1519 throw py::error_already_set(); 1520 MlirOperation owner; 1521 if (mlirValueIsAOpResult(value)) 1522 owner = mlirOpResultGetOwner(value); 1523 if (mlirValueIsABlockArgument(value)) 1524 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1525 if (mlirOperationIsNull(owner)) 1526 throw py::error_already_set(); 1527 MlirContext ctx = mlirOperationGetContext(owner); 1528 PyOperationRef ownerRef = 1529 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1530 return PyValue(ownerRef, value); 1531 } 1532 1533 //------------------------------------------------------------------------------ 1534 // PySymbolTable. 1535 //------------------------------------------------------------------------------ 1536 1537 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1538 : operation(operation.getOperation().getRef()) { 1539 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1540 if (mlirSymbolTableIsNull(symbolTable)) { 1541 throw py::cast_error("Operation is not a Symbol Table."); 1542 } 1543 } 1544 1545 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1546 operation->checkValid(); 1547 MlirOperation symbol = mlirSymbolTableLookup( 1548 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1549 if (mlirOperationIsNull(symbol)) 1550 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1551 1552 return PyOperation::forOperation(operation->getContext(), symbol, 1553 operation.getObject()) 1554 ->createOpView(); 1555 } 1556 1557 void PySymbolTable::erase(PyOperationBase &symbol) { 1558 operation->checkValid(); 1559 symbol.getOperation().checkValid(); 1560 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1561 // The operation is also erased, so we must invalidate it. There may be Python 1562 // references to this operation so we don't want to delete it from the list of 1563 // live operations here. 1564 symbol.getOperation().valid = false; 1565 } 1566 1567 void PySymbolTable::dunderDel(const std::string &name) { 1568 py::object operation = dunderGetItem(name); 1569 erase(py::cast<PyOperationBase &>(operation)); 1570 } 1571 1572 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1573 operation->checkValid(); 1574 symbol.getOperation().checkValid(); 1575 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1576 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1577 if (mlirAttributeIsNull(symbolAttr)) 1578 throw py::value_error("Expected operation to have a symbol name."); 1579 return PyAttribute( 1580 symbol.getOperation().getContext(), 1581 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1582 } 1583 1584 namespace { 1585 /// CRTP base class for Python MLIR values that subclass Value and should be 1586 /// castable from it. The value hierarchy is one level deep and is not supposed 1587 /// to accommodate other levels unless core MLIR changes. 1588 template <typename DerivedTy> 1589 class PyConcreteValue : public PyValue { 1590 public: 1591 // Derived classes must define statics for: 1592 // IsAFunctionTy isaFunction 1593 // const char *pyClassName 1594 // and redefine bindDerived. 1595 using ClassTy = py::class_<DerivedTy, PyValue>; 1596 using IsAFunctionTy = bool (*)(MlirValue); 1597 1598 PyConcreteValue() = default; 1599 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1600 : PyValue(operationRef, value) {} 1601 PyConcreteValue(PyValue &orig) 1602 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1603 1604 /// Attempts to cast the original value to the derived type and throws on 1605 /// type mismatches. 1606 static MlirValue castFrom(PyValue &orig) { 1607 if (!DerivedTy::isaFunction(orig.get())) { 1608 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1609 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1610 DerivedTy::pyClassName + 1611 " (from " + origRepr + ")"); 1612 } 1613 return orig.get(); 1614 } 1615 1616 /// Binds the Python module objects to functions of this class. 1617 static void bind(py::module &m) { 1618 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1619 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1620 cls.def_static("isinstance", [](PyValue &otherValue) -> bool { 1621 return DerivedTy::isaFunction(otherValue); 1622 }); 1623 DerivedTy::bindDerived(cls); 1624 } 1625 1626 /// Implemented by derived classes to add methods to the Python subclass. 1627 static void bindDerived(ClassTy &m) {} 1628 }; 1629 1630 /// Python wrapper for MlirBlockArgument. 1631 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1632 public: 1633 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1634 static constexpr const char *pyClassName = "BlockArgument"; 1635 using PyConcreteValue::PyConcreteValue; 1636 1637 static void bindDerived(ClassTy &c) { 1638 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1639 return PyBlock(self.getParentOperation(), 1640 mlirBlockArgumentGetOwner(self.get())); 1641 }); 1642 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1643 return mlirBlockArgumentGetArgNumber(self.get()); 1644 }); 1645 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1646 return mlirBlockArgumentSetType(self.get(), type); 1647 }); 1648 } 1649 }; 1650 1651 /// Python wrapper for MlirOpResult. 1652 class PyOpResult : public PyConcreteValue<PyOpResult> { 1653 public: 1654 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1655 static constexpr const char *pyClassName = "OpResult"; 1656 using PyConcreteValue::PyConcreteValue; 1657 1658 static void bindDerived(ClassTy &c) { 1659 c.def_property_readonly("owner", [](PyOpResult &self) { 1660 assert( 1661 mlirOperationEqual(self.getParentOperation()->get(), 1662 mlirOpResultGetOwner(self.get())) && 1663 "expected the owner of the value in Python to match that in the IR"); 1664 return self.getParentOperation().getObject(); 1665 }); 1666 c.def_property_readonly("result_number", [](PyOpResult &self) { 1667 return mlirOpResultGetResultNumber(self.get()); 1668 }); 1669 } 1670 }; 1671 1672 /// Returns the list of types of the values held by container. 1673 template <typename Container> 1674 static std::vector<PyType> getValueTypes(Container &container, 1675 PyMlirContextRef &context) { 1676 std::vector<PyType> result; 1677 result.reserve(container.getNumElements()); 1678 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1679 result.push_back( 1680 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1681 } 1682 return result; 1683 } 1684 1685 /// A list of block arguments. Internally, these are stored as consecutive 1686 /// elements, random access is cheap. The argument list is associated with the 1687 /// operation that contains the block (detached blocks are not allowed in 1688 /// Python bindings) and extends its lifetime. 1689 class PyBlockArgumentList 1690 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1691 public: 1692 static constexpr const char *pyClassName = "BlockArgumentList"; 1693 1694 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1695 intptr_t startIndex = 0, intptr_t length = -1, 1696 intptr_t step = 1) 1697 : Sliceable(startIndex, 1698 length == -1 ? mlirBlockGetNumArguments(block) : length, 1699 step), 1700 operation(std::move(operation)), block(block) {} 1701 1702 /// Returns the number of arguments in the list. 1703 intptr_t getNumElements() { 1704 operation->checkValid(); 1705 return mlirBlockGetNumArguments(block); 1706 } 1707 1708 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1709 PyBlockArgument getElement(intptr_t pos) { 1710 MlirValue argument = mlirBlockGetArgument(block, pos); 1711 return PyBlockArgument(operation, argument); 1712 } 1713 1714 /// Returns a sublist of this list. 1715 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1716 intptr_t step) { 1717 return PyBlockArgumentList(operation, block, startIndex, length, step); 1718 } 1719 1720 static void bindDerived(ClassTy &c) { 1721 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1722 return getValueTypes(self, self.operation->getContext()); 1723 }); 1724 } 1725 1726 private: 1727 PyOperationRef operation; 1728 MlirBlock block; 1729 }; 1730 1731 /// A list of operation operands. Internally, these are stored as consecutive 1732 /// elements, random access is cheap. The result list is associated with the 1733 /// operation whose results these are, and extends the lifetime of this 1734 /// operation. 1735 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1736 public: 1737 static constexpr const char *pyClassName = "OpOperandList"; 1738 1739 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1740 intptr_t length = -1, intptr_t step = 1) 1741 : Sliceable(startIndex, 1742 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1743 : length, 1744 step), 1745 operation(operation) {} 1746 1747 intptr_t getNumElements() { 1748 operation->checkValid(); 1749 return mlirOperationGetNumOperands(operation->get()); 1750 } 1751 1752 PyValue getElement(intptr_t pos) { 1753 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1754 MlirOperation owner; 1755 if (mlirValueIsAOpResult(operand)) 1756 owner = mlirOpResultGetOwner(operand); 1757 else if (mlirValueIsABlockArgument(operand)) 1758 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1759 else 1760 assert(false && "Value must be an block arg or op result."); 1761 PyOperationRef pyOwner = 1762 PyOperation::forOperation(operation->getContext(), owner); 1763 return PyValue(pyOwner, operand); 1764 } 1765 1766 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1767 return PyOpOperandList(operation, startIndex, length, step); 1768 } 1769 1770 void dunderSetItem(intptr_t index, PyValue value) { 1771 index = wrapIndex(index); 1772 mlirOperationSetOperand(operation->get(), index, value.get()); 1773 } 1774 1775 static void bindDerived(ClassTy &c) { 1776 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1777 } 1778 1779 private: 1780 PyOperationRef operation; 1781 }; 1782 1783 /// A list of operation results. Internally, these are stored as consecutive 1784 /// elements, random access is cheap. The result list is associated with the 1785 /// operation whose results these are, and extends the lifetime of this 1786 /// operation. 1787 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1788 public: 1789 static constexpr const char *pyClassName = "OpResultList"; 1790 1791 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1792 intptr_t length = -1, intptr_t step = 1) 1793 : Sliceable(startIndex, 1794 length == -1 ? mlirOperationGetNumResults(operation->get()) 1795 : length, 1796 step), 1797 operation(operation) {} 1798 1799 intptr_t getNumElements() { 1800 operation->checkValid(); 1801 return mlirOperationGetNumResults(operation->get()); 1802 } 1803 1804 PyOpResult getElement(intptr_t index) { 1805 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1806 return PyOpResult(value); 1807 } 1808 1809 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1810 return PyOpResultList(operation, startIndex, length, step); 1811 } 1812 1813 static void bindDerived(ClassTy &c) { 1814 c.def_property_readonly("types", [](PyOpResultList &self) { 1815 return getValueTypes(self, self.operation->getContext()); 1816 }); 1817 } 1818 1819 private: 1820 PyOperationRef operation; 1821 }; 1822 1823 /// A list of operation attributes. Can be indexed by name, producing 1824 /// attributes, or by index, producing named attributes. 1825 class PyOpAttributeMap { 1826 public: 1827 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1828 1829 PyAttribute dunderGetItemNamed(const std::string &name) { 1830 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1831 toMlirStringRef(name)); 1832 if (mlirAttributeIsNull(attr)) { 1833 throw SetPyError(PyExc_KeyError, 1834 "attempt to access a non-existent attribute"); 1835 } 1836 return PyAttribute(operation->getContext(), attr); 1837 } 1838 1839 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1840 if (index < 0 || index >= dunderLen()) { 1841 throw SetPyError(PyExc_IndexError, 1842 "attempt to access out of bounds attribute"); 1843 } 1844 MlirNamedAttribute namedAttr = 1845 mlirOperationGetAttribute(operation->get(), index); 1846 return PyNamedAttribute( 1847 namedAttr.attribute, 1848 std::string(mlirIdentifierStr(namedAttr.name).data, 1849 mlirIdentifierStr(namedAttr.name).length)); 1850 } 1851 1852 void dunderSetItem(const std::string &name, PyAttribute attr) { 1853 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1854 attr); 1855 } 1856 1857 void dunderDelItem(const std::string &name) { 1858 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1859 toMlirStringRef(name)); 1860 if (!removed) 1861 throw SetPyError(PyExc_KeyError, 1862 "attempt to delete a non-existent attribute"); 1863 } 1864 1865 intptr_t dunderLen() { 1866 return mlirOperationGetNumAttributes(operation->get()); 1867 } 1868 1869 bool dunderContains(const std::string &name) { 1870 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1871 operation->get(), toMlirStringRef(name))); 1872 } 1873 1874 static void bind(py::module &m) { 1875 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 1876 .def("__contains__", &PyOpAttributeMap::dunderContains) 1877 .def("__len__", &PyOpAttributeMap::dunderLen) 1878 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1879 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1880 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1881 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1882 } 1883 1884 private: 1885 PyOperationRef operation; 1886 }; 1887 1888 } // end namespace 1889 1890 //------------------------------------------------------------------------------ 1891 // Populates the core exports of the 'ir' submodule. 1892 //------------------------------------------------------------------------------ 1893 1894 void mlir::python::populateIRCore(py::module &m) { 1895 //---------------------------------------------------------------------------- 1896 // Mapping of MlirContext. 1897 //---------------------------------------------------------------------------- 1898 py::class_<PyMlirContext>(m, "Context", py::module_local()) 1899 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1900 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1901 .def("_get_context_again", 1902 [](PyMlirContext &self) { 1903 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1904 return ref.releaseObject(); 1905 }) 1906 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1907 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1908 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1909 &PyMlirContext::getCapsule) 1910 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1911 .def("__enter__", &PyMlirContext::contextEnter) 1912 .def("__exit__", &PyMlirContext::contextExit) 1913 .def_property_readonly_static( 1914 "current", 1915 [](py::object & /*class*/) { 1916 auto *context = PyThreadContextEntry::getDefaultContext(); 1917 if (!context) 1918 throw SetPyError(PyExc_ValueError, "No current Context"); 1919 return context; 1920 }, 1921 "Gets the Context bound to the current thread or raises ValueError") 1922 .def_property_readonly( 1923 "dialects", 1924 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1925 "Gets a container for accessing dialects by name") 1926 .def_property_readonly( 1927 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1928 "Alias for 'dialect'") 1929 .def( 1930 "get_dialect_descriptor", 1931 [=](PyMlirContext &self, std::string &name) { 1932 MlirDialect dialect = mlirContextGetOrLoadDialect( 1933 self.get(), {name.data(), name.size()}); 1934 if (mlirDialectIsNull(dialect)) { 1935 throw SetPyError(PyExc_ValueError, 1936 Twine("Dialect '") + name + "' not found"); 1937 } 1938 return PyDialectDescriptor(self.getRef(), dialect); 1939 }, 1940 "Gets or loads a dialect by name, returning its descriptor object") 1941 .def_property( 1942 "allow_unregistered_dialects", 1943 [](PyMlirContext &self) -> bool { 1944 return mlirContextGetAllowUnregisteredDialects(self.get()); 1945 }, 1946 [](PyMlirContext &self, bool value) { 1947 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1948 }) 1949 .def("enable_multithreading", 1950 [](PyMlirContext &self, bool enable) { 1951 mlirContextEnableMultithreading(self.get(), enable); 1952 }) 1953 .def("is_registered_operation", 1954 [](PyMlirContext &self, std::string &name) { 1955 return mlirContextIsRegisteredOperation( 1956 self.get(), MlirStringRef{name.data(), name.size()}); 1957 }); 1958 1959 //---------------------------------------------------------------------------- 1960 // Mapping of PyDialectDescriptor 1961 //---------------------------------------------------------------------------- 1962 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 1963 .def_property_readonly("namespace", 1964 [](PyDialectDescriptor &self) { 1965 MlirStringRef ns = 1966 mlirDialectGetNamespace(self.get()); 1967 return py::str(ns.data, ns.length); 1968 }) 1969 .def("__repr__", [](PyDialectDescriptor &self) { 1970 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1971 std::string repr("<DialectDescriptor "); 1972 repr.append(ns.data, ns.length); 1973 repr.append(">"); 1974 return repr; 1975 }); 1976 1977 //---------------------------------------------------------------------------- 1978 // Mapping of PyDialects 1979 //---------------------------------------------------------------------------- 1980 py::class_<PyDialects>(m, "Dialects", py::module_local()) 1981 .def("__getitem__", 1982 [=](PyDialects &self, std::string keyName) { 1983 MlirDialect dialect = 1984 self.getDialectForKey(keyName, /*attrError=*/false); 1985 py::object descriptor = 1986 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1987 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1988 }) 1989 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1990 MlirDialect dialect = 1991 self.getDialectForKey(attrName, /*attrError=*/true); 1992 py::object descriptor = 1993 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1994 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1995 }); 1996 1997 //---------------------------------------------------------------------------- 1998 // Mapping of PyDialect 1999 //---------------------------------------------------------------------------- 2000 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2001 .def(py::init<py::object>(), "descriptor") 2002 .def_property_readonly( 2003 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2004 .def("__repr__", [](py::object self) { 2005 auto clazz = self.attr("__class__"); 2006 return py::str("<Dialect ") + 2007 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2008 clazz.attr("__module__") + py::str(".") + 2009 clazz.attr("__name__") + py::str(")>"); 2010 }); 2011 2012 //---------------------------------------------------------------------------- 2013 // Mapping of Location 2014 //---------------------------------------------------------------------------- 2015 py::class_<PyLocation>(m, "Location", py::module_local()) 2016 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2017 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2018 .def("__enter__", &PyLocation::contextEnter) 2019 .def("__exit__", &PyLocation::contextExit) 2020 .def("__eq__", 2021 [](PyLocation &self, PyLocation &other) -> bool { 2022 return mlirLocationEqual(self, other); 2023 }) 2024 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2025 .def_property_readonly_static( 2026 "current", 2027 [](py::object & /*class*/) { 2028 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2029 if (!loc) 2030 throw SetPyError(PyExc_ValueError, "No current Location"); 2031 return loc; 2032 }, 2033 "Gets the Location bound to the current thread or raises ValueError") 2034 .def_static( 2035 "unknown", 2036 [](DefaultingPyMlirContext context) { 2037 return PyLocation(context->getRef(), 2038 mlirLocationUnknownGet(context->get())); 2039 }, 2040 py::arg("context") = py::none(), 2041 "Gets a Location representing an unknown location") 2042 .def_static( 2043 "callsite", 2044 [](PyLocation callee, const std::vector<PyLocation> &frames, 2045 DefaultingPyMlirContext context) { 2046 if (frames.empty()) 2047 throw py::value_error("No caller frames provided"); 2048 MlirLocation caller = frames.back().get(); 2049 for (const PyLocation &frame : 2050 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2051 caller = mlirLocationCallSiteGet(frame.get(), caller); 2052 return PyLocation(context->getRef(), 2053 mlirLocationCallSiteGet(callee.get(), caller)); 2054 }, 2055 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2056 kContextGetCallSiteLocationDocstring) 2057 .def_static( 2058 "file", 2059 [](std::string filename, int line, int col, 2060 DefaultingPyMlirContext context) { 2061 return PyLocation( 2062 context->getRef(), 2063 mlirLocationFileLineColGet( 2064 context->get(), toMlirStringRef(filename), line, col)); 2065 }, 2066 py::arg("filename"), py::arg("line"), py::arg("col"), 2067 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2068 .def_static( 2069 "name", 2070 [](std::string name, llvm::Optional<PyLocation> childLoc, 2071 DefaultingPyMlirContext context) { 2072 return PyLocation( 2073 context->getRef(), 2074 mlirLocationNameGet( 2075 context->get(), toMlirStringRef(name), 2076 childLoc ? childLoc->get() 2077 : mlirLocationUnknownGet(context->get()))); 2078 }, 2079 py::arg("name"), py::arg("childLoc") = py::none(), 2080 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2081 .def_property_readonly( 2082 "context", 2083 [](PyLocation &self) { return self.getContext().getObject(); }, 2084 "Context that owns the Location") 2085 .def("__repr__", [](PyLocation &self) { 2086 PyPrintAccumulator printAccum; 2087 mlirLocationPrint(self, printAccum.getCallback(), 2088 printAccum.getUserData()); 2089 return printAccum.join(); 2090 }); 2091 2092 //---------------------------------------------------------------------------- 2093 // Mapping of Module 2094 //---------------------------------------------------------------------------- 2095 py::class_<PyModule>(m, "Module", py::module_local()) 2096 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2097 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2098 .def_static( 2099 "parse", 2100 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2101 MlirModule module = mlirModuleCreateParse( 2102 context->get(), toMlirStringRef(moduleAsm)); 2103 // TODO: Rework error reporting once diagnostic engine is exposed 2104 // in C API. 2105 if (mlirModuleIsNull(module)) { 2106 throw SetPyError( 2107 PyExc_ValueError, 2108 "Unable to parse module assembly (see diagnostics)"); 2109 } 2110 return PyModule::forModule(module).releaseObject(); 2111 }, 2112 py::arg("asm"), py::arg("context") = py::none(), 2113 kModuleParseDocstring) 2114 .def_static( 2115 "create", 2116 [](DefaultingPyLocation loc) { 2117 MlirModule module = mlirModuleCreateEmpty(loc); 2118 return PyModule::forModule(module).releaseObject(); 2119 }, 2120 py::arg("loc") = py::none(), "Creates an empty module") 2121 .def_property_readonly( 2122 "context", 2123 [](PyModule &self) { return self.getContext().getObject(); }, 2124 "Context that created the Module") 2125 .def_property_readonly( 2126 "operation", 2127 [](PyModule &self) { 2128 return PyOperation::forOperation(self.getContext(), 2129 mlirModuleGetOperation(self.get()), 2130 self.getRef().releaseObject()) 2131 .releaseObject(); 2132 }, 2133 "Accesses the module as an operation") 2134 .def_property_readonly( 2135 "body", 2136 [](PyModule &self) { 2137 PyOperationRef module_op = PyOperation::forOperation( 2138 self.getContext(), mlirModuleGetOperation(self.get()), 2139 self.getRef().releaseObject()); 2140 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2141 return returnBlock; 2142 }, 2143 "Return the block for this module") 2144 .def( 2145 "dump", 2146 [](PyModule &self) { 2147 mlirOperationDump(mlirModuleGetOperation(self.get())); 2148 }, 2149 kDumpDocstring) 2150 .def( 2151 "__str__", 2152 [](PyModule &self) { 2153 MlirOperation operation = mlirModuleGetOperation(self.get()); 2154 PyPrintAccumulator printAccum; 2155 mlirOperationPrint(operation, printAccum.getCallback(), 2156 printAccum.getUserData()); 2157 return printAccum.join(); 2158 }, 2159 kOperationStrDunderDocstring); 2160 2161 //---------------------------------------------------------------------------- 2162 // Mapping of Operation. 2163 //---------------------------------------------------------------------------- 2164 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2165 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2166 [](PyOperationBase &self) { 2167 return self.getOperation().getCapsule(); 2168 }) 2169 .def("__eq__", 2170 [](PyOperationBase &self, PyOperationBase &other) { 2171 return &self.getOperation() == &other.getOperation(); 2172 }) 2173 .def("__eq__", 2174 [](PyOperationBase &self, py::object other) { return false; }) 2175 .def("__hash__", 2176 [](PyOperationBase &self) { 2177 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2178 }) 2179 .def_property_readonly("attributes", 2180 [](PyOperationBase &self) { 2181 return PyOpAttributeMap( 2182 self.getOperation().getRef()); 2183 }) 2184 .def_property_readonly("operands", 2185 [](PyOperationBase &self) { 2186 return PyOpOperandList( 2187 self.getOperation().getRef()); 2188 }) 2189 .def_property_readonly("regions", 2190 [](PyOperationBase &self) { 2191 return PyRegionList( 2192 self.getOperation().getRef()); 2193 }) 2194 .def_property_readonly( 2195 "results", 2196 [](PyOperationBase &self) { 2197 return PyOpResultList(self.getOperation().getRef()); 2198 }, 2199 "Returns the list of Operation results.") 2200 .def_property_readonly( 2201 "result", 2202 [](PyOperationBase &self) { 2203 auto &operation = self.getOperation(); 2204 auto numResults = mlirOperationGetNumResults(operation); 2205 if (numResults != 1) { 2206 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2207 throw SetPyError( 2208 PyExc_ValueError, 2209 Twine("Cannot call .result on operation ") + 2210 StringRef(name.data, name.length) + " which has " + 2211 Twine(numResults) + 2212 " results (it is only valid for operations with a " 2213 "single result)"); 2214 } 2215 return PyOpResult(operation.getRef(), 2216 mlirOperationGetResult(operation, 0)); 2217 }, 2218 "Shortcut to get an op result if it has only one (throws an error " 2219 "otherwise).") 2220 .def_property_readonly( 2221 "location", 2222 [](PyOperationBase &self) { 2223 PyOperation &operation = self.getOperation(); 2224 return PyLocation(operation.getContext(), 2225 mlirOperationGetLocation(operation.get())); 2226 }, 2227 "Returns the source location the operation was defined or derived " 2228 "from.") 2229 .def( 2230 "__str__", 2231 [](PyOperationBase &self) { 2232 return self.getAsm(/*binary=*/false, 2233 /*largeElementsLimit=*/llvm::None, 2234 /*enableDebugInfo=*/false, 2235 /*prettyDebugInfo=*/false, 2236 /*printGenericOpForm=*/false, 2237 /*useLocalScope=*/false); 2238 }, 2239 "Returns the assembly form of the operation.") 2240 .def("print", &PyOperationBase::print, 2241 // Careful: Lots of arguments must match up with print method. 2242 py::arg("file") = py::none(), py::arg("binary") = false, 2243 py::arg("large_elements_limit") = py::none(), 2244 py::arg("enable_debug_info") = false, 2245 py::arg("pretty_debug_info") = false, 2246 py::arg("print_generic_op_form") = false, 2247 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2248 .def("get_asm", &PyOperationBase::getAsm, 2249 // Careful: Lots of arguments must match up with get_asm method. 2250 py::arg("binary") = false, 2251 py::arg("large_elements_limit") = py::none(), 2252 py::arg("enable_debug_info") = false, 2253 py::arg("pretty_debug_info") = false, 2254 py::arg("print_generic_op_form") = false, 2255 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2256 .def( 2257 "verify", 2258 [](PyOperationBase &self) { 2259 return mlirOperationVerify(self.getOperation()); 2260 }, 2261 "Verify the operation and return true if it passes, false if it " 2262 "fails.") 2263 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2264 "Puts self immediately after the other operation in its parent " 2265 "block.") 2266 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2267 "Puts self immediately before the other operation in its parent " 2268 "block.") 2269 .def( 2270 "detach_from_parent", 2271 [](PyOperationBase &self) { 2272 PyOperation &operation = self.getOperation(); 2273 operation.checkValid(); 2274 if (!operation.isAttached()) 2275 throw py::value_error("Detached operation has no parent."); 2276 2277 operation.detachFromParent(); 2278 return operation.createOpView(); 2279 }, 2280 "Detaches the operation from its parent block."); 2281 2282 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2283 .def_static("create", &PyOperation::create, py::arg("name"), 2284 py::arg("results") = py::none(), 2285 py::arg("operands") = py::none(), 2286 py::arg("attributes") = py::none(), 2287 py::arg("successors") = py::none(), py::arg("regions") = 0, 2288 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2289 kOperationCreateDocstring) 2290 .def_property_readonly("parent", 2291 [](PyOperation &self) -> py::object { 2292 auto parent = self.getParentOperation(); 2293 if (parent) 2294 return parent->getObject(); 2295 return py::none(); 2296 }) 2297 .def("erase", &PyOperation::erase) 2298 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2299 &PyOperation::getCapsule) 2300 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2301 .def_property_readonly("name", 2302 [](PyOperation &self) { 2303 self.checkValid(); 2304 MlirOperation operation = self.get(); 2305 MlirStringRef name = mlirIdentifierStr( 2306 mlirOperationGetName(operation)); 2307 return py::str(name.data, name.length); 2308 }) 2309 .def_property_readonly( 2310 "context", 2311 [](PyOperation &self) { 2312 self.checkValid(); 2313 return self.getContext().getObject(); 2314 }, 2315 "Context that owns the Operation") 2316 .def_property_readonly("opview", &PyOperation::createOpView); 2317 2318 auto opViewClass = 2319 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2320 .def(py::init<py::object>()) 2321 .def_property_readonly("operation", &PyOpView::getOperationObject) 2322 .def_property_readonly( 2323 "context", 2324 [](PyOpView &self) { 2325 return self.getOperation().getContext().getObject(); 2326 }, 2327 "Context that owns the Operation") 2328 .def("__str__", [](PyOpView &self) { 2329 return py::str(self.getOperationObject()); 2330 }); 2331 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2332 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2333 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2334 opViewClass.attr("build_generic") = classmethod( 2335 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2336 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2337 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2338 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2339 "Builds a specific, generated OpView based on class level attributes."); 2340 2341 //---------------------------------------------------------------------------- 2342 // Mapping of PyRegion. 2343 //---------------------------------------------------------------------------- 2344 py::class_<PyRegion>(m, "Region", py::module_local()) 2345 .def_property_readonly( 2346 "blocks", 2347 [](PyRegion &self) { 2348 return PyBlockList(self.getParentOperation(), self.get()); 2349 }, 2350 "Returns a forward-optimized sequence of blocks.") 2351 .def_property_readonly( 2352 "owner", 2353 [](PyRegion &self) { 2354 return self.getParentOperation()->createOpView(); 2355 }, 2356 "Returns the operation owning this region.") 2357 .def( 2358 "__iter__", 2359 [](PyRegion &self) { 2360 self.checkValid(); 2361 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2362 return PyBlockIterator(self.getParentOperation(), firstBlock); 2363 }, 2364 "Iterates over blocks in the region.") 2365 .def("__eq__", 2366 [](PyRegion &self, PyRegion &other) { 2367 return self.get().ptr == other.get().ptr; 2368 }) 2369 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2370 2371 //---------------------------------------------------------------------------- 2372 // Mapping of PyBlock. 2373 //---------------------------------------------------------------------------- 2374 py::class_<PyBlock>(m, "Block", py::module_local()) 2375 .def_property_readonly( 2376 "owner", 2377 [](PyBlock &self) { 2378 return self.getParentOperation()->createOpView(); 2379 }, 2380 "Returns the owning operation of this block.") 2381 .def_property_readonly( 2382 "region", 2383 [](PyBlock &self) { 2384 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2385 return PyRegion(self.getParentOperation(), region); 2386 }, 2387 "Returns the owning region of this block.") 2388 .def_property_readonly( 2389 "arguments", 2390 [](PyBlock &self) { 2391 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2392 }, 2393 "Returns a list of block arguments.") 2394 .def_property_readonly( 2395 "operations", 2396 [](PyBlock &self) { 2397 return PyOperationList(self.getParentOperation(), self.get()); 2398 }, 2399 "Returns a forward-optimized sequence of operations.") 2400 .def_static( 2401 "create_at_start", 2402 [](PyRegion &parent, py::list pyArgTypes) { 2403 parent.checkValid(); 2404 llvm::SmallVector<MlirType, 4> argTypes; 2405 argTypes.reserve(pyArgTypes.size()); 2406 for (auto &pyArg : pyArgTypes) { 2407 argTypes.push_back(pyArg.cast<PyType &>()); 2408 } 2409 2410 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2411 mlirRegionInsertOwnedBlock(parent, 0, block); 2412 return PyBlock(parent.getParentOperation(), block); 2413 }, 2414 py::arg("parent"), py::arg("pyArgTypes") = py::list(), 2415 "Creates and returns a new Block at the beginning of the given " 2416 "region (with given argument types).") 2417 .def( 2418 "create_before", 2419 [](PyBlock &self, py::args pyArgTypes) { 2420 self.checkValid(); 2421 llvm::SmallVector<MlirType, 4> argTypes; 2422 argTypes.reserve(pyArgTypes.size()); 2423 for (auto &pyArg : pyArgTypes) { 2424 argTypes.push_back(pyArg.cast<PyType &>()); 2425 } 2426 2427 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2428 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2429 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2430 return PyBlock(self.getParentOperation(), block); 2431 }, 2432 "Creates and returns a new Block before this block " 2433 "(with given argument types).") 2434 .def( 2435 "create_after", 2436 [](PyBlock &self, py::args pyArgTypes) { 2437 self.checkValid(); 2438 llvm::SmallVector<MlirType, 4> argTypes; 2439 argTypes.reserve(pyArgTypes.size()); 2440 for (auto &pyArg : pyArgTypes) { 2441 argTypes.push_back(pyArg.cast<PyType &>()); 2442 } 2443 2444 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2445 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2446 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2447 return PyBlock(self.getParentOperation(), block); 2448 }, 2449 "Creates and returns a new Block after this block " 2450 "(with given argument types).") 2451 .def( 2452 "__iter__", 2453 [](PyBlock &self) { 2454 self.checkValid(); 2455 MlirOperation firstOperation = 2456 mlirBlockGetFirstOperation(self.get()); 2457 return PyOperationIterator(self.getParentOperation(), 2458 firstOperation); 2459 }, 2460 "Iterates over operations in the block.") 2461 .def("__eq__", 2462 [](PyBlock &self, PyBlock &other) { 2463 return self.get().ptr == other.get().ptr; 2464 }) 2465 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2466 .def( 2467 "__str__", 2468 [](PyBlock &self) { 2469 self.checkValid(); 2470 PyPrintAccumulator printAccum; 2471 mlirBlockPrint(self.get(), printAccum.getCallback(), 2472 printAccum.getUserData()); 2473 return printAccum.join(); 2474 }, 2475 "Returns the assembly form of the block.") 2476 .def( 2477 "append", 2478 [](PyBlock &self, PyOperationBase &operation) { 2479 if (operation.getOperation().isAttached()) 2480 operation.getOperation().detachFromParent(); 2481 2482 MlirOperation mlirOperation = operation.getOperation().get(); 2483 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2484 operation.getOperation().setAttached( 2485 self.getParentOperation().getObject()); 2486 }, 2487 "Appends an operation to this block. If the operation is currently " 2488 "in another block, it will be moved."); 2489 2490 //---------------------------------------------------------------------------- 2491 // Mapping of PyInsertionPoint. 2492 //---------------------------------------------------------------------------- 2493 2494 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2495 .def(py::init<PyBlock &>(), py::arg("block"), 2496 "Inserts after the last operation but still inside the block.") 2497 .def("__enter__", &PyInsertionPoint::contextEnter) 2498 .def("__exit__", &PyInsertionPoint::contextExit) 2499 .def_property_readonly_static( 2500 "current", 2501 [](py::object & /*class*/) { 2502 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2503 if (!ip) 2504 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2505 return ip; 2506 }, 2507 "Gets the InsertionPoint bound to the current thread or raises " 2508 "ValueError if none has been set") 2509 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2510 "Inserts before a referenced operation.") 2511 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2512 py::arg("block"), "Inserts at the beginning of the block.") 2513 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2514 py::arg("block"), "Inserts before the block terminator.") 2515 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2516 "Inserts an operation.") 2517 .def_property_readonly( 2518 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2519 "Returns the block that this InsertionPoint points to."); 2520 2521 //---------------------------------------------------------------------------- 2522 // Mapping of PyAttribute. 2523 //---------------------------------------------------------------------------- 2524 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2525 // Delegate to the PyAttribute copy constructor, which will also lifetime 2526 // extend the backing context which owns the MlirAttribute. 2527 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2528 "Casts the passed attribute to the generic Attribute") 2529 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2530 &PyAttribute::getCapsule) 2531 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2532 .def_static( 2533 "parse", 2534 [](std::string attrSpec, DefaultingPyMlirContext context) { 2535 MlirAttribute type = mlirAttributeParseGet( 2536 context->get(), toMlirStringRef(attrSpec)); 2537 // TODO: Rework error reporting once diagnostic engine is exposed 2538 // in C API. 2539 if (mlirAttributeIsNull(type)) { 2540 throw SetPyError(PyExc_ValueError, 2541 Twine("Unable to parse attribute: '") + 2542 attrSpec + "'"); 2543 } 2544 return PyAttribute(context->getRef(), type); 2545 }, 2546 py::arg("asm"), py::arg("context") = py::none(), 2547 "Parses an attribute from an assembly form") 2548 .def_property_readonly( 2549 "context", 2550 [](PyAttribute &self) { return self.getContext().getObject(); }, 2551 "Context that owns the Attribute") 2552 .def_property_readonly("type", 2553 [](PyAttribute &self) { 2554 return PyType(self.getContext()->getRef(), 2555 mlirAttributeGetType(self)); 2556 }) 2557 .def( 2558 "get_named", 2559 [](PyAttribute &self, std::string name) { 2560 return PyNamedAttribute(self, std::move(name)); 2561 }, 2562 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2563 .def("__eq__", 2564 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2565 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2566 .def("__hash__", 2567 [](PyAttribute &self) { 2568 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2569 }) 2570 .def( 2571 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2572 kDumpDocstring) 2573 .def( 2574 "__str__", 2575 [](PyAttribute &self) { 2576 PyPrintAccumulator printAccum; 2577 mlirAttributePrint(self, printAccum.getCallback(), 2578 printAccum.getUserData()); 2579 return printAccum.join(); 2580 }, 2581 "Returns the assembly form of the Attribute.") 2582 .def("__repr__", [](PyAttribute &self) { 2583 // Generally, assembly formats are not printed for __repr__ because 2584 // this can cause exceptionally long debug output and exceptions. 2585 // However, attribute values are generally considered useful and are 2586 // printed. This may need to be re-evaluated if debug dumps end up 2587 // being excessive. 2588 PyPrintAccumulator printAccum; 2589 printAccum.parts.append("Attribute("); 2590 mlirAttributePrint(self, printAccum.getCallback(), 2591 printAccum.getUserData()); 2592 printAccum.parts.append(")"); 2593 return printAccum.join(); 2594 }); 2595 2596 //---------------------------------------------------------------------------- 2597 // Mapping of PyNamedAttribute 2598 //---------------------------------------------------------------------------- 2599 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2600 .def("__repr__", 2601 [](PyNamedAttribute &self) { 2602 PyPrintAccumulator printAccum; 2603 printAccum.parts.append("NamedAttribute("); 2604 printAccum.parts.append( 2605 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2606 mlirIdentifierStr(self.namedAttr.name).length)); 2607 printAccum.parts.append("="); 2608 mlirAttributePrint(self.namedAttr.attribute, 2609 printAccum.getCallback(), 2610 printAccum.getUserData()); 2611 printAccum.parts.append(")"); 2612 return printAccum.join(); 2613 }) 2614 .def_property_readonly( 2615 "name", 2616 [](PyNamedAttribute &self) { 2617 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2618 mlirIdentifierStr(self.namedAttr.name).length); 2619 }, 2620 "The name of the NamedAttribute binding") 2621 .def_property_readonly( 2622 "attr", 2623 [](PyNamedAttribute &self) { 2624 // TODO: When named attribute is removed/refactored, also remove 2625 // this constructor (it does an inefficient table lookup). 2626 auto contextRef = PyMlirContext::forContext( 2627 mlirAttributeGetContext(self.namedAttr.attribute)); 2628 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2629 }, 2630 py::keep_alive<0, 1>(), 2631 "The underlying generic attribute of the NamedAttribute binding"); 2632 2633 //---------------------------------------------------------------------------- 2634 // Mapping of PyType. 2635 //---------------------------------------------------------------------------- 2636 py::class_<PyType>(m, "Type", py::module_local()) 2637 // Delegate to the PyType copy constructor, which will also lifetime 2638 // extend the backing context which owns the MlirType. 2639 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2640 "Casts the passed type to the generic Type") 2641 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2642 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2643 .def_static( 2644 "parse", 2645 [](std::string typeSpec, DefaultingPyMlirContext context) { 2646 MlirType type = 2647 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2648 // TODO: Rework error reporting once diagnostic engine is exposed 2649 // in C API. 2650 if (mlirTypeIsNull(type)) { 2651 throw SetPyError(PyExc_ValueError, 2652 Twine("Unable to parse type: '") + typeSpec + 2653 "'"); 2654 } 2655 return PyType(context->getRef(), type); 2656 }, 2657 py::arg("asm"), py::arg("context") = py::none(), 2658 kContextParseTypeDocstring) 2659 .def_property_readonly( 2660 "context", [](PyType &self) { return self.getContext().getObject(); }, 2661 "Context that owns the Type") 2662 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2663 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2664 .def("__hash__", 2665 [](PyType &self) { 2666 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2667 }) 2668 .def( 2669 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2670 .def( 2671 "__str__", 2672 [](PyType &self) { 2673 PyPrintAccumulator printAccum; 2674 mlirTypePrint(self, printAccum.getCallback(), 2675 printAccum.getUserData()); 2676 return printAccum.join(); 2677 }, 2678 "Returns the assembly form of the type.") 2679 .def("__repr__", [](PyType &self) { 2680 // Generally, assembly formats are not printed for __repr__ because 2681 // this can cause exceptionally long debug output and exceptions. 2682 // However, types are an exception as they typically have compact 2683 // assembly forms and printing them is useful. 2684 PyPrintAccumulator printAccum; 2685 printAccum.parts.append("Type("); 2686 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2687 printAccum.parts.append(")"); 2688 return printAccum.join(); 2689 }); 2690 2691 //---------------------------------------------------------------------------- 2692 // Mapping of Value. 2693 //---------------------------------------------------------------------------- 2694 py::class_<PyValue>(m, "Value", py::module_local()) 2695 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2696 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2697 .def_property_readonly( 2698 "context", 2699 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2700 "Context in which the value lives.") 2701 .def( 2702 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2703 kDumpDocstring) 2704 .def_property_readonly( 2705 "owner", 2706 [](PyValue &self) { 2707 assert(mlirOperationEqual(self.getParentOperation()->get(), 2708 mlirOpResultGetOwner(self.get())) && 2709 "expected the owner of the value in Python to match that in " 2710 "the IR"); 2711 return self.getParentOperation().getObject(); 2712 }) 2713 .def("__eq__", 2714 [](PyValue &self, PyValue &other) { 2715 return self.get().ptr == other.get().ptr; 2716 }) 2717 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2718 .def("__hash__", 2719 [](PyValue &self) { 2720 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2721 }) 2722 .def( 2723 "__str__", 2724 [](PyValue &self) { 2725 PyPrintAccumulator printAccum; 2726 printAccum.parts.append("Value("); 2727 mlirValuePrint(self.get(), printAccum.getCallback(), 2728 printAccum.getUserData()); 2729 printAccum.parts.append(")"); 2730 return printAccum.join(); 2731 }, 2732 kValueDunderStrDocstring) 2733 .def_property_readonly("type", [](PyValue &self) { 2734 return PyType(self.getParentOperation()->getContext(), 2735 mlirValueGetType(self.get())); 2736 }); 2737 PyBlockArgument::bind(m); 2738 PyOpResult::bind(m); 2739 2740 //---------------------------------------------------------------------------- 2741 // Mapping of SymbolTable. 2742 //---------------------------------------------------------------------------- 2743 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 2744 .def(py::init<PyOperationBase &>()) 2745 .def("__getitem__", &PySymbolTable::dunderGetItem) 2746 .def("insert", &PySymbolTable::insert) 2747 .def("erase", &PySymbolTable::erase) 2748 .def("__delitem__", &PySymbolTable::dunderDel) 2749 .def("__contains__", [](PySymbolTable &table, const std::string &name) { 2750 return !mlirOperationIsNull(mlirSymbolTableLookup( 2751 table, mlirStringRefCreate(name.data(), name.length()))); 2752 }); 2753 2754 // Container bindings. 2755 PyBlockArgumentList::bind(m); 2756 PyBlockIterator::bind(m); 2757 PyBlockList::bind(m); 2758 PyOperationIterator::bind(m); 2759 PyOperationList::bind(m); 2760 PyOpAttributeMap::bind(m); 2761 PyOpOperandList::bind(m); 2762 PyOpResultList::bind(m); 2763 PyRegionIterator::bind(m); 2764 PyRegionList::bind(m); 2765 2766 // Debug bindings. 2767 PyGlobalDebugFlag::bind(m); 2768 } 2769