1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Registration.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include <pybind11/stl.h> 20 21 namespace py = pybind11; 22 using namespace mlir; 23 using namespace mlir::python; 24 25 using llvm::SmallVector; 26 using llvm::StringRef; 27 using llvm::Twine; 28 29 //------------------------------------------------------------------------------ 30 // Docstrings (trivial, non-duplicated docstrings are included inline). 31 //------------------------------------------------------------------------------ 32 33 static const char kContextParseTypeDocstring[] = 34 R"(Parses the assembly form of a type. 35 36 Returns a Type object or raises a ValueError if the type cannot be parsed. 37 38 See also: https://mlir.llvm.org/docs/LangRef/#type-system 39 )"; 40 41 static const char kContextGetFileLocationDocstring[] = 42 R"(Gets a Location representing a file, line and column)"; 43 44 static const char kModuleParseDocstring[] = 45 R"(Parses a module's assembly format from a string. 46 47 Returns a new MlirModule or raises a ValueError if the parsing fails. 48 49 See also: https://mlir.llvm.org/docs/LangRef/ 50 )"; 51 52 static const char kOperationCreateDocstring[] = 53 R"(Creates a new operation. 54 55 Args: 56 name: Operation name (e.g. "dialect.operation"). 57 results: Sequence of Type representing op result types. 58 attributes: Dict of str:Attribute. 59 successors: List of Block for the operation's successors. 60 regions: Number of regions to create. 61 location: A Location object (defaults to resolve from context manager). 62 ip: An InsertionPoint (defaults to resolve from context manager or set to 63 False to disable insertion, even with an insertion point set in the 64 context manager). 65 Returns: 66 A new "detached" Operation object. Detached operations can be added 67 to blocks, which causes them to become "attached." 68 )"; 69 70 static const char kOperationPrintDocstring[] = 71 R"(Prints the assembly form of the operation to a file like object. 72 73 Args: 74 file: The file like object to write to. Defaults to sys.stdout. 75 binary: Whether to write bytes (True) or str (False). Defaults to False. 76 large_elements_limit: Whether to elide elements attributes above this 77 number of elements. Defaults to None (no limit). 78 enable_debug_info: Whether to print debug/location information. Defaults 79 to False. 80 pretty_debug_info: Whether to format debug information for easier reading 81 by a human (warning: the result is unparseable). 82 print_generic_op_form: Whether to print the generic assembly forms of all 83 ops. Defaults to False. 84 use_local_Scope: Whether to print in a way that is more optimized for 85 multi-threaded access but may not be consistent with how the overall 86 module prints. 87 )"; 88 89 static const char kOperationGetAsmDocstring[] = 90 R"(Gets the assembly form of the operation with all options available. 91 92 Args: 93 binary: Whether to return a bytes (True) or str (False) object. Defaults to 94 False. 95 ... others ...: See the print() method for common keyword arguments for 96 configuring the printout. 97 Returns: 98 Either a bytes or str object, depending on the setting of the 'binary' 99 argument. 100 )"; 101 102 static const char kOperationStrDunderDocstring[] = 103 R"(Gets the assembly form of the operation with default options. 104 105 If more advanced control over the assembly formatting or I/O options is needed, 106 use the dedicated print or get_asm method, which supports keyword arguments to 107 customize behavior. 108 )"; 109 110 static const char kDumpDocstring[] = 111 R"(Dumps a debug representation of the object to stderr.)"; 112 113 static const char kAppendBlockDocstring[] = 114 R"(Appends a new block, with argument types as positional args. 115 116 Returns: 117 The created block. 118 )"; 119 120 static const char kValueDunderStrDocstring[] = 121 R"(Returns the string form of the value. 122 123 If the value is a block argument, this is the assembly form of its type and the 124 position in the argument list. If the value is an operation result, this is 125 equivalent to printing the operation that produced it. 126 )"; 127 128 //------------------------------------------------------------------------------ 129 // Utilities. 130 //------------------------------------------------------------------------------ 131 132 // Helper for creating an @classmethod. 133 template <class Func, typename... Args> 134 py::object classmethod(Func f, Args... args) { 135 py::object cf = py::cpp_function(f, args...); 136 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 137 } 138 139 static py::object 140 createCustomDialectWrapper(const std::string &dialectNamespace, 141 py::object dialectDescriptor) { 142 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 143 if (!dialectClass) { 144 // Use the base class. 145 return py::cast(PyDialect(std::move(dialectDescriptor))); 146 } 147 148 // Create the custom implementation. 149 return (*dialectClass)(std::move(dialectDescriptor)); 150 } 151 152 static MlirStringRef toMlirStringRef(const std::string &s) { 153 return mlirStringRefCreate(s.data(), s.size()); 154 } 155 156 //------------------------------------------------------------------------------ 157 // Collections. 158 //------------------------------------------------------------------------------ 159 160 namespace { 161 162 class PyRegionIterator { 163 public: 164 PyRegionIterator(PyOperationRef operation) 165 : operation(std::move(operation)) {} 166 167 PyRegionIterator &dunderIter() { return *this; } 168 169 PyRegion dunderNext() { 170 operation->checkValid(); 171 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 172 throw py::stop_iteration(); 173 } 174 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 175 return PyRegion(operation, region); 176 } 177 178 static void bind(py::module &m) { 179 py::class_<PyRegionIterator>(m, "RegionIterator") 180 .def("__iter__", &PyRegionIterator::dunderIter) 181 .def("__next__", &PyRegionIterator::dunderNext); 182 } 183 184 private: 185 PyOperationRef operation; 186 int nextIndex = 0; 187 }; 188 189 /// Regions of an op are fixed length and indexed numerically so are represented 190 /// with a sequence-like container. 191 class PyRegionList { 192 public: 193 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 194 195 intptr_t dunderLen() { 196 operation->checkValid(); 197 return mlirOperationGetNumRegions(operation->get()); 198 } 199 200 PyRegion dunderGetItem(intptr_t index) { 201 // dunderLen checks validity. 202 if (index < 0 || index >= dunderLen()) { 203 throw SetPyError(PyExc_IndexError, 204 "attempt to access out of bounds region"); 205 } 206 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 207 return PyRegion(operation, region); 208 } 209 210 static void bind(py::module &m) { 211 py::class_<PyRegionList>(m, "RegionSequence") 212 .def("__len__", &PyRegionList::dunderLen) 213 .def("__getitem__", &PyRegionList::dunderGetItem); 214 } 215 216 private: 217 PyOperationRef operation; 218 }; 219 220 class PyBlockIterator { 221 public: 222 PyBlockIterator(PyOperationRef operation, MlirBlock next) 223 : operation(std::move(operation)), next(next) {} 224 225 PyBlockIterator &dunderIter() { return *this; } 226 227 PyBlock dunderNext() { 228 operation->checkValid(); 229 if (mlirBlockIsNull(next)) { 230 throw py::stop_iteration(); 231 } 232 233 PyBlock returnBlock(operation, next); 234 next = mlirBlockGetNextInRegion(next); 235 return returnBlock; 236 } 237 238 static void bind(py::module &m) { 239 py::class_<PyBlockIterator>(m, "BlockIterator") 240 .def("__iter__", &PyBlockIterator::dunderIter) 241 .def("__next__", &PyBlockIterator::dunderNext); 242 } 243 244 private: 245 PyOperationRef operation; 246 MlirBlock next; 247 }; 248 249 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 250 /// we present them as a more full-featured list-like container but optimize 251 /// it for forward iteration. Blocks are always owned by a region. 252 class PyBlockList { 253 public: 254 PyBlockList(PyOperationRef operation, MlirRegion region) 255 : operation(std::move(operation)), region(region) {} 256 257 PyBlockIterator dunderIter() { 258 operation->checkValid(); 259 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 260 } 261 262 intptr_t dunderLen() { 263 operation->checkValid(); 264 intptr_t count = 0; 265 MlirBlock block = mlirRegionGetFirstBlock(region); 266 while (!mlirBlockIsNull(block)) { 267 count += 1; 268 block = mlirBlockGetNextInRegion(block); 269 } 270 return count; 271 } 272 273 PyBlock dunderGetItem(intptr_t index) { 274 operation->checkValid(); 275 if (index < 0) { 276 throw SetPyError(PyExc_IndexError, 277 "attempt to access out of bounds block"); 278 } 279 MlirBlock block = mlirRegionGetFirstBlock(region); 280 while (!mlirBlockIsNull(block)) { 281 if (index == 0) { 282 return PyBlock(operation, block); 283 } 284 block = mlirBlockGetNextInRegion(block); 285 index -= 1; 286 } 287 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 288 } 289 290 PyBlock appendBlock(py::args pyArgTypes) { 291 operation->checkValid(); 292 llvm::SmallVector<MlirType, 4> argTypes; 293 argTypes.reserve(pyArgTypes.size()); 294 for (auto &pyArg : pyArgTypes) { 295 argTypes.push_back(pyArg.cast<PyType &>()); 296 } 297 298 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 299 mlirRegionAppendOwnedBlock(region, block); 300 return PyBlock(operation, block); 301 } 302 303 static void bind(py::module &m) { 304 py::class_<PyBlockList>(m, "BlockList") 305 .def("__getitem__", &PyBlockList::dunderGetItem) 306 .def("__iter__", &PyBlockList::dunderIter) 307 .def("__len__", &PyBlockList::dunderLen) 308 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 309 } 310 311 private: 312 PyOperationRef operation; 313 MlirRegion region; 314 }; 315 316 class PyOperationIterator { 317 public: 318 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 319 : parentOperation(std::move(parentOperation)), next(next) {} 320 321 PyOperationIterator &dunderIter() { return *this; } 322 323 py::object dunderNext() { 324 parentOperation->checkValid(); 325 if (mlirOperationIsNull(next)) { 326 throw py::stop_iteration(); 327 } 328 329 PyOperationRef returnOperation = 330 PyOperation::forOperation(parentOperation->getContext(), next); 331 next = mlirOperationGetNextInBlock(next); 332 return returnOperation->createOpView(); 333 } 334 335 static void bind(py::module &m) { 336 py::class_<PyOperationIterator>(m, "OperationIterator") 337 .def("__iter__", &PyOperationIterator::dunderIter) 338 .def("__next__", &PyOperationIterator::dunderNext); 339 } 340 341 private: 342 PyOperationRef parentOperation; 343 MlirOperation next; 344 }; 345 346 /// Operations are exposed by the C-API as a forward-only linked list. In 347 /// Python, we present them as a more full-featured list-like container but 348 /// optimize it for forward iteration. Iterable operations are always owned 349 /// by a block. 350 class PyOperationList { 351 public: 352 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 353 : parentOperation(std::move(parentOperation)), block(block) {} 354 355 PyOperationIterator dunderIter() { 356 parentOperation->checkValid(); 357 return PyOperationIterator(parentOperation, 358 mlirBlockGetFirstOperation(block)); 359 } 360 361 intptr_t dunderLen() { 362 parentOperation->checkValid(); 363 intptr_t count = 0; 364 MlirOperation childOp = mlirBlockGetFirstOperation(block); 365 while (!mlirOperationIsNull(childOp)) { 366 count += 1; 367 childOp = mlirOperationGetNextInBlock(childOp); 368 } 369 return count; 370 } 371 372 py::object dunderGetItem(intptr_t index) { 373 parentOperation->checkValid(); 374 if (index < 0) { 375 throw SetPyError(PyExc_IndexError, 376 "attempt to access out of bounds operation"); 377 } 378 MlirOperation childOp = mlirBlockGetFirstOperation(block); 379 while (!mlirOperationIsNull(childOp)) { 380 if (index == 0) { 381 return PyOperation::forOperation(parentOperation->getContext(), childOp) 382 ->createOpView(); 383 } 384 childOp = mlirOperationGetNextInBlock(childOp); 385 index -= 1; 386 } 387 throw SetPyError(PyExc_IndexError, 388 "attempt to access out of bounds operation"); 389 } 390 391 static void bind(py::module &m) { 392 py::class_<PyOperationList>(m, "OperationList") 393 .def("__getitem__", &PyOperationList::dunderGetItem) 394 .def("__iter__", &PyOperationList::dunderIter) 395 .def("__len__", &PyOperationList::dunderLen); 396 } 397 398 private: 399 PyOperationRef parentOperation; 400 MlirBlock block; 401 }; 402 403 } // namespace 404 405 //------------------------------------------------------------------------------ 406 // PyMlirContext 407 //------------------------------------------------------------------------------ 408 409 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 410 py::gil_scoped_acquire acquire; 411 auto &liveContexts = getLiveContexts(); 412 liveContexts[context.ptr] = this; 413 } 414 415 PyMlirContext::~PyMlirContext() { 416 // Note that the only public way to construct an instance is via the 417 // forContext method, which always puts the associated handle into 418 // liveContexts. 419 py::gil_scoped_acquire acquire; 420 getLiveContexts().erase(context.ptr); 421 mlirContextDestroy(context); 422 } 423 424 py::object PyMlirContext::getCapsule() { 425 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 426 } 427 428 py::object PyMlirContext::createFromCapsule(py::object capsule) { 429 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 430 if (mlirContextIsNull(rawContext)) 431 throw py::error_already_set(); 432 return forContext(rawContext).releaseObject(); 433 } 434 435 PyMlirContext *PyMlirContext::createNewContextForInit() { 436 MlirContext context = mlirContextCreate(); 437 mlirRegisterAllDialects(context); 438 return new PyMlirContext(context); 439 } 440 441 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 442 py::gil_scoped_acquire acquire; 443 auto &liveContexts = getLiveContexts(); 444 auto it = liveContexts.find(context.ptr); 445 if (it == liveContexts.end()) { 446 // Create. 447 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 448 py::object pyRef = py::cast(unownedContextWrapper); 449 assert(pyRef && "cast to py::object failed"); 450 liveContexts[context.ptr] = unownedContextWrapper; 451 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 452 } 453 // Use existing. 454 py::object pyRef = py::cast(it->second); 455 return PyMlirContextRef(it->second, std::move(pyRef)); 456 } 457 458 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 459 static LiveContextMap liveContexts; 460 return liveContexts; 461 } 462 463 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 464 465 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 466 467 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 468 469 pybind11::object PyMlirContext::contextEnter() { 470 return PyThreadContextEntry::pushContext(*this); 471 } 472 473 void PyMlirContext::contextExit(pybind11::object excType, 474 pybind11::object excVal, 475 pybind11::object excTb) { 476 PyThreadContextEntry::popContext(*this); 477 } 478 479 PyMlirContext &DefaultingPyMlirContext::resolve() { 480 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 481 if (!context) { 482 throw SetPyError( 483 PyExc_RuntimeError, 484 "An MLIR function requires a Context but none was provided in the call " 485 "or from the surrounding environment. Either pass to the function with " 486 "a 'context=' argument or establish a default using 'with Context():'"); 487 } 488 return *context; 489 } 490 491 //------------------------------------------------------------------------------ 492 // PyThreadContextEntry management 493 //------------------------------------------------------------------------------ 494 495 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 496 static thread_local std::vector<PyThreadContextEntry> stack; 497 return stack; 498 } 499 500 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 501 auto &stack = getStack(); 502 if (stack.empty()) 503 return nullptr; 504 return &stack.back(); 505 } 506 507 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 508 py::object insertionPoint, 509 py::object location) { 510 auto &stack = getStack(); 511 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 512 std::move(location)); 513 // If the new stack has more than one entry and the context of the new top 514 // entry matches the previous, copy the insertionPoint and location from the 515 // previous entry if missing from the new top entry. 516 if (stack.size() > 1) { 517 auto &prev = *(stack.rbegin() + 1); 518 auto ¤t = stack.back(); 519 if (current.context.is(prev.context)) { 520 // Default non-context objects from the previous entry. 521 if (!current.insertionPoint) 522 current.insertionPoint = prev.insertionPoint; 523 if (!current.location) 524 current.location = prev.location; 525 } 526 } 527 } 528 529 PyMlirContext *PyThreadContextEntry::getContext() { 530 if (!context) 531 return nullptr; 532 return py::cast<PyMlirContext *>(context); 533 } 534 535 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 536 if (!insertionPoint) 537 return nullptr; 538 return py::cast<PyInsertionPoint *>(insertionPoint); 539 } 540 541 PyLocation *PyThreadContextEntry::getLocation() { 542 if (!location) 543 return nullptr; 544 return py::cast<PyLocation *>(location); 545 } 546 547 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 548 auto *tos = getTopOfStack(); 549 return tos ? tos->getContext() : nullptr; 550 } 551 552 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 553 auto *tos = getTopOfStack(); 554 return tos ? tos->getInsertionPoint() : nullptr; 555 } 556 557 PyLocation *PyThreadContextEntry::getDefaultLocation() { 558 auto *tos = getTopOfStack(); 559 return tos ? tos->getLocation() : nullptr; 560 } 561 562 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 563 py::object contextObj = py::cast(context); 564 push(FrameKind::Context, /*context=*/contextObj, 565 /*insertionPoint=*/py::object(), 566 /*location=*/py::object()); 567 return contextObj; 568 } 569 570 void PyThreadContextEntry::popContext(PyMlirContext &context) { 571 auto &stack = getStack(); 572 if (stack.empty()) 573 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 574 auto &tos = stack.back(); 575 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 576 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 577 stack.pop_back(); 578 } 579 580 py::object 581 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 582 py::object contextObj = 583 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 584 py::object insertionPointObj = py::cast(insertionPoint); 585 push(FrameKind::InsertionPoint, 586 /*context=*/contextObj, 587 /*insertionPoint=*/insertionPointObj, 588 /*location=*/py::object()); 589 return insertionPointObj; 590 } 591 592 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 593 auto &stack = getStack(); 594 if (stack.empty()) 595 throw SetPyError(PyExc_RuntimeError, 596 "Unbalanced InsertionPoint enter/exit"); 597 auto &tos = stack.back(); 598 if (tos.frameKind != FrameKind::InsertionPoint && 599 tos.getInsertionPoint() != &insertionPoint) 600 throw SetPyError(PyExc_RuntimeError, 601 "Unbalanced InsertionPoint enter/exit"); 602 stack.pop_back(); 603 } 604 605 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 606 py::object contextObj = location.getContext().getObject(); 607 py::object locationObj = py::cast(location); 608 push(FrameKind::Location, /*context=*/contextObj, 609 /*insertionPoint=*/py::object(), 610 /*location=*/locationObj); 611 return locationObj; 612 } 613 614 void PyThreadContextEntry::popLocation(PyLocation &location) { 615 auto &stack = getStack(); 616 if (stack.empty()) 617 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 618 auto &tos = stack.back(); 619 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 620 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 621 stack.pop_back(); 622 } 623 624 //------------------------------------------------------------------------------ 625 // PyDialect, PyDialectDescriptor, PyDialects 626 //------------------------------------------------------------------------------ 627 628 MlirDialect PyDialects::getDialectForKey(const std::string &key, 629 bool attrError) { 630 // If the "std" dialect was asked for, substitute the empty namespace :( 631 static const std::string emptyKey; 632 const std::string *canonKey = key == "std" ? &emptyKey : &key; 633 MlirDialect dialect = mlirContextGetOrLoadDialect( 634 getContext()->get(), {canonKey->data(), canonKey->size()}); 635 if (mlirDialectIsNull(dialect)) { 636 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 637 Twine("Dialect '") + key + "' not found"); 638 } 639 return dialect; 640 } 641 642 //------------------------------------------------------------------------------ 643 // PyLocation 644 //------------------------------------------------------------------------------ 645 646 py::object PyLocation::getCapsule() { 647 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 648 } 649 650 PyLocation PyLocation::createFromCapsule(py::object capsule) { 651 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 652 if (mlirLocationIsNull(rawLoc)) 653 throw py::error_already_set(); 654 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 655 rawLoc); 656 } 657 658 py::object PyLocation::contextEnter() { 659 return PyThreadContextEntry::pushLocation(*this); 660 } 661 662 void PyLocation::contextExit(py::object excType, py::object excVal, 663 py::object excTb) { 664 PyThreadContextEntry::popLocation(*this); 665 } 666 667 PyLocation &DefaultingPyLocation::resolve() { 668 auto *location = PyThreadContextEntry::getDefaultLocation(); 669 if (!location) { 670 throw SetPyError( 671 PyExc_RuntimeError, 672 "An MLIR function requires a Location but none was provided in the " 673 "call or from the surrounding environment. Either pass to the function " 674 "with a 'loc=' argument or establish a default using 'with loc:'"); 675 } 676 return *location; 677 } 678 679 //------------------------------------------------------------------------------ 680 // PyModule 681 //------------------------------------------------------------------------------ 682 683 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 684 : BaseContextObject(std::move(contextRef)), module(module) {} 685 686 PyModule::~PyModule() { 687 py::gil_scoped_acquire acquire; 688 auto &liveModules = getContext()->liveModules; 689 assert(liveModules.count(module.ptr) == 1 && 690 "destroying module not in live map"); 691 liveModules.erase(module.ptr); 692 mlirModuleDestroy(module); 693 } 694 695 PyModuleRef PyModule::forModule(MlirModule module) { 696 MlirContext context = mlirModuleGetContext(module); 697 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 698 699 py::gil_scoped_acquire acquire; 700 auto &liveModules = contextRef->liveModules; 701 auto it = liveModules.find(module.ptr); 702 if (it == liveModules.end()) { 703 // Create. 704 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 705 // Note that the default return value policy on cast is automatic_reference, 706 // which does not take ownership (delete will not be called). 707 // Just be explicit. 708 py::object pyRef = 709 py::cast(unownedModule, py::return_value_policy::take_ownership); 710 unownedModule->handle = pyRef; 711 liveModules[module.ptr] = 712 std::make_pair(unownedModule->handle, unownedModule); 713 return PyModuleRef(unownedModule, std::move(pyRef)); 714 } 715 // Use existing. 716 PyModule *existing = it->second.second; 717 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 718 return PyModuleRef(existing, std::move(pyRef)); 719 } 720 721 py::object PyModule::createFromCapsule(py::object capsule) { 722 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 723 if (mlirModuleIsNull(rawModule)) 724 throw py::error_already_set(); 725 return forModule(rawModule).releaseObject(); 726 } 727 728 py::object PyModule::getCapsule() { 729 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 730 } 731 732 //------------------------------------------------------------------------------ 733 // PyOperation 734 //------------------------------------------------------------------------------ 735 736 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 737 : BaseContextObject(std::move(contextRef)), operation(operation) {} 738 739 PyOperation::~PyOperation() { 740 auto &liveOperations = getContext()->liveOperations; 741 assert(liveOperations.count(operation.ptr) == 1 && 742 "destroying operation not in live map"); 743 liveOperations.erase(operation.ptr); 744 if (!isAttached()) { 745 mlirOperationDestroy(operation); 746 } 747 } 748 749 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 750 MlirOperation operation, 751 py::object parentKeepAlive) { 752 auto &liveOperations = contextRef->liveOperations; 753 // Create. 754 PyOperation *unownedOperation = 755 new PyOperation(std::move(contextRef), operation); 756 // Note that the default return value policy on cast is automatic_reference, 757 // which does not take ownership (delete will not be called). 758 // Just be explicit. 759 py::object pyRef = 760 py::cast(unownedOperation, py::return_value_policy::take_ownership); 761 unownedOperation->handle = pyRef; 762 if (parentKeepAlive) { 763 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 764 } 765 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 766 return PyOperationRef(unownedOperation, std::move(pyRef)); 767 } 768 769 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 770 MlirOperation operation, 771 py::object parentKeepAlive) { 772 auto &liveOperations = contextRef->liveOperations; 773 auto it = liveOperations.find(operation.ptr); 774 if (it == liveOperations.end()) { 775 // Create. 776 return createInstance(std::move(contextRef), operation, 777 std::move(parentKeepAlive)); 778 } 779 // Use existing. 780 PyOperation *existing = it->second.second; 781 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 782 return PyOperationRef(existing, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 assert(liveOperations.count(operation.ptr) == 0 && 790 "cannot create detached operation that already exists"); 791 (void)liveOperations; 792 793 PyOperationRef created = createInstance(std::move(contextRef), operation, 794 std::move(parentKeepAlive)); 795 created->attached = false; 796 return created; 797 } 798 799 void PyOperation::checkValid() const { 800 if (!valid) { 801 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 802 } 803 } 804 805 void PyOperationBase::print(py::object fileObject, bool binary, 806 llvm::Optional<int64_t> largeElementsLimit, 807 bool enableDebugInfo, bool prettyDebugInfo, 808 bool printGenericOpForm, bool useLocalScope) { 809 PyOperation &operation = getOperation(); 810 operation.checkValid(); 811 if (fileObject.is_none()) 812 fileObject = py::module::import("sys").attr("stdout"); 813 814 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 815 fileObject.attr("write")("// Verification failed, printing generic form\n"); 816 printGenericOpForm = true; 817 } 818 819 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 820 if (largeElementsLimit) 821 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 822 if (enableDebugInfo) 823 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 824 if (printGenericOpForm) 825 mlirOpPrintingFlagsPrintGenericOpForm(flags); 826 827 PyFileAccumulator accum(fileObject, binary); 828 py::gil_scoped_release(); 829 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 830 accum.getUserData()); 831 mlirOpPrintingFlagsDestroy(flags); 832 } 833 834 py::object PyOperationBase::getAsm(bool binary, 835 llvm::Optional<int64_t> largeElementsLimit, 836 bool enableDebugInfo, bool prettyDebugInfo, 837 bool printGenericOpForm, 838 bool useLocalScope) { 839 py::object fileObject; 840 if (binary) { 841 fileObject = py::module::import("io").attr("BytesIO")(); 842 } else { 843 fileObject = py::module::import("io").attr("StringIO")(); 844 } 845 print(fileObject, /*binary=*/binary, 846 /*largeElementsLimit=*/largeElementsLimit, 847 /*enableDebugInfo=*/enableDebugInfo, 848 /*prettyDebugInfo=*/prettyDebugInfo, 849 /*printGenericOpForm=*/printGenericOpForm, 850 /*useLocalScope=*/useLocalScope); 851 852 return fileObject.attr("getvalue")(); 853 } 854 855 PyOperationRef PyOperation::getParentOperation() { 856 if (!isAttached()) 857 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 858 MlirOperation operation = mlirOperationGetParentOperation(get()); 859 if (mlirOperationIsNull(operation)) 860 throw SetPyError(PyExc_ValueError, "Operation has no parent."); 861 return PyOperation::forOperation(getContext(), operation); 862 } 863 864 PyBlock PyOperation::getBlock() { 865 PyOperationRef parentOperation = getParentOperation(); 866 MlirBlock block = mlirOperationGetBlock(get()); 867 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 868 return PyBlock{std::move(parentOperation), block}; 869 } 870 871 py::object PyOperation::create( 872 std::string name, llvm::Optional<std::vector<PyType *>> results, 873 llvm::Optional<std::vector<PyValue *>> operands, 874 llvm::Optional<py::dict> attributes, 875 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 876 DefaultingPyLocation location, py::object maybeIp) { 877 llvm::SmallVector<MlirValue, 4> mlirOperands; 878 llvm::SmallVector<MlirType, 4> mlirResults; 879 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 880 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 881 882 // General parameter validation. 883 if (regions < 0) 884 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 885 886 // Unpack/validate operands. 887 if (operands) { 888 mlirOperands.reserve(operands->size()); 889 for (PyValue *operand : *operands) { 890 if (!operand) 891 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 892 mlirOperands.push_back(operand->get()); 893 } 894 } 895 896 // Unpack/validate results. 897 if (results) { 898 mlirResults.reserve(results->size()); 899 for (PyType *result : *results) { 900 // TODO: Verify result type originate from the same context. 901 if (!result) 902 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 903 mlirResults.push_back(*result); 904 } 905 } 906 // Unpack/validate attributes. 907 if (attributes) { 908 mlirAttributes.reserve(attributes->size()); 909 for (auto &it : *attributes) { 910 std::string key; 911 try { 912 key = it.first.cast<std::string>(); 913 } catch (py::cast_error &err) { 914 std::string msg = "Invalid attribute key (not a string) when " 915 "attempting to create the operation \"" + 916 name + "\" (" + err.what() + ")"; 917 throw py::cast_error(msg); 918 } 919 try { 920 auto &attribute = it.second.cast<PyAttribute &>(); 921 // TODO: Verify attribute originates from the same context. 922 mlirAttributes.emplace_back(std::move(key), attribute); 923 } catch (py::reference_cast_error &) { 924 // This exception seems thrown when the value is "None". 925 std::string msg = 926 "Found an invalid (`None`?) attribute value for the key \"" + key + 927 "\" when attempting to create the operation \"" + name + "\""; 928 throw py::cast_error(msg); 929 } catch (py::cast_error &err) { 930 std::string msg = "Invalid attribute value for the key \"" + key + 931 "\" when attempting to create the operation \"" + 932 name + "\" (" + err.what() + ")"; 933 throw py::cast_error(msg); 934 } 935 } 936 } 937 // Unpack/validate successors. 938 if (successors) { 939 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 940 mlirSuccessors.reserve(successors->size()); 941 for (auto *successor : *successors) { 942 // TODO: Verify successor originate from the same context. 943 if (!successor) 944 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 945 mlirSuccessors.push_back(successor->get()); 946 } 947 } 948 949 // Apply unpacked/validated to the operation state. Beyond this 950 // point, exceptions cannot be thrown or else the state will leak. 951 MlirOperationState state = 952 mlirOperationStateGet(toMlirStringRef(name), location); 953 if (!mlirOperands.empty()) 954 mlirOperationStateAddOperands(&state, mlirOperands.size(), 955 mlirOperands.data()); 956 if (!mlirResults.empty()) 957 mlirOperationStateAddResults(&state, mlirResults.size(), 958 mlirResults.data()); 959 if (!mlirAttributes.empty()) { 960 // Note that the attribute names directly reference bytes in 961 // mlirAttributes, so that vector must not be changed from here 962 // on. 963 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 964 mlirNamedAttributes.reserve(mlirAttributes.size()); 965 for (auto &it : mlirAttributes) 966 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 967 mlirIdentifierGet(mlirAttributeGetContext(it.second), 968 toMlirStringRef(it.first)), 969 it.second)); 970 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 971 mlirNamedAttributes.data()); 972 } 973 if (!mlirSuccessors.empty()) 974 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 975 mlirSuccessors.data()); 976 if (regions) { 977 llvm::SmallVector<MlirRegion, 4> mlirRegions; 978 mlirRegions.resize(regions); 979 for (int i = 0; i < regions; ++i) 980 mlirRegions[i] = mlirRegionCreate(); 981 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 982 mlirRegions.data()); 983 } 984 985 // Construct the operation. 986 MlirOperation operation = mlirOperationCreate(&state); 987 PyOperationRef created = 988 PyOperation::createDetached(location->getContext(), operation); 989 990 // InsertPoint active? 991 if (!maybeIp.is(py::cast(false))) { 992 PyInsertionPoint *ip; 993 if (maybeIp.is_none()) { 994 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 995 } else { 996 ip = py::cast<PyInsertionPoint *>(maybeIp); 997 } 998 if (ip) 999 ip->insert(*created.get()); 1000 } 1001 1002 return created->createOpView(); 1003 } 1004 1005 py::object PyOperation::createOpView() { 1006 MlirIdentifier ident = mlirOperationGetName(get()); 1007 MlirStringRef identStr = mlirIdentifierStr(ident); 1008 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1009 StringRef(identStr.data, identStr.length)); 1010 if (opViewClass) 1011 return (*opViewClass)(getRef().getObject()); 1012 return py::cast(PyOpView(getRef().getObject())); 1013 } 1014 1015 //------------------------------------------------------------------------------ 1016 // PyOpView 1017 //------------------------------------------------------------------------------ 1018 1019 py::object 1020 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1021 py::list operandList, 1022 llvm::Optional<py::dict> attributes, 1023 llvm::Optional<std::vector<PyBlock *>> successors, 1024 llvm::Optional<int> regions, 1025 DefaultingPyLocation location, py::object maybeIp) { 1026 PyMlirContextRef context = location->getContext(); 1027 // Class level operation construction metadata. 1028 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1029 // Operand and result segment specs are either none, which does no 1030 // variadic unpacking, or a list of ints with segment sizes, where each 1031 // element is either a positive number (typically 1 for a scalar) or -1 to 1032 // indicate that it is derived from the length of the same-indexed operand 1033 // or result (implying that it is a list at that position). 1034 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1035 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1036 1037 std::vector<uint32_t> operandSegmentLengths; 1038 std::vector<uint32_t> resultSegmentLengths; 1039 1040 // Validate/determine region count. 1041 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1042 int opMinRegionCount = std::get<0>(opRegionSpec); 1043 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1044 if (!regions) { 1045 regions = opMinRegionCount; 1046 } 1047 if (*regions < opMinRegionCount) { 1048 throw py::value_error( 1049 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1050 llvm::Twine(opMinRegionCount) + 1051 " regions but was built with regions=" + llvm::Twine(*regions)) 1052 .str()); 1053 } 1054 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1055 throw py::value_error( 1056 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1057 llvm::Twine(opMinRegionCount) + 1058 " regions but was built with regions=" + llvm::Twine(*regions)) 1059 .str()); 1060 } 1061 1062 // Unpack results. 1063 std::vector<PyType *> resultTypes; 1064 resultTypes.reserve(resultTypeList.size()); 1065 if (resultSegmentSpecObj.is_none()) { 1066 // Non-variadic result unpacking. 1067 for (auto it : llvm::enumerate(resultTypeList)) { 1068 try { 1069 resultTypes.push_back(py::cast<PyType *>(it.value())); 1070 if (!resultTypes.back()) 1071 throw py::cast_error(); 1072 } catch (py::cast_error &err) { 1073 throw py::value_error((llvm::Twine("Result ") + 1074 llvm::Twine(it.index()) + " of operation \"" + 1075 name + "\" must be a Type (" + err.what() + ")") 1076 .str()); 1077 } 1078 } 1079 } else { 1080 // Sized result unpacking. 1081 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1082 if (resultSegmentSpec.size() != resultTypeList.size()) { 1083 throw py::value_error((llvm::Twine("Operation \"") + name + 1084 "\" requires " + 1085 llvm::Twine(resultSegmentSpec.size()) + 1086 "result segments but was provided " + 1087 llvm::Twine(resultTypeList.size())) 1088 .str()); 1089 } 1090 resultSegmentLengths.reserve(resultTypeList.size()); 1091 for (auto it : 1092 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1093 int segmentSpec = std::get<1>(it.value()); 1094 if (segmentSpec == 1 || segmentSpec == 0) { 1095 // Unpack unary element. 1096 try { 1097 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1098 if (resultType) { 1099 resultTypes.push_back(resultType); 1100 resultSegmentLengths.push_back(1); 1101 } else if (segmentSpec == 0) { 1102 // Allowed to be optional. 1103 resultSegmentLengths.push_back(0); 1104 } else { 1105 throw py::cast_error("was None and result is not optional"); 1106 } 1107 } catch (py::cast_error &err) { 1108 throw py::value_error((llvm::Twine("Result ") + 1109 llvm::Twine(it.index()) + " of operation \"" + 1110 name + "\" must be a Type (" + err.what() + 1111 ")") 1112 .str()); 1113 } 1114 } else if (segmentSpec == -1) { 1115 // Unpack sequence by appending. 1116 try { 1117 if (std::get<0>(it.value()).is_none()) { 1118 // Treat it as an empty list. 1119 resultSegmentLengths.push_back(0); 1120 } else { 1121 // Unpack the list. 1122 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1123 for (py::object segmentItem : segment) { 1124 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1125 if (!resultTypes.back()) { 1126 throw py::cast_error("contained a None item"); 1127 } 1128 } 1129 resultSegmentLengths.push_back(segment.size()); 1130 } 1131 } catch (std::exception &err) { 1132 // NOTE: Sloppy to be using a catch-all here, but there are at least 1133 // three different unrelated exceptions that can be thrown in the 1134 // above "casts". Just keep the scope above small and catch them all. 1135 throw py::value_error((llvm::Twine("Result ") + 1136 llvm::Twine(it.index()) + " of operation \"" + 1137 name + "\" must be a Sequence of Types (" + 1138 err.what() + ")") 1139 .str()); 1140 } 1141 } else { 1142 throw py::value_error("Unexpected segment spec"); 1143 } 1144 } 1145 } 1146 1147 // Unpack operands. 1148 std::vector<PyValue *> operands; 1149 operands.reserve(operands.size()); 1150 if (operandSegmentSpecObj.is_none()) { 1151 // Non-sized operand unpacking. 1152 for (auto it : llvm::enumerate(operandList)) { 1153 try { 1154 operands.push_back(py::cast<PyValue *>(it.value())); 1155 if (!operands.back()) 1156 throw py::cast_error(); 1157 } catch (py::cast_error &err) { 1158 throw py::value_error((llvm::Twine("Operand ") + 1159 llvm::Twine(it.index()) + " of operation \"" + 1160 name + "\" must be a Value (" + err.what() + ")") 1161 .str()); 1162 } 1163 } 1164 } else { 1165 // Sized operand unpacking. 1166 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1167 if (operandSegmentSpec.size() != operandList.size()) { 1168 throw py::value_error((llvm::Twine("Operation \"") + name + 1169 "\" requires " + 1170 llvm::Twine(operandSegmentSpec.size()) + 1171 "operand segments but was provided " + 1172 llvm::Twine(operandList.size())) 1173 .str()); 1174 } 1175 operandSegmentLengths.reserve(operandList.size()); 1176 for (auto it : 1177 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1178 int segmentSpec = std::get<1>(it.value()); 1179 if (segmentSpec == 1 || segmentSpec == 0) { 1180 // Unpack unary element. 1181 try { 1182 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1183 if (operandValue) { 1184 operands.push_back(operandValue); 1185 operandSegmentLengths.push_back(1); 1186 } else if (segmentSpec == 0) { 1187 // Allowed to be optional. 1188 operandSegmentLengths.push_back(0); 1189 } else { 1190 throw py::cast_error("was None and operand is not optional"); 1191 } 1192 } catch (py::cast_error &err) { 1193 throw py::value_error((llvm::Twine("Operand ") + 1194 llvm::Twine(it.index()) + " of operation \"" + 1195 name + "\" must be a Value (" + err.what() + 1196 ")") 1197 .str()); 1198 } 1199 } else if (segmentSpec == -1) { 1200 // Unpack sequence by appending. 1201 try { 1202 if (std::get<0>(it.value()).is_none()) { 1203 // Treat it as an empty list. 1204 operandSegmentLengths.push_back(0); 1205 } else { 1206 // Unpack the list. 1207 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1208 for (py::object segmentItem : segment) { 1209 operands.push_back(py::cast<PyValue *>(segmentItem)); 1210 if (!operands.back()) { 1211 throw py::cast_error("contained a None item"); 1212 } 1213 } 1214 operandSegmentLengths.push_back(segment.size()); 1215 } 1216 } catch (std::exception &err) { 1217 // NOTE: Sloppy to be using a catch-all here, but there are at least 1218 // three different unrelated exceptions that can be thrown in the 1219 // above "casts". Just keep the scope above small and catch them all. 1220 throw py::value_error((llvm::Twine("Operand ") + 1221 llvm::Twine(it.index()) + " of operation \"" + 1222 name + "\" must be a Sequence of Values (" + 1223 err.what() + ")") 1224 .str()); 1225 } 1226 } else { 1227 throw py::value_error("Unexpected segment spec"); 1228 } 1229 } 1230 } 1231 1232 // Merge operand/result segment lengths into attributes if needed. 1233 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1234 // Dup. 1235 if (attributes) { 1236 attributes = py::dict(*attributes); 1237 } else { 1238 attributes = py::dict(); 1239 } 1240 if (attributes->contains("result_segment_sizes") || 1241 attributes->contains("operand_segment_sizes")) { 1242 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1243 "'operand_segment_sizes' attribute is unsupported. " 1244 "Use Operation.create for such low-level access."); 1245 } 1246 1247 // Add result_segment_sizes attribute. 1248 if (!resultSegmentLengths.empty()) { 1249 int64_t size = resultSegmentLengths.size(); 1250 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1251 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1252 resultSegmentLengths.size(), resultSegmentLengths.data()); 1253 (*attributes)["result_segment_sizes"] = 1254 PyAttribute(context, segmentLengthAttr); 1255 } 1256 1257 // Add operand_segment_sizes attribute. 1258 if (!operandSegmentLengths.empty()) { 1259 int64_t size = operandSegmentLengths.size(); 1260 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1261 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1262 operandSegmentLengths.size(), operandSegmentLengths.data()); 1263 (*attributes)["operand_segment_sizes"] = 1264 PyAttribute(context, segmentLengthAttr); 1265 } 1266 } 1267 1268 // Delegate to create. 1269 return PyOperation::create(std::move(name), 1270 /*results=*/std::move(resultTypes), 1271 /*operands=*/std::move(operands), 1272 /*attributes=*/std::move(attributes), 1273 /*successors=*/std::move(successors), 1274 /*regions=*/*regions, location, maybeIp); 1275 } 1276 1277 PyOpView::PyOpView(py::object operationObject) 1278 // Casting through the PyOperationBase base-class and then back to the 1279 // Operation lets us accept any PyOperationBase subclass. 1280 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1281 operationObject(operation.getRef().getObject()) {} 1282 1283 py::object PyOpView::createRawSubclass(py::object userClass) { 1284 // This is... a little gross. The typical pattern is to have a pure python 1285 // class that extends OpView like: 1286 // class AddFOp(_cext.ir.OpView): 1287 // def __init__(self, loc, lhs, rhs): 1288 // operation = loc.context.create_operation( 1289 // "addf", lhs, rhs, results=[lhs.type]) 1290 // super().__init__(operation) 1291 // 1292 // I.e. The goal of the user facing type is to provide a nice constructor 1293 // that has complete freedom for the op under construction. This is at odds 1294 // with our other desire to sometimes create this object by just passing an 1295 // operation (to initialize the base class). We could do *arg and **kwargs 1296 // munging to try to make it work, but instead, we synthesize a new class 1297 // on the fly which extends this user class (AddFOp in this example) and 1298 // *give it* the base class's __init__ method, thus bypassing the 1299 // intermediate subclass's __init__ method entirely. While slightly, 1300 // underhanded, this is safe/legal because the type hierarchy has not changed 1301 // (we just added a new leaf) and we aren't mucking around with __new__. 1302 // Typically, this new class will be stored on the original as "_Raw" and will 1303 // be used for casts and other things that need a variant of the class that 1304 // is initialized purely from an operation. 1305 py::object parentMetaclass = 1306 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1307 py::dict attributes; 1308 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1309 // now. 1310 // auto opViewType = py::type::of<PyOpView>(); 1311 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1312 attributes["__init__"] = opViewType.attr("__init__"); 1313 py::str origName = userClass.attr("__name__"); 1314 py::str newName = py::str("_") + origName; 1315 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1316 } 1317 1318 //------------------------------------------------------------------------------ 1319 // PyInsertionPoint. 1320 //------------------------------------------------------------------------------ 1321 1322 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1323 1324 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1325 : refOperation(beforeOperationBase.getOperation().getRef()), 1326 block((*refOperation)->getBlock()) {} 1327 1328 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1329 PyOperation &operation = operationBase.getOperation(); 1330 if (operation.isAttached()) 1331 throw SetPyError(PyExc_ValueError, 1332 "Attempt to insert operation that is already attached"); 1333 block.getParentOperation()->checkValid(); 1334 MlirOperation beforeOp = {nullptr}; 1335 if (refOperation) { 1336 // Insert before operation. 1337 (*refOperation)->checkValid(); 1338 beforeOp = (*refOperation)->get(); 1339 } else { 1340 // Insert at end (before null) is only valid if the block does not 1341 // already end in a known terminator (violating this will cause assertion 1342 // failures later). 1343 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1344 throw py::index_error("Cannot insert operation at the end of a block " 1345 "that already has a terminator. Did you mean to " 1346 "use 'InsertionPoint.at_block_terminator(block)' " 1347 "versus 'InsertionPoint(block)'?"); 1348 } 1349 } 1350 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1351 operation.setAttached(); 1352 } 1353 1354 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1355 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1356 if (mlirOperationIsNull(firstOp)) { 1357 // Just insert at end. 1358 return PyInsertionPoint(block); 1359 } 1360 1361 // Insert before first op. 1362 PyOperationRef firstOpRef = PyOperation::forOperation( 1363 block.getParentOperation()->getContext(), firstOp); 1364 return PyInsertionPoint{block, std::move(firstOpRef)}; 1365 } 1366 1367 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1368 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1369 if (mlirOperationIsNull(terminator)) 1370 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1371 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1372 block.getParentOperation()->getContext(), terminator); 1373 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1374 } 1375 1376 py::object PyInsertionPoint::contextEnter() { 1377 return PyThreadContextEntry::pushInsertionPoint(*this); 1378 } 1379 1380 void PyInsertionPoint::contextExit(pybind11::object excType, 1381 pybind11::object excVal, 1382 pybind11::object excTb) { 1383 PyThreadContextEntry::popInsertionPoint(*this); 1384 } 1385 1386 //------------------------------------------------------------------------------ 1387 // PyAttribute. 1388 //------------------------------------------------------------------------------ 1389 1390 bool PyAttribute::operator==(const PyAttribute &other) { 1391 return mlirAttributeEqual(attr, other.attr); 1392 } 1393 1394 py::object PyAttribute::getCapsule() { 1395 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1396 } 1397 1398 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1399 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1400 if (mlirAttributeIsNull(rawAttr)) 1401 throw py::error_already_set(); 1402 return PyAttribute( 1403 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1404 } 1405 1406 //------------------------------------------------------------------------------ 1407 // PyNamedAttribute. 1408 //------------------------------------------------------------------------------ 1409 1410 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1411 : ownedName(new std::string(std::move(ownedName))) { 1412 namedAttr = mlirNamedAttributeGet( 1413 mlirIdentifierGet(mlirAttributeGetContext(attr), 1414 toMlirStringRef(*this->ownedName)), 1415 attr); 1416 } 1417 1418 //------------------------------------------------------------------------------ 1419 // PyType. 1420 //------------------------------------------------------------------------------ 1421 1422 bool PyType::operator==(const PyType &other) { 1423 return mlirTypeEqual(type, other.type); 1424 } 1425 1426 py::object PyType::getCapsule() { 1427 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1428 } 1429 1430 PyType PyType::createFromCapsule(py::object capsule) { 1431 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1432 if (mlirTypeIsNull(rawType)) 1433 throw py::error_already_set(); 1434 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1435 rawType); 1436 } 1437 1438 //------------------------------------------------------------------------------ 1439 // PyValue and subclases. 1440 //------------------------------------------------------------------------------ 1441 1442 namespace { 1443 /// CRTP base class for Python MLIR values that subclass Value and should be 1444 /// castable from it. The value hierarchy is one level deep and is not supposed 1445 /// to accommodate other levels unless core MLIR changes. 1446 template <typename DerivedTy> 1447 class PyConcreteValue : public PyValue { 1448 public: 1449 // Derived classes must define statics for: 1450 // IsAFunctionTy isaFunction 1451 // const char *pyClassName 1452 // and redefine bindDerived. 1453 using ClassTy = py::class_<DerivedTy, PyValue>; 1454 using IsAFunctionTy = bool (*)(MlirValue); 1455 1456 PyConcreteValue() = default; 1457 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1458 : PyValue(operationRef, value) {} 1459 PyConcreteValue(PyValue &orig) 1460 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1461 1462 /// Attempts to cast the original value to the derived type and throws on 1463 /// type mismatches. 1464 static MlirValue castFrom(PyValue &orig) { 1465 if (!DerivedTy::isaFunction(orig.get())) { 1466 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1467 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1468 DerivedTy::pyClassName + 1469 " (from " + origRepr + ")"); 1470 } 1471 return orig.get(); 1472 } 1473 1474 /// Binds the Python module objects to functions of this class. 1475 static void bind(py::module &m) { 1476 auto cls = ClassTy(m, DerivedTy::pyClassName); 1477 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1478 DerivedTy::bindDerived(cls); 1479 } 1480 1481 /// Implemented by derived classes to add methods to the Python subclass. 1482 static void bindDerived(ClassTy &m) {} 1483 }; 1484 1485 /// Python wrapper for MlirBlockArgument. 1486 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1487 public: 1488 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1489 static constexpr const char *pyClassName = "BlockArgument"; 1490 using PyConcreteValue::PyConcreteValue; 1491 1492 static void bindDerived(ClassTy &c) { 1493 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1494 return PyBlock(self.getParentOperation(), 1495 mlirBlockArgumentGetOwner(self.get())); 1496 }); 1497 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1498 return mlirBlockArgumentGetArgNumber(self.get()); 1499 }); 1500 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1501 return mlirBlockArgumentSetType(self.get(), type); 1502 }); 1503 } 1504 }; 1505 1506 /// Python wrapper for MlirOpResult. 1507 class PyOpResult : public PyConcreteValue<PyOpResult> { 1508 public: 1509 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1510 static constexpr const char *pyClassName = "OpResult"; 1511 using PyConcreteValue::PyConcreteValue; 1512 1513 static void bindDerived(ClassTy &c) { 1514 c.def_property_readonly("owner", [](PyOpResult &self) { 1515 assert( 1516 mlirOperationEqual(self.getParentOperation()->get(), 1517 mlirOpResultGetOwner(self.get())) && 1518 "expected the owner of the value in Python to match that in the IR"); 1519 return self.getParentOperation(); 1520 }); 1521 c.def_property_readonly("result_number", [](PyOpResult &self) { 1522 return mlirOpResultGetResultNumber(self.get()); 1523 }); 1524 } 1525 }; 1526 1527 /// A list of block arguments. Internally, these are stored as consecutive 1528 /// elements, random access is cheap. The argument list is associated with the 1529 /// operation that contains the block (detached blocks are not allowed in 1530 /// Python bindings) and extends its lifetime. 1531 class PyBlockArgumentList { 1532 public: 1533 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1534 : operation(std::move(operation)), block(block) {} 1535 1536 /// Returns the length of the block argument list. 1537 intptr_t dunderLen() { 1538 operation->checkValid(); 1539 return mlirBlockGetNumArguments(block); 1540 } 1541 1542 /// Returns `index`-th element of the block argument list. 1543 PyBlockArgument dunderGetItem(intptr_t index) { 1544 if (index < 0 || index >= dunderLen()) { 1545 throw SetPyError(PyExc_IndexError, 1546 "attempt to access out of bounds region"); 1547 } 1548 PyValue value(operation, mlirBlockGetArgument(block, index)); 1549 return PyBlockArgument(value); 1550 } 1551 1552 /// Defines a Python class in the bindings. 1553 static void bind(py::module &m) { 1554 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1555 .def("__len__", &PyBlockArgumentList::dunderLen) 1556 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1557 } 1558 1559 private: 1560 PyOperationRef operation; 1561 MlirBlock block; 1562 }; 1563 1564 /// A list of operation operands. Internally, these are stored as consecutive 1565 /// elements, random access is cheap. The result list is associated with the 1566 /// operation whose results these are, and extends the lifetime of this 1567 /// operation. 1568 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1569 public: 1570 static constexpr const char *pyClassName = "OpOperandList"; 1571 1572 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1573 intptr_t length = -1, intptr_t step = 1) 1574 : Sliceable(startIndex, 1575 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1576 : length, 1577 step), 1578 operation(operation) {} 1579 1580 intptr_t getNumElements() { 1581 operation->checkValid(); 1582 return mlirOperationGetNumOperands(operation->get()); 1583 } 1584 1585 PyValue getElement(intptr_t pos) { 1586 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); 1587 } 1588 1589 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1590 return PyOpOperandList(operation, startIndex, length, step); 1591 } 1592 1593 private: 1594 PyOperationRef operation; 1595 }; 1596 1597 /// A list of operation results. Internally, these are stored as consecutive 1598 /// elements, random access is cheap. The result list is associated with the 1599 /// operation whose results these are, and extends the lifetime of this 1600 /// operation. 1601 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1602 public: 1603 static constexpr const char *pyClassName = "OpResultList"; 1604 1605 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1606 intptr_t length = -1, intptr_t step = 1) 1607 : Sliceable(startIndex, 1608 length == -1 ? mlirOperationGetNumResults(operation->get()) 1609 : length, 1610 step), 1611 operation(operation) {} 1612 1613 intptr_t getNumElements() { 1614 operation->checkValid(); 1615 return mlirOperationGetNumResults(operation->get()); 1616 } 1617 1618 PyOpResult getElement(intptr_t index) { 1619 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1620 return PyOpResult(value); 1621 } 1622 1623 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1624 return PyOpResultList(operation, startIndex, length, step); 1625 } 1626 1627 private: 1628 PyOperationRef operation; 1629 }; 1630 1631 /// A list of operation attributes. Can be indexed by name, producing 1632 /// attributes, or by index, producing named attributes. 1633 class PyOpAttributeMap { 1634 public: 1635 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1636 1637 PyAttribute dunderGetItemNamed(const std::string &name) { 1638 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1639 toMlirStringRef(name)); 1640 if (mlirAttributeIsNull(attr)) { 1641 throw SetPyError(PyExc_KeyError, 1642 "attempt to access a non-existent attribute"); 1643 } 1644 return PyAttribute(operation->getContext(), attr); 1645 } 1646 1647 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1648 if (index < 0 || index >= dunderLen()) { 1649 throw SetPyError(PyExc_IndexError, 1650 "attempt to access out of bounds attribute"); 1651 } 1652 MlirNamedAttribute namedAttr = 1653 mlirOperationGetAttribute(operation->get(), index); 1654 return PyNamedAttribute( 1655 namedAttr.attribute, 1656 std::string(mlirIdentifierStr(namedAttr.name).data)); 1657 } 1658 1659 void dunderSetItem(const std::string &name, PyAttribute attr) { 1660 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1661 attr); 1662 } 1663 1664 void dunderDelItem(const std::string &name) { 1665 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1666 toMlirStringRef(name)); 1667 if (!removed) 1668 throw SetPyError(PyExc_KeyError, 1669 "attempt to delete a non-existent attribute"); 1670 } 1671 1672 intptr_t dunderLen() { 1673 return mlirOperationGetNumAttributes(operation->get()); 1674 } 1675 1676 bool dunderContains(const std::string &name) { 1677 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1678 operation->get(), toMlirStringRef(name))); 1679 } 1680 1681 static void bind(py::module &m) { 1682 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1683 .def("__contains__", &PyOpAttributeMap::dunderContains) 1684 .def("__len__", &PyOpAttributeMap::dunderLen) 1685 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1686 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1687 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1688 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1689 } 1690 1691 private: 1692 PyOperationRef operation; 1693 }; 1694 1695 } // end namespace 1696 1697 //------------------------------------------------------------------------------ 1698 // Populates the core exports of the 'ir' submodule. 1699 //------------------------------------------------------------------------------ 1700 1701 void mlir::python::populateIRCore(py::module &m) { 1702 //---------------------------------------------------------------------------- 1703 // Mapping of MlirContext 1704 //---------------------------------------------------------------------------- 1705 py::class_<PyMlirContext>(m, "Context") 1706 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1707 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1708 .def("_get_context_again", 1709 [](PyMlirContext &self) { 1710 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1711 return ref.releaseObject(); 1712 }) 1713 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1714 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1715 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1716 &PyMlirContext::getCapsule) 1717 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1718 .def("__enter__", &PyMlirContext::contextEnter) 1719 .def("__exit__", &PyMlirContext::contextExit) 1720 .def_property_readonly_static( 1721 "current", 1722 [](py::object & /*class*/) { 1723 auto *context = PyThreadContextEntry::getDefaultContext(); 1724 if (!context) 1725 throw SetPyError(PyExc_ValueError, "No current Context"); 1726 return context; 1727 }, 1728 "Gets the Context bound to the current thread or raises ValueError") 1729 .def_property_readonly( 1730 "dialects", 1731 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1732 "Gets a container for accessing dialects by name") 1733 .def_property_readonly( 1734 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1735 "Alias for 'dialect'") 1736 .def( 1737 "get_dialect_descriptor", 1738 [=](PyMlirContext &self, std::string &name) { 1739 MlirDialect dialect = mlirContextGetOrLoadDialect( 1740 self.get(), {name.data(), name.size()}); 1741 if (mlirDialectIsNull(dialect)) { 1742 throw SetPyError(PyExc_ValueError, 1743 Twine("Dialect '") + name + "' not found"); 1744 } 1745 return PyDialectDescriptor(self.getRef(), dialect); 1746 }, 1747 "Gets or loads a dialect by name, returning its descriptor object") 1748 .def_property( 1749 "allow_unregistered_dialects", 1750 [](PyMlirContext &self) -> bool { 1751 return mlirContextGetAllowUnregisteredDialects(self.get()); 1752 }, 1753 [](PyMlirContext &self, bool value) { 1754 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1755 }) 1756 .def("is_registered_operation", 1757 [](PyMlirContext &self, std::string &name) { 1758 return mlirContextIsRegisteredOperation( 1759 self.get(), MlirStringRef{name.data(), name.size()}); 1760 }); 1761 1762 //---------------------------------------------------------------------------- 1763 // Mapping of PyDialectDescriptor 1764 //---------------------------------------------------------------------------- 1765 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1766 .def_property_readonly("namespace", 1767 [](PyDialectDescriptor &self) { 1768 MlirStringRef ns = 1769 mlirDialectGetNamespace(self.get()); 1770 return py::str(ns.data, ns.length); 1771 }) 1772 .def("__repr__", [](PyDialectDescriptor &self) { 1773 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1774 std::string repr("<DialectDescriptor "); 1775 repr.append(ns.data, ns.length); 1776 repr.append(">"); 1777 return repr; 1778 }); 1779 1780 //---------------------------------------------------------------------------- 1781 // Mapping of PyDialects 1782 //---------------------------------------------------------------------------- 1783 py::class_<PyDialects>(m, "Dialects") 1784 .def("__getitem__", 1785 [=](PyDialects &self, std::string keyName) { 1786 MlirDialect dialect = 1787 self.getDialectForKey(keyName, /*attrError=*/false); 1788 py::object descriptor = 1789 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1790 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1791 }) 1792 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1793 MlirDialect dialect = 1794 self.getDialectForKey(attrName, /*attrError=*/true); 1795 py::object descriptor = 1796 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1797 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1798 }); 1799 1800 //---------------------------------------------------------------------------- 1801 // Mapping of PyDialect 1802 //---------------------------------------------------------------------------- 1803 py::class_<PyDialect>(m, "Dialect") 1804 .def(py::init<py::object>(), "descriptor") 1805 .def_property_readonly( 1806 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1807 .def("__repr__", [](py::object self) { 1808 auto clazz = self.attr("__class__"); 1809 return py::str("<Dialect ") + 1810 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1811 clazz.attr("__module__") + py::str(".") + 1812 clazz.attr("__name__") + py::str(")>"); 1813 }); 1814 1815 //---------------------------------------------------------------------------- 1816 // Mapping of Location 1817 //---------------------------------------------------------------------------- 1818 py::class_<PyLocation>(m, "Location") 1819 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1820 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1821 .def("__enter__", &PyLocation::contextEnter) 1822 .def("__exit__", &PyLocation::contextExit) 1823 .def("__eq__", 1824 [](PyLocation &self, PyLocation &other) -> bool { 1825 return mlirLocationEqual(self, other); 1826 }) 1827 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1828 .def_property_readonly_static( 1829 "current", 1830 [](py::object & /*class*/) { 1831 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1832 if (!loc) 1833 throw SetPyError(PyExc_ValueError, "No current Location"); 1834 return loc; 1835 }, 1836 "Gets the Location bound to the current thread or raises ValueError") 1837 .def_static( 1838 "unknown", 1839 [](DefaultingPyMlirContext context) { 1840 return PyLocation(context->getRef(), 1841 mlirLocationUnknownGet(context->get())); 1842 }, 1843 py::arg("context") = py::none(), 1844 "Gets a Location representing an unknown location") 1845 .def_static( 1846 "file", 1847 [](std::string filename, int line, int col, 1848 DefaultingPyMlirContext context) { 1849 return PyLocation( 1850 context->getRef(), 1851 mlirLocationFileLineColGet( 1852 context->get(), toMlirStringRef(filename), line, col)); 1853 }, 1854 py::arg("filename"), py::arg("line"), py::arg("col"), 1855 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1856 .def_property_readonly( 1857 "context", 1858 [](PyLocation &self) { return self.getContext().getObject(); }, 1859 "Context that owns the Location") 1860 .def("__repr__", [](PyLocation &self) { 1861 PyPrintAccumulator printAccum; 1862 mlirLocationPrint(self, printAccum.getCallback(), 1863 printAccum.getUserData()); 1864 return printAccum.join(); 1865 }); 1866 1867 //---------------------------------------------------------------------------- 1868 // Mapping of Module 1869 //---------------------------------------------------------------------------- 1870 py::class_<PyModule>(m, "Module") 1871 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1872 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1873 .def_static( 1874 "parse", 1875 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1876 MlirModule module = mlirModuleCreateParse( 1877 context->get(), toMlirStringRef(moduleAsm)); 1878 // TODO: Rework error reporting once diagnostic engine is exposed 1879 // in C API. 1880 if (mlirModuleIsNull(module)) { 1881 throw SetPyError( 1882 PyExc_ValueError, 1883 "Unable to parse module assembly (see diagnostics)"); 1884 } 1885 return PyModule::forModule(module).releaseObject(); 1886 }, 1887 py::arg("asm"), py::arg("context") = py::none(), 1888 kModuleParseDocstring) 1889 .def_static( 1890 "create", 1891 [](DefaultingPyLocation loc) { 1892 MlirModule module = mlirModuleCreateEmpty(loc); 1893 return PyModule::forModule(module).releaseObject(); 1894 }, 1895 py::arg("loc") = py::none(), "Creates an empty module") 1896 .def_property_readonly( 1897 "context", 1898 [](PyModule &self) { return self.getContext().getObject(); }, 1899 "Context that created the Module") 1900 .def_property_readonly( 1901 "operation", 1902 [](PyModule &self) { 1903 return PyOperation::forOperation(self.getContext(), 1904 mlirModuleGetOperation(self.get()), 1905 self.getRef().releaseObject()) 1906 .releaseObject(); 1907 }, 1908 "Accesses the module as an operation") 1909 .def_property_readonly( 1910 "body", 1911 [](PyModule &self) { 1912 PyOperationRef module_op = PyOperation::forOperation( 1913 self.getContext(), mlirModuleGetOperation(self.get()), 1914 self.getRef().releaseObject()); 1915 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 1916 return returnBlock; 1917 }, 1918 "Return the block for this module") 1919 .def( 1920 "dump", 1921 [](PyModule &self) { 1922 mlirOperationDump(mlirModuleGetOperation(self.get())); 1923 }, 1924 kDumpDocstring) 1925 .def( 1926 "__str__", 1927 [](PyModule &self) { 1928 MlirOperation operation = mlirModuleGetOperation(self.get()); 1929 PyPrintAccumulator printAccum; 1930 mlirOperationPrint(operation, printAccum.getCallback(), 1931 printAccum.getUserData()); 1932 return printAccum.join(); 1933 }, 1934 kOperationStrDunderDocstring); 1935 1936 //---------------------------------------------------------------------------- 1937 // Mapping of Operation. 1938 //---------------------------------------------------------------------------- 1939 py::class_<PyOperationBase>(m, "_OperationBase") 1940 .def("__eq__", 1941 [](PyOperationBase &self, PyOperationBase &other) { 1942 return &self.getOperation() == &other.getOperation(); 1943 }) 1944 .def("__eq__", 1945 [](PyOperationBase &self, py::object other) { return false; }) 1946 .def_property_readonly("attributes", 1947 [](PyOperationBase &self) { 1948 return PyOpAttributeMap( 1949 self.getOperation().getRef()); 1950 }) 1951 .def_property_readonly("operands", 1952 [](PyOperationBase &self) { 1953 return PyOpOperandList( 1954 self.getOperation().getRef()); 1955 }) 1956 .def_property_readonly("regions", 1957 [](PyOperationBase &self) { 1958 return PyRegionList( 1959 self.getOperation().getRef()); 1960 }) 1961 .def_property_readonly( 1962 "results", 1963 [](PyOperationBase &self) { 1964 return PyOpResultList(self.getOperation().getRef()); 1965 }, 1966 "Returns the list of Operation results.") 1967 .def_property_readonly( 1968 "result", 1969 [](PyOperationBase &self) { 1970 auto &operation = self.getOperation(); 1971 auto numResults = mlirOperationGetNumResults(operation); 1972 if (numResults != 1) { 1973 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 1974 throw SetPyError( 1975 PyExc_ValueError, 1976 Twine("Cannot call .result on operation ") + 1977 StringRef(name.data, name.length) + " which has " + 1978 Twine(numResults) + 1979 " results (it is only valid for operations with a " 1980 "single result)"); 1981 } 1982 return PyOpResult(operation.getRef(), 1983 mlirOperationGetResult(operation, 0)); 1984 }, 1985 "Shortcut to get an op result if it has only one (throws an error " 1986 "otherwise).") 1987 .def("__iter__", 1988 [](PyOperationBase &self) { 1989 return PyRegionIterator(self.getOperation().getRef()); 1990 }) 1991 .def( 1992 "__str__", 1993 [](PyOperationBase &self) { 1994 return self.getAsm(/*binary=*/false, 1995 /*largeElementsLimit=*/llvm::None, 1996 /*enableDebugInfo=*/false, 1997 /*prettyDebugInfo=*/false, 1998 /*printGenericOpForm=*/false, 1999 /*useLocalScope=*/false); 2000 }, 2001 "Returns the assembly form of the operation.") 2002 .def("print", &PyOperationBase::print, 2003 // Careful: Lots of arguments must match up with print method. 2004 py::arg("file") = py::none(), py::arg("binary") = false, 2005 py::arg("large_elements_limit") = py::none(), 2006 py::arg("enable_debug_info") = false, 2007 py::arg("pretty_debug_info") = false, 2008 py::arg("print_generic_op_form") = false, 2009 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2010 .def("get_asm", &PyOperationBase::getAsm, 2011 // Careful: Lots of arguments must match up with get_asm method. 2012 py::arg("binary") = false, 2013 py::arg("large_elements_limit") = py::none(), 2014 py::arg("enable_debug_info") = false, 2015 py::arg("pretty_debug_info") = false, 2016 py::arg("print_generic_op_form") = false, 2017 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2018 .def( 2019 "verify", 2020 [](PyOperationBase &self) { 2021 return mlirOperationVerify(self.getOperation()); 2022 }, 2023 "Verify the operation and return true if it passes, false if it " 2024 "fails."); 2025 2026 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2027 .def_static("create", &PyOperation::create, py::arg("name"), 2028 py::arg("results") = py::none(), 2029 py::arg("operands") = py::none(), 2030 py::arg("attributes") = py::none(), 2031 py::arg("successors") = py::none(), py::arg("regions") = 0, 2032 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2033 kOperationCreateDocstring) 2034 .def_property_readonly("name", 2035 [](PyOperation &self) { 2036 MlirOperation operation = self.get(); 2037 MlirStringRef name = mlirIdentifierStr( 2038 mlirOperationGetName(operation)); 2039 return py::str(name.data, name.length); 2040 }) 2041 .def_property_readonly( 2042 "context", 2043 [](PyOperation &self) { return self.getContext().getObject(); }, 2044 "Context that owns the Operation") 2045 .def_property_readonly("opview", &PyOperation::createOpView); 2046 2047 auto opViewClass = 2048 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2049 .def(py::init<py::object>()) 2050 .def_property_readonly("operation", &PyOpView::getOperationObject) 2051 .def_property_readonly( 2052 "context", 2053 [](PyOpView &self) { 2054 return self.getOperation().getContext().getObject(); 2055 }, 2056 "Context that owns the Operation") 2057 .def("__str__", [](PyOpView &self) { 2058 return py::str(self.getOperationObject()); 2059 }); 2060 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2061 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2062 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2063 opViewClass.attr("build_generic") = classmethod( 2064 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2065 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2066 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2067 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2068 "Builds a specific, generated OpView based on class level attributes."); 2069 2070 //---------------------------------------------------------------------------- 2071 // Mapping of PyRegion. 2072 //---------------------------------------------------------------------------- 2073 py::class_<PyRegion>(m, "Region") 2074 .def_property_readonly( 2075 "blocks", 2076 [](PyRegion &self) { 2077 return PyBlockList(self.getParentOperation(), self.get()); 2078 }, 2079 "Returns a forward-optimized sequence of blocks.") 2080 .def( 2081 "__iter__", 2082 [](PyRegion &self) { 2083 self.checkValid(); 2084 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2085 return PyBlockIterator(self.getParentOperation(), firstBlock); 2086 }, 2087 "Iterates over blocks in the region.") 2088 .def("__eq__", 2089 [](PyRegion &self, PyRegion &other) { 2090 return self.get().ptr == other.get().ptr; 2091 }) 2092 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2093 2094 //---------------------------------------------------------------------------- 2095 // Mapping of PyBlock. 2096 //---------------------------------------------------------------------------- 2097 py::class_<PyBlock>(m, "Block") 2098 .def_property_readonly( 2099 "arguments", 2100 [](PyBlock &self) { 2101 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2102 }, 2103 "Returns a list of block arguments.") 2104 .def_property_readonly( 2105 "operations", 2106 [](PyBlock &self) { 2107 return PyOperationList(self.getParentOperation(), self.get()); 2108 }, 2109 "Returns a forward-optimized sequence of operations.") 2110 .def( 2111 "__iter__", 2112 [](PyBlock &self) { 2113 self.checkValid(); 2114 MlirOperation firstOperation = 2115 mlirBlockGetFirstOperation(self.get()); 2116 return PyOperationIterator(self.getParentOperation(), 2117 firstOperation); 2118 }, 2119 "Iterates over operations in the block.") 2120 .def("__eq__", 2121 [](PyBlock &self, PyBlock &other) { 2122 return self.get().ptr == other.get().ptr; 2123 }) 2124 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2125 .def( 2126 "__str__", 2127 [](PyBlock &self) { 2128 self.checkValid(); 2129 PyPrintAccumulator printAccum; 2130 mlirBlockPrint(self.get(), printAccum.getCallback(), 2131 printAccum.getUserData()); 2132 return printAccum.join(); 2133 }, 2134 "Returns the assembly form of the block."); 2135 2136 //---------------------------------------------------------------------------- 2137 // Mapping of PyInsertionPoint. 2138 //---------------------------------------------------------------------------- 2139 2140 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2141 .def(py::init<PyBlock &>(), py::arg("block"), 2142 "Inserts after the last operation but still inside the block.") 2143 .def("__enter__", &PyInsertionPoint::contextEnter) 2144 .def("__exit__", &PyInsertionPoint::contextExit) 2145 .def_property_readonly_static( 2146 "current", 2147 [](py::object & /*class*/) { 2148 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2149 if (!ip) 2150 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2151 return ip; 2152 }, 2153 "Gets the InsertionPoint bound to the current thread or raises " 2154 "ValueError if none has been set") 2155 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2156 "Inserts before a referenced operation.") 2157 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2158 py::arg("block"), "Inserts at the beginning of the block.") 2159 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2160 py::arg("block"), "Inserts before the block terminator.") 2161 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2162 "Inserts an operation."); 2163 2164 //---------------------------------------------------------------------------- 2165 // Mapping of PyAttribute. 2166 //---------------------------------------------------------------------------- 2167 py::class_<PyAttribute>(m, "Attribute") 2168 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2169 &PyAttribute::getCapsule) 2170 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2171 .def_static( 2172 "parse", 2173 [](std::string attrSpec, DefaultingPyMlirContext context) { 2174 MlirAttribute type = mlirAttributeParseGet( 2175 context->get(), toMlirStringRef(attrSpec)); 2176 // TODO: Rework error reporting once diagnostic engine is exposed 2177 // in C API. 2178 if (mlirAttributeIsNull(type)) { 2179 throw SetPyError(PyExc_ValueError, 2180 Twine("Unable to parse attribute: '") + 2181 attrSpec + "'"); 2182 } 2183 return PyAttribute(context->getRef(), type); 2184 }, 2185 py::arg("asm"), py::arg("context") = py::none(), 2186 "Parses an attribute from an assembly form") 2187 .def_property_readonly( 2188 "context", 2189 [](PyAttribute &self) { return self.getContext().getObject(); }, 2190 "Context that owns the Attribute") 2191 .def_property_readonly("type", 2192 [](PyAttribute &self) { 2193 return PyType(self.getContext()->getRef(), 2194 mlirAttributeGetType(self)); 2195 }) 2196 .def( 2197 "get_named", 2198 [](PyAttribute &self, std::string name) { 2199 return PyNamedAttribute(self, std::move(name)); 2200 }, 2201 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2202 .def("__eq__", 2203 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2204 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2205 .def( 2206 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2207 kDumpDocstring) 2208 .def( 2209 "__str__", 2210 [](PyAttribute &self) { 2211 PyPrintAccumulator printAccum; 2212 mlirAttributePrint(self, printAccum.getCallback(), 2213 printAccum.getUserData()); 2214 return printAccum.join(); 2215 }, 2216 "Returns the assembly form of the Attribute.") 2217 .def("__repr__", [](PyAttribute &self) { 2218 // Generally, assembly formats are not printed for __repr__ because 2219 // this can cause exceptionally long debug output and exceptions. 2220 // However, attribute values are generally considered useful and are 2221 // printed. This may need to be re-evaluated if debug dumps end up 2222 // being excessive. 2223 PyPrintAccumulator printAccum; 2224 printAccum.parts.append("Attribute("); 2225 mlirAttributePrint(self, printAccum.getCallback(), 2226 printAccum.getUserData()); 2227 printAccum.parts.append(")"); 2228 return printAccum.join(); 2229 }); 2230 2231 //---------------------------------------------------------------------------- 2232 // Mapping of PyNamedAttribute 2233 //---------------------------------------------------------------------------- 2234 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2235 .def("__repr__", 2236 [](PyNamedAttribute &self) { 2237 PyPrintAccumulator printAccum; 2238 printAccum.parts.append("NamedAttribute("); 2239 printAccum.parts.append( 2240 mlirIdentifierStr(self.namedAttr.name).data); 2241 printAccum.parts.append("="); 2242 mlirAttributePrint(self.namedAttr.attribute, 2243 printAccum.getCallback(), 2244 printAccum.getUserData()); 2245 printAccum.parts.append(")"); 2246 return printAccum.join(); 2247 }) 2248 .def_property_readonly( 2249 "name", 2250 [](PyNamedAttribute &self) { 2251 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2252 mlirIdentifierStr(self.namedAttr.name).length); 2253 }, 2254 "The name of the NamedAttribute binding") 2255 .def_property_readonly( 2256 "attr", 2257 [](PyNamedAttribute &self) { 2258 // TODO: When named attribute is removed/refactored, also remove 2259 // this constructor (it does an inefficient table lookup). 2260 auto contextRef = PyMlirContext::forContext( 2261 mlirAttributeGetContext(self.namedAttr.attribute)); 2262 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2263 }, 2264 py::keep_alive<0, 1>(), 2265 "The underlying generic attribute of the NamedAttribute binding"); 2266 2267 //---------------------------------------------------------------------------- 2268 // Mapping of PyType. 2269 //---------------------------------------------------------------------------- 2270 py::class_<PyType>(m, "Type") 2271 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2272 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2273 .def_static( 2274 "parse", 2275 [](std::string typeSpec, DefaultingPyMlirContext context) { 2276 MlirType type = 2277 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2278 // TODO: Rework error reporting once diagnostic engine is exposed 2279 // in C API. 2280 if (mlirTypeIsNull(type)) { 2281 throw SetPyError(PyExc_ValueError, 2282 Twine("Unable to parse type: '") + typeSpec + 2283 "'"); 2284 } 2285 return PyType(context->getRef(), type); 2286 }, 2287 py::arg("asm"), py::arg("context") = py::none(), 2288 kContextParseTypeDocstring) 2289 .def_property_readonly( 2290 "context", [](PyType &self) { return self.getContext().getObject(); }, 2291 "Context that owns the Type") 2292 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2293 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2294 .def( 2295 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2296 .def( 2297 "__str__", 2298 [](PyType &self) { 2299 PyPrintAccumulator printAccum; 2300 mlirTypePrint(self, printAccum.getCallback(), 2301 printAccum.getUserData()); 2302 return printAccum.join(); 2303 }, 2304 "Returns the assembly form of the type.") 2305 .def("__repr__", [](PyType &self) { 2306 // Generally, assembly formats are not printed for __repr__ because 2307 // this can cause exceptionally long debug output and exceptions. 2308 // However, types are an exception as they typically have compact 2309 // assembly forms and printing them is useful. 2310 PyPrintAccumulator printAccum; 2311 printAccum.parts.append("Type("); 2312 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2313 printAccum.parts.append(")"); 2314 return printAccum.join(); 2315 }); 2316 2317 //---------------------------------------------------------------------------- 2318 // Mapping of Value. 2319 //---------------------------------------------------------------------------- 2320 py::class_<PyValue>(m, "Value") 2321 .def_property_readonly( 2322 "context", 2323 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2324 "Context in which the value lives.") 2325 .def( 2326 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2327 kDumpDocstring) 2328 .def("__eq__", 2329 [](PyValue &self, PyValue &other) { 2330 return self.get().ptr == other.get().ptr; 2331 }) 2332 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2333 .def( 2334 "__str__", 2335 [](PyValue &self) { 2336 PyPrintAccumulator printAccum; 2337 printAccum.parts.append("Value("); 2338 mlirValuePrint(self.get(), printAccum.getCallback(), 2339 printAccum.getUserData()); 2340 printAccum.parts.append(")"); 2341 return printAccum.join(); 2342 }, 2343 kValueDunderStrDocstring) 2344 .def_property_readonly("type", [](PyValue &self) { 2345 return PyType(self.getParentOperation()->getContext(), 2346 mlirValueGetType(self.get())); 2347 }); 2348 PyBlockArgument::bind(m); 2349 PyOpResult::bind(m); 2350 2351 // Container bindings. 2352 PyBlockArgumentList::bind(m); 2353 PyBlockIterator::bind(m); 2354 PyBlockList::bind(m); 2355 PyOperationIterator::bind(m); 2356 PyOperationList::bind(m); 2357 PyOpAttributeMap::bind(m); 2358 PyOpOperandList::bind(m); 2359 PyOpResultList::bind(m); 2360 PyRegionIterator::bind(m); 2361 PyRegionList::bind(m); 2362 } 2363