1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Registration.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include <pybind11/stl.h> 20 21 namespace py = pybind11; 22 using namespace mlir; 23 using namespace mlir::python; 24 25 using llvm::SmallVector; 26 using llvm::StringRef; 27 using llvm::Twine; 28 29 //------------------------------------------------------------------------------ 30 // Docstrings (trivial, non-duplicated docstrings are included inline). 31 //------------------------------------------------------------------------------ 32 33 static const char kContextParseTypeDocstring[] = 34 R"(Parses the assembly form of a type. 35 36 Returns a Type object or raises a ValueError if the type cannot be parsed. 37 38 See also: https://mlir.llvm.org/docs/LangRef/#type-system 39 )"; 40 41 static const char kContextGetFileLocationDocstring[] = 42 R"(Gets a Location representing a file, line and column)"; 43 44 static const char kModuleParseDocstring[] = 45 R"(Parses a module's assembly format from a string. 46 47 Returns a new MlirModule or raises a ValueError if the parsing fails. 48 49 See also: https://mlir.llvm.org/docs/LangRef/ 50 )"; 51 52 static const char kOperationCreateDocstring[] = 53 R"(Creates a new operation. 54 55 Args: 56 name: Operation name (e.g. "dialect.operation"). 57 results: Sequence of Type representing op result types. 58 attributes: Dict of str:Attribute. 59 successors: List of Block for the operation's successors. 60 regions: Number of regions to create. 61 location: A Location object (defaults to resolve from context manager). 62 ip: An InsertionPoint (defaults to resolve from context manager or set to 63 False to disable insertion, even with an insertion point set in the 64 context manager). 65 Returns: 66 A new "detached" Operation object. Detached operations can be added 67 to blocks, which causes them to become "attached." 68 )"; 69 70 static const char kOperationPrintDocstring[] = 71 R"(Prints the assembly form of the operation to a file like object. 72 73 Args: 74 file: The file like object to write to. Defaults to sys.stdout. 75 binary: Whether to write bytes (True) or str (False). Defaults to False. 76 large_elements_limit: Whether to elide elements attributes above this 77 number of elements. Defaults to None (no limit). 78 enable_debug_info: Whether to print debug/location information. Defaults 79 to False. 80 pretty_debug_info: Whether to format debug information for easier reading 81 by a human (warning: the result is unparseable). 82 print_generic_op_form: Whether to print the generic assembly forms of all 83 ops. Defaults to False. 84 use_local_Scope: Whether to print in a way that is more optimized for 85 multi-threaded access but may not be consistent with how the overall 86 module prints. 87 )"; 88 89 static const char kOperationGetAsmDocstring[] = 90 R"(Gets the assembly form of the operation with all options available. 91 92 Args: 93 binary: Whether to return a bytes (True) or str (False) object. Defaults to 94 False. 95 ... others ...: See the print() method for common keyword arguments for 96 configuring the printout. 97 Returns: 98 Either a bytes or str object, depending on the setting of the 'binary' 99 argument. 100 )"; 101 102 static const char kOperationStrDunderDocstring[] = 103 R"(Gets the assembly form of the operation with default options. 104 105 If more advanced control over the assembly formatting or I/O options is needed, 106 use the dedicated print or get_asm method, which supports keyword arguments to 107 customize behavior. 108 )"; 109 110 static const char kDumpDocstring[] = 111 R"(Dumps a debug representation of the object to stderr.)"; 112 113 static const char kAppendBlockDocstring[] = 114 R"(Appends a new block, with argument types as positional args. 115 116 Returns: 117 The created block. 118 )"; 119 120 static const char kValueDunderStrDocstring[] = 121 R"(Returns the string form of the value. 122 123 If the value is a block argument, this is the assembly form of its type and the 124 position in the argument list. If the value is an operation result, this is 125 equivalent to printing the operation that produced it. 126 )"; 127 128 //------------------------------------------------------------------------------ 129 // Utilities. 130 //------------------------------------------------------------------------------ 131 132 // Helper for creating an @classmethod. 133 template <class Func, typename... Args> 134 py::object classmethod(Func f, Args... args) { 135 py::object cf = py::cpp_function(f, args...); 136 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 137 } 138 139 static py::object 140 createCustomDialectWrapper(const std::string &dialectNamespace, 141 py::object dialectDescriptor) { 142 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 143 if (!dialectClass) { 144 // Use the base class. 145 return py::cast(PyDialect(std::move(dialectDescriptor))); 146 } 147 148 // Create the custom implementation. 149 return (*dialectClass)(std::move(dialectDescriptor)); 150 } 151 152 static MlirStringRef toMlirStringRef(const std::string &s) { 153 return mlirStringRefCreate(s.data(), s.size()); 154 } 155 156 //------------------------------------------------------------------------------ 157 // Collections. 158 //------------------------------------------------------------------------------ 159 160 namespace { 161 162 class PyRegionIterator { 163 public: 164 PyRegionIterator(PyOperationRef operation) 165 : operation(std::move(operation)) {} 166 167 PyRegionIterator &dunderIter() { return *this; } 168 169 PyRegion dunderNext() { 170 operation->checkValid(); 171 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 172 throw py::stop_iteration(); 173 } 174 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 175 return PyRegion(operation, region); 176 } 177 178 static void bind(py::module &m) { 179 py::class_<PyRegionIterator>(m, "RegionIterator") 180 .def("__iter__", &PyRegionIterator::dunderIter) 181 .def("__next__", &PyRegionIterator::dunderNext); 182 } 183 184 private: 185 PyOperationRef operation; 186 int nextIndex = 0; 187 }; 188 189 /// Regions of an op are fixed length and indexed numerically so are represented 190 /// with a sequence-like container. 191 class PyRegionList { 192 public: 193 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 194 195 intptr_t dunderLen() { 196 operation->checkValid(); 197 return mlirOperationGetNumRegions(operation->get()); 198 } 199 200 PyRegion dunderGetItem(intptr_t index) { 201 // dunderLen checks validity. 202 if (index < 0 || index >= dunderLen()) { 203 throw SetPyError(PyExc_IndexError, 204 "attempt to access out of bounds region"); 205 } 206 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 207 return PyRegion(operation, region); 208 } 209 210 static void bind(py::module &m) { 211 py::class_<PyRegionList>(m, "RegionSequence") 212 .def("__len__", &PyRegionList::dunderLen) 213 .def("__getitem__", &PyRegionList::dunderGetItem); 214 } 215 216 private: 217 PyOperationRef operation; 218 }; 219 220 class PyBlockIterator { 221 public: 222 PyBlockIterator(PyOperationRef operation, MlirBlock next) 223 : operation(std::move(operation)), next(next) {} 224 225 PyBlockIterator &dunderIter() { return *this; } 226 227 PyBlock dunderNext() { 228 operation->checkValid(); 229 if (mlirBlockIsNull(next)) { 230 throw py::stop_iteration(); 231 } 232 233 PyBlock returnBlock(operation, next); 234 next = mlirBlockGetNextInRegion(next); 235 return returnBlock; 236 } 237 238 static void bind(py::module &m) { 239 py::class_<PyBlockIterator>(m, "BlockIterator") 240 .def("__iter__", &PyBlockIterator::dunderIter) 241 .def("__next__", &PyBlockIterator::dunderNext); 242 } 243 244 private: 245 PyOperationRef operation; 246 MlirBlock next; 247 }; 248 249 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 250 /// we present them as a more full-featured list-like container but optimize 251 /// it for forward iteration. Blocks are always owned by a region. 252 class PyBlockList { 253 public: 254 PyBlockList(PyOperationRef operation, MlirRegion region) 255 : operation(std::move(operation)), region(region) {} 256 257 PyBlockIterator dunderIter() { 258 operation->checkValid(); 259 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 260 } 261 262 intptr_t dunderLen() { 263 operation->checkValid(); 264 intptr_t count = 0; 265 MlirBlock block = mlirRegionGetFirstBlock(region); 266 while (!mlirBlockIsNull(block)) { 267 count += 1; 268 block = mlirBlockGetNextInRegion(block); 269 } 270 return count; 271 } 272 273 PyBlock dunderGetItem(intptr_t index) { 274 operation->checkValid(); 275 if (index < 0) { 276 throw SetPyError(PyExc_IndexError, 277 "attempt to access out of bounds block"); 278 } 279 MlirBlock block = mlirRegionGetFirstBlock(region); 280 while (!mlirBlockIsNull(block)) { 281 if (index == 0) { 282 return PyBlock(operation, block); 283 } 284 block = mlirBlockGetNextInRegion(block); 285 index -= 1; 286 } 287 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 288 } 289 290 PyBlock appendBlock(py::args pyArgTypes) { 291 operation->checkValid(); 292 llvm::SmallVector<MlirType, 4> argTypes; 293 argTypes.reserve(pyArgTypes.size()); 294 for (auto &pyArg : pyArgTypes) { 295 argTypes.push_back(pyArg.cast<PyType &>()); 296 } 297 298 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 299 mlirRegionAppendOwnedBlock(region, block); 300 return PyBlock(operation, block); 301 } 302 303 static void bind(py::module &m) { 304 py::class_<PyBlockList>(m, "BlockList") 305 .def("__getitem__", &PyBlockList::dunderGetItem) 306 .def("__iter__", &PyBlockList::dunderIter) 307 .def("__len__", &PyBlockList::dunderLen) 308 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 309 } 310 311 private: 312 PyOperationRef operation; 313 MlirRegion region; 314 }; 315 316 class PyOperationIterator { 317 public: 318 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 319 : parentOperation(std::move(parentOperation)), next(next) {} 320 321 PyOperationIterator &dunderIter() { return *this; } 322 323 py::object dunderNext() { 324 parentOperation->checkValid(); 325 if (mlirOperationIsNull(next)) { 326 throw py::stop_iteration(); 327 } 328 329 PyOperationRef returnOperation = 330 PyOperation::forOperation(parentOperation->getContext(), next); 331 next = mlirOperationGetNextInBlock(next); 332 return returnOperation->createOpView(); 333 } 334 335 static void bind(py::module &m) { 336 py::class_<PyOperationIterator>(m, "OperationIterator") 337 .def("__iter__", &PyOperationIterator::dunderIter) 338 .def("__next__", &PyOperationIterator::dunderNext); 339 } 340 341 private: 342 PyOperationRef parentOperation; 343 MlirOperation next; 344 }; 345 346 /// Operations are exposed by the C-API as a forward-only linked list. In 347 /// Python, we present them as a more full-featured list-like container but 348 /// optimize it for forward iteration. Iterable operations are always owned 349 /// by a block. 350 class PyOperationList { 351 public: 352 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 353 : parentOperation(std::move(parentOperation)), block(block) {} 354 355 PyOperationIterator dunderIter() { 356 parentOperation->checkValid(); 357 return PyOperationIterator(parentOperation, 358 mlirBlockGetFirstOperation(block)); 359 } 360 361 intptr_t dunderLen() { 362 parentOperation->checkValid(); 363 intptr_t count = 0; 364 MlirOperation childOp = mlirBlockGetFirstOperation(block); 365 while (!mlirOperationIsNull(childOp)) { 366 count += 1; 367 childOp = mlirOperationGetNextInBlock(childOp); 368 } 369 return count; 370 } 371 372 py::object dunderGetItem(intptr_t index) { 373 parentOperation->checkValid(); 374 if (index < 0) { 375 throw SetPyError(PyExc_IndexError, 376 "attempt to access out of bounds operation"); 377 } 378 MlirOperation childOp = mlirBlockGetFirstOperation(block); 379 while (!mlirOperationIsNull(childOp)) { 380 if (index == 0) { 381 return PyOperation::forOperation(parentOperation->getContext(), childOp) 382 ->createOpView(); 383 } 384 childOp = mlirOperationGetNextInBlock(childOp); 385 index -= 1; 386 } 387 throw SetPyError(PyExc_IndexError, 388 "attempt to access out of bounds operation"); 389 } 390 391 static void bind(py::module &m) { 392 py::class_<PyOperationList>(m, "OperationList") 393 .def("__getitem__", &PyOperationList::dunderGetItem) 394 .def("__iter__", &PyOperationList::dunderIter) 395 .def("__len__", &PyOperationList::dunderLen); 396 } 397 398 private: 399 PyOperationRef parentOperation; 400 MlirBlock block; 401 }; 402 403 } // namespace 404 405 //------------------------------------------------------------------------------ 406 // PyMlirContext 407 //------------------------------------------------------------------------------ 408 409 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 410 py::gil_scoped_acquire acquire; 411 auto &liveContexts = getLiveContexts(); 412 liveContexts[context.ptr] = this; 413 } 414 415 PyMlirContext::~PyMlirContext() { 416 // Note that the only public way to construct an instance is via the 417 // forContext method, which always puts the associated handle into 418 // liveContexts. 419 py::gil_scoped_acquire acquire; 420 getLiveContexts().erase(context.ptr); 421 mlirContextDestroy(context); 422 } 423 424 py::object PyMlirContext::getCapsule() { 425 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 426 } 427 428 py::object PyMlirContext::createFromCapsule(py::object capsule) { 429 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 430 if (mlirContextIsNull(rawContext)) 431 throw py::error_already_set(); 432 return forContext(rawContext).releaseObject(); 433 } 434 435 PyMlirContext *PyMlirContext::createNewContextForInit() { 436 MlirContext context = mlirContextCreate(); 437 mlirRegisterAllDialects(context); 438 return new PyMlirContext(context); 439 } 440 441 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 442 py::gil_scoped_acquire acquire; 443 auto &liveContexts = getLiveContexts(); 444 auto it = liveContexts.find(context.ptr); 445 if (it == liveContexts.end()) { 446 // Create. 447 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 448 py::object pyRef = py::cast(unownedContextWrapper); 449 assert(pyRef && "cast to py::object failed"); 450 liveContexts[context.ptr] = unownedContextWrapper; 451 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 452 } 453 // Use existing. 454 py::object pyRef = py::cast(it->second); 455 return PyMlirContextRef(it->second, std::move(pyRef)); 456 } 457 458 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 459 static LiveContextMap liveContexts; 460 return liveContexts; 461 } 462 463 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 464 465 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 466 467 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 468 469 pybind11::object PyMlirContext::contextEnter() { 470 return PyThreadContextEntry::pushContext(*this); 471 } 472 473 void PyMlirContext::contextExit(pybind11::object excType, 474 pybind11::object excVal, 475 pybind11::object excTb) { 476 PyThreadContextEntry::popContext(*this); 477 } 478 479 PyMlirContext &DefaultingPyMlirContext::resolve() { 480 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 481 if (!context) { 482 throw SetPyError( 483 PyExc_RuntimeError, 484 "An MLIR function requires a Context but none was provided in the call " 485 "or from the surrounding environment. Either pass to the function with " 486 "a 'context=' argument or establish a default using 'with Context():'"); 487 } 488 return *context; 489 } 490 491 //------------------------------------------------------------------------------ 492 // PyThreadContextEntry management 493 //------------------------------------------------------------------------------ 494 495 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 496 static thread_local std::vector<PyThreadContextEntry> stack; 497 return stack; 498 } 499 500 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 501 auto &stack = getStack(); 502 if (stack.empty()) 503 return nullptr; 504 return &stack.back(); 505 } 506 507 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 508 py::object insertionPoint, 509 py::object location) { 510 auto &stack = getStack(); 511 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 512 std::move(location)); 513 // If the new stack has more than one entry and the context of the new top 514 // entry matches the previous, copy the insertionPoint and location from the 515 // previous entry if missing from the new top entry. 516 if (stack.size() > 1) { 517 auto &prev = *(stack.rbegin() + 1); 518 auto ¤t = stack.back(); 519 if (current.context.is(prev.context)) { 520 // Default non-context objects from the previous entry. 521 if (!current.insertionPoint) 522 current.insertionPoint = prev.insertionPoint; 523 if (!current.location) 524 current.location = prev.location; 525 } 526 } 527 } 528 529 PyMlirContext *PyThreadContextEntry::getContext() { 530 if (!context) 531 return nullptr; 532 return py::cast<PyMlirContext *>(context); 533 } 534 535 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 536 if (!insertionPoint) 537 return nullptr; 538 return py::cast<PyInsertionPoint *>(insertionPoint); 539 } 540 541 PyLocation *PyThreadContextEntry::getLocation() { 542 if (!location) 543 return nullptr; 544 return py::cast<PyLocation *>(location); 545 } 546 547 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 548 auto *tos = getTopOfStack(); 549 return tos ? tos->getContext() : nullptr; 550 } 551 552 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 553 auto *tos = getTopOfStack(); 554 return tos ? tos->getInsertionPoint() : nullptr; 555 } 556 557 PyLocation *PyThreadContextEntry::getDefaultLocation() { 558 auto *tos = getTopOfStack(); 559 return tos ? tos->getLocation() : nullptr; 560 } 561 562 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 563 py::object contextObj = py::cast(context); 564 push(FrameKind::Context, /*context=*/contextObj, 565 /*insertionPoint=*/py::object(), 566 /*location=*/py::object()); 567 return contextObj; 568 } 569 570 void PyThreadContextEntry::popContext(PyMlirContext &context) { 571 auto &stack = getStack(); 572 if (stack.empty()) 573 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 574 auto &tos = stack.back(); 575 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 576 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 577 stack.pop_back(); 578 } 579 580 py::object 581 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 582 py::object contextObj = 583 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 584 py::object insertionPointObj = py::cast(insertionPoint); 585 push(FrameKind::InsertionPoint, 586 /*context=*/contextObj, 587 /*insertionPoint=*/insertionPointObj, 588 /*location=*/py::object()); 589 return insertionPointObj; 590 } 591 592 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 593 auto &stack = getStack(); 594 if (stack.empty()) 595 throw SetPyError(PyExc_RuntimeError, 596 "Unbalanced InsertionPoint enter/exit"); 597 auto &tos = stack.back(); 598 if (tos.frameKind != FrameKind::InsertionPoint && 599 tos.getInsertionPoint() != &insertionPoint) 600 throw SetPyError(PyExc_RuntimeError, 601 "Unbalanced InsertionPoint enter/exit"); 602 stack.pop_back(); 603 } 604 605 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 606 py::object contextObj = location.getContext().getObject(); 607 py::object locationObj = py::cast(location); 608 push(FrameKind::Location, /*context=*/contextObj, 609 /*insertionPoint=*/py::object(), 610 /*location=*/locationObj); 611 return locationObj; 612 } 613 614 void PyThreadContextEntry::popLocation(PyLocation &location) { 615 auto &stack = getStack(); 616 if (stack.empty()) 617 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 618 auto &tos = stack.back(); 619 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 620 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 621 stack.pop_back(); 622 } 623 624 //------------------------------------------------------------------------------ 625 // PyDialect, PyDialectDescriptor, PyDialects 626 //------------------------------------------------------------------------------ 627 628 MlirDialect PyDialects::getDialectForKey(const std::string &key, 629 bool attrError) { 630 // If the "std" dialect was asked for, substitute the empty namespace :( 631 static const std::string emptyKey; 632 const std::string *canonKey = key == "std" ? &emptyKey : &key; 633 MlirDialect dialect = mlirContextGetOrLoadDialect( 634 getContext()->get(), {canonKey->data(), canonKey->size()}); 635 if (mlirDialectIsNull(dialect)) { 636 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 637 Twine("Dialect '") + key + "' not found"); 638 } 639 return dialect; 640 } 641 642 //------------------------------------------------------------------------------ 643 // PyLocation 644 //------------------------------------------------------------------------------ 645 646 py::object PyLocation::getCapsule() { 647 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 648 } 649 650 PyLocation PyLocation::createFromCapsule(py::object capsule) { 651 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 652 if (mlirLocationIsNull(rawLoc)) 653 throw py::error_already_set(); 654 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 655 rawLoc); 656 } 657 658 py::object PyLocation::contextEnter() { 659 return PyThreadContextEntry::pushLocation(*this); 660 } 661 662 void PyLocation::contextExit(py::object excType, py::object excVal, 663 py::object excTb) { 664 PyThreadContextEntry::popLocation(*this); 665 } 666 667 PyLocation &DefaultingPyLocation::resolve() { 668 auto *location = PyThreadContextEntry::getDefaultLocation(); 669 if (!location) { 670 throw SetPyError( 671 PyExc_RuntimeError, 672 "An MLIR function requires a Location but none was provided in the " 673 "call or from the surrounding environment. Either pass to the function " 674 "with a 'loc=' argument or establish a default using 'with loc:'"); 675 } 676 return *location; 677 } 678 679 //------------------------------------------------------------------------------ 680 // PyModule 681 //------------------------------------------------------------------------------ 682 683 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 684 : BaseContextObject(std::move(contextRef)), module(module) {} 685 686 PyModule::~PyModule() { 687 py::gil_scoped_acquire acquire; 688 auto &liveModules = getContext()->liveModules; 689 assert(liveModules.count(module.ptr) == 1 && 690 "destroying module not in live map"); 691 liveModules.erase(module.ptr); 692 mlirModuleDestroy(module); 693 } 694 695 PyModuleRef PyModule::forModule(MlirModule module) { 696 MlirContext context = mlirModuleGetContext(module); 697 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 698 699 py::gil_scoped_acquire acquire; 700 auto &liveModules = contextRef->liveModules; 701 auto it = liveModules.find(module.ptr); 702 if (it == liveModules.end()) { 703 // Create. 704 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 705 // Note that the default return value policy on cast is automatic_reference, 706 // which does not take ownership (delete will not be called). 707 // Just be explicit. 708 py::object pyRef = 709 py::cast(unownedModule, py::return_value_policy::take_ownership); 710 unownedModule->handle = pyRef; 711 liveModules[module.ptr] = 712 std::make_pair(unownedModule->handle, unownedModule); 713 return PyModuleRef(unownedModule, std::move(pyRef)); 714 } 715 // Use existing. 716 PyModule *existing = it->second.second; 717 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 718 return PyModuleRef(existing, std::move(pyRef)); 719 } 720 721 py::object PyModule::createFromCapsule(py::object capsule) { 722 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 723 if (mlirModuleIsNull(rawModule)) 724 throw py::error_already_set(); 725 return forModule(rawModule).releaseObject(); 726 } 727 728 py::object PyModule::getCapsule() { 729 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 730 } 731 732 //------------------------------------------------------------------------------ 733 // PyOperation 734 //------------------------------------------------------------------------------ 735 736 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 737 : BaseContextObject(std::move(contextRef)), operation(operation) {} 738 739 PyOperation::~PyOperation() { 740 auto &liveOperations = getContext()->liveOperations; 741 assert(liveOperations.count(operation.ptr) == 1 && 742 "destroying operation not in live map"); 743 liveOperations.erase(operation.ptr); 744 if (!isAttached()) { 745 mlirOperationDestroy(operation); 746 } 747 } 748 749 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 750 MlirOperation operation, 751 py::object parentKeepAlive) { 752 auto &liveOperations = contextRef->liveOperations; 753 // Create. 754 PyOperation *unownedOperation = 755 new PyOperation(std::move(contextRef), operation); 756 // Note that the default return value policy on cast is automatic_reference, 757 // which does not take ownership (delete will not be called). 758 // Just be explicit. 759 py::object pyRef = 760 py::cast(unownedOperation, py::return_value_policy::take_ownership); 761 unownedOperation->handle = pyRef; 762 if (parentKeepAlive) { 763 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 764 } 765 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 766 return PyOperationRef(unownedOperation, std::move(pyRef)); 767 } 768 769 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 770 MlirOperation operation, 771 py::object parentKeepAlive) { 772 auto &liveOperations = contextRef->liveOperations; 773 auto it = liveOperations.find(operation.ptr); 774 if (it == liveOperations.end()) { 775 // Create. 776 return createInstance(std::move(contextRef), operation, 777 std::move(parentKeepAlive)); 778 } 779 // Use existing. 780 PyOperation *existing = it->second.second; 781 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 782 return PyOperationRef(existing, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 assert(liveOperations.count(operation.ptr) == 0 && 790 "cannot create detached operation that already exists"); 791 (void)liveOperations; 792 793 PyOperationRef created = createInstance(std::move(contextRef), operation, 794 std::move(parentKeepAlive)); 795 created->attached = false; 796 return created; 797 } 798 799 void PyOperation::checkValid() const { 800 if (!valid) { 801 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 802 } 803 } 804 805 void PyOperationBase::print(py::object fileObject, bool binary, 806 llvm::Optional<int64_t> largeElementsLimit, 807 bool enableDebugInfo, bool prettyDebugInfo, 808 bool printGenericOpForm, bool useLocalScope) { 809 PyOperation &operation = getOperation(); 810 operation.checkValid(); 811 if (fileObject.is_none()) 812 fileObject = py::module::import("sys").attr("stdout"); 813 814 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 815 fileObject.attr("write")("// Verification failed, printing generic form\n"); 816 printGenericOpForm = true; 817 } 818 819 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 820 if (largeElementsLimit) 821 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 822 if (enableDebugInfo) 823 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 824 if (printGenericOpForm) 825 mlirOpPrintingFlagsPrintGenericOpForm(flags); 826 827 PyFileAccumulator accum(fileObject, binary); 828 py::gil_scoped_release(); 829 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 830 accum.getUserData()); 831 mlirOpPrintingFlagsDestroy(flags); 832 } 833 834 py::object PyOperationBase::getAsm(bool binary, 835 llvm::Optional<int64_t> largeElementsLimit, 836 bool enableDebugInfo, bool prettyDebugInfo, 837 bool printGenericOpForm, 838 bool useLocalScope) { 839 py::object fileObject; 840 if (binary) { 841 fileObject = py::module::import("io").attr("BytesIO")(); 842 } else { 843 fileObject = py::module::import("io").attr("StringIO")(); 844 } 845 print(fileObject, /*binary=*/binary, 846 /*largeElementsLimit=*/largeElementsLimit, 847 /*enableDebugInfo=*/enableDebugInfo, 848 /*prettyDebugInfo=*/prettyDebugInfo, 849 /*printGenericOpForm=*/printGenericOpForm, 850 /*useLocalScope=*/useLocalScope); 851 852 return fileObject.attr("getvalue")(); 853 } 854 855 PyOperationRef PyOperation::getParentOperation() { 856 if (!isAttached()) 857 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 858 MlirOperation operation = mlirOperationGetParentOperation(get()); 859 if (mlirOperationIsNull(operation)) 860 throw SetPyError(PyExc_ValueError, "Operation has no parent."); 861 return PyOperation::forOperation(getContext(), operation); 862 } 863 864 PyBlock PyOperation::getBlock() { 865 PyOperationRef parentOperation = getParentOperation(); 866 MlirBlock block = mlirOperationGetBlock(get()); 867 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 868 return PyBlock{std::move(parentOperation), block}; 869 } 870 871 py::object PyOperation::create( 872 std::string name, llvm::Optional<std::vector<PyType *>> results, 873 llvm::Optional<std::vector<PyValue *>> operands, 874 llvm::Optional<py::dict> attributes, 875 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 876 DefaultingPyLocation location, py::object maybeIp) { 877 llvm::SmallVector<MlirValue, 4> mlirOperands; 878 llvm::SmallVector<MlirType, 4> mlirResults; 879 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 880 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 881 882 // General parameter validation. 883 if (regions < 0) 884 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 885 886 // Unpack/validate operands. 887 if (operands) { 888 mlirOperands.reserve(operands->size()); 889 for (PyValue *operand : *operands) { 890 if (!operand) 891 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 892 mlirOperands.push_back(operand->get()); 893 } 894 } 895 896 // Unpack/validate results. 897 if (results) { 898 mlirResults.reserve(results->size()); 899 for (PyType *result : *results) { 900 // TODO: Verify result type originate from the same context. 901 if (!result) 902 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 903 mlirResults.push_back(*result); 904 } 905 } 906 // Unpack/validate attributes. 907 if (attributes) { 908 mlirAttributes.reserve(attributes->size()); 909 for (auto &it : *attributes) { 910 std::string key; 911 try { 912 key = it.first.cast<std::string>(); 913 } catch (py::cast_error &err) { 914 std::string msg = "Invalid attribute key (not a string) when " 915 "attempting to create the operation \"" + 916 name + "\" (" + err.what() + ")"; 917 throw py::cast_error(msg); 918 } 919 try { 920 auto &attribute = it.second.cast<PyAttribute &>(); 921 // TODO: Verify attribute originates from the same context. 922 mlirAttributes.emplace_back(std::move(key), attribute); 923 } catch (py::reference_cast_error &) { 924 // This exception seems thrown when the value is "None". 925 std::string msg = 926 "Found an invalid (`None`?) attribute value for the key \"" + key + 927 "\" when attempting to create the operation \"" + name + "\""; 928 throw py::cast_error(msg); 929 } catch (py::cast_error &err) { 930 std::string msg = "Invalid attribute value for the key \"" + key + 931 "\" when attempting to create the operation \"" + 932 name + "\" (" + err.what() + ")"; 933 throw py::cast_error(msg); 934 } 935 } 936 } 937 // Unpack/validate successors. 938 if (successors) { 939 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 940 mlirSuccessors.reserve(successors->size()); 941 for (auto *successor : *successors) { 942 // TODO: Verify successor originate from the same context. 943 if (!successor) 944 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 945 mlirSuccessors.push_back(successor->get()); 946 } 947 } 948 949 // Apply unpacked/validated to the operation state. Beyond this 950 // point, exceptions cannot be thrown or else the state will leak. 951 MlirOperationState state = 952 mlirOperationStateGet(toMlirStringRef(name), location); 953 if (!mlirOperands.empty()) 954 mlirOperationStateAddOperands(&state, mlirOperands.size(), 955 mlirOperands.data()); 956 if (!mlirResults.empty()) 957 mlirOperationStateAddResults(&state, mlirResults.size(), 958 mlirResults.data()); 959 if (!mlirAttributes.empty()) { 960 // Note that the attribute names directly reference bytes in 961 // mlirAttributes, so that vector must not be changed from here 962 // on. 963 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 964 mlirNamedAttributes.reserve(mlirAttributes.size()); 965 for (auto &it : mlirAttributes) 966 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 967 mlirIdentifierGet(mlirAttributeGetContext(it.second), 968 toMlirStringRef(it.first)), 969 it.second)); 970 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 971 mlirNamedAttributes.data()); 972 } 973 if (!mlirSuccessors.empty()) 974 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 975 mlirSuccessors.data()); 976 if (regions) { 977 llvm::SmallVector<MlirRegion, 4> mlirRegions; 978 mlirRegions.resize(regions); 979 for (int i = 0; i < regions; ++i) 980 mlirRegions[i] = mlirRegionCreate(); 981 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 982 mlirRegions.data()); 983 } 984 985 // Construct the operation. 986 MlirOperation operation = mlirOperationCreate(&state); 987 PyOperationRef created = 988 PyOperation::createDetached(location->getContext(), operation); 989 990 // InsertPoint active? 991 if (!maybeIp.is(py::cast(false))) { 992 PyInsertionPoint *ip; 993 if (maybeIp.is_none()) { 994 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 995 } else { 996 ip = py::cast<PyInsertionPoint *>(maybeIp); 997 } 998 if (ip) 999 ip->insert(*created.get()); 1000 } 1001 1002 return created->createOpView(); 1003 } 1004 1005 py::object PyOperation::createOpView() { 1006 MlirIdentifier ident = mlirOperationGetName(get()); 1007 MlirStringRef identStr = mlirIdentifierStr(ident); 1008 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1009 StringRef(identStr.data, identStr.length)); 1010 if (opViewClass) 1011 return (*opViewClass)(getRef().getObject()); 1012 return py::cast(PyOpView(getRef().getObject())); 1013 } 1014 1015 //------------------------------------------------------------------------------ 1016 // PyOpView 1017 //------------------------------------------------------------------------------ 1018 1019 py::object 1020 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1021 py::list operandList, 1022 llvm::Optional<py::dict> attributes, 1023 llvm::Optional<std::vector<PyBlock *>> successors, 1024 llvm::Optional<int> regions, 1025 DefaultingPyLocation location, py::object maybeIp) { 1026 PyMlirContextRef context = location->getContext(); 1027 // Class level operation construction metadata. 1028 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1029 // Operand and result segment specs are either none, which does no 1030 // variadic unpacking, or a list of ints with segment sizes, where each 1031 // element is either a positive number (typically 1 for a scalar) or -1 to 1032 // indicate that it is derived from the length of the same-indexed operand 1033 // or result (implying that it is a list at that position). 1034 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1035 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1036 1037 std::vector<uint32_t> operandSegmentLengths; 1038 std::vector<uint32_t> resultSegmentLengths; 1039 1040 // Validate/determine region count. 1041 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1042 int opMinRegionCount = std::get<0>(opRegionSpec); 1043 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1044 if (!regions) { 1045 regions = opMinRegionCount; 1046 } 1047 if (*regions < opMinRegionCount) { 1048 throw py::value_error( 1049 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1050 llvm::Twine(opMinRegionCount) + 1051 " regions but was built with regions=" + llvm::Twine(*regions)) 1052 .str()); 1053 } 1054 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1055 throw py::value_error( 1056 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1057 llvm::Twine(opMinRegionCount) + 1058 " regions but was built with regions=" + llvm::Twine(*regions)) 1059 .str()); 1060 } 1061 1062 // Unpack results. 1063 std::vector<PyType *> resultTypes; 1064 resultTypes.reserve(resultTypeList.size()); 1065 if (resultSegmentSpecObj.is_none()) { 1066 // Non-variadic result unpacking. 1067 for (auto it : llvm::enumerate(resultTypeList)) { 1068 try { 1069 resultTypes.push_back(py::cast<PyType *>(it.value())); 1070 if (!resultTypes.back()) 1071 throw py::cast_error(); 1072 } catch (py::cast_error &err) { 1073 throw py::value_error((llvm::Twine("Result ") + 1074 llvm::Twine(it.index()) + " of operation \"" + 1075 name + "\" must be a Type (" + err.what() + ")") 1076 .str()); 1077 } 1078 } 1079 } else { 1080 // Sized result unpacking. 1081 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1082 if (resultSegmentSpec.size() != resultTypeList.size()) { 1083 throw py::value_error((llvm::Twine("Operation \"") + name + 1084 "\" requires " + 1085 llvm::Twine(resultSegmentSpec.size()) + 1086 "result segments but was provided " + 1087 llvm::Twine(resultTypeList.size())) 1088 .str()); 1089 } 1090 resultSegmentLengths.reserve(resultTypeList.size()); 1091 for (auto it : 1092 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1093 int segmentSpec = std::get<1>(it.value()); 1094 if (segmentSpec == 1 || segmentSpec == 0) { 1095 // Unpack unary element. 1096 try { 1097 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1098 if (resultType) { 1099 resultTypes.push_back(resultType); 1100 resultSegmentLengths.push_back(1); 1101 } else if (segmentSpec == 0) { 1102 // Allowed to be optional. 1103 resultSegmentLengths.push_back(0); 1104 } else { 1105 throw py::cast_error("was None and result is not optional"); 1106 } 1107 } catch (py::cast_error &err) { 1108 throw py::value_error((llvm::Twine("Result ") + 1109 llvm::Twine(it.index()) + " of operation \"" + 1110 name + "\" must be a Type (" + err.what() + 1111 ")") 1112 .str()); 1113 } 1114 } else if (segmentSpec == -1) { 1115 // Unpack sequence by appending. 1116 try { 1117 if (std::get<0>(it.value()).is_none()) { 1118 // Treat it as an empty list. 1119 resultSegmentLengths.push_back(0); 1120 } else { 1121 // Unpack the list. 1122 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1123 for (py::object segmentItem : segment) { 1124 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1125 if (!resultTypes.back()) { 1126 throw py::cast_error("contained a None item"); 1127 } 1128 } 1129 resultSegmentLengths.push_back(segment.size()); 1130 } 1131 } catch (std::exception &err) { 1132 // NOTE: Sloppy to be using a catch-all here, but there are at least 1133 // three different unrelated exceptions that can be thrown in the 1134 // above "casts". Just keep the scope above small and catch them all. 1135 throw py::value_error((llvm::Twine("Result ") + 1136 llvm::Twine(it.index()) + " of operation \"" + 1137 name + "\" must be a Sequence of Types (" + 1138 err.what() + ")") 1139 .str()); 1140 } 1141 } else { 1142 throw py::value_error("Unexpected segment spec"); 1143 } 1144 } 1145 } 1146 1147 // Unpack operands. 1148 std::vector<PyValue *> operands; 1149 operands.reserve(operands.size()); 1150 if (operandSegmentSpecObj.is_none()) { 1151 // Non-sized operand unpacking. 1152 for (auto it : llvm::enumerate(operandList)) { 1153 try { 1154 operands.push_back(py::cast<PyValue *>(it.value())); 1155 if (!operands.back()) 1156 throw py::cast_error(); 1157 } catch (py::cast_error &err) { 1158 throw py::value_error((llvm::Twine("Operand ") + 1159 llvm::Twine(it.index()) + " of operation \"" + 1160 name + "\" must be a Value (" + err.what() + ")") 1161 .str()); 1162 } 1163 } 1164 } else { 1165 // Sized operand unpacking. 1166 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1167 if (operandSegmentSpec.size() != operandList.size()) { 1168 throw py::value_error((llvm::Twine("Operation \"") + name + 1169 "\" requires " + 1170 llvm::Twine(operandSegmentSpec.size()) + 1171 "operand segments but was provided " + 1172 llvm::Twine(operandList.size())) 1173 .str()); 1174 } 1175 operandSegmentLengths.reserve(operandList.size()); 1176 for (auto it : 1177 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1178 int segmentSpec = std::get<1>(it.value()); 1179 if (segmentSpec == 1 || segmentSpec == 0) { 1180 // Unpack unary element. 1181 try { 1182 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1183 if (operandValue) { 1184 operands.push_back(operandValue); 1185 operandSegmentLengths.push_back(1); 1186 } else if (segmentSpec == 0) { 1187 // Allowed to be optional. 1188 operandSegmentLengths.push_back(0); 1189 } else { 1190 throw py::cast_error("was None and operand is not optional"); 1191 } 1192 } catch (py::cast_error &err) { 1193 throw py::value_error((llvm::Twine("Operand ") + 1194 llvm::Twine(it.index()) + " of operation \"" + 1195 name + "\" must be a Value (" + err.what() + 1196 ")") 1197 .str()); 1198 } 1199 } else if (segmentSpec == -1) { 1200 // Unpack sequence by appending. 1201 try { 1202 if (std::get<0>(it.value()).is_none()) { 1203 // Treat it as an empty list. 1204 operandSegmentLengths.push_back(0); 1205 } else { 1206 // Unpack the list. 1207 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1208 for (py::object segmentItem : segment) { 1209 operands.push_back(py::cast<PyValue *>(segmentItem)); 1210 if (!operands.back()) { 1211 throw py::cast_error("contained a None item"); 1212 } 1213 } 1214 operandSegmentLengths.push_back(segment.size()); 1215 } 1216 } catch (std::exception &err) { 1217 // NOTE: Sloppy to be using a catch-all here, but there are at least 1218 // three different unrelated exceptions that can be thrown in the 1219 // above "casts". Just keep the scope above small and catch them all. 1220 throw py::value_error((llvm::Twine("Operand ") + 1221 llvm::Twine(it.index()) + " of operation \"" + 1222 name + "\" must be a Sequence of Values (" + 1223 err.what() + ")") 1224 .str()); 1225 } 1226 } else { 1227 throw py::value_error("Unexpected segment spec"); 1228 } 1229 } 1230 } 1231 1232 // Merge operand/result segment lengths into attributes if needed. 1233 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1234 // Dup. 1235 if (attributes) { 1236 attributes = py::dict(*attributes); 1237 } else { 1238 attributes = py::dict(); 1239 } 1240 if (attributes->contains("result_segment_sizes") || 1241 attributes->contains("operand_segment_sizes")) { 1242 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1243 "'operand_segment_sizes' attribute is unsupported. " 1244 "Use Operation.create for such low-level access."); 1245 } 1246 1247 // Add result_segment_sizes attribute. 1248 if (!resultSegmentLengths.empty()) { 1249 int64_t size = resultSegmentLengths.size(); 1250 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1251 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1252 resultSegmentLengths.size(), resultSegmentLengths.data()); 1253 (*attributes)["result_segment_sizes"] = 1254 PyAttribute(context, segmentLengthAttr); 1255 } 1256 1257 // Add operand_segment_sizes attribute. 1258 if (!operandSegmentLengths.empty()) { 1259 int64_t size = operandSegmentLengths.size(); 1260 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1261 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1262 operandSegmentLengths.size(), operandSegmentLengths.data()); 1263 (*attributes)["operand_segment_sizes"] = 1264 PyAttribute(context, segmentLengthAttr); 1265 } 1266 } 1267 1268 // Delegate to create. 1269 return PyOperation::create(std::move(name), 1270 /*results=*/std::move(resultTypes), 1271 /*operands=*/std::move(operands), 1272 /*attributes=*/std::move(attributes), 1273 /*successors=*/std::move(successors), 1274 /*regions=*/*regions, location, maybeIp); 1275 } 1276 1277 PyOpView::PyOpView(py::object operationObject) 1278 // Casting through the PyOperationBase base-class and then back to the 1279 // Operation lets us accept any PyOperationBase subclass. 1280 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1281 operationObject(operation.getRef().getObject()) {} 1282 1283 py::object PyOpView::createRawSubclass(py::object userClass) { 1284 // This is... a little gross. The typical pattern is to have a pure python 1285 // class that extends OpView like: 1286 // class AddFOp(_cext.ir.OpView): 1287 // def __init__(self, loc, lhs, rhs): 1288 // operation = loc.context.create_operation( 1289 // "addf", lhs, rhs, results=[lhs.type]) 1290 // super().__init__(operation) 1291 // 1292 // I.e. The goal of the user facing type is to provide a nice constructor 1293 // that has complete freedom for the op under construction. This is at odds 1294 // with our other desire to sometimes create this object by just passing an 1295 // operation (to initialize the base class). We could do *arg and **kwargs 1296 // munging to try to make it work, but instead, we synthesize a new class 1297 // on the fly which extends this user class (AddFOp in this example) and 1298 // *give it* the base class's __init__ method, thus bypassing the 1299 // intermediate subclass's __init__ method entirely. While slightly, 1300 // underhanded, this is safe/legal because the type hierarchy has not changed 1301 // (we just added a new leaf) and we aren't mucking around with __new__. 1302 // Typically, this new class will be stored on the original as "_Raw" and will 1303 // be used for casts and other things that need a variant of the class that 1304 // is initialized purely from an operation. 1305 py::object parentMetaclass = 1306 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1307 py::dict attributes; 1308 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1309 // now. 1310 // auto opViewType = py::type::of<PyOpView>(); 1311 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1312 attributes["__init__"] = opViewType.attr("__init__"); 1313 py::str origName = userClass.attr("__name__"); 1314 py::str newName = py::str("_") + origName; 1315 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1316 } 1317 1318 //------------------------------------------------------------------------------ 1319 // PyInsertionPoint. 1320 //------------------------------------------------------------------------------ 1321 1322 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1323 1324 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1325 : refOperation(beforeOperationBase.getOperation().getRef()), 1326 block((*refOperation)->getBlock()) {} 1327 1328 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1329 PyOperation &operation = operationBase.getOperation(); 1330 if (operation.isAttached()) 1331 throw SetPyError(PyExc_ValueError, 1332 "Attempt to insert operation that is already attached"); 1333 block.getParentOperation()->checkValid(); 1334 MlirOperation beforeOp = {nullptr}; 1335 if (refOperation) { 1336 // Insert before operation. 1337 (*refOperation)->checkValid(); 1338 beforeOp = (*refOperation)->get(); 1339 } else { 1340 // Insert at end (before null) is only valid if the block does not 1341 // already end in a known terminator (violating this will cause assertion 1342 // failures later). 1343 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1344 throw py::index_error("Cannot insert operation at the end of a block " 1345 "that already has a terminator. Did you mean to " 1346 "use 'InsertionPoint.at_block_terminator(block)' " 1347 "versus 'InsertionPoint(block)'?"); 1348 } 1349 } 1350 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1351 operation.setAttached(); 1352 } 1353 1354 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1355 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1356 if (mlirOperationIsNull(firstOp)) { 1357 // Just insert at end. 1358 return PyInsertionPoint(block); 1359 } 1360 1361 // Insert before first op. 1362 PyOperationRef firstOpRef = PyOperation::forOperation( 1363 block.getParentOperation()->getContext(), firstOp); 1364 return PyInsertionPoint{block, std::move(firstOpRef)}; 1365 } 1366 1367 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1368 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1369 if (mlirOperationIsNull(terminator)) 1370 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1371 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1372 block.getParentOperation()->getContext(), terminator); 1373 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1374 } 1375 1376 py::object PyInsertionPoint::contextEnter() { 1377 return PyThreadContextEntry::pushInsertionPoint(*this); 1378 } 1379 1380 void PyInsertionPoint::contextExit(pybind11::object excType, 1381 pybind11::object excVal, 1382 pybind11::object excTb) { 1383 PyThreadContextEntry::popInsertionPoint(*this); 1384 } 1385 1386 //------------------------------------------------------------------------------ 1387 // PyAttribute. 1388 //------------------------------------------------------------------------------ 1389 1390 bool PyAttribute::operator==(const PyAttribute &other) { 1391 return mlirAttributeEqual(attr, other.attr); 1392 } 1393 1394 py::object PyAttribute::getCapsule() { 1395 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1396 } 1397 1398 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1399 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1400 if (mlirAttributeIsNull(rawAttr)) 1401 throw py::error_already_set(); 1402 return PyAttribute( 1403 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1404 } 1405 1406 //------------------------------------------------------------------------------ 1407 // PyNamedAttribute. 1408 //------------------------------------------------------------------------------ 1409 1410 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1411 : ownedName(new std::string(std::move(ownedName))) { 1412 namedAttr = mlirNamedAttributeGet( 1413 mlirIdentifierGet(mlirAttributeGetContext(attr), 1414 toMlirStringRef(*this->ownedName)), 1415 attr); 1416 } 1417 1418 //------------------------------------------------------------------------------ 1419 // PyType. 1420 //------------------------------------------------------------------------------ 1421 1422 bool PyType::operator==(const PyType &other) { 1423 return mlirTypeEqual(type, other.type); 1424 } 1425 1426 py::object PyType::getCapsule() { 1427 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1428 } 1429 1430 PyType PyType::createFromCapsule(py::object capsule) { 1431 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1432 if (mlirTypeIsNull(rawType)) 1433 throw py::error_already_set(); 1434 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1435 rawType); 1436 } 1437 1438 //------------------------------------------------------------------------------ 1439 // PyValue and subclases. 1440 //------------------------------------------------------------------------------ 1441 1442 namespace { 1443 /// CRTP base class for Python MLIR values that subclass Value and should be 1444 /// castable from it. The value hierarchy is one level deep and is not supposed 1445 /// to accommodate other levels unless core MLIR changes. 1446 template <typename DerivedTy> 1447 class PyConcreteValue : public PyValue { 1448 public: 1449 // Derived classes must define statics for: 1450 // IsAFunctionTy isaFunction 1451 // const char *pyClassName 1452 // and redefine bindDerived. 1453 using ClassTy = py::class_<DerivedTy, PyValue>; 1454 using IsAFunctionTy = bool (*)(MlirValue); 1455 1456 PyConcreteValue() = default; 1457 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1458 : PyValue(operationRef, value) {} 1459 PyConcreteValue(PyValue &orig) 1460 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1461 1462 /// Attempts to cast the original value to the derived type and throws on 1463 /// type mismatches. 1464 static MlirValue castFrom(PyValue &orig) { 1465 if (!DerivedTy::isaFunction(orig.get())) { 1466 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1467 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1468 DerivedTy::pyClassName + 1469 " (from " + origRepr + ")"); 1470 } 1471 return orig.get(); 1472 } 1473 1474 /// Binds the Python module objects to functions of this class. 1475 static void bind(py::module &m) { 1476 auto cls = ClassTy(m, DerivedTy::pyClassName); 1477 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1478 DerivedTy::bindDerived(cls); 1479 } 1480 1481 /// Implemented by derived classes to add methods to the Python subclass. 1482 static void bindDerived(ClassTy &m) {} 1483 }; 1484 1485 /// Python wrapper for MlirBlockArgument. 1486 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1487 public: 1488 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1489 static constexpr const char *pyClassName = "BlockArgument"; 1490 using PyConcreteValue::PyConcreteValue; 1491 1492 static void bindDerived(ClassTy &c) { 1493 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1494 return PyBlock(self.getParentOperation(), 1495 mlirBlockArgumentGetOwner(self.get())); 1496 }); 1497 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1498 return mlirBlockArgumentGetArgNumber(self.get()); 1499 }); 1500 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1501 return mlirBlockArgumentSetType(self.get(), type); 1502 }); 1503 } 1504 }; 1505 1506 /// Python wrapper for MlirOpResult. 1507 class PyOpResult : public PyConcreteValue<PyOpResult> { 1508 public: 1509 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1510 static constexpr const char *pyClassName = "OpResult"; 1511 using PyConcreteValue::PyConcreteValue; 1512 1513 static void bindDerived(ClassTy &c) { 1514 c.def_property_readonly("owner", [](PyOpResult &self) { 1515 assert( 1516 mlirOperationEqual(self.getParentOperation()->get(), 1517 mlirOpResultGetOwner(self.get())) && 1518 "expected the owner of the value in Python to match that in the IR"); 1519 return self.getParentOperation(); 1520 }); 1521 c.def_property_readonly("result_number", [](PyOpResult &self) { 1522 return mlirOpResultGetResultNumber(self.get()); 1523 }); 1524 } 1525 }; 1526 1527 /// A list of block arguments. Internally, these are stored as consecutive 1528 /// elements, random access is cheap. The argument list is associated with the 1529 /// operation that contains the block (detached blocks are not allowed in 1530 /// Python bindings) and extends its lifetime. 1531 class PyBlockArgumentList { 1532 public: 1533 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1534 : operation(std::move(operation)), block(block) {} 1535 1536 /// Returns the length of the block argument list. 1537 intptr_t dunderLen() { 1538 operation->checkValid(); 1539 return mlirBlockGetNumArguments(block); 1540 } 1541 1542 /// Returns `index`-th element of the block argument list. 1543 PyBlockArgument dunderGetItem(intptr_t index) { 1544 if (index < 0 || index >= dunderLen()) { 1545 throw SetPyError(PyExc_IndexError, 1546 "attempt to access out of bounds region"); 1547 } 1548 PyValue value(operation, mlirBlockGetArgument(block, index)); 1549 return PyBlockArgument(value); 1550 } 1551 1552 /// Defines a Python class in the bindings. 1553 static void bind(py::module &m) { 1554 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1555 .def("__len__", &PyBlockArgumentList::dunderLen) 1556 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1557 } 1558 1559 private: 1560 PyOperationRef operation; 1561 MlirBlock block; 1562 }; 1563 1564 /// A list of operation operands. Internally, these are stored as consecutive 1565 /// elements, random access is cheap. The result list is associated with the 1566 /// operation whose results these are, and extends the lifetime of this 1567 /// operation. 1568 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1569 public: 1570 static constexpr const char *pyClassName = "OpOperandList"; 1571 1572 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1573 intptr_t length = -1, intptr_t step = 1) 1574 : Sliceable(startIndex, 1575 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1576 : length, 1577 step), 1578 operation(operation) {} 1579 1580 intptr_t getNumElements() { 1581 operation->checkValid(); 1582 return mlirOperationGetNumOperands(operation->get()); 1583 } 1584 1585 PyValue getElement(intptr_t pos) { 1586 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); 1587 } 1588 1589 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1590 return PyOpOperandList(operation, startIndex, length, step); 1591 } 1592 1593 private: 1594 PyOperationRef operation; 1595 }; 1596 1597 /// A list of operation results. Internally, these are stored as consecutive 1598 /// elements, random access is cheap. The result list is associated with the 1599 /// operation whose results these are, and extends the lifetime of this 1600 /// operation. 1601 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1602 public: 1603 static constexpr const char *pyClassName = "OpResultList"; 1604 1605 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1606 intptr_t length = -1, intptr_t step = 1) 1607 : Sliceable(startIndex, 1608 length == -1 ? mlirOperationGetNumResults(operation->get()) 1609 : length, 1610 step), 1611 operation(operation) {} 1612 1613 intptr_t getNumElements() { 1614 operation->checkValid(); 1615 return mlirOperationGetNumResults(operation->get()); 1616 } 1617 1618 PyOpResult getElement(intptr_t index) { 1619 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1620 return PyOpResult(value); 1621 } 1622 1623 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1624 return PyOpResultList(operation, startIndex, length, step); 1625 } 1626 1627 private: 1628 PyOperationRef operation; 1629 }; 1630 1631 /// A list of operation attributes. Can be indexed by name, producing 1632 /// attributes, or by index, producing named attributes. 1633 class PyOpAttributeMap { 1634 public: 1635 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1636 1637 PyAttribute dunderGetItemNamed(const std::string &name) { 1638 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1639 toMlirStringRef(name)); 1640 if (mlirAttributeIsNull(attr)) { 1641 throw SetPyError(PyExc_KeyError, 1642 "attempt to access a non-existent attribute"); 1643 } 1644 return PyAttribute(operation->getContext(), attr); 1645 } 1646 1647 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1648 if (index < 0 || index >= dunderLen()) { 1649 throw SetPyError(PyExc_IndexError, 1650 "attempt to access out of bounds attribute"); 1651 } 1652 MlirNamedAttribute namedAttr = 1653 mlirOperationGetAttribute(operation->get(), index); 1654 return PyNamedAttribute( 1655 namedAttr.attribute, 1656 std::string(mlirIdentifierStr(namedAttr.name).data)); 1657 } 1658 1659 void dunderSetItem(const std::string &name, PyAttribute attr) { 1660 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1661 attr); 1662 } 1663 1664 void dunderDelItem(const std::string &name) { 1665 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1666 toMlirStringRef(name)); 1667 if (!removed) 1668 throw SetPyError(PyExc_KeyError, 1669 "attempt to delete a non-existent attribute"); 1670 } 1671 1672 intptr_t dunderLen() { 1673 return mlirOperationGetNumAttributes(operation->get()); 1674 } 1675 1676 bool dunderContains(const std::string &name) { 1677 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1678 operation->get(), toMlirStringRef(name))); 1679 } 1680 1681 static void bind(py::module &m) { 1682 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1683 .def("__contains__", &PyOpAttributeMap::dunderContains) 1684 .def("__len__", &PyOpAttributeMap::dunderLen) 1685 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1686 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1687 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1688 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1689 } 1690 1691 private: 1692 PyOperationRef operation; 1693 }; 1694 1695 } // end namespace 1696 1697 //------------------------------------------------------------------------------ 1698 // Populates the core exports of the 'ir' submodule. 1699 //------------------------------------------------------------------------------ 1700 1701 void mlir::python::populateIRCore(py::module &m) { 1702 //---------------------------------------------------------------------------- 1703 // Mapping of MlirContext 1704 //---------------------------------------------------------------------------- 1705 py::class_<PyMlirContext>(m, "Context") 1706 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1707 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1708 .def("_get_context_again", 1709 [](PyMlirContext &self) { 1710 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1711 return ref.releaseObject(); 1712 }) 1713 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1714 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1715 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1716 &PyMlirContext::getCapsule) 1717 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1718 .def("__enter__", &PyMlirContext::contextEnter) 1719 .def("__exit__", &PyMlirContext::contextExit) 1720 .def_property_readonly_static( 1721 "current", 1722 [](py::object & /*class*/) { 1723 auto *context = PyThreadContextEntry::getDefaultContext(); 1724 if (!context) 1725 throw SetPyError(PyExc_ValueError, "No current Context"); 1726 return context; 1727 }, 1728 "Gets the Context bound to the current thread or raises ValueError") 1729 .def_property_readonly( 1730 "dialects", 1731 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1732 "Gets a container for accessing dialects by name") 1733 .def_property_readonly( 1734 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1735 "Alias for 'dialect'") 1736 .def( 1737 "get_dialect_descriptor", 1738 [=](PyMlirContext &self, std::string &name) { 1739 MlirDialect dialect = mlirContextGetOrLoadDialect( 1740 self.get(), {name.data(), name.size()}); 1741 if (mlirDialectIsNull(dialect)) { 1742 throw SetPyError(PyExc_ValueError, 1743 Twine("Dialect '") + name + "' not found"); 1744 } 1745 return PyDialectDescriptor(self.getRef(), dialect); 1746 }, 1747 "Gets or loads a dialect by name, returning its descriptor object") 1748 .def_property( 1749 "allow_unregistered_dialects", 1750 [](PyMlirContext &self) -> bool { 1751 return mlirContextGetAllowUnregisteredDialects(self.get()); 1752 }, 1753 [](PyMlirContext &self, bool value) { 1754 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1755 }); 1756 1757 //---------------------------------------------------------------------------- 1758 // Mapping of PyDialectDescriptor 1759 //---------------------------------------------------------------------------- 1760 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1761 .def_property_readonly("namespace", 1762 [](PyDialectDescriptor &self) { 1763 MlirStringRef ns = 1764 mlirDialectGetNamespace(self.get()); 1765 return py::str(ns.data, ns.length); 1766 }) 1767 .def("__repr__", [](PyDialectDescriptor &self) { 1768 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1769 std::string repr("<DialectDescriptor "); 1770 repr.append(ns.data, ns.length); 1771 repr.append(">"); 1772 return repr; 1773 }); 1774 1775 //---------------------------------------------------------------------------- 1776 // Mapping of PyDialects 1777 //---------------------------------------------------------------------------- 1778 py::class_<PyDialects>(m, "Dialects") 1779 .def("__getitem__", 1780 [=](PyDialects &self, std::string keyName) { 1781 MlirDialect dialect = 1782 self.getDialectForKey(keyName, /*attrError=*/false); 1783 py::object descriptor = 1784 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1785 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1786 }) 1787 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1788 MlirDialect dialect = 1789 self.getDialectForKey(attrName, /*attrError=*/true); 1790 py::object descriptor = 1791 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1792 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1793 }); 1794 1795 //---------------------------------------------------------------------------- 1796 // Mapping of PyDialect 1797 //---------------------------------------------------------------------------- 1798 py::class_<PyDialect>(m, "Dialect") 1799 .def(py::init<py::object>(), "descriptor") 1800 .def_property_readonly( 1801 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1802 .def("__repr__", [](py::object self) { 1803 auto clazz = self.attr("__class__"); 1804 return py::str("<Dialect ") + 1805 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1806 clazz.attr("__module__") + py::str(".") + 1807 clazz.attr("__name__") + py::str(")>"); 1808 }); 1809 1810 //---------------------------------------------------------------------------- 1811 // Mapping of Location 1812 //---------------------------------------------------------------------------- 1813 py::class_<PyLocation>(m, "Location") 1814 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1815 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1816 .def("__enter__", &PyLocation::contextEnter) 1817 .def("__exit__", &PyLocation::contextExit) 1818 .def("__eq__", 1819 [](PyLocation &self, PyLocation &other) -> bool { 1820 return mlirLocationEqual(self, other); 1821 }) 1822 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1823 .def_property_readonly_static( 1824 "current", 1825 [](py::object & /*class*/) { 1826 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1827 if (!loc) 1828 throw SetPyError(PyExc_ValueError, "No current Location"); 1829 return loc; 1830 }, 1831 "Gets the Location bound to the current thread or raises ValueError") 1832 .def_static( 1833 "unknown", 1834 [](DefaultingPyMlirContext context) { 1835 return PyLocation(context->getRef(), 1836 mlirLocationUnknownGet(context->get())); 1837 }, 1838 py::arg("context") = py::none(), 1839 "Gets a Location representing an unknown location") 1840 .def_static( 1841 "file", 1842 [](std::string filename, int line, int col, 1843 DefaultingPyMlirContext context) { 1844 return PyLocation( 1845 context->getRef(), 1846 mlirLocationFileLineColGet( 1847 context->get(), toMlirStringRef(filename), line, col)); 1848 }, 1849 py::arg("filename"), py::arg("line"), py::arg("col"), 1850 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1851 .def_property_readonly( 1852 "context", 1853 [](PyLocation &self) { return self.getContext().getObject(); }, 1854 "Context that owns the Location") 1855 .def("__repr__", [](PyLocation &self) { 1856 PyPrintAccumulator printAccum; 1857 mlirLocationPrint(self, printAccum.getCallback(), 1858 printAccum.getUserData()); 1859 return printAccum.join(); 1860 }); 1861 1862 //---------------------------------------------------------------------------- 1863 // Mapping of Module 1864 //---------------------------------------------------------------------------- 1865 py::class_<PyModule>(m, "Module") 1866 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1867 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1868 .def_static( 1869 "parse", 1870 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1871 MlirModule module = mlirModuleCreateParse( 1872 context->get(), toMlirStringRef(moduleAsm)); 1873 // TODO: Rework error reporting once diagnostic engine is exposed 1874 // in C API. 1875 if (mlirModuleIsNull(module)) { 1876 throw SetPyError( 1877 PyExc_ValueError, 1878 "Unable to parse module assembly (see diagnostics)"); 1879 } 1880 return PyModule::forModule(module).releaseObject(); 1881 }, 1882 py::arg("asm"), py::arg("context") = py::none(), 1883 kModuleParseDocstring) 1884 .def_static( 1885 "create", 1886 [](DefaultingPyLocation loc) { 1887 MlirModule module = mlirModuleCreateEmpty(loc); 1888 return PyModule::forModule(module).releaseObject(); 1889 }, 1890 py::arg("loc") = py::none(), "Creates an empty module") 1891 .def_property_readonly( 1892 "context", 1893 [](PyModule &self) { return self.getContext().getObject(); }, 1894 "Context that created the Module") 1895 .def_property_readonly( 1896 "operation", 1897 [](PyModule &self) { 1898 return PyOperation::forOperation(self.getContext(), 1899 mlirModuleGetOperation(self.get()), 1900 self.getRef().releaseObject()) 1901 .releaseObject(); 1902 }, 1903 "Accesses the module as an operation") 1904 .def_property_readonly( 1905 "body", 1906 [](PyModule &self) { 1907 PyOperationRef module_op = PyOperation::forOperation( 1908 self.getContext(), mlirModuleGetOperation(self.get()), 1909 self.getRef().releaseObject()); 1910 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 1911 return returnBlock; 1912 }, 1913 "Return the block for this module") 1914 .def( 1915 "dump", 1916 [](PyModule &self) { 1917 mlirOperationDump(mlirModuleGetOperation(self.get())); 1918 }, 1919 kDumpDocstring) 1920 .def( 1921 "__str__", 1922 [](PyModule &self) { 1923 MlirOperation operation = mlirModuleGetOperation(self.get()); 1924 PyPrintAccumulator printAccum; 1925 mlirOperationPrint(operation, printAccum.getCallback(), 1926 printAccum.getUserData()); 1927 return printAccum.join(); 1928 }, 1929 kOperationStrDunderDocstring); 1930 1931 //---------------------------------------------------------------------------- 1932 // Mapping of Operation. 1933 //---------------------------------------------------------------------------- 1934 py::class_<PyOperationBase>(m, "_OperationBase") 1935 .def("__eq__", 1936 [](PyOperationBase &self, PyOperationBase &other) { 1937 return &self.getOperation() == &other.getOperation(); 1938 }) 1939 .def("__eq__", 1940 [](PyOperationBase &self, py::object other) { return false; }) 1941 .def_property_readonly("attributes", 1942 [](PyOperationBase &self) { 1943 return PyOpAttributeMap( 1944 self.getOperation().getRef()); 1945 }) 1946 .def_property_readonly("operands", 1947 [](PyOperationBase &self) { 1948 return PyOpOperandList( 1949 self.getOperation().getRef()); 1950 }) 1951 .def_property_readonly("regions", 1952 [](PyOperationBase &self) { 1953 return PyRegionList( 1954 self.getOperation().getRef()); 1955 }) 1956 .def_property_readonly( 1957 "results", 1958 [](PyOperationBase &self) { 1959 return PyOpResultList(self.getOperation().getRef()); 1960 }, 1961 "Returns the list of Operation results.") 1962 .def_property_readonly( 1963 "result", 1964 [](PyOperationBase &self) { 1965 auto &operation = self.getOperation(); 1966 auto numResults = mlirOperationGetNumResults(operation); 1967 if (numResults != 1) { 1968 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 1969 throw SetPyError( 1970 PyExc_ValueError, 1971 Twine("Cannot call .result on operation ") + 1972 StringRef(name.data, name.length) + " which has " + 1973 Twine(numResults) + 1974 " results (it is only valid for operations with a " 1975 "single result)"); 1976 } 1977 return PyOpResult(operation.getRef(), 1978 mlirOperationGetResult(operation, 0)); 1979 }, 1980 "Shortcut to get an op result if it has only one (throws an error " 1981 "otherwise).") 1982 .def("__iter__", 1983 [](PyOperationBase &self) { 1984 return PyRegionIterator(self.getOperation().getRef()); 1985 }) 1986 .def( 1987 "__str__", 1988 [](PyOperationBase &self) { 1989 return self.getAsm(/*binary=*/false, 1990 /*largeElementsLimit=*/llvm::None, 1991 /*enableDebugInfo=*/false, 1992 /*prettyDebugInfo=*/false, 1993 /*printGenericOpForm=*/false, 1994 /*useLocalScope=*/false); 1995 }, 1996 "Returns the assembly form of the operation.") 1997 .def("print", &PyOperationBase::print, 1998 // Careful: Lots of arguments must match up with print method. 1999 py::arg("file") = py::none(), py::arg("binary") = false, 2000 py::arg("large_elements_limit") = py::none(), 2001 py::arg("enable_debug_info") = false, 2002 py::arg("pretty_debug_info") = false, 2003 py::arg("print_generic_op_form") = false, 2004 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2005 .def("get_asm", &PyOperationBase::getAsm, 2006 // Careful: Lots of arguments must match up with get_asm method. 2007 py::arg("binary") = false, 2008 py::arg("large_elements_limit") = py::none(), 2009 py::arg("enable_debug_info") = false, 2010 py::arg("pretty_debug_info") = false, 2011 py::arg("print_generic_op_form") = false, 2012 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2013 .def( 2014 "verify", 2015 [](PyOperationBase &self) { 2016 return mlirOperationVerify(self.getOperation()); 2017 }, 2018 "Verify the operation and return true if it passes, false if it " 2019 "fails."); 2020 2021 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2022 .def_static("create", &PyOperation::create, py::arg("name"), 2023 py::arg("results") = py::none(), 2024 py::arg("operands") = py::none(), 2025 py::arg("attributes") = py::none(), 2026 py::arg("successors") = py::none(), py::arg("regions") = 0, 2027 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2028 kOperationCreateDocstring) 2029 .def_property_readonly("name", 2030 [](PyOperation &self) { 2031 MlirOperation operation = self.get(); 2032 MlirStringRef name = mlirIdentifierStr( 2033 mlirOperationGetName(operation)); 2034 return py::str(name.data, name.length); 2035 }) 2036 .def_property_readonly( 2037 "context", 2038 [](PyOperation &self) { return self.getContext().getObject(); }, 2039 "Context that owns the Operation") 2040 .def_property_readonly("opview", &PyOperation::createOpView); 2041 2042 auto opViewClass = 2043 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2044 .def(py::init<py::object>()) 2045 .def_property_readonly("operation", &PyOpView::getOperationObject) 2046 .def_property_readonly( 2047 "context", 2048 [](PyOpView &self) { 2049 return self.getOperation().getContext().getObject(); 2050 }, 2051 "Context that owns the Operation") 2052 .def("__str__", [](PyOpView &self) { 2053 return py::str(self.getOperationObject()); 2054 }); 2055 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2056 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2057 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2058 opViewClass.attr("build_generic") = classmethod( 2059 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2060 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2061 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2062 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2063 "Builds a specific, generated OpView based on class level attributes."); 2064 2065 //---------------------------------------------------------------------------- 2066 // Mapping of PyRegion. 2067 //---------------------------------------------------------------------------- 2068 py::class_<PyRegion>(m, "Region") 2069 .def_property_readonly( 2070 "blocks", 2071 [](PyRegion &self) { 2072 return PyBlockList(self.getParentOperation(), self.get()); 2073 }, 2074 "Returns a forward-optimized sequence of blocks.") 2075 .def( 2076 "__iter__", 2077 [](PyRegion &self) { 2078 self.checkValid(); 2079 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2080 return PyBlockIterator(self.getParentOperation(), firstBlock); 2081 }, 2082 "Iterates over blocks in the region.") 2083 .def("__eq__", 2084 [](PyRegion &self, PyRegion &other) { 2085 return self.get().ptr == other.get().ptr; 2086 }) 2087 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2088 2089 //---------------------------------------------------------------------------- 2090 // Mapping of PyBlock. 2091 //---------------------------------------------------------------------------- 2092 py::class_<PyBlock>(m, "Block") 2093 .def_property_readonly( 2094 "arguments", 2095 [](PyBlock &self) { 2096 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2097 }, 2098 "Returns a list of block arguments.") 2099 .def_property_readonly( 2100 "operations", 2101 [](PyBlock &self) { 2102 return PyOperationList(self.getParentOperation(), self.get()); 2103 }, 2104 "Returns a forward-optimized sequence of operations.") 2105 .def( 2106 "__iter__", 2107 [](PyBlock &self) { 2108 self.checkValid(); 2109 MlirOperation firstOperation = 2110 mlirBlockGetFirstOperation(self.get()); 2111 return PyOperationIterator(self.getParentOperation(), 2112 firstOperation); 2113 }, 2114 "Iterates over operations in the block.") 2115 .def("__eq__", 2116 [](PyBlock &self, PyBlock &other) { 2117 return self.get().ptr == other.get().ptr; 2118 }) 2119 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2120 .def( 2121 "__str__", 2122 [](PyBlock &self) { 2123 self.checkValid(); 2124 PyPrintAccumulator printAccum; 2125 mlirBlockPrint(self.get(), printAccum.getCallback(), 2126 printAccum.getUserData()); 2127 return printAccum.join(); 2128 }, 2129 "Returns the assembly form of the block."); 2130 2131 //---------------------------------------------------------------------------- 2132 // Mapping of PyInsertionPoint. 2133 //---------------------------------------------------------------------------- 2134 2135 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2136 .def(py::init<PyBlock &>(), py::arg("block"), 2137 "Inserts after the last operation but still inside the block.") 2138 .def("__enter__", &PyInsertionPoint::contextEnter) 2139 .def("__exit__", &PyInsertionPoint::contextExit) 2140 .def_property_readonly_static( 2141 "current", 2142 [](py::object & /*class*/) { 2143 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2144 if (!ip) 2145 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2146 return ip; 2147 }, 2148 "Gets the InsertionPoint bound to the current thread or raises " 2149 "ValueError if none has been set") 2150 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2151 "Inserts before a referenced operation.") 2152 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2153 py::arg("block"), "Inserts at the beginning of the block.") 2154 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2155 py::arg("block"), "Inserts before the block terminator.") 2156 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2157 "Inserts an operation."); 2158 2159 //---------------------------------------------------------------------------- 2160 // Mapping of PyAttribute. 2161 //---------------------------------------------------------------------------- 2162 py::class_<PyAttribute>(m, "Attribute") 2163 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2164 &PyAttribute::getCapsule) 2165 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2166 .def_static( 2167 "parse", 2168 [](std::string attrSpec, DefaultingPyMlirContext context) { 2169 MlirAttribute type = mlirAttributeParseGet( 2170 context->get(), toMlirStringRef(attrSpec)); 2171 // TODO: Rework error reporting once diagnostic engine is exposed 2172 // in C API. 2173 if (mlirAttributeIsNull(type)) { 2174 throw SetPyError(PyExc_ValueError, 2175 Twine("Unable to parse attribute: '") + 2176 attrSpec + "'"); 2177 } 2178 return PyAttribute(context->getRef(), type); 2179 }, 2180 py::arg("asm"), py::arg("context") = py::none(), 2181 "Parses an attribute from an assembly form") 2182 .def_property_readonly( 2183 "context", 2184 [](PyAttribute &self) { return self.getContext().getObject(); }, 2185 "Context that owns the Attribute") 2186 .def_property_readonly("type", 2187 [](PyAttribute &self) { 2188 return PyType(self.getContext()->getRef(), 2189 mlirAttributeGetType(self)); 2190 }) 2191 .def( 2192 "get_named", 2193 [](PyAttribute &self, std::string name) { 2194 return PyNamedAttribute(self, std::move(name)); 2195 }, 2196 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2197 .def("__eq__", 2198 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2199 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2200 .def( 2201 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2202 kDumpDocstring) 2203 .def( 2204 "__str__", 2205 [](PyAttribute &self) { 2206 PyPrintAccumulator printAccum; 2207 mlirAttributePrint(self, printAccum.getCallback(), 2208 printAccum.getUserData()); 2209 return printAccum.join(); 2210 }, 2211 "Returns the assembly form of the Attribute.") 2212 .def("__repr__", [](PyAttribute &self) { 2213 // Generally, assembly formats are not printed for __repr__ because 2214 // this can cause exceptionally long debug output and exceptions. 2215 // However, attribute values are generally considered useful and are 2216 // printed. This may need to be re-evaluated if debug dumps end up 2217 // being excessive. 2218 PyPrintAccumulator printAccum; 2219 printAccum.parts.append("Attribute("); 2220 mlirAttributePrint(self, printAccum.getCallback(), 2221 printAccum.getUserData()); 2222 printAccum.parts.append(")"); 2223 return printAccum.join(); 2224 }); 2225 2226 //---------------------------------------------------------------------------- 2227 // Mapping of PyNamedAttribute 2228 //---------------------------------------------------------------------------- 2229 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2230 .def("__repr__", 2231 [](PyNamedAttribute &self) { 2232 PyPrintAccumulator printAccum; 2233 printAccum.parts.append("NamedAttribute("); 2234 printAccum.parts.append( 2235 mlirIdentifierStr(self.namedAttr.name).data); 2236 printAccum.parts.append("="); 2237 mlirAttributePrint(self.namedAttr.attribute, 2238 printAccum.getCallback(), 2239 printAccum.getUserData()); 2240 printAccum.parts.append(")"); 2241 return printAccum.join(); 2242 }) 2243 .def_property_readonly( 2244 "name", 2245 [](PyNamedAttribute &self) { 2246 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2247 mlirIdentifierStr(self.namedAttr.name).length); 2248 }, 2249 "The name of the NamedAttribute binding") 2250 .def_property_readonly( 2251 "attr", 2252 [](PyNamedAttribute &self) { 2253 // TODO: When named attribute is removed/refactored, also remove 2254 // this constructor (it does an inefficient table lookup). 2255 auto contextRef = PyMlirContext::forContext( 2256 mlirAttributeGetContext(self.namedAttr.attribute)); 2257 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2258 }, 2259 py::keep_alive<0, 1>(), 2260 "The underlying generic attribute of the NamedAttribute binding"); 2261 2262 //---------------------------------------------------------------------------- 2263 // Mapping of PyType. 2264 //---------------------------------------------------------------------------- 2265 py::class_<PyType>(m, "Type") 2266 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2267 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2268 .def_static( 2269 "parse", 2270 [](std::string typeSpec, DefaultingPyMlirContext context) { 2271 MlirType type = 2272 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2273 // TODO: Rework error reporting once diagnostic engine is exposed 2274 // in C API. 2275 if (mlirTypeIsNull(type)) { 2276 throw SetPyError(PyExc_ValueError, 2277 Twine("Unable to parse type: '") + typeSpec + 2278 "'"); 2279 } 2280 return PyType(context->getRef(), type); 2281 }, 2282 py::arg("asm"), py::arg("context") = py::none(), 2283 kContextParseTypeDocstring) 2284 .def_property_readonly( 2285 "context", [](PyType &self) { return self.getContext().getObject(); }, 2286 "Context that owns the Type") 2287 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2288 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2289 .def( 2290 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2291 .def( 2292 "__str__", 2293 [](PyType &self) { 2294 PyPrintAccumulator printAccum; 2295 mlirTypePrint(self, printAccum.getCallback(), 2296 printAccum.getUserData()); 2297 return printAccum.join(); 2298 }, 2299 "Returns the assembly form of the type.") 2300 .def("__repr__", [](PyType &self) { 2301 // Generally, assembly formats are not printed for __repr__ because 2302 // this can cause exceptionally long debug output and exceptions. 2303 // However, types are an exception as they typically have compact 2304 // assembly forms and printing them is useful. 2305 PyPrintAccumulator printAccum; 2306 printAccum.parts.append("Type("); 2307 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2308 printAccum.parts.append(")"); 2309 return printAccum.join(); 2310 }); 2311 2312 //---------------------------------------------------------------------------- 2313 // Mapping of Value. 2314 //---------------------------------------------------------------------------- 2315 py::class_<PyValue>(m, "Value") 2316 .def_property_readonly( 2317 "context", 2318 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2319 "Context in which the value lives.") 2320 .def( 2321 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2322 kDumpDocstring) 2323 .def("__eq__", 2324 [](PyValue &self, PyValue &other) { 2325 return self.get().ptr == other.get().ptr; 2326 }) 2327 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2328 .def( 2329 "__str__", 2330 [](PyValue &self) { 2331 PyPrintAccumulator printAccum; 2332 printAccum.parts.append("Value("); 2333 mlirValuePrint(self.get(), printAccum.getCallback(), 2334 printAccum.getUserData()); 2335 printAccum.parts.append(")"); 2336 return printAccum.join(); 2337 }, 2338 kValueDunderStrDocstring) 2339 .def_property_readonly("type", [](PyValue &self) { 2340 return PyType(self.getParentOperation()->getContext(), 2341 mlirValueGetType(self.get())); 2342 }); 2343 PyBlockArgument::bind(m); 2344 PyOpResult::bind(m); 2345 2346 // Container bindings. 2347 PyBlockArgumentList::bind(m); 2348 PyBlockIterator::bind(m); 2349 PyBlockList::bind(m); 2350 PyOperationIterator::bind(m); 2351 PyOperationList::bind(m); 2352 PyOpAttributeMap::bind(m); 2353 PyOpOperandList::bind(m); 2354 PyOpResultList::bind(m); 2355 PyRegionIterator::bind(m); 2356 PyRegionList::bind(m); 2357 } 2358