1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Registration.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include <pybind11/stl.h> 20 21 namespace py = pybind11; 22 using namespace mlir; 23 using namespace mlir::python; 24 25 using llvm::SmallVector; 26 using llvm::StringRef; 27 using llvm::Twine; 28 29 //------------------------------------------------------------------------------ 30 // Docstrings (trivial, non-duplicated docstrings are included inline). 31 //------------------------------------------------------------------------------ 32 33 static const char kContextParseTypeDocstring[] = 34 R"(Parses the assembly form of a type. 35 36 Returns a Type object or raises a ValueError if the type cannot be parsed. 37 38 See also: https://mlir.llvm.org/docs/LangRef/#type-system 39 )"; 40 41 static const char kContextGetFileLocationDocstring[] = 42 R"(Gets a Location representing a file, line and column)"; 43 44 static const char kModuleParseDocstring[] = 45 R"(Parses a module's assembly format from a string. 46 47 Returns a new MlirModule or raises a ValueError if the parsing fails. 48 49 See also: https://mlir.llvm.org/docs/LangRef/ 50 )"; 51 52 static const char kOperationCreateDocstring[] = 53 R"(Creates a new operation. 54 55 Args: 56 name: Operation name (e.g. "dialect.operation"). 57 results: Sequence of Type representing op result types. 58 attributes: Dict of str:Attribute. 59 successors: List of Block for the operation's successors. 60 regions: Number of regions to create. 61 location: A Location object (defaults to resolve from context manager). 62 ip: An InsertionPoint (defaults to resolve from context manager or set to 63 False to disable insertion, even with an insertion point set in the 64 context manager). 65 Returns: 66 A new "detached" Operation object. Detached operations can be added 67 to blocks, which causes them to become "attached." 68 )"; 69 70 static const char kOperationPrintDocstring[] = 71 R"(Prints the assembly form of the operation to a file like object. 72 73 Args: 74 file: The file like object to write to. Defaults to sys.stdout. 75 binary: Whether to write bytes (True) or str (False). Defaults to False. 76 large_elements_limit: Whether to elide elements attributes above this 77 number of elements. Defaults to None (no limit). 78 enable_debug_info: Whether to print debug/location information. Defaults 79 to False. 80 pretty_debug_info: Whether to format debug information for easier reading 81 by a human (warning: the result is unparseable). 82 print_generic_op_form: Whether to print the generic assembly forms of all 83 ops. Defaults to False. 84 use_local_Scope: Whether to print in a way that is more optimized for 85 multi-threaded access but may not be consistent with how the overall 86 module prints. 87 )"; 88 89 static const char kOperationGetAsmDocstring[] = 90 R"(Gets the assembly form of the operation with all options available. 91 92 Args: 93 binary: Whether to return a bytes (True) or str (False) object. Defaults to 94 False. 95 ... others ...: See the print() method for common keyword arguments for 96 configuring the printout. 97 Returns: 98 Either a bytes or str object, depending on the setting of the 'binary' 99 argument. 100 )"; 101 102 static const char kOperationStrDunderDocstring[] = 103 R"(Gets the assembly form of the operation with default options. 104 105 If more advanced control over the assembly formatting or I/O options is needed, 106 use the dedicated print or get_asm method, which supports keyword arguments to 107 customize behavior. 108 )"; 109 110 static const char kDumpDocstring[] = 111 R"(Dumps a debug representation of the object to stderr.)"; 112 113 static const char kAppendBlockDocstring[] = 114 R"(Appends a new block, with argument types as positional args. 115 116 Returns: 117 The created block. 118 )"; 119 120 static const char kValueDunderStrDocstring[] = 121 R"(Returns the string form of the value. 122 123 If the value is a block argument, this is the assembly form of its type and the 124 position in the argument list. If the value is an operation result, this is 125 equivalent to printing the operation that produced it. 126 )"; 127 128 //------------------------------------------------------------------------------ 129 // Utilities. 130 //------------------------------------------------------------------------------ 131 132 // Helper for creating an @classmethod. 133 template <class Func, typename... Args> 134 py::object classmethod(Func f, Args... args) { 135 py::object cf = py::cpp_function(f, args...); 136 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 137 } 138 139 static py::object 140 createCustomDialectWrapper(const std::string &dialectNamespace, 141 py::object dialectDescriptor) { 142 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 143 if (!dialectClass) { 144 // Use the base class. 145 return py::cast(PyDialect(std::move(dialectDescriptor))); 146 } 147 148 // Create the custom implementation. 149 return (*dialectClass)(std::move(dialectDescriptor)); 150 } 151 152 static MlirStringRef toMlirStringRef(const std::string &s) { 153 return mlirStringRefCreate(s.data(), s.size()); 154 } 155 156 //------------------------------------------------------------------------------ 157 // Collections. 158 //------------------------------------------------------------------------------ 159 160 namespace { 161 162 class PyRegionIterator { 163 public: 164 PyRegionIterator(PyOperationRef operation) 165 : operation(std::move(operation)) {} 166 167 PyRegionIterator &dunderIter() { return *this; } 168 169 PyRegion dunderNext() { 170 operation->checkValid(); 171 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 172 throw py::stop_iteration(); 173 } 174 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 175 return PyRegion(operation, region); 176 } 177 178 static void bind(py::module &m) { 179 py::class_<PyRegionIterator>(m, "RegionIterator") 180 .def("__iter__", &PyRegionIterator::dunderIter) 181 .def("__next__", &PyRegionIterator::dunderNext); 182 } 183 184 private: 185 PyOperationRef operation; 186 int nextIndex = 0; 187 }; 188 189 /// Regions of an op are fixed length and indexed numerically so are represented 190 /// with a sequence-like container. 191 class PyRegionList { 192 public: 193 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 194 195 intptr_t dunderLen() { 196 operation->checkValid(); 197 return mlirOperationGetNumRegions(operation->get()); 198 } 199 200 PyRegion dunderGetItem(intptr_t index) { 201 // dunderLen checks validity. 202 if (index < 0 || index >= dunderLen()) { 203 throw SetPyError(PyExc_IndexError, 204 "attempt to access out of bounds region"); 205 } 206 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 207 return PyRegion(operation, region); 208 } 209 210 static void bind(py::module &m) { 211 py::class_<PyRegionList>(m, "RegionSequence") 212 .def("__len__", &PyRegionList::dunderLen) 213 .def("__getitem__", &PyRegionList::dunderGetItem); 214 } 215 216 private: 217 PyOperationRef operation; 218 }; 219 220 class PyBlockIterator { 221 public: 222 PyBlockIterator(PyOperationRef operation, MlirBlock next) 223 : operation(std::move(operation)), next(next) {} 224 225 PyBlockIterator &dunderIter() { return *this; } 226 227 PyBlock dunderNext() { 228 operation->checkValid(); 229 if (mlirBlockIsNull(next)) { 230 throw py::stop_iteration(); 231 } 232 233 PyBlock returnBlock(operation, next); 234 next = mlirBlockGetNextInRegion(next); 235 return returnBlock; 236 } 237 238 static void bind(py::module &m) { 239 py::class_<PyBlockIterator>(m, "BlockIterator") 240 .def("__iter__", &PyBlockIterator::dunderIter) 241 .def("__next__", &PyBlockIterator::dunderNext); 242 } 243 244 private: 245 PyOperationRef operation; 246 MlirBlock next; 247 }; 248 249 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 250 /// we present them as a more full-featured list-like container but optimize 251 /// it for forward iteration. Blocks are always owned by a region. 252 class PyBlockList { 253 public: 254 PyBlockList(PyOperationRef operation, MlirRegion region) 255 : operation(std::move(operation)), region(region) {} 256 257 PyBlockIterator dunderIter() { 258 operation->checkValid(); 259 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 260 } 261 262 intptr_t dunderLen() { 263 operation->checkValid(); 264 intptr_t count = 0; 265 MlirBlock block = mlirRegionGetFirstBlock(region); 266 while (!mlirBlockIsNull(block)) { 267 count += 1; 268 block = mlirBlockGetNextInRegion(block); 269 } 270 return count; 271 } 272 273 PyBlock dunderGetItem(intptr_t index) { 274 operation->checkValid(); 275 if (index < 0) { 276 throw SetPyError(PyExc_IndexError, 277 "attempt to access out of bounds block"); 278 } 279 MlirBlock block = mlirRegionGetFirstBlock(region); 280 while (!mlirBlockIsNull(block)) { 281 if (index == 0) { 282 return PyBlock(operation, block); 283 } 284 block = mlirBlockGetNextInRegion(block); 285 index -= 1; 286 } 287 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 288 } 289 290 PyBlock appendBlock(py::args pyArgTypes) { 291 operation->checkValid(); 292 llvm::SmallVector<MlirType, 4> argTypes; 293 argTypes.reserve(pyArgTypes.size()); 294 for (auto &pyArg : pyArgTypes) { 295 argTypes.push_back(pyArg.cast<PyType &>()); 296 } 297 298 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 299 mlirRegionAppendOwnedBlock(region, block); 300 return PyBlock(operation, block); 301 } 302 303 static void bind(py::module &m) { 304 py::class_<PyBlockList>(m, "BlockList") 305 .def("__getitem__", &PyBlockList::dunderGetItem) 306 .def("__iter__", &PyBlockList::dunderIter) 307 .def("__len__", &PyBlockList::dunderLen) 308 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 309 } 310 311 private: 312 PyOperationRef operation; 313 MlirRegion region; 314 }; 315 316 class PyOperationIterator { 317 public: 318 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 319 : parentOperation(std::move(parentOperation)), next(next) {} 320 321 PyOperationIterator &dunderIter() { return *this; } 322 323 py::object dunderNext() { 324 parentOperation->checkValid(); 325 if (mlirOperationIsNull(next)) { 326 throw py::stop_iteration(); 327 } 328 329 PyOperationRef returnOperation = 330 PyOperation::forOperation(parentOperation->getContext(), next); 331 next = mlirOperationGetNextInBlock(next); 332 return returnOperation->createOpView(); 333 } 334 335 static void bind(py::module &m) { 336 py::class_<PyOperationIterator>(m, "OperationIterator") 337 .def("__iter__", &PyOperationIterator::dunderIter) 338 .def("__next__", &PyOperationIterator::dunderNext); 339 } 340 341 private: 342 PyOperationRef parentOperation; 343 MlirOperation next; 344 }; 345 346 /// Operations are exposed by the C-API as a forward-only linked list. In 347 /// Python, we present them as a more full-featured list-like container but 348 /// optimize it for forward iteration. Iterable operations are always owned 349 /// by a block. 350 class PyOperationList { 351 public: 352 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 353 : parentOperation(std::move(parentOperation)), block(block) {} 354 355 PyOperationIterator dunderIter() { 356 parentOperation->checkValid(); 357 return PyOperationIterator(parentOperation, 358 mlirBlockGetFirstOperation(block)); 359 } 360 361 intptr_t dunderLen() { 362 parentOperation->checkValid(); 363 intptr_t count = 0; 364 MlirOperation childOp = mlirBlockGetFirstOperation(block); 365 while (!mlirOperationIsNull(childOp)) { 366 count += 1; 367 childOp = mlirOperationGetNextInBlock(childOp); 368 } 369 return count; 370 } 371 372 py::object dunderGetItem(intptr_t index) { 373 parentOperation->checkValid(); 374 if (index < 0) { 375 throw SetPyError(PyExc_IndexError, 376 "attempt to access out of bounds operation"); 377 } 378 MlirOperation childOp = mlirBlockGetFirstOperation(block); 379 while (!mlirOperationIsNull(childOp)) { 380 if (index == 0) { 381 return PyOperation::forOperation(parentOperation->getContext(), childOp) 382 ->createOpView(); 383 } 384 childOp = mlirOperationGetNextInBlock(childOp); 385 index -= 1; 386 } 387 throw SetPyError(PyExc_IndexError, 388 "attempt to access out of bounds operation"); 389 } 390 391 static void bind(py::module &m) { 392 py::class_<PyOperationList>(m, "OperationList") 393 .def("__getitem__", &PyOperationList::dunderGetItem) 394 .def("__iter__", &PyOperationList::dunderIter) 395 .def("__len__", &PyOperationList::dunderLen); 396 } 397 398 private: 399 PyOperationRef parentOperation; 400 MlirBlock block; 401 }; 402 403 } // namespace 404 405 //------------------------------------------------------------------------------ 406 // PyMlirContext 407 //------------------------------------------------------------------------------ 408 409 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 410 py::gil_scoped_acquire acquire; 411 auto &liveContexts = getLiveContexts(); 412 liveContexts[context.ptr] = this; 413 } 414 415 PyMlirContext::~PyMlirContext() { 416 // Note that the only public way to construct an instance is via the 417 // forContext method, which always puts the associated handle into 418 // liveContexts. 419 py::gil_scoped_acquire acquire; 420 getLiveContexts().erase(context.ptr); 421 mlirContextDestroy(context); 422 } 423 424 py::object PyMlirContext::getCapsule() { 425 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 426 } 427 428 py::object PyMlirContext::createFromCapsule(py::object capsule) { 429 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 430 if (mlirContextIsNull(rawContext)) 431 throw py::error_already_set(); 432 return forContext(rawContext).releaseObject(); 433 } 434 435 PyMlirContext *PyMlirContext::createNewContextForInit() { 436 MlirContext context = mlirContextCreate(); 437 mlirRegisterAllDialects(context); 438 return new PyMlirContext(context); 439 } 440 441 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 442 py::gil_scoped_acquire acquire; 443 auto &liveContexts = getLiveContexts(); 444 auto it = liveContexts.find(context.ptr); 445 if (it == liveContexts.end()) { 446 // Create. 447 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 448 py::object pyRef = py::cast(unownedContextWrapper); 449 assert(pyRef && "cast to py::object failed"); 450 liveContexts[context.ptr] = unownedContextWrapper; 451 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 452 } 453 // Use existing. 454 py::object pyRef = py::cast(it->second); 455 return PyMlirContextRef(it->second, std::move(pyRef)); 456 } 457 458 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 459 static LiveContextMap liveContexts; 460 return liveContexts; 461 } 462 463 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 464 465 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 466 467 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 468 469 pybind11::object PyMlirContext::contextEnter() { 470 return PyThreadContextEntry::pushContext(*this); 471 } 472 473 void PyMlirContext::contextExit(pybind11::object excType, 474 pybind11::object excVal, 475 pybind11::object excTb) { 476 PyThreadContextEntry::popContext(*this); 477 } 478 479 PyMlirContext &DefaultingPyMlirContext::resolve() { 480 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 481 if (!context) { 482 throw SetPyError( 483 PyExc_RuntimeError, 484 "An MLIR function requires a Context but none was provided in the call " 485 "or from the surrounding environment. Either pass to the function with " 486 "a 'context=' argument or establish a default using 'with Context():'"); 487 } 488 return *context; 489 } 490 491 //------------------------------------------------------------------------------ 492 // PyThreadContextEntry management 493 //------------------------------------------------------------------------------ 494 495 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 496 static thread_local std::vector<PyThreadContextEntry> stack; 497 return stack; 498 } 499 500 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 501 auto &stack = getStack(); 502 if (stack.empty()) 503 return nullptr; 504 return &stack.back(); 505 } 506 507 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 508 py::object insertionPoint, 509 py::object location) { 510 auto &stack = getStack(); 511 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 512 std::move(location)); 513 // If the new stack has more than one entry and the context of the new top 514 // entry matches the previous, copy the insertionPoint and location from the 515 // previous entry if missing from the new top entry. 516 if (stack.size() > 1) { 517 auto &prev = *(stack.rbegin() + 1); 518 auto ¤t = stack.back(); 519 if (current.context.is(prev.context)) { 520 // Default non-context objects from the previous entry. 521 if (!current.insertionPoint) 522 current.insertionPoint = prev.insertionPoint; 523 if (!current.location) 524 current.location = prev.location; 525 } 526 } 527 } 528 529 PyMlirContext *PyThreadContextEntry::getContext() { 530 if (!context) 531 return nullptr; 532 return py::cast<PyMlirContext *>(context); 533 } 534 535 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 536 if (!insertionPoint) 537 return nullptr; 538 return py::cast<PyInsertionPoint *>(insertionPoint); 539 } 540 541 PyLocation *PyThreadContextEntry::getLocation() { 542 if (!location) 543 return nullptr; 544 return py::cast<PyLocation *>(location); 545 } 546 547 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 548 auto *tos = getTopOfStack(); 549 return tos ? tos->getContext() : nullptr; 550 } 551 552 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 553 auto *tos = getTopOfStack(); 554 return tos ? tos->getInsertionPoint() : nullptr; 555 } 556 557 PyLocation *PyThreadContextEntry::getDefaultLocation() { 558 auto *tos = getTopOfStack(); 559 return tos ? tos->getLocation() : nullptr; 560 } 561 562 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 563 py::object contextObj = py::cast(context); 564 push(FrameKind::Context, /*context=*/contextObj, 565 /*insertionPoint=*/py::object(), 566 /*location=*/py::object()); 567 return contextObj; 568 } 569 570 void PyThreadContextEntry::popContext(PyMlirContext &context) { 571 auto &stack = getStack(); 572 if (stack.empty()) 573 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 574 auto &tos = stack.back(); 575 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 576 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 577 stack.pop_back(); 578 } 579 580 py::object 581 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 582 py::object contextObj = 583 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 584 py::object insertionPointObj = py::cast(insertionPoint); 585 push(FrameKind::InsertionPoint, 586 /*context=*/contextObj, 587 /*insertionPoint=*/insertionPointObj, 588 /*location=*/py::object()); 589 return insertionPointObj; 590 } 591 592 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 593 auto &stack = getStack(); 594 if (stack.empty()) 595 throw SetPyError(PyExc_RuntimeError, 596 "Unbalanced InsertionPoint enter/exit"); 597 auto &tos = stack.back(); 598 if (tos.frameKind != FrameKind::InsertionPoint && 599 tos.getInsertionPoint() != &insertionPoint) 600 throw SetPyError(PyExc_RuntimeError, 601 "Unbalanced InsertionPoint enter/exit"); 602 stack.pop_back(); 603 } 604 605 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 606 py::object contextObj = location.getContext().getObject(); 607 py::object locationObj = py::cast(location); 608 push(FrameKind::Location, /*context=*/contextObj, 609 /*insertionPoint=*/py::object(), 610 /*location=*/locationObj); 611 return locationObj; 612 } 613 614 void PyThreadContextEntry::popLocation(PyLocation &location) { 615 auto &stack = getStack(); 616 if (stack.empty()) 617 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 618 auto &tos = stack.back(); 619 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 620 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 621 stack.pop_back(); 622 } 623 624 //------------------------------------------------------------------------------ 625 // PyDialect, PyDialectDescriptor, PyDialects 626 //------------------------------------------------------------------------------ 627 628 MlirDialect PyDialects::getDialectForKey(const std::string &key, 629 bool attrError) { 630 // If the "std" dialect was asked for, substitute the empty namespace :( 631 static const std::string emptyKey; 632 const std::string *canonKey = key == "std" ? &emptyKey : &key; 633 MlirDialect dialect = mlirContextGetOrLoadDialect( 634 getContext()->get(), {canonKey->data(), canonKey->size()}); 635 if (mlirDialectIsNull(dialect)) { 636 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 637 Twine("Dialect '") + key + "' not found"); 638 } 639 return dialect; 640 } 641 642 //------------------------------------------------------------------------------ 643 // PyLocation 644 //------------------------------------------------------------------------------ 645 646 py::object PyLocation::getCapsule() { 647 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 648 } 649 650 PyLocation PyLocation::createFromCapsule(py::object capsule) { 651 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 652 if (mlirLocationIsNull(rawLoc)) 653 throw py::error_already_set(); 654 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 655 rawLoc); 656 } 657 658 py::object PyLocation::contextEnter() { 659 return PyThreadContextEntry::pushLocation(*this); 660 } 661 662 void PyLocation::contextExit(py::object excType, py::object excVal, 663 py::object excTb) { 664 PyThreadContextEntry::popLocation(*this); 665 } 666 667 PyLocation &DefaultingPyLocation::resolve() { 668 auto *location = PyThreadContextEntry::getDefaultLocation(); 669 if (!location) { 670 throw SetPyError( 671 PyExc_RuntimeError, 672 "An MLIR function requires a Location but none was provided in the " 673 "call or from the surrounding environment. Either pass to the function " 674 "with a 'loc=' argument or establish a default using 'with loc:'"); 675 } 676 return *location; 677 } 678 679 //------------------------------------------------------------------------------ 680 // PyModule 681 //------------------------------------------------------------------------------ 682 683 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 684 : BaseContextObject(std::move(contextRef)), module(module) {} 685 686 PyModule::~PyModule() { 687 py::gil_scoped_acquire acquire; 688 auto &liveModules = getContext()->liveModules; 689 assert(liveModules.count(module.ptr) == 1 && 690 "destroying module not in live map"); 691 liveModules.erase(module.ptr); 692 mlirModuleDestroy(module); 693 } 694 695 PyModuleRef PyModule::forModule(MlirModule module) { 696 MlirContext context = mlirModuleGetContext(module); 697 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 698 699 py::gil_scoped_acquire acquire; 700 auto &liveModules = contextRef->liveModules; 701 auto it = liveModules.find(module.ptr); 702 if (it == liveModules.end()) { 703 // Create. 704 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 705 // Note that the default return value policy on cast is automatic_reference, 706 // which does not take ownership (delete will not be called). 707 // Just be explicit. 708 py::object pyRef = 709 py::cast(unownedModule, py::return_value_policy::take_ownership); 710 unownedModule->handle = pyRef; 711 liveModules[module.ptr] = 712 std::make_pair(unownedModule->handle, unownedModule); 713 return PyModuleRef(unownedModule, std::move(pyRef)); 714 } 715 // Use existing. 716 PyModule *existing = it->second.second; 717 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 718 return PyModuleRef(existing, std::move(pyRef)); 719 } 720 721 py::object PyModule::createFromCapsule(py::object capsule) { 722 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 723 if (mlirModuleIsNull(rawModule)) 724 throw py::error_already_set(); 725 return forModule(rawModule).releaseObject(); 726 } 727 728 py::object PyModule::getCapsule() { 729 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 730 } 731 732 //------------------------------------------------------------------------------ 733 // PyOperation 734 //------------------------------------------------------------------------------ 735 736 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 737 : BaseContextObject(std::move(contextRef)), operation(operation) {} 738 739 PyOperation::~PyOperation() { 740 auto &liveOperations = getContext()->liveOperations; 741 assert(liveOperations.count(operation.ptr) == 1 && 742 "destroying operation not in live map"); 743 liveOperations.erase(operation.ptr); 744 if (!isAttached()) { 745 mlirOperationDestroy(operation); 746 } 747 } 748 749 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 750 MlirOperation operation, 751 py::object parentKeepAlive) { 752 auto &liveOperations = contextRef->liveOperations; 753 // Create. 754 PyOperation *unownedOperation = 755 new PyOperation(std::move(contextRef), operation); 756 // Note that the default return value policy on cast is automatic_reference, 757 // which does not take ownership (delete will not be called). 758 // Just be explicit. 759 py::object pyRef = 760 py::cast(unownedOperation, py::return_value_policy::take_ownership); 761 unownedOperation->handle = pyRef; 762 if (parentKeepAlive) { 763 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 764 } 765 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 766 return PyOperationRef(unownedOperation, std::move(pyRef)); 767 } 768 769 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 770 MlirOperation operation, 771 py::object parentKeepAlive) { 772 auto &liveOperations = contextRef->liveOperations; 773 auto it = liveOperations.find(operation.ptr); 774 if (it == liveOperations.end()) { 775 // Create. 776 return createInstance(std::move(contextRef), operation, 777 std::move(parentKeepAlive)); 778 } 779 // Use existing. 780 PyOperation *existing = it->second.second; 781 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 782 return PyOperationRef(existing, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 assert(liveOperations.count(operation.ptr) == 0 && 790 "cannot create detached operation that already exists"); 791 (void)liveOperations; 792 793 PyOperationRef created = createInstance(std::move(contextRef), operation, 794 std::move(parentKeepAlive)); 795 created->attached = false; 796 return created; 797 } 798 799 void PyOperation::checkValid() const { 800 if (!valid) { 801 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 802 } 803 } 804 805 void PyOperationBase::print(py::object fileObject, bool binary, 806 llvm::Optional<int64_t> largeElementsLimit, 807 bool enableDebugInfo, bool prettyDebugInfo, 808 bool printGenericOpForm, bool useLocalScope) { 809 PyOperation &operation = getOperation(); 810 operation.checkValid(); 811 if (fileObject.is_none()) 812 fileObject = py::module::import("sys").attr("stdout"); 813 814 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 815 fileObject.attr("write")("// Verification failed, printing generic form\n"); 816 printGenericOpForm = true; 817 } 818 819 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 820 if (largeElementsLimit) 821 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 822 if (enableDebugInfo) 823 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 824 if (printGenericOpForm) 825 mlirOpPrintingFlagsPrintGenericOpForm(flags); 826 827 PyFileAccumulator accum(fileObject, binary); 828 py::gil_scoped_release(); 829 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 830 accum.getUserData()); 831 mlirOpPrintingFlagsDestroy(flags); 832 } 833 834 py::object PyOperationBase::getAsm(bool binary, 835 llvm::Optional<int64_t> largeElementsLimit, 836 bool enableDebugInfo, bool prettyDebugInfo, 837 bool printGenericOpForm, 838 bool useLocalScope) { 839 py::object fileObject; 840 if (binary) { 841 fileObject = py::module::import("io").attr("BytesIO")(); 842 } else { 843 fileObject = py::module::import("io").attr("StringIO")(); 844 } 845 print(fileObject, /*binary=*/binary, 846 /*largeElementsLimit=*/largeElementsLimit, 847 /*enableDebugInfo=*/enableDebugInfo, 848 /*prettyDebugInfo=*/prettyDebugInfo, 849 /*printGenericOpForm=*/printGenericOpForm, 850 /*useLocalScope=*/useLocalScope); 851 852 return fileObject.attr("getvalue")(); 853 } 854 855 PyOperationRef PyOperation::getParentOperation() { 856 if (!isAttached()) 857 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 858 MlirOperation operation = mlirOperationGetParentOperation(get()); 859 if (mlirOperationIsNull(operation)) 860 throw SetPyError(PyExc_ValueError, "Operation has no parent."); 861 return PyOperation::forOperation(getContext(), operation); 862 } 863 864 PyBlock PyOperation::getBlock() { 865 PyOperationRef parentOperation = getParentOperation(); 866 MlirBlock block = mlirOperationGetBlock(get()); 867 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 868 return PyBlock{std::move(parentOperation), block}; 869 } 870 871 py::object PyOperation::getCapsule() { 872 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 873 } 874 875 py::object PyOperation::createFromCapsule(py::object capsule) { 876 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 877 if (mlirOperationIsNull(rawOperation)) 878 throw py::error_already_set(); 879 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 880 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 881 .releaseObject(); 882 } 883 884 py::object PyOperation::create( 885 std::string name, llvm::Optional<std::vector<PyType *>> results, 886 llvm::Optional<std::vector<PyValue *>> operands, 887 llvm::Optional<py::dict> attributes, 888 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 889 DefaultingPyLocation location, py::object maybeIp) { 890 llvm::SmallVector<MlirValue, 4> mlirOperands; 891 llvm::SmallVector<MlirType, 4> mlirResults; 892 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 893 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 894 895 // General parameter validation. 896 if (regions < 0) 897 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 898 899 // Unpack/validate operands. 900 if (operands) { 901 mlirOperands.reserve(operands->size()); 902 for (PyValue *operand : *operands) { 903 if (!operand) 904 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 905 mlirOperands.push_back(operand->get()); 906 } 907 } 908 909 // Unpack/validate results. 910 if (results) { 911 mlirResults.reserve(results->size()); 912 for (PyType *result : *results) { 913 // TODO: Verify result type originate from the same context. 914 if (!result) 915 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 916 mlirResults.push_back(*result); 917 } 918 } 919 // Unpack/validate attributes. 920 if (attributes) { 921 mlirAttributes.reserve(attributes->size()); 922 for (auto &it : *attributes) { 923 std::string key; 924 try { 925 key = it.first.cast<std::string>(); 926 } catch (py::cast_error &err) { 927 std::string msg = "Invalid attribute key (not a string) when " 928 "attempting to create the operation \"" + 929 name + "\" (" + err.what() + ")"; 930 throw py::cast_error(msg); 931 } 932 try { 933 auto &attribute = it.second.cast<PyAttribute &>(); 934 // TODO: Verify attribute originates from the same context. 935 mlirAttributes.emplace_back(std::move(key), attribute); 936 } catch (py::reference_cast_error &) { 937 // This exception seems thrown when the value is "None". 938 std::string msg = 939 "Found an invalid (`None`?) attribute value for the key \"" + key + 940 "\" when attempting to create the operation \"" + name + "\""; 941 throw py::cast_error(msg); 942 } catch (py::cast_error &err) { 943 std::string msg = "Invalid attribute value for the key \"" + key + 944 "\" when attempting to create the operation \"" + 945 name + "\" (" + err.what() + ")"; 946 throw py::cast_error(msg); 947 } 948 } 949 } 950 // Unpack/validate successors. 951 if (successors) { 952 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 953 mlirSuccessors.reserve(successors->size()); 954 for (auto *successor : *successors) { 955 // TODO: Verify successor originate from the same context. 956 if (!successor) 957 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 958 mlirSuccessors.push_back(successor->get()); 959 } 960 } 961 962 // Apply unpacked/validated to the operation state. Beyond this 963 // point, exceptions cannot be thrown or else the state will leak. 964 MlirOperationState state = 965 mlirOperationStateGet(toMlirStringRef(name), location); 966 if (!mlirOperands.empty()) 967 mlirOperationStateAddOperands(&state, mlirOperands.size(), 968 mlirOperands.data()); 969 if (!mlirResults.empty()) 970 mlirOperationStateAddResults(&state, mlirResults.size(), 971 mlirResults.data()); 972 if (!mlirAttributes.empty()) { 973 // Note that the attribute names directly reference bytes in 974 // mlirAttributes, so that vector must not be changed from here 975 // on. 976 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 977 mlirNamedAttributes.reserve(mlirAttributes.size()); 978 for (auto &it : mlirAttributes) 979 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 980 mlirIdentifierGet(mlirAttributeGetContext(it.second), 981 toMlirStringRef(it.first)), 982 it.second)); 983 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 984 mlirNamedAttributes.data()); 985 } 986 if (!mlirSuccessors.empty()) 987 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 988 mlirSuccessors.data()); 989 if (regions) { 990 llvm::SmallVector<MlirRegion, 4> mlirRegions; 991 mlirRegions.resize(regions); 992 for (int i = 0; i < regions; ++i) 993 mlirRegions[i] = mlirRegionCreate(); 994 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 995 mlirRegions.data()); 996 } 997 998 // Construct the operation. 999 MlirOperation operation = mlirOperationCreate(&state); 1000 PyOperationRef created = 1001 PyOperation::createDetached(location->getContext(), operation); 1002 1003 // InsertPoint active? 1004 if (!maybeIp.is(py::cast(false))) { 1005 PyInsertionPoint *ip; 1006 if (maybeIp.is_none()) { 1007 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1008 } else { 1009 ip = py::cast<PyInsertionPoint *>(maybeIp); 1010 } 1011 if (ip) 1012 ip->insert(*created.get()); 1013 } 1014 1015 return created->createOpView(); 1016 } 1017 1018 py::object PyOperation::createOpView() { 1019 MlirIdentifier ident = mlirOperationGetName(get()); 1020 MlirStringRef identStr = mlirIdentifierStr(ident); 1021 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1022 StringRef(identStr.data, identStr.length)); 1023 if (opViewClass) 1024 return (*opViewClass)(getRef().getObject()); 1025 return py::cast(PyOpView(getRef().getObject())); 1026 } 1027 1028 //------------------------------------------------------------------------------ 1029 // PyOpView 1030 //------------------------------------------------------------------------------ 1031 1032 py::object 1033 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1034 py::list operandList, 1035 llvm::Optional<py::dict> attributes, 1036 llvm::Optional<std::vector<PyBlock *>> successors, 1037 llvm::Optional<int> regions, 1038 DefaultingPyLocation location, py::object maybeIp) { 1039 PyMlirContextRef context = location->getContext(); 1040 // Class level operation construction metadata. 1041 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1042 // Operand and result segment specs are either none, which does no 1043 // variadic unpacking, or a list of ints with segment sizes, where each 1044 // element is either a positive number (typically 1 for a scalar) or -1 to 1045 // indicate that it is derived from the length of the same-indexed operand 1046 // or result (implying that it is a list at that position). 1047 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1048 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1049 1050 std::vector<uint32_t> operandSegmentLengths; 1051 std::vector<uint32_t> resultSegmentLengths; 1052 1053 // Validate/determine region count. 1054 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1055 int opMinRegionCount = std::get<0>(opRegionSpec); 1056 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1057 if (!regions) { 1058 regions = opMinRegionCount; 1059 } 1060 if (*regions < opMinRegionCount) { 1061 throw py::value_error( 1062 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1063 llvm::Twine(opMinRegionCount) + 1064 " regions but was built with regions=" + llvm::Twine(*regions)) 1065 .str()); 1066 } 1067 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1068 throw py::value_error( 1069 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1070 llvm::Twine(opMinRegionCount) + 1071 " regions but was built with regions=" + llvm::Twine(*regions)) 1072 .str()); 1073 } 1074 1075 // Unpack results. 1076 std::vector<PyType *> resultTypes; 1077 resultTypes.reserve(resultTypeList.size()); 1078 if (resultSegmentSpecObj.is_none()) { 1079 // Non-variadic result unpacking. 1080 for (auto it : llvm::enumerate(resultTypeList)) { 1081 try { 1082 resultTypes.push_back(py::cast<PyType *>(it.value())); 1083 if (!resultTypes.back()) 1084 throw py::cast_error(); 1085 } catch (py::cast_error &err) { 1086 throw py::value_error((llvm::Twine("Result ") + 1087 llvm::Twine(it.index()) + " of operation \"" + 1088 name + "\" must be a Type (" + err.what() + ")") 1089 .str()); 1090 } 1091 } 1092 } else { 1093 // Sized result unpacking. 1094 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1095 if (resultSegmentSpec.size() != resultTypeList.size()) { 1096 throw py::value_error((llvm::Twine("Operation \"") + name + 1097 "\" requires " + 1098 llvm::Twine(resultSegmentSpec.size()) + 1099 "result segments but was provided " + 1100 llvm::Twine(resultTypeList.size())) 1101 .str()); 1102 } 1103 resultSegmentLengths.reserve(resultTypeList.size()); 1104 for (auto it : 1105 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1106 int segmentSpec = std::get<1>(it.value()); 1107 if (segmentSpec == 1 || segmentSpec == 0) { 1108 // Unpack unary element. 1109 try { 1110 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1111 if (resultType) { 1112 resultTypes.push_back(resultType); 1113 resultSegmentLengths.push_back(1); 1114 } else if (segmentSpec == 0) { 1115 // Allowed to be optional. 1116 resultSegmentLengths.push_back(0); 1117 } else { 1118 throw py::cast_error("was None and result is not optional"); 1119 } 1120 } catch (py::cast_error &err) { 1121 throw py::value_error((llvm::Twine("Result ") + 1122 llvm::Twine(it.index()) + " of operation \"" + 1123 name + "\" must be a Type (" + err.what() + 1124 ")") 1125 .str()); 1126 } 1127 } else if (segmentSpec == -1) { 1128 // Unpack sequence by appending. 1129 try { 1130 if (std::get<0>(it.value()).is_none()) { 1131 // Treat it as an empty list. 1132 resultSegmentLengths.push_back(0); 1133 } else { 1134 // Unpack the list. 1135 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1136 for (py::object segmentItem : segment) { 1137 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1138 if (!resultTypes.back()) { 1139 throw py::cast_error("contained a None item"); 1140 } 1141 } 1142 resultSegmentLengths.push_back(segment.size()); 1143 } 1144 } catch (std::exception &err) { 1145 // NOTE: Sloppy to be using a catch-all here, but there are at least 1146 // three different unrelated exceptions that can be thrown in the 1147 // above "casts". Just keep the scope above small and catch them all. 1148 throw py::value_error((llvm::Twine("Result ") + 1149 llvm::Twine(it.index()) + " of operation \"" + 1150 name + "\" must be a Sequence of Types (" + 1151 err.what() + ")") 1152 .str()); 1153 } 1154 } else { 1155 throw py::value_error("Unexpected segment spec"); 1156 } 1157 } 1158 } 1159 1160 // Unpack operands. 1161 std::vector<PyValue *> operands; 1162 operands.reserve(operands.size()); 1163 if (operandSegmentSpecObj.is_none()) { 1164 // Non-sized operand unpacking. 1165 for (auto it : llvm::enumerate(operandList)) { 1166 try { 1167 operands.push_back(py::cast<PyValue *>(it.value())); 1168 if (!operands.back()) 1169 throw py::cast_error(); 1170 } catch (py::cast_error &err) { 1171 throw py::value_error((llvm::Twine("Operand ") + 1172 llvm::Twine(it.index()) + " of operation \"" + 1173 name + "\" must be a Value (" + err.what() + ")") 1174 .str()); 1175 } 1176 } 1177 } else { 1178 // Sized operand unpacking. 1179 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1180 if (operandSegmentSpec.size() != operandList.size()) { 1181 throw py::value_error((llvm::Twine("Operation \"") + name + 1182 "\" requires " + 1183 llvm::Twine(operandSegmentSpec.size()) + 1184 "operand segments but was provided " + 1185 llvm::Twine(operandList.size())) 1186 .str()); 1187 } 1188 operandSegmentLengths.reserve(operandList.size()); 1189 for (auto it : 1190 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1191 int segmentSpec = std::get<1>(it.value()); 1192 if (segmentSpec == 1 || segmentSpec == 0) { 1193 // Unpack unary element. 1194 try { 1195 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1196 if (operandValue) { 1197 operands.push_back(operandValue); 1198 operandSegmentLengths.push_back(1); 1199 } else if (segmentSpec == 0) { 1200 // Allowed to be optional. 1201 operandSegmentLengths.push_back(0); 1202 } else { 1203 throw py::cast_error("was None and operand is not optional"); 1204 } 1205 } catch (py::cast_error &err) { 1206 throw py::value_error((llvm::Twine("Operand ") + 1207 llvm::Twine(it.index()) + " of operation \"" + 1208 name + "\" must be a Value (" + err.what() + 1209 ")") 1210 .str()); 1211 } 1212 } else if (segmentSpec == -1) { 1213 // Unpack sequence by appending. 1214 try { 1215 if (std::get<0>(it.value()).is_none()) { 1216 // Treat it as an empty list. 1217 operandSegmentLengths.push_back(0); 1218 } else { 1219 // Unpack the list. 1220 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1221 for (py::object segmentItem : segment) { 1222 operands.push_back(py::cast<PyValue *>(segmentItem)); 1223 if (!operands.back()) { 1224 throw py::cast_error("contained a None item"); 1225 } 1226 } 1227 operandSegmentLengths.push_back(segment.size()); 1228 } 1229 } catch (std::exception &err) { 1230 // NOTE: Sloppy to be using a catch-all here, but there are at least 1231 // three different unrelated exceptions that can be thrown in the 1232 // above "casts". Just keep the scope above small and catch them all. 1233 throw py::value_error((llvm::Twine("Operand ") + 1234 llvm::Twine(it.index()) + " of operation \"" + 1235 name + "\" must be a Sequence of Values (" + 1236 err.what() + ")") 1237 .str()); 1238 } 1239 } else { 1240 throw py::value_error("Unexpected segment spec"); 1241 } 1242 } 1243 } 1244 1245 // Merge operand/result segment lengths into attributes if needed. 1246 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1247 // Dup. 1248 if (attributes) { 1249 attributes = py::dict(*attributes); 1250 } else { 1251 attributes = py::dict(); 1252 } 1253 if (attributes->contains("result_segment_sizes") || 1254 attributes->contains("operand_segment_sizes")) { 1255 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1256 "'operand_segment_sizes' attribute is unsupported. " 1257 "Use Operation.create for such low-level access."); 1258 } 1259 1260 // Add result_segment_sizes attribute. 1261 if (!resultSegmentLengths.empty()) { 1262 int64_t size = resultSegmentLengths.size(); 1263 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1264 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1265 resultSegmentLengths.size(), resultSegmentLengths.data()); 1266 (*attributes)["result_segment_sizes"] = 1267 PyAttribute(context, segmentLengthAttr); 1268 } 1269 1270 // Add operand_segment_sizes attribute. 1271 if (!operandSegmentLengths.empty()) { 1272 int64_t size = operandSegmentLengths.size(); 1273 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1274 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1275 operandSegmentLengths.size(), operandSegmentLengths.data()); 1276 (*attributes)["operand_segment_sizes"] = 1277 PyAttribute(context, segmentLengthAttr); 1278 } 1279 } 1280 1281 // Delegate to create. 1282 return PyOperation::create(std::move(name), 1283 /*results=*/std::move(resultTypes), 1284 /*operands=*/std::move(operands), 1285 /*attributes=*/std::move(attributes), 1286 /*successors=*/std::move(successors), 1287 /*regions=*/*regions, location, maybeIp); 1288 } 1289 1290 PyOpView::PyOpView(py::object operationObject) 1291 // Casting through the PyOperationBase base-class and then back to the 1292 // Operation lets us accept any PyOperationBase subclass. 1293 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1294 operationObject(operation.getRef().getObject()) {} 1295 1296 py::object PyOpView::createRawSubclass(py::object userClass) { 1297 // This is... a little gross. The typical pattern is to have a pure python 1298 // class that extends OpView like: 1299 // class AddFOp(_cext.ir.OpView): 1300 // def __init__(self, loc, lhs, rhs): 1301 // operation = loc.context.create_operation( 1302 // "addf", lhs, rhs, results=[lhs.type]) 1303 // super().__init__(operation) 1304 // 1305 // I.e. The goal of the user facing type is to provide a nice constructor 1306 // that has complete freedom for the op under construction. This is at odds 1307 // with our other desire to sometimes create this object by just passing an 1308 // operation (to initialize the base class). We could do *arg and **kwargs 1309 // munging to try to make it work, but instead, we synthesize a new class 1310 // on the fly which extends this user class (AddFOp in this example) and 1311 // *give it* the base class's __init__ method, thus bypassing the 1312 // intermediate subclass's __init__ method entirely. While slightly, 1313 // underhanded, this is safe/legal because the type hierarchy has not changed 1314 // (we just added a new leaf) and we aren't mucking around with __new__. 1315 // Typically, this new class will be stored on the original as "_Raw" and will 1316 // be used for casts and other things that need a variant of the class that 1317 // is initialized purely from an operation. 1318 py::object parentMetaclass = 1319 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1320 py::dict attributes; 1321 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1322 // now. 1323 // auto opViewType = py::type::of<PyOpView>(); 1324 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1325 attributes["__init__"] = opViewType.attr("__init__"); 1326 py::str origName = userClass.attr("__name__"); 1327 py::str newName = py::str("_") + origName; 1328 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1329 } 1330 1331 //------------------------------------------------------------------------------ 1332 // PyInsertionPoint. 1333 //------------------------------------------------------------------------------ 1334 1335 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1336 1337 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1338 : refOperation(beforeOperationBase.getOperation().getRef()), 1339 block((*refOperation)->getBlock()) {} 1340 1341 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1342 PyOperation &operation = operationBase.getOperation(); 1343 if (operation.isAttached()) 1344 throw SetPyError(PyExc_ValueError, 1345 "Attempt to insert operation that is already attached"); 1346 block.getParentOperation()->checkValid(); 1347 MlirOperation beforeOp = {nullptr}; 1348 if (refOperation) { 1349 // Insert before operation. 1350 (*refOperation)->checkValid(); 1351 beforeOp = (*refOperation)->get(); 1352 } else { 1353 // Insert at end (before null) is only valid if the block does not 1354 // already end in a known terminator (violating this will cause assertion 1355 // failures later). 1356 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1357 throw py::index_error("Cannot insert operation at the end of a block " 1358 "that already has a terminator. Did you mean to " 1359 "use 'InsertionPoint.at_block_terminator(block)' " 1360 "versus 'InsertionPoint(block)'?"); 1361 } 1362 } 1363 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1364 operation.setAttached(); 1365 } 1366 1367 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1368 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1369 if (mlirOperationIsNull(firstOp)) { 1370 // Just insert at end. 1371 return PyInsertionPoint(block); 1372 } 1373 1374 // Insert before first op. 1375 PyOperationRef firstOpRef = PyOperation::forOperation( 1376 block.getParentOperation()->getContext(), firstOp); 1377 return PyInsertionPoint{block, std::move(firstOpRef)}; 1378 } 1379 1380 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1381 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1382 if (mlirOperationIsNull(terminator)) 1383 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1384 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1385 block.getParentOperation()->getContext(), terminator); 1386 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1387 } 1388 1389 py::object PyInsertionPoint::contextEnter() { 1390 return PyThreadContextEntry::pushInsertionPoint(*this); 1391 } 1392 1393 void PyInsertionPoint::contextExit(pybind11::object excType, 1394 pybind11::object excVal, 1395 pybind11::object excTb) { 1396 PyThreadContextEntry::popInsertionPoint(*this); 1397 } 1398 1399 //------------------------------------------------------------------------------ 1400 // PyAttribute. 1401 //------------------------------------------------------------------------------ 1402 1403 bool PyAttribute::operator==(const PyAttribute &other) { 1404 return mlirAttributeEqual(attr, other.attr); 1405 } 1406 1407 py::object PyAttribute::getCapsule() { 1408 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1409 } 1410 1411 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1412 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1413 if (mlirAttributeIsNull(rawAttr)) 1414 throw py::error_already_set(); 1415 return PyAttribute( 1416 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1417 } 1418 1419 //------------------------------------------------------------------------------ 1420 // PyNamedAttribute. 1421 //------------------------------------------------------------------------------ 1422 1423 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1424 : ownedName(new std::string(std::move(ownedName))) { 1425 namedAttr = mlirNamedAttributeGet( 1426 mlirIdentifierGet(mlirAttributeGetContext(attr), 1427 toMlirStringRef(*this->ownedName)), 1428 attr); 1429 } 1430 1431 //------------------------------------------------------------------------------ 1432 // PyType. 1433 //------------------------------------------------------------------------------ 1434 1435 bool PyType::operator==(const PyType &other) { 1436 return mlirTypeEqual(type, other.type); 1437 } 1438 1439 py::object PyType::getCapsule() { 1440 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1441 } 1442 1443 PyType PyType::createFromCapsule(py::object capsule) { 1444 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1445 if (mlirTypeIsNull(rawType)) 1446 throw py::error_already_set(); 1447 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1448 rawType); 1449 } 1450 1451 //------------------------------------------------------------------------------ 1452 // PyValue and subclases. 1453 //------------------------------------------------------------------------------ 1454 1455 namespace { 1456 /// CRTP base class for Python MLIR values that subclass Value and should be 1457 /// castable from it. The value hierarchy is one level deep and is not supposed 1458 /// to accommodate other levels unless core MLIR changes. 1459 template <typename DerivedTy> 1460 class PyConcreteValue : public PyValue { 1461 public: 1462 // Derived classes must define statics for: 1463 // IsAFunctionTy isaFunction 1464 // const char *pyClassName 1465 // and redefine bindDerived. 1466 using ClassTy = py::class_<DerivedTy, PyValue>; 1467 using IsAFunctionTy = bool (*)(MlirValue); 1468 1469 PyConcreteValue() = default; 1470 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1471 : PyValue(operationRef, value) {} 1472 PyConcreteValue(PyValue &orig) 1473 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1474 1475 /// Attempts to cast the original value to the derived type and throws on 1476 /// type mismatches. 1477 static MlirValue castFrom(PyValue &orig) { 1478 if (!DerivedTy::isaFunction(orig.get())) { 1479 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1480 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1481 DerivedTy::pyClassName + 1482 " (from " + origRepr + ")"); 1483 } 1484 return orig.get(); 1485 } 1486 1487 /// Binds the Python module objects to functions of this class. 1488 static void bind(py::module &m) { 1489 auto cls = ClassTy(m, DerivedTy::pyClassName); 1490 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1491 DerivedTy::bindDerived(cls); 1492 } 1493 1494 /// Implemented by derived classes to add methods to the Python subclass. 1495 static void bindDerived(ClassTy &m) {} 1496 }; 1497 1498 /// Python wrapper for MlirBlockArgument. 1499 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1500 public: 1501 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1502 static constexpr const char *pyClassName = "BlockArgument"; 1503 using PyConcreteValue::PyConcreteValue; 1504 1505 static void bindDerived(ClassTy &c) { 1506 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1507 return PyBlock(self.getParentOperation(), 1508 mlirBlockArgumentGetOwner(self.get())); 1509 }); 1510 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1511 return mlirBlockArgumentGetArgNumber(self.get()); 1512 }); 1513 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1514 return mlirBlockArgumentSetType(self.get(), type); 1515 }); 1516 } 1517 }; 1518 1519 /// Python wrapper for MlirOpResult. 1520 class PyOpResult : public PyConcreteValue<PyOpResult> { 1521 public: 1522 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1523 static constexpr const char *pyClassName = "OpResult"; 1524 using PyConcreteValue::PyConcreteValue; 1525 1526 static void bindDerived(ClassTy &c) { 1527 c.def_property_readonly("owner", [](PyOpResult &self) { 1528 assert( 1529 mlirOperationEqual(self.getParentOperation()->get(), 1530 mlirOpResultGetOwner(self.get())) && 1531 "expected the owner of the value in Python to match that in the IR"); 1532 return self.getParentOperation(); 1533 }); 1534 c.def_property_readonly("result_number", [](PyOpResult &self) { 1535 return mlirOpResultGetResultNumber(self.get()); 1536 }); 1537 } 1538 }; 1539 1540 /// A list of block arguments. Internally, these are stored as consecutive 1541 /// elements, random access is cheap. The argument list is associated with the 1542 /// operation that contains the block (detached blocks are not allowed in 1543 /// Python bindings) and extends its lifetime. 1544 class PyBlockArgumentList { 1545 public: 1546 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1547 : operation(std::move(operation)), block(block) {} 1548 1549 /// Returns the length of the block argument list. 1550 intptr_t dunderLen() { 1551 operation->checkValid(); 1552 return mlirBlockGetNumArguments(block); 1553 } 1554 1555 /// Returns `index`-th element of the block argument list. 1556 PyBlockArgument dunderGetItem(intptr_t index) { 1557 if (index < 0 || index >= dunderLen()) { 1558 throw SetPyError(PyExc_IndexError, 1559 "attempt to access out of bounds region"); 1560 } 1561 PyValue value(operation, mlirBlockGetArgument(block, index)); 1562 return PyBlockArgument(value); 1563 } 1564 1565 /// Defines a Python class in the bindings. 1566 static void bind(py::module &m) { 1567 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1568 .def("__len__", &PyBlockArgumentList::dunderLen) 1569 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1570 } 1571 1572 private: 1573 PyOperationRef operation; 1574 MlirBlock block; 1575 }; 1576 1577 /// A list of operation operands. Internally, these are stored as consecutive 1578 /// elements, random access is cheap. The result list is associated with the 1579 /// operation whose results these are, and extends the lifetime of this 1580 /// operation. 1581 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1582 public: 1583 static constexpr const char *pyClassName = "OpOperandList"; 1584 1585 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1586 intptr_t length = -1, intptr_t step = 1) 1587 : Sliceable(startIndex, 1588 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1589 : length, 1590 step), 1591 operation(operation) {} 1592 1593 intptr_t getNumElements() { 1594 operation->checkValid(); 1595 return mlirOperationGetNumOperands(operation->get()); 1596 } 1597 1598 PyValue getElement(intptr_t pos) { 1599 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); 1600 } 1601 1602 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1603 return PyOpOperandList(operation, startIndex, length, step); 1604 } 1605 1606 private: 1607 PyOperationRef operation; 1608 }; 1609 1610 /// A list of operation results. Internally, these are stored as consecutive 1611 /// elements, random access is cheap. The result list is associated with the 1612 /// operation whose results these are, and extends the lifetime of this 1613 /// operation. 1614 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1615 public: 1616 static constexpr const char *pyClassName = "OpResultList"; 1617 1618 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1619 intptr_t length = -1, intptr_t step = 1) 1620 : Sliceable(startIndex, 1621 length == -1 ? mlirOperationGetNumResults(operation->get()) 1622 : length, 1623 step), 1624 operation(operation) {} 1625 1626 intptr_t getNumElements() { 1627 operation->checkValid(); 1628 return mlirOperationGetNumResults(operation->get()); 1629 } 1630 1631 PyOpResult getElement(intptr_t index) { 1632 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1633 return PyOpResult(value); 1634 } 1635 1636 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1637 return PyOpResultList(operation, startIndex, length, step); 1638 } 1639 1640 private: 1641 PyOperationRef operation; 1642 }; 1643 1644 /// A list of operation attributes. Can be indexed by name, producing 1645 /// attributes, or by index, producing named attributes. 1646 class PyOpAttributeMap { 1647 public: 1648 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1649 1650 PyAttribute dunderGetItemNamed(const std::string &name) { 1651 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1652 toMlirStringRef(name)); 1653 if (mlirAttributeIsNull(attr)) { 1654 throw SetPyError(PyExc_KeyError, 1655 "attempt to access a non-existent attribute"); 1656 } 1657 return PyAttribute(operation->getContext(), attr); 1658 } 1659 1660 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1661 if (index < 0 || index >= dunderLen()) { 1662 throw SetPyError(PyExc_IndexError, 1663 "attempt to access out of bounds attribute"); 1664 } 1665 MlirNamedAttribute namedAttr = 1666 mlirOperationGetAttribute(operation->get(), index); 1667 return PyNamedAttribute( 1668 namedAttr.attribute, 1669 std::string(mlirIdentifierStr(namedAttr.name).data)); 1670 } 1671 1672 void dunderSetItem(const std::string &name, PyAttribute attr) { 1673 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1674 attr); 1675 } 1676 1677 void dunderDelItem(const std::string &name) { 1678 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1679 toMlirStringRef(name)); 1680 if (!removed) 1681 throw SetPyError(PyExc_KeyError, 1682 "attempt to delete a non-existent attribute"); 1683 } 1684 1685 intptr_t dunderLen() { 1686 return mlirOperationGetNumAttributes(operation->get()); 1687 } 1688 1689 bool dunderContains(const std::string &name) { 1690 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1691 operation->get(), toMlirStringRef(name))); 1692 } 1693 1694 static void bind(py::module &m) { 1695 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1696 .def("__contains__", &PyOpAttributeMap::dunderContains) 1697 .def("__len__", &PyOpAttributeMap::dunderLen) 1698 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1699 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1700 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1701 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1702 } 1703 1704 private: 1705 PyOperationRef operation; 1706 }; 1707 1708 } // end namespace 1709 1710 //------------------------------------------------------------------------------ 1711 // Populates the core exports of the 'ir' submodule. 1712 //------------------------------------------------------------------------------ 1713 1714 void mlir::python::populateIRCore(py::module &m) { 1715 //---------------------------------------------------------------------------- 1716 // Mapping of Global functions 1717 //---------------------------------------------------------------------------- 1718 m.def("_enable_debug", [](bool enable) { mlirEnableGlobalDebug(enable); }); 1719 1720 //---------------------------------------------------------------------------- 1721 // Mapping of MlirContext 1722 //---------------------------------------------------------------------------- 1723 py::class_<PyMlirContext>(m, "Context") 1724 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1725 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1726 .def("_get_context_again", 1727 [](PyMlirContext &self) { 1728 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1729 return ref.releaseObject(); 1730 }) 1731 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1732 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1733 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1734 &PyMlirContext::getCapsule) 1735 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1736 .def("__enter__", &PyMlirContext::contextEnter) 1737 .def("__exit__", &PyMlirContext::contextExit) 1738 .def_property_readonly_static( 1739 "current", 1740 [](py::object & /*class*/) { 1741 auto *context = PyThreadContextEntry::getDefaultContext(); 1742 if (!context) 1743 throw SetPyError(PyExc_ValueError, "No current Context"); 1744 return context; 1745 }, 1746 "Gets the Context bound to the current thread or raises ValueError") 1747 .def_property_readonly( 1748 "dialects", 1749 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1750 "Gets a container for accessing dialects by name") 1751 .def_property_readonly( 1752 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1753 "Alias for 'dialect'") 1754 .def( 1755 "get_dialect_descriptor", 1756 [=](PyMlirContext &self, std::string &name) { 1757 MlirDialect dialect = mlirContextGetOrLoadDialect( 1758 self.get(), {name.data(), name.size()}); 1759 if (mlirDialectIsNull(dialect)) { 1760 throw SetPyError(PyExc_ValueError, 1761 Twine("Dialect '") + name + "' not found"); 1762 } 1763 return PyDialectDescriptor(self.getRef(), dialect); 1764 }, 1765 "Gets or loads a dialect by name, returning its descriptor object") 1766 .def_property( 1767 "allow_unregistered_dialects", 1768 [](PyMlirContext &self) -> bool { 1769 return mlirContextGetAllowUnregisteredDialects(self.get()); 1770 }, 1771 [](PyMlirContext &self, bool value) { 1772 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1773 }) 1774 .def("enable_multithreading", 1775 [](PyMlirContext &self, bool enable) { 1776 mlirContextEnableMultithreading(self.get(), enable); 1777 }) 1778 .def("is_registered_operation", 1779 [](PyMlirContext &self, std::string &name) { 1780 return mlirContextIsRegisteredOperation( 1781 self.get(), MlirStringRef{name.data(), name.size()}); 1782 }); 1783 1784 //---------------------------------------------------------------------------- 1785 // Mapping of PyDialectDescriptor 1786 //---------------------------------------------------------------------------- 1787 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1788 .def_property_readonly("namespace", 1789 [](PyDialectDescriptor &self) { 1790 MlirStringRef ns = 1791 mlirDialectGetNamespace(self.get()); 1792 return py::str(ns.data, ns.length); 1793 }) 1794 .def("__repr__", [](PyDialectDescriptor &self) { 1795 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1796 std::string repr("<DialectDescriptor "); 1797 repr.append(ns.data, ns.length); 1798 repr.append(">"); 1799 return repr; 1800 }); 1801 1802 //---------------------------------------------------------------------------- 1803 // Mapping of PyDialects 1804 //---------------------------------------------------------------------------- 1805 py::class_<PyDialects>(m, "Dialects") 1806 .def("__getitem__", 1807 [=](PyDialects &self, std::string keyName) { 1808 MlirDialect dialect = 1809 self.getDialectForKey(keyName, /*attrError=*/false); 1810 py::object descriptor = 1811 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1812 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1813 }) 1814 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1815 MlirDialect dialect = 1816 self.getDialectForKey(attrName, /*attrError=*/true); 1817 py::object descriptor = 1818 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1819 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1820 }); 1821 1822 //---------------------------------------------------------------------------- 1823 // Mapping of PyDialect 1824 //---------------------------------------------------------------------------- 1825 py::class_<PyDialect>(m, "Dialect") 1826 .def(py::init<py::object>(), "descriptor") 1827 .def_property_readonly( 1828 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1829 .def("__repr__", [](py::object self) { 1830 auto clazz = self.attr("__class__"); 1831 return py::str("<Dialect ") + 1832 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1833 clazz.attr("__module__") + py::str(".") + 1834 clazz.attr("__name__") + py::str(")>"); 1835 }); 1836 1837 //---------------------------------------------------------------------------- 1838 // Mapping of Location 1839 //---------------------------------------------------------------------------- 1840 py::class_<PyLocation>(m, "Location") 1841 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1842 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1843 .def("__enter__", &PyLocation::contextEnter) 1844 .def("__exit__", &PyLocation::contextExit) 1845 .def("__eq__", 1846 [](PyLocation &self, PyLocation &other) -> bool { 1847 return mlirLocationEqual(self, other); 1848 }) 1849 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1850 .def_property_readonly_static( 1851 "current", 1852 [](py::object & /*class*/) { 1853 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1854 if (!loc) 1855 throw SetPyError(PyExc_ValueError, "No current Location"); 1856 return loc; 1857 }, 1858 "Gets the Location bound to the current thread or raises ValueError") 1859 .def_static( 1860 "unknown", 1861 [](DefaultingPyMlirContext context) { 1862 return PyLocation(context->getRef(), 1863 mlirLocationUnknownGet(context->get())); 1864 }, 1865 py::arg("context") = py::none(), 1866 "Gets a Location representing an unknown location") 1867 .def_static( 1868 "file", 1869 [](std::string filename, int line, int col, 1870 DefaultingPyMlirContext context) { 1871 return PyLocation( 1872 context->getRef(), 1873 mlirLocationFileLineColGet( 1874 context->get(), toMlirStringRef(filename), line, col)); 1875 }, 1876 py::arg("filename"), py::arg("line"), py::arg("col"), 1877 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1878 .def_property_readonly( 1879 "context", 1880 [](PyLocation &self) { return self.getContext().getObject(); }, 1881 "Context that owns the Location") 1882 .def("__repr__", [](PyLocation &self) { 1883 PyPrintAccumulator printAccum; 1884 mlirLocationPrint(self, printAccum.getCallback(), 1885 printAccum.getUserData()); 1886 return printAccum.join(); 1887 }); 1888 1889 //---------------------------------------------------------------------------- 1890 // Mapping of Module 1891 //---------------------------------------------------------------------------- 1892 py::class_<PyModule>(m, "Module") 1893 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1894 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1895 .def_static( 1896 "parse", 1897 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1898 MlirModule module = mlirModuleCreateParse( 1899 context->get(), toMlirStringRef(moduleAsm)); 1900 // TODO: Rework error reporting once diagnostic engine is exposed 1901 // in C API. 1902 if (mlirModuleIsNull(module)) { 1903 throw SetPyError( 1904 PyExc_ValueError, 1905 "Unable to parse module assembly (see diagnostics)"); 1906 } 1907 return PyModule::forModule(module).releaseObject(); 1908 }, 1909 py::arg("asm"), py::arg("context") = py::none(), 1910 kModuleParseDocstring) 1911 .def_static( 1912 "create", 1913 [](DefaultingPyLocation loc) { 1914 MlirModule module = mlirModuleCreateEmpty(loc); 1915 return PyModule::forModule(module).releaseObject(); 1916 }, 1917 py::arg("loc") = py::none(), "Creates an empty module") 1918 .def_property_readonly( 1919 "context", 1920 [](PyModule &self) { return self.getContext().getObject(); }, 1921 "Context that created the Module") 1922 .def_property_readonly( 1923 "operation", 1924 [](PyModule &self) { 1925 return PyOperation::forOperation(self.getContext(), 1926 mlirModuleGetOperation(self.get()), 1927 self.getRef().releaseObject()) 1928 .releaseObject(); 1929 }, 1930 "Accesses the module as an operation") 1931 .def_property_readonly( 1932 "body", 1933 [](PyModule &self) { 1934 PyOperationRef module_op = PyOperation::forOperation( 1935 self.getContext(), mlirModuleGetOperation(self.get()), 1936 self.getRef().releaseObject()); 1937 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 1938 return returnBlock; 1939 }, 1940 "Return the block for this module") 1941 .def( 1942 "dump", 1943 [](PyModule &self) { 1944 mlirOperationDump(mlirModuleGetOperation(self.get())); 1945 }, 1946 kDumpDocstring) 1947 .def( 1948 "__str__", 1949 [](PyModule &self) { 1950 MlirOperation operation = mlirModuleGetOperation(self.get()); 1951 PyPrintAccumulator printAccum; 1952 mlirOperationPrint(operation, printAccum.getCallback(), 1953 printAccum.getUserData()); 1954 return printAccum.join(); 1955 }, 1956 kOperationStrDunderDocstring); 1957 1958 //---------------------------------------------------------------------------- 1959 // Mapping of Operation. 1960 //---------------------------------------------------------------------------- 1961 py::class_<PyOperationBase>(m, "_OperationBase") 1962 .def("__eq__", 1963 [](PyOperationBase &self, PyOperationBase &other) { 1964 return &self.getOperation() == &other.getOperation(); 1965 }) 1966 .def("__eq__", 1967 [](PyOperationBase &self, py::object other) { return false; }) 1968 .def_property_readonly("attributes", 1969 [](PyOperationBase &self) { 1970 return PyOpAttributeMap( 1971 self.getOperation().getRef()); 1972 }) 1973 .def_property_readonly("operands", 1974 [](PyOperationBase &self) { 1975 return PyOpOperandList( 1976 self.getOperation().getRef()); 1977 }) 1978 .def_property_readonly("regions", 1979 [](PyOperationBase &self) { 1980 return PyRegionList( 1981 self.getOperation().getRef()); 1982 }) 1983 .def_property_readonly( 1984 "results", 1985 [](PyOperationBase &self) { 1986 return PyOpResultList(self.getOperation().getRef()); 1987 }, 1988 "Returns the list of Operation results.") 1989 .def_property_readonly( 1990 "result", 1991 [](PyOperationBase &self) { 1992 auto &operation = self.getOperation(); 1993 auto numResults = mlirOperationGetNumResults(operation); 1994 if (numResults != 1) { 1995 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 1996 throw SetPyError( 1997 PyExc_ValueError, 1998 Twine("Cannot call .result on operation ") + 1999 StringRef(name.data, name.length) + " which has " + 2000 Twine(numResults) + 2001 " results (it is only valid for operations with a " 2002 "single result)"); 2003 } 2004 return PyOpResult(operation.getRef(), 2005 mlirOperationGetResult(operation, 0)); 2006 }, 2007 "Shortcut to get an op result if it has only one (throws an error " 2008 "otherwise).") 2009 .def("__iter__", 2010 [](PyOperationBase &self) { 2011 return PyRegionIterator(self.getOperation().getRef()); 2012 }) 2013 .def( 2014 "__str__", 2015 [](PyOperationBase &self) { 2016 return self.getAsm(/*binary=*/false, 2017 /*largeElementsLimit=*/llvm::None, 2018 /*enableDebugInfo=*/false, 2019 /*prettyDebugInfo=*/false, 2020 /*printGenericOpForm=*/false, 2021 /*useLocalScope=*/false); 2022 }, 2023 "Returns the assembly form of the operation.") 2024 .def("print", &PyOperationBase::print, 2025 // Careful: Lots of arguments must match up with print method. 2026 py::arg("file") = py::none(), py::arg("binary") = false, 2027 py::arg("large_elements_limit") = py::none(), 2028 py::arg("enable_debug_info") = false, 2029 py::arg("pretty_debug_info") = false, 2030 py::arg("print_generic_op_form") = false, 2031 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2032 .def("get_asm", &PyOperationBase::getAsm, 2033 // Careful: Lots of arguments must match up with get_asm method. 2034 py::arg("binary") = false, 2035 py::arg("large_elements_limit") = py::none(), 2036 py::arg("enable_debug_info") = false, 2037 py::arg("pretty_debug_info") = false, 2038 py::arg("print_generic_op_form") = false, 2039 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2040 .def( 2041 "verify", 2042 [](PyOperationBase &self) { 2043 return mlirOperationVerify(self.getOperation()); 2044 }, 2045 "Verify the operation and return true if it passes, false if it " 2046 "fails."); 2047 2048 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2049 .def_static("create", &PyOperation::create, py::arg("name"), 2050 py::arg("results") = py::none(), 2051 py::arg("operands") = py::none(), 2052 py::arg("attributes") = py::none(), 2053 py::arg("successors") = py::none(), py::arg("regions") = 0, 2054 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2055 kOperationCreateDocstring) 2056 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2057 &PyOperation::getCapsule) 2058 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2059 .def_property_readonly("name", 2060 [](PyOperation &self) { 2061 MlirOperation operation = self.get(); 2062 MlirStringRef name = mlirIdentifierStr( 2063 mlirOperationGetName(operation)); 2064 return py::str(name.data, name.length); 2065 }) 2066 .def_property_readonly( 2067 "context", 2068 [](PyOperation &self) { return self.getContext().getObject(); }, 2069 "Context that owns the Operation") 2070 .def_property_readonly("opview", &PyOperation::createOpView); 2071 2072 auto opViewClass = 2073 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2074 .def(py::init<py::object>()) 2075 .def_property_readonly("operation", &PyOpView::getOperationObject) 2076 .def_property_readonly( 2077 "context", 2078 [](PyOpView &self) { 2079 return self.getOperation().getContext().getObject(); 2080 }, 2081 "Context that owns the Operation") 2082 .def("__str__", [](PyOpView &self) { 2083 return py::str(self.getOperationObject()); 2084 }); 2085 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2086 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2087 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2088 opViewClass.attr("build_generic") = classmethod( 2089 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2090 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2091 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2092 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2093 "Builds a specific, generated OpView based on class level attributes."); 2094 2095 //---------------------------------------------------------------------------- 2096 // Mapping of PyRegion. 2097 //---------------------------------------------------------------------------- 2098 py::class_<PyRegion>(m, "Region") 2099 .def_property_readonly( 2100 "blocks", 2101 [](PyRegion &self) { 2102 return PyBlockList(self.getParentOperation(), self.get()); 2103 }, 2104 "Returns a forward-optimized sequence of blocks.") 2105 .def( 2106 "__iter__", 2107 [](PyRegion &self) { 2108 self.checkValid(); 2109 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2110 return PyBlockIterator(self.getParentOperation(), firstBlock); 2111 }, 2112 "Iterates over blocks in the region.") 2113 .def("__eq__", 2114 [](PyRegion &self, PyRegion &other) { 2115 return self.get().ptr == other.get().ptr; 2116 }) 2117 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2118 2119 //---------------------------------------------------------------------------- 2120 // Mapping of PyBlock. 2121 //---------------------------------------------------------------------------- 2122 py::class_<PyBlock>(m, "Block") 2123 .def_property_readonly( 2124 "arguments", 2125 [](PyBlock &self) { 2126 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2127 }, 2128 "Returns a list of block arguments.") 2129 .def_property_readonly( 2130 "operations", 2131 [](PyBlock &self) { 2132 return PyOperationList(self.getParentOperation(), self.get()); 2133 }, 2134 "Returns a forward-optimized sequence of operations.") 2135 .def( 2136 "__iter__", 2137 [](PyBlock &self) { 2138 self.checkValid(); 2139 MlirOperation firstOperation = 2140 mlirBlockGetFirstOperation(self.get()); 2141 return PyOperationIterator(self.getParentOperation(), 2142 firstOperation); 2143 }, 2144 "Iterates over operations in the block.") 2145 .def("__eq__", 2146 [](PyBlock &self, PyBlock &other) { 2147 return self.get().ptr == other.get().ptr; 2148 }) 2149 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2150 .def( 2151 "__str__", 2152 [](PyBlock &self) { 2153 self.checkValid(); 2154 PyPrintAccumulator printAccum; 2155 mlirBlockPrint(self.get(), printAccum.getCallback(), 2156 printAccum.getUserData()); 2157 return printAccum.join(); 2158 }, 2159 "Returns the assembly form of the block."); 2160 2161 //---------------------------------------------------------------------------- 2162 // Mapping of PyInsertionPoint. 2163 //---------------------------------------------------------------------------- 2164 2165 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2166 .def(py::init<PyBlock &>(), py::arg("block"), 2167 "Inserts after the last operation but still inside the block.") 2168 .def("__enter__", &PyInsertionPoint::contextEnter) 2169 .def("__exit__", &PyInsertionPoint::contextExit) 2170 .def_property_readonly_static( 2171 "current", 2172 [](py::object & /*class*/) { 2173 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2174 if (!ip) 2175 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2176 return ip; 2177 }, 2178 "Gets the InsertionPoint bound to the current thread or raises " 2179 "ValueError if none has been set") 2180 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2181 "Inserts before a referenced operation.") 2182 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2183 py::arg("block"), "Inserts at the beginning of the block.") 2184 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2185 py::arg("block"), "Inserts before the block terminator.") 2186 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2187 "Inserts an operation."); 2188 2189 //---------------------------------------------------------------------------- 2190 // Mapping of PyAttribute. 2191 //---------------------------------------------------------------------------- 2192 py::class_<PyAttribute>(m, "Attribute") 2193 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2194 &PyAttribute::getCapsule) 2195 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2196 .def_static( 2197 "parse", 2198 [](std::string attrSpec, DefaultingPyMlirContext context) { 2199 MlirAttribute type = mlirAttributeParseGet( 2200 context->get(), toMlirStringRef(attrSpec)); 2201 // TODO: Rework error reporting once diagnostic engine is exposed 2202 // in C API. 2203 if (mlirAttributeIsNull(type)) { 2204 throw SetPyError(PyExc_ValueError, 2205 Twine("Unable to parse attribute: '") + 2206 attrSpec + "'"); 2207 } 2208 return PyAttribute(context->getRef(), type); 2209 }, 2210 py::arg("asm"), py::arg("context") = py::none(), 2211 "Parses an attribute from an assembly form") 2212 .def_property_readonly( 2213 "context", 2214 [](PyAttribute &self) { return self.getContext().getObject(); }, 2215 "Context that owns the Attribute") 2216 .def_property_readonly("type", 2217 [](PyAttribute &self) { 2218 return PyType(self.getContext()->getRef(), 2219 mlirAttributeGetType(self)); 2220 }) 2221 .def( 2222 "get_named", 2223 [](PyAttribute &self, std::string name) { 2224 return PyNamedAttribute(self, std::move(name)); 2225 }, 2226 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2227 .def("__eq__", 2228 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2229 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2230 .def( 2231 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2232 kDumpDocstring) 2233 .def( 2234 "__str__", 2235 [](PyAttribute &self) { 2236 PyPrintAccumulator printAccum; 2237 mlirAttributePrint(self, printAccum.getCallback(), 2238 printAccum.getUserData()); 2239 return printAccum.join(); 2240 }, 2241 "Returns the assembly form of the Attribute.") 2242 .def("__repr__", [](PyAttribute &self) { 2243 // Generally, assembly formats are not printed for __repr__ because 2244 // this can cause exceptionally long debug output and exceptions. 2245 // However, attribute values are generally considered useful and are 2246 // printed. This may need to be re-evaluated if debug dumps end up 2247 // being excessive. 2248 PyPrintAccumulator printAccum; 2249 printAccum.parts.append("Attribute("); 2250 mlirAttributePrint(self, printAccum.getCallback(), 2251 printAccum.getUserData()); 2252 printAccum.parts.append(")"); 2253 return printAccum.join(); 2254 }); 2255 2256 //---------------------------------------------------------------------------- 2257 // Mapping of PyNamedAttribute 2258 //---------------------------------------------------------------------------- 2259 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2260 .def("__repr__", 2261 [](PyNamedAttribute &self) { 2262 PyPrintAccumulator printAccum; 2263 printAccum.parts.append("NamedAttribute("); 2264 printAccum.parts.append( 2265 mlirIdentifierStr(self.namedAttr.name).data); 2266 printAccum.parts.append("="); 2267 mlirAttributePrint(self.namedAttr.attribute, 2268 printAccum.getCallback(), 2269 printAccum.getUserData()); 2270 printAccum.parts.append(")"); 2271 return printAccum.join(); 2272 }) 2273 .def_property_readonly( 2274 "name", 2275 [](PyNamedAttribute &self) { 2276 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2277 mlirIdentifierStr(self.namedAttr.name).length); 2278 }, 2279 "The name of the NamedAttribute binding") 2280 .def_property_readonly( 2281 "attr", 2282 [](PyNamedAttribute &self) { 2283 // TODO: When named attribute is removed/refactored, also remove 2284 // this constructor (it does an inefficient table lookup). 2285 auto contextRef = PyMlirContext::forContext( 2286 mlirAttributeGetContext(self.namedAttr.attribute)); 2287 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2288 }, 2289 py::keep_alive<0, 1>(), 2290 "The underlying generic attribute of the NamedAttribute binding"); 2291 2292 //---------------------------------------------------------------------------- 2293 // Mapping of PyType. 2294 //---------------------------------------------------------------------------- 2295 py::class_<PyType>(m, "Type") 2296 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2297 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2298 .def_static( 2299 "parse", 2300 [](std::string typeSpec, DefaultingPyMlirContext context) { 2301 MlirType type = 2302 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2303 // TODO: Rework error reporting once diagnostic engine is exposed 2304 // in C API. 2305 if (mlirTypeIsNull(type)) { 2306 throw SetPyError(PyExc_ValueError, 2307 Twine("Unable to parse type: '") + typeSpec + 2308 "'"); 2309 } 2310 return PyType(context->getRef(), type); 2311 }, 2312 py::arg("asm"), py::arg("context") = py::none(), 2313 kContextParseTypeDocstring) 2314 .def_property_readonly( 2315 "context", [](PyType &self) { return self.getContext().getObject(); }, 2316 "Context that owns the Type") 2317 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2318 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2319 .def( 2320 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2321 .def( 2322 "__str__", 2323 [](PyType &self) { 2324 PyPrintAccumulator printAccum; 2325 mlirTypePrint(self, printAccum.getCallback(), 2326 printAccum.getUserData()); 2327 return printAccum.join(); 2328 }, 2329 "Returns the assembly form of the type.") 2330 .def("__repr__", [](PyType &self) { 2331 // Generally, assembly formats are not printed for __repr__ because 2332 // this can cause exceptionally long debug output and exceptions. 2333 // However, types are an exception as they typically have compact 2334 // assembly forms and printing them is useful. 2335 PyPrintAccumulator printAccum; 2336 printAccum.parts.append("Type("); 2337 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2338 printAccum.parts.append(")"); 2339 return printAccum.join(); 2340 }); 2341 2342 //---------------------------------------------------------------------------- 2343 // Mapping of Value. 2344 //---------------------------------------------------------------------------- 2345 py::class_<PyValue>(m, "Value") 2346 .def_property_readonly( 2347 "context", 2348 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2349 "Context in which the value lives.") 2350 .def( 2351 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2352 kDumpDocstring) 2353 .def("__eq__", 2354 [](PyValue &self, PyValue &other) { 2355 return self.get().ptr == other.get().ptr; 2356 }) 2357 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2358 .def( 2359 "__str__", 2360 [](PyValue &self) { 2361 PyPrintAccumulator printAccum; 2362 printAccum.parts.append("Value("); 2363 mlirValuePrint(self.get(), printAccum.getCallback(), 2364 printAccum.getUserData()); 2365 printAccum.parts.append(")"); 2366 return printAccum.join(); 2367 }, 2368 kValueDunderStrDocstring) 2369 .def_property_readonly("type", [](PyValue &self) { 2370 return PyType(self.getParentOperation()->getContext(), 2371 mlirValueGetType(self.get())); 2372 }); 2373 PyBlockArgument::bind(m); 2374 PyOpResult::bind(m); 2375 2376 // Container bindings. 2377 PyBlockArgumentList::bind(m); 2378 PyBlockIterator::bind(m); 2379 PyBlockList::bind(m); 2380 PyOperationIterator::bind(m); 2381 PyOperationList::bind(m); 2382 PyOpAttributeMap::bind(m); 2383 PyOpOperandList::bind(m); 2384 PyOpResultList::bind(m); 2385 PyRegionIterator::bind(m); 2386 PyRegionList::bind(m); 2387 } 2388