1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 #include <utility> 25 26 namespace py = pybind11; 27 using namespace mlir; 28 using namespace mlir::python; 29 30 using llvm::SmallVector; 31 using llvm::StringRef; 32 using llvm::Twine; 33 34 //------------------------------------------------------------------------------ 35 // Docstrings (trivial, non-duplicated docstrings are included inline). 36 //------------------------------------------------------------------------------ 37 38 static const char kContextParseTypeDocstring[] = 39 R"(Parses the assembly form of a type. 40 41 Returns a Type object or raises a ValueError if the type cannot be parsed. 42 43 See also: https://mlir.llvm.org/docs/LangRef/#type-system 44 )"; 45 46 static const char kContextGetCallSiteLocationDocstring[] = 47 R"(Gets a Location representing a caller and callsite)"; 48 49 static const char kContextGetFileLocationDocstring[] = 50 R"(Gets a Location representing a file, line and column)"; 51 52 static const char kContextGetFusedLocationDocstring[] = 53 R"(Gets a Location representing a fused location with optional metadata)"; 54 55 static const char kContextGetNameLocationDocString[] = 56 R"(Gets a Location representing a named location with optional child location)"; 57 58 static const char kModuleParseDocstring[] = 59 R"(Parses a module's assembly format from a string. 60 61 Returns a new MlirModule or raises a ValueError if the parsing fails. 62 63 See also: https://mlir.llvm.org/docs/LangRef/ 64 )"; 65 66 static const char kOperationCreateDocstring[] = 67 R"(Creates a new operation. 68 69 Args: 70 name: Operation name (e.g. "dialect.operation"). 71 results: Sequence of Type representing op result types. 72 attributes: Dict of str:Attribute. 73 successors: List of Block for the operation's successors. 74 regions: Number of regions to create. 75 location: A Location object (defaults to resolve from context manager). 76 ip: An InsertionPoint (defaults to resolve from context manager or set to 77 False to disable insertion, even with an insertion point set in the 78 context manager). 79 Returns: 80 A new "detached" Operation object. Detached operations can be added 81 to blocks, which causes them to become "attached." 82 )"; 83 84 static const char kOperationPrintDocstring[] = 85 R"(Prints the assembly form of the operation to a file like object. 86 87 Args: 88 file: The file like object to write to. Defaults to sys.stdout. 89 binary: Whether to write bytes (True) or str (False). Defaults to False. 90 large_elements_limit: Whether to elide elements attributes above this 91 number of elements. Defaults to None (no limit). 92 enable_debug_info: Whether to print debug/location information. Defaults 93 to False. 94 pretty_debug_info: Whether to format debug information for easier reading 95 by a human (warning: the result is unparseable). 96 print_generic_op_form: Whether to print the generic assembly forms of all 97 ops. Defaults to False. 98 use_local_Scope: Whether to print in a way that is more optimized for 99 multi-threaded access but may not be consistent with how the overall 100 module prints. 101 assume_verified: By default, if not printing generic form, the verifier 102 will be run and if it fails, generic form will be printed with a comment 103 about failed verification. While a reasonable default for interactive use, 104 for systematic use, it is often better for the caller to verify explicitly 105 and report failures in a more robust fashion. Set this to True if doing this 106 in order to avoid running a redundant verification. If the IR is actually 107 invalid, behavior is undefined. 108 )"; 109 110 static const char kOperationGetAsmDocstring[] = 111 R"(Gets the assembly form of the operation with all options available. 112 113 Args: 114 binary: Whether to return a bytes (True) or str (False) object. Defaults to 115 False. 116 ... others ...: See the print() method for common keyword arguments for 117 configuring the printout. 118 Returns: 119 Either a bytes or str object, depending on the setting of the 'binary' 120 argument. 121 )"; 122 123 static const char kOperationStrDunderDocstring[] = 124 R"(Gets the assembly form of the operation with default options. 125 126 If more advanced control over the assembly formatting or I/O options is needed, 127 use the dedicated print or get_asm method, which supports keyword arguments to 128 customize behavior. 129 )"; 130 131 static const char kDumpDocstring[] = 132 R"(Dumps a debug representation of the object to stderr.)"; 133 134 static const char kAppendBlockDocstring[] = 135 R"(Appends a new block, with argument types as positional args. 136 137 Returns: 138 The created block. 139 )"; 140 141 static const char kValueDunderStrDocstring[] = 142 R"(Returns the string form of the value. 143 144 If the value is a block argument, this is the assembly form of its type and the 145 position in the argument list. If the value is an operation result, this is 146 equivalent to printing the operation that produced it. 147 )"; 148 149 //------------------------------------------------------------------------------ 150 // Utilities. 151 //------------------------------------------------------------------------------ 152 153 /// Helper for creating an @classmethod. 154 template <class Func, typename... Args> 155 py::object classmethod(Func f, Args... args) { 156 py::object cf = py::cpp_function(f, args...); 157 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 158 } 159 160 static py::object 161 createCustomDialectWrapper(const std::string &dialectNamespace, 162 py::object dialectDescriptor) { 163 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 164 if (!dialectClass) { 165 // Use the base class. 166 return py::cast(PyDialect(std::move(dialectDescriptor))); 167 } 168 169 // Create the custom implementation. 170 return (*dialectClass)(std::move(dialectDescriptor)); 171 } 172 173 static MlirStringRef toMlirStringRef(const std::string &s) { 174 return mlirStringRefCreate(s.data(), s.size()); 175 } 176 177 /// Wrapper for the global LLVM debugging flag. 178 struct PyGlobalDebugFlag { 179 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 180 181 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } 182 183 static void bind(py::module &m) { 184 // Debug flags. 185 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 186 .def_property_static("flag", &PyGlobalDebugFlag::get, 187 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 188 } 189 }; 190 191 //------------------------------------------------------------------------------ 192 // Collections. 193 //------------------------------------------------------------------------------ 194 195 namespace { 196 197 class PyRegionIterator { 198 public: 199 PyRegionIterator(PyOperationRef operation) 200 : operation(std::move(operation)) {} 201 202 PyRegionIterator &dunderIter() { return *this; } 203 204 PyRegion dunderNext() { 205 operation->checkValid(); 206 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 207 throw py::stop_iteration(); 208 } 209 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 210 return PyRegion(operation, region); 211 } 212 213 static void bind(py::module &m) { 214 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 215 .def("__iter__", &PyRegionIterator::dunderIter) 216 .def("__next__", &PyRegionIterator::dunderNext); 217 } 218 219 private: 220 PyOperationRef operation; 221 int nextIndex = 0; 222 }; 223 224 /// Regions of an op are fixed length and indexed numerically so are represented 225 /// with a sequence-like container. 226 class PyRegionList { 227 public: 228 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 229 230 intptr_t dunderLen() { 231 operation->checkValid(); 232 return mlirOperationGetNumRegions(operation->get()); 233 } 234 235 PyRegion dunderGetItem(intptr_t index) { 236 // dunderLen checks validity. 237 if (index < 0 || index >= dunderLen()) { 238 throw SetPyError(PyExc_IndexError, 239 "attempt to access out of bounds region"); 240 } 241 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 242 return PyRegion(operation, region); 243 } 244 245 static void bind(py::module &m) { 246 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 247 .def("__len__", &PyRegionList::dunderLen) 248 .def("__getitem__", &PyRegionList::dunderGetItem); 249 } 250 251 private: 252 PyOperationRef operation; 253 }; 254 255 class PyBlockIterator { 256 public: 257 PyBlockIterator(PyOperationRef operation, MlirBlock next) 258 : operation(std::move(operation)), next(next) {} 259 260 PyBlockIterator &dunderIter() { return *this; } 261 262 PyBlock dunderNext() { 263 operation->checkValid(); 264 if (mlirBlockIsNull(next)) { 265 throw py::stop_iteration(); 266 } 267 268 PyBlock returnBlock(operation, next); 269 next = mlirBlockGetNextInRegion(next); 270 return returnBlock; 271 } 272 273 static void bind(py::module &m) { 274 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 275 .def("__iter__", &PyBlockIterator::dunderIter) 276 .def("__next__", &PyBlockIterator::dunderNext); 277 } 278 279 private: 280 PyOperationRef operation; 281 MlirBlock next; 282 }; 283 284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 285 /// we present them as a more full-featured list-like container but optimize 286 /// it for forward iteration. Blocks are always owned by a region. 287 class PyBlockList { 288 public: 289 PyBlockList(PyOperationRef operation, MlirRegion region) 290 : operation(std::move(operation)), region(region) {} 291 292 PyBlockIterator dunderIter() { 293 operation->checkValid(); 294 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 295 } 296 297 intptr_t dunderLen() { 298 operation->checkValid(); 299 intptr_t count = 0; 300 MlirBlock block = mlirRegionGetFirstBlock(region); 301 while (!mlirBlockIsNull(block)) { 302 count += 1; 303 block = mlirBlockGetNextInRegion(block); 304 } 305 return count; 306 } 307 308 PyBlock dunderGetItem(intptr_t index) { 309 operation->checkValid(); 310 if (index < 0) { 311 throw SetPyError(PyExc_IndexError, 312 "attempt to access out of bounds block"); 313 } 314 MlirBlock block = mlirRegionGetFirstBlock(region); 315 while (!mlirBlockIsNull(block)) { 316 if (index == 0) { 317 return PyBlock(operation, block); 318 } 319 block = mlirBlockGetNextInRegion(block); 320 index -= 1; 321 } 322 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 323 } 324 325 PyBlock appendBlock(const py::args &pyArgTypes) { 326 operation->checkValid(); 327 llvm::SmallVector<MlirType, 4> argTypes; 328 argTypes.reserve(pyArgTypes.size()); 329 for (auto &pyArg : pyArgTypes) { 330 argTypes.push_back(pyArg.cast<PyType &>()); 331 } 332 333 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 334 mlirRegionAppendOwnedBlock(region, block); 335 return PyBlock(operation, block); 336 } 337 338 static void bind(py::module &m) { 339 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 340 .def("__getitem__", &PyBlockList::dunderGetItem) 341 .def("__iter__", &PyBlockList::dunderIter) 342 .def("__len__", &PyBlockList::dunderLen) 343 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 344 } 345 346 private: 347 PyOperationRef operation; 348 MlirRegion region; 349 }; 350 351 class PyOperationIterator { 352 public: 353 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 354 : parentOperation(std::move(parentOperation)), next(next) {} 355 356 PyOperationIterator &dunderIter() { return *this; } 357 358 py::object dunderNext() { 359 parentOperation->checkValid(); 360 if (mlirOperationIsNull(next)) { 361 throw py::stop_iteration(); 362 } 363 364 PyOperationRef returnOperation = 365 PyOperation::forOperation(parentOperation->getContext(), next); 366 next = mlirOperationGetNextInBlock(next); 367 return returnOperation->createOpView(); 368 } 369 370 static void bind(py::module &m) { 371 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 372 .def("__iter__", &PyOperationIterator::dunderIter) 373 .def("__next__", &PyOperationIterator::dunderNext); 374 } 375 376 private: 377 PyOperationRef parentOperation; 378 MlirOperation next; 379 }; 380 381 /// Operations are exposed by the C-API as a forward-only linked list. In 382 /// Python, we present them as a more full-featured list-like container but 383 /// optimize it for forward iteration. Iterable operations are always owned 384 /// by a block. 385 class PyOperationList { 386 public: 387 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 388 : parentOperation(std::move(parentOperation)), block(block) {} 389 390 PyOperationIterator dunderIter() { 391 parentOperation->checkValid(); 392 return PyOperationIterator(parentOperation, 393 mlirBlockGetFirstOperation(block)); 394 } 395 396 intptr_t dunderLen() { 397 parentOperation->checkValid(); 398 intptr_t count = 0; 399 MlirOperation childOp = mlirBlockGetFirstOperation(block); 400 while (!mlirOperationIsNull(childOp)) { 401 count += 1; 402 childOp = mlirOperationGetNextInBlock(childOp); 403 } 404 return count; 405 } 406 407 py::object dunderGetItem(intptr_t index) { 408 parentOperation->checkValid(); 409 if (index < 0) { 410 throw SetPyError(PyExc_IndexError, 411 "attempt to access out of bounds operation"); 412 } 413 MlirOperation childOp = mlirBlockGetFirstOperation(block); 414 while (!mlirOperationIsNull(childOp)) { 415 if (index == 0) { 416 return PyOperation::forOperation(parentOperation->getContext(), childOp) 417 ->createOpView(); 418 } 419 childOp = mlirOperationGetNextInBlock(childOp); 420 index -= 1; 421 } 422 throw SetPyError(PyExc_IndexError, 423 "attempt to access out of bounds operation"); 424 } 425 426 static void bind(py::module &m) { 427 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 428 .def("__getitem__", &PyOperationList::dunderGetItem) 429 .def("__iter__", &PyOperationList::dunderIter) 430 .def("__len__", &PyOperationList::dunderLen); 431 } 432 433 private: 434 PyOperationRef parentOperation; 435 MlirBlock block; 436 }; 437 438 } // namespace 439 440 //------------------------------------------------------------------------------ 441 // PyMlirContext 442 //------------------------------------------------------------------------------ 443 444 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 445 py::gil_scoped_acquire acquire; 446 auto &liveContexts = getLiveContexts(); 447 liveContexts[context.ptr] = this; 448 } 449 450 PyMlirContext::~PyMlirContext() { 451 // Note that the only public way to construct an instance is via the 452 // forContext method, which always puts the associated handle into 453 // liveContexts. 454 py::gil_scoped_acquire acquire; 455 getLiveContexts().erase(context.ptr); 456 mlirContextDestroy(context); 457 } 458 459 py::object PyMlirContext::getCapsule() { 460 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 461 } 462 463 py::object PyMlirContext::createFromCapsule(py::object capsule) { 464 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 465 if (mlirContextIsNull(rawContext)) 466 throw py::error_already_set(); 467 return forContext(rawContext).releaseObject(); 468 } 469 470 PyMlirContext *PyMlirContext::createNewContextForInit() { 471 MlirContext context = mlirContextCreate(); 472 mlirRegisterAllDialects(context); 473 return new PyMlirContext(context); 474 } 475 476 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 477 py::gil_scoped_acquire acquire; 478 auto &liveContexts = getLiveContexts(); 479 auto it = liveContexts.find(context.ptr); 480 if (it == liveContexts.end()) { 481 // Create. 482 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 483 py::object pyRef = py::cast(unownedContextWrapper); 484 assert(pyRef && "cast to py::object failed"); 485 liveContexts[context.ptr] = unownedContextWrapper; 486 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 487 } 488 // Use existing. 489 py::object pyRef = py::cast(it->second); 490 return PyMlirContextRef(it->second, std::move(pyRef)); 491 } 492 493 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 494 static LiveContextMap liveContexts; 495 return liveContexts; 496 } 497 498 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 499 500 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 501 502 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 503 504 pybind11::object PyMlirContext::contextEnter() { 505 return PyThreadContextEntry::pushContext(*this); 506 } 507 508 void PyMlirContext::contextExit(const pybind11::object &excType, 509 const pybind11::object &excVal, 510 const pybind11::object &excTb) { 511 PyThreadContextEntry::popContext(*this); 512 } 513 514 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { 515 // Note that ownership is transferred to the delete callback below by way of 516 // an explicit inc_ref (borrow). 517 PyDiagnosticHandler *pyHandler = 518 new PyDiagnosticHandler(get(), std::move(callback)); 519 py::object pyHandlerObject = 520 py::cast(pyHandler, py::return_value_policy::take_ownership); 521 pyHandlerObject.inc_ref(); 522 523 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 524 // guaranteed to be known to pybind. 525 auto handlerCallback = 526 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 527 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 528 py::object pyDiagnosticObject = 529 py::cast(pyDiagnostic, py::return_value_policy::take_ownership); 530 531 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 532 bool result = false; 533 { 534 // Since this can be called from arbitrary C++ contexts, always get the 535 // gil. 536 py::gil_scoped_acquire gil; 537 try { 538 result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); 539 } catch (std::exception &e) { 540 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 541 e.what()); 542 pyHandler->hadError = true; 543 } 544 } 545 546 pyDiagnostic->invalidate(); 547 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 548 }; 549 auto deleteCallback = +[](void *userData) { 550 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 551 assert(pyHandler->registeredID && "handler is not registered"); 552 pyHandler->registeredID.reset(); 553 554 // Decrement reference, balancing the inc_ref() above. 555 py::object pyHandlerObject = 556 py::cast(pyHandler, py::return_value_policy::reference); 557 pyHandlerObject.dec_ref(); 558 }; 559 560 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 561 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 562 return pyHandlerObject; 563 } 564 565 PyMlirContext &DefaultingPyMlirContext::resolve() { 566 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 567 if (!context) { 568 throw SetPyError( 569 PyExc_RuntimeError, 570 "An MLIR function requires a Context but none was provided in the call " 571 "or from the surrounding environment. Either pass to the function with " 572 "a 'context=' argument or establish a default using 'with Context():'"); 573 } 574 return *context; 575 } 576 577 //------------------------------------------------------------------------------ 578 // PyThreadContextEntry management 579 //------------------------------------------------------------------------------ 580 581 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 582 static thread_local std::vector<PyThreadContextEntry> stack; 583 return stack; 584 } 585 586 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 587 auto &stack = getStack(); 588 if (stack.empty()) 589 return nullptr; 590 return &stack.back(); 591 } 592 593 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 594 py::object insertionPoint, 595 py::object location) { 596 auto &stack = getStack(); 597 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 598 std::move(location)); 599 // If the new stack has more than one entry and the context of the new top 600 // entry matches the previous, copy the insertionPoint and location from the 601 // previous entry if missing from the new top entry. 602 if (stack.size() > 1) { 603 auto &prev = *(stack.rbegin() + 1); 604 auto ¤t = stack.back(); 605 if (current.context.is(prev.context)) { 606 // Default non-context objects from the previous entry. 607 if (!current.insertionPoint) 608 current.insertionPoint = prev.insertionPoint; 609 if (!current.location) 610 current.location = prev.location; 611 } 612 } 613 } 614 615 PyMlirContext *PyThreadContextEntry::getContext() { 616 if (!context) 617 return nullptr; 618 return py::cast<PyMlirContext *>(context); 619 } 620 621 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 622 if (!insertionPoint) 623 return nullptr; 624 return py::cast<PyInsertionPoint *>(insertionPoint); 625 } 626 627 PyLocation *PyThreadContextEntry::getLocation() { 628 if (!location) 629 return nullptr; 630 return py::cast<PyLocation *>(location); 631 } 632 633 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 634 auto *tos = getTopOfStack(); 635 return tos ? tos->getContext() : nullptr; 636 } 637 638 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 639 auto *tos = getTopOfStack(); 640 return tos ? tos->getInsertionPoint() : nullptr; 641 } 642 643 PyLocation *PyThreadContextEntry::getDefaultLocation() { 644 auto *tos = getTopOfStack(); 645 return tos ? tos->getLocation() : nullptr; 646 } 647 648 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 649 py::object contextObj = py::cast(context); 650 push(FrameKind::Context, /*context=*/contextObj, 651 /*insertionPoint=*/py::object(), 652 /*location=*/py::object()); 653 return contextObj; 654 } 655 656 void PyThreadContextEntry::popContext(PyMlirContext &context) { 657 auto &stack = getStack(); 658 if (stack.empty()) 659 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 660 auto &tos = stack.back(); 661 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 662 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 663 stack.pop_back(); 664 } 665 666 py::object 667 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 668 py::object contextObj = 669 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 670 py::object insertionPointObj = py::cast(insertionPoint); 671 push(FrameKind::InsertionPoint, 672 /*context=*/contextObj, 673 /*insertionPoint=*/insertionPointObj, 674 /*location=*/py::object()); 675 return insertionPointObj; 676 } 677 678 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 679 auto &stack = getStack(); 680 if (stack.empty()) 681 throw SetPyError(PyExc_RuntimeError, 682 "Unbalanced InsertionPoint enter/exit"); 683 auto &tos = stack.back(); 684 if (tos.frameKind != FrameKind::InsertionPoint && 685 tos.getInsertionPoint() != &insertionPoint) 686 throw SetPyError(PyExc_RuntimeError, 687 "Unbalanced InsertionPoint enter/exit"); 688 stack.pop_back(); 689 } 690 691 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 692 py::object contextObj = location.getContext().getObject(); 693 py::object locationObj = py::cast(location); 694 push(FrameKind::Location, /*context=*/contextObj, 695 /*insertionPoint=*/py::object(), 696 /*location=*/locationObj); 697 return locationObj; 698 } 699 700 void PyThreadContextEntry::popLocation(PyLocation &location) { 701 auto &stack = getStack(); 702 if (stack.empty()) 703 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 704 auto &tos = stack.back(); 705 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 706 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 707 stack.pop_back(); 708 } 709 710 //------------------------------------------------------------------------------ 711 // PyDiagnostic* 712 //------------------------------------------------------------------------------ 713 714 void PyDiagnostic::invalidate() { 715 valid = false; 716 if (materializedNotes) { 717 for (auto ¬eObject : *materializedNotes) { 718 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); 719 note->invalidate(); 720 } 721 } 722 } 723 724 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 725 py::object callback) 726 : context(context), callback(std::move(callback)) {} 727 728 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 729 730 void PyDiagnosticHandler::detach() { 731 if (!registeredID) 732 return; 733 MlirDiagnosticHandlerID localID = *registeredID; 734 mlirContextDetachDiagnosticHandler(context, localID); 735 assert(!registeredID && "should have unregistered"); 736 // Not strictly necessary but keeps stale pointers from being around to cause 737 // issues. 738 context = {nullptr}; 739 } 740 741 void PyDiagnostic::checkValid() { 742 if (!valid) { 743 throw std::invalid_argument( 744 "Diagnostic is invalid (used outside of callback)"); 745 } 746 } 747 748 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 749 checkValid(); 750 return mlirDiagnosticGetSeverity(diagnostic); 751 } 752 753 PyLocation PyDiagnostic::getLocation() { 754 checkValid(); 755 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 756 MlirContext context = mlirLocationGetContext(loc); 757 return PyLocation(PyMlirContext::forContext(context), loc); 758 } 759 760 py::str PyDiagnostic::getMessage() { 761 checkValid(); 762 py::object fileObject = py::module::import("io").attr("StringIO")(); 763 PyFileAccumulator accum(fileObject, /*binary=*/false); 764 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 765 return fileObject.attr("getvalue")(); 766 } 767 768 py::tuple PyDiagnostic::getNotes() { 769 checkValid(); 770 if (materializedNotes) 771 return *materializedNotes; 772 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 773 materializedNotes = py::tuple(numNotes); 774 for (intptr_t i = 0; i < numNotes; ++i) { 775 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 776 py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); 777 PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); 778 } 779 return *materializedNotes; 780 } 781 782 //------------------------------------------------------------------------------ 783 // PyDialect, PyDialectDescriptor, PyDialects 784 //------------------------------------------------------------------------------ 785 786 MlirDialect PyDialects::getDialectForKey(const std::string &key, 787 bool attrError) { 788 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 789 {key.data(), key.size()}); 790 if (mlirDialectIsNull(dialect)) { 791 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 792 Twine("Dialect '") + key + "' not found"); 793 } 794 return dialect; 795 } 796 797 //------------------------------------------------------------------------------ 798 // PyLocation 799 //------------------------------------------------------------------------------ 800 801 py::object PyLocation::getCapsule() { 802 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 803 } 804 805 PyLocation PyLocation::createFromCapsule(py::object capsule) { 806 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 807 if (mlirLocationIsNull(rawLoc)) 808 throw py::error_already_set(); 809 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 810 rawLoc); 811 } 812 813 py::object PyLocation::contextEnter() { 814 return PyThreadContextEntry::pushLocation(*this); 815 } 816 817 void PyLocation::contextExit(const pybind11::object &excType, 818 const pybind11::object &excVal, 819 const pybind11::object &excTb) { 820 PyThreadContextEntry::popLocation(*this); 821 } 822 823 PyLocation &DefaultingPyLocation::resolve() { 824 auto *location = PyThreadContextEntry::getDefaultLocation(); 825 if (!location) { 826 throw SetPyError( 827 PyExc_RuntimeError, 828 "An MLIR function requires a Location but none was provided in the " 829 "call or from the surrounding environment. Either pass to the function " 830 "with a 'loc=' argument or establish a default using 'with loc:'"); 831 } 832 return *location; 833 } 834 835 //------------------------------------------------------------------------------ 836 // PyModule 837 //------------------------------------------------------------------------------ 838 839 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 840 : BaseContextObject(std::move(contextRef)), module(module) {} 841 842 PyModule::~PyModule() { 843 py::gil_scoped_acquire acquire; 844 auto &liveModules = getContext()->liveModules; 845 assert(liveModules.count(module.ptr) == 1 && 846 "destroying module not in live map"); 847 liveModules.erase(module.ptr); 848 mlirModuleDestroy(module); 849 } 850 851 PyModuleRef PyModule::forModule(MlirModule module) { 852 MlirContext context = mlirModuleGetContext(module); 853 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 854 855 py::gil_scoped_acquire acquire; 856 auto &liveModules = contextRef->liveModules; 857 auto it = liveModules.find(module.ptr); 858 if (it == liveModules.end()) { 859 // Create. 860 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 861 // Note that the default return value policy on cast is automatic_reference, 862 // which does not take ownership (delete will not be called). 863 // Just be explicit. 864 py::object pyRef = 865 py::cast(unownedModule, py::return_value_policy::take_ownership); 866 unownedModule->handle = pyRef; 867 liveModules[module.ptr] = 868 std::make_pair(unownedModule->handle, unownedModule); 869 return PyModuleRef(unownedModule, std::move(pyRef)); 870 } 871 // Use existing. 872 PyModule *existing = it->second.second; 873 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 874 return PyModuleRef(existing, std::move(pyRef)); 875 } 876 877 py::object PyModule::createFromCapsule(py::object capsule) { 878 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 879 if (mlirModuleIsNull(rawModule)) 880 throw py::error_already_set(); 881 return forModule(rawModule).releaseObject(); 882 } 883 884 py::object PyModule::getCapsule() { 885 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 886 } 887 888 //------------------------------------------------------------------------------ 889 // PyOperation 890 //------------------------------------------------------------------------------ 891 892 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 893 : BaseContextObject(std::move(contextRef)), operation(operation) {} 894 895 PyOperation::~PyOperation() { 896 // If the operation has already been invalidated there is nothing to do. 897 if (!valid) 898 return; 899 auto &liveOperations = getContext()->liveOperations; 900 assert(liveOperations.count(operation.ptr) == 1 && 901 "destroying operation not in live map"); 902 liveOperations.erase(operation.ptr); 903 if (!isAttached()) { 904 mlirOperationDestroy(operation); 905 } 906 } 907 908 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 909 MlirOperation operation, 910 py::object parentKeepAlive) { 911 auto &liveOperations = contextRef->liveOperations; 912 // Create. 913 PyOperation *unownedOperation = 914 new PyOperation(std::move(contextRef), operation); 915 // Note that the default return value policy on cast is automatic_reference, 916 // which does not take ownership (delete will not be called). 917 // Just be explicit. 918 py::object pyRef = 919 py::cast(unownedOperation, py::return_value_policy::take_ownership); 920 unownedOperation->handle = pyRef; 921 if (parentKeepAlive) { 922 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 923 } 924 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 925 return PyOperationRef(unownedOperation, std::move(pyRef)); 926 } 927 928 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 929 MlirOperation operation, 930 py::object parentKeepAlive) { 931 auto &liveOperations = contextRef->liveOperations; 932 auto it = liveOperations.find(operation.ptr); 933 if (it == liveOperations.end()) { 934 // Create. 935 return createInstance(std::move(contextRef), operation, 936 std::move(parentKeepAlive)); 937 } 938 // Use existing. 939 PyOperation *existing = it->second.second; 940 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 941 return PyOperationRef(existing, std::move(pyRef)); 942 } 943 944 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 945 MlirOperation operation, 946 py::object parentKeepAlive) { 947 auto &liveOperations = contextRef->liveOperations; 948 assert(liveOperations.count(operation.ptr) == 0 && 949 "cannot create detached operation that already exists"); 950 (void)liveOperations; 951 952 PyOperationRef created = createInstance(std::move(contextRef), operation, 953 std::move(parentKeepAlive)); 954 created->attached = false; 955 return created; 956 } 957 958 void PyOperation::checkValid() const { 959 if (!valid) { 960 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 961 } 962 } 963 964 void PyOperationBase::print(py::object fileObject, bool binary, 965 llvm::Optional<int64_t> largeElementsLimit, 966 bool enableDebugInfo, bool prettyDebugInfo, 967 bool printGenericOpForm, bool useLocalScope, 968 bool assumeVerified) { 969 PyOperation &operation = getOperation(); 970 operation.checkValid(); 971 if (fileObject.is_none()) 972 fileObject = py::module::import("sys").attr("stdout"); 973 974 if (!assumeVerified && !printGenericOpForm && 975 !mlirOperationVerify(operation)) { 976 std::string message("// Verification failed, printing generic form\n"); 977 if (binary) { 978 fileObject.attr("write")(py::bytes(message)); 979 } else { 980 fileObject.attr("write")(py::str(message)); 981 } 982 printGenericOpForm = true; 983 } 984 985 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 986 if (largeElementsLimit) 987 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 988 if (enableDebugInfo) 989 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 990 if (printGenericOpForm) 991 mlirOpPrintingFlagsPrintGenericOpForm(flags); 992 993 PyFileAccumulator accum(fileObject, binary); 994 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 995 accum.getUserData()); 996 mlirOpPrintingFlagsDestroy(flags); 997 } 998 999 py::object PyOperationBase::getAsm(bool binary, 1000 llvm::Optional<int64_t> largeElementsLimit, 1001 bool enableDebugInfo, bool prettyDebugInfo, 1002 bool printGenericOpForm, bool useLocalScope, 1003 bool assumeVerified) { 1004 py::object fileObject; 1005 if (binary) { 1006 fileObject = py::module::import("io").attr("BytesIO")(); 1007 } else { 1008 fileObject = py::module::import("io").attr("StringIO")(); 1009 } 1010 print(fileObject, /*binary=*/binary, 1011 /*largeElementsLimit=*/largeElementsLimit, 1012 /*enableDebugInfo=*/enableDebugInfo, 1013 /*prettyDebugInfo=*/prettyDebugInfo, 1014 /*printGenericOpForm=*/printGenericOpForm, 1015 /*useLocalScope=*/useLocalScope, 1016 /*assumeVerified=*/assumeVerified); 1017 1018 return fileObject.attr("getvalue")(); 1019 } 1020 1021 void PyOperationBase::moveAfter(PyOperationBase &other) { 1022 PyOperation &operation = getOperation(); 1023 PyOperation &otherOp = other.getOperation(); 1024 operation.checkValid(); 1025 otherOp.checkValid(); 1026 mlirOperationMoveAfter(operation, otherOp); 1027 operation.parentKeepAlive = otherOp.parentKeepAlive; 1028 } 1029 1030 void PyOperationBase::moveBefore(PyOperationBase &other) { 1031 PyOperation &operation = getOperation(); 1032 PyOperation &otherOp = other.getOperation(); 1033 operation.checkValid(); 1034 otherOp.checkValid(); 1035 mlirOperationMoveBefore(operation, otherOp); 1036 operation.parentKeepAlive = otherOp.parentKeepAlive; 1037 } 1038 1039 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 1040 checkValid(); 1041 if (!isAttached()) 1042 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 1043 MlirOperation operation = mlirOperationGetParentOperation(get()); 1044 if (mlirOperationIsNull(operation)) 1045 return {}; 1046 return PyOperation::forOperation(getContext(), operation); 1047 } 1048 1049 PyBlock PyOperation::getBlock() { 1050 checkValid(); 1051 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 1052 MlirBlock block = mlirOperationGetBlock(get()); 1053 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1054 assert(parentOperation && "Operation has no parent"); 1055 return PyBlock{std::move(*parentOperation), block}; 1056 } 1057 1058 py::object PyOperation::getCapsule() { 1059 checkValid(); 1060 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 1061 } 1062 1063 py::object PyOperation::createFromCapsule(py::object capsule) { 1064 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1065 if (mlirOperationIsNull(rawOperation)) 1066 throw py::error_already_set(); 1067 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1068 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1069 .releaseObject(); 1070 } 1071 1072 py::object PyOperation::create( 1073 const std::string &name, llvm::Optional<std::vector<PyType *>> results, 1074 llvm::Optional<std::vector<PyValue *>> operands, 1075 llvm::Optional<py::dict> attributes, 1076 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 1077 DefaultingPyLocation location, const py::object &maybeIp) { 1078 llvm::SmallVector<MlirValue, 4> mlirOperands; 1079 llvm::SmallVector<MlirType, 4> mlirResults; 1080 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1081 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1082 1083 // General parameter validation. 1084 if (regions < 0) 1085 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 1086 1087 // Unpack/validate operands. 1088 if (operands) { 1089 mlirOperands.reserve(operands->size()); 1090 for (PyValue *operand : *operands) { 1091 if (!operand) 1092 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 1093 mlirOperands.push_back(operand->get()); 1094 } 1095 } 1096 1097 // Unpack/validate results. 1098 if (results) { 1099 mlirResults.reserve(results->size()); 1100 for (PyType *result : *results) { 1101 // TODO: Verify result type originate from the same context. 1102 if (!result) 1103 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 1104 mlirResults.push_back(*result); 1105 } 1106 } 1107 // Unpack/validate attributes. 1108 if (attributes) { 1109 mlirAttributes.reserve(attributes->size()); 1110 for (auto &it : *attributes) { 1111 std::string key; 1112 try { 1113 key = it.first.cast<std::string>(); 1114 } catch (py::cast_error &err) { 1115 std::string msg = "Invalid attribute key (not a string) when " 1116 "attempting to create the operation \"" + 1117 name + "\" (" + err.what() + ")"; 1118 throw py::cast_error(msg); 1119 } 1120 try { 1121 auto &attribute = it.second.cast<PyAttribute &>(); 1122 // TODO: Verify attribute originates from the same context. 1123 mlirAttributes.emplace_back(std::move(key), attribute); 1124 } catch (py::reference_cast_error &) { 1125 // This exception seems thrown when the value is "None". 1126 std::string msg = 1127 "Found an invalid (`None`?) attribute value for the key \"" + key + 1128 "\" when attempting to create the operation \"" + name + "\""; 1129 throw py::cast_error(msg); 1130 } catch (py::cast_error &err) { 1131 std::string msg = "Invalid attribute value for the key \"" + key + 1132 "\" when attempting to create the operation \"" + 1133 name + "\" (" + err.what() + ")"; 1134 throw py::cast_error(msg); 1135 } 1136 } 1137 } 1138 // Unpack/validate successors. 1139 if (successors) { 1140 mlirSuccessors.reserve(successors->size()); 1141 for (auto *successor : *successors) { 1142 // TODO: Verify successor originate from the same context. 1143 if (!successor) 1144 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1145 mlirSuccessors.push_back(successor->get()); 1146 } 1147 } 1148 1149 // Apply unpacked/validated to the operation state. Beyond this 1150 // point, exceptions cannot be thrown or else the state will leak. 1151 MlirOperationState state = 1152 mlirOperationStateGet(toMlirStringRef(name), location); 1153 if (!mlirOperands.empty()) 1154 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1155 mlirOperands.data()); 1156 if (!mlirResults.empty()) 1157 mlirOperationStateAddResults(&state, mlirResults.size(), 1158 mlirResults.data()); 1159 if (!mlirAttributes.empty()) { 1160 // Note that the attribute names directly reference bytes in 1161 // mlirAttributes, so that vector must not be changed from here 1162 // on. 1163 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1164 mlirNamedAttributes.reserve(mlirAttributes.size()); 1165 for (auto &it : mlirAttributes) 1166 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1167 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1168 toMlirStringRef(it.first)), 1169 it.second)); 1170 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1171 mlirNamedAttributes.data()); 1172 } 1173 if (!mlirSuccessors.empty()) 1174 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1175 mlirSuccessors.data()); 1176 if (regions) { 1177 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1178 mlirRegions.resize(regions); 1179 for (int i = 0; i < regions; ++i) 1180 mlirRegions[i] = mlirRegionCreate(); 1181 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1182 mlirRegions.data()); 1183 } 1184 1185 // Construct the operation. 1186 MlirOperation operation = mlirOperationCreate(&state); 1187 PyOperationRef created = 1188 PyOperation::createDetached(location->getContext(), operation); 1189 1190 // InsertPoint active? 1191 if (!maybeIp.is(py::cast(false))) { 1192 PyInsertionPoint *ip; 1193 if (maybeIp.is_none()) { 1194 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1195 } else { 1196 ip = py::cast<PyInsertionPoint *>(maybeIp); 1197 } 1198 if (ip) 1199 ip->insert(*created.get()); 1200 } 1201 1202 return created->createOpView(); 1203 } 1204 1205 py::object PyOperation::createOpView() { 1206 checkValid(); 1207 MlirIdentifier ident = mlirOperationGetName(get()); 1208 MlirStringRef identStr = mlirIdentifierStr(ident); 1209 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1210 StringRef(identStr.data, identStr.length)); 1211 if (opViewClass) 1212 return (*opViewClass)(getRef().getObject()); 1213 return py::cast(PyOpView(getRef().getObject())); 1214 } 1215 1216 void PyOperation::erase() { 1217 checkValid(); 1218 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1219 // Python reference to a child operation is live. All children should also 1220 // have their `valid` bit set to false. 1221 auto &liveOperations = getContext()->liveOperations; 1222 if (liveOperations.count(operation.ptr)) 1223 liveOperations.erase(operation.ptr); 1224 mlirOperationDestroy(operation); 1225 valid = false; 1226 } 1227 1228 //------------------------------------------------------------------------------ 1229 // PyOpView 1230 //------------------------------------------------------------------------------ 1231 1232 py::object PyOpView::buildGeneric( 1233 const py::object &cls, py::list resultTypeList, py::list operandList, 1234 llvm::Optional<py::dict> attributes, 1235 llvm::Optional<std::vector<PyBlock *>> successors, 1236 llvm::Optional<int> regions, DefaultingPyLocation location, 1237 const py::object &maybeIp) { 1238 PyMlirContextRef context = location->getContext(); 1239 // Class level operation construction metadata. 1240 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1241 // Operand and result segment specs are either none, which does no 1242 // variadic unpacking, or a list of ints with segment sizes, where each 1243 // element is either a positive number (typically 1 for a scalar) or -1 to 1244 // indicate that it is derived from the length of the same-indexed operand 1245 // or result (implying that it is a list at that position). 1246 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1247 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1248 1249 std::vector<uint32_t> operandSegmentLengths; 1250 std::vector<uint32_t> resultSegmentLengths; 1251 1252 // Validate/determine region count. 1253 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1254 int opMinRegionCount = std::get<0>(opRegionSpec); 1255 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1256 if (!regions) { 1257 regions = opMinRegionCount; 1258 } 1259 if (*regions < opMinRegionCount) { 1260 throw py::value_error( 1261 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1262 llvm::Twine(opMinRegionCount) + 1263 " regions but was built with regions=" + llvm::Twine(*regions)) 1264 .str()); 1265 } 1266 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1267 throw py::value_error( 1268 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1269 llvm::Twine(opMinRegionCount) + 1270 " regions but was built with regions=" + llvm::Twine(*regions)) 1271 .str()); 1272 } 1273 1274 // Unpack results. 1275 std::vector<PyType *> resultTypes; 1276 resultTypes.reserve(resultTypeList.size()); 1277 if (resultSegmentSpecObj.is_none()) { 1278 // Non-variadic result unpacking. 1279 for (const auto &it : llvm::enumerate(resultTypeList)) { 1280 try { 1281 resultTypes.push_back(py::cast<PyType *>(it.value())); 1282 if (!resultTypes.back()) 1283 throw py::cast_error(); 1284 } catch (py::cast_error &err) { 1285 throw py::value_error((llvm::Twine("Result ") + 1286 llvm::Twine(it.index()) + " of operation \"" + 1287 name + "\" must be a Type (" + err.what() + ")") 1288 .str()); 1289 } 1290 } 1291 } else { 1292 // Sized result unpacking. 1293 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1294 if (resultSegmentSpec.size() != resultTypeList.size()) { 1295 throw py::value_error((llvm::Twine("Operation \"") + name + 1296 "\" requires " + 1297 llvm::Twine(resultSegmentSpec.size()) + 1298 " result segments but was provided " + 1299 llvm::Twine(resultTypeList.size())) 1300 .str()); 1301 } 1302 resultSegmentLengths.reserve(resultTypeList.size()); 1303 for (const auto &it : 1304 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1305 int segmentSpec = std::get<1>(it.value()); 1306 if (segmentSpec == 1 || segmentSpec == 0) { 1307 // Unpack unary element. 1308 try { 1309 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1310 if (resultType) { 1311 resultTypes.push_back(resultType); 1312 resultSegmentLengths.push_back(1); 1313 } else if (segmentSpec == 0) { 1314 // Allowed to be optional. 1315 resultSegmentLengths.push_back(0); 1316 } else { 1317 throw py::cast_error("was None and result is not optional"); 1318 } 1319 } catch (py::cast_error &err) { 1320 throw py::value_error((llvm::Twine("Result ") + 1321 llvm::Twine(it.index()) + " of operation \"" + 1322 name + "\" must be a Type (" + err.what() + 1323 ")") 1324 .str()); 1325 } 1326 } else if (segmentSpec == -1) { 1327 // Unpack sequence by appending. 1328 try { 1329 if (std::get<0>(it.value()).is_none()) { 1330 // Treat it as an empty list. 1331 resultSegmentLengths.push_back(0); 1332 } else { 1333 // Unpack the list. 1334 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1335 for (py::object segmentItem : segment) { 1336 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1337 if (!resultTypes.back()) { 1338 throw py::cast_error("contained a None item"); 1339 } 1340 } 1341 resultSegmentLengths.push_back(segment.size()); 1342 } 1343 } catch (std::exception &err) { 1344 // NOTE: Sloppy to be using a catch-all here, but there are at least 1345 // three different unrelated exceptions that can be thrown in the 1346 // above "casts". Just keep the scope above small and catch them all. 1347 throw py::value_error((llvm::Twine("Result ") + 1348 llvm::Twine(it.index()) + " of operation \"" + 1349 name + "\" must be a Sequence of Types (" + 1350 err.what() + ")") 1351 .str()); 1352 } 1353 } else { 1354 throw py::value_error("Unexpected segment spec"); 1355 } 1356 } 1357 } 1358 1359 // Unpack operands. 1360 std::vector<PyValue *> operands; 1361 operands.reserve(operands.size()); 1362 if (operandSegmentSpecObj.is_none()) { 1363 // Non-sized operand unpacking. 1364 for (const auto &it : llvm::enumerate(operandList)) { 1365 try { 1366 operands.push_back(py::cast<PyValue *>(it.value())); 1367 if (!operands.back()) 1368 throw py::cast_error(); 1369 } catch (py::cast_error &err) { 1370 throw py::value_error((llvm::Twine("Operand ") + 1371 llvm::Twine(it.index()) + " of operation \"" + 1372 name + "\" must be a Value (" + err.what() + ")") 1373 .str()); 1374 } 1375 } 1376 } else { 1377 // Sized operand unpacking. 1378 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1379 if (operandSegmentSpec.size() != operandList.size()) { 1380 throw py::value_error((llvm::Twine("Operation \"") + name + 1381 "\" requires " + 1382 llvm::Twine(operandSegmentSpec.size()) + 1383 "operand segments but was provided " + 1384 llvm::Twine(operandList.size())) 1385 .str()); 1386 } 1387 operandSegmentLengths.reserve(operandList.size()); 1388 for (const auto &it : 1389 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1390 int segmentSpec = std::get<1>(it.value()); 1391 if (segmentSpec == 1 || segmentSpec == 0) { 1392 // Unpack unary element. 1393 try { 1394 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1395 if (operandValue) { 1396 operands.push_back(operandValue); 1397 operandSegmentLengths.push_back(1); 1398 } else if (segmentSpec == 0) { 1399 // Allowed to be optional. 1400 operandSegmentLengths.push_back(0); 1401 } else { 1402 throw py::cast_error("was None and operand is not optional"); 1403 } 1404 } catch (py::cast_error &err) { 1405 throw py::value_error((llvm::Twine("Operand ") + 1406 llvm::Twine(it.index()) + " of operation \"" + 1407 name + "\" must be a Value (" + err.what() + 1408 ")") 1409 .str()); 1410 } 1411 } else if (segmentSpec == -1) { 1412 // Unpack sequence by appending. 1413 try { 1414 if (std::get<0>(it.value()).is_none()) { 1415 // Treat it as an empty list. 1416 operandSegmentLengths.push_back(0); 1417 } else { 1418 // Unpack the list. 1419 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1420 for (py::object segmentItem : segment) { 1421 operands.push_back(py::cast<PyValue *>(segmentItem)); 1422 if (!operands.back()) { 1423 throw py::cast_error("contained a None item"); 1424 } 1425 } 1426 operandSegmentLengths.push_back(segment.size()); 1427 } 1428 } catch (std::exception &err) { 1429 // NOTE: Sloppy to be using a catch-all here, but there are at least 1430 // three different unrelated exceptions that can be thrown in the 1431 // above "casts". Just keep the scope above small and catch them all. 1432 throw py::value_error((llvm::Twine("Operand ") + 1433 llvm::Twine(it.index()) + " of operation \"" + 1434 name + "\" must be a Sequence of Values (" + 1435 err.what() + ")") 1436 .str()); 1437 } 1438 } else { 1439 throw py::value_error("Unexpected segment spec"); 1440 } 1441 } 1442 } 1443 1444 // Merge operand/result segment lengths into attributes if needed. 1445 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1446 // Dup. 1447 if (attributes) { 1448 attributes = py::dict(*attributes); 1449 } else { 1450 attributes = py::dict(); 1451 } 1452 if (attributes->contains("result_segment_sizes") || 1453 attributes->contains("operand_segment_sizes")) { 1454 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1455 "'operand_segment_sizes' attribute is unsupported. " 1456 "Use Operation.create for such low-level access."); 1457 } 1458 1459 // Add result_segment_sizes attribute. 1460 if (!resultSegmentLengths.empty()) { 1461 int64_t size = resultSegmentLengths.size(); 1462 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1463 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1464 resultSegmentLengths.size(), resultSegmentLengths.data()); 1465 (*attributes)["result_segment_sizes"] = 1466 PyAttribute(context, segmentLengthAttr); 1467 } 1468 1469 // Add operand_segment_sizes attribute. 1470 if (!operandSegmentLengths.empty()) { 1471 int64_t size = operandSegmentLengths.size(); 1472 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1473 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1474 operandSegmentLengths.size(), operandSegmentLengths.data()); 1475 (*attributes)["operand_segment_sizes"] = 1476 PyAttribute(context, segmentLengthAttr); 1477 } 1478 } 1479 1480 // Delegate to create. 1481 return PyOperation::create(name, 1482 /*results=*/std::move(resultTypes), 1483 /*operands=*/std::move(operands), 1484 /*attributes=*/std::move(attributes), 1485 /*successors=*/std::move(successors), 1486 /*regions=*/*regions, location, maybeIp); 1487 } 1488 1489 PyOpView::PyOpView(const py::object &operationObject) 1490 // Casting through the PyOperationBase base-class and then back to the 1491 // Operation lets us accept any PyOperationBase subclass. 1492 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1493 operationObject(operation.getRef().getObject()) {} 1494 1495 py::object PyOpView::createRawSubclass(const py::object &userClass) { 1496 // This is... a little gross. The typical pattern is to have a pure python 1497 // class that extends OpView like: 1498 // class AddFOp(_cext.ir.OpView): 1499 // def __init__(self, loc, lhs, rhs): 1500 // operation = loc.context.create_operation( 1501 // "addf", lhs, rhs, results=[lhs.type]) 1502 // super().__init__(operation) 1503 // 1504 // I.e. The goal of the user facing type is to provide a nice constructor 1505 // that has complete freedom for the op under construction. This is at odds 1506 // with our other desire to sometimes create this object by just passing an 1507 // operation (to initialize the base class). We could do *arg and **kwargs 1508 // munging to try to make it work, but instead, we synthesize a new class 1509 // on the fly which extends this user class (AddFOp in this example) and 1510 // *give it* the base class's __init__ method, thus bypassing the 1511 // intermediate subclass's __init__ method entirely. While slightly, 1512 // underhanded, this is safe/legal because the type hierarchy has not changed 1513 // (we just added a new leaf) and we aren't mucking around with __new__. 1514 // Typically, this new class will be stored on the original as "_Raw" and will 1515 // be used for casts and other things that need a variant of the class that 1516 // is initialized purely from an operation. 1517 py::object parentMetaclass = 1518 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1519 py::dict attributes; 1520 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1521 // now. 1522 // auto opViewType = py::type::of<PyOpView>(); 1523 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1524 attributes["__init__"] = opViewType.attr("__init__"); 1525 py::str origName = userClass.attr("__name__"); 1526 py::str newName = py::str("_") + origName; 1527 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1528 } 1529 1530 //------------------------------------------------------------------------------ 1531 // PyInsertionPoint. 1532 //------------------------------------------------------------------------------ 1533 1534 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1535 1536 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1537 : refOperation(beforeOperationBase.getOperation().getRef()), 1538 block((*refOperation)->getBlock()) {} 1539 1540 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1541 PyOperation &operation = operationBase.getOperation(); 1542 if (operation.isAttached()) 1543 throw SetPyError(PyExc_ValueError, 1544 "Attempt to insert operation that is already attached"); 1545 block.getParentOperation()->checkValid(); 1546 MlirOperation beforeOp = {nullptr}; 1547 if (refOperation) { 1548 // Insert before operation. 1549 (*refOperation)->checkValid(); 1550 beforeOp = (*refOperation)->get(); 1551 } else { 1552 // Insert at end (before null) is only valid if the block does not 1553 // already end in a known terminator (violating this will cause assertion 1554 // failures later). 1555 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1556 throw py::index_error("Cannot insert operation at the end of a block " 1557 "that already has a terminator. Did you mean to " 1558 "use 'InsertionPoint.at_block_terminator(block)' " 1559 "versus 'InsertionPoint(block)'?"); 1560 } 1561 } 1562 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1563 operation.setAttached(); 1564 } 1565 1566 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1567 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1568 if (mlirOperationIsNull(firstOp)) { 1569 // Just insert at end. 1570 return PyInsertionPoint(block); 1571 } 1572 1573 // Insert before first op. 1574 PyOperationRef firstOpRef = PyOperation::forOperation( 1575 block.getParentOperation()->getContext(), firstOp); 1576 return PyInsertionPoint{block, std::move(firstOpRef)}; 1577 } 1578 1579 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1580 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1581 if (mlirOperationIsNull(terminator)) 1582 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1583 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1584 block.getParentOperation()->getContext(), terminator); 1585 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1586 } 1587 1588 py::object PyInsertionPoint::contextEnter() { 1589 return PyThreadContextEntry::pushInsertionPoint(*this); 1590 } 1591 1592 void PyInsertionPoint::contextExit(const pybind11::object &excType, 1593 const pybind11::object &excVal, 1594 const pybind11::object &excTb) { 1595 PyThreadContextEntry::popInsertionPoint(*this); 1596 } 1597 1598 //------------------------------------------------------------------------------ 1599 // PyAttribute. 1600 //------------------------------------------------------------------------------ 1601 1602 bool PyAttribute::operator==(const PyAttribute &other) { 1603 return mlirAttributeEqual(attr, other.attr); 1604 } 1605 1606 py::object PyAttribute::getCapsule() { 1607 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1608 } 1609 1610 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1611 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1612 if (mlirAttributeIsNull(rawAttr)) 1613 throw py::error_already_set(); 1614 return PyAttribute( 1615 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1616 } 1617 1618 //------------------------------------------------------------------------------ 1619 // PyNamedAttribute. 1620 //------------------------------------------------------------------------------ 1621 1622 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1623 : ownedName(new std::string(std::move(ownedName))) { 1624 namedAttr = mlirNamedAttributeGet( 1625 mlirIdentifierGet(mlirAttributeGetContext(attr), 1626 toMlirStringRef(*this->ownedName)), 1627 attr); 1628 } 1629 1630 //------------------------------------------------------------------------------ 1631 // PyType. 1632 //------------------------------------------------------------------------------ 1633 1634 bool PyType::operator==(const PyType &other) { 1635 return mlirTypeEqual(type, other.type); 1636 } 1637 1638 py::object PyType::getCapsule() { 1639 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1640 } 1641 1642 PyType PyType::createFromCapsule(py::object capsule) { 1643 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1644 if (mlirTypeIsNull(rawType)) 1645 throw py::error_already_set(); 1646 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1647 rawType); 1648 } 1649 1650 //------------------------------------------------------------------------------ 1651 // PyValue and subclases. 1652 //------------------------------------------------------------------------------ 1653 1654 pybind11::object PyValue::getCapsule() { 1655 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1656 } 1657 1658 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1659 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1660 if (mlirValueIsNull(value)) 1661 throw py::error_already_set(); 1662 MlirOperation owner; 1663 if (mlirValueIsAOpResult(value)) 1664 owner = mlirOpResultGetOwner(value); 1665 if (mlirValueIsABlockArgument(value)) 1666 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1667 if (mlirOperationIsNull(owner)) 1668 throw py::error_already_set(); 1669 MlirContext ctx = mlirOperationGetContext(owner); 1670 PyOperationRef ownerRef = 1671 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1672 return PyValue(ownerRef, value); 1673 } 1674 1675 //------------------------------------------------------------------------------ 1676 // PySymbolTable. 1677 //------------------------------------------------------------------------------ 1678 1679 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1680 : operation(operation.getOperation().getRef()) { 1681 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1682 if (mlirSymbolTableIsNull(symbolTable)) { 1683 throw py::cast_error("Operation is not a Symbol Table."); 1684 } 1685 } 1686 1687 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1688 operation->checkValid(); 1689 MlirOperation symbol = mlirSymbolTableLookup( 1690 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1691 if (mlirOperationIsNull(symbol)) 1692 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1693 1694 return PyOperation::forOperation(operation->getContext(), symbol, 1695 operation.getObject()) 1696 ->createOpView(); 1697 } 1698 1699 void PySymbolTable::erase(PyOperationBase &symbol) { 1700 operation->checkValid(); 1701 symbol.getOperation().checkValid(); 1702 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1703 // The operation is also erased, so we must invalidate it. There may be Python 1704 // references to this operation so we don't want to delete it from the list of 1705 // live operations here. 1706 symbol.getOperation().valid = false; 1707 } 1708 1709 void PySymbolTable::dunderDel(const std::string &name) { 1710 py::object operation = dunderGetItem(name); 1711 erase(py::cast<PyOperationBase &>(operation)); 1712 } 1713 1714 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1715 operation->checkValid(); 1716 symbol.getOperation().checkValid(); 1717 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1718 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1719 if (mlirAttributeIsNull(symbolAttr)) 1720 throw py::value_error("Expected operation to have a symbol name."); 1721 return PyAttribute( 1722 symbol.getOperation().getContext(), 1723 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1724 } 1725 1726 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1727 // Op must already be a symbol. 1728 PyOperation &operation = symbol.getOperation(); 1729 operation.checkValid(); 1730 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1731 MlirAttribute existingNameAttr = 1732 mlirOperationGetAttributeByName(operation.get(), attrName); 1733 if (mlirAttributeIsNull(existingNameAttr)) 1734 throw py::value_error("Expected operation to have a symbol name."); 1735 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1736 } 1737 1738 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1739 const std::string &name) { 1740 // Op must already be a symbol. 1741 PyOperation &operation = symbol.getOperation(); 1742 operation.checkValid(); 1743 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1744 MlirAttribute existingNameAttr = 1745 mlirOperationGetAttributeByName(operation.get(), attrName); 1746 if (mlirAttributeIsNull(existingNameAttr)) 1747 throw py::value_error("Expected operation to have a symbol name."); 1748 MlirAttribute newNameAttr = 1749 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1750 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1751 } 1752 1753 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1754 PyOperation &operation = symbol.getOperation(); 1755 operation.checkValid(); 1756 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1757 MlirAttribute existingVisAttr = 1758 mlirOperationGetAttributeByName(operation.get(), attrName); 1759 if (mlirAttributeIsNull(existingVisAttr)) 1760 throw py::value_error("Expected operation to have a symbol visibility."); 1761 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1762 } 1763 1764 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1765 const std::string &visibility) { 1766 if (visibility != "public" && visibility != "private" && 1767 visibility != "nested") 1768 throw py::value_error( 1769 "Expected visibility to be 'public', 'private' or 'nested'"); 1770 PyOperation &operation = symbol.getOperation(); 1771 operation.checkValid(); 1772 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1773 MlirAttribute existingVisAttr = 1774 mlirOperationGetAttributeByName(operation.get(), attrName); 1775 if (mlirAttributeIsNull(existingVisAttr)) 1776 throw py::value_error("Expected operation to have a symbol visibility."); 1777 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1778 toMlirStringRef(visibility)); 1779 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1780 } 1781 1782 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1783 const std::string &newSymbol, 1784 PyOperationBase &from) { 1785 PyOperation &fromOperation = from.getOperation(); 1786 fromOperation.checkValid(); 1787 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1788 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1789 from.getOperation()))) 1790 1791 throw py::value_error("Symbol rename failed"); 1792 } 1793 1794 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1795 bool allSymUsesVisible, 1796 py::object callback) { 1797 PyOperation &fromOperation = from.getOperation(); 1798 fromOperation.checkValid(); 1799 struct UserData { 1800 PyMlirContextRef context; 1801 py::object callback; 1802 bool gotException; 1803 std::string exceptionWhat; 1804 py::object exceptionType; 1805 }; 1806 UserData userData{ 1807 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1808 mlirSymbolTableWalkSymbolTables( 1809 fromOperation.get(), allSymUsesVisible, 1810 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1811 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1812 auto pyFoundOp = 1813 PyOperation::forOperation(calleeUserData->context, foundOp); 1814 if (calleeUserData->gotException) 1815 return; 1816 try { 1817 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1818 } catch (py::error_already_set &e) { 1819 calleeUserData->gotException = true; 1820 calleeUserData->exceptionWhat = e.what(); 1821 calleeUserData->exceptionType = e.type(); 1822 } 1823 }, 1824 static_cast<void *>(&userData)); 1825 if (userData.gotException) { 1826 std::string message("Exception raised in callback: "); 1827 message.append(userData.exceptionWhat); 1828 throw std::runtime_error(message); 1829 } 1830 } 1831 1832 namespace { 1833 /// CRTP base class for Python MLIR values that subclass Value and should be 1834 /// castable from it. The value hierarchy is one level deep and is not supposed 1835 /// to accommodate other levels unless core MLIR changes. 1836 template <typename DerivedTy> 1837 class PyConcreteValue : public PyValue { 1838 public: 1839 // Derived classes must define statics for: 1840 // IsAFunctionTy isaFunction 1841 // const char *pyClassName 1842 // and redefine bindDerived. 1843 using ClassTy = py::class_<DerivedTy, PyValue>; 1844 using IsAFunctionTy = bool (*)(MlirValue); 1845 1846 PyConcreteValue() = default; 1847 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1848 : PyValue(operationRef, value) {} 1849 PyConcreteValue(PyValue &orig) 1850 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1851 1852 /// Attempts to cast the original value to the derived type and throws on 1853 /// type mismatches. 1854 static MlirValue castFrom(PyValue &orig) { 1855 if (!DerivedTy::isaFunction(orig.get())) { 1856 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1857 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1858 DerivedTy::pyClassName + 1859 " (from " + origRepr + ")"); 1860 } 1861 return orig.get(); 1862 } 1863 1864 /// Binds the Python module objects to functions of this class. 1865 static void bind(py::module &m) { 1866 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1867 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1868 cls.def_static( 1869 "isinstance", 1870 [](PyValue &otherValue) -> bool { 1871 return DerivedTy::isaFunction(otherValue); 1872 }, 1873 py::arg("other_value")); 1874 DerivedTy::bindDerived(cls); 1875 } 1876 1877 /// Implemented by derived classes to add methods to the Python subclass. 1878 static void bindDerived(ClassTy &m) {} 1879 }; 1880 1881 /// Python wrapper for MlirBlockArgument. 1882 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1883 public: 1884 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1885 static constexpr const char *pyClassName = "BlockArgument"; 1886 using PyConcreteValue::PyConcreteValue; 1887 1888 static void bindDerived(ClassTy &c) { 1889 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1890 return PyBlock(self.getParentOperation(), 1891 mlirBlockArgumentGetOwner(self.get())); 1892 }); 1893 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1894 return mlirBlockArgumentGetArgNumber(self.get()); 1895 }); 1896 c.def( 1897 "set_type", 1898 [](PyBlockArgument &self, PyType type) { 1899 return mlirBlockArgumentSetType(self.get(), type); 1900 }, 1901 py::arg("type")); 1902 } 1903 }; 1904 1905 /// Python wrapper for MlirOpResult. 1906 class PyOpResult : public PyConcreteValue<PyOpResult> { 1907 public: 1908 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1909 static constexpr const char *pyClassName = "OpResult"; 1910 using PyConcreteValue::PyConcreteValue; 1911 1912 static void bindDerived(ClassTy &c) { 1913 c.def_property_readonly("owner", [](PyOpResult &self) { 1914 assert( 1915 mlirOperationEqual(self.getParentOperation()->get(), 1916 mlirOpResultGetOwner(self.get())) && 1917 "expected the owner of the value in Python to match that in the IR"); 1918 return self.getParentOperation().getObject(); 1919 }); 1920 c.def_property_readonly("result_number", [](PyOpResult &self) { 1921 return mlirOpResultGetResultNumber(self.get()); 1922 }); 1923 } 1924 }; 1925 1926 /// Returns the list of types of the values held by container. 1927 template <typename Container> 1928 static std::vector<PyType> getValueTypes(Container &container, 1929 PyMlirContextRef &context) { 1930 std::vector<PyType> result; 1931 result.reserve(container.getNumElements()); 1932 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1933 result.push_back( 1934 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1935 } 1936 return result; 1937 } 1938 1939 /// A list of block arguments. Internally, these are stored as consecutive 1940 /// elements, random access is cheap. The argument list is associated with the 1941 /// operation that contains the block (detached blocks are not allowed in 1942 /// Python bindings) and extends its lifetime. 1943 class PyBlockArgumentList 1944 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1945 public: 1946 static constexpr const char *pyClassName = "BlockArgumentList"; 1947 1948 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1949 intptr_t startIndex = 0, intptr_t length = -1, 1950 intptr_t step = 1) 1951 : Sliceable(startIndex, 1952 length == -1 ? mlirBlockGetNumArguments(block) : length, 1953 step), 1954 operation(std::move(operation)), block(block) {} 1955 1956 /// Returns the number of arguments in the list. 1957 intptr_t getNumElements() { 1958 operation->checkValid(); 1959 return mlirBlockGetNumArguments(block); 1960 } 1961 1962 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1963 PyBlockArgument getElement(intptr_t pos) { 1964 MlirValue argument = mlirBlockGetArgument(block, pos); 1965 return PyBlockArgument(operation, argument); 1966 } 1967 1968 /// Returns a sublist of this list. 1969 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1970 intptr_t step) { 1971 return PyBlockArgumentList(operation, block, startIndex, length, step); 1972 } 1973 1974 static void bindDerived(ClassTy &c) { 1975 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1976 return getValueTypes(self, self.operation->getContext()); 1977 }); 1978 } 1979 1980 private: 1981 PyOperationRef operation; 1982 MlirBlock block; 1983 }; 1984 1985 /// A list of operation operands. Internally, these are stored as consecutive 1986 /// elements, random access is cheap. The result list is associated with the 1987 /// operation whose results these are, and extends the lifetime of this 1988 /// operation. 1989 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1990 public: 1991 static constexpr const char *pyClassName = "OpOperandList"; 1992 1993 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1994 intptr_t length = -1, intptr_t step = 1) 1995 : Sliceable(startIndex, 1996 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1997 : length, 1998 step), 1999 operation(operation) {} 2000 2001 intptr_t getNumElements() { 2002 operation->checkValid(); 2003 return mlirOperationGetNumOperands(operation->get()); 2004 } 2005 2006 PyValue getElement(intptr_t pos) { 2007 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2008 MlirOperation owner; 2009 if (mlirValueIsAOpResult(operand)) 2010 owner = mlirOpResultGetOwner(operand); 2011 else if (mlirValueIsABlockArgument(operand)) 2012 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2013 else 2014 assert(false && "Value must be an block arg or op result."); 2015 PyOperationRef pyOwner = 2016 PyOperation::forOperation(operation->getContext(), owner); 2017 return PyValue(pyOwner, operand); 2018 } 2019 2020 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2021 return PyOpOperandList(operation, startIndex, length, step); 2022 } 2023 2024 void dunderSetItem(intptr_t index, PyValue value) { 2025 index = wrapIndex(index); 2026 mlirOperationSetOperand(operation->get(), index, value.get()); 2027 } 2028 2029 static void bindDerived(ClassTy &c) { 2030 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2031 } 2032 2033 private: 2034 PyOperationRef operation; 2035 }; 2036 2037 /// A list of operation results. Internally, these are stored as consecutive 2038 /// elements, random access is cheap. The result list is associated with the 2039 /// operation whose results these are, and extends the lifetime of this 2040 /// operation. 2041 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 2042 public: 2043 static constexpr const char *pyClassName = "OpResultList"; 2044 2045 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 2046 intptr_t length = -1, intptr_t step = 1) 2047 : Sliceable(startIndex, 2048 length == -1 ? mlirOperationGetNumResults(operation->get()) 2049 : length, 2050 step), 2051 operation(operation) {} 2052 2053 intptr_t getNumElements() { 2054 operation->checkValid(); 2055 return mlirOperationGetNumResults(operation->get()); 2056 } 2057 2058 PyOpResult getElement(intptr_t index) { 2059 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 2060 return PyOpResult(value); 2061 } 2062 2063 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2064 return PyOpResultList(operation, startIndex, length, step); 2065 } 2066 2067 static void bindDerived(ClassTy &c) { 2068 c.def_property_readonly("types", [](PyOpResultList &self) { 2069 return getValueTypes(self, self.operation->getContext()); 2070 }); 2071 } 2072 2073 private: 2074 PyOperationRef operation; 2075 }; 2076 2077 /// A list of operation attributes. Can be indexed by name, producing 2078 /// attributes, or by index, producing named attributes. 2079 class PyOpAttributeMap { 2080 public: 2081 PyOpAttributeMap(PyOperationRef operation) 2082 : operation(std::move(operation)) {} 2083 2084 PyAttribute dunderGetItemNamed(const std::string &name) { 2085 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2086 toMlirStringRef(name)); 2087 if (mlirAttributeIsNull(attr)) { 2088 throw SetPyError(PyExc_KeyError, 2089 "attempt to access a non-existent attribute"); 2090 } 2091 return PyAttribute(operation->getContext(), attr); 2092 } 2093 2094 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2095 if (index < 0 || index >= dunderLen()) { 2096 throw SetPyError(PyExc_IndexError, 2097 "attempt to access out of bounds attribute"); 2098 } 2099 MlirNamedAttribute namedAttr = 2100 mlirOperationGetAttribute(operation->get(), index); 2101 return PyNamedAttribute( 2102 namedAttr.attribute, 2103 std::string(mlirIdentifierStr(namedAttr.name).data, 2104 mlirIdentifierStr(namedAttr.name).length)); 2105 } 2106 2107 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2108 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2109 attr); 2110 } 2111 2112 void dunderDelItem(const std::string &name) { 2113 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2114 toMlirStringRef(name)); 2115 if (!removed) 2116 throw SetPyError(PyExc_KeyError, 2117 "attempt to delete a non-existent attribute"); 2118 } 2119 2120 intptr_t dunderLen() { 2121 return mlirOperationGetNumAttributes(operation->get()); 2122 } 2123 2124 bool dunderContains(const std::string &name) { 2125 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2126 operation->get(), toMlirStringRef(name))); 2127 } 2128 2129 static void bind(py::module &m) { 2130 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2131 .def("__contains__", &PyOpAttributeMap::dunderContains) 2132 .def("__len__", &PyOpAttributeMap::dunderLen) 2133 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2134 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2135 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2136 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2137 } 2138 2139 private: 2140 PyOperationRef operation; 2141 }; 2142 2143 } // namespace 2144 2145 //------------------------------------------------------------------------------ 2146 // Populates the core exports of the 'ir' submodule. 2147 //------------------------------------------------------------------------------ 2148 2149 void mlir::python::populateIRCore(py::module &m) { 2150 //---------------------------------------------------------------------------- 2151 // Enums. 2152 //---------------------------------------------------------------------------- 2153 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local()) 2154 .value("ERROR", MlirDiagnosticError) 2155 .value("WARNING", MlirDiagnosticWarning) 2156 .value("NOTE", MlirDiagnosticNote) 2157 .value("REMARK", MlirDiagnosticRemark); 2158 2159 //---------------------------------------------------------------------------- 2160 // Mapping of Diagnostics. 2161 //---------------------------------------------------------------------------- 2162 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local()) 2163 .def_property_readonly("severity", &PyDiagnostic::getSeverity) 2164 .def_property_readonly("location", &PyDiagnostic::getLocation) 2165 .def_property_readonly("message", &PyDiagnostic::getMessage) 2166 .def_property_readonly("notes", &PyDiagnostic::getNotes) 2167 .def("__str__", [](PyDiagnostic &self) -> py::str { 2168 if (!self.isValid()) 2169 return "<Invalid Diagnostic>"; 2170 return self.getMessage(); 2171 }); 2172 2173 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local()) 2174 .def("detach", &PyDiagnosticHandler::detach) 2175 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) 2176 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) 2177 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2178 .def("__exit__", &PyDiagnosticHandler::contextExit); 2179 2180 //---------------------------------------------------------------------------- 2181 // Mapping of MlirContext. 2182 //---------------------------------------------------------------------------- 2183 py::class_<PyMlirContext>(m, "Context", py::module_local()) 2184 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2185 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2186 .def("_get_context_again", 2187 [](PyMlirContext &self) { 2188 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2189 return ref.releaseObject(); 2190 }) 2191 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2192 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2193 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2194 &PyMlirContext::getCapsule) 2195 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2196 .def("__enter__", &PyMlirContext::contextEnter) 2197 .def("__exit__", &PyMlirContext::contextExit) 2198 .def_property_readonly_static( 2199 "current", 2200 [](py::object & /*class*/) { 2201 auto *context = PyThreadContextEntry::getDefaultContext(); 2202 if (!context) 2203 throw SetPyError(PyExc_ValueError, "No current Context"); 2204 return context; 2205 }, 2206 "Gets the Context bound to the current thread or raises ValueError") 2207 .def_property_readonly( 2208 "dialects", 2209 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2210 "Gets a container for accessing dialects by name") 2211 .def_property_readonly( 2212 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2213 "Alias for 'dialect'") 2214 .def( 2215 "get_dialect_descriptor", 2216 [=](PyMlirContext &self, std::string &name) { 2217 MlirDialect dialect = mlirContextGetOrLoadDialect( 2218 self.get(), {name.data(), name.size()}); 2219 if (mlirDialectIsNull(dialect)) { 2220 throw SetPyError(PyExc_ValueError, 2221 Twine("Dialect '") + name + "' not found"); 2222 } 2223 return PyDialectDescriptor(self.getRef(), dialect); 2224 }, 2225 py::arg("dialect_name"), 2226 "Gets or loads a dialect by name, returning its descriptor object") 2227 .def_property( 2228 "allow_unregistered_dialects", 2229 [](PyMlirContext &self) -> bool { 2230 return mlirContextGetAllowUnregisteredDialects(self.get()); 2231 }, 2232 [](PyMlirContext &self, bool value) { 2233 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2234 }) 2235 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2236 py::arg("callback"), 2237 "Attaches a diagnostic handler that will receive callbacks") 2238 .def( 2239 "enable_multithreading", 2240 [](PyMlirContext &self, bool enable) { 2241 mlirContextEnableMultithreading(self.get(), enable); 2242 }, 2243 py::arg("enable")) 2244 .def( 2245 "is_registered_operation", 2246 [](PyMlirContext &self, std::string &name) { 2247 return mlirContextIsRegisteredOperation( 2248 self.get(), MlirStringRef{name.data(), name.size()}); 2249 }, 2250 py::arg("operation_name")); 2251 2252 //---------------------------------------------------------------------------- 2253 // Mapping of PyDialectDescriptor 2254 //---------------------------------------------------------------------------- 2255 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2256 .def_property_readonly("namespace", 2257 [](PyDialectDescriptor &self) { 2258 MlirStringRef ns = 2259 mlirDialectGetNamespace(self.get()); 2260 return py::str(ns.data, ns.length); 2261 }) 2262 .def("__repr__", [](PyDialectDescriptor &self) { 2263 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2264 std::string repr("<DialectDescriptor "); 2265 repr.append(ns.data, ns.length); 2266 repr.append(">"); 2267 return repr; 2268 }); 2269 2270 //---------------------------------------------------------------------------- 2271 // Mapping of PyDialects 2272 //---------------------------------------------------------------------------- 2273 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2274 .def("__getitem__", 2275 [=](PyDialects &self, std::string keyName) { 2276 MlirDialect dialect = 2277 self.getDialectForKey(keyName, /*attrError=*/false); 2278 py::object descriptor = 2279 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2280 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2281 }) 2282 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2283 MlirDialect dialect = 2284 self.getDialectForKey(attrName, /*attrError=*/true); 2285 py::object descriptor = 2286 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2287 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2288 }); 2289 2290 //---------------------------------------------------------------------------- 2291 // Mapping of PyDialect 2292 //---------------------------------------------------------------------------- 2293 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2294 .def(py::init<py::object>(), py::arg("descriptor")) 2295 .def_property_readonly( 2296 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2297 .def("__repr__", [](py::object self) { 2298 auto clazz = self.attr("__class__"); 2299 return py::str("<Dialect ") + 2300 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2301 clazz.attr("__module__") + py::str(".") + 2302 clazz.attr("__name__") + py::str(")>"); 2303 }); 2304 2305 //---------------------------------------------------------------------------- 2306 // Mapping of Location 2307 //---------------------------------------------------------------------------- 2308 py::class_<PyLocation>(m, "Location", py::module_local()) 2309 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2310 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2311 .def("__enter__", &PyLocation::contextEnter) 2312 .def("__exit__", &PyLocation::contextExit) 2313 .def("__eq__", 2314 [](PyLocation &self, PyLocation &other) -> bool { 2315 return mlirLocationEqual(self, other); 2316 }) 2317 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2318 .def_property_readonly_static( 2319 "current", 2320 [](py::object & /*class*/) { 2321 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2322 if (!loc) 2323 throw SetPyError(PyExc_ValueError, "No current Location"); 2324 return loc; 2325 }, 2326 "Gets the Location bound to the current thread or raises ValueError") 2327 .def_static( 2328 "unknown", 2329 [](DefaultingPyMlirContext context) { 2330 return PyLocation(context->getRef(), 2331 mlirLocationUnknownGet(context->get())); 2332 }, 2333 py::arg("context") = py::none(), 2334 "Gets a Location representing an unknown location") 2335 .def_static( 2336 "callsite", 2337 [](PyLocation callee, const std::vector<PyLocation> &frames, 2338 DefaultingPyMlirContext context) { 2339 if (frames.empty()) 2340 throw py::value_error("No caller frames provided"); 2341 MlirLocation caller = frames.back().get(); 2342 for (const PyLocation &frame : 2343 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2344 caller = mlirLocationCallSiteGet(frame.get(), caller); 2345 return PyLocation(context->getRef(), 2346 mlirLocationCallSiteGet(callee.get(), caller)); 2347 }, 2348 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2349 kContextGetCallSiteLocationDocstring) 2350 .def_static( 2351 "file", 2352 [](std::string filename, int line, int col, 2353 DefaultingPyMlirContext context) { 2354 return PyLocation( 2355 context->getRef(), 2356 mlirLocationFileLineColGet( 2357 context->get(), toMlirStringRef(filename), line, col)); 2358 }, 2359 py::arg("filename"), py::arg("line"), py::arg("col"), 2360 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2361 .def_static( 2362 "fused", 2363 [](const std::vector<PyLocation> &pyLocations, 2364 llvm::Optional<PyAttribute> metadata, 2365 DefaultingPyMlirContext context) { 2366 llvm::SmallVector<MlirLocation, 4> locations; 2367 locations.reserve(pyLocations.size()); 2368 for (auto &pyLocation : pyLocations) 2369 locations.push_back(pyLocation.get()); 2370 MlirLocation location = mlirLocationFusedGet( 2371 context->get(), locations.size(), locations.data(), 2372 metadata ? metadata->get() : MlirAttribute{0}); 2373 return PyLocation(context->getRef(), location); 2374 }, 2375 py::arg("locations"), py::arg("metadata") = py::none(), 2376 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2377 .def_static( 2378 "name", 2379 [](std::string name, llvm::Optional<PyLocation> childLoc, 2380 DefaultingPyMlirContext context) { 2381 return PyLocation( 2382 context->getRef(), 2383 mlirLocationNameGet( 2384 context->get(), toMlirStringRef(name), 2385 childLoc ? childLoc->get() 2386 : mlirLocationUnknownGet(context->get()))); 2387 }, 2388 py::arg("name"), py::arg("childLoc") = py::none(), 2389 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2390 .def_property_readonly( 2391 "context", 2392 [](PyLocation &self) { return self.getContext().getObject(); }, 2393 "Context that owns the Location") 2394 .def( 2395 "emit_error", 2396 [](PyLocation &self, std::string message) { 2397 mlirEmitError(self, message.c_str()); 2398 }, 2399 py::arg("message"), "Emits an error at this location") 2400 .def("__repr__", [](PyLocation &self) { 2401 PyPrintAccumulator printAccum; 2402 mlirLocationPrint(self, printAccum.getCallback(), 2403 printAccum.getUserData()); 2404 return printAccum.join(); 2405 }); 2406 2407 //---------------------------------------------------------------------------- 2408 // Mapping of Module 2409 //---------------------------------------------------------------------------- 2410 py::class_<PyModule>(m, "Module", py::module_local()) 2411 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2412 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2413 .def_static( 2414 "parse", 2415 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2416 MlirModule module = mlirModuleCreateParse( 2417 context->get(), toMlirStringRef(moduleAsm)); 2418 // TODO: Rework error reporting once diagnostic engine is exposed 2419 // in C API. 2420 if (mlirModuleIsNull(module)) { 2421 throw SetPyError( 2422 PyExc_ValueError, 2423 "Unable to parse module assembly (see diagnostics)"); 2424 } 2425 return PyModule::forModule(module).releaseObject(); 2426 }, 2427 py::arg("asm"), py::arg("context") = py::none(), 2428 kModuleParseDocstring) 2429 .def_static( 2430 "create", 2431 [](DefaultingPyLocation loc) { 2432 MlirModule module = mlirModuleCreateEmpty(loc); 2433 return PyModule::forModule(module).releaseObject(); 2434 }, 2435 py::arg("loc") = py::none(), "Creates an empty module") 2436 .def_property_readonly( 2437 "context", 2438 [](PyModule &self) { return self.getContext().getObject(); }, 2439 "Context that created the Module") 2440 .def_property_readonly( 2441 "operation", 2442 [](PyModule &self) { 2443 return PyOperation::forOperation(self.getContext(), 2444 mlirModuleGetOperation(self.get()), 2445 self.getRef().releaseObject()) 2446 .releaseObject(); 2447 }, 2448 "Accesses the module as an operation") 2449 .def_property_readonly( 2450 "body", 2451 [](PyModule &self) { 2452 PyOperationRef moduleOp = PyOperation::forOperation( 2453 self.getContext(), mlirModuleGetOperation(self.get()), 2454 self.getRef().releaseObject()); 2455 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 2456 return returnBlock; 2457 }, 2458 "Return the block for this module") 2459 .def( 2460 "dump", 2461 [](PyModule &self) { 2462 mlirOperationDump(mlirModuleGetOperation(self.get())); 2463 }, 2464 kDumpDocstring) 2465 .def( 2466 "__str__", 2467 [](py::object self) { 2468 // Defer to the operation's __str__. 2469 return self.attr("operation").attr("__str__")(); 2470 }, 2471 kOperationStrDunderDocstring); 2472 2473 //---------------------------------------------------------------------------- 2474 // Mapping of Operation. 2475 //---------------------------------------------------------------------------- 2476 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2477 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2478 [](PyOperationBase &self) { 2479 return self.getOperation().getCapsule(); 2480 }) 2481 .def("__eq__", 2482 [](PyOperationBase &self, PyOperationBase &other) { 2483 return &self.getOperation() == &other.getOperation(); 2484 }) 2485 .def("__eq__", 2486 [](PyOperationBase &self, py::object other) { return false; }) 2487 .def("__hash__", 2488 [](PyOperationBase &self) { 2489 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2490 }) 2491 .def_property_readonly("attributes", 2492 [](PyOperationBase &self) { 2493 return PyOpAttributeMap( 2494 self.getOperation().getRef()); 2495 }) 2496 .def_property_readonly("operands", 2497 [](PyOperationBase &self) { 2498 return PyOpOperandList( 2499 self.getOperation().getRef()); 2500 }) 2501 .def_property_readonly("regions", 2502 [](PyOperationBase &self) { 2503 return PyRegionList( 2504 self.getOperation().getRef()); 2505 }) 2506 .def_property_readonly( 2507 "results", 2508 [](PyOperationBase &self) { 2509 return PyOpResultList(self.getOperation().getRef()); 2510 }, 2511 "Returns the list of Operation results.") 2512 .def_property_readonly( 2513 "result", 2514 [](PyOperationBase &self) { 2515 auto &operation = self.getOperation(); 2516 auto numResults = mlirOperationGetNumResults(operation); 2517 if (numResults != 1) { 2518 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2519 throw SetPyError( 2520 PyExc_ValueError, 2521 Twine("Cannot call .result on operation ") + 2522 StringRef(name.data, name.length) + " which has " + 2523 Twine(numResults) + 2524 " results (it is only valid for operations with a " 2525 "single result)"); 2526 } 2527 return PyOpResult(operation.getRef(), 2528 mlirOperationGetResult(operation, 0)); 2529 }, 2530 "Shortcut to get an op result if it has only one (throws an error " 2531 "otherwise).") 2532 .def_property_readonly( 2533 "location", 2534 [](PyOperationBase &self) { 2535 PyOperation &operation = self.getOperation(); 2536 return PyLocation(operation.getContext(), 2537 mlirOperationGetLocation(operation.get())); 2538 }, 2539 "Returns the source location the operation was defined or derived " 2540 "from.") 2541 .def( 2542 "__str__", 2543 [](PyOperationBase &self) { 2544 return self.getAsm(/*binary=*/false, 2545 /*largeElementsLimit=*/llvm::None, 2546 /*enableDebugInfo=*/false, 2547 /*prettyDebugInfo=*/false, 2548 /*printGenericOpForm=*/false, 2549 /*useLocalScope=*/false, 2550 /*assumeVerified=*/false); 2551 }, 2552 "Returns the assembly form of the operation.") 2553 .def("print", &PyOperationBase::print, 2554 // Careful: Lots of arguments must match up with print method. 2555 py::arg("file") = py::none(), py::arg("binary") = false, 2556 py::arg("large_elements_limit") = py::none(), 2557 py::arg("enable_debug_info") = false, 2558 py::arg("pretty_debug_info") = false, 2559 py::arg("print_generic_op_form") = false, 2560 py::arg("use_local_scope") = false, 2561 py::arg("assume_verified") = false, kOperationPrintDocstring) 2562 .def("get_asm", &PyOperationBase::getAsm, 2563 // Careful: Lots of arguments must match up with get_asm method. 2564 py::arg("binary") = false, 2565 py::arg("large_elements_limit") = py::none(), 2566 py::arg("enable_debug_info") = false, 2567 py::arg("pretty_debug_info") = false, 2568 py::arg("print_generic_op_form") = false, 2569 py::arg("use_local_scope") = false, 2570 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2571 .def( 2572 "verify", 2573 [](PyOperationBase &self) { 2574 return mlirOperationVerify(self.getOperation()); 2575 }, 2576 "Verify the operation and return true if it passes, false if it " 2577 "fails.") 2578 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2579 "Puts self immediately after the other operation in its parent " 2580 "block.") 2581 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2582 "Puts self immediately before the other operation in its parent " 2583 "block.") 2584 .def( 2585 "detach_from_parent", 2586 [](PyOperationBase &self) { 2587 PyOperation &operation = self.getOperation(); 2588 operation.checkValid(); 2589 if (!operation.isAttached()) 2590 throw py::value_error("Detached operation has no parent."); 2591 2592 operation.detachFromParent(); 2593 return operation.createOpView(); 2594 }, 2595 "Detaches the operation from its parent block."); 2596 2597 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2598 .def_static("create", &PyOperation::create, py::arg("name"), 2599 py::arg("results") = py::none(), 2600 py::arg("operands") = py::none(), 2601 py::arg("attributes") = py::none(), 2602 py::arg("successors") = py::none(), py::arg("regions") = 0, 2603 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2604 kOperationCreateDocstring) 2605 .def_property_readonly("parent", 2606 [](PyOperation &self) -> py::object { 2607 auto parent = self.getParentOperation(); 2608 if (parent) 2609 return parent->getObject(); 2610 return py::none(); 2611 }) 2612 .def("erase", &PyOperation::erase) 2613 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2614 &PyOperation::getCapsule) 2615 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2616 .def_property_readonly("name", 2617 [](PyOperation &self) { 2618 self.checkValid(); 2619 MlirOperation operation = self.get(); 2620 MlirStringRef name = mlirIdentifierStr( 2621 mlirOperationGetName(operation)); 2622 return py::str(name.data, name.length); 2623 }) 2624 .def_property_readonly( 2625 "context", 2626 [](PyOperation &self) { 2627 self.checkValid(); 2628 return self.getContext().getObject(); 2629 }, 2630 "Context that owns the Operation") 2631 .def_property_readonly("opview", &PyOperation::createOpView); 2632 2633 auto opViewClass = 2634 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2635 .def(py::init<py::object>(), py::arg("operation")) 2636 .def_property_readonly("operation", &PyOpView::getOperationObject) 2637 .def_property_readonly( 2638 "context", 2639 [](PyOpView &self) { 2640 return self.getOperation().getContext().getObject(); 2641 }, 2642 "Context that owns the Operation") 2643 .def("__str__", [](PyOpView &self) { 2644 return py::str(self.getOperationObject()); 2645 }); 2646 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2647 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2648 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2649 opViewClass.attr("build_generic") = classmethod( 2650 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2651 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2652 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2653 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2654 "Builds a specific, generated OpView based on class level attributes."); 2655 2656 //---------------------------------------------------------------------------- 2657 // Mapping of PyRegion. 2658 //---------------------------------------------------------------------------- 2659 py::class_<PyRegion>(m, "Region", py::module_local()) 2660 .def_property_readonly( 2661 "blocks", 2662 [](PyRegion &self) { 2663 return PyBlockList(self.getParentOperation(), self.get()); 2664 }, 2665 "Returns a forward-optimized sequence of blocks.") 2666 .def_property_readonly( 2667 "owner", 2668 [](PyRegion &self) { 2669 return self.getParentOperation()->createOpView(); 2670 }, 2671 "Returns the operation owning this region.") 2672 .def( 2673 "__iter__", 2674 [](PyRegion &self) { 2675 self.checkValid(); 2676 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2677 return PyBlockIterator(self.getParentOperation(), firstBlock); 2678 }, 2679 "Iterates over blocks in the region.") 2680 .def("__eq__", 2681 [](PyRegion &self, PyRegion &other) { 2682 return self.get().ptr == other.get().ptr; 2683 }) 2684 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2685 2686 //---------------------------------------------------------------------------- 2687 // Mapping of PyBlock. 2688 //---------------------------------------------------------------------------- 2689 py::class_<PyBlock>(m, "Block", py::module_local()) 2690 .def_property_readonly( 2691 "owner", 2692 [](PyBlock &self) { 2693 return self.getParentOperation()->createOpView(); 2694 }, 2695 "Returns the owning operation of this block.") 2696 .def_property_readonly( 2697 "region", 2698 [](PyBlock &self) { 2699 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2700 return PyRegion(self.getParentOperation(), region); 2701 }, 2702 "Returns the owning region of this block.") 2703 .def_property_readonly( 2704 "arguments", 2705 [](PyBlock &self) { 2706 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2707 }, 2708 "Returns a list of block arguments.") 2709 .def_property_readonly( 2710 "operations", 2711 [](PyBlock &self) { 2712 return PyOperationList(self.getParentOperation(), self.get()); 2713 }, 2714 "Returns a forward-optimized sequence of operations.") 2715 .def_static( 2716 "create_at_start", 2717 [](PyRegion &parent, py::list pyArgTypes) { 2718 parent.checkValid(); 2719 llvm::SmallVector<MlirType, 4> argTypes; 2720 argTypes.reserve(pyArgTypes.size()); 2721 for (auto &pyArg : pyArgTypes) { 2722 argTypes.push_back(pyArg.cast<PyType &>()); 2723 } 2724 2725 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2726 mlirRegionInsertOwnedBlock(parent, 0, block); 2727 return PyBlock(parent.getParentOperation(), block); 2728 }, 2729 py::arg("parent"), py::arg("arg_types") = py::list(), 2730 "Creates and returns a new Block at the beginning of the given " 2731 "region (with given argument types).") 2732 .def( 2733 "create_before", 2734 [](PyBlock &self, py::args pyArgTypes) { 2735 self.checkValid(); 2736 llvm::SmallVector<MlirType, 4> argTypes; 2737 argTypes.reserve(pyArgTypes.size()); 2738 for (auto &pyArg : pyArgTypes) { 2739 argTypes.push_back(pyArg.cast<PyType &>()); 2740 } 2741 2742 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2743 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2744 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2745 return PyBlock(self.getParentOperation(), block); 2746 }, 2747 "Creates and returns a new Block before this block " 2748 "(with given argument types).") 2749 .def( 2750 "create_after", 2751 [](PyBlock &self, py::args pyArgTypes) { 2752 self.checkValid(); 2753 llvm::SmallVector<MlirType, 4> argTypes; 2754 argTypes.reserve(pyArgTypes.size()); 2755 for (auto &pyArg : pyArgTypes) { 2756 argTypes.push_back(pyArg.cast<PyType &>()); 2757 } 2758 2759 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2760 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2761 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2762 return PyBlock(self.getParentOperation(), block); 2763 }, 2764 "Creates and returns a new Block after this block " 2765 "(with given argument types).") 2766 .def( 2767 "__iter__", 2768 [](PyBlock &self) { 2769 self.checkValid(); 2770 MlirOperation firstOperation = 2771 mlirBlockGetFirstOperation(self.get()); 2772 return PyOperationIterator(self.getParentOperation(), 2773 firstOperation); 2774 }, 2775 "Iterates over operations in the block.") 2776 .def("__eq__", 2777 [](PyBlock &self, PyBlock &other) { 2778 return self.get().ptr == other.get().ptr; 2779 }) 2780 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2781 .def( 2782 "__str__", 2783 [](PyBlock &self) { 2784 self.checkValid(); 2785 PyPrintAccumulator printAccum; 2786 mlirBlockPrint(self.get(), printAccum.getCallback(), 2787 printAccum.getUserData()); 2788 return printAccum.join(); 2789 }, 2790 "Returns the assembly form of the block.") 2791 .def( 2792 "append", 2793 [](PyBlock &self, PyOperationBase &operation) { 2794 if (operation.getOperation().isAttached()) 2795 operation.getOperation().detachFromParent(); 2796 2797 MlirOperation mlirOperation = operation.getOperation().get(); 2798 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2799 operation.getOperation().setAttached( 2800 self.getParentOperation().getObject()); 2801 }, 2802 py::arg("operation"), 2803 "Appends an operation to this block. If the operation is currently " 2804 "in another block, it will be moved."); 2805 2806 //---------------------------------------------------------------------------- 2807 // Mapping of PyInsertionPoint. 2808 //---------------------------------------------------------------------------- 2809 2810 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2811 .def(py::init<PyBlock &>(), py::arg("block"), 2812 "Inserts after the last operation but still inside the block.") 2813 .def("__enter__", &PyInsertionPoint::contextEnter) 2814 .def("__exit__", &PyInsertionPoint::contextExit) 2815 .def_property_readonly_static( 2816 "current", 2817 [](py::object & /*class*/) { 2818 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2819 if (!ip) 2820 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2821 return ip; 2822 }, 2823 "Gets the InsertionPoint bound to the current thread or raises " 2824 "ValueError if none has been set") 2825 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2826 "Inserts before a referenced operation.") 2827 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2828 py::arg("block"), "Inserts at the beginning of the block.") 2829 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2830 py::arg("block"), "Inserts before the block terminator.") 2831 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2832 "Inserts an operation.") 2833 .def_property_readonly( 2834 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2835 "Returns the block that this InsertionPoint points to."); 2836 2837 //---------------------------------------------------------------------------- 2838 // Mapping of PyAttribute. 2839 //---------------------------------------------------------------------------- 2840 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2841 // Delegate to the PyAttribute copy constructor, which will also lifetime 2842 // extend the backing context which owns the MlirAttribute. 2843 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2844 "Casts the passed attribute to the generic Attribute") 2845 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2846 &PyAttribute::getCapsule) 2847 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2848 .def_static( 2849 "parse", 2850 [](std::string attrSpec, DefaultingPyMlirContext context) { 2851 MlirAttribute type = mlirAttributeParseGet( 2852 context->get(), toMlirStringRef(attrSpec)); 2853 // TODO: Rework error reporting once diagnostic engine is exposed 2854 // in C API. 2855 if (mlirAttributeIsNull(type)) { 2856 throw SetPyError(PyExc_ValueError, 2857 Twine("Unable to parse attribute: '") + 2858 attrSpec + "'"); 2859 } 2860 return PyAttribute(context->getRef(), type); 2861 }, 2862 py::arg("asm"), py::arg("context") = py::none(), 2863 "Parses an attribute from an assembly form") 2864 .def_property_readonly( 2865 "context", 2866 [](PyAttribute &self) { return self.getContext().getObject(); }, 2867 "Context that owns the Attribute") 2868 .def_property_readonly("type", 2869 [](PyAttribute &self) { 2870 return PyType(self.getContext()->getRef(), 2871 mlirAttributeGetType(self)); 2872 }) 2873 .def( 2874 "get_named", 2875 [](PyAttribute &self, std::string name) { 2876 return PyNamedAttribute(self, std::move(name)); 2877 }, 2878 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2879 .def("__eq__", 2880 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2881 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2882 .def("__hash__", 2883 [](PyAttribute &self) { 2884 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2885 }) 2886 .def( 2887 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2888 kDumpDocstring) 2889 .def( 2890 "__str__", 2891 [](PyAttribute &self) { 2892 PyPrintAccumulator printAccum; 2893 mlirAttributePrint(self, printAccum.getCallback(), 2894 printAccum.getUserData()); 2895 return printAccum.join(); 2896 }, 2897 "Returns the assembly form of the Attribute.") 2898 .def("__repr__", [](PyAttribute &self) { 2899 // Generally, assembly formats are not printed for __repr__ because 2900 // this can cause exceptionally long debug output and exceptions. 2901 // However, attribute values are generally considered useful and are 2902 // printed. This may need to be re-evaluated if debug dumps end up 2903 // being excessive. 2904 PyPrintAccumulator printAccum; 2905 printAccum.parts.append("Attribute("); 2906 mlirAttributePrint(self, printAccum.getCallback(), 2907 printAccum.getUserData()); 2908 printAccum.parts.append(")"); 2909 return printAccum.join(); 2910 }); 2911 2912 //---------------------------------------------------------------------------- 2913 // Mapping of PyNamedAttribute 2914 //---------------------------------------------------------------------------- 2915 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2916 .def("__repr__", 2917 [](PyNamedAttribute &self) { 2918 PyPrintAccumulator printAccum; 2919 printAccum.parts.append("NamedAttribute("); 2920 printAccum.parts.append( 2921 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2922 mlirIdentifierStr(self.namedAttr.name).length)); 2923 printAccum.parts.append("="); 2924 mlirAttributePrint(self.namedAttr.attribute, 2925 printAccum.getCallback(), 2926 printAccum.getUserData()); 2927 printAccum.parts.append(")"); 2928 return printAccum.join(); 2929 }) 2930 .def_property_readonly( 2931 "name", 2932 [](PyNamedAttribute &self) { 2933 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2934 mlirIdentifierStr(self.namedAttr.name).length); 2935 }, 2936 "The name of the NamedAttribute binding") 2937 .def_property_readonly( 2938 "attr", 2939 [](PyNamedAttribute &self) { 2940 // TODO: When named attribute is removed/refactored, also remove 2941 // this constructor (it does an inefficient table lookup). 2942 auto contextRef = PyMlirContext::forContext( 2943 mlirAttributeGetContext(self.namedAttr.attribute)); 2944 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2945 }, 2946 py::keep_alive<0, 1>(), 2947 "The underlying generic attribute of the NamedAttribute binding"); 2948 2949 //---------------------------------------------------------------------------- 2950 // Mapping of PyType. 2951 //---------------------------------------------------------------------------- 2952 py::class_<PyType>(m, "Type", py::module_local()) 2953 // Delegate to the PyType copy constructor, which will also lifetime 2954 // extend the backing context which owns the MlirType. 2955 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2956 "Casts the passed type to the generic Type") 2957 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2958 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2959 .def_static( 2960 "parse", 2961 [](std::string typeSpec, DefaultingPyMlirContext context) { 2962 MlirType type = 2963 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2964 // TODO: Rework error reporting once diagnostic engine is exposed 2965 // in C API. 2966 if (mlirTypeIsNull(type)) { 2967 throw SetPyError(PyExc_ValueError, 2968 Twine("Unable to parse type: '") + typeSpec + 2969 "'"); 2970 } 2971 return PyType(context->getRef(), type); 2972 }, 2973 py::arg("asm"), py::arg("context") = py::none(), 2974 kContextParseTypeDocstring) 2975 .def_property_readonly( 2976 "context", [](PyType &self) { return self.getContext().getObject(); }, 2977 "Context that owns the Type") 2978 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2979 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2980 .def("__hash__", 2981 [](PyType &self) { 2982 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2983 }) 2984 .def( 2985 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2986 .def( 2987 "__str__", 2988 [](PyType &self) { 2989 PyPrintAccumulator printAccum; 2990 mlirTypePrint(self, printAccum.getCallback(), 2991 printAccum.getUserData()); 2992 return printAccum.join(); 2993 }, 2994 "Returns the assembly form of the type.") 2995 .def("__repr__", [](PyType &self) { 2996 // Generally, assembly formats are not printed for __repr__ because 2997 // this can cause exceptionally long debug output and exceptions. 2998 // However, types are an exception as they typically have compact 2999 // assembly forms and printing them is useful. 3000 PyPrintAccumulator printAccum; 3001 printAccum.parts.append("Type("); 3002 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 3003 printAccum.parts.append(")"); 3004 return printAccum.join(); 3005 }); 3006 3007 //---------------------------------------------------------------------------- 3008 // Mapping of Value. 3009 //---------------------------------------------------------------------------- 3010 py::class_<PyValue>(m, "Value", py::module_local()) 3011 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3012 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3013 .def_property_readonly( 3014 "context", 3015 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3016 "Context in which the value lives.") 3017 .def( 3018 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3019 kDumpDocstring) 3020 .def_property_readonly( 3021 "owner", 3022 [](PyValue &self) { 3023 assert(mlirOperationEqual(self.getParentOperation()->get(), 3024 mlirOpResultGetOwner(self.get())) && 3025 "expected the owner of the value in Python to match that in " 3026 "the IR"); 3027 return self.getParentOperation().getObject(); 3028 }) 3029 .def("__eq__", 3030 [](PyValue &self, PyValue &other) { 3031 return self.get().ptr == other.get().ptr; 3032 }) 3033 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 3034 .def("__hash__", 3035 [](PyValue &self) { 3036 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3037 }) 3038 .def( 3039 "__str__", 3040 [](PyValue &self) { 3041 PyPrintAccumulator printAccum; 3042 printAccum.parts.append("Value("); 3043 mlirValuePrint(self.get(), printAccum.getCallback(), 3044 printAccum.getUserData()); 3045 printAccum.parts.append(")"); 3046 return printAccum.join(); 3047 }, 3048 kValueDunderStrDocstring) 3049 .def_property_readonly("type", [](PyValue &self) { 3050 return PyType(self.getParentOperation()->getContext(), 3051 mlirValueGetType(self.get())); 3052 }); 3053 PyBlockArgument::bind(m); 3054 PyOpResult::bind(m); 3055 3056 //---------------------------------------------------------------------------- 3057 // Mapping of SymbolTable. 3058 //---------------------------------------------------------------------------- 3059 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 3060 .def(py::init<PyOperationBase &>()) 3061 .def("__getitem__", &PySymbolTable::dunderGetItem) 3062 .def("insert", &PySymbolTable::insert, py::arg("operation")) 3063 .def("erase", &PySymbolTable::erase, py::arg("operation")) 3064 .def("__delitem__", &PySymbolTable::dunderDel) 3065 .def("__contains__", 3066 [](PySymbolTable &table, const std::string &name) { 3067 return !mlirOperationIsNull(mlirSymbolTableLookup( 3068 table, mlirStringRefCreate(name.data(), name.length()))); 3069 }) 3070 // Static helpers. 3071 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3072 py::arg("symbol"), py::arg("name")) 3073 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3074 py::arg("symbol")) 3075 .def_static("get_visibility", &PySymbolTable::getVisibility, 3076 py::arg("symbol")) 3077 .def_static("set_visibility", &PySymbolTable::setVisibility, 3078 py::arg("symbol"), py::arg("visibility")) 3079 .def_static("replace_all_symbol_uses", 3080 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 3081 py::arg("new_symbol"), py::arg("from_op")) 3082 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3083 py::arg("from_op"), py::arg("all_sym_uses_visible"), 3084 py::arg("callback")); 3085 3086 // Container bindings. 3087 PyBlockArgumentList::bind(m); 3088 PyBlockIterator::bind(m); 3089 PyBlockList::bind(m); 3090 PyOperationIterator::bind(m); 3091 PyOperationList::bind(m); 3092 PyOpAttributeMap::bind(m); 3093 PyOpOperandList::bind(m); 3094 PyOpResultList::bind(m); 3095 PyRegionIterator::bind(m); 3096 PyRegionList::bind(m); 3097 3098 // Debug bindings. 3099 PyGlobalDebugFlag::bind(m); 3100 } 3101