1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Registration.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include <pybind11/stl.h> 20 21 namespace py = pybind11; 22 using namespace mlir; 23 using namespace mlir::python; 24 25 using llvm::SmallVector; 26 using llvm::StringRef; 27 using llvm::Twine; 28 29 //------------------------------------------------------------------------------ 30 // Docstrings (trivial, non-duplicated docstrings are included inline). 31 //------------------------------------------------------------------------------ 32 33 static const char kContextParseTypeDocstring[] = 34 R"(Parses the assembly form of a type. 35 36 Returns a Type object or raises a ValueError if the type cannot be parsed. 37 38 See also: https://mlir.llvm.org/docs/LangRef/#type-system 39 )"; 40 41 static const char kContextGetFileLocationDocstring[] = 42 R"(Gets a Location representing a file, line and column)"; 43 44 static const char kModuleParseDocstring[] = 45 R"(Parses a module's assembly format from a string. 46 47 Returns a new MlirModule or raises a ValueError if the parsing fails. 48 49 See also: https://mlir.llvm.org/docs/LangRef/ 50 )"; 51 52 static const char kOperationCreateDocstring[] = 53 R"(Creates a new operation. 54 55 Args: 56 name: Operation name (e.g. "dialect.operation"). 57 results: Sequence of Type representing op result types. 58 attributes: Dict of str:Attribute. 59 successors: List of Block for the operation's successors. 60 regions: Number of regions to create. 61 location: A Location object (defaults to resolve from context manager). 62 ip: An InsertionPoint (defaults to resolve from context manager or set to 63 False to disable insertion, even with an insertion point set in the 64 context manager). 65 Returns: 66 A new "detached" Operation object. Detached operations can be added 67 to blocks, which causes them to become "attached." 68 )"; 69 70 static const char kOperationPrintDocstring[] = 71 R"(Prints the assembly form of the operation to a file like object. 72 73 Args: 74 file: The file like object to write to. Defaults to sys.stdout. 75 binary: Whether to write bytes (True) or str (False). Defaults to False. 76 large_elements_limit: Whether to elide elements attributes above this 77 number of elements. Defaults to None (no limit). 78 enable_debug_info: Whether to print debug/location information. Defaults 79 to False. 80 pretty_debug_info: Whether to format debug information for easier reading 81 by a human (warning: the result is unparseable). 82 print_generic_op_form: Whether to print the generic assembly forms of all 83 ops. Defaults to False. 84 use_local_Scope: Whether to print in a way that is more optimized for 85 multi-threaded access but may not be consistent with how the overall 86 module prints. 87 )"; 88 89 static const char kOperationGetAsmDocstring[] = 90 R"(Gets the assembly form of the operation with all options available. 91 92 Args: 93 binary: Whether to return a bytes (True) or str (False) object. Defaults to 94 False. 95 ... others ...: See the print() method for common keyword arguments for 96 configuring the printout. 97 Returns: 98 Either a bytes or str object, depending on the setting of the 'binary' 99 argument. 100 )"; 101 102 static const char kOperationStrDunderDocstring[] = 103 R"(Gets the assembly form of the operation with default options. 104 105 If more advanced control over the assembly formatting or I/O options is needed, 106 use the dedicated print or get_asm method, which supports keyword arguments to 107 customize behavior. 108 )"; 109 110 static const char kDumpDocstring[] = 111 R"(Dumps a debug representation of the object to stderr.)"; 112 113 static const char kAppendBlockDocstring[] = 114 R"(Appends a new block, with argument types as positional args. 115 116 Returns: 117 The created block. 118 )"; 119 120 static const char kValueDunderStrDocstring[] = 121 R"(Returns the string form of the value. 122 123 If the value is a block argument, this is the assembly form of its type and the 124 position in the argument list. If the value is an operation result, this is 125 equivalent to printing the operation that produced it. 126 )"; 127 128 //------------------------------------------------------------------------------ 129 // Utilities. 130 //------------------------------------------------------------------------------ 131 132 // Helper for creating an @classmethod. 133 template <class Func, typename... Args> 134 py::object classmethod(Func f, Args... args) { 135 py::object cf = py::cpp_function(f, args...); 136 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 137 } 138 139 static py::object 140 createCustomDialectWrapper(const std::string &dialectNamespace, 141 py::object dialectDescriptor) { 142 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 143 if (!dialectClass) { 144 // Use the base class. 145 return py::cast(PyDialect(std::move(dialectDescriptor))); 146 } 147 148 // Create the custom implementation. 149 return (*dialectClass)(std::move(dialectDescriptor)); 150 } 151 152 static MlirStringRef toMlirStringRef(const std::string &s) { 153 return mlirStringRefCreate(s.data(), s.size()); 154 } 155 156 //------------------------------------------------------------------------------ 157 // Collections. 158 //------------------------------------------------------------------------------ 159 160 namespace { 161 162 class PyRegionIterator { 163 public: 164 PyRegionIterator(PyOperationRef operation) 165 : operation(std::move(operation)) {} 166 167 PyRegionIterator &dunderIter() { return *this; } 168 169 PyRegion dunderNext() { 170 operation->checkValid(); 171 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 172 throw py::stop_iteration(); 173 } 174 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 175 return PyRegion(operation, region); 176 } 177 178 static void bind(py::module &m) { 179 py::class_<PyRegionIterator>(m, "RegionIterator") 180 .def("__iter__", &PyRegionIterator::dunderIter) 181 .def("__next__", &PyRegionIterator::dunderNext); 182 } 183 184 private: 185 PyOperationRef operation; 186 int nextIndex = 0; 187 }; 188 189 /// Regions of an op are fixed length and indexed numerically so are represented 190 /// with a sequence-like container. 191 class PyRegionList { 192 public: 193 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 194 195 intptr_t dunderLen() { 196 operation->checkValid(); 197 return mlirOperationGetNumRegions(operation->get()); 198 } 199 200 PyRegion dunderGetItem(intptr_t index) { 201 // dunderLen checks validity. 202 if (index < 0 || index >= dunderLen()) { 203 throw SetPyError(PyExc_IndexError, 204 "attempt to access out of bounds region"); 205 } 206 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 207 return PyRegion(operation, region); 208 } 209 210 static void bind(py::module &m) { 211 py::class_<PyRegionList>(m, "RegionSequence") 212 .def("__len__", &PyRegionList::dunderLen) 213 .def("__getitem__", &PyRegionList::dunderGetItem); 214 } 215 216 private: 217 PyOperationRef operation; 218 }; 219 220 class PyBlockIterator { 221 public: 222 PyBlockIterator(PyOperationRef operation, MlirBlock next) 223 : operation(std::move(operation)), next(next) {} 224 225 PyBlockIterator &dunderIter() { return *this; } 226 227 PyBlock dunderNext() { 228 operation->checkValid(); 229 if (mlirBlockIsNull(next)) { 230 throw py::stop_iteration(); 231 } 232 233 PyBlock returnBlock(operation, next); 234 next = mlirBlockGetNextInRegion(next); 235 return returnBlock; 236 } 237 238 static void bind(py::module &m) { 239 py::class_<PyBlockIterator>(m, "BlockIterator") 240 .def("__iter__", &PyBlockIterator::dunderIter) 241 .def("__next__", &PyBlockIterator::dunderNext); 242 } 243 244 private: 245 PyOperationRef operation; 246 MlirBlock next; 247 }; 248 249 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 250 /// we present them as a more full-featured list-like container but optimize 251 /// it for forward iteration. Blocks are always owned by a region. 252 class PyBlockList { 253 public: 254 PyBlockList(PyOperationRef operation, MlirRegion region) 255 : operation(std::move(operation)), region(region) {} 256 257 PyBlockIterator dunderIter() { 258 operation->checkValid(); 259 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 260 } 261 262 intptr_t dunderLen() { 263 operation->checkValid(); 264 intptr_t count = 0; 265 MlirBlock block = mlirRegionGetFirstBlock(region); 266 while (!mlirBlockIsNull(block)) { 267 count += 1; 268 block = mlirBlockGetNextInRegion(block); 269 } 270 return count; 271 } 272 273 PyBlock dunderGetItem(intptr_t index) { 274 operation->checkValid(); 275 if (index < 0) { 276 throw SetPyError(PyExc_IndexError, 277 "attempt to access out of bounds block"); 278 } 279 MlirBlock block = mlirRegionGetFirstBlock(region); 280 while (!mlirBlockIsNull(block)) { 281 if (index == 0) { 282 return PyBlock(operation, block); 283 } 284 block = mlirBlockGetNextInRegion(block); 285 index -= 1; 286 } 287 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 288 } 289 290 PyBlock appendBlock(py::args pyArgTypes) { 291 operation->checkValid(); 292 llvm::SmallVector<MlirType, 4> argTypes; 293 argTypes.reserve(pyArgTypes.size()); 294 for (auto &pyArg : pyArgTypes) { 295 argTypes.push_back(pyArg.cast<PyType &>()); 296 } 297 298 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 299 mlirRegionAppendOwnedBlock(region, block); 300 return PyBlock(operation, block); 301 } 302 303 static void bind(py::module &m) { 304 py::class_<PyBlockList>(m, "BlockList") 305 .def("__getitem__", &PyBlockList::dunderGetItem) 306 .def("__iter__", &PyBlockList::dunderIter) 307 .def("__len__", &PyBlockList::dunderLen) 308 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 309 } 310 311 private: 312 PyOperationRef operation; 313 MlirRegion region; 314 }; 315 316 class PyOperationIterator { 317 public: 318 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 319 : parentOperation(std::move(parentOperation)), next(next) {} 320 321 PyOperationIterator &dunderIter() { return *this; } 322 323 py::object dunderNext() { 324 parentOperation->checkValid(); 325 if (mlirOperationIsNull(next)) { 326 throw py::stop_iteration(); 327 } 328 329 PyOperationRef returnOperation = 330 PyOperation::forOperation(parentOperation->getContext(), next); 331 next = mlirOperationGetNextInBlock(next); 332 return returnOperation->createOpView(); 333 } 334 335 static void bind(py::module &m) { 336 py::class_<PyOperationIterator>(m, "OperationIterator") 337 .def("__iter__", &PyOperationIterator::dunderIter) 338 .def("__next__", &PyOperationIterator::dunderNext); 339 } 340 341 private: 342 PyOperationRef parentOperation; 343 MlirOperation next; 344 }; 345 346 /// Operations are exposed by the C-API as a forward-only linked list. In 347 /// Python, we present them as a more full-featured list-like container but 348 /// optimize it for forward iteration. Iterable operations are always owned 349 /// by a block. 350 class PyOperationList { 351 public: 352 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 353 : parentOperation(std::move(parentOperation)), block(block) {} 354 355 PyOperationIterator dunderIter() { 356 parentOperation->checkValid(); 357 return PyOperationIterator(parentOperation, 358 mlirBlockGetFirstOperation(block)); 359 } 360 361 intptr_t dunderLen() { 362 parentOperation->checkValid(); 363 intptr_t count = 0; 364 MlirOperation childOp = mlirBlockGetFirstOperation(block); 365 while (!mlirOperationIsNull(childOp)) { 366 count += 1; 367 childOp = mlirOperationGetNextInBlock(childOp); 368 } 369 return count; 370 } 371 372 py::object dunderGetItem(intptr_t index) { 373 parentOperation->checkValid(); 374 if (index < 0) { 375 throw SetPyError(PyExc_IndexError, 376 "attempt to access out of bounds operation"); 377 } 378 MlirOperation childOp = mlirBlockGetFirstOperation(block); 379 while (!mlirOperationIsNull(childOp)) { 380 if (index == 0) { 381 return PyOperation::forOperation(parentOperation->getContext(), childOp) 382 ->createOpView(); 383 } 384 childOp = mlirOperationGetNextInBlock(childOp); 385 index -= 1; 386 } 387 throw SetPyError(PyExc_IndexError, 388 "attempt to access out of bounds operation"); 389 } 390 391 static void bind(py::module &m) { 392 py::class_<PyOperationList>(m, "OperationList") 393 .def("__getitem__", &PyOperationList::dunderGetItem) 394 .def("__iter__", &PyOperationList::dunderIter) 395 .def("__len__", &PyOperationList::dunderLen); 396 } 397 398 private: 399 PyOperationRef parentOperation; 400 MlirBlock block; 401 }; 402 403 } // namespace 404 405 //------------------------------------------------------------------------------ 406 // PyMlirContext 407 //------------------------------------------------------------------------------ 408 409 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 410 py::gil_scoped_acquire acquire; 411 auto &liveContexts = getLiveContexts(); 412 liveContexts[context.ptr] = this; 413 } 414 415 PyMlirContext::~PyMlirContext() { 416 // Note that the only public way to construct an instance is via the 417 // forContext method, which always puts the associated handle into 418 // liveContexts. 419 py::gil_scoped_acquire acquire; 420 getLiveContexts().erase(context.ptr); 421 mlirContextDestroy(context); 422 } 423 424 py::object PyMlirContext::getCapsule() { 425 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 426 } 427 428 py::object PyMlirContext::createFromCapsule(py::object capsule) { 429 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 430 if (mlirContextIsNull(rawContext)) 431 throw py::error_already_set(); 432 return forContext(rawContext).releaseObject(); 433 } 434 435 PyMlirContext *PyMlirContext::createNewContextForInit() { 436 MlirContext context = mlirContextCreate(); 437 mlirRegisterAllDialects(context); 438 return new PyMlirContext(context); 439 } 440 441 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 442 py::gil_scoped_acquire acquire; 443 auto &liveContexts = getLiveContexts(); 444 auto it = liveContexts.find(context.ptr); 445 if (it == liveContexts.end()) { 446 // Create. 447 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 448 py::object pyRef = py::cast(unownedContextWrapper); 449 assert(pyRef && "cast to py::object failed"); 450 liveContexts[context.ptr] = unownedContextWrapper; 451 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 452 } 453 // Use existing. 454 py::object pyRef = py::cast(it->second); 455 return PyMlirContextRef(it->second, std::move(pyRef)); 456 } 457 458 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 459 static LiveContextMap liveContexts; 460 return liveContexts; 461 } 462 463 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 464 465 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 466 467 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 468 469 pybind11::object PyMlirContext::contextEnter() { 470 return PyThreadContextEntry::pushContext(*this); 471 } 472 473 void PyMlirContext::contextExit(pybind11::object excType, 474 pybind11::object excVal, 475 pybind11::object excTb) { 476 PyThreadContextEntry::popContext(*this); 477 } 478 479 PyMlirContext &DefaultingPyMlirContext::resolve() { 480 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 481 if (!context) { 482 throw SetPyError( 483 PyExc_RuntimeError, 484 "An MLIR function requires a Context but none was provided in the call " 485 "or from the surrounding environment. Either pass to the function with " 486 "a 'context=' argument or establish a default using 'with Context():'"); 487 } 488 return *context; 489 } 490 491 //------------------------------------------------------------------------------ 492 // PyThreadContextEntry management 493 //------------------------------------------------------------------------------ 494 495 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 496 static thread_local std::vector<PyThreadContextEntry> stack; 497 return stack; 498 } 499 500 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 501 auto &stack = getStack(); 502 if (stack.empty()) 503 return nullptr; 504 return &stack.back(); 505 } 506 507 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 508 py::object insertionPoint, 509 py::object location) { 510 auto &stack = getStack(); 511 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 512 std::move(location)); 513 // If the new stack has more than one entry and the context of the new top 514 // entry matches the previous, copy the insertionPoint and location from the 515 // previous entry if missing from the new top entry. 516 if (stack.size() > 1) { 517 auto &prev = *(stack.rbegin() + 1); 518 auto ¤t = stack.back(); 519 if (current.context.is(prev.context)) { 520 // Default non-context objects from the previous entry. 521 if (!current.insertionPoint) 522 current.insertionPoint = prev.insertionPoint; 523 if (!current.location) 524 current.location = prev.location; 525 } 526 } 527 } 528 529 PyMlirContext *PyThreadContextEntry::getContext() { 530 if (!context) 531 return nullptr; 532 return py::cast<PyMlirContext *>(context); 533 } 534 535 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 536 if (!insertionPoint) 537 return nullptr; 538 return py::cast<PyInsertionPoint *>(insertionPoint); 539 } 540 541 PyLocation *PyThreadContextEntry::getLocation() { 542 if (!location) 543 return nullptr; 544 return py::cast<PyLocation *>(location); 545 } 546 547 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 548 auto *tos = getTopOfStack(); 549 return tos ? tos->getContext() : nullptr; 550 } 551 552 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 553 auto *tos = getTopOfStack(); 554 return tos ? tos->getInsertionPoint() : nullptr; 555 } 556 557 PyLocation *PyThreadContextEntry::getDefaultLocation() { 558 auto *tos = getTopOfStack(); 559 return tos ? tos->getLocation() : nullptr; 560 } 561 562 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 563 py::object contextObj = py::cast(context); 564 push(FrameKind::Context, /*context=*/contextObj, 565 /*insertionPoint=*/py::object(), 566 /*location=*/py::object()); 567 return contextObj; 568 } 569 570 void PyThreadContextEntry::popContext(PyMlirContext &context) { 571 auto &stack = getStack(); 572 if (stack.empty()) 573 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 574 auto &tos = stack.back(); 575 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 576 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 577 stack.pop_back(); 578 } 579 580 py::object 581 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 582 py::object contextObj = 583 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 584 py::object insertionPointObj = py::cast(insertionPoint); 585 push(FrameKind::InsertionPoint, 586 /*context=*/contextObj, 587 /*insertionPoint=*/insertionPointObj, 588 /*location=*/py::object()); 589 return insertionPointObj; 590 } 591 592 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 593 auto &stack = getStack(); 594 if (stack.empty()) 595 throw SetPyError(PyExc_RuntimeError, 596 "Unbalanced InsertionPoint enter/exit"); 597 auto &tos = stack.back(); 598 if (tos.frameKind != FrameKind::InsertionPoint && 599 tos.getInsertionPoint() != &insertionPoint) 600 throw SetPyError(PyExc_RuntimeError, 601 "Unbalanced InsertionPoint enter/exit"); 602 stack.pop_back(); 603 } 604 605 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 606 py::object contextObj = location.getContext().getObject(); 607 py::object locationObj = py::cast(location); 608 push(FrameKind::Location, /*context=*/contextObj, 609 /*insertionPoint=*/py::object(), 610 /*location=*/locationObj); 611 return locationObj; 612 } 613 614 void PyThreadContextEntry::popLocation(PyLocation &location) { 615 auto &stack = getStack(); 616 if (stack.empty()) 617 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 618 auto &tos = stack.back(); 619 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 620 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 621 stack.pop_back(); 622 } 623 624 //------------------------------------------------------------------------------ 625 // PyDialect, PyDialectDescriptor, PyDialects 626 //------------------------------------------------------------------------------ 627 628 MlirDialect PyDialects::getDialectForKey(const std::string &key, 629 bool attrError) { 630 // If the "std" dialect was asked for, substitute the empty namespace :( 631 static const std::string emptyKey; 632 const std::string *canonKey = key == "std" ? &emptyKey : &key; 633 MlirDialect dialect = mlirContextGetOrLoadDialect( 634 getContext()->get(), {canonKey->data(), canonKey->size()}); 635 if (mlirDialectIsNull(dialect)) { 636 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 637 Twine("Dialect '") + key + "' not found"); 638 } 639 return dialect; 640 } 641 642 //------------------------------------------------------------------------------ 643 // PyLocation 644 //------------------------------------------------------------------------------ 645 646 py::object PyLocation::getCapsule() { 647 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 648 } 649 650 PyLocation PyLocation::createFromCapsule(py::object capsule) { 651 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 652 if (mlirLocationIsNull(rawLoc)) 653 throw py::error_already_set(); 654 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 655 rawLoc); 656 } 657 658 py::object PyLocation::contextEnter() { 659 return PyThreadContextEntry::pushLocation(*this); 660 } 661 662 void PyLocation::contextExit(py::object excType, py::object excVal, 663 py::object excTb) { 664 PyThreadContextEntry::popLocation(*this); 665 } 666 667 PyLocation &DefaultingPyLocation::resolve() { 668 auto *location = PyThreadContextEntry::getDefaultLocation(); 669 if (!location) { 670 throw SetPyError( 671 PyExc_RuntimeError, 672 "An MLIR function requires a Location but none was provided in the " 673 "call or from the surrounding environment. Either pass to the function " 674 "with a 'loc=' argument or establish a default using 'with loc:'"); 675 } 676 return *location; 677 } 678 679 //------------------------------------------------------------------------------ 680 // PyModule 681 //------------------------------------------------------------------------------ 682 683 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 684 : BaseContextObject(std::move(contextRef)), module(module) {} 685 686 PyModule::~PyModule() { 687 py::gil_scoped_acquire acquire; 688 auto &liveModules = getContext()->liveModules; 689 assert(liveModules.count(module.ptr) == 1 && 690 "destroying module not in live map"); 691 liveModules.erase(module.ptr); 692 mlirModuleDestroy(module); 693 } 694 695 PyModuleRef PyModule::forModule(MlirModule module) { 696 MlirContext context = mlirModuleGetContext(module); 697 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 698 699 py::gil_scoped_acquire acquire; 700 auto &liveModules = contextRef->liveModules; 701 auto it = liveModules.find(module.ptr); 702 if (it == liveModules.end()) { 703 // Create. 704 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 705 // Note that the default return value policy on cast is automatic_reference, 706 // which does not take ownership (delete will not be called). 707 // Just be explicit. 708 py::object pyRef = 709 py::cast(unownedModule, py::return_value_policy::take_ownership); 710 unownedModule->handle = pyRef; 711 liveModules[module.ptr] = 712 std::make_pair(unownedModule->handle, unownedModule); 713 return PyModuleRef(unownedModule, std::move(pyRef)); 714 } 715 // Use existing. 716 PyModule *existing = it->second.second; 717 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 718 return PyModuleRef(existing, std::move(pyRef)); 719 } 720 721 py::object PyModule::createFromCapsule(py::object capsule) { 722 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 723 if (mlirModuleIsNull(rawModule)) 724 throw py::error_already_set(); 725 return forModule(rawModule).releaseObject(); 726 } 727 728 py::object PyModule::getCapsule() { 729 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 730 } 731 732 //------------------------------------------------------------------------------ 733 // PyOperation 734 //------------------------------------------------------------------------------ 735 736 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 737 : BaseContextObject(std::move(contextRef)), operation(operation) {} 738 739 PyOperation::~PyOperation() { 740 auto &liveOperations = getContext()->liveOperations; 741 assert(liveOperations.count(operation.ptr) == 1 && 742 "destroying operation not in live map"); 743 liveOperations.erase(operation.ptr); 744 if (!isAttached()) { 745 mlirOperationDestroy(operation); 746 } 747 } 748 749 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 750 MlirOperation operation, 751 py::object parentKeepAlive) { 752 auto &liveOperations = contextRef->liveOperations; 753 // Create. 754 PyOperation *unownedOperation = 755 new PyOperation(std::move(contextRef), operation); 756 // Note that the default return value policy on cast is automatic_reference, 757 // which does not take ownership (delete will not be called). 758 // Just be explicit. 759 py::object pyRef = 760 py::cast(unownedOperation, py::return_value_policy::take_ownership); 761 unownedOperation->handle = pyRef; 762 if (parentKeepAlive) { 763 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 764 } 765 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 766 return PyOperationRef(unownedOperation, std::move(pyRef)); 767 } 768 769 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 770 MlirOperation operation, 771 py::object parentKeepAlive) { 772 auto &liveOperations = contextRef->liveOperations; 773 auto it = liveOperations.find(operation.ptr); 774 if (it == liveOperations.end()) { 775 // Create. 776 return createInstance(std::move(contextRef), operation, 777 std::move(parentKeepAlive)); 778 } 779 // Use existing. 780 PyOperation *existing = it->second.second; 781 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 782 return PyOperationRef(existing, std::move(pyRef)); 783 } 784 785 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 786 MlirOperation operation, 787 py::object parentKeepAlive) { 788 auto &liveOperations = contextRef->liveOperations; 789 assert(liveOperations.count(operation.ptr) == 0 && 790 "cannot create detached operation that already exists"); 791 (void)liveOperations; 792 793 PyOperationRef created = createInstance(std::move(contextRef), operation, 794 std::move(parentKeepAlive)); 795 created->attached = false; 796 return created; 797 } 798 799 void PyOperation::checkValid() const { 800 if (!valid) { 801 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 802 } 803 } 804 805 void PyOperationBase::print(py::object fileObject, bool binary, 806 llvm::Optional<int64_t> largeElementsLimit, 807 bool enableDebugInfo, bool prettyDebugInfo, 808 bool printGenericOpForm, bool useLocalScope) { 809 PyOperation &operation = getOperation(); 810 operation.checkValid(); 811 if (fileObject.is_none()) 812 fileObject = py::module::import("sys").attr("stdout"); 813 814 if (!printGenericOpForm && !mlirOperationVerify(operation)) { 815 fileObject.attr("write")("// Verification failed, printing generic form\n"); 816 printGenericOpForm = true; 817 } 818 819 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 820 if (largeElementsLimit) 821 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 822 if (enableDebugInfo) 823 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 824 if (printGenericOpForm) 825 mlirOpPrintingFlagsPrintGenericOpForm(flags); 826 827 PyFileAccumulator accum(fileObject, binary); 828 py::gil_scoped_release(); 829 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 830 accum.getUserData()); 831 mlirOpPrintingFlagsDestroy(flags); 832 } 833 834 py::object PyOperationBase::getAsm(bool binary, 835 llvm::Optional<int64_t> largeElementsLimit, 836 bool enableDebugInfo, bool prettyDebugInfo, 837 bool printGenericOpForm, 838 bool useLocalScope) { 839 py::object fileObject; 840 if (binary) { 841 fileObject = py::module::import("io").attr("BytesIO")(); 842 } else { 843 fileObject = py::module::import("io").attr("StringIO")(); 844 } 845 print(fileObject, /*binary=*/binary, 846 /*largeElementsLimit=*/largeElementsLimit, 847 /*enableDebugInfo=*/enableDebugInfo, 848 /*prettyDebugInfo=*/prettyDebugInfo, 849 /*printGenericOpForm=*/printGenericOpForm, 850 /*useLocalScope=*/useLocalScope); 851 852 return fileObject.attr("getvalue")(); 853 } 854 855 PyOperationRef PyOperation::getParentOperation() { 856 if (!isAttached()) 857 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 858 MlirOperation operation = mlirOperationGetParentOperation(get()); 859 if (mlirOperationIsNull(operation)) 860 throw SetPyError(PyExc_ValueError, "Operation has no parent."); 861 return PyOperation::forOperation(getContext(), operation); 862 } 863 864 PyBlock PyOperation::getBlock() { 865 PyOperationRef parentOperation = getParentOperation(); 866 MlirBlock block = mlirOperationGetBlock(get()); 867 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 868 return PyBlock{std::move(parentOperation), block}; 869 } 870 871 py::object PyOperation::getCapsule() { 872 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 873 } 874 875 py::object PyOperation::createFromCapsule(py::object capsule) { 876 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 877 if (mlirOperationIsNull(rawOperation)) 878 throw py::error_already_set(); 879 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 880 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 881 .releaseObject(); 882 } 883 884 py::object PyOperation::create( 885 std::string name, llvm::Optional<std::vector<PyType *>> results, 886 llvm::Optional<std::vector<PyValue *>> operands, 887 llvm::Optional<py::dict> attributes, 888 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 889 DefaultingPyLocation location, py::object maybeIp) { 890 llvm::SmallVector<MlirValue, 4> mlirOperands; 891 llvm::SmallVector<MlirType, 4> mlirResults; 892 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 893 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 894 895 // General parameter validation. 896 if (regions < 0) 897 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 898 899 // Unpack/validate operands. 900 if (operands) { 901 mlirOperands.reserve(operands->size()); 902 for (PyValue *operand : *operands) { 903 if (!operand) 904 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 905 mlirOperands.push_back(operand->get()); 906 } 907 } 908 909 // Unpack/validate results. 910 if (results) { 911 mlirResults.reserve(results->size()); 912 for (PyType *result : *results) { 913 // TODO: Verify result type originate from the same context. 914 if (!result) 915 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 916 mlirResults.push_back(*result); 917 } 918 } 919 // Unpack/validate attributes. 920 if (attributes) { 921 mlirAttributes.reserve(attributes->size()); 922 for (auto &it : *attributes) { 923 std::string key; 924 try { 925 key = it.first.cast<std::string>(); 926 } catch (py::cast_error &err) { 927 std::string msg = "Invalid attribute key (not a string) when " 928 "attempting to create the operation \"" + 929 name + "\" (" + err.what() + ")"; 930 throw py::cast_error(msg); 931 } 932 try { 933 auto &attribute = it.second.cast<PyAttribute &>(); 934 // TODO: Verify attribute originates from the same context. 935 mlirAttributes.emplace_back(std::move(key), attribute); 936 } catch (py::reference_cast_error &) { 937 // This exception seems thrown when the value is "None". 938 std::string msg = 939 "Found an invalid (`None`?) attribute value for the key \"" + key + 940 "\" when attempting to create the operation \"" + name + "\""; 941 throw py::cast_error(msg); 942 } catch (py::cast_error &err) { 943 std::string msg = "Invalid attribute value for the key \"" + key + 944 "\" when attempting to create the operation \"" + 945 name + "\" (" + err.what() + ")"; 946 throw py::cast_error(msg); 947 } 948 } 949 } 950 // Unpack/validate successors. 951 if (successors) { 952 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 953 mlirSuccessors.reserve(successors->size()); 954 for (auto *successor : *successors) { 955 // TODO: Verify successor originate from the same context. 956 if (!successor) 957 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 958 mlirSuccessors.push_back(successor->get()); 959 } 960 } 961 962 // Apply unpacked/validated to the operation state. Beyond this 963 // point, exceptions cannot be thrown or else the state will leak. 964 MlirOperationState state = 965 mlirOperationStateGet(toMlirStringRef(name), location); 966 if (!mlirOperands.empty()) 967 mlirOperationStateAddOperands(&state, mlirOperands.size(), 968 mlirOperands.data()); 969 if (!mlirResults.empty()) 970 mlirOperationStateAddResults(&state, mlirResults.size(), 971 mlirResults.data()); 972 if (!mlirAttributes.empty()) { 973 // Note that the attribute names directly reference bytes in 974 // mlirAttributes, so that vector must not be changed from here 975 // on. 976 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 977 mlirNamedAttributes.reserve(mlirAttributes.size()); 978 for (auto &it : mlirAttributes) 979 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 980 mlirIdentifierGet(mlirAttributeGetContext(it.second), 981 toMlirStringRef(it.first)), 982 it.second)); 983 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 984 mlirNamedAttributes.data()); 985 } 986 if (!mlirSuccessors.empty()) 987 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 988 mlirSuccessors.data()); 989 if (regions) { 990 llvm::SmallVector<MlirRegion, 4> mlirRegions; 991 mlirRegions.resize(regions); 992 for (int i = 0; i < regions; ++i) 993 mlirRegions[i] = mlirRegionCreate(); 994 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 995 mlirRegions.data()); 996 } 997 998 // Construct the operation. 999 MlirOperation operation = mlirOperationCreate(&state); 1000 PyOperationRef created = 1001 PyOperation::createDetached(location->getContext(), operation); 1002 1003 // InsertPoint active? 1004 if (!maybeIp.is(py::cast(false))) { 1005 PyInsertionPoint *ip; 1006 if (maybeIp.is_none()) { 1007 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1008 } else { 1009 ip = py::cast<PyInsertionPoint *>(maybeIp); 1010 } 1011 if (ip) 1012 ip->insert(*created.get()); 1013 } 1014 1015 return created->createOpView(); 1016 } 1017 1018 py::object PyOperation::createOpView() { 1019 MlirIdentifier ident = mlirOperationGetName(get()); 1020 MlirStringRef identStr = mlirIdentifierStr(ident); 1021 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1022 StringRef(identStr.data, identStr.length)); 1023 if (opViewClass) 1024 return (*opViewClass)(getRef().getObject()); 1025 return py::cast(PyOpView(getRef().getObject())); 1026 } 1027 1028 //------------------------------------------------------------------------------ 1029 // PyOpView 1030 //------------------------------------------------------------------------------ 1031 1032 py::object 1033 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1034 py::list operandList, 1035 llvm::Optional<py::dict> attributes, 1036 llvm::Optional<std::vector<PyBlock *>> successors, 1037 llvm::Optional<int> regions, 1038 DefaultingPyLocation location, py::object maybeIp) { 1039 PyMlirContextRef context = location->getContext(); 1040 // Class level operation construction metadata. 1041 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1042 // Operand and result segment specs are either none, which does no 1043 // variadic unpacking, or a list of ints with segment sizes, where each 1044 // element is either a positive number (typically 1 for a scalar) or -1 to 1045 // indicate that it is derived from the length of the same-indexed operand 1046 // or result (implying that it is a list at that position). 1047 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1048 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1049 1050 std::vector<uint32_t> operandSegmentLengths; 1051 std::vector<uint32_t> resultSegmentLengths; 1052 1053 // Validate/determine region count. 1054 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1055 int opMinRegionCount = std::get<0>(opRegionSpec); 1056 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1057 if (!regions) { 1058 regions = opMinRegionCount; 1059 } 1060 if (*regions < opMinRegionCount) { 1061 throw py::value_error( 1062 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1063 llvm::Twine(opMinRegionCount) + 1064 " regions but was built with regions=" + llvm::Twine(*regions)) 1065 .str()); 1066 } 1067 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1068 throw py::value_error( 1069 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1070 llvm::Twine(opMinRegionCount) + 1071 " regions but was built with regions=" + llvm::Twine(*regions)) 1072 .str()); 1073 } 1074 1075 // Unpack results. 1076 std::vector<PyType *> resultTypes; 1077 resultTypes.reserve(resultTypeList.size()); 1078 if (resultSegmentSpecObj.is_none()) { 1079 // Non-variadic result unpacking. 1080 for (auto it : llvm::enumerate(resultTypeList)) { 1081 try { 1082 resultTypes.push_back(py::cast<PyType *>(it.value())); 1083 if (!resultTypes.back()) 1084 throw py::cast_error(); 1085 } catch (py::cast_error &err) { 1086 throw py::value_error((llvm::Twine("Result ") + 1087 llvm::Twine(it.index()) + " of operation \"" + 1088 name + "\" must be a Type (" + err.what() + ")") 1089 .str()); 1090 } 1091 } 1092 } else { 1093 // Sized result unpacking. 1094 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1095 if (resultSegmentSpec.size() != resultTypeList.size()) { 1096 throw py::value_error((llvm::Twine("Operation \"") + name + 1097 "\" requires " + 1098 llvm::Twine(resultSegmentSpec.size()) + 1099 "result segments but was provided " + 1100 llvm::Twine(resultTypeList.size())) 1101 .str()); 1102 } 1103 resultSegmentLengths.reserve(resultTypeList.size()); 1104 for (auto it : 1105 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1106 int segmentSpec = std::get<1>(it.value()); 1107 if (segmentSpec == 1 || segmentSpec == 0) { 1108 // Unpack unary element. 1109 try { 1110 auto resultType = py::cast<PyType *>(std::get<0>(it.value())); 1111 if (resultType) { 1112 resultTypes.push_back(resultType); 1113 resultSegmentLengths.push_back(1); 1114 } else if (segmentSpec == 0) { 1115 // Allowed to be optional. 1116 resultSegmentLengths.push_back(0); 1117 } else { 1118 throw py::cast_error("was None and result is not optional"); 1119 } 1120 } catch (py::cast_error &err) { 1121 throw py::value_error((llvm::Twine("Result ") + 1122 llvm::Twine(it.index()) + " of operation \"" + 1123 name + "\" must be a Type (" + err.what() + 1124 ")") 1125 .str()); 1126 } 1127 } else if (segmentSpec == -1) { 1128 // Unpack sequence by appending. 1129 try { 1130 if (std::get<0>(it.value()).is_none()) { 1131 // Treat it as an empty list. 1132 resultSegmentLengths.push_back(0); 1133 } else { 1134 // Unpack the list. 1135 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1136 for (py::object segmentItem : segment) { 1137 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1138 if (!resultTypes.back()) { 1139 throw py::cast_error("contained a None item"); 1140 } 1141 } 1142 resultSegmentLengths.push_back(segment.size()); 1143 } 1144 } catch (std::exception &err) { 1145 // NOTE: Sloppy to be using a catch-all here, but there are at least 1146 // three different unrelated exceptions that can be thrown in the 1147 // above "casts". Just keep the scope above small and catch them all. 1148 throw py::value_error((llvm::Twine("Result ") + 1149 llvm::Twine(it.index()) + " of operation \"" + 1150 name + "\" must be a Sequence of Types (" + 1151 err.what() + ")") 1152 .str()); 1153 } 1154 } else { 1155 throw py::value_error("Unexpected segment spec"); 1156 } 1157 } 1158 } 1159 1160 // Unpack operands. 1161 std::vector<PyValue *> operands; 1162 operands.reserve(operands.size()); 1163 if (operandSegmentSpecObj.is_none()) { 1164 // Non-sized operand unpacking. 1165 for (auto it : llvm::enumerate(operandList)) { 1166 try { 1167 operands.push_back(py::cast<PyValue *>(it.value())); 1168 if (!operands.back()) 1169 throw py::cast_error(); 1170 } catch (py::cast_error &err) { 1171 throw py::value_error((llvm::Twine("Operand ") + 1172 llvm::Twine(it.index()) + " of operation \"" + 1173 name + "\" must be a Value (" + err.what() + ")") 1174 .str()); 1175 } 1176 } 1177 } else { 1178 // Sized operand unpacking. 1179 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1180 if (operandSegmentSpec.size() != operandList.size()) { 1181 throw py::value_error((llvm::Twine("Operation \"") + name + 1182 "\" requires " + 1183 llvm::Twine(operandSegmentSpec.size()) + 1184 "operand segments but was provided " + 1185 llvm::Twine(operandList.size())) 1186 .str()); 1187 } 1188 operandSegmentLengths.reserve(operandList.size()); 1189 for (auto it : 1190 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1191 int segmentSpec = std::get<1>(it.value()); 1192 if (segmentSpec == 1 || segmentSpec == 0) { 1193 // Unpack unary element. 1194 try { 1195 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1196 if (operandValue) { 1197 operands.push_back(operandValue); 1198 operandSegmentLengths.push_back(1); 1199 } else if (segmentSpec == 0) { 1200 // Allowed to be optional. 1201 operandSegmentLengths.push_back(0); 1202 } else { 1203 throw py::cast_error("was None and operand is not optional"); 1204 } 1205 } catch (py::cast_error &err) { 1206 throw py::value_error((llvm::Twine("Operand ") + 1207 llvm::Twine(it.index()) + " of operation \"" + 1208 name + "\" must be a Value (" + err.what() + 1209 ")") 1210 .str()); 1211 } 1212 } else if (segmentSpec == -1) { 1213 // Unpack sequence by appending. 1214 try { 1215 if (std::get<0>(it.value()).is_none()) { 1216 // Treat it as an empty list. 1217 operandSegmentLengths.push_back(0); 1218 } else { 1219 // Unpack the list. 1220 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1221 for (py::object segmentItem : segment) { 1222 operands.push_back(py::cast<PyValue *>(segmentItem)); 1223 if (!operands.back()) { 1224 throw py::cast_error("contained a None item"); 1225 } 1226 } 1227 operandSegmentLengths.push_back(segment.size()); 1228 } 1229 } catch (std::exception &err) { 1230 // NOTE: Sloppy to be using a catch-all here, but there are at least 1231 // three different unrelated exceptions that can be thrown in the 1232 // above "casts". Just keep the scope above small and catch them all. 1233 throw py::value_error((llvm::Twine("Operand ") + 1234 llvm::Twine(it.index()) + " of operation \"" + 1235 name + "\" must be a Sequence of Values (" + 1236 err.what() + ")") 1237 .str()); 1238 } 1239 } else { 1240 throw py::value_error("Unexpected segment spec"); 1241 } 1242 } 1243 } 1244 1245 // Merge operand/result segment lengths into attributes if needed. 1246 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1247 // Dup. 1248 if (attributes) { 1249 attributes = py::dict(*attributes); 1250 } else { 1251 attributes = py::dict(); 1252 } 1253 if (attributes->contains("result_segment_sizes") || 1254 attributes->contains("operand_segment_sizes")) { 1255 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1256 "'operand_segment_sizes' attribute is unsupported. " 1257 "Use Operation.create for such low-level access."); 1258 } 1259 1260 // Add result_segment_sizes attribute. 1261 if (!resultSegmentLengths.empty()) { 1262 int64_t size = resultSegmentLengths.size(); 1263 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1264 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1265 resultSegmentLengths.size(), resultSegmentLengths.data()); 1266 (*attributes)["result_segment_sizes"] = 1267 PyAttribute(context, segmentLengthAttr); 1268 } 1269 1270 // Add operand_segment_sizes attribute. 1271 if (!operandSegmentLengths.empty()) { 1272 int64_t size = operandSegmentLengths.size(); 1273 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1274 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1275 operandSegmentLengths.size(), operandSegmentLengths.data()); 1276 (*attributes)["operand_segment_sizes"] = 1277 PyAttribute(context, segmentLengthAttr); 1278 } 1279 } 1280 1281 // Delegate to create. 1282 return PyOperation::create(std::move(name), 1283 /*results=*/std::move(resultTypes), 1284 /*operands=*/std::move(operands), 1285 /*attributes=*/std::move(attributes), 1286 /*successors=*/std::move(successors), 1287 /*regions=*/*regions, location, maybeIp); 1288 } 1289 1290 PyOpView::PyOpView(py::object operationObject) 1291 // Casting through the PyOperationBase base-class and then back to the 1292 // Operation lets us accept any PyOperationBase subclass. 1293 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1294 operationObject(operation.getRef().getObject()) {} 1295 1296 py::object PyOpView::createRawSubclass(py::object userClass) { 1297 // This is... a little gross. The typical pattern is to have a pure python 1298 // class that extends OpView like: 1299 // class AddFOp(_cext.ir.OpView): 1300 // def __init__(self, loc, lhs, rhs): 1301 // operation = loc.context.create_operation( 1302 // "addf", lhs, rhs, results=[lhs.type]) 1303 // super().__init__(operation) 1304 // 1305 // I.e. The goal of the user facing type is to provide a nice constructor 1306 // that has complete freedom for the op under construction. This is at odds 1307 // with our other desire to sometimes create this object by just passing an 1308 // operation (to initialize the base class). We could do *arg and **kwargs 1309 // munging to try to make it work, but instead, we synthesize a new class 1310 // on the fly which extends this user class (AddFOp in this example) and 1311 // *give it* the base class's __init__ method, thus bypassing the 1312 // intermediate subclass's __init__ method entirely. While slightly, 1313 // underhanded, this is safe/legal because the type hierarchy has not changed 1314 // (we just added a new leaf) and we aren't mucking around with __new__. 1315 // Typically, this new class will be stored on the original as "_Raw" and will 1316 // be used for casts and other things that need a variant of the class that 1317 // is initialized purely from an operation. 1318 py::object parentMetaclass = 1319 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1320 py::dict attributes; 1321 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1322 // now. 1323 // auto opViewType = py::type::of<PyOpView>(); 1324 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1325 attributes["__init__"] = opViewType.attr("__init__"); 1326 py::str origName = userClass.attr("__name__"); 1327 py::str newName = py::str("_") + origName; 1328 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1329 } 1330 1331 //------------------------------------------------------------------------------ 1332 // PyInsertionPoint. 1333 //------------------------------------------------------------------------------ 1334 1335 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1336 1337 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1338 : refOperation(beforeOperationBase.getOperation().getRef()), 1339 block((*refOperation)->getBlock()) {} 1340 1341 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1342 PyOperation &operation = operationBase.getOperation(); 1343 if (operation.isAttached()) 1344 throw SetPyError(PyExc_ValueError, 1345 "Attempt to insert operation that is already attached"); 1346 block.getParentOperation()->checkValid(); 1347 MlirOperation beforeOp = {nullptr}; 1348 if (refOperation) { 1349 // Insert before operation. 1350 (*refOperation)->checkValid(); 1351 beforeOp = (*refOperation)->get(); 1352 } else { 1353 // Insert at end (before null) is only valid if the block does not 1354 // already end in a known terminator (violating this will cause assertion 1355 // failures later). 1356 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1357 throw py::index_error("Cannot insert operation at the end of a block " 1358 "that already has a terminator. Did you mean to " 1359 "use 'InsertionPoint.at_block_terminator(block)' " 1360 "versus 'InsertionPoint(block)'?"); 1361 } 1362 } 1363 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1364 operation.setAttached(); 1365 } 1366 1367 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1368 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1369 if (mlirOperationIsNull(firstOp)) { 1370 // Just insert at end. 1371 return PyInsertionPoint(block); 1372 } 1373 1374 // Insert before first op. 1375 PyOperationRef firstOpRef = PyOperation::forOperation( 1376 block.getParentOperation()->getContext(), firstOp); 1377 return PyInsertionPoint{block, std::move(firstOpRef)}; 1378 } 1379 1380 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1381 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1382 if (mlirOperationIsNull(terminator)) 1383 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1384 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1385 block.getParentOperation()->getContext(), terminator); 1386 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1387 } 1388 1389 py::object PyInsertionPoint::contextEnter() { 1390 return PyThreadContextEntry::pushInsertionPoint(*this); 1391 } 1392 1393 void PyInsertionPoint::contextExit(pybind11::object excType, 1394 pybind11::object excVal, 1395 pybind11::object excTb) { 1396 PyThreadContextEntry::popInsertionPoint(*this); 1397 } 1398 1399 //------------------------------------------------------------------------------ 1400 // PyAttribute. 1401 //------------------------------------------------------------------------------ 1402 1403 bool PyAttribute::operator==(const PyAttribute &other) { 1404 return mlirAttributeEqual(attr, other.attr); 1405 } 1406 1407 py::object PyAttribute::getCapsule() { 1408 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1409 } 1410 1411 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1412 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1413 if (mlirAttributeIsNull(rawAttr)) 1414 throw py::error_already_set(); 1415 return PyAttribute( 1416 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1417 } 1418 1419 //------------------------------------------------------------------------------ 1420 // PyNamedAttribute. 1421 //------------------------------------------------------------------------------ 1422 1423 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1424 : ownedName(new std::string(std::move(ownedName))) { 1425 namedAttr = mlirNamedAttributeGet( 1426 mlirIdentifierGet(mlirAttributeGetContext(attr), 1427 toMlirStringRef(*this->ownedName)), 1428 attr); 1429 } 1430 1431 //------------------------------------------------------------------------------ 1432 // PyType. 1433 //------------------------------------------------------------------------------ 1434 1435 bool PyType::operator==(const PyType &other) { 1436 return mlirTypeEqual(type, other.type); 1437 } 1438 1439 py::object PyType::getCapsule() { 1440 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1441 } 1442 1443 PyType PyType::createFromCapsule(py::object capsule) { 1444 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1445 if (mlirTypeIsNull(rawType)) 1446 throw py::error_already_set(); 1447 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1448 rawType); 1449 } 1450 1451 //------------------------------------------------------------------------------ 1452 // PyValue and subclases. 1453 //------------------------------------------------------------------------------ 1454 1455 namespace { 1456 /// CRTP base class for Python MLIR values that subclass Value and should be 1457 /// castable from it. The value hierarchy is one level deep and is not supposed 1458 /// to accommodate other levels unless core MLIR changes. 1459 template <typename DerivedTy> 1460 class PyConcreteValue : public PyValue { 1461 public: 1462 // Derived classes must define statics for: 1463 // IsAFunctionTy isaFunction 1464 // const char *pyClassName 1465 // and redefine bindDerived. 1466 using ClassTy = py::class_<DerivedTy, PyValue>; 1467 using IsAFunctionTy = bool (*)(MlirValue); 1468 1469 PyConcreteValue() = default; 1470 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1471 : PyValue(operationRef, value) {} 1472 PyConcreteValue(PyValue &orig) 1473 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1474 1475 /// Attempts to cast the original value to the derived type and throws on 1476 /// type mismatches. 1477 static MlirValue castFrom(PyValue &orig) { 1478 if (!DerivedTy::isaFunction(orig.get())) { 1479 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1480 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1481 DerivedTy::pyClassName + 1482 " (from " + origRepr + ")"); 1483 } 1484 return orig.get(); 1485 } 1486 1487 /// Binds the Python module objects to functions of this class. 1488 static void bind(py::module &m) { 1489 auto cls = ClassTy(m, DerivedTy::pyClassName); 1490 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); 1491 DerivedTy::bindDerived(cls); 1492 } 1493 1494 /// Implemented by derived classes to add methods to the Python subclass. 1495 static void bindDerived(ClassTy &m) {} 1496 }; 1497 1498 /// Python wrapper for MlirBlockArgument. 1499 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1500 public: 1501 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1502 static constexpr const char *pyClassName = "BlockArgument"; 1503 using PyConcreteValue::PyConcreteValue; 1504 1505 static void bindDerived(ClassTy &c) { 1506 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1507 return PyBlock(self.getParentOperation(), 1508 mlirBlockArgumentGetOwner(self.get())); 1509 }); 1510 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1511 return mlirBlockArgumentGetArgNumber(self.get()); 1512 }); 1513 c.def("set_type", [](PyBlockArgument &self, PyType type) { 1514 return mlirBlockArgumentSetType(self.get(), type); 1515 }); 1516 } 1517 }; 1518 1519 /// Python wrapper for MlirOpResult. 1520 class PyOpResult : public PyConcreteValue<PyOpResult> { 1521 public: 1522 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1523 static constexpr const char *pyClassName = "OpResult"; 1524 using PyConcreteValue::PyConcreteValue; 1525 1526 static void bindDerived(ClassTy &c) { 1527 c.def_property_readonly("owner", [](PyOpResult &self) { 1528 assert( 1529 mlirOperationEqual(self.getParentOperation()->get(), 1530 mlirOpResultGetOwner(self.get())) && 1531 "expected the owner of the value in Python to match that in the IR"); 1532 return self.getParentOperation(); 1533 }); 1534 c.def_property_readonly("result_number", [](PyOpResult &self) { 1535 return mlirOpResultGetResultNumber(self.get()); 1536 }); 1537 } 1538 }; 1539 1540 /// A list of block arguments. Internally, these are stored as consecutive 1541 /// elements, random access is cheap. The argument list is associated with the 1542 /// operation that contains the block (detached blocks are not allowed in 1543 /// Python bindings) and extends its lifetime. 1544 class PyBlockArgumentList { 1545 public: 1546 PyBlockArgumentList(PyOperationRef operation, MlirBlock block) 1547 : operation(std::move(operation)), block(block) {} 1548 1549 /// Returns the length of the block argument list. 1550 intptr_t dunderLen() { 1551 operation->checkValid(); 1552 return mlirBlockGetNumArguments(block); 1553 } 1554 1555 /// Returns `index`-th element of the block argument list. 1556 PyBlockArgument dunderGetItem(intptr_t index) { 1557 if (index < 0 || index >= dunderLen()) { 1558 throw SetPyError(PyExc_IndexError, 1559 "attempt to access out of bounds region"); 1560 } 1561 PyValue value(operation, mlirBlockGetArgument(block, index)); 1562 return PyBlockArgument(value); 1563 } 1564 1565 /// Defines a Python class in the bindings. 1566 static void bind(py::module &m) { 1567 py::class_<PyBlockArgumentList>(m, "BlockArgumentList") 1568 .def("__len__", &PyBlockArgumentList::dunderLen) 1569 .def("__getitem__", &PyBlockArgumentList::dunderGetItem); 1570 } 1571 1572 private: 1573 PyOperationRef operation; 1574 MlirBlock block; 1575 }; 1576 1577 /// A list of operation operands. Internally, these are stored as consecutive 1578 /// elements, random access is cheap. The result list is associated with the 1579 /// operation whose results these are, and extends the lifetime of this 1580 /// operation. 1581 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1582 public: 1583 static constexpr const char *pyClassName = "OpOperandList"; 1584 1585 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1586 intptr_t length = -1, intptr_t step = 1) 1587 : Sliceable(startIndex, 1588 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1589 : length, 1590 step), 1591 operation(operation) {} 1592 1593 intptr_t getNumElements() { 1594 operation->checkValid(); 1595 return mlirOperationGetNumOperands(operation->get()); 1596 } 1597 1598 PyValue getElement(intptr_t pos) { 1599 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); 1600 } 1601 1602 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1603 return PyOpOperandList(operation, startIndex, length, step); 1604 } 1605 1606 private: 1607 PyOperationRef operation; 1608 }; 1609 1610 /// A list of operation results. Internally, these are stored as consecutive 1611 /// elements, random access is cheap. The result list is associated with the 1612 /// operation whose results these are, and extends the lifetime of this 1613 /// operation. 1614 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1615 public: 1616 static constexpr const char *pyClassName = "OpResultList"; 1617 1618 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1619 intptr_t length = -1, intptr_t step = 1) 1620 : Sliceable(startIndex, 1621 length == -1 ? mlirOperationGetNumResults(operation->get()) 1622 : length, 1623 step), 1624 operation(operation) {} 1625 1626 intptr_t getNumElements() { 1627 operation->checkValid(); 1628 return mlirOperationGetNumResults(operation->get()); 1629 } 1630 1631 PyOpResult getElement(intptr_t index) { 1632 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1633 return PyOpResult(value); 1634 } 1635 1636 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1637 return PyOpResultList(operation, startIndex, length, step); 1638 } 1639 1640 private: 1641 PyOperationRef operation; 1642 }; 1643 1644 /// A list of operation attributes. Can be indexed by name, producing 1645 /// attributes, or by index, producing named attributes. 1646 class PyOpAttributeMap { 1647 public: 1648 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1649 1650 PyAttribute dunderGetItemNamed(const std::string &name) { 1651 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1652 toMlirStringRef(name)); 1653 if (mlirAttributeIsNull(attr)) { 1654 throw SetPyError(PyExc_KeyError, 1655 "attempt to access a non-existent attribute"); 1656 } 1657 return PyAttribute(operation->getContext(), attr); 1658 } 1659 1660 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1661 if (index < 0 || index >= dunderLen()) { 1662 throw SetPyError(PyExc_IndexError, 1663 "attempt to access out of bounds attribute"); 1664 } 1665 MlirNamedAttribute namedAttr = 1666 mlirOperationGetAttribute(operation->get(), index); 1667 return PyNamedAttribute( 1668 namedAttr.attribute, 1669 std::string(mlirIdentifierStr(namedAttr.name).data)); 1670 } 1671 1672 void dunderSetItem(const std::string &name, PyAttribute attr) { 1673 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1674 attr); 1675 } 1676 1677 void dunderDelItem(const std::string &name) { 1678 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1679 toMlirStringRef(name)); 1680 if (!removed) 1681 throw SetPyError(PyExc_KeyError, 1682 "attempt to delete a non-existent attribute"); 1683 } 1684 1685 intptr_t dunderLen() { 1686 return mlirOperationGetNumAttributes(operation->get()); 1687 } 1688 1689 bool dunderContains(const std::string &name) { 1690 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 1691 operation->get(), toMlirStringRef(name))); 1692 } 1693 1694 static void bind(py::module &m) { 1695 py::class_<PyOpAttributeMap>(m, "OpAttributeMap") 1696 .def("__contains__", &PyOpAttributeMap::dunderContains) 1697 .def("__len__", &PyOpAttributeMap::dunderLen) 1698 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 1699 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 1700 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 1701 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 1702 } 1703 1704 private: 1705 PyOperationRef operation; 1706 }; 1707 1708 } // end namespace 1709 1710 //------------------------------------------------------------------------------ 1711 // Populates the core exports of the 'ir' submodule. 1712 //------------------------------------------------------------------------------ 1713 1714 void mlir::python::populateIRCore(py::module &m) { 1715 //---------------------------------------------------------------------------- 1716 // Mapping of MlirContext 1717 //---------------------------------------------------------------------------- 1718 py::class_<PyMlirContext>(m, "Context") 1719 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 1720 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 1721 .def("_get_context_again", 1722 [](PyMlirContext &self) { 1723 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 1724 return ref.releaseObject(); 1725 }) 1726 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 1727 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 1728 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 1729 &PyMlirContext::getCapsule) 1730 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 1731 .def("__enter__", &PyMlirContext::contextEnter) 1732 .def("__exit__", &PyMlirContext::contextExit) 1733 .def_property_readonly_static( 1734 "current", 1735 [](py::object & /*class*/) { 1736 auto *context = PyThreadContextEntry::getDefaultContext(); 1737 if (!context) 1738 throw SetPyError(PyExc_ValueError, "No current Context"); 1739 return context; 1740 }, 1741 "Gets the Context bound to the current thread or raises ValueError") 1742 .def_property_readonly( 1743 "dialects", 1744 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1745 "Gets a container for accessing dialects by name") 1746 .def_property_readonly( 1747 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 1748 "Alias for 'dialect'") 1749 .def( 1750 "get_dialect_descriptor", 1751 [=](PyMlirContext &self, std::string &name) { 1752 MlirDialect dialect = mlirContextGetOrLoadDialect( 1753 self.get(), {name.data(), name.size()}); 1754 if (mlirDialectIsNull(dialect)) { 1755 throw SetPyError(PyExc_ValueError, 1756 Twine("Dialect '") + name + "' not found"); 1757 } 1758 return PyDialectDescriptor(self.getRef(), dialect); 1759 }, 1760 "Gets or loads a dialect by name, returning its descriptor object") 1761 .def_property( 1762 "allow_unregistered_dialects", 1763 [](PyMlirContext &self) -> bool { 1764 return mlirContextGetAllowUnregisteredDialects(self.get()); 1765 }, 1766 [](PyMlirContext &self, bool value) { 1767 mlirContextSetAllowUnregisteredDialects(self.get(), value); 1768 }) 1769 .def("is_registered_operation", 1770 [](PyMlirContext &self, std::string &name) { 1771 return mlirContextIsRegisteredOperation( 1772 self.get(), MlirStringRef{name.data(), name.size()}); 1773 }); 1774 1775 //---------------------------------------------------------------------------- 1776 // Mapping of PyDialectDescriptor 1777 //---------------------------------------------------------------------------- 1778 py::class_<PyDialectDescriptor>(m, "DialectDescriptor") 1779 .def_property_readonly("namespace", 1780 [](PyDialectDescriptor &self) { 1781 MlirStringRef ns = 1782 mlirDialectGetNamespace(self.get()); 1783 return py::str(ns.data, ns.length); 1784 }) 1785 .def("__repr__", [](PyDialectDescriptor &self) { 1786 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 1787 std::string repr("<DialectDescriptor "); 1788 repr.append(ns.data, ns.length); 1789 repr.append(">"); 1790 return repr; 1791 }); 1792 1793 //---------------------------------------------------------------------------- 1794 // Mapping of PyDialects 1795 //---------------------------------------------------------------------------- 1796 py::class_<PyDialects>(m, "Dialects") 1797 .def("__getitem__", 1798 [=](PyDialects &self, std::string keyName) { 1799 MlirDialect dialect = 1800 self.getDialectForKey(keyName, /*attrError=*/false); 1801 py::object descriptor = 1802 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1803 return createCustomDialectWrapper(keyName, std::move(descriptor)); 1804 }) 1805 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 1806 MlirDialect dialect = 1807 self.getDialectForKey(attrName, /*attrError=*/true); 1808 py::object descriptor = 1809 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 1810 return createCustomDialectWrapper(attrName, std::move(descriptor)); 1811 }); 1812 1813 //---------------------------------------------------------------------------- 1814 // Mapping of PyDialect 1815 //---------------------------------------------------------------------------- 1816 py::class_<PyDialect>(m, "Dialect") 1817 .def(py::init<py::object>(), "descriptor") 1818 .def_property_readonly( 1819 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 1820 .def("__repr__", [](py::object self) { 1821 auto clazz = self.attr("__class__"); 1822 return py::str("<Dialect ") + 1823 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 1824 clazz.attr("__module__") + py::str(".") + 1825 clazz.attr("__name__") + py::str(")>"); 1826 }); 1827 1828 //---------------------------------------------------------------------------- 1829 // Mapping of Location 1830 //---------------------------------------------------------------------------- 1831 py::class_<PyLocation>(m, "Location") 1832 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 1833 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 1834 .def("__enter__", &PyLocation::contextEnter) 1835 .def("__exit__", &PyLocation::contextExit) 1836 .def("__eq__", 1837 [](PyLocation &self, PyLocation &other) -> bool { 1838 return mlirLocationEqual(self, other); 1839 }) 1840 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 1841 .def_property_readonly_static( 1842 "current", 1843 [](py::object & /*class*/) { 1844 auto *loc = PyThreadContextEntry::getDefaultLocation(); 1845 if (!loc) 1846 throw SetPyError(PyExc_ValueError, "No current Location"); 1847 return loc; 1848 }, 1849 "Gets the Location bound to the current thread or raises ValueError") 1850 .def_static( 1851 "unknown", 1852 [](DefaultingPyMlirContext context) { 1853 return PyLocation(context->getRef(), 1854 mlirLocationUnknownGet(context->get())); 1855 }, 1856 py::arg("context") = py::none(), 1857 "Gets a Location representing an unknown location") 1858 .def_static( 1859 "file", 1860 [](std::string filename, int line, int col, 1861 DefaultingPyMlirContext context) { 1862 return PyLocation( 1863 context->getRef(), 1864 mlirLocationFileLineColGet( 1865 context->get(), toMlirStringRef(filename), line, col)); 1866 }, 1867 py::arg("filename"), py::arg("line"), py::arg("col"), 1868 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 1869 .def_property_readonly( 1870 "context", 1871 [](PyLocation &self) { return self.getContext().getObject(); }, 1872 "Context that owns the Location") 1873 .def("__repr__", [](PyLocation &self) { 1874 PyPrintAccumulator printAccum; 1875 mlirLocationPrint(self, printAccum.getCallback(), 1876 printAccum.getUserData()); 1877 return printAccum.join(); 1878 }); 1879 1880 //---------------------------------------------------------------------------- 1881 // Mapping of Module 1882 //---------------------------------------------------------------------------- 1883 py::class_<PyModule>(m, "Module") 1884 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 1885 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 1886 .def_static( 1887 "parse", 1888 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 1889 MlirModule module = mlirModuleCreateParse( 1890 context->get(), toMlirStringRef(moduleAsm)); 1891 // TODO: Rework error reporting once diagnostic engine is exposed 1892 // in C API. 1893 if (mlirModuleIsNull(module)) { 1894 throw SetPyError( 1895 PyExc_ValueError, 1896 "Unable to parse module assembly (see diagnostics)"); 1897 } 1898 return PyModule::forModule(module).releaseObject(); 1899 }, 1900 py::arg("asm"), py::arg("context") = py::none(), 1901 kModuleParseDocstring) 1902 .def_static( 1903 "create", 1904 [](DefaultingPyLocation loc) { 1905 MlirModule module = mlirModuleCreateEmpty(loc); 1906 return PyModule::forModule(module).releaseObject(); 1907 }, 1908 py::arg("loc") = py::none(), "Creates an empty module") 1909 .def_property_readonly( 1910 "context", 1911 [](PyModule &self) { return self.getContext().getObject(); }, 1912 "Context that created the Module") 1913 .def_property_readonly( 1914 "operation", 1915 [](PyModule &self) { 1916 return PyOperation::forOperation(self.getContext(), 1917 mlirModuleGetOperation(self.get()), 1918 self.getRef().releaseObject()) 1919 .releaseObject(); 1920 }, 1921 "Accesses the module as an operation") 1922 .def_property_readonly( 1923 "body", 1924 [](PyModule &self) { 1925 PyOperationRef module_op = PyOperation::forOperation( 1926 self.getContext(), mlirModuleGetOperation(self.get()), 1927 self.getRef().releaseObject()); 1928 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 1929 return returnBlock; 1930 }, 1931 "Return the block for this module") 1932 .def( 1933 "dump", 1934 [](PyModule &self) { 1935 mlirOperationDump(mlirModuleGetOperation(self.get())); 1936 }, 1937 kDumpDocstring) 1938 .def( 1939 "__str__", 1940 [](PyModule &self) { 1941 MlirOperation operation = mlirModuleGetOperation(self.get()); 1942 PyPrintAccumulator printAccum; 1943 mlirOperationPrint(operation, printAccum.getCallback(), 1944 printAccum.getUserData()); 1945 return printAccum.join(); 1946 }, 1947 kOperationStrDunderDocstring); 1948 1949 //---------------------------------------------------------------------------- 1950 // Mapping of Operation. 1951 //---------------------------------------------------------------------------- 1952 py::class_<PyOperationBase>(m, "_OperationBase") 1953 .def("__eq__", 1954 [](PyOperationBase &self, PyOperationBase &other) { 1955 return &self.getOperation() == &other.getOperation(); 1956 }) 1957 .def("__eq__", 1958 [](PyOperationBase &self, py::object other) { return false; }) 1959 .def_property_readonly("attributes", 1960 [](PyOperationBase &self) { 1961 return PyOpAttributeMap( 1962 self.getOperation().getRef()); 1963 }) 1964 .def_property_readonly("operands", 1965 [](PyOperationBase &self) { 1966 return PyOpOperandList( 1967 self.getOperation().getRef()); 1968 }) 1969 .def_property_readonly("regions", 1970 [](PyOperationBase &self) { 1971 return PyRegionList( 1972 self.getOperation().getRef()); 1973 }) 1974 .def_property_readonly( 1975 "results", 1976 [](PyOperationBase &self) { 1977 return PyOpResultList(self.getOperation().getRef()); 1978 }, 1979 "Returns the list of Operation results.") 1980 .def_property_readonly( 1981 "result", 1982 [](PyOperationBase &self) { 1983 auto &operation = self.getOperation(); 1984 auto numResults = mlirOperationGetNumResults(operation); 1985 if (numResults != 1) { 1986 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 1987 throw SetPyError( 1988 PyExc_ValueError, 1989 Twine("Cannot call .result on operation ") + 1990 StringRef(name.data, name.length) + " which has " + 1991 Twine(numResults) + 1992 " results (it is only valid for operations with a " 1993 "single result)"); 1994 } 1995 return PyOpResult(operation.getRef(), 1996 mlirOperationGetResult(operation, 0)); 1997 }, 1998 "Shortcut to get an op result if it has only one (throws an error " 1999 "otherwise).") 2000 .def("__iter__", 2001 [](PyOperationBase &self) { 2002 return PyRegionIterator(self.getOperation().getRef()); 2003 }) 2004 .def( 2005 "__str__", 2006 [](PyOperationBase &self) { 2007 return self.getAsm(/*binary=*/false, 2008 /*largeElementsLimit=*/llvm::None, 2009 /*enableDebugInfo=*/false, 2010 /*prettyDebugInfo=*/false, 2011 /*printGenericOpForm=*/false, 2012 /*useLocalScope=*/false); 2013 }, 2014 "Returns the assembly form of the operation.") 2015 .def("print", &PyOperationBase::print, 2016 // Careful: Lots of arguments must match up with print method. 2017 py::arg("file") = py::none(), py::arg("binary") = false, 2018 py::arg("large_elements_limit") = py::none(), 2019 py::arg("enable_debug_info") = false, 2020 py::arg("pretty_debug_info") = false, 2021 py::arg("print_generic_op_form") = false, 2022 py::arg("use_local_scope") = false, kOperationPrintDocstring) 2023 .def("get_asm", &PyOperationBase::getAsm, 2024 // Careful: Lots of arguments must match up with get_asm method. 2025 py::arg("binary") = false, 2026 py::arg("large_elements_limit") = py::none(), 2027 py::arg("enable_debug_info") = false, 2028 py::arg("pretty_debug_info") = false, 2029 py::arg("print_generic_op_form") = false, 2030 py::arg("use_local_scope") = false, kOperationGetAsmDocstring) 2031 .def( 2032 "verify", 2033 [](PyOperationBase &self) { 2034 return mlirOperationVerify(self.getOperation()); 2035 }, 2036 "Verify the operation and return true if it passes, false if it " 2037 "fails."); 2038 2039 py::class_<PyOperation, PyOperationBase>(m, "Operation") 2040 .def_static("create", &PyOperation::create, py::arg("name"), 2041 py::arg("results") = py::none(), 2042 py::arg("operands") = py::none(), 2043 py::arg("attributes") = py::none(), 2044 py::arg("successors") = py::none(), py::arg("regions") = 0, 2045 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2046 kOperationCreateDocstring) 2047 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2048 &PyOperation::getCapsule) 2049 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2050 .def_property_readonly("name", 2051 [](PyOperation &self) { 2052 MlirOperation operation = self.get(); 2053 MlirStringRef name = mlirIdentifierStr( 2054 mlirOperationGetName(operation)); 2055 return py::str(name.data, name.length); 2056 }) 2057 .def_property_readonly( 2058 "context", 2059 [](PyOperation &self) { return self.getContext().getObject(); }, 2060 "Context that owns the Operation") 2061 .def_property_readonly("opview", &PyOperation::createOpView); 2062 2063 auto opViewClass = 2064 py::class_<PyOpView, PyOperationBase>(m, "OpView") 2065 .def(py::init<py::object>()) 2066 .def_property_readonly("operation", &PyOpView::getOperationObject) 2067 .def_property_readonly( 2068 "context", 2069 [](PyOpView &self) { 2070 return self.getOperation().getContext().getObject(); 2071 }, 2072 "Context that owns the Operation") 2073 .def("__str__", [](PyOpView &self) { 2074 return py::str(self.getOperationObject()); 2075 }); 2076 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2077 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2078 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2079 opViewClass.attr("build_generic") = classmethod( 2080 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2081 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2082 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2083 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2084 "Builds a specific, generated OpView based on class level attributes."); 2085 2086 //---------------------------------------------------------------------------- 2087 // Mapping of PyRegion. 2088 //---------------------------------------------------------------------------- 2089 py::class_<PyRegion>(m, "Region") 2090 .def_property_readonly( 2091 "blocks", 2092 [](PyRegion &self) { 2093 return PyBlockList(self.getParentOperation(), self.get()); 2094 }, 2095 "Returns a forward-optimized sequence of blocks.") 2096 .def( 2097 "__iter__", 2098 [](PyRegion &self) { 2099 self.checkValid(); 2100 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2101 return PyBlockIterator(self.getParentOperation(), firstBlock); 2102 }, 2103 "Iterates over blocks in the region.") 2104 .def("__eq__", 2105 [](PyRegion &self, PyRegion &other) { 2106 return self.get().ptr == other.get().ptr; 2107 }) 2108 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2109 2110 //---------------------------------------------------------------------------- 2111 // Mapping of PyBlock. 2112 //---------------------------------------------------------------------------- 2113 py::class_<PyBlock>(m, "Block") 2114 .def_property_readonly( 2115 "arguments", 2116 [](PyBlock &self) { 2117 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2118 }, 2119 "Returns a list of block arguments.") 2120 .def_property_readonly( 2121 "operations", 2122 [](PyBlock &self) { 2123 return PyOperationList(self.getParentOperation(), self.get()); 2124 }, 2125 "Returns a forward-optimized sequence of operations.") 2126 .def( 2127 "__iter__", 2128 [](PyBlock &self) { 2129 self.checkValid(); 2130 MlirOperation firstOperation = 2131 mlirBlockGetFirstOperation(self.get()); 2132 return PyOperationIterator(self.getParentOperation(), 2133 firstOperation); 2134 }, 2135 "Iterates over operations in the block.") 2136 .def("__eq__", 2137 [](PyBlock &self, PyBlock &other) { 2138 return self.get().ptr == other.get().ptr; 2139 }) 2140 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2141 .def( 2142 "__str__", 2143 [](PyBlock &self) { 2144 self.checkValid(); 2145 PyPrintAccumulator printAccum; 2146 mlirBlockPrint(self.get(), printAccum.getCallback(), 2147 printAccum.getUserData()); 2148 return printAccum.join(); 2149 }, 2150 "Returns the assembly form of the block."); 2151 2152 //---------------------------------------------------------------------------- 2153 // Mapping of PyInsertionPoint. 2154 //---------------------------------------------------------------------------- 2155 2156 py::class_<PyInsertionPoint>(m, "InsertionPoint") 2157 .def(py::init<PyBlock &>(), py::arg("block"), 2158 "Inserts after the last operation but still inside the block.") 2159 .def("__enter__", &PyInsertionPoint::contextEnter) 2160 .def("__exit__", &PyInsertionPoint::contextExit) 2161 .def_property_readonly_static( 2162 "current", 2163 [](py::object & /*class*/) { 2164 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2165 if (!ip) 2166 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2167 return ip; 2168 }, 2169 "Gets the InsertionPoint bound to the current thread or raises " 2170 "ValueError if none has been set") 2171 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2172 "Inserts before a referenced operation.") 2173 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2174 py::arg("block"), "Inserts at the beginning of the block.") 2175 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2176 py::arg("block"), "Inserts before the block terminator.") 2177 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2178 "Inserts an operation."); 2179 2180 //---------------------------------------------------------------------------- 2181 // Mapping of PyAttribute. 2182 //---------------------------------------------------------------------------- 2183 py::class_<PyAttribute>(m, "Attribute") 2184 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2185 &PyAttribute::getCapsule) 2186 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2187 .def_static( 2188 "parse", 2189 [](std::string attrSpec, DefaultingPyMlirContext context) { 2190 MlirAttribute type = mlirAttributeParseGet( 2191 context->get(), toMlirStringRef(attrSpec)); 2192 // TODO: Rework error reporting once diagnostic engine is exposed 2193 // in C API. 2194 if (mlirAttributeIsNull(type)) { 2195 throw SetPyError(PyExc_ValueError, 2196 Twine("Unable to parse attribute: '") + 2197 attrSpec + "'"); 2198 } 2199 return PyAttribute(context->getRef(), type); 2200 }, 2201 py::arg("asm"), py::arg("context") = py::none(), 2202 "Parses an attribute from an assembly form") 2203 .def_property_readonly( 2204 "context", 2205 [](PyAttribute &self) { return self.getContext().getObject(); }, 2206 "Context that owns the Attribute") 2207 .def_property_readonly("type", 2208 [](PyAttribute &self) { 2209 return PyType(self.getContext()->getRef(), 2210 mlirAttributeGetType(self)); 2211 }) 2212 .def( 2213 "get_named", 2214 [](PyAttribute &self, std::string name) { 2215 return PyNamedAttribute(self, std::move(name)); 2216 }, 2217 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2218 .def("__eq__", 2219 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2220 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2221 .def( 2222 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2223 kDumpDocstring) 2224 .def( 2225 "__str__", 2226 [](PyAttribute &self) { 2227 PyPrintAccumulator printAccum; 2228 mlirAttributePrint(self, printAccum.getCallback(), 2229 printAccum.getUserData()); 2230 return printAccum.join(); 2231 }, 2232 "Returns the assembly form of the Attribute.") 2233 .def("__repr__", [](PyAttribute &self) { 2234 // Generally, assembly formats are not printed for __repr__ because 2235 // this can cause exceptionally long debug output and exceptions. 2236 // However, attribute values are generally considered useful and are 2237 // printed. This may need to be re-evaluated if debug dumps end up 2238 // being excessive. 2239 PyPrintAccumulator printAccum; 2240 printAccum.parts.append("Attribute("); 2241 mlirAttributePrint(self, printAccum.getCallback(), 2242 printAccum.getUserData()); 2243 printAccum.parts.append(")"); 2244 return printAccum.join(); 2245 }); 2246 2247 //---------------------------------------------------------------------------- 2248 // Mapping of PyNamedAttribute 2249 //---------------------------------------------------------------------------- 2250 py::class_<PyNamedAttribute>(m, "NamedAttribute") 2251 .def("__repr__", 2252 [](PyNamedAttribute &self) { 2253 PyPrintAccumulator printAccum; 2254 printAccum.parts.append("NamedAttribute("); 2255 printAccum.parts.append( 2256 mlirIdentifierStr(self.namedAttr.name).data); 2257 printAccum.parts.append("="); 2258 mlirAttributePrint(self.namedAttr.attribute, 2259 printAccum.getCallback(), 2260 printAccum.getUserData()); 2261 printAccum.parts.append(")"); 2262 return printAccum.join(); 2263 }) 2264 .def_property_readonly( 2265 "name", 2266 [](PyNamedAttribute &self) { 2267 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2268 mlirIdentifierStr(self.namedAttr.name).length); 2269 }, 2270 "The name of the NamedAttribute binding") 2271 .def_property_readonly( 2272 "attr", 2273 [](PyNamedAttribute &self) { 2274 // TODO: When named attribute is removed/refactored, also remove 2275 // this constructor (it does an inefficient table lookup). 2276 auto contextRef = PyMlirContext::forContext( 2277 mlirAttributeGetContext(self.namedAttr.attribute)); 2278 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2279 }, 2280 py::keep_alive<0, 1>(), 2281 "The underlying generic attribute of the NamedAttribute binding"); 2282 2283 //---------------------------------------------------------------------------- 2284 // Mapping of PyType. 2285 //---------------------------------------------------------------------------- 2286 py::class_<PyType>(m, "Type") 2287 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2288 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2289 .def_static( 2290 "parse", 2291 [](std::string typeSpec, DefaultingPyMlirContext context) { 2292 MlirType type = 2293 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2294 // TODO: Rework error reporting once diagnostic engine is exposed 2295 // in C API. 2296 if (mlirTypeIsNull(type)) { 2297 throw SetPyError(PyExc_ValueError, 2298 Twine("Unable to parse type: '") + typeSpec + 2299 "'"); 2300 } 2301 return PyType(context->getRef(), type); 2302 }, 2303 py::arg("asm"), py::arg("context") = py::none(), 2304 kContextParseTypeDocstring) 2305 .def_property_readonly( 2306 "context", [](PyType &self) { return self.getContext().getObject(); }, 2307 "Context that owns the Type") 2308 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2309 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2310 .def( 2311 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2312 .def( 2313 "__str__", 2314 [](PyType &self) { 2315 PyPrintAccumulator printAccum; 2316 mlirTypePrint(self, printAccum.getCallback(), 2317 printAccum.getUserData()); 2318 return printAccum.join(); 2319 }, 2320 "Returns the assembly form of the type.") 2321 .def("__repr__", [](PyType &self) { 2322 // Generally, assembly formats are not printed for __repr__ because 2323 // this can cause exceptionally long debug output and exceptions. 2324 // However, types are an exception as they typically have compact 2325 // assembly forms and printing them is useful. 2326 PyPrintAccumulator printAccum; 2327 printAccum.parts.append("Type("); 2328 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2329 printAccum.parts.append(")"); 2330 return printAccum.join(); 2331 }); 2332 2333 //---------------------------------------------------------------------------- 2334 // Mapping of Value. 2335 //---------------------------------------------------------------------------- 2336 py::class_<PyValue>(m, "Value") 2337 .def_property_readonly( 2338 "context", 2339 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2340 "Context in which the value lives.") 2341 .def( 2342 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2343 kDumpDocstring) 2344 .def("__eq__", 2345 [](PyValue &self, PyValue &other) { 2346 return self.get().ptr == other.get().ptr; 2347 }) 2348 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2349 .def( 2350 "__str__", 2351 [](PyValue &self) { 2352 PyPrintAccumulator printAccum; 2353 printAccum.parts.append("Value("); 2354 mlirValuePrint(self.get(), printAccum.getCallback(), 2355 printAccum.getUserData()); 2356 printAccum.parts.append(")"); 2357 return printAccum.join(); 2358 }, 2359 kValueDunderStrDocstring) 2360 .def_property_readonly("type", [](PyValue &self) { 2361 return PyType(self.getParentOperation()->getContext(), 2362 mlirValueGetType(self.get())); 2363 }); 2364 PyBlockArgument::bind(m); 2365 PyOpResult::bind(m); 2366 2367 // Container bindings. 2368 PyBlockArgumentList::bind(m); 2369 PyBlockIterator::bind(m); 2370 PyBlockList::bind(m); 2371 PyOperationIterator::bind(m); 2372 PyOperationList::bind(m); 2373 PyOpAttributeMap::bind(m); 2374 PyOpOperandList::bind(m); 2375 PyOpResultList::bind(m); 2376 PyRegionIterator::bind(m); 2377 PyRegionList::bind(m); 2378 } 2379