1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 #include <utility> 25 26 namespace py = pybind11; 27 using namespace mlir; 28 using namespace mlir::python; 29 30 using llvm::SmallVector; 31 using llvm::StringRef; 32 using llvm::Twine; 33 34 //------------------------------------------------------------------------------ 35 // Docstrings (trivial, non-duplicated docstrings are included inline). 36 //------------------------------------------------------------------------------ 37 38 static const char kContextParseTypeDocstring[] = 39 R"(Parses the assembly form of a type. 40 41 Returns a Type object or raises a ValueError if the type cannot be parsed. 42 43 See also: https://mlir.llvm.org/docs/LangRef/#type-system 44 )"; 45 46 static const char kContextGetCallSiteLocationDocstring[] = 47 R"(Gets a Location representing a caller and callsite)"; 48 49 static const char kContextGetFileLocationDocstring[] = 50 R"(Gets a Location representing a file, line and column)"; 51 52 static const char kContextGetFusedLocationDocstring[] = 53 R"(Gets a Location representing a fused location with optional metadata)"; 54 55 static const char kContextGetNameLocationDocString[] = 56 R"(Gets a Location representing a named location with optional child location)"; 57 58 static const char kModuleParseDocstring[] = 59 R"(Parses a module's assembly format from a string. 60 61 Returns a new MlirModule or raises a ValueError if the parsing fails. 62 63 See also: https://mlir.llvm.org/docs/LangRef/ 64 )"; 65 66 static const char kOperationCreateDocstring[] = 67 R"(Creates a new operation. 68 69 Args: 70 name: Operation name (e.g. "dialect.operation"). 71 results: Sequence of Type representing op result types. 72 attributes: Dict of str:Attribute. 73 successors: List of Block for the operation's successors. 74 regions: Number of regions to create. 75 location: A Location object (defaults to resolve from context manager). 76 ip: An InsertionPoint (defaults to resolve from context manager or set to 77 False to disable insertion, even with an insertion point set in the 78 context manager). 79 Returns: 80 A new "detached" Operation object. Detached operations can be added 81 to blocks, which causes them to become "attached." 82 )"; 83 84 static const char kOperationPrintDocstring[] = 85 R"(Prints the assembly form of the operation to a file like object. 86 87 Args: 88 file: The file like object to write to. Defaults to sys.stdout. 89 binary: Whether to write bytes (True) or str (False). Defaults to False. 90 large_elements_limit: Whether to elide elements attributes above this 91 number of elements. Defaults to None (no limit). 92 enable_debug_info: Whether to print debug/location information. Defaults 93 to False. 94 pretty_debug_info: Whether to format debug information for easier reading 95 by a human (warning: the result is unparseable). 96 print_generic_op_form: Whether to print the generic assembly forms of all 97 ops. Defaults to False. 98 use_local_Scope: Whether to print in a way that is more optimized for 99 multi-threaded access but may not be consistent with how the overall 100 module prints. 101 assume_verified: By default, if not printing generic form, the verifier 102 will be run and if it fails, generic form will be printed with a comment 103 about failed verification. While a reasonable default for interactive use, 104 for systematic use, it is often better for the caller to verify explicitly 105 and report failures in a more robust fashion. Set this to True if doing this 106 in order to avoid running a redundant verification. If the IR is actually 107 invalid, behavior is undefined. 108 )"; 109 110 static const char kOperationGetAsmDocstring[] = 111 R"(Gets the assembly form of the operation with all options available. 112 113 Args: 114 binary: Whether to return a bytes (True) or str (False) object. Defaults to 115 False. 116 ... others ...: See the print() method for common keyword arguments for 117 configuring the printout. 118 Returns: 119 Either a bytes or str object, depending on the setting of the 'binary' 120 argument. 121 )"; 122 123 static const char kOperationStrDunderDocstring[] = 124 R"(Gets the assembly form of the operation with default options. 125 126 If more advanced control over the assembly formatting or I/O options is needed, 127 use the dedicated print or get_asm method, which supports keyword arguments to 128 customize behavior. 129 )"; 130 131 static const char kDumpDocstring[] = 132 R"(Dumps a debug representation of the object to stderr.)"; 133 134 static const char kAppendBlockDocstring[] = 135 R"(Appends a new block, with argument types as positional args. 136 137 Returns: 138 The created block. 139 )"; 140 141 static const char kValueDunderStrDocstring[] = 142 R"(Returns the string form of the value. 143 144 If the value is a block argument, this is the assembly form of its type and the 145 position in the argument list. If the value is an operation result, this is 146 equivalent to printing the operation that produced it. 147 )"; 148 149 //------------------------------------------------------------------------------ 150 // Utilities. 151 //------------------------------------------------------------------------------ 152 153 /// Helper for creating an @classmethod. 154 template <class Func, typename... Args> 155 py::object classmethod(Func f, Args... args) { 156 py::object cf = py::cpp_function(f, args...); 157 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 158 } 159 160 static py::object 161 createCustomDialectWrapper(const std::string &dialectNamespace, 162 py::object dialectDescriptor) { 163 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 164 if (!dialectClass) { 165 // Use the base class. 166 return py::cast(PyDialect(std::move(dialectDescriptor))); 167 } 168 169 // Create the custom implementation. 170 return (*dialectClass)(std::move(dialectDescriptor)); 171 } 172 173 static MlirStringRef toMlirStringRef(const std::string &s) { 174 return mlirStringRefCreate(s.data(), s.size()); 175 } 176 177 /// Wrapper for the global LLVM debugging flag. 178 struct PyGlobalDebugFlag { 179 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 180 181 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } 182 183 static void bind(py::module &m) { 184 // Debug flags. 185 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 186 .def_property_static("flag", &PyGlobalDebugFlag::get, 187 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 188 } 189 }; 190 191 //------------------------------------------------------------------------------ 192 // Collections. 193 //------------------------------------------------------------------------------ 194 195 namespace { 196 197 class PyRegionIterator { 198 public: 199 PyRegionIterator(PyOperationRef operation) 200 : operation(std::move(operation)) {} 201 202 PyRegionIterator &dunderIter() { return *this; } 203 204 PyRegion dunderNext() { 205 operation->checkValid(); 206 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 207 throw py::stop_iteration(); 208 } 209 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 210 return PyRegion(operation, region); 211 } 212 213 static void bind(py::module &m) { 214 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 215 .def("__iter__", &PyRegionIterator::dunderIter) 216 .def("__next__", &PyRegionIterator::dunderNext); 217 } 218 219 private: 220 PyOperationRef operation; 221 int nextIndex = 0; 222 }; 223 224 /// Regions of an op are fixed length and indexed numerically so are represented 225 /// with a sequence-like container. 226 class PyRegionList { 227 public: 228 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 229 230 intptr_t dunderLen() { 231 operation->checkValid(); 232 return mlirOperationGetNumRegions(operation->get()); 233 } 234 235 PyRegion dunderGetItem(intptr_t index) { 236 // dunderLen checks validity. 237 if (index < 0 || index >= dunderLen()) { 238 throw SetPyError(PyExc_IndexError, 239 "attempt to access out of bounds region"); 240 } 241 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 242 return PyRegion(operation, region); 243 } 244 245 static void bind(py::module &m) { 246 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 247 .def("__len__", &PyRegionList::dunderLen) 248 .def("__getitem__", &PyRegionList::dunderGetItem); 249 } 250 251 private: 252 PyOperationRef operation; 253 }; 254 255 class PyBlockIterator { 256 public: 257 PyBlockIterator(PyOperationRef operation, MlirBlock next) 258 : operation(std::move(operation)), next(next) {} 259 260 PyBlockIterator &dunderIter() { return *this; } 261 262 PyBlock dunderNext() { 263 operation->checkValid(); 264 if (mlirBlockIsNull(next)) { 265 throw py::stop_iteration(); 266 } 267 268 PyBlock returnBlock(operation, next); 269 next = mlirBlockGetNextInRegion(next); 270 return returnBlock; 271 } 272 273 static void bind(py::module &m) { 274 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 275 .def("__iter__", &PyBlockIterator::dunderIter) 276 .def("__next__", &PyBlockIterator::dunderNext); 277 } 278 279 private: 280 PyOperationRef operation; 281 MlirBlock next; 282 }; 283 284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 285 /// we present them as a more full-featured list-like container but optimize 286 /// it for forward iteration. Blocks are always owned by a region. 287 class PyBlockList { 288 public: 289 PyBlockList(PyOperationRef operation, MlirRegion region) 290 : operation(std::move(operation)), region(region) {} 291 292 PyBlockIterator dunderIter() { 293 operation->checkValid(); 294 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 295 } 296 297 intptr_t dunderLen() { 298 operation->checkValid(); 299 intptr_t count = 0; 300 MlirBlock block = mlirRegionGetFirstBlock(region); 301 while (!mlirBlockIsNull(block)) { 302 count += 1; 303 block = mlirBlockGetNextInRegion(block); 304 } 305 return count; 306 } 307 308 PyBlock dunderGetItem(intptr_t index) { 309 operation->checkValid(); 310 if (index < 0) { 311 throw SetPyError(PyExc_IndexError, 312 "attempt to access out of bounds block"); 313 } 314 MlirBlock block = mlirRegionGetFirstBlock(region); 315 while (!mlirBlockIsNull(block)) { 316 if (index == 0) { 317 return PyBlock(operation, block); 318 } 319 block = mlirBlockGetNextInRegion(block); 320 index -= 1; 321 } 322 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 323 } 324 325 PyBlock appendBlock(const py::args &pyArgTypes) { 326 operation->checkValid(); 327 llvm::SmallVector<MlirType, 4> argTypes; 328 llvm::SmallVector<MlirLocation, 4> argLocs; 329 argTypes.reserve(pyArgTypes.size()); 330 argLocs.reserve(pyArgTypes.size()); 331 for (auto &pyArg : pyArgTypes) { 332 argTypes.push_back(pyArg.cast<PyType &>()); 333 // TODO: Pass in a proper location here. 334 argLocs.push_back( 335 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 336 } 337 338 MlirBlock block = 339 mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 340 mlirRegionAppendOwnedBlock(region, block); 341 return PyBlock(operation, block); 342 } 343 344 static void bind(py::module &m) { 345 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 346 .def("__getitem__", &PyBlockList::dunderGetItem) 347 .def("__iter__", &PyBlockList::dunderIter) 348 .def("__len__", &PyBlockList::dunderLen) 349 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 350 } 351 352 private: 353 PyOperationRef operation; 354 MlirRegion region; 355 }; 356 357 class PyOperationIterator { 358 public: 359 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 360 : parentOperation(std::move(parentOperation)), next(next) {} 361 362 PyOperationIterator &dunderIter() { return *this; } 363 364 py::object dunderNext() { 365 parentOperation->checkValid(); 366 if (mlirOperationIsNull(next)) { 367 throw py::stop_iteration(); 368 } 369 370 PyOperationRef returnOperation = 371 PyOperation::forOperation(parentOperation->getContext(), next); 372 next = mlirOperationGetNextInBlock(next); 373 return returnOperation->createOpView(); 374 } 375 376 static void bind(py::module &m) { 377 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 378 .def("__iter__", &PyOperationIterator::dunderIter) 379 .def("__next__", &PyOperationIterator::dunderNext); 380 } 381 382 private: 383 PyOperationRef parentOperation; 384 MlirOperation next; 385 }; 386 387 /// Operations are exposed by the C-API as a forward-only linked list. In 388 /// Python, we present them as a more full-featured list-like container but 389 /// optimize it for forward iteration. Iterable operations are always owned 390 /// by a block. 391 class PyOperationList { 392 public: 393 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 394 : parentOperation(std::move(parentOperation)), block(block) {} 395 396 PyOperationIterator dunderIter() { 397 parentOperation->checkValid(); 398 return PyOperationIterator(parentOperation, 399 mlirBlockGetFirstOperation(block)); 400 } 401 402 intptr_t dunderLen() { 403 parentOperation->checkValid(); 404 intptr_t count = 0; 405 MlirOperation childOp = mlirBlockGetFirstOperation(block); 406 while (!mlirOperationIsNull(childOp)) { 407 count += 1; 408 childOp = mlirOperationGetNextInBlock(childOp); 409 } 410 return count; 411 } 412 413 py::object dunderGetItem(intptr_t index) { 414 parentOperation->checkValid(); 415 if (index < 0) { 416 throw SetPyError(PyExc_IndexError, 417 "attempt to access out of bounds operation"); 418 } 419 MlirOperation childOp = mlirBlockGetFirstOperation(block); 420 while (!mlirOperationIsNull(childOp)) { 421 if (index == 0) { 422 return PyOperation::forOperation(parentOperation->getContext(), childOp) 423 ->createOpView(); 424 } 425 childOp = mlirOperationGetNextInBlock(childOp); 426 index -= 1; 427 } 428 throw SetPyError(PyExc_IndexError, 429 "attempt to access out of bounds operation"); 430 } 431 432 static void bind(py::module &m) { 433 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 434 .def("__getitem__", &PyOperationList::dunderGetItem) 435 .def("__iter__", &PyOperationList::dunderIter) 436 .def("__len__", &PyOperationList::dunderLen); 437 } 438 439 private: 440 PyOperationRef parentOperation; 441 MlirBlock block; 442 }; 443 444 } // namespace 445 446 //------------------------------------------------------------------------------ 447 // PyMlirContext 448 //------------------------------------------------------------------------------ 449 450 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 451 py::gil_scoped_acquire acquire; 452 auto &liveContexts = getLiveContexts(); 453 liveContexts[context.ptr] = this; 454 } 455 456 PyMlirContext::~PyMlirContext() { 457 // Note that the only public way to construct an instance is via the 458 // forContext method, which always puts the associated handle into 459 // liveContexts. 460 py::gil_scoped_acquire acquire; 461 getLiveContexts().erase(context.ptr); 462 mlirContextDestroy(context); 463 } 464 465 py::object PyMlirContext::getCapsule() { 466 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 467 } 468 469 py::object PyMlirContext::createFromCapsule(py::object capsule) { 470 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 471 if (mlirContextIsNull(rawContext)) 472 throw py::error_already_set(); 473 return forContext(rawContext).releaseObject(); 474 } 475 476 PyMlirContext *PyMlirContext::createNewContextForInit() { 477 MlirContext context = mlirContextCreate(); 478 mlirRegisterAllDialects(context); 479 return new PyMlirContext(context); 480 } 481 482 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 483 py::gil_scoped_acquire acquire; 484 auto &liveContexts = getLiveContexts(); 485 auto it = liveContexts.find(context.ptr); 486 if (it == liveContexts.end()) { 487 // Create. 488 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 489 py::object pyRef = py::cast(unownedContextWrapper); 490 assert(pyRef && "cast to py::object failed"); 491 liveContexts[context.ptr] = unownedContextWrapper; 492 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 493 } 494 // Use existing. 495 py::object pyRef = py::cast(it->second); 496 return PyMlirContextRef(it->second, std::move(pyRef)); 497 } 498 499 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 500 static LiveContextMap liveContexts; 501 return liveContexts; 502 } 503 504 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 505 506 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 507 508 size_t PyMlirContext::clearLiveOperations() { 509 for (auto &op : liveOperations) 510 op.second.second->setInvalid(); 511 size_t numInvalidated = liveOperations.size(); 512 liveOperations.clear(); 513 return numInvalidated; 514 } 515 516 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 517 518 pybind11::object PyMlirContext::contextEnter() { 519 return PyThreadContextEntry::pushContext(*this); 520 } 521 522 void PyMlirContext::contextExit(const pybind11::object &excType, 523 const pybind11::object &excVal, 524 const pybind11::object &excTb) { 525 PyThreadContextEntry::popContext(*this); 526 } 527 528 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { 529 // Note that ownership is transferred to the delete callback below by way of 530 // an explicit inc_ref (borrow). 531 PyDiagnosticHandler *pyHandler = 532 new PyDiagnosticHandler(get(), std::move(callback)); 533 py::object pyHandlerObject = 534 py::cast(pyHandler, py::return_value_policy::take_ownership); 535 pyHandlerObject.inc_ref(); 536 537 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 538 // guaranteed to be known to pybind. 539 auto handlerCallback = 540 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 541 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 542 py::object pyDiagnosticObject = 543 py::cast(pyDiagnostic, py::return_value_policy::take_ownership); 544 545 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 546 bool result = false; 547 { 548 // Since this can be called from arbitrary C++ contexts, always get the 549 // gil. 550 py::gil_scoped_acquire gil; 551 try { 552 result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); 553 } catch (std::exception &e) { 554 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 555 e.what()); 556 pyHandler->hadError = true; 557 } 558 } 559 560 pyDiagnostic->invalidate(); 561 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 562 }; 563 auto deleteCallback = +[](void *userData) { 564 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 565 assert(pyHandler->registeredID && "handler is not registered"); 566 pyHandler->registeredID.reset(); 567 568 // Decrement reference, balancing the inc_ref() above. 569 py::object pyHandlerObject = 570 py::cast(pyHandler, py::return_value_policy::reference); 571 pyHandlerObject.dec_ref(); 572 }; 573 574 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 575 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 576 return pyHandlerObject; 577 } 578 579 PyMlirContext &DefaultingPyMlirContext::resolve() { 580 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 581 if (!context) { 582 throw SetPyError( 583 PyExc_RuntimeError, 584 "An MLIR function requires a Context but none was provided in the call " 585 "or from the surrounding environment. Either pass to the function with " 586 "a 'context=' argument or establish a default using 'with Context():'"); 587 } 588 return *context; 589 } 590 591 //------------------------------------------------------------------------------ 592 // PyThreadContextEntry management 593 //------------------------------------------------------------------------------ 594 595 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 596 static thread_local std::vector<PyThreadContextEntry> stack; 597 return stack; 598 } 599 600 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 601 auto &stack = getStack(); 602 if (stack.empty()) 603 return nullptr; 604 return &stack.back(); 605 } 606 607 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 608 py::object insertionPoint, 609 py::object location) { 610 auto &stack = getStack(); 611 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 612 std::move(location)); 613 // If the new stack has more than one entry and the context of the new top 614 // entry matches the previous, copy the insertionPoint and location from the 615 // previous entry if missing from the new top entry. 616 if (stack.size() > 1) { 617 auto &prev = *(stack.rbegin() + 1); 618 auto ¤t = stack.back(); 619 if (current.context.is(prev.context)) { 620 // Default non-context objects from the previous entry. 621 if (!current.insertionPoint) 622 current.insertionPoint = prev.insertionPoint; 623 if (!current.location) 624 current.location = prev.location; 625 } 626 } 627 } 628 629 PyMlirContext *PyThreadContextEntry::getContext() { 630 if (!context) 631 return nullptr; 632 return py::cast<PyMlirContext *>(context); 633 } 634 635 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 636 if (!insertionPoint) 637 return nullptr; 638 return py::cast<PyInsertionPoint *>(insertionPoint); 639 } 640 641 PyLocation *PyThreadContextEntry::getLocation() { 642 if (!location) 643 return nullptr; 644 return py::cast<PyLocation *>(location); 645 } 646 647 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 648 auto *tos = getTopOfStack(); 649 return tos ? tos->getContext() : nullptr; 650 } 651 652 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 653 auto *tos = getTopOfStack(); 654 return tos ? tos->getInsertionPoint() : nullptr; 655 } 656 657 PyLocation *PyThreadContextEntry::getDefaultLocation() { 658 auto *tos = getTopOfStack(); 659 return tos ? tos->getLocation() : nullptr; 660 } 661 662 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 663 py::object contextObj = py::cast(context); 664 push(FrameKind::Context, /*context=*/contextObj, 665 /*insertionPoint=*/py::object(), 666 /*location=*/py::object()); 667 return contextObj; 668 } 669 670 void PyThreadContextEntry::popContext(PyMlirContext &context) { 671 auto &stack = getStack(); 672 if (stack.empty()) 673 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 674 auto &tos = stack.back(); 675 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 676 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 677 stack.pop_back(); 678 } 679 680 py::object 681 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 682 py::object contextObj = 683 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 684 py::object insertionPointObj = py::cast(insertionPoint); 685 push(FrameKind::InsertionPoint, 686 /*context=*/contextObj, 687 /*insertionPoint=*/insertionPointObj, 688 /*location=*/py::object()); 689 return insertionPointObj; 690 } 691 692 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 693 auto &stack = getStack(); 694 if (stack.empty()) 695 throw SetPyError(PyExc_RuntimeError, 696 "Unbalanced InsertionPoint enter/exit"); 697 auto &tos = stack.back(); 698 if (tos.frameKind != FrameKind::InsertionPoint && 699 tos.getInsertionPoint() != &insertionPoint) 700 throw SetPyError(PyExc_RuntimeError, 701 "Unbalanced InsertionPoint enter/exit"); 702 stack.pop_back(); 703 } 704 705 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 706 py::object contextObj = location.getContext().getObject(); 707 py::object locationObj = py::cast(location); 708 push(FrameKind::Location, /*context=*/contextObj, 709 /*insertionPoint=*/py::object(), 710 /*location=*/locationObj); 711 return locationObj; 712 } 713 714 void PyThreadContextEntry::popLocation(PyLocation &location) { 715 auto &stack = getStack(); 716 if (stack.empty()) 717 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 718 auto &tos = stack.back(); 719 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 720 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 721 stack.pop_back(); 722 } 723 724 //------------------------------------------------------------------------------ 725 // PyDiagnostic* 726 //------------------------------------------------------------------------------ 727 728 void PyDiagnostic::invalidate() { 729 valid = false; 730 if (materializedNotes) { 731 for (auto ¬eObject : *materializedNotes) { 732 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); 733 note->invalidate(); 734 } 735 } 736 } 737 738 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 739 py::object callback) 740 : context(context), callback(std::move(callback)) {} 741 742 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 743 744 void PyDiagnosticHandler::detach() { 745 if (!registeredID) 746 return; 747 MlirDiagnosticHandlerID localID = *registeredID; 748 mlirContextDetachDiagnosticHandler(context, localID); 749 assert(!registeredID && "should have unregistered"); 750 // Not strictly necessary but keeps stale pointers from being around to cause 751 // issues. 752 context = {nullptr}; 753 } 754 755 void PyDiagnostic::checkValid() { 756 if (!valid) { 757 throw std::invalid_argument( 758 "Diagnostic is invalid (used outside of callback)"); 759 } 760 } 761 762 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 763 checkValid(); 764 return mlirDiagnosticGetSeverity(diagnostic); 765 } 766 767 PyLocation PyDiagnostic::getLocation() { 768 checkValid(); 769 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 770 MlirContext context = mlirLocationGetContext(loc); 771 return PyLocation(PyMlirContext::forContext(context), loc); 772 } 773 774 py::str PyDiagnostic::getMessage() { 775 checkValid(); 776 py::object fileObject = py::module::import("io").attr("StringIO")(); 777 PyFileAccumulator accum(fileObject, /*binary=*/false); 778 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 779 return fileObject.attr("getvalue")(); 780 } 781 782 py::tuple PyDiagnostic::getNotes() { 783 checkValid(); 784 if (materializedNotes) 785 return *materializedNotes; 786 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 787 materializedNotes = py::tuple(numNotes); 788 for (intptr_t i = 0; i < numNotes; ++i) { 789 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 790 py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); 791 PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); 792 } 793 return *materializedNotes; 794 } 795 796 //------------------------------------------------------------------------------ 797 // PyDialect, PyDialectDescriptor, PyDialects 798 //------------------------------------------------------------------------------ 799 800 MlirDialect PyDialects::getDialectForKey(const std::string &key, 801 bool attrError) { 802 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 803 {key.data(), key.size()}); 804 if (mlirDialectIsNull(dialect)) { 805 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 806 Twine("Dialect '") + key + "' not found"); 807 } 808 return dialect; 809 } 810 811 //------------------------------------------------------------------------------ 812 // PyLocation 813 //------------------------------------------------------------------------------ 814 815 py::object PyLocation::getCapsule() { 816 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 817 } 818 819 PyLocation PyLocation::createFromCapsule(py::object capsule) { 820 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 821 if (mlirLocationIsNull(rawLoc)) 822 throw py::error_already_set(); 823 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 824 rawLoc); 825 } 826 827 py::object PyLocation::contextEnter() { 828 return PyThreadContextEntry::pushLocation(*this); 829 } 830 831 void PyLocation::contextExit(const pybind11::object &excType, 832 const pybind11::object &excVal, 833 const pybind11::object &excTb) { 834 PyThreadContextEntry::popLocation(*this); 835 } 836 837 PyLocation &DefaultingPyLocation::resolve() { 838 auto *location = PyThreadContextEntry::getDefaultLocation(); 839 if (!location) { 840 throw SetPyError( 841 PyExc_RuntimeError, 842 "An MLIR function requires a Location but none was provided in the " 843 "call or from the surrounding environment. Either pass to the function " 844 "with a 'loc=' argument or establish a default using 'with loc:'"); 845 } 846 return *location; 847 } 848 849 //------------------------------------------------------------------------------ 850 // PyModule 851 //------------------------------------------------------------------------------ 852 853 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 854 : BaseContextObject(std::move(contextRef)), module(module) {} 855 856 PyModule::~PyModule() { 857 py::gil_scoped_acquire acquire; 858 auto &liveModules = getContext()->liveModules; 859 assert(liveModules.count(module.ptr) == 1 && 860 "destroying module not in live map"); 861 liveModules.erase(module.ptr); 862 mlirModuleDestroy(module); 863 } 864 865 PyModuleRef PyModule::forModule(MlirModule module) { 866 MlirContext context = mlirModuleGetContext(module); 867 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 868 869 py::gil_scoped_acquire acquire; 870 auto &liveModules = contextRef->liveModules; 871 auto it = liveModules.find(module.ptr); 872 if (it == liveModules.end()) { 873 // Create. 874 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 875 // Note that the default return value policy on cast is automatic_reference, 876 // which does not take ownership (delete will not be called). 877 // Just be explicit. 878 py::object pyRef = 879 py::cast(unownedModule, py::return_value_policy::take_ownership); 880 unownedModule->handle = pyRef; 881 liveModules[module.ptr] = 882 std::make_pair(unownedModule->handle, unownedModule); 883 return PyModuleRef(unownedModule, std::move(pyRef)); 884 } 885 // Use existing. 886 PyModule *existing = it->second.second; 887 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 888 return PyModuleRef(existing, std::move(pyRef)); 889 } 890 891 py::object PyModule::createFromCapsule(py::object capsule) { 892 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 893 if (mlirModuleIsNull(rawModule)) 894 throw py::error_already_set(); 895 return forModule(rawModule).releaseObject(); 896 } 897 898 py::object PyModule::getCapsule() { 899 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 900 } 901 902 //------------------------------------------------------------------------------ 903 // PyOperation 904 //------------------------------------------------------------------------------ 905 906 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 907 : BaseContextObject(std::move(contextRef)), operation(operation) {} 908 909 PyOperation::~PyOperation() { 910 // If the operation has already been invalidated there is nothing to do. 911 if (!valid) 912 return; 913 auto &liveOperations = getContext()->liveOperations; 914 assert(liveOperations.count(operation.ptr) == 1 && 915 "destroying operation not in live map"); 916 liveOperations.erase(operation.ptr); 917 if (!isAttached()) { 918 mlirOperationDestroy(operation); 919 } 920 } 921 922 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 923 MlirOperation operation, 924 py::object parentKeepAlive) { 925 auto &liveOperations = contextRef->liveOperations; 926 // Create. 927 PyOperation *unownedOperation = 928 new PyOperation(std::move(contextRef), operation); 929 // Note that the default return value policy on cast is automatic_reference, 930 // which does not take ownership (delete will not be called). 931 // Just be explicit. 932 py::object pyRef = 933 py::cast(unownedOperation, py::return_value_policy::take_ownership); 934 unownedOperation->handle = pyRef; 935 if (parentKeepAlive) { 936 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 937 } 938 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 939 return PyOperationRef(unownedOperation, std::move(pyRef)); 940 } 941 942 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 943 MlirOperation operation, 944 py::object parentKeepAlive) { 945 auto &liveOperations = contextRef->liveOperations; 946 auto it = liveOperations.find(operation.ptr); 947 if (it == liveOperations.end()) { 948 // Create. 949 return createInstance(std::move(contextRef), operation, 950 std::move(parentKeepAlive)); 951 } 952 // Use existing. 953 PyOperation *existing = it->second.second; 954 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 955 return PyOperationRef(existing, std::move(pyRef)); 956 } 957 958 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 959 MlirOperation operation, 960 py::object parentKeepAlive) { 961 auto &liveOperations = contextRef->liveOperations; 962 assert(liveOperations.count(operation.ptr) == 0 && 963 "cannot create detached operation that already exists"); 964 (void)liveOperations; 965 966 PyOperationRef created = createInstance(std::move(contextRef), operation, 967 std::move(parentKeepAlive)); 968 created->attached = false; 969 return created; 970 } 971 972 void PyOperation::checkValid() const { 973 if (!valid) { 974 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 975 } 976 } 977 978 void PyOperationBase::print(py::object fileObject, bool binary, 979 llvm::Optional<int64_t> largeElementsLimit, 980 bool enableDebugInfo, bool prettyDebugInfo, 981 bool printGenericOpForm, bool useLocalScope, 982 bool assumeVerified) { 983 PyOperation &operation = getOperation(); 984 operation.checkValid(); 985 if (fileObject.is_none()) 986 fileObject = py::module::import("sys").attr("stdout"); 987 988 if (!assumeVerified && !printGenericOpForm && 989 !mlirOperationVerify(operation)) { 990 std::string message("// Verification failed, printing generic form\n"); 991 if (binary) { 992 fileObject.attr("write")(py::bytes(message)); 993 } else { 994 fileObject.attr("write")(py::str(message)); 995 } 996 printGenericOpForm = true; 997 } 998 999 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 1000 if (largeElementsLimit) 1001 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 1002 if (enableDebugInfo) 1003 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 1004 if (printGenericOpForm) 1005 mlirOpPrintingFlagsPrintGenericOpForm(flags); 1006 1007 PyFileAccumulator accum(fileObject, binary); 1008 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1009 accum.getUserData()); 1010 mlirOpPrintingFlagsDestroy(flags); 1011 } 1012 1013 py::object PyOperationBase::getAsm(bool binary, 1014 llvm::Optional<int64_t> largeElementsLimit, 1015 bool enableDebugInfo, bool prettyDebugInfo, 1016 bool printGenericOpForm, bool useLocalScope, 1017 bool assumeVerified) { 1018 py::object fileObject; 1019 if (binary) { 1020 fileObject = py::module::import("io").attr("BytesIO")(); 1021 } else { 1022 fileObject = py::module::import("io").attr("StringIO")(); 1023 } 1024 print(fileObject, /*binary=*/binary, 1025 /*largeElementsLimit=*/largeElementsLimit, 1026 /*enableDebugInfo=*/enableDebugInfo, 1027 /*prettyDebugInfo=*/prettyDebugInfo, 1028 /*printGenericOpForm=*/printGenericOpForm, 1029 /*useLocalScope=*/useLocalScope, 1030 /*assumeVerified=*/assumeVerified); 1031 1032 return fileObject.attr("getvalue")(); 1033 } 1034 1035 void PyOperationBase::moveAfter(PyOperationBase &other) { 1036 PyOperation &operation = getOperation(); 1037 PyOperation &otherOp = other.getOperation(); 1038 operation.checkValid(); 1039 otherOp.checkValid(); 1040 mlirOperationMoveAfter(operation, otherOp); 1041 operation.parentKeepAlive = otherOp.parentKeepAlive; 1042 } 1043 1044 void PyOperationBase::moveBefore(PyOperationBase &other) { 1045 PyOperation &operation = getOperation(); 1046 PyOperation &otherOp = other.getOperation(); 1047 operation.checkValid(); 1048 otherOp.checkValid(); 1049 mlirOperationMoveBefore(operation, otherOp); 1050 operation.parentKeepAlive = otherOp.parentKeepAlive; 1051 } 1052 1053 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 1054 checkValid(); 1055 if (!isAttached()) 1056 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 1057 MlirOperation operation = mlirOperationGetParentOperation(get()); 1058 if (mlirOperationIsNull(operation)) 1059 return {}; 1060 return PyOperation::forOperation(getContext(), operation); 1061 } 1062 1063 PyBlock PyOperation::getBlock() { 1064 checkValid(); 1065 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 1066 MlirBlock block = mlirOperationGetBlock(get()); 1067 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1068 assert(parentOperation && "Operation has no parent"); 1069 return PyBlock{std::move(*parentOperation), block}; 1070 } 1071 1072 py::object PyOperation::getCapsule() { 1073 checkValid(); 1074 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 1075 } 1076 1077 py::object PyOperation::createFromCapsule(py::object capsule) { 1078 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1079 if (mlirOperationIsNull(rawOperation)) 1080 throw py::error_already_set(); 1081 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1082 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1083 .releaseObject(); 1084 } 1085 1086 static void maybeInsertOperation(PyOperationRef &op, 1087 const py::object &maybeIp) { 1088 // InsertPoint active? 1089 if (!maybeIp.is(py::cast(false))) { 1090 PyInsertionPoint *ip; 1091 if (maybeIp.is_none()) { 1092 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1093 } else { 1094 ip = py::cast<PyInsertionPoint *>(maybeIp); 1095 } 1096 if (ip) 1097 ip->insert(*op.get()); 1098 } 1099 } 1100 1101 py::object PyOperation::create( 1102 const std::string &name, llvm::Optional<std::vector<PyType *>> results, 1103 llvm::Optional<std::vector<PyValue *>> operands, 1104 llvm::Optional<py::dict> attributes, 1105 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 1106 DefaultingPyLocation location, const py::object &maybeIp) { 1107 llvm::SmallVector<MlirValue, 4> mlirOperands; 1108 llvm::SmallVector<MlirType, 4> mlirResults; 1109 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1110 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1111 1112 // General parameter validation. 1113 if (regions < 0) 1114 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 1115 1116 // Unpack/validate operands. 1117 if (operands) { 1118 mlirOperands.reserve(operands->size()); 1119 for (PyValue *operand : *operands) { 1120 if (!operand) 1121 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 1122 mlirOperands.push_back(operand->get()); 1123 } 1124 } 1125 1126 // Unpack/validate results. 1127 if (results) { 1128 mlirResults.reserve(results->size()); 1129 for (PyType *result : *results) { 1130 // TODO: Verify result type originate from the same context. 1131 if (!result) 1132 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 1133 mlirResults.push_back(*result); 1134 } 1135 } 1136 // Unpack/validate attributes. 1137 if (attributes) { 1138 mlirAttributes.reserve(attributes->size()); 1139 for (auto &it : *attributes) { 1140 std::string key; 1141 try { 1142 key = it.first.cast<std::string>(); 1143 } catch (py::cast_error &err) { 1144 std::string msg = "Invalid attribute key (not a string) when " 1145 "attempting to create the operation \"" + 1146 name + "\" (" + err.what() + ")"; 1147 throw py::cast_error(msg); 1148 } 1149 try { 1150 auto &attribute = it.second.cast<PyAttribute &>(); 1151 // TODO: Verify attribute originates from the same context. 1152 mlirAttributes.emplace_back(std::move(key), attribute); 1153 } catch (py::reference_cast_error &) { 1154 // This exception seems thrown when the value is "None". 1155 std::string msg = 1156 "Found an invalid (`None`?) attribute value for the key \"" + key + 1157 "\" when attempting to create the operation \"" + name + "\""; 1158 throw py::cast_error(msg); 1159 } catch (py::cast_error &err) { 1160 std::string msg = "Invalid attribute value for the key \"" + key + 1161 "\" when attempting to create the operation \"" + 1162 name + "\" (" + err.what() + ")"; 1163 throw py::cast_error(msg); 1164 } 1165 } 1166 } 1167 // Unpack/validate successors. 1168 if (successors) { 1169 mlirSuccessors.reserve(successors->size()); 1170 for (auto *successor : *successors) { 1171 // TODO: Verify successor originate from the same context. 1172 if (!successor) 1173 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1174 mlirSuccessors.push_back(successor->get()); 1175 } 1176 } 1177 1178 // Apply unpacked/validated to the operation state. Beyond this 1179 // point, exceptions cannot be thrown or else the state will leak. 1180 MlirOperationState state = 1181 mlirOperationStateGet(toMlirStringRef(name), location); 1182 if (!mlirOperands.empty()) 1183 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1184 mlirOperands.data()); 1185 if (!mlirResults.empty()) 1186 mlirOperationStateAddResults(&state, mlirResults.size(), 1187 mlirResults.data()); 1188 if (!mlirAttributes.empty()) { 1189 // Note that the attribute names directly reference bytes in 1190 // mlirAttributes, so that vector must not be changed from here 1191 // on. 1192 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1193 mlirNamedAttributes.reserve(mlirAttributes.size()); 1194 for (auto &it : mlirAttributes) 1195 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1196 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1197 toMlirStringRef(it.first)), 1198 it.second)); 1199 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1200 mlirNamedAttributes.data()); 1201 } 1202 if (!mlirSuccessors.empty()) 1203 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1204 mlirSuccessors.data()); 1205 if (regions) { 1206 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1207 mlirRegions.resize(regions); 1208 for (int i = 0; i < regions; ++i) 1209 mlirRegions[i] = mlirRegionCreate(); 1210 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1211 mlirRegions.data()); 1212 } 1213 1214 // Construct the operation. 1215 MlirOperation operation = mlirOperationCreate(&state); 1216 PyOperationRef created = 1217 PyOperation::createDetached(location->getContext(), operation); 1218 maybeInsertOperation(created, maybeIp); 1219 1220 return created->createOpView(); 1221 } 1222 1223 py::object PyOperation::clone(const py::object &maybeIp) { 1224 MlirOperation clonedOperation = mlirOperationClone(operation); 1225 PyOperationRef cloned = 1226 PyOperation::createDetached(getContext(), clonedOperation); 1227 maybeInsertOperation(cloned, maybeIp); 1228 1229 return cloned->createOpView(); 1230 } 1231 1232 py::object PyOperation::createOpView() { 1233 checkValid(); 1234 MlirIdentifier ident = mlirOperationGetName(get()); 1235 MlirStringRef identStr = mlirIdentifierStr(ident); 1236 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1237 StringRef(identStr.data, identStr.length)); 1238 if (opViewClass) 1239 return (*opViewClass)(getRef().getObject()); 1240 return py::cast(PyOpView(getRef().getObject())); 1241 } 1242 1243 void PyOperation::erase() { 1244 checkValid(); 1245 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1246 // Python reference to a child operation is live. All children should also 1247 // have their `valid` bit set to false. 1248 auto &liveOperations = getContext()->liveOperations; 1249 if (liveOperations.count(operation.ptr)) 1250 liveOperations.erase(operation.ptr); 1251 mlirOperationDestroy(operation); 1252 valid = false; 1253 } 1254 1255 //------------------------------------------------------------------------------ 1256 // PyOpView 1257 //------------------------------------------------------------------------------ 1258 1259 py::object PyOpView::buildGeneric( 1260 const py::object &cls, py::list resultTypeList, py::list operandList, 1261 llvm::Optional<py::dict> attributes, 1262 llvm::Optional<std::vector<PyBlock *>> successors, 1263 llvm::Optional<int> regions, DefaultingPyLocation location, 1264 const py::object &maybeIp) { 1265 PyMlirContextRef context = location->getContext(); 1266 // Class level operation construction metadata. 1267 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1268 // Operand and result segment specs are either none, which does no 1269 // variadic unpacking, or a list of ints with segment sizes, where each 1270 // element is either a positive number (typically 1 for a scalar) or -1 to 1271 // indicate that it is derived from the length of the same-indexed operand 1272 // or result (implying that it is a list at that position). 1273 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1274 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1275 1276 std::vector<uint32_t> operandSegmentLengths; 1277 std::vector<uint32_t> resultSegmentLengths; 1278 1279 // Validate/determine region count. 1280 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1281 int opMinRegionCount = std::get<0>(opRegionSpec); 1282 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1283 if (!regions) { 1284 regions = opMinRegionCount; 1285 } 1286 if (*regions < opMinRegionCount) { 1287 throw py::value_error( 1288 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1289 llvm::Twine(opMinRegionCount) + 1290 " regions but was built with regions=" + llvm::Twine(*regions)) 1291 .str()); 1292 } 1293 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1294 throw py::value_error( 1295 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1296 llvm::Twine(opMinRegionCount) + 1297 " regions but was built with regions=" + llvm::Twine(*regions)) 1298 .str()); 1299 } 1300 1301 // Unpack results. 1302 std::vector<PyType *> resultTypes; 1303 resultTypes.reserve(resultTypeList.size()); 1304 if (resultSegmentSpecObj.is_none()) { 1305 // Non-variadic result unpacking. 1306 for (const auto &it : llvm::enumerate(resultTypeList)) { 1307 try { 1308 resultTypes.push_back(py::cast<PyType *>(it.value())); 1309 if (!resultTypes.back()) 1310 throw py::cast_error(); 1311 } catch (py::cast_error &err) { 1312 throw py::value_error((llvm::Twine("Result ") + 1313 llvm::Twine(it.index()) + " of operation \"" + 1314 name + "\" must be a Type (" + err.what() + ")") 1315 .str()); 1316 } 1317 } 1318 } else { 1319 // Sized result unpacking. 1320 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1321 if (resultSegmentSpec.size() != resultTypeList.size()) { 1322 throw py::value_error((llvm::Twine("Operation \"") + name + 1323 "\" requires " + 1324 llvm::Twine(resultSegmentSpec.size()) + 1325 " result segments but was provided " + 1326 llvm::Twine(resultTypeList.size())) 1327 .str()); 1328 } 1329 resultSegmentLengths.reserve(resultTypeList.size()); 1330 for (const auto &it : 1331 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1332 int segmentSpec = std::get<1>(it.value()); 1333 if (segmentSpec == 1 || segmentSpec == 0) { 1334 // Unpack unary element. 1335 try { 1336 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1337 if (resultType) { 1338 resultTypes.push_back(resultType); 1339 resultSegmentLengths.push_back(1); 1340 } else if (segmentSpec == 0) { 1341 // Allowed to be optional. 1342 resultSegmentLengths.push_back(0); 1343 } else { 1344 throw py::cast_error("was None and result is not optional"); 1345 } 1346 } catch (py::cast_error &err) { 1347 throw py::value_error((llvm::Twine("Result ") + 1348 llvm::Twine(it.index()) + " of operation \"" + 1349 name + "\" must be a Type (" + err.what() + 1350 ")") 1351 .str()); 1352 } 1353 } else if (segmentSpec == -1) { 1354 // Unpack sequence by appending. 1355 try { 1356 if (std::get<0>(it.value()).is_none()) { 1357 // Treat it as an empty list. 1358 resultSegmentLengths.push_back(0); 1359 } else { 1360 // Unpack the list. 1361 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1362 for (py::object segmentItem : segment) { 1363 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1364 if (!resultTypes.back()) { 1365 throw py::cast_error("contained a None item"); 1366 } 1367 } 1368 resultSegmentLengths.push_back(segment.size()); 1369 } 1370 } catch (std::exception &err) { 1371 // NOTE: Sloppy to be using a catch-all here, but there are at least 1372 // three different unrelated exceptions that can be thrown in the 1373 // above "casts". Just keep the scope above small and catch them all. 1374 throw py::value_error((llvm::Twine("Result ") + 1375 llvm::Twine(it.index()) + " of operation \"" + 1376 name + "\" must be a Sequence of Types (" + 1377 err.what() + ")") 1378 .str()); 1379 } 1380 } else { 1381 throw py::value_error("Unexpected segment spec"); 1382 } 1383 } 1384 } 1385 1386 // Unpack operands. 1387 std::vector<PyValue *> operands; 1388 operands.reserve(operands.size()); 1389 if (operandSegmentSpecObj.is_none()) { 1390 // Non-sized operand unpacking. 1391 for (const auto &it : llvm::enumerate(operandList)) { 1392 try { 1393 operands.push_back(py::cast<PyValue *>(it.value())); 1394 if (!operands.back()) 1395 throw py::cast_error(); 1396 } catch (py::cast_error &err) { 1397 throw py::value_error((llvm::Twine("Operand ") + 1398 llvm::Twine(it.index()) + " of operation \"" + 1399 name + "\" must be a Value (" + err.what() + ")") 1400 .str()); 1401 } 1402 } 1403 } else { 1404 // Sized operand unpacking. 1405 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1406 if (operandSegmentSpec.size() != operandList.size()) { 1407 throw py::value_error((llvm::Twine("Operation \"") + name + 1408 "\" requires " + 1409 llvm::Twine(operandSegmentSpec.size()) + 1410 "operand segments but was provided " + 1411 llvm::Twine(operandList.size())) 1412 .str()); 1413 } 1414 operandSegmentLengths.reserve(operandList.size()); 1415 for (const auto &it : 1416 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1417 int segmentSpec = std::get<1>(it.value()); 1418 if (segmentSpec == 1 || segmentSpec == 0) { 1419 // Unpack unary element. 1420 try { 1421 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1422 if (operandValue) { 1423 operands.push_back(operandValue); 1424 operandSegmentLengths.push_back(1); 1425 } else if (segmentSpec == 0) { 1426 // Allowed to be optional. 1427 operandSegmentLengths.push_back(0); 1428 } else { 1429 throw py::cast_error("was None and operand is not optional"); 1430 } 1431 } catch (py::cast_error &err) { 1432 throw py::value_error((llvm::Twine("Operand ") + 1433 llvm::Twine(it.index()) + " of operation \"" + 1434 name + "\" must be a Value (" + err.what() + 1435 ")") 1436 .str()); 1437 } 1438 } else if (segmentSpec == -1) { 1439 // Unpack sequence by appending. 1440 try { 1441 if (std::get<0>(it.value()).is_none()) { 1442 // Treat it as an empty list. 1443 operandSegmentLengths.push_back(0); 1444 } else { 1445 // Unpack the list. 1446 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1447 for (py::object segmentItem : segment) { 1448 operands.push_back(py::cast<PyValue *>(segmentItem)); 1449 if (!operands.back()) { 1450 throw py::cast_error("contained a None item"); 1451 } 1452 } 1453 operandSegmentLengths.push_back(segment.size()); 1454 } 1455 } catch (std::exception &err) { 1456 // NOTE: Sloppy to be using a catch-all here, but there are at least 1457 // three different unrelated exceptions that can be thrown in the 1458 // above "casts". Just keep the scope above small and catch them all. 1459 throw py::value_error((llvm::Twine("Operand ") + 1460 llvm::Twine(it.index()) + " of operation \"" + 1461 name + "\" must be a Sequence of Values (" + 1462 err.what() + ")") 1463 .str()); 1464 } 1465 } else { 1466 throw py::value_error("Unexpected segment spec"); 1467 } 1468 } 1469 } 1470 1471 // Merge operand/result segment lengths into attributes if needed. 1472 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1473 // Dup. 1474 if (attributes) { 1475 attributes = py::dict(*attributes); 1476 } else { 1477 attributes = py::dict(); 1478 } 1479 if (attributes->contains("result_segment_sizes") || 1480 attributes->contains("operand_segment_sizes")) { 1481 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1482 "'operand_segment_sizes' attribute is unsupported. " 1483 "Use Operation.create for such low-level access."); 1484 } 1485 1486 // Add result_segment_sizes attribute. 1487 if (!resultSegmentLengths.empty()) { 1488 int64_t size = resultSegmentLengths.size(); 1489 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1490 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1491 resultSegmentLengths.size(), resultSegmentLengths.data()); 1492 (*attributes)["result_segment_sizes"] = 1493 PyAttribute(context, segmentLengthAttr); 1494 } 1495 1496 // Add operand_segment_sizes attribute. 1497 if (!operandSegmentLengths.empty()) { 1498 int64_t size = operandSegmentLengths.size(); 1499 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1500 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1501 operandSegmentLengths.size(), operandSegmentLengths.data()); 1502 (*attributes)["operand_segment_sizes"] = 1503 PyAttribute(context, segmentLengthAttr); 1504 } 1505 } 1506 1507 // Delegate to create. 1508 return PyOperation::create(name, 1509 /*results=*/std::move(resultTypes), 1510 /*operands=*/std::move(operands), 1511 /*attributes=*/std::move(attributes), 1512 /*successors=*/std::move(successors), 1513 /*regions=*/*regions, location, maybeIp); 1514 } 1515 1516 PyOpView::PyOpView(const py::object &operationObject) 1517 // Casting through the PyOperationBase base-class and then back to the 1518 // Operation lets us accept any PyOperationBase subclass. 1519 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1520 operationObject(operation.getRef().getObject()) {} 1521 1522 py::object PyOpView::createRawSubclass(const py::object &userClass) { 1523 // This is... a little gross. The typical pattern is to have a pure python 1524 // class that extends OpView like: 1525 // class AddFOp(_cext.ir.OpView): 1526 // def __init__(self, loc, lhs, rhs): 1527 // operation = loc.context.create_operation( 1528 // "addf", lhs, rhs, results=[lhs.type]) 1529 // super().__init__(operation) 1530 // 1531 // I.e. The goal of the user facing type is to provide a nice constructor 1532 // that has complete freedom for the op under construction. This is at odds 1533 // with our other desire to sometimes create this object by just passing an 1534 // operation (to initialize the base class). We could do *arg and **kwargs 1535 // munging to try to make it work, but instead, we synthesize a new class 1536 // on the fly which extends this user class (AddFOp in this example) and 1537 // *give it* the base class's __init__ method, thus bypassing the 1538 // intermediate subclass's __init__ method entirely. While slightly, 1539 // underhanded, this is safe/legal because the type hierarchy has not changed 1540 // (we just added a new leaf) and we aren't mucking around with __new__. 1541 // Typically, this new class will be stored on the original as "_Raw" and will 1542 // be used for casts and other things that need a variant of the class that 1543 // is initialized purely from an operation. 1544 py::object parentMetaclass = 1545 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1546 py::dict attributes; 1547 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1548 // now. 1549 // auto opViewType = py::type::of<PyOpView>(); 1550 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1551 attributes["__init__"] = opViewType.attr("__init__"); 1552 py::str origName = userClass.attr("__name__"); 1553 py::str newName = py::str("_") + origName; 1554 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1555 } 1556 1557 //------------------------------------------------------------------------------ 1558 // PyInsertionPoint. 1559 //------------------------------------------------------------------------------ 1560 1561 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1562 1563 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1564 : refOperation(beforeOperationBase.getOperation().getRef()), 1565 block((*refOperation)->getBlock()) {} 1566 1567 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1568 PyOperation &operation = operationBase.getOperation(); 1569 if (operation.isAttached()) 1570 throw SetPyError(PyExc_ValueError, 1571 "Attempt to insert operation that is already attached"); 1572 block.getParentOperation()->checkValid(); 1573 MlirOperation beforeOp = {nullptr}; 1574 if (refOperation) { 1575 // Insert before operation. 1576 (*refOperation)->checkValid(); 1577 beforeOp = (*refOperation)->get(); 1578 } else { 1579 // Insert at end (before null) is only valid if the block does not 1580 // already end in a known terminator (violating this will cause assertion 1581 // failures later). 1582 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1583 throw py::index_error("Cannot insert operation at the end of a block " 1584 "that already has a terminator. Did you mean to " 1585 "use 'InsertionPoint.at_block_terminator(block)' " 1586 "versus 'InsertionPoint(block)'?"); 1587 } 1588 } 1589 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1590 operation.setAttached(); 1591 } 1592 1593 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1594 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1595 if (mlirOperationIsNull(firstOp)) { 1596 // Just insert at end. 1597 return PyInsertionPoint(block); 1598 } 1599 1600 // Insert before first op. 1601 PyOperationRef firstOpRef = PyOperation::forOperation( 1602 block.getParentOperation()->getContext(), firstOp); 1603 return PyInsertionPoint{block, std::move(firstOpRef)}; 1604 } 1605 1606 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1607 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1608 if (mlirOperationIsNull(terminator)) 1609 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1610 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1611 block.getParentOperation()->getContext(), terminator); 1612 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1613 } 1614 1615 py::object PyInsertionPoint::contextEnter() { 1616 return PyThreadContextEntry::pushInsertionPoint(*this); 1617 } 1618 1619 void PyInsertionPoint::contextExit(const pybind11::object &excType, 1620 const pybind11::object &excVal, 1621 const pybind11::object &excTb) { 1622 PyThreadContextEntry::popInsertionPoint(*this); 1623 } 1624 1625 //------------------------------------------------------------------------------ 1626 // PyAttribute. 1627 //------------------------------------------------------------------------------ 1628 1629 bool PyAttribute::operator==(const PyAttribute &other) { 1630 return mlirAttributeEqual(attr, other.attr); 1631 } 1632 1633 py::object PyAttribute::getCapsule() { 1634 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1635 } 1636 1637 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1638 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1639 if (mlirAttributeIsNull(rawAttr)) 1640 throw py::error_already_set(); 1641 return PyAttribute( 1642 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1643 } 1644 1645 //------------------------------------------------------------------------------ 1646 // PyNamedAttribute. 1647 //------------------------------------------------------------------------------ 1648 1649 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1650 : ownedName(new std::string(std::move(ownedName))) { 1651 namedAttr = mlirNamedAttributeGet( 1652 mlirIdentifierGet(mlirAttributeGetContext(attr), 1653 toMlirStringRef(*this->ownedName)), 1654 attr); 1655 } 1656 1657 //------------------------------------------------------------------------------ 1658 // PyType. 1659 //------------------------------------------------------------------------------ 1660 1661 bool PyType::operator==(const PyType &other) { 1662 return mlirTypeEqual(type, other.type); 1663 } 1664 1665 py::object PyType::getCapsule() { 1666 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1667 } 1668 1669 PyType PyType::createFromCapsule(py::object capsule) { 1670 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1671 if (mlirTypeIsNull(rawType)) 1672 throw py::error_already_set(); 1673 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1674 rawType); 1675 } 1676 1677 //------------------------------------------------------------------------------ 1678 // PyValue and subclases. 1679 //------------------------------------------------------------------------------ 1680 1681 pybind11::object PyValue::getCapsule() { 1682 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1683 } 1684 1685 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1686 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1687 if (mlirValueIsNull(value)) 1688 throw py::error_already_set(); 1689 MlirOperation owner; 1690 if (mlirValueIsAOpResult(value)) 1691 owner = mlirOpResultGetOwner(value); 1692 if (mlirValueIsABlockArgument(value)) 1693 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1694 if (mlirOperationIsNull(owner)) 1695 throw py::error_already_set(); 1696 MlirContext ctx = mlirOperationGetContext(owner); 1697 PyOperationRef ownerRef = 1698 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1699 return PyValue(ownerRef, value); 1700 } 1701 1702 //------------------------------------------------------------------------------ 1703 // PySymbolTable. 1704 //------------------------------------------------------------------------------ 1705 1706 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1707 : operation(operation.getOperation().getRef()) { 1708 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1709 if (mlirSymbolTableIsNull(symbolTable)) { 1710 throw py::cast_error("Operation is not a Symbol Table."); 1711 } 1712 } 1713 1714 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1715 operation->checkValid(); 1716 MlirOperation symbol = mlirSymbolTableLookup( 1717 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1718 if (mlirOperationIsNull(symbol)) 1719 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1720 1721 return PyOperation::forOperation(operation->getContext(), symbol, 1722 operation.getObject()) 1723 ->createOpView(); 1724 } 1725 1726 void PySymbolTable::erase(PyOperationBase &symbol) { 1727 operation->checkValid(); 1728 symbol.getOperation().checkValid(); 1729 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1730 // The operation is also erased, so we must invalidate it. There may be Python 1731 // references to this operation so we don't want to delete it from the list of 1732 // live operations here. 1733 symbol.getOperation().valid = false; 1734 } 1735 1736 void PySymbolTable::dunderDel(const std::string &name) { 1737 py::object operation = dunderGetItem(name); 1738 erase(py::cast<PyOperationBase &>(operation)); 1739 } 1740 1741 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1742 operation->checkValid(); 1743 symbol.getOperation().checkValid(); 1744 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1745 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1746 if (mlirAttributeIsNull(symbolAttr)) 1747 throw py::value_error("Expected operation to have a symbol name."); 1748 return PyAttribute( 1749 symbol.getOperation().getContext(), 1750 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1751 } 1752 1753 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1754 // Op must already be a symbol. 1755 PyOperation &operation = symbol.getOperation(); 1756 operation.checkValid(); 1757 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1758 MlirAttribute existingNameAttr = 1759 mlirOperationGetAttributeByName(operation.get(), attrName); 1760 if (mlirAttributeIsNull(existingNameAttr)) 1761 throw py::value_error("Expected operation to have a symbol name."); 1762 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1763 } 1764 1765 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1766 const std::string &name) { 1767 // Op must already be a symbol. 1768 PyOperation &operation = symbol.getOperation(); 1769 operation.checkValid(); 1770 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1771 MlirAttribute existingNameAttr = 1772 mlirOperationGetAttributeByName(operation.get(), attrName); 1773 if (mlirAttributeIsNull(existingNameAttr)) 1774 throw py::value_error("Expected operation to have a symbol name."); 1775 MlirAttribute newNameAttr = 1776 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1777 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1778 } 1779 1780 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1781 PyOperation &operation = symbol.getOperation(); 1782 operation.checkValid(); 1783 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1784 MlirAttribute existingVisAttr = 1785 mlirOperationGetAttributeByName(operation.get(), attrName); 1786 if (mlirAttributeIsNull(existingVisAttr)) 1787 throw py::value_error("Expected operation to have a symbol visibility."); 1788 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1789 } 1790 1791 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1792 const std::string &visibility) { 1793 if (visibility != "public" && visibility != "private" && 1794 visibility != "nested") 1795 throw py::value_error( 1796 "Expected visibility to be 'public', 'private' or 'nested'"); 1797 PyOperation &operation = symbol.getOperation(); 1798 operation.checkValid(); 1799 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1800 MlirAttribute existingVisAttr = 1801 mlirOperationGetAttributeByName(operation.get(), attrName); 1802 if (mlirAttributeIsNull(existingVisAttr)) 1803 throw py::value_error("Expected operation to have a symbol visibility."); 1804 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1805 toMlirStringRef(visibility)); 1806 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1807 } 1808 1809 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1810 const std::string &newSymbol, 1811 PyOperationBase &from) { 1812 PyOperation &fromOperation = from.getOperation(); 1813 fromOperation.checkValid(); 1814 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1815 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1816 from.getOperation()))) 1817 1818 throw py::value_error("Symbol rename failed"); 1819 } 1820 1821 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1822 bool allSymUsesVisible, 1823 py::object callback) { 1824 PyOperation &fromOperation = from.getOperation(); 1825 fromOperation.checkValid(); 1826 struct UserData { 1827 PyMlirContextRef context; 1828 py::object callback; 1829 bool gotException; 1830 std::string exceptionWhat; 1831 py::object exceptionType; 1832 }; 1833 UserData userData{ 1834 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1835 mlirSymbolTableWalkSymbolTables( 1836 fromOperation.get(), allSymUsesVisible, 1837 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1838 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1839 auto pyFoundOp = 1840 PyOperation::forOperation(calleeUserData->context, foundOp); 1841 if (calleeUserData->gotException) 1842 return; 1843 try { 1844 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1845 } catch (py::error_already_set &e) { 1846 calleeUserData->gotException = true; 1847 calleeUserData->exceptionWhat = e.what(); 1848 calleeUserData->exceptionType = e.type(); 1849 } 1850 }, 1851 static_cast<void *>(&userData)); 1852 if (userData.gotException) { 1853 std::string message("Exception raised in callback: "); 1854 message.append(userData.exceptionWhat); 1855 throw std::runtime_error(message); 1856 } 1857 } 1858 1859 namespace { 1860 /// CRTP base class for Python MLIR values that subclass Value and should be 1861 /// castable from it. The value hierarchy is one level deep and is not supposed 1862 /// to accommodate other levels unless core MLIR changes. 1863 template <typename DerivedTy> 1864 class PyConcreteValue : public PyValue { 1865 public: 1866 // Derived classes must define statics for: 1867 // IsAFunctionTy isaFunction 1868 // const char *pyClassName 1869 // and redefine bindDerived. 1870 using ClassTy = py::class_<DerivedTy, PyValue>; 1871 using IsAFunctionTy = bool (*)(MlirValue); 1872 1873 PyConcreteValue() = default; 1874 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1875 : PyValue(operationRef, value) {} 1876 PyConcreteValue(PyValue &orig) 1877 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1878 1879 /// Attempts to cast the original value to the derived type and throws on 1880 /// type mismatches. 1881 static MlirValue castFrom(PyValue &orig) { 1882 if (!DerivedTy::isaFunction(orig.get())) { 1883 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1884 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1885 DerivedTy::pyClassName + 1886 " (from " + origRepr + ")"); 1887 } 1888 return orig.get(); 1889 } 1890 1891 /// Binds the Python module objects to functions of this class. 1892 static void bind(py::module &m) { 1893 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1894 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1895 cls.def_static( 1896 "isinstance", 1897 [](PyValue &otherValue) -> bool { 1898 return DerivedTy::isaFunction(otherValue); 1899 }, 1900 py::arg("other_value")); 1901 DerivedTy::bindDerived(cls); 1902 } 1903 1904 /// Implemented by derived classes to add methods to the Python subclass. 1905 static void bindDerived(ClassTy &m) {} 1906 }; 1907 1908 /// Python wrapper for MlirBlockArgument. 1909 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1910 public: 1911 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1912 static constexpr const char *pyClassName = "BlockArgument"; 1913 using PyConcreteValue::PyConcreteValue; 1914 1915 static void bindDerived(ClassTy &c) { 1916 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1917 return PyBlock(self.getParentOperation(), 1918 mlirBlockArgumentGetOwner(self.get())); 1919 }); 1920 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1921 return mlirBlockArgumentGetArgNumber(self.get()); 1922 }); 1923 c.def( 1924 "set_type", 1925 [](PyBlockArgument &self, PyType type) { 1926 return mlirBlockArgumentSetType(self.get(), type); 1927 }, 1928 py::arg("type")); 1929 } 1930 }; 1931 1932 /// Python wrapper for MlirOpResult. 1933 class PyOpResult : public PyConcreteValue<PyOpResult> { 1934 public: 1935 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1936 static constexpr const char *pyClassName = "OpResult"; 1937 using PyConcreteValue::PyConcreteValue; 1938 1939 static void bindDerived(ClassTy &c) { 1940 c.def_property_readonly("owner", [](PyOpResult &self) { 1941 assert( 1942 mlirOperationEqual(self.getParentOperation()->get(), 1943 mlirOpResultGetOwner(self.get())) && 1944 "expected the owner of the value in Python to match that in the IR"); 1945 return self.getParentOperation().getObject(); 1946 }); 1947 c.def_property_readonly("result_number", [](PyOpResult &self) { 1948 return mlirOpResultGetResultNumber(self.get()); 1949 }); 1950 } 1951 }; 1952 1953 /// Returns the list of types of the values held by container. 1954 template <typename Container> 1955 static std::vector<PyType> getValueTypes(Container &container, 1956 PyMlirContextRef &context) { 1957 std::vector<PyType> result; 1958 result.reserve(container.getNumElements()); 1959 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1960 result.push_back( 1961 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1962 } 1963 return result; 1964 } 1965 1966 /// A list of block arguments. Internally, these are stored as consecutive 1967 /// elements, random access is cheap. The argument list is associated with the 1968 /// operation that contains the block (detached blocks are not allowed in 1969 /// Python bindings) and extends its lifetime. 1970 class PyBlockArgumentList 1971 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1972 public: 1973 static constexpr const char *pyClassName = "BlockArgumentList"; 1974 1975 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1976 intptr_t startIndex = 0, intptr_t length = -1, 1977 intptr_t step = 1) 1978 : Sliceable(startIndex, 1979 length == -1 ? mlirBlockGetNumArguments(block) : length, 1980 step), 1981 operation(std::move(operation)), block(block) {} 1982 1983 /// Returns the number of arguments in the list. 1984 intptr_t getNumElements() { 1985 operation->checkValid(); 1986 return mlirBlockGetNumArguments(block); 1987 } 1988 1989 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1990 PyBlockArgument getElement(intptr_t pos) { 1991 MlirValue argument = mlirBlockGetArgument(block, pos); 1992 return PyBlockArgument(operation, argument); 1993 } 1994 1995 /// Returns a sublist of this list. 1996 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1997 intptr_t step) { 1998 return PyBlockArgumentList(operation, block, startIndex, length, step); 1999 } 2000 2001 static void bindDerived(ClassTy &c) { 2002 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 2003 return getValueTypes(self, self.operation->getContext()); 2004 }); 2005 } 2006 2007 private: 2008 PyOperationRef operation; 2009 MlirBlock block; 2010 }; 2011 2012 /// A list of operation operands. Internally, these are stored as consecutive 2013 /// elements, random access is cheap. The result list is associated with the 2014 /// operation whose results these are, and extends the lifetime of this 2015 /// operation. 2016 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 2017 public: 2018 static constexpr const char *pyClassName = "OpOperandList"; 2019 2020 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2021 intptr_t length = -1, intptr_t step = 1) 2022 : Sliceable(startIndex, 2023 length == -1 ? mlirOperationGetNumOperands(operation->get()) 2024 : length, 2025 step), 2026 operation(operation) {} 2027 2028 intptr_t getNumElements() { 2029 operation->checkValid(); 2030 return mlirOperationGetNumOperands(operation->get()); 2031 } 2032 2033 PyValue getElement(intptr_t pos) { 2034 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2035 MlirOperation owner; 2036 if (mlirValueIsAOpResult(operand)) 2037 owner = mlirOpResultGetOwner(operand); 2038 else if (mlirValueIsABlockArgument(operand)) 2039 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2040 else 2041 assert(false && "Value must be an block arg or op result."); 2042 PyOperationRef pyOwner = 2043 PyOperation::forOperation(operation->getContext(), owner); 2044 return PyValue(pyOwner, operand); 2045 } 2046 2047 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2048 return PyOpOperandList(operation, startIndex, length, step); 2049 } 2050 2051 void dunderSetItem(intptr_t index, PyValue value) { 2052 index = wrapIndex(index); 2053 mlirOperationSetOperand(operation->get(), index, value.get()); 2054 } 2055 2056 static void bindDerived(ClassTy &c) { 2057 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2058 } 2059 2060 private: 2061 PyOperationRef operation; 2062 }; 2063 2064 /// A list of operation results. Internally, these are stored as consecutive 2065 /// elements, random access is cheap. The result list is associated with the 2066 /// operation whose results these are, and extends the lifetime of this 2067 /// operation. 2068 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 2069 public: 2070 static constexpr const char *pyClassName = "OpResultList"; 2071 2072 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 2073 intptr_t length = -1, intptr_t step = 1) 2074 : Sliceable(startIndex, 2075 length == -1 ? mlirOperationGetNumResults(operation->get()) 2076 : length, 2077 step), 2078 operation(operation) {} 2079 2080 intptr_t getNumElements() { 2081 operation->checkValid(); 2082 return mlirOperationGetNumResults(operation->get()); 2083 } 2084 2085 PyOpResult getElement(intptr_t index) { 2086 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 2087 return PyOpResult(value); 2088 } 2089 2090 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2091 return PyOpResultList(operation, startIndex, length, step); 2092 } 2093 2094 static void bindDerived(ClassTy &c) { 2095 c.def_property_readonly("types", [](PyOpResultList &self) { 2096 return getValueTypes(self, self.operation->getContext()); 2097 }); 2098 } 2099 2100 private: 2101 PyOperationRef operation; 2102 }; 2103 2104 /// A list of operation attributes. Can be indexed by name, producing 2105 /// attributes, or by index, producing named attributes. 2106 class PyOpAttributeMap { 2107 public: 2108 PyOpAttributeMap(PyOperationRef operation) 2109 : operation(std::move(operation)) {} 2110 2111 PyAttribute dunderGetItemNamed(const std::string &name) { 2112 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2113 toMlirStringRef(name)); 2114 if (mlirAttributeIsNull(attr)) { 2115 throw SetPyError(PyExc_KeyError, 2116 "attempt to access a non-existent attribute"); 2117 } 2118 return PyAttribute(operation->getContext(), attr); 2119 } 2120 2121 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2122 if (index < 0 || index >= dunderLen()) { 2123 throw SetPyError(PyExc_IndexError, 2124 "attempt to access out of bounds attribute"); 2125 } 2126 MlirNamedAttribute namedAttr = 2127 mlirOperationGetAttribute(operation->get(), index); 2128 return PyNamedAttribute( 2129 namedAttr.attribute, 2130 std::string(mlirIdentifierStr(namedAttr.name).data, 2131 mlirIdentifierStr(namedAttr.name).length)); 2132 } 2133 2134 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2135 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2136 attr); 2137 } 2138 2139 void dunderDelItem(const std::string &name) { 2140 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2141 toMlirStringRef(name)); 2142 if (!removed) 2143 throw SetPyError(PyExc_KeyError, 2144 "attempt to delete a non-existent attribute"); 2145 } 2146 2147 intptr_t dunderLen() { 2148 return mlirOperationGetNumAttributes(operation->get()); 2149 } 2150 2151 bool dunderContains(const std::string &name) { 2152 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2153 operation->get(), toMlirStringRef(name))); 2154 } 2155 2156 static void bind(py::module &m) { 2157 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2158 .def("__contains__", &PyOpAttributeMap::dunderContains) 2159 .def("__len__", &PyOpAttributeMap::dunderLen) 2160 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2161 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2162 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2163 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2164 } 2165 2166 private: 2167 PyOperationRef operation; 2168 }; 2169 2170 } // namespace 2171 2172 //------------------------------------------------------------------------------ 2173 // Populates the core exports of the 'ir' submodule. 2174 //------------------------------------------------------------------------------ 2175 2176 void mlir::python::populateIRCore(py::module &m) { 2177 //---------------------------------------------------------------------------- 2178 // Enums. 2179 //---------------------------------------------------------------------------- 2180 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local()) 2181 .value("ERROR", MlirDiagnosticError) 2182 .value("WARNING", MlirDiagnosticWarning) 2183 .value("NOTE", MlirDiagnosticNote) 2184 .value("REMARK", MlirDiagnosticRemark); 2185 2186 //---------------------------------------------------------------------------- 2187 // Mapping of Diagnostics. 2188 //---------------------------------------------------------------------------- 2189 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local()) 2190 .def_property_readonly("severity", &PyDiagnostic::getSeverity) 2191 .def_property_readonly("location", &PyDiagnostic::getLocation) 2192 .def_property_readonly("message", &PyDiagnostic::getMessage) 2193 .def_property_readonly("notes", &PyDiagnostic::getNotes) 2194 .def("__str__", [](PyDiagnostic &self) -> py::str { 2195 if (!self.isValid()) 2196 return "<Invalid Diagnostic>"; 2197 return self.getMessage(); 2198 }); 2199 2200 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local()) 2201 .def("detach", &PyDiagnosticHandler::detach) 2202 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) 2203 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) 2204 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2205 .def("__exit__", &PyDiagnosticHandler::contextExit); 2206 2207 //---------------------------------------------------------------------------- 2208 // Mapping of MlirContext. 2209 //---------------------------------------------------------------------------- 2210 py::class_<PyMlirContext>(m, "Context", py::module_local()) 2211 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2212 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2213 .def("_get_context_again", 2214 [](PyMlirContext &self) { 2215 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2216 return ref.releaseObject(); 2217 }) 2218 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2219 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) 2220 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2221 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2222 &PyMlirContext::getCapsule) 2223 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2224 .def("__enter__", &PyMlirContext::contextEnter) 2225 .def("__exit__", &PyMlirContext::contextExit) 2226 .def_property_readonly_static( 2227 "current", 2228 [](py::object & /*class*/) { 2229 auto *context = PyThreadContextEntry::getDefaultContext(); 2230 if (!context) 2231 throw SetPyError(PyExc_ValueError, "No current Context"); 2232 return context; 2233 }, 2234 "Gets the Context bound to the current thread or raises ValueError") 2235 .def_property_readonly( 2236 "dialects", 2237 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2238 "Gets a container for accessing dialects by name") 2239 .def_property_readonly( 2240 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2241 "Alias for 'dialect'") 2242 .def( 2243 "get_dialect_descriptor", 2244 [=](PyMlirContext &self, std::string &name) { 2245 MlirDialect dialect = mlirContextGetOrLoadDialect( 2246 self.get(), {name.data(), name.size()}); 2247 if (mlirDialectIsNull(dialect)) { 2248 throw SetPyError(PyExc_ValueError, 2249 Twine("Dialect '") + name + "' not found"); 2250 } 2251 return PyDialectDescriptor(self.getRef(), dialect); 2252 }, 2253 py::arg("dialect_name"), 2254 "Gets or loads a dialect by name, returning its descriptor object") 2255 .def_property( 2256 "allow_unregistered_dialects", 2257 [](PyMlirContext &self) -> bool { 2258 return mlirContextGetAllowUnregisteredDialects(self.get()); 2259 }, 2260 [](PyMlirContext &self, bool value) { 2261 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2262 }) 2263 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2264 py::arg("callback"), 2265 "Attaches a diagnostic handler that will receive callbacks") 2266 .def( 2267 "enable_multithreading", 2268 [](PyMlirContext &self, bool enable) { 2269 mlirContextEnableMultithreading(self.get(), enable); 2270 }, 2271 py::arg("enable")) 2272 .def( 2273 "is_registered_operation", 2274 [](PyMlirContext &self, std::string &name) { 2275 return mlirContextIsRegisteredOperation( 2276 self.get(), MlirStringRef{name.data(), name.size()}); 2277 }, 2278 py::arg("operation_name")); 2279 2280 //---------------------------------------------------------------------------- 2281 // Mapping of PyDialectDescriptor 2282 //---------------------------------------------------------------------------- 2283 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2284 .def_property_readonly("namespace", 2285 [](PyDialectDescriptor &self) { 2286 MlirStringRef ns = 2287 mlirDialectGetNamespace(self.get()); 2288 return py::str(ns.data, ns.length); 2289 }) 2290 .def("__repr__", [](PyDialectDescriptor &self) { 2291 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2292 std::string repr("<DialectDescriptor "); 2293 repr.append(ns.data, ns.length); 2294 repr.append(">"); 2295 return repr; 2296 }); 2297 2298 //---------------------------------------------------------------------------- 2299 // Mapping of PyDialects 2300 //---------------------------------------------------------------------------- 2301 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2302 .def("__getitem__", 2303 [=](PyDialects &self, std::string keyName) { 2304 MlirDialect dialect = 2305 self.getDialectForKey(keyName, /*attrError=*/false); 2306 py::object descriptor = 2307 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2308 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2309 }) 2310 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2311 MlirDialect dialect = 2312 self.getDialectForKey(attrName, /*attrError=*/true); 2313 py::object descriptor = 2314 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2315 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2316 }); 2317 2318 //---------------------------------------------------------------------------- 2319 // Mapping of PyDialect 2320 //---------------------------------------------------------------------------- 2321 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2322 .def(py::init<py::object>(), py::arg("descriptor")) 2323 .def_property_readonly( 2324 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2325 .def("__repr__", [](py::object self) { 2326 auto clazz = self.attr("__class__"); 2327 return py::str("<Dialect ") + 2328 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2329 clazz.attr("__module__") + py::str(".") + 2330 clazz.attr("__name__") + py::str(")>"); 2331 }); 2332 2333 //---------------------------------------------------------------------------- 2334 // Mapping of Location 2335 //---------------------------------------------------------------------------- 2336 py::class_<PyLocation>(m, "Location", py::module_local()) 2337 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2338 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2339 .def("__enter__", &PyLocation::contextEnter) 2340 .def("__exit__", &PyLocation::contextExit) 2341 .def("__eq__", 2342 [](PyLocation &self, PyLocation &other) -> bool { 2343 return mlirLocationEqual(self, other); 2344 }) 2345 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2346 .def_property_readonly_static( 2347 "current", 2348 [](py::object & /*class*/) { 2349 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2350 if (!loc) 2351 throw SetPyError(PyExc_ValueError, "No current Location"); 2352 return loc; 2353 }, 2354 "Gets the Location bound to the current thread or raises ValueError") 2355 .def_static( 2356 "unknown", 2357 [](DefaultingPyMlirContext context) { 2358 return PyLocation(context->getRef(), 2359 mlirLocationUnknownGet(context->get())); 2360 }, 2361 py::arg("context") = py::none(), 2362 "Gets a Location representing an unknown location") 2363 .def_static( 2364 "callsite", 2365 [](PyLocation callee, const std::vector<PyLocation> &frames, 2366 DefaultingPyMlirContext context) { 2367 if (frames.empty()) 2368 throw py::value_error("No caller frames provided"); 2369 MlirLocation caller = frames.back().get(); 2370 for (const PyLocation &frame : 2371 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2372 caller = mlirLocationCallSiteGet(frame.get(), caller); 2373 return PyLocation(context->getRef(), 2374 mlirLocationCallSiteGet(callee.get(), caller)); 2375 }, 2376 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2377 kContextGetCallSiteLocationDocstring) 2378 .def_static( 2379 "file", 2380 [](std::string filename, int line, int col, 2381 DefaultingPyMlirContext context) { 2382 return PyLocation( 2383 context->getRef(), 2384 mlirLocationFileLineColGet( 2385 context->get(), toMlirStringRef(filename), line, col)); 2386 }, 2387 py::arg("filename"), py::arg("line"), py::arg("col"), 2388 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2389 .def_static( 2390 "fused", 2391 [](const std::vector<PyLocation> &pyLocations, 2392 llvm::Optional<PyAttribute> metadata, 2393 DefaultingPyMlirContext context) { 2394 llvm::SmallVector<MlirLocation, 4> locations; 2395 locations.reserve(pyLocations.size()); 2396 for (auto &pyLocation : pyLocations) 2397 locations.push_back(pyLocation.get()); 2398 MlirLocation location = mlirLocationFusedGet( 2399 context->get(), locations.size(), locations.data(), 2400 metadata ? metadata->get() : MlirAttribute{0}); 2401 return PyLocation(context->getRef(), location); 2402 }, 2403 py::arg("locations"), py::arg("metadata") = py::none(), 2404 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2405 .def_static( 2406 "name", 2407 [](std::string name, llvm::Optional<PyLocation> childLoc, 2408 DefaultingPyMlirContext context) { 2409 return PyLocation( 2410 context->getRef(), 2411 mlirLocationNameGet( 2412 context->get(), toMlirStringRef(name), 2413 childLoc ? childLoc->get() 2414 : mlirLocationUnknownGet(context->get()))); 2415 }, 2416 py::arg("name"), py::arg("childLoc") = py::none(), 2417 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2418 .def_property_readonly( 2419 "context", 2420 [](PyLocation &self) { return self.getContext().getObject(); }, 2421 "Context that owns the Location") 2422 .def( 2423 "emit_error", 2424 [](PyLocation &self, std::string message) { 2425 mlirEmitError(self, message.c_str()); 2426 }, 2427 py::arg("message"), "Emits an error at this location") 2428 .def("__repr__", [](PyLocation &self) { 2429 PyPrintAccumulator printAccum; 2430 mlirLocationPrint(self, printAccum.getCallback(), 2431 printAccum.getUserData()); 2432 return printAccum.join(); 2433 }); 2434 2435 //---------------------------------------------------------------------------- 2436 // Mapping of Module 2437 //---------------------------------------------------------------------------- 2438 py::class_<PyModule>(m, "Module", py::module_local()) 2439 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2440 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2441 .def_static( 2442 "parse", 2443 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2444 MlirModule module = mlirModuleCreateParse( 2445 context->get(), toMlirStringRef(moduleAsm)); 2446 // TODO: Rework error reporting once diagnostic engine is exposed 2447 // in C API. 2448 if (mlirModuleIsNull(module)) { 2449 throw SetPyError( 2450 PyExc_ValueError, 2451 "Unable to parse module assembly (see diagnostics)"); 2452 } 2453 return PyModule::forModule(module).releaseObject(); 2454 }, 2455 py::arg("asm"), py::arg("context") = py::none(), 2456 kModuleParseDocstring) 2457 .def_static( 2458 "create", 2459 [](DefaultingPyLocation loc) { 2460 MlirModule module = mlirModuleCreateEmpty(loc); 2461 return PyModule::forModule(module).releaseObject(); 2462 }, 2463 py::arg("loc") = py::none(), "Creates an empty module") 2464 .def_property_readonly( 2465 "context", 2466 [](PyModule &self) { return self.getContext().getObject(); }, 2467 "Context that created the Module") 2468 .def_property_readonly( 2469 "operation", 2470 [](PyModule &self) { 2471 return PyOperation::forOperation(self.getContext(), 2472 mlirModuleGetOperation(self.get()), 2473 self.getRef().releaseObject()) 2474 .releaseObject(); 2475 }, 2476 "Accesses the module as an operation") 2477 .def_property_readonly( 2478 "body", 2479 [](PyModule &self) { 2480 PyOperationRef moduleOp = PyOperation::forOperation( 2481 self.getContext(), mlirModuleGetOperation(self.get()), 2482 self.getRef().releaseObject()); 2483 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 2484 return returnBlock; 2485 }, 2486 "Return the block for this module") 2487 .def( 2488 "dump", 2489 [](PyModule &self) { 2490 mlirOperationDump(mlirModuleGetOperation(self.get())); 2491 }, 2492 kDumpDocstring) 2493 .def( 2494 "__str__", 2495 [](py::object self) { 2496 // Defer to the operation's __str__. 2497 return self.attr("operation").attr("__str__")(); 2498 }, 2499 kOperationStrDunderDocstring); 2500 2501 //---------------------------------------------------------------------------- 2502 // Mapping of Operation. 2503 //---------------------------------------------------------------------------- 2504 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2505 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2506 [](PyOperationBase &self) { 2507 return self.getOperation().getCapsule(); 2508 }) 2509 .def("__eq__", 2510 [](PyOperationBase &self, PyOperationBase &other) { 2511 return &self.getOperation() == &other.getOperation(); 2512 }) 2513 .def("__eq__", 2514 [](PyOperationBase &self, py::object other) { return false; }) 2515 .def("__hash__", 2516 [](PyOperationBase &self) { 2517 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2518 }) 2519 .def_property_readonly("attributes", 2520 [](PyOperationBase &self) { 2521 return PyOpAttributeMap( 2522 self.getOperation().getRef()); 2523 }) 2524 .def_property_readonly("operands", 2525 [](PyOperationBase &self) { 2526 return PyOpOperandList( 2527 self.getOperation().getRef()); 2528 }) 2529 .def_property_readonly("regions", 2530 [](PyOperationBase &self) { 2531 return PyRegionList( 2532 self.getOperation().getRef()); 2533 }) 2534 .def_property_readonly( 2535 "results", 2536 [](PyOperationBase &self) { 2537 return PyOpResultList(self.getOperation().getRef()); 2538 }, 2539 "Returns the list of Operation results.") 2540 .def_property_readonly( 2541 "result", 2542 [](PyOperationBase &self) { 2543 auto &operation = self.getOperation(); 2544 auto numResults = mlirOperationGetNumResults(operation); 2545 if (numResults != 1) { 2546 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2547 throw SetPyError( 2548 PyExc_ValueError, 2549 Twine("Cannot call .result on operation ") + 2550 StringRef(name.data, name.length) + " which has " + 2551 Twine(numResults) + 2552 " results (it is only valid for operations with a " 2553 "single result)"); 2554 } 2555 return PyOpResult(operation.getRef(), 2556 mlirOperationGetResult(operation, 0)); 2557 }, 2558 "Shortcut to get an op result if it has only one (throws an error " 2559 "otherwise).") 2560 .def_property_readonly( 2561 "location", 2562 [](PyOperationBase &self) { 2563 PyOperation &operation = self.getOperation(); 2564 return PyLocation(operation.getContext(), 2565 mlirOperationGetLocation(operation.get())); 2566 }, 2567 "Returns the source location the operation was defined or derived " 2568 "from.") 2569 .def( 2570 "__str__", 2571 [](PyOperationBase &self) { 2572 return self.getAsm(/*binary=*/false, 2573 /*largeElementsLimit=*/llvm::None, 2574 /*enableDebugInfo=*/false, 2575 /*prettyDebugInfo=*/false, 2576 /*printGenericOpForm=*/false, 2577 /*useLocalScope=*/false, 2578 /*assumeVerified=*/false); 2579 }, 2580 "Returns the assembly form of the operation.") 2581 .def("print", &PyOperationBase::print, 2582 // Careful: Lots of arguments must match up with print method. 2583 py::arg("file") = py::none(), py::arg("binary") = false, 2584 py::arg("large_elements_limit") = py::none(), 2585 py::arg("enable_debug_info") = false, 2586 py::arg("pretty_debug_info") = false, 2587 py::arg("print_generic_op_form") = false, 2588 py::arg("use_local_scope") = false, 2589 py::arg("assume_verified") = false, kOperationPrintDocstring) 2590 .def("get_asm", &PyOperationBase::getAsm, 2591 // Careful: Lots of arguments must match up with get_asm method. 2592 py::arg("binary") = false, 2593 py::arg("large_elements_limit") = py::none(), 2594 py::arg("enable_debug_info") = false, 2595 py::arg("pretty_debug_info") = false, 2596 py::arg("print_generic_op_form") = false, 2597 py::arg("use_local_scope") = false, 2598 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2599 .def( 2600 "verify", 2601 [](PyOperationBase &self) { 2602 return mlirOperationVerify(self.getOperation()); 2603 }, 2604 "Verify the operation and return true if it passes, false if it " 2605 "fails.") 2606 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2607 "Puts self immediately after the other operation in its parent " 2608 "block.") 2609 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2610 "Puts self immediately before the other operation in its parent " 2611 "block.") 2612 .def( 2613 "detach_from_parent", 2614 [](PyOperationBase &self) { 2615 PyOperation &operation = self.getOperation(); 2616 operation.checkValid(); 2617 if (!operation.isAttached()) 2618 throw py::value_error("Detached operation has no parent."); 2619 2620 operation.detachFromParent(); 2621 return operation.createOpView(); 2622 }, 2623 "Detaches the operation from its parent block."); 2624 2625 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2626 .def_static("create", &PyOperation::create, py::arg("name"), 2627 py::arg("results") = py::none(), 2628 py::arg("operands") = py::none(), 2629 py::arg("attributes") = py::none(), 2630 py::arg("successors") = py::none(), py::arg("regions") = 0, 2631 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2632 kOperationCreateDocstring) 2633 .def_property_readonly("parent", 2634 [](PyOperation &self) -> py::object { 2635 auto parent = self.getParentOperation(); 2636 if (parent) 2637 return parent->getObject(); 2638 return py::none(); 2639 }) 2640 .def("erase", &PyOperation::erase) 2641 .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) 2642 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2643 &PyOperation::getCapsule) 2644 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2645 .def_property_readonly("name", 2646 [](PyOperation &self) { 2647 self.checkValid(); 2648 MlirOperation operation = self.get(); 2649 MlirStringRef name = mlirIdentifierStr( 2650 mlirOperationGetName(operation)); 2651 return py::str(name.data, name.length); 2652 }) 2653 .def_property_readonly( 2654 "context", 2655 [](PyOperation &self) { 2656 self.checkValid(); 2657 return self.getContext().getObject(); 2658 }, 2659 "Context that owns the Operation") 2660 .def_property_readonly("opview", &PyOperation::createOpView); 2661 2662 auto opViewClass = 2663 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2664 .def(py::init<py::object>(), py::arg("operation")) 2665 .def_property_readonly("operation", &PyOpView::getOperationObject) 2666 .def_property_readonly( 2667 "context", 2668 [](PyOpView &self) { 2669 return self.getOperation().getContext().getObject(); 2670 }, 2671 "Context that owns the Operation") 2672 .def("__str__", [](PyOpView &self) { 2673 return py::str(self.getOperationObject()); 2674 }); 2675 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2676 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2677 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2678 opViewClass.attr("build_generic") = classmethod( 2679 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2680 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2681 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2682 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2683 "Builds a specific, generated OpView based on class level attributes."); 2684 2685 //---------------------------------------------------------------------------- 2686 // Mapping of PyRegion. 2687 //---------------------------------------------------------------------------- 2688 py::class_<PyRegion>(m, "Region", py::module_local()) 2689 .def_property_readonly( 2690 "blocks", 2691 [](PyRegion &self) { 2692 return PyBlockList(self.getParentOperation(), self.get()); 2693 }, 2694 "Returns a forward-optimized sequence of blocks.") 2695 .def_property_readonly( 2696 "owner", 2697 [](PyRegion &self) { 2698 return self.getParentOperation()->createOpView(); 2699 }, 2700 "Returns the operation owning this region.") 2701 .def( 2702 "__iter__", 2703 [](PyRegion &self) { 2704 self.checkValid(); 2705 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2706 return PyBlockIterator(self.getParentOperation(), firstBlock); 2707 }, 2708 "Iterates over blocks in the region.") 2709 .def("__eq__", 2710 [](PyRegion &self, PyRegion &other) { 2711 return self.get().ptr == other.get().ptr; 2712 }) 2713 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2714 2715 //---------------------------------------------------------------------------- 2716 // Mapping of PyBlock. 2717 //---------------------------------------------------------------------------- 2718 py::class_<PyBlock>(m, "Block", py::module_local()) 2719 .def_property_readonly( 2720 "owner", 2721 [](PyBlock &self) { 2722 return self.getParentOperation()->createOpView(); 2723 }, 2724 "Returns the owning operation of this block.") 2725 .def_property_readonly( 2726 "region", 2727 [](PyBlock &self) { 2728 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2729 return PyRegion(self.getParentOperation(), region); 2730 }, 2731 "Returns the owning region of this block.") 2732 .def_property_readonly( 2733 "arguments", 2734 [](PyBlock &self) { 2735 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2736 }, 2737 "Returns a list of block arguments.") 2738 .def_property_readonly( 2739 "operations", 2740 [](PyBlock &self) { 2741 return PyOperationList(self.getParentOperation(), self.get()); 2742 }, 2743 "Returns a forward-optimized sequence of operations.") 2744 .def_static( 2745 "create_at_start", 2746 [](PyRegion &parent, py::list pyArgTypes) { 2747 parent.checkValid(); 2748 llvm::SmallVector<MlirType, 4> argTypes; 2749 llvm::SmallVector<MlirLocation, 4> argLocs; 2750 argTypes.reserve(pyArgTypes.size()); 2751 argLocs.reserve(pyArgTypes.size()); 2752 for (auto &pyArg : pyArgTypes) { 2753 argTypes.push_back(pyArg.cast<PyType &>()); 2754 // TODO: Pass in a proper location here. 2755 argLocs.push_back( 2756 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2757 } 2758 2759 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2760 argLocs.data()); 2761 mlirRegionInsertOwnedBlock(parent, 0, block); 2762 return PyBlock(parent.getParentOperation(), block); 2763 }, 2764 py::arg("parent"), py::arg("arg_types") = py::list(), 2765 "Creates and returns a new Block at the beginning of the given " 2766 "region (with given argument types).") 2767 .def( 2768 "append_to", 2769 [](PyBlock &self, PyRegion ®ion) { 2770 MlirBlock b = self.get(); 2771 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) 2772 mlirBlockDetach(b); 2773 mlirRegionAppendOwnedBlock(region.get(), b); 2774 }, 2775 "Append this block to a region, transferring ownership if necessary") 2776 .def( 2777 "create_before", 2778 [](PyBlock &self, py::args pyArgTypes) { 2779 self.checkValid(); 2780 llvm::SmallVector<MlirType, 4> argTypes; 2781 llvm::SmallVector<MlirLocation, 4> argLocs; 2782 argTypes.reserve(pyArgTypes.size()); 2783 argLocs.reserve(pyArgTypes.size()); 2784 for (auto &pyArg : pyArgTypes) { 2785 argTypes.push_back(pyArg.cast<PyType &>()); 2786 // TODO: Pass in a proper location here. 2787 argLocs.push_back( 2788 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2789 } 2790 2791 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2792 argLocs.data()); 2793 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2794 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2795 return PyBlock(self.getParentOperation(), block); 2796 }, 2797 "Creates and returns a new Block before this block " 2798 "(with given argument types).") 2799 .def( 2800 "create_after", 2801 [](PyBlock &self, py::args pyArgTypes) { 2802 self.checkValid(); 2803 llvm::SmallVector<MlirType, 4> argTypes; 2804 llvm::SmallVector<MlirLocation, 4> argLocs; 2805 argTypes.reserve(pyArgTypes.size()); 2806 argLocs.reserve(pyArgTypes.size()); 2807 for (auto &pyArg : pyArgTypes) { 2808 argTypes.push_back(pyArg.cast<PyType &>()); 2809 2810 // TODO: Pass in a proper location here. 2811 argLocs.push_back( 2812 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2813 } 2814 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2815 argLocs.data()); 2816 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2817 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2818 return PyBlock(self.getParentOperation(), block); 2819 }, 2820 "Creates and returns a new Block after this block " 2821 "(with given argument types).") 2822 .def( 2823 "__iter__", 2824 [](PyBlock &self) { 2825 self.checkValid(); 2826 MlirOperation firstOperation = 2827 mlirBlockGetFirstOperation(self.get()); 2828 return PyOperationIterator(self.getParentOperation(), 2829 firstOperation); 2830 }, 2831 "Iterates over operations in the block.") 2832 .def("__eq__", 2833 [](PyBlock &self, PyBlock &other) { 2834 return self.get().ptr == other.get().ptr; 2835 }) 2836 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2837 .def( 2838 "__str__", 2839 [](PyBlock &self) { 2840 self.checkValid(); 2841 PyPrintAccumulator printAccum; 2842 mlirBlockPrint(self.get(), printAccum.getCallback(), 2843 printAccum.getUserData()); 2844 return printAccum.join(); 2845 }, 2846 "Returns the assembly form of the block.") 2847 .def( 2848 "append", 2849 [](PyBlock &self, PyOperationBase &operation) { 2850 if (operation.getOperation().isAttached()) 2851 operation.getOperation().detachFromParent(); 2852 2853 MlirOperation mlirOperation = operation.getOperation().get(); 2854 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2855 operation.getOperation().setAttached( 2856 self.getParentOperation().getObject()); 2857 }, 2858 py::arg("operation"), 2859 "Appends an operation to this block. If the operation is currently " 2860 "in another block, it will be moved."); 2861 2862 //---------------------------------------------------------------------------- 2863 // Mapping of PyInsertionPoint. 2864 //---------------------------------------------------------------------------- 2865 2866 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2867 .def(py::init<PyBlock &>(), py::arg("block"), 2868 "Inserts after the last operation but still inside the block.") 2869 .def("__enter__", &PyInsertionPoint::contextEnter) 2870 .def("__exit__", &PyInsertionPoint::contextExit) 2871 .def_property_readonly_static( 2872 "current", 2873 [](py::object & /*class*/) { 2874 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2875 if (!ip) 2876 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2877 return ip; 2878 }, 2879 "Gets the InsertionPoint bound to the current thread or raises " 2880 "ValueError if none has been set") 2881 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2882 "Inserts before a referenced operation.") 2883 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2884 py::arg("block"), "Inserts at the beginning of the block.") 2885 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2886 py::arg("block"), "Inserts before the block terminator.") 2887 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2888 "Inserts an operation.") 2889 .def_property_readonly( 2890 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2891 "Returns the block that this InsertionPoint points to."); 2892 2893 //---------------------------------------------------------------------------- 2894 // Mapping of PyAttribute. 2895 //---------------------------------------------------------------------------- 2896 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2897 // Delegate to the PyAttribute copy constructor, which will also lifetime 2898 // extend the backing context which owns the MlirAttribute. 2899 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2900 "Casts the passed attribute to the generic Attribute") 2901 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2902 &PyAttribute::getCapsule) 2903 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2904 .def_static( 2905 "parse", 2906 [](std::string attrSpec, DefaultingPyMlirContext context) { 2907 MlirAttribute type = mlirAttributeParseGet( 2908 context->get(), toMlirStringRef(attrSpec)); 2909 // TODO: Rework error reporting once diagnostic engine is exposed 2910 // in C API. 2911 if (mlirAttributeIsNull(type)) { 2912 throw SetPyError(PyExc_ValueError, 2913 Twine("Unable to parse attribute: '") + 2914 attrSpec + "'"); 2915 } 2916 return PyAttribute(context->getRef(), type); 2917 }, 2918 py::arg("asm"), py::arg("context") = py::none(), 2919 "Parses an attribute from an assembly form") 2920 .def_property_readonly( 2921 "context", 2922 [](PyAttribute &self) { return self.getContext().getObject(); }, 2923 "Context that owns the Attribute") 2924 .def_property_readonly("type", 2925 [](PyAttribute &self) { 2926 return PyType(self.getContext()->getRef(), 2927 mlirAttributeGetType(self)); 2928 }) 2929 .def( 2930 "get_named", 2931 [](PyAttribute &self, std::string name) { 2932 return PyNamedAttribute(self, std::move(name)); 2933 }, 2934 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2935 .def("__eq__", 2936 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2937 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2938 .def("__hash__", 2939 [](PyAttribute &self) { 2940 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2941 }) 2942 .def( 2943 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2944 kDumpDocstring) 2945 .def( 2946 "__str__", 2947 [](PyAttribute &self) { 2948 PyPrintAccumulator printAccum; 2949 mlirAttributePrint(self, printAccum.getCallback(), 2950 printAccum.getUserData()); 2951 return printAccum.join(); 2952 }, 2953 "Returns the assembly form of the Attribute.") 2954 .def("__repr__", [](PyAttribute &self) { 2955 // Generally, assembly formats are not printed for __repr__ because 2956 // this can cause exceptionally long debug output and exceptions. 2957 // However, attribute values are generally considered useful and are 2958 // printed. This may need to be re-evaluated if debug dumps end up 2959 // being excessive. 2960 PyPrintAccumulator printAccum; 2961 printAccum.parts.append("Attribute("); 2962 mlirAttributePrint(self, printAccum.getCallback(), 2963 printAccum.getUserData()); 2964 printAccum.parts.append(")"); 2965 return printAccum.join(); 2966 }); 2967 2968 //---------------------------------------------------------------------------- 2969 // Mapping of PyNamedAttribute 2970 //---------------------------------------------------------------------------- 2971 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2972 .def("__repr__", 2973 [](PyNamedAttribute &self) { 2974 PyPrintAccumulator printAccum; 2975 printAccum.parts.append("NamedAttribute("); 2976 printAccum.parts.append( 2977 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2978 mlirIdentifierStr(self.namedAttr.name).length)); 2979 printAccum.parts.append("="); 2980 mlirAttributePrint(self.namedAttr.attribute, 2981 printAccum.getCallback(), 2982 printAccum.getUserData()); 2983 printAccum.parts.append(")"); 2984 return printAccum.join(); 2985 }) 2986 .def_property_readonly( 2987 "name", 2988 [](PyNamedAttribute &self) { 2989 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2990 mlirIdentifierStr(self.namedAttr.name).length); 2991 }, 2992 "The name of the NamedAttribute binding") 2993 .def_property_readonly( 2994 "attr", 2995 [](PyNamedAttribute &self) { 2996 // TODO: When named attribute is removed/refactored, also remove 2997 // this constructor (it does an inefficient table lookup). 2998 auto contextRef = PyMlirContext::forContext( 2999 mlirAttributeGetContext(self.namedAttr.attribute)); 3000 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 3001 }, 3002 py::keep_alive<0, 1>(), 3003 "The underlying generic attribute of the NamedAttribute binding"); 3004 3005 //---------------------------------------------------------------------------- 3006 // Mapping of PyType. 3007 //---------------------------------------------------------------------------- 3008 py::class_<PyType>(m, "Type", py::module_local()) 3009 // Delegate to the PyType copy constructor, which will also lifetime 3010 // extend the backing context which owns the MlirType. 3011 .def(py::init<PyType &>(), py::arg("cast_from_type"), 3012 "Casts the passed type to the generic Type") 3013 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 3014 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 3015 .def_static( 3016 "parse", 3017 [](std::string typeSpec, DefaultingPyMlirContext context) { 3018 MlirType type = 3019 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 3020 // TODO: Rework error reporting once diagnostic engine is exposed 3021 // in C API. 3022 if (mlirTypeIsNull(type)) { 3023 throw SetPyError(PyExc_ValueError, 3024 Twine("Unable to parse type: '") + typeSpec + 3025 "'"); 3026 } 3027 return PyType(context->getRef(), type); 3028 }, 3029 py::arg("asm"), py::arg("context") = py::none(), 3030 kContextParseTypeDocstring) 3031 .def_property_readonly( 3032 "context", [](PyType &self) { return self.getContext().getObject(); }, 3033 "Context that owns the Type") 3034 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3035 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 3036 .def("__hash__", 3037 [](PyType &self) { 3038 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3039 }) 3040 .def( 3041 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3042 .def( 3043 "__str__", 3044 [](PyType &self) { 3045 PyPrintAccumulator printAccum; 3046 mlirTypePrint(self, printAccum.getCallback(), 3047 printAccum.getUserData()); 3048 return printAccum.join(); 3049 }, 3050 "Returns the assembly form of the type.") 3051 .def("__repr__", [](PyType &self) { 3052 // Generally, assembly formats are not printed for __repr__ because 3053 // this can cause exceptionally long debug output and exceptions. 3054 // However, types are an exception as they typically have compact 3055 // assembly forms and printing them is useful. 3056 PyPrintAccumulator printAccum; 3057 printAccum.parts.append("Type("); 3058 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 3059 printAccum.parts.append(")"); 3060 return printAccum.join(); 3061 }); 3062 3063 //---------------------------------------------------------------------------- 3064 // Mapping of Value. 3065 //---------------------------------------------------------------------------- 3066 py::class_<PyValue>(m, "Value", py::module_local()) 3067 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3068 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3069 .def_property_readonly( 3070 "context", 3071 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3072 "Context in which the value lives.") 3073 .def( 3074 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3075 kDumpDocstring) 3076 .def_property_readonly( 3077 "owner", 3078 [](PyValue &self) { 3079 assert(mlirOperationEqual(self.getParentOperation()->get(), 3080 mlirOpResultGetOwner(self.get())) && 3081 "expected the owner of the value in Python to match that in " 3082 "the IR"); 3083 return self.getParentOperation().getObject(); 3084 }) 3085 .def("__eq__", 3086 [](PyValue &self, PyValue &other) { 3087 return self.get().ptr == other.get().ptr; 3088 }) 3089 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 3090 .def("__hash__", 3091 [](PyValue &self) { 3092 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3093 }) 3094 .def( 3095 "__str__", 3096 [](PyValue &self) { 3097 PyPrintAccumulator printAccum; 3098 printAccum.parts.append("Value("); 3099 mlirValuePrint(self.get(), printAccum.getCallback(), 3100 printAccum.getUserData()); 3101 printAccum.parts.append(")"); 3102 return printAccum.join(); 3103 }, 3104 kValueDunderStrDocstring) 3105 .def_property_readonly("type", [](PyValue &self) { 3106 return PyType(self.getParentOperation()->getContext(), 3107 mlirValueGetType(self.get())); 3108 }); 3109 PyBlockArgument::bind(m); 3110 PyOpResult::bind(m); 3111 3112 //---------------------------------------------------------------------------- 3113 // Mapping of SymbolTable. 3114 //---------------------------------------------------------------------------- 3115 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 3116 .def(py::init<PyOperationBase &>()) 3117 .def("__getitem__", &PySymbolTable::dunderGetItem) 3118 .def("insert", &PySymbolTable::insert, py::arg("operation")) 3119 .def("erase", &PySymbolTable::erase, py::arg("operation")) 3120 .def("__delitem__", &PySymbolTable::dunderDel) 3121 .def("__contains__", 3122 [](PySymbolTable &table, const std::string &name) { 3123 return !mlirOperationIsNull(mlirSymbolTableLookup( 3124 table, mlirStringRefCreate(name.data(), name.length()))); 3125 }) 3126 // Static helpers. 3127 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3128 py::arg("symbol"), py::arg("name")) 3129 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3130 py::arg("symbol")) 3131 .def_static("get_visibility", &PySymbolTable::getVisibility, 3132 py::arg("symbol")) 3133 .def_static("set_visibility", &PySymbolTable::setVisibility, 3134 py::arg("symbol"), py::arg("visibility")) 3135 .def_static("replace_all_symbol_uses", 3136 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 3137 py::arg("new_symbol"), py::arg("from_op")) 3138 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3139 py::arg("from_op"), py::arg("all_sym_uses_visible"), 3140 py::arg("callback")); 3141 3142 // Container bindings. 3143 PyBlockArgumentList::bind(m); 3144 PyBlockIterator::bind(m); 3145 PyBlockList::bind(m); 3146 PyOperationIterator::bind(m); 3147 PyOperationList::bind(m); 3148 PyOpAttributeMap::bind(m); 3149 PyOpOperandList::bind(m); 3150 PyOpResultList::bind(m); 3151 PyRegionIterator::bind(m); 3152 PyRegionList::bind(m); 3153 3154 // Debug bindings. 3155 PyGlobalDebugFlag::bind(m); 3156 } 3157