1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 23 #include <utility> 24 25 namespace py = pybind11; 26 using namespace mlir; 27 using namespace mlir::python; 28 29 using llvm::SmallVector; 30 using llvm::StringRef; 31 using llvm::Twine; 32 33 //------------------------------------------------------------------------------ 34 // Docstrings (trivial, non-duplicated docstrings are included inline). 35 //------------------------------------------------------------------------------ 36 37 static const char kContextParseTypeDocstring[] = 38 R"(Parses the assembly form of a type. 39 40 Returns a Type object or raises a ValueError if the type cannot be parsed. 41 42 See also: https://mlir.llvm.org/docs/LangRef/#type-system 43 )"; 44 45 static const char kContextGetCallSiteLocationDocstring[] = 46 R"(Gets a Location representing a caller and callsite)"; 47 48 static const char kContextGetFileLocationDocstring[] = 49 R"(Gets a Location representing a file, line and column)"; 50 51 static const char kContextGetFusedLocationDocstring[] = 52 R"(Gets a Location representing a fused location with optional metadata)"; 53 54 static const char kContextGetNameLocationDocString[] = 55 R"(Gets a Location representing a named location with optional child location)"; 56 57 static const char kModuleParseDocstring[] = 58 R"(Parses a module's assembly format from a string. 59 60 Returns a new MlirModule or raises a ValueError if the parsing fails. 61 62 See also: https://mlir.llvm.org/docs/LangRef/ 63 )"; 64 65 static const char kOperationCreateDocstring[] = 66 R"(Creates a new operation. 67 68 Args: 69 name: Operation name (e.g. "dialect.operation"). 70 results: Sequence of Type representing op result types. 71 attributes: Dict of str:Attribute. 72 successors: List of Block for the operation's successors. 73 regions: Number of regions to create. 74 location: A Location object (defaults to resolve from context manager). 75 ip: An InsertionPoint (defaults to resolve from context manager or set to 76 False to disable insertion, even with an insertion point set in the 77 context manager). 78 Returns: 79 A new "detached" Operation object. Detached operations can be added 80 to blocks, which causes them to become "attached." 81 )"; 82 83 static const char kOperationPrintDocstring[] = 84 R"(Prints the assembly form of the operation to a file like object. 85 86 Args: 87 file: The file like object to write to. Defaults to sys.stdout. 88 binary: Whether to write bytes (True) or str (False). Defaults to False. 89 large_elements_limit: Whether to elide elements attributes above this 90 number of elements. Defaults to None (no limit). 91 enable_debug_info: Whether to print debug/location information. Defaults 92 to False. 93 pretty_debug_info: Whether to format debug information for easier reading 94 by a human (warning: the result is unparseable). 95 print_generic_op_form: Whether to print the generic assembly forms of all 96 ops. Defaults to False. 97 use_local_Scope: Whether to print in a way that is more optimized for 98 multi-threaded access but may not be consistent with how the overall 99 module prints. 100 assume_verified: By default, if not printing generic form, the verifier 101 will be run and if it fails, generic form will be printed with a comment 102 about failed verification. While a reasonable default for interactive use, 103 for systematic use, it is often better for the caller to verify explicitly 104 and report failures in a more robust fashion. Set this to True if doing this 105 in order to avoid running a redundant verification. If the IR is actually 106 invalid, behavior is undefined. 107 )"; 108 109 static const char kOperationGetAsmDocstring[] = 110 R"(Gets the assembly form of the operation with all options available. 111 112 Args: 113 binary: Whether to return a bytes (True) or str (False) object. Defaults to 114 False. 115 ... others ...: See the print() method for common keyword arguments for 116 configuring the printout. 117 Returns: 118 Either a bytes or str object, depending on the setting of the 'binary' 119 argument. 120 )"; 121 122 static const char kOperationStrDunderDocstring[] = 123 R"(Gets the assembly form of the operation with default options. 124 125 If more advanced control over the assembly formatting or I/O options is needed, 126 use the dedicated print or get_asm method, which supports keyword arguments to 127 customize behavior. 128 )"; 129 130 static const char kDumpDocstring[] = 131 R"(Dumps a debug representation of the object to stderr.)"; 132 133 static const char kAppendBlockDocstring[] = 134 R"(Appends a new block, with argument types as positional args. 135 136 Returns: 137 The created block. 138 )"; 139 140 static const char kValueDunderStrDocstring[] = 141 R"(Returns the string form of the value. 142 143 If the value is a block argument, this is the assembly form of its type and the 144 position in the argument list. If the value is an operation result, this is 145 equivalent to printing the operation that produced it. 146 )"; 147 148 //------------------------------------------------------------------------------ 149 // Utilities. 150 //------------------------------------------------------------------------------ 151 152 /// Helper for creating an @classmethod. 153 template <class Func, typename... Args> 154 py::object classmethod(Func f, Args... args) { 155 py::object cf = py::cpp_function(f, args...); 156 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 157 } 158 159 static py::object 160 createCustomDialectWrapper(const std::string &dialectNamespace, 161 py::object dialectDescriptor) { 162 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 163 if (!dialectClass) { 164 // Use the base class. 165 return py::cast(PyDialect(std::move(dialectDescriptor))); 166 } 167 168 // Create the custom implementation. 169 return (*dialectClass)(std::move(dialectDescriptor)); 170 } 171 172 static MlirStringRef toMlirStringRef(const std::string &s) { 173 return mlirStringRefCreate(s.data(), s.size()); 174 } 175 176 /// Wrapper for the global LLVM debugging flag. 177 struct PyGlobalDebugFlag { 178 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 179 180 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } 181 182 static void bind(py::module &m) { 183 // Debug flags. 184 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 185 .def_property_static("flag", &PyGlobalDebugFlag::get, 186 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 187 } 188 }; 189 190 //------------------------------------------------------------------------------ 191 // Collections. 192 //------------------------------------------------------------------------------ 193 194 namespace { 195 196 class PyRegionIterator { 197 public: 198 PyRegionIterator(PyOperationRef operation) 199 : operation(std::move(operation)) {} 200 201 PyRegionIterator &dunderIter() { return *this; } 202 203 PyRegion dunderNext() { 204 operation->checkValid(); 205 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 206 throw py::stop_iteration(); 207 } 208 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 209 return PyRegion(operation, region); 210 } 211 212 static void bind(py::module &m) { 213 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 214 .def("__iter__", &PyRegionIterator::dunderIter) 215 .def("__next__", &PyRegionIterator::dunderNext); 216 } 217 218 private: 219 PyOperationRef operation; 220 int nextIndex = 0; 221 }; 222 223 /// Regions of an op are fixed length and indexed numerically so are represented 224 /// with a sequence-like container. 225 class PyRegionList { 226 public: 227 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 228 229 intptr_t dunderLen() { 230 operation->checkValid(); 231 return mlirOperationGetNumRegions(operation->get()); 232 } 233 234 PyRegion dunderGetItem(intptr_t index) { 235 // dunderLen checks validity. 236 if (index < 0 || index >= dunderLen()) { 237 throw SetPyError(PyExc_IndexError, 238 "attempt to access out of bounds region"); 239 } 240 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 241 return PyRegion(operation, region); 242 } 243 244 static void bind(py::module &m) { 245 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 246 .def("__len__", &PyRegionList::dunderLen) 247 .def("__getitem__", &PyRegionList::dunderGetItem); 248 } 249 250 private: 251 PyOperationRef operation; 252 }; 253 254 class PyBlockIterator { 255 public: 256 PyBlockIterator(PyOperationRef operation, MlirBlock next) 257 : operation(std::move(operation)), next(next) {} 258 259 PyBlockIterator &dunderIter() { return *this; } 260 261 PyBlock dunderNext() { 262 operation->checkValid(); 263 if (mlirBlockIsNull(next)) { 264 throw py::stop_iteration(); 265 } 266 267 PyBlock returnBlock(operation, next); 268 next = mlirBlockGetNextInRegion(next); 269 return returnBlock; 270 } 271 272 static void bind(py::module &m) { 273 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 274 .def("__iter__", &PyBlockIterator::dunderIter) 275 .def("__next__", &PyBlockIterator::dunderNext); 276 } 277 278 private: 279 PyOperationRef operation; 280 MlirBlock next; 281 }; 282 283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 284 /// we present them as a more full-featured list-like container but optimize 285 /// it for forward iteration. Blocks are always owned by a region. 286 class PyBlockList { 287 public: 288 PyBlockList(PyOperationRef operation, MlirRegion region) 289 : operation(std::move(operation)), region(region) {} 290 291 PyBlockIterator dunderIter() { 292 operation->checkValid(); 293 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 294 } 295 296 intptr_t dunderLen() { 297 operation->checkValid(); 298 intptr_t count = 0; 299 MlirBlock block = mlirRegionGetFirstBlock(region); 300 while (!mlirBlockIsNull(block)) { 301 count += 1; 302 block = mlirBlockGetNextInRegion(block); 303 } 304 return count; 305 } 306 307 PyBlock dunderGetItem(intptr_t index) { 308 operation->checkValid(); 309 if (index < 0) { 310 throw SetPyError(PyExc_IndexError, 311 "attempt to access out of bounds block"); 312 } 313 MlirBlock block = mlirRegionGetFirstBlock(region); 314 while (!mlirBlockIsNull(block)) { 315 if (index == 0) { 316 return PyBlock(operation, block); 317 } 318 block = mlirBlockGetNextInRegion(block); 319 index -= 1; 320 } 321 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 322 } 323 324 PyBlock appendBlock(const py::args &pyArgTypes) { 325 operation->checkValid(); 326 llvm::SmallVector<MlirType, 4> argTypes; 327 llvm::SmallVector<MlirLocation, 4> argLocs; 328 argTypes.reserve(pyArgTypes.size()); 329 argLocs.reserve(pyArgTypes.size()); 330 for (auto &pyArg : pyArgTypes) { 331 argTypes.push_back(pyArg.cast<PyType &>()); 332 // TODO: Pass in a proper location here. 333 argLocs.push_back( 334 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 335 } 336 337 MlirBlock block = 338 mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 339 mlirRegionAppendOwnedBlock(region, block); 340 return PyBlock(operation, block); 341 } 342 343 static void bind(py::module &m) { 344 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 345 .def("__getitem__", &PyBlockList::dunderGetItem) 346 .def("__iter__", &PyBlockList::dunderIter) 347 .def("__len__", &PyBlockList::dunderLen) 348 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 349 } 350 351 private: 352 PyOperationRef operation; 353 MlirRegion region; 354 }; 355 356 class PyOperationIterator { 357 public: 358 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 359 : parentOperation(std::move(parentOperation)), next(next) {} 360 361 PyOperationIterator &dunderIter() { return *this; } 362 363 py::object dunderNext() { 364 parentOperation->checkValid(); 365 if (mlirOperationIsNull(next)) { 366 throw py::stop_iteration(); 367 } 368 369 PyOperationRef returnOperation = 370 PyOperation::forOperation(parentOperation->getContext(), next); 371 next = mlirOperationGetNextInBlock(next); 372 return returnOperation->createOpView(); 373 } 374 375 static void bind(py::module &m) { 376 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 377 .def("__iter__", &PyOperationIterator::dunderIter) 378 .def("__next__", &PyOperationIterator::dunderNext); 379 } 380 381 private: 382 PyOperationRef parentOperation; 383 MlirOperation next; 384 }; 385 386 /// Operations are exposed by the C-API as a forward-only linked list. In 387 /// Python, we present them as a more full-featured list-like container but 388 /// optimize it for forward iteration. Iterable operations are always owned 389 /// by a block. 390 class PyOperationList { 391 public: 392 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 393 : parentOperation(std::move(parentOperation)), block(block) {} 394 395 PyOperationIterator dunderIter() { 396 parentOperation->checkValid(); 397 return PyOperationIterator(parentOperation, 398 mlirBlockGetFirstOperation(block)); 399 } 400 401 intptr_t dunderLen() { 402 parentOperation->checkValid(); 403 intptr_t count = 0; 404 MlirOperation childOp = mlirBlockGetFirstOperation(block); 405 while (!mlirOperationIsNull(childOp)) { 406 count += 1; 407 childOp = mlirOperationGetNextInBlock(childOp); 408 } 409 return count; 410 } 411 412 py::object dunderGetItem(intptr_t index) { 413 parentOperation->checkValid(); 414 if (index < 0) { 415 throw SetPyError(PyExc_IndexError, 416 "attempt to access out of bounds operation"); 417 } 418 MlirOperation childOp = mlirBlockGetFirstOperation(block); 419 while (!mlirOperationIsNull(childOp)) { 420 if (index == 0) { 421 return PyOperation::forOperation(parentOperation->getContext(), childOp) 422 ->createOpView(); 423 } 424 childOp = mlirOperationGetNextInBlock(childOp); 425 index -= 1; 426 } 427 throw SetPyError(PyExc_IndexError, 428 "attempt to access out of bounds operation"); 429 } 430 431 static void bind(py::module &m) { 432 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 433 .def("__getitem__", &PyOperationList::dunderGetItem) 434 .def("__iter__", &PyOperationList::dunderIter) 435 .def("__len__", &PyOperationList::dunderLen); 436 } 437 438 private: 439 PyOperationRef parentOperation; 440 MlirBlock block; 441 }; 442 443 } // namespace 444 445 //------------------------------------------------------------------------------ 446 // PyMlirContext 447 //------------------------------------------------------------------------------ 448 449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 450 py::gil_scoped_acquire acquire; 451 auto &liveContexts = getLiveContexts(); 452 liveContexts[context.ptr] = this; 453 } 454 455 PyMlirContext::~PyMlirContext() { 456 // Note that the only public way to construct an instance is via the 457 // forContext method, which always puts the associated handle into 458 // liveContexts. 459 py::gil_scoped_acquire acquire; 460 getLiveContexts().erase(context.ptr); 461 mlirContextDestroy(context); 462 } 463 464 py::object PyMlirContext::getCapsule() { 465 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 466 } 467 468 py::object PyMlirContext::createFromCapsule(py::object capsule) { 469 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 470 if (mlirContextIsNull(rawContext)) 471 throw py::error_already_set(); 472 return forContext(rawContext).releaseObject(); 473 } 474 475 PyMlirContext *PyMlirContext::createNewContextForInit() { 476 MlirContext context = mlirContextCreate(); 477 mlirRegisterAllDialects(context); 478 return new PyMlirContext(context); 479 } 480 481 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 482 py::gil_scoped_acquire acquire; 483 auto &liveContexts = getLiveContexts(); 484 auto it = liveContexts.find(context.ptr); 485 if (it == liveContexts.end()) { 486 // Create. 487 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 488 py::object pyRef = py::cast(unownedContextWrapper); 489 assert(pyRef && "cast to py::object failed"); 490 liveContexts[context.ptr] = unownedContextWrapper; 491 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 492 } 493 // Use existing. 494 py::object pyRef = py::cast(it->second); 495 return PyMlirContextRef(it->second, std::move(pyRef)); 496 } 497 498 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 499 static LiveContextMap liveContexts; 500 return liveContexts; 501 } 502 503 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 504 505 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 506 507 size_t PyMlirContext::clearLiveOperations() { 508 for (auto &op : liveOperations) 509 op.second.second->setInvalid(); 510 size_t numInvalidated = liveOperations.size(); 511 liveOperations.clear(); 512 return numInvalidated; 513 } 514 515 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 516 517 pybind11::object PyMlirContext::contextEnter() { 518 return PyThreadContextEntry::pushContext(*this); 519 } 520 521 void PyMlirContext::contextExit(const pybind11::object &excType, 522 const pybind11::object &excVal, 523 const pybind11::object &excTb) { 524 PyThreadContextEntry::popContext(*this); 525 } 526 527 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { 528 // Note that ownership is transferred to the delete callback below by way of 529 // an explicit inc_ref (borrow). 530 PyDiagnosticHandler *pyHandler = 531 new PyDiagnosticHandler(get(), std::move(callback)); 532 py::object pyHandlerObject = 533 py::cast(pyHandler, py::return_value_policy::take_ownership); 534 pyHandlerObject.inc_ref(); 535 536 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 537 // guaranteed to be known to pybind. 538 auto handlerCallback = 539 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 540 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 541 py::object pyDiagnosticObject = 542 py::cast(pyDiagnostic, py::return_value_policy::take_ownership); 543 544 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 545 bool result = false; 546 { 547 // Since this can be called from arbitrary C++ contexts, always get the 548 // gil. 549 py::gil_scoped_acquire gil; 550 try { 551 result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); 552 } catch (std::exception &e) { 553 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 554 e.what()); 555 pyHandler->hadError = true; 556 } 557 } 558 559 pyDiagnostic->invalidate(); 560 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 561 }; 562 auto deleteCallback = +[](void *userData) { 563 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 564 assert(pyHandler->registeredID && "handler is not registered"); 565 pyHandler->registeredID.reset(); 566 567 // Decrement reference, balancing the inc_ref() above. 568 py::object pyHandlerObject = 569 py::cast(pyHandler, py::return_value_policy::reference); 570 pyHandlerObject.dec_ref(); 571 }; 572 573 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 574 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 575 return pyHandlerObject; 576 } 577 578 PyMlirContext &DefaultingPyMlirContext::resolve() { 579 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 580 if (!context) { 581 throw SetPyError( 582 PyExc_RuntimeError, 583 "An MLIR function requires a Context but none was provided in the call " 584 "or from the surrounding environment. Either pass to the function with " 585 "a 'context=' argument or establish a default using 'with Context():'"); 586 } 587 return *context; 588 } 589 590 //------------------------------------------------------------------------------ 591 // PyThreadContextEntry management 592 //------------------------------------------------------------------------------ 593 594 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 595 static thread_local std::vector<PyThreadContextEntry> stack; 596 return stack; 597 } 598 599 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 600 auto &stack = getStack(); 601 if (stack.empty()) 602 return nullptr; 603 return &stack.back(); 604 } 605 606 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 607 py::object insertionPoint, 608 py::object location) { 609 auto &stack = getStack(); 610 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 611 std::move(location)); 612 // If the new stack has more than one entry and the context of the new top 613 // entry matches the previous, copy the insertionPoint and location from the 614 // previous entry if missing from the new top entry. 615 if (stack.size() > 1) { 616 auto &prev = *(stack.rbegin() + 1); 617 auto ¤t = stack.back(); 618 if (current.context.is(prev.context)) { 619 // Default non-context objects from the previous entry. 620 if (!current.insertionPoint) 621 current.insertionPoint = prev.insertionPoint; 622 if (!current.location) 623 current.location = prev.location; 624 } 625 } 626 } 627 628 PyMlirContext *PyThreadContextEntry::getContext() { 629 if (!context) 630 return nullptr; 631 return py::cast<PyMlirContext *>(context); 632 } 633 634 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 635 if (!insertionPoint) 636 return nullptr; 637 return py::cast<PyInsertionPoint *>(insertionPoint); 638 } 639 640 PyLocation *PyThreadContextEntry::getLocation() { 641 if (!location) 642 return nullptr; 643 return py::cast<PyLocation *>(location); 644 } 645 646 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 647 auto *tos = getTopOfStack(); 648 return tos ? tos->getContext() : nullptr; 649 } 650 651 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 652 auto *tos = getTopOfStack(); 653 return tos ? tos->getInsertionPoint() : nullptr; 654 } 655 656 PyLocation *PyThreadContextEntry::getDefaultLocation() { 657 auto *tos = getTopOfStack(); 658 return tos ? tos->getLocation() : nullptr; 659 } 660 661 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 662 py::object contextObj = py::cast(context); 663 push(FrameKind::Context, /*context=*/contextObj, 664 /*insertionPoint=*/py::object(), 665 /*location=*/py::object()); 666 return contextObj; 667 } 668 669 void PyThreadContextEntry::popContext(PyMlirContext &context) { 670 auto &stack = getStack(); 671 if (stack.empty()) 672 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 673 auto &tos = stack.back(); 674 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 675 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 676 stack.pop_back(); 677 } 678 679 py::object 680 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 681 py::object contextObj = 682 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 683 py::object insertionPointObj = py::cast(insertionPoint); 684 push(FrameKind::InsertionPoint, 685 /*context=*/contextObj, 686 /*insertionPoint=*/insertionPointObj, 687 /*location=*/py::object()); 688 return insertionPointObj; 689 } 690 691 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 692 auto &stack = getStack(); 693 if (stack.empty()) 694 throw SetPyError(PyExc_RuntimeError, 695 "Unbalanced InsertionPoint enter/exit"); 696 auto &tos = stack.back(); 697 if (tos.frameKind != FrameKind::InsertionPoint && 698 tos.getInsertionPoint() != &insertionPoint) 699 throw SetPyError(PyExc_RuntimeError, 700 "Unbalanced InsertionPoint enter/exit"); 701 stack.pop_back(); 702 } 703 704 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 705 py::object contextObj = location.getContext().getObject(); 706 py::object locationObj = py::cast(location); 707 push(FrameKind::Location, /*context=*/contextObj, 708 /*insertionPoint=*/py::object(), 709 /*location=*/locationObj); 710 return locationObj; 711 } 712 713 void PyThreadContextEntry::popLocation(PyLocation &location) { 714 auto &stack = getStack(); 715 if (stack.empty()) 716 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 717 auto &tos = stack.back(); 718 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 719 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 720 stack.pop_back(); 721 } 722 723 //------------------------------------------------------------------------------ 724 // PyDiagnostic* 725 //------------------------------------------------------------------------------ 726 727 void PyDiagnostic::invalidate() { 728 valid = false; 729 if (materializedNotes) { 730 for (auto ¬eObject : *materializedNotes) { 731 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); 732 note->invalidate(); 733 } 734 } 735 } 736 737 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 738 py::object callback) 739 : context(context), callback(std::move(callback)) {} 740 741 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 742 743 void PyDiagnosticHandler::detach() { 744 if (!registeredID) 745 return; 746 MlirDiagnosticHandlerID localID = *registeredID; 747 mlirContextDetachDiagnosticHandler(context, localID); 748 assert(!registeredID && "should have unregistered"); 749 // Not strictly necessary but keeps stale pointers from being around to cause 750 // issues. 751 context = {nullptr}; 752 } 753 754 void PyDiagnostic::checkValid() { 755 if (!valid) { 756 throw std::invalid_argument( 757 "Diagnostic is invalid (used outside of callback)"); 758 } 759 } 760 761 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 762 checkValid(); 763 return mlirDiagnosticGetSeverity(diagnostic); 764 } 765 766 PyLocation PyDiagnostic::getLocation() { 767 checkValid(); 768 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 769 MlirContext context = mlirLocationGetContext(loc); 770 return PyLocation(PyMlirContext::forContext(context), loc); 771 } 772 773 py::str PyDiagnostic::getMessage() { 774 checkValid(); 775 py::object fileObject = py::module::import("io").attr("StringIO")(); 776 PyFileAccumulator accum(fileObject, /*binary=*/false); 777 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 778 return fileObject.attr("getvalue")(); 779 } 780 781 py::tuple PyDiagnostic::getNotes() { 782 checkValid(); 783 if (materializedNotes) 784 return *materializedNotes; 785 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 786 materializedNotes = py::tuple(numNotes); 787 for (intptr_t i = 0; i < numNotes; ++i) { 788 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 789 py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); 790 PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); 791 } 792 return *materializedNotes; 793 } 794 795 //------------------------------------------------------------------------------ 796 // PyDialect, PyDialectDescriptor, PyDialects 797 //------------------------------------------------------------------------------ 798 799 MlirDialect PyDialects::getDialectForKey(const std::string &key, 800 bool attrError) { 801 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 802 {key.data(), key.size()}); 803 if (mlirDialectIsNull(dialect)) { 804 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 805 Twine("Dialect '") + key + "' not found"); 806 } 807 return dialect; 808 } 809 810 //------------------------------------------------------------------------------ 811 // PyLocation 812 //------------------------------------------------------------------------------ 813 814 py::object PyLocation::getCapsule() { 815 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 816 } 817 818 PyLocation PyLocation::createFromCapsule(py::object capsule) { 819 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 820 if (mlirLocationIsNull(rawLoc)) 821 throw py::error_already_set(); 822 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 823 rawLoc); 824 } 825 826 py::object PyLocation::contextEnter() { 827 return PyThreadContextEntry::pushLocation(*this); 828 } 829 830 void PyLocation::contextExit(const pybind11::object &excType, 831 const pybind11::object &excVal, 832 const pybind11::object &excTb) { 833 PyThreadContextEntry::popLocation(*this); 834 } 835 836 PyLocation &DefaultingPyLocation::resolve() { 837 auto *location = PyThreadContextEntry::getDefaultLocation(); 838 if (!location) { 839 throw SetPyError( 840 PyExc_RuntimeError, 841 "An MLIR function requires a Location but none was provided in the " 842 "call or from the surrounding environment. Either pass to the function " 843 "with a 'loc=' argument or establish a default using 'with loc:'"); 844 } 845 return *location; 846 } 847 848 //------------------------------------------------------------------------------ 849 // PyModule 850 //------------------------------------------------------------------------------ 851 852 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 853 : BaseContextObject(std::move(contextRef)), module(module) {} 854 855 PyModule::~PyModule() { 856 py::gil_scoped_acquire acquire; 857 auto &liveModules = getContext()->liveModules; 858 assert(liveModules.count(module.ptr) == 1 && 859 "destroying module not in live map"); 860 liveModules.erase(module.ptr); 861 mlirModuleDestroy(module); 862 } 863 864 PyModuleRef PyModule::forModule(MlirModule module) { 865 MlirContext context = mlirModuleGetContext(module); 866 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 867 868 py::gil_scoped_acquire acquire; 869 auto &liveModules = contextRef->liveModules; 870 auto it = liveModules.find(module.ptr); 871 if (it == liveModules.end()) { 872 // Create. 873 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 874 // Note that the default return value policy on cast is automatic_reference, 875 // which does not take ownership (delete will not be called). 876 // Just be explicit. 877 py::object pyRef = 878 py::cast(unownedModule, py::return_value_policy::take_ownership); 879 unownedModule->handle = pyRef; 880 liveModules[module.ptr] = 881 std::make_pair(unownedModule->handle, unownedModule); 882 return PyModuleRef(unownedModule, std::move(pyRef)); 883 } 884 // Use existing. 885 PyModule *existing = it->second.second; 886 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 887 return PyModuleRef(existing, std::move(pyRef)); 888 } 889 890 py::object PyModule::createFromCapsule(py::object capsule) { 891 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 892 if (mlirModuleIsNull(rawModule)) 893 throw py::error_already_set(); 894 return forModule(rawModule).releaseObject(); 895 } 896 897 py::object PyModule::getCapsule() { 898 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 899 } 900 901 //------------------------------------------------------------------------------ 902 // PyOperation 903 //------------------------------------------------------------------------------ 904 905 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 906 : BaseContextObject(std::move(contextRef)), operation(operation) {} 907 908 PyOperation::~PyOperation() { 909 // If the operation has already been invalidated there is nothing to do. 910 if (!valid) 911 return; 912 auto &liveOperations = getContext()->liveOperations; 913 assert(liveOperations.count(operation.ptr) == 1 && 914 "destroying operation not in live map"); 915 liveOperations.erase(operation.ptr); 916 if (!isAttached()) { 917 mlirOperationDestroy(operation); 918 } 919 } 920 921 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 922 MlirOperation operation, 923 py::object parentKeepAlive) { 924 auto &liveOperations = contextRef->liveOperations; 925 // Create. 926 PyOperation *unownedOperation = 927 new PyOperation(std::move(contextRef), operation); 928 // Note that the default return value policy on cast is automatic_reference, 929 // which does not take ownership (delete will not be called). 930 // Just be explicit. 931 py::object pyRef = 932 py::cast(unownedOperation, py::return_value_policy::take_ownership); 933 unownedOperation->handle = pyRef; 934 if (parentKeepAlive) { 935 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 936 } 937 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 938 return PyOperationRef(unownedOperation, std::move(pyRef)); 939 } 940 941 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 942 MlirOperation operation, 943 py::object parentKeepAlive) { 944 auto &liveOperations = contextRef->liveOperations; 945 auto it = liveOperations.find(operation.ptr); 946 if (it == liveOperations.end()) { 947 // Create. 948 return createInstance(std::move(contextRef), operation, 949 std::move(parentKeepAlive)); 950 } 951 // Use existing. 952 PyOperation *existing = it->second.second; 953 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 954 return PyOperationRef(existing, std::move(pyRef)); 955 } 956 957 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 958 MlirOperation operation, 959 py::object parentKeepAlive) { 960 auto &liveOperations = contextRef->liveOperations; 961 assert(liveOperations.count(operation.ptr) == 0 && 962 "cannot create detached operation that already exists"); 963 (void)liveOperations; 964 965 PyOperationRef created = createInstance(std::move(contextRef), operation, 966 std::move(parentKeepAlive)); 967 created->attached = false; 968 return created; 969 } 970 971 void PyOperation::checkValid() const { 972 if (!valid) { 973 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 974 } 975 } 976 977 void PyOperationBase::print(py::object fileObject, bool binary, 978 llvm::Optional<int64_t> largeElementsLimit, 979 bool enableDebugInfo, bool prettyDebugInfo, 980 bool printGenericOpForm, bool useLocalScope, 981 bool assumeVerified) { 982 PyOperation &operation = getOperation(); 983 operation.checkValid(); 984 if (fileObject.is_none()) 985 fileObject = py::module::import("sys").attr("stdout"); 986 987 if (!assumeVerified && !printGenericOpForm && 988 !mlirOperationVerify(operation)) { 989 std::string message("// Verification failed, printing generic form\n"); 990 if (binary) { 991 fileObject.attr("write")(py::bytes(message)); 992 } else { 993 fileObject.attr("write")(py::str(message)); 994 } 995 printGenericOpForm = true; 996 } 997 998 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 999 if (largeElementsLimit) 1000 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 1001 if (enableDebugInfo) 1002 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 1003 if (printGenericOpForm) 1004 mlirOpPrintingFlagsPrintGenericOpForm(flags); 1005 1006 PyFileAccumulator accum(fileObject, binary); 1007 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1008 accum.getUserData()); 1009 mlirOpPrintingFlagsDestroy(flags); 1010 } 1011 1012 py::object PyOperationBase::getAsm(bool binary, 1013 llvm::Optional<int64_t> largeElementsLimit, 1014 bool enableDebugInfo, bool prettyDebugInfo, 1015 bool printGenericOpForm, bool useLocalScope, 1016 bool assumeVerified) { 1017 py::object fileObject; 1018 if (binary) { 1019 fileObject = py::module::import("io").attr("BytesIO")(); 1020 } else { 1021 fileObject = py::module::import("io").attr("StringIO")(); 1022 } 1023 print(fileObject, /*binary=*/binary, 1024 /*largeElementsLimit=*/largeElementsLimit, 1025 /*enableDebugInfo=*/enableDebugInfo, 1026 /*prettyDebugInfo=*/prettyDebugInfo, 1027 /*printGenericOpForm=*/printGenericOpForm, 1028 /*useLocalScope=*/useLocalScope, 1029 /*assumeVerified=*/assumeVerified); 1030 1031 return fileObject.attr("getvalue")(); 1032 } 1033 1034 void PyOperationBase::moveAfter(PyOperationBase &other) { 1035 PyOperation &operation = getOperation(); 1036 PyOperation &otherOp = other.getOperation(); 1037 operation.checkValid(); 1038 otherOp.checkValid(); 1039 mlirOperationMoveAfter(operation, otherOp); 1040 operation.parentKeepAlive = otherOp.parentKeepAlive; 1041 } 1042 1043 void PyOperationBase::moveBefore(PyOperationBase &other) { 1044 PyOperation &operation = getOperation(); 1045 PyOperation &otherOp = other.getOperation(); 1046 operation.checkValid(); 1047 otherOp.checkValid(); 1048 mlirOperationMoveBefore(operation, otherOp); 1049 operation.parentKeepAlive = otherOp.parentKeepAlive; 1050 } 1051 1052 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 1053 checkValid(); 1054 if (!isAttached()) 1055 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 1056 MlirOperation operation = mlirOperationGetParentOperation(get()); 1057 if (mlirOperationIsNull(operation)) 1058 return {}; 1059 return PyOperation::forOperation(getContext(), operation); 1060 } 1061 1062 PyBlock PyOperation::getBlock() { 1063 checkValid(); 1064 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 1065 MlirBlock block = mlirOperationGetBlock(get()); 1066 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1067 assert(parentOperation && "Operation has no parent"); 1068 return PyBlock{std::move(*parentOperation), block}; 1069 } 1070 1071 py::object PyOperation::getCapsule() { 1072 checkValid(); 1073 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 1074 } 1075 1076 py::object PyOperation::createFromCapsule(py::object capsule) { 1077 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1078 if (mlirOperationIsNull(rawOperation)) 1079 throw py::error_already_set(); 1080 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1081 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1082 .releaseObject(); 1083 } 1084 1085 static void maybeInsertOperation(PyOperationRef &op, 1086 const py::object &maybeIp) { 1087 // InsertPoint active? 1088 if (!maybeIp.is(py::cast(false))) { 1089 PyInsertionPoint *ip; 1090 if (maybeIp.is_none()) { 1091 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1092 } else { 1093 ip = py::cast<PyInsertionPoint *>(maybeIp); 1094 } 1095 if (ip) 1096 ip->insert(*op.get()); 1097 } 1098 } 1099 1100 py::object PyOperation::create( 1101 const std::string &name, llvm::Optional<std::vector<PyType *>> results, 1102 llvm::Optional<std::vector<PyValue *>> operands, 1103 llvm::Optional<py::dict> attributes, 1104 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 1105 DefaultingPyLocation location, const py::object &maybeIp) { 1106 llvm::SmallVector<MlirValue, 4> mlirOperands; 1107 llvm::SmallVector<MlirType, 4> mlirResults; 1108 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1109 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1110 1111 // General parameter validation. 1112 if (regions < 0) 1113 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 1114 1115 // Unpack/validate operands. 1116 if (operands) { 1117 mlirOperands.reserve(operands->size()); 1118 for (PyValue *operand : *operands) { 1119 if (!operand) 1120 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 1121 mlirOperands.push_back(operand->get()); 1122 } 1123 } 1124 1125 // Unpack/validate results. 1126 if (results) { 1127 mlirResults.reserve(results->size()); 1128 for (PyType *result : *results) { 1129 // TODO: Verify result type originate from the same context. 1130 if (!result) 1131 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 1132 mlirResults.push_back(*result); 1133 } 1134 } 1135 // Unpack/validate attributes. 1136 if (attributes) { 1137 mlirAttributes.reserve(attributes->size()); 1138 for (auto &it : *attributes) { 1139 std::string key; 1140 try { 1141 key = it.first.cast<std::string>(); 1142 } catch (py::cast_error &err) { 1143 std::string msg = "Invalid attribute key (not a string) when " 1144 "attempting to create the operation \"" + 1145 name + "\" (" + err.what() + ")"; 1146 throw py::cast_error(msg); 1147 } 1148 try { 1149 auto &attribute = it.second.cast<PyAttribute &>(); 1150 // TODO: Verify attribute originates from the same context. 1151 mlirAttributes.emplace_back(std::move(key), attribute); 1152 } catch (py::reference_cast_error &) { 1153 // This exception seems thrown when the value is "None". 1154 std::string msg = 1155 "Found an invalid (`None`?) attribute value for the key \"" + key + 1156 "\" when attempting to create the operation \"" + name + "\""; 1157 throw py::cast_error(msg); 1158 } catch (py::cast_error &err) { 1159 std::string msg = "Invalid attribute value for the key \"" + key + 1160 "\" when attempting to create the operation \"" + 1161 name + "\" (" + err.what() + ")"; 1162 throw py::cast_error(msg); 1163 } 1164 } 1165 } 1166 // Unpack/validate successors. 1167 if (successors) { 1168 mlirSuccessors.reserve(successors->size()); 1169 for (auto *successor : *successors) { 1170 // TODO: Verify successor originate from the same context. 1171 if (!successor) 1172 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1173 mlirSuccessors.push_back(successor->get()); 1174 } 1175 } 1176 1177 // Apply unpacked/validated to the operation state. Beyond this 1178 // point, exceptions cannot be thrown or else the state will leak. 1179 MlirOperationState state = 1180 mlirOperationStateGet(toMlirStringRef(name), location); 1181 if (!mlirOperands.empty()) 1182 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1183 mlirOperands.data()); 1184 if (!mlirResults.empty()) 1185 mlirOperationStateAddResults(&state, mlirResults.size(), 1186 mlirResults.data()); 1187 if (!mlirAttributes.empty()) { 1188 // Note that the attribute names directly reference bytes in 1189 // mlirAttributes, so that vector must not be changed from here 1190 // on. 1191 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1192 mlirNamedAttributes.reserve(mlirAttributes.size()); 1193 for (auto &it : mlirAttributes) 1194 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1195 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1196 toMlirStringRef(it.first)), 1197 it.second)); 1198 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1199 mlirNamedAttributes.data()); 1200 } 1201 if (!mlirSuccessors.empty()) 1202 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1203 mlirSuccessors.data()); 1204 if (regions) { 1205 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1206 mlirRegions.resize(regions); 1207 for (int i = 0; i < regions; ++i) 1208 mlirRegions[i] = mlirRegionCreate(); 1209 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1210 mlirRegions.data()); 1211 } 1212 1213 // Construct the operation. 1214 MlirOperation operation = mlirOperationCreate(&state); 1215 PyOperationRef created = 1216 PyOperation::createDetached(location->getContext(), operation); 1217 maybeInsertOperation(created, maybeIp); 1218 1219 return created->createOpView(); 1220 } 1221 1222 py::object PyOperation::clone(const py::object &maybeIp) { 1223 MlirOperation clonedOperation = mlirOperationClone(operation); 1224 PyOperationRef cloned = 1225 PyOperation::createDetached(getContext(), clonedOperation); 1226 maybeInsertOperation(cloned, maybeIp); 1227 1228 return cloned->createOpView(); 1229 } 1230 1231 py::object PyOperation::createOpView() { 1232 checkValid(); 1233 MlirIdentifier ident = mlirOperationGetName(get()); 1234 MlirStringRef identStr = mlirIdentifierStr(ident); 1235 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1236 StringRef(identStr.data, identStr.length)); 1237 if (opViewClass) 1238 return (*opViewClass)(getRef().getObject()); 1239 return py::cast(PyOpView(getRef().getObject())); 1240 } 1241 1242 void PyOperation::erase() { 1243 checkValid(); 1244 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1245 // Python reference to a child operation is live. All children should also 1246 // have their `valid` bit set to false. 1247 auto &liveOperations = getContext()->liveOperations; 1248 if (liveOperations.count(operation.ptr)) 1249 liveOperations.erase(operation.ptr); 1250 mlirOperationDestroy(operation); 1251 valid = false; 1252 } 1253 1254 //------------------------------------------------------------------------------ 1255 // PyOpView 1256 //------------------------------------------------------------------------------ 1257 1258 py::object PyOpView::buildGeneric( 1259 const py::object &cls, py::list resultTypeList, py::list operandList, 1260 llvm::Optional<py::dict> attributes, 1261 llvm::Optional<std::vector<PyBlock *>> successors, 1262 llvm::Optional<int> regions, DefaultingPyLocation location, 1263 const py::object &maybeIp) { 1264 PyMlirContextRef context = location->getContext(); 1265 // Class level operation construction metadata. 1266 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1267 // Operand and result segment specs are either none, which does no 1268 // variadic unpacking, or a list of ints with segment sizes, where each 1269 // element is either a positive number (typically 1 for a scalar) or -1 to 1270 // indicate that it is derived from the length of the same-indexed operand 1271 // or result (implying that it is a list at that position). 1272 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1273 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1274 1275 std::vector<uint32_t> operandSegmentLengths; 1276 std::vector<uint32_t> resultSegmentLengths; 1277 1278 // Validate/determine region count. 1279 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1280 int opMinRegionCount = std::get<0>(opRegionSpec); 1281 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1282 if (!regions) { 1283 regions = opMinRegionCount; 1284 } 1285 if (*regions < opMinRegionCount) { 1286 throw py::value_error( 1287 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1288 llvm::Twine(opMinRegionCount) + 1289 " regions but was built with regions=" + llvm::Twine(*regions)) 1290 .str()); 1291 } 1292 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1293 throw py::value_error( 1294 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1295 llvm::Twine(opMinRegionCount) + 1296 " regions but was built with regions=" + llvm::Twine(*regions)) 1297 .str()); 1298 } 1299 1300 // Unpack results. 1301 std::vector<PyType *> resultTypes; 1302 resultTypes.reserve(resultTypeList.size()); 1303 if (resultSegmentSpecObj.is_none()) { 1304 // Non-variadic result unpacking. 1305 for (const auto &it : llvm::enumerate(resultTypeList)) { 1306 try { 1307 resultTypes.push_back(py::cast<PyType *>(it.value())); 1308 if (!resultTypes.back()) 1309 throw py::cast_error(); 1310 } catch (py::cast_error &err) { 1311 throw py::value_error((llvm::Twine("Result ") + 1312 llvm::Twine(it.index()) + " of operation \"" + 1313 name + "\" must be a Type (" + err.what() + ")") 1314 .str()); 1315 } 1316 } 1317 } else { 1318 // Sized result unpacking. 1319 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1320 if (resultSegmentSpec.size() != resultTypeList.size()) { 1321 throw py::value_error((llvm::Twine("Operation \"") + name + 1322 "\" requires " + 1323 llvm::Twine(resultSegmentSpec.size()) + 1324 " result segments but was provided " + 1325 llvm::Twine(resultTypeList.size())) 1326 .str()); 1327 } 1328 resultSegmentLengths.reserve(resultTypeList.size()); 1329 for (const auto &it : 1330 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1331 int segmentSpec = std::get<1>(it.value()); 1332 if (segmentSpec == 1 || segmentSpec == 0) { 1333 // Unpack unary element. 1334 try { 1335 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1336 if (resultType) { 1337 resultTypes.push_back(resultType); 1338 resultSegmentLengths.push_back(1); 1339 } else if (segmentSpec == 0) { 1340 // Allowed to be optional. 1341 resultSegmentLengths.push_back(0); 1342 } else { 1343 throw py::cast_error("was None and result is not optional"); 1344 } 1345 } catch (py::cast_error &err) { 1346 throw py::value_error((llvm::Twine("Result ") + 1347 llvm::Twine(it.index()) + " of operation \"" + 1348 name + "\" must be a Type (" + err.what() + 1349 ")") 1350 .str()); 1351 } 1352 } else if (segmentSpec == -1) { 1353 // Unpack sequence by appending. 1354 try { 1355 if (std::get<0>(it.value()).is_none()) { 1356 // Treat it as an empty list. 1357 resultSegmentLengths.push_back(0); 1358 } else { 1359 // Unpack the list. 1360 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1361 for (py::object segmentItem : segment) { 1362 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1363 if (!resultTypes.back()) { 1364 throw py::cast_error("contained a None item"); 1365 } 1366 } 1367 resultSegmentLengths.push_back(segment.size()); 1368 } 1369 } catch (std::exception &err) { 1370 // NOTE: Sloppy to be using a catch-all here, but there are at least 1371 // three different unrelated exceptions that can be thrown in the 1372 // above "casts". Just keep the scope above small and catch them all. 1373 throw py::value_error((llvm::Twine("Result ") + 1374 llvm::Twine(it.index()) + " of operation \"" + 1375 name + "\" must be a Sequence of Types (" + 1376 err.what() + ")") 1377 .str()); 1378 } 1379 } else { 1380 throw py::value_error("Unexpected segment spec"); 1381 } 1382 } 1383 } 1384 1385 // Unpack operands. 1386 std::vector<PyValue *> operands; 1387 operands.reserve(operands.size()); 1388 if (operandSegmentSpecObj.is_none()) { 1389 // Non-sized operand unpacking. 1390 for (const auto &it : llvm::enumerate(operandList)) { 1391 try { 1392 operands.push_back(py::cast<PyValue *>(it.value())); 1393 if (!operands.back()) 1394 throw py::cast_error(); 1395 } catch (py::cast_error &err) { 1396 throw py::value_error((llvm::Twine("Operand ") + 1397 llvm::Twine(it.index()) + " of operation \"" + 1398 name + "\" must be a Value (" + err.what() + ")") 1399 .str()); 1400 } 1401 } 1402 } else { 1403 // Sized operand unpacking. 1404 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1405 if (operandSegmentSpec.size() != operandList.size()) { 1406 throw py::value_error((llvm::Twine("Operation \"") + name + 1407 "\" requires " + 1408 llvm::Twine(operandSegmentSpec.size()) + 1409 "operand segments but was provided " + 1410 llvm::Twine(operandList.size())) 1411 .str()); 1412 } 1413 operandSegmentLengths.reserve(operandList.size()); 1414 for (const auto &it : 1415 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1416 int segmentSpec = std::get<1>(it.value()); 1417 if (segmentSpec == 1 || segmentSpec == 0) { 1418 // Unpack unary element. 1419 try { 1420 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1421 if (operandValue) { 1422 operands.push_back(operandValue); 1423 operandSegmentLengths.push_back(1); 1424 } else if (segmentSpec == 0) { 1425 // Allowed to be optional. 1426 operandSegmentLengths.push_back(0); 1427 } else { 1428 throw py::cast_error("was None and operand is not optional"); 1429 } 1430 } catch (py::cast_error &err) { 1431 throw py::value_error((llvm::Twine("Operand ") + 1432 llvm::Twine(it.index()) + " of operation \"" + 1433 name + "\" must be a Value (" + err.what() + 1434 ")") 1435 .str()); 1436 } 1437 } else if (segmentSpec == -1) { 1438 // Unpack sequence by appending. 1439 try { 1440 if (std::get<0>(it.value()).is_none()) { 1441 // Treat it as an empty list. 1442 operandSegmentLengths.push_back(0); 1443 } else { 1444 // Unpack the list. 1445 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1446 for (py::object segmentItem : segment) { 1447 operands.push_back(py::cast<PyValue *>(segmentItem)); 1448 if (!operands.back()) { 1449 throw py::cast_error("contained a None item"); 1450 } 1451 } 1452 operandSegmentLengths.push_back(segment.size()); 1453 } 1454 } catch (std::exception &err) { 1455 // NOTE: Sloppy to be using a catch-all here, but there are at least 1456 // three different unrelated exceptions that can be thrown in the 1457 // above "casts". Just keep the scope above small and catch them all. 1458 throw py::value_error((llvm::Twine("Operand ") + 1459 llvm::Twine(it.index()) + " of operation \"" + 1460 name + "\" must be a Sequence of Values (" + 1461 err.what() + ")") 1462 .str()); 1463 } 1464 } else { 1465 throw py::value_error("Unexpected segment spec"); 1466 } 1467 } 1468 } 1469 1470 // Merge operand/result segment lengths into attributes if needed. 1471 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1472 // Dup. 1473 if (attributes) { 1474 attributes = py::dict(*attributes); 1475 } else { 1476 attributes = py::dict(); 1477 } 1478 if (attributes->contains("result_segment_sizes") || 1479 attributes->contains("operand_segment_sizes")) { 1480 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1481 "'operand_segment_sizes' attribute is unsupported. " 1482 "Use Operation.create for such low-level access."); 1483 } 1484 1485 // Add result_segment_sizes attribute. 1486 if (!resultSegmentLengths.empty()) { 1487 int64_t size = resultSegmentLengths.size(); 1488 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1489 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1490 resultSegmentLengths.size(), resultSegmentLengths.data()); 1491 (*attributes)["result_segment_sizes"] = 1492 PyAttribute(context, segmentLengthAttr); 1493 } 1494 1495 // Add operand_segment_sizes attribute. 1496 if (!operandSegmentLengths.empty()) { 1497 int64_t size = operandSegmentLengths.size(); 1498 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1499 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1500 operandSegmentLengths.size(), operandSegmentLengths.data()); 1501 (*attributes)["operand_segment_sizes"] = 1502 PyAttribute(context, segmentLengthAttr); 1503 } 1504 } 1505 1506 // Delegate to create. 1507 return PyOperation::create(name, 1508 /*results=*/std::move(resultTypes), 1509 /*operands=*/std::move(operands), 1510 /*attributes=*/std::move(attributes), 1511 /*successors=*/std::move(successors), 1512 /*regions=*/*regions, location, maybeIp); 1513 } 1514 1515 PyOpView::PyOpView(const py::object &operationObject) 1516 // Casting through the PyOperationBase base-class and then back to the 1517 // Operation lets us accept any PyOperationBase subclass. 1518 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1519 operationObject(operation.getRef().getObject()) {} 1520 1521 py::object PyOpView::createRawSubclass(const py::object &userClass) { 1522 // This is... a little gross. The typical pattern is to have a pure python 1523 // class that extends OpView like: 1524 // class AddFOp(_cext.ir.OpView): 1525 // def __init__(self, loc, lhs, rhs): 1526 // operation = loc.context.create_operation( 1527 // "addf", lhs, rhs, results=[lhs.type]) 1528 // super().__init__(operation) 1529 // 1530 // I.e. The goal of the user facing type is to provide a nice constructor 1531 // that has complete freedom for the op under construction. This is at odds 1532 // with our other desire to sometimes create this object by just passing an 1533 // operation (to initialize the base class). We could do *arg and **kwargs 1534 // munging to try to make it work, but instead, we synthesize a new class 1535 // on the fly which extends this user class (AddFOp in this example) and 1536 // *give it* the base class's __init__ method, thus bypassing the 1537 // intermediate subclass's __init__ method entirely. While slightly, 1538 // underhanded, this is safe/legal because the type hierarchy has not changed 1539 // (we just added a new leaf) and we aren't mucking around with __new__. 1540 // Typically, this new class will be stored on the original as "_Raw" and will 1541 // be used for casts and other things that need a variant of the class that 1542 // is initialized purely from an operation. 1543 py::object parentMetaclass = 1544 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1545 py::dict attributes; 1546 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1547 // now. 1548 // auto opViewType = py::type::of<PyOpView>(); 1549 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1550 attributes["__init__"] = opViewType.attr("__init__"); 1551 py::str origName = userClass.attr("__name__"); 1552 py::str newName = py::str("_") + origName; 1553 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1554 } 1555 1556 //------------------------------------------------------------------------------ 1557 // PyInsertionPoint. 1558 //------------------------------------------------------------------------------ 1559 1560 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1561 1562 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1563 : refOperation(beforeOperationBase.getOperation().getRef()), 1564 block((*refOperation)->getBlock()) {} 1565 1566 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1567 PyOperation &operation = operationBase.getOperation(); 1568 if (operation.isAttached()) 1569 throw SetPyError(PyExc_ValueError, 1570 "Attempt to insert operation that is already attached"); 1571 block.getParentOperation()->checkValid(); 1572 MlirOperation beforeOp = {nullptr}; 1573 if (refOperation) { 1574 // Insert before operation. 1575 (*refOperation)->checkValid(); 1576 beforeOp = (*refOperation)->get(); 1577 } else { 1578 // Insert at end (before null) is only valid if the block does not 1579 // already end in a known terminator (violating this will cause assertion 1580 // failures later). 1581 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1582 throw py::index_error("Cannot insert operation at the end of a block " 1583 "that already has a terminator. Did you mean to " 1584 "use 'InsertionPoint.at_block_terminator(block)' " 1585 "versus 'InsertionPoint(block)'?"); 1586 } 1587 } 1588 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1589 operation.setAttached(); 1590 } 1591 1592 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1593 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1594 if (mlirOperationIsNull(firstOp)) { 1595 // Just insert at end. 1596 return PyInsertionPoint(block); 1597 } 1598 1599 // Insert before first op. 1600 PyOperationRef firstOpRef = PyOperation::forOperation( 1601 block.getParentOperation()->getContext(), firstOp); 1602 return PyInsertionPoint{block, std::move(firstOpRef)}; 1603 } 1604 1605 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1606 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1607 if (mlirOperationIsNull(terminator)) 1608 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1609 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1610 block.getParentOperation()->getContext(), terminator); 1611 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1612 } 1613 1614 py::object PyInsertionPoint::contextEnter() { 1615 return PyThreadContextEntry::pushInsertionPoint(*this); 1616 } 1617 1618 void PyInsertionPoint::contextExit(const pybind11::object &excType, 1619 const pybind11::object &excVal, 1620 const pybind11::object &excTb) { 1621 PyThreadContextEntry::popInsertionPoint(*this); 1622 } 1623 1624 //------------------------------------------------------------------------------ 1625 // PyAttribute. 1626 //------------------------------------------------------------------------------ 1627 1628 bool PyAttribute::operator==(const PyAttribute &other) { 1629 return mlirAttributeEqual(attr, other.attr); 1630 } 1631 1632 py::object PyAttribute::getCapsule() { 1633 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1634 } 1635 1636 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1637 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1638 if (mlirAttributeIsNull(rawAttr)) 1639 throw py::error_already_set(); 1640 return PyAttribute( 1641 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1642 } 1643 1644 //------------------------------------------------------------------------------ 1645 // PyNamedAttribute. 1646 //------------------------------------------------------------------------------ 1647 1648 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1649 : ownedName(new std::string(std::move(ownedName))) { 1650 namedAttr = mlirNamedAttributeGet( 1651 mlirIdentifierGet(mlirAttributeGetContext(attr), 1652 toMlirStringRef(*this->ownedName)), 1653 attr); 1654 } 1655 1656 //------------------------------------------------------------------------------ 1657 // PyType. 1658 //------------------------------------------------------------------------------ 1659 1660 bool PyType::operator==(const PyType &other) { 1661 return mlirTypeEqual(type, other.type); 1662 } 1663 1664 py::object PyType::getCapsule() { 1665 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1666 } 1667 1668 PyType PyType::createFromCapsule(py::object capsule) { 1669 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1670 if (mlirTypeIsNull(rawType)) 1671 throw py::error_already_set(); 1672 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1673 rawType); 1674 } 1675 1676 //------------------------------------------------------------------------------ 1677 // PyValue and subclases. 1678 //------------------------------------------------------------------------------ 1679 1680 pybind11::object PyValue::getCapsule() { 1681 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1682 } 1683 1684 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1685 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1686 if (mlirValueIsNull(value)) 1687 throw py::error_already_set(); 1688 MlirOperation owner; 1689 if (mlirValueIsAOpResult(value)) 1690 owner = mlirOpResultGetOwner(value); 1691 if (mlirValueIsABlockArgument(value)) 1692 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1693 if (mlirOperationIsNull(owner)) 1694 throw py::error_already_set(); 1695 MlirContext ctx = mlirOperationGetContext(owner); 1696 PyOperationRef ownerRef = 1697 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1698 return PyValue(ownerRef, value); 1699 } 1700 1701 //------------------------------------------------------------------------------ 1702 // PySymbolTable. 1703 //------------------------------------------------------------------------------ 1704 1705 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1706 : operation(operation.getOperation().getRef()) { 1707 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1708 if (mlirSymbolTableIsNull(symbolTable)) { 1709 throw py::cast_error("Operation is not a Symbol Table."); 1710 } 1711 } 1712 1713 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1714 operation->checkValid(); 1715 MlirOperation symbol = mlirSymbolTableLookup( 1716 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1717 if (mlirOperationIsNull(symbol)) 1718 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1719 1720 return PyOperation::forOperation(operation->getContext(), symbol, 1721 operation.getObject()) 1722 ->createOpView(); 1723 } 1724 1725 void PySymbolTable::erase(PyOperationBase &symbol) { 1726 operation->checkValid(); 1727 symbol.getOperation().checkValid(); 1728 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1729 // The operation is also erased, so we must invalidate it. There may be Python 1730 // references to this operation so we don't want to delete it from the list of 1731 // live operations here. 1732 symbol.getOperation().valid = false; 1733 } 1734 1735 void PySymbolTable::dunderDel(const std::string &name) { 1736 py::object operation = dunderGetItem(name); 1737 erase(py::cast<PyOperationBase &>(operation)); 1738 } 1739 1740 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1741 operation->checkValid(); 1742 symbol.getOperation().checkValid(); 1743 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1744 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1745 if (mlirAttributeIsNull(symbolAttr)) 1746 throw py::value_error("Expected operation to have a symbol name."); 1747 return PyAttribute( 1748 symbol.getOperation().getContext(), 1749 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1750 } 1751 1752 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1753 // Op must already be a symbol. 1754 PyOperation &operation = symbol.getOperation(); 1755 operation.checkValid(); 1756 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1757 MlirAttribute existingNameAttr = 1758 mlirOperationGetAttributeByName(operation.get(), attrName); 1759 if (mlirAttributeIsNull(existingNameAttr)) 1760 throw py::value_error("Expected operation to have a symbol name."); 1761 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1762 } 1763 1764 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1765 const std::string &name) { 1766 // Op must already be a symbol. 1767 PyOperation &operation = symbol.getOperation(); 1768 operation.checkValid(); 1769 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1770 MlirAttribute existingNameAttr = 1771 mlirOperationGetAttributeByName(operation.get(), attrName); 1772 if (mlirAttributeIsNull(existingNameAttr)) 1773 throw py::value_error("Expected operation to have a symbol name."); 1774 MlirAttribute newNameAttr = 1775 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1776 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1777 } 1778 1779 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1780 PyOperation &operation = symbol.getOperation(); 1781 operation.checkValid(); 1782 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1783 MlirAttribute existingVisAttr = 1784 mlirOperationGetAttributeByName(operation.get(), attrName); 1785 if (mlirAttributeIsNull(existingVisAttr)) 1786 throw py::value_error("Expected operation to have a symbol visibility."); 1787 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1788 } 1789 1790 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1791 const std::string &visibility) { 1792 if (visibility != "public" && visibility != "private" && 1793 visibility != "nested") 1794 throw py::value_error( 1795 "Expected visibility to be 'public', 'private' or 'nested'"); 1796 PyOperation &operation = symbol.getOperation(); 1797 operation.checkValid(); 1798 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1799 MlirAttribute existingVisAttr = 1800 mlirOperationGetAttributeByName(operation.get(), attrName); 1801 if (mlirAttributeIsNull(existingVisAttr)) 1802 throw py::value_error("Expected operation to have a symbol visibility."); 1803 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1804 toMlirStringRef(visibility)); 1805 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1806 } 1807 1808 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1809 const std::string &newSymbol, 1810 PyOperationBase &from) { 1811 PyOperation &fromOperation = from.getOperation(); 1812 fromOperation.checkValid(); 1813 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1814 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1815 from.getOperation()))) 1816 1817 throw py::value_error("Symbol rename failed"); 1818 } 1819 1820 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1821 bool allSymUsesVisible, 1822 py::object callback) { 1823 PyOperation &fromOperation = from.getOperation(); 1824 fromOperation.checkValid(); 1825 struct UserData { 1826 PyMlirContextRef context; 1827 py::object callback; 1828 bool gotException; 1829 std::string exceptionWhat; 1830 py::object exceptionType; 1831 }; 1832 UserData userData{ 1833 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1834 mlirSymbolTableWalkSymbolTables( 1835 fromOperation.get(), allSymUsesVisible, 1836 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1837 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1838 auto pyFoundOp = 1839 PyOperation::forOperation(calleeUserData->context, foundOp); 1840 if (calleeUserData->gotException) 1841 return; 1842 try { 1843 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1844 } catch (py::error_already_set &e) { 1845 calleeUserData->gotException = true; 1846 calleeUserData->exceptionWhat = e.what(); 1847 calleeUserData->exceptionType = e.type(); 1848 } 1849 }, 1850 static_cast<void *>(&userData)); 1851 if (userData.gotException) { 1852 std::string message("Exception raised in callback: "); 1853 message.append(userData.exceptionWhat); 1854 throw std::runtime_error(message); 1855 } 1856 } 1857 1858 namespace { 1859 /// CRTP base class for Python MLIR values that subclass Value and should be 1860 /// castable from it. The value hierarchy is one level deep and is not supposed 1861 /// to accommodate other levels unless core MLIR changes. 1862 template <typename DerivedTy> 1863 class PyConcreteValue : public PyValue { 1864 public: 1865 // Derived classes must define statics for: 1866 // IsAFunctionTy isaFunction 1867 // const char *pyClassName 1868 // and redefine bindDerived. 1869 using ClassTy = py::class_<DerivedTy, PyValue>; 1870 using IsAFunctionTy = bool (*)(MlirValue); 1871 1872 PyConcreteValue() = default; 1873 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1874 : PyValue(operationRef, value) {} 1875 PyConcreteValue(PyValue &orig) 1876 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1877 1878 /// Attempts to cast the original value to the derived type and throws on 1879 /// type mismatches. 1880 static MlirValue castFrom(PyValue &orig) { 1881 if (!DerivedTy::isaFunction(orig.get())) { 1882 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1883 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1884 DerivedTy::pyClassName + 1885 " (from " + origRepr + ")"); 1886 } 1887 return orig.get(); 1888 } 1889 1890 /// Binds the Python module objects to functions of this class. 1891 static void bind(py::module &m) { 1892 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1893 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1894 cls.def_static( 1895 "isinstance", 1896 [](PyValue &otherValue) -> bool { 1897 return DerivedTy::isaFunction(otherValue); 1898 }, 1899 py::arg("other_value")); 1900 DerivedTy::bindDerived(cls); 1901 } 1902 1903 /// Implemented by derived classes to add methods to the Python subclass. 1904 static void bindDerived(ClassTy &m) {} 1905 }; 1906 1907 /// Python wrapper for MlirBlockArgument. 1908 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1909 public: 1910 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1911 static constexpr const char *pyClassName = "BlockArgument"; 1912 using PyConcreteValue::PyConcreteValue; 1913 1914 static void bindDerived(ClassTy &c) { 1915 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1916 return PyBlock(self.getParentOperation(), 1917 mlirBlockArgumentGetOwner(self.get())); 1918 }); 1919 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1920 return mlirBlockArgumentGetArgNumber(self.get()); 1921 }); 1922 c.def( 1923 "set_type", 1924 [](PyBlockArgument &self, PyType type) { 1925 return mlirBlockArgumentSetType(self.get(), type); 1926 }, 1927 py::arg("type")); 1928 } 1929 }; 1930 1931 /// Python wrapper for MlirOpResult. 1932 class PyOpResult : public PyConcreteValue<PyOpResult> { 1933 public: 1934 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1935 static constexpr const char *pyClassName = "OpResult"; 1936 using PyConcreteValue::PyConcreteValue; 1937 1938 static void bindDerived(ClassTy &c) { 1939 c.def_property_readonly("owner", [](PyOpResult &self) { 1940 assert( 1941 mlirOperationEqual(self.getParentOperation()->get(), 1942 mlirOpResultGetOwner(self.get())) && 1943 "expected the owner of the value in Python to match that in the IR"); 1944 return self.getParentOperation().getObject(); 1945 }); 1946 c.def_property_readonly("result_number", [](PyOpResult &self) { 1947 return mlirOpResultGetResultNumber(self.get()); 1948 }); 1949 } 1950 }; 1951 1952 /// Returns the list of types of the values held by container. 1953 template <typename Container> 1954 static std::vector<PyType> getValueTypes(Container &container, 1955 PyMlirContextRef &context) { 1956 std::vector<PyType> result; 1957 result.reserve(container.getNumElements()); 1958 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1959 result.push_back( 1960 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1961 } 1962 return result; 1963 } 1964 1965 /// A list of block arguments. Internally, these are stored as consecutive 1966 /// elements, random access is cheap. The argument list is associated with the 1967 /// operation that contains the block (detached blocks are not allowed in 1968 /// Python bindings) and extends its lifetime. 1969 class PyBlockArgumentList 1970 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1971 public: 1972 static constexpr const char *pyClassName = "BlockArgumentList"; 1973 1974 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1975 intptr_t startIndex = 0, intptr_t length = -1, 1976 intptr_t step = 1) 1977 : Sliceable(startIndex, 1978 length == -1 ? mlirBlockGetNumArguments(block) : length, 1979 step), 1980 operation(std::move(operation)), block(block) {} 1981 1982 /// Returns the number of arguments in the list. 1983 intptr_t getNumElements() { 1984 operation->checkValid(); 1985 return mlirBlockGetNumArguments(block); 1986 } 1987 1988 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1989 PyBlockArgument getElement(intptr_t pos) { 1990 MlirValue argument = mlirBlockGetArgument(block, pos); 1991 return PyBlockArgument(operation, argument); 1992 } 1993 1994 /// Returns a sublist of this list. 1995 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1996 intptr_t step) { 1997 return PyBlockArgumentList(operation, block, startIndex, length, step); 1998 } 1999 2000 static void bindDerived(ClassTy &c) { 2001 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 2002 return getValueTypes(self, self.operation->getContext()); 2003 }); 2004 } 2005 2006 private: 2007 PyOperationRef operation; 2008 MlirBlock block; 2009 }; 2010 2011 /// A list of operation operands. Internally, these are stored as consecutive 2012 /// elements, random access is cheap. The result list is associated with the 2013 /// operation whose results these are, and extends the lifetime of this 2014 /// operation. 2015 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 2016 public: 2017 static constexpr const char *pyClassName = "OpOperandList"; 2018 2019 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2020 intptr_t length = -1, intptr_t step = 1) 2021 : Sliceable(startIndex, 2022 length == -1 ? mlirOperationGetNumOperands(operation->get()) 2023 : length, 2024 step), 2025 operation(operation) {} 2026 2027 intptr_t getNumElements() { 2028 operation->checkValid(); 2029 return mlirOperationGetNumOperands(operation->get()); 2030 } 2031 2032 PyValue getElement(intptr_t pos) { 2033 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2034 MlirOperation owner; 2035 if (mlirValueIsAOpResult(operand)) 2036 owner = mlirOpResultGetOwner(operand); 2037 else if (mlirValueIsABlockArgument(operand)) 2038 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2039 else 2040 assert(false && "Value must be an block arg or op result."); 2041 PyOperationRef pyOwner = 2042 PyOperation::forOperation(operation->getContext(), owner); 2043 return PyValue(pyOwner, operand); 2044 } 2045 2046 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2047 return PyOpOperandList(operation, startIndex, length, step); 2048 } 2049 2050 void dunderSetItem(intptr_t index, PyValue value) { 2051 index = wrapIndex(index); 2052 mlirOperationSetOperand(operation->get(), index, value.get()); 2053 } 2054 2055 static void bindDerived(ClassTy &c) { 2056 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2057 } 2058 2059 private: 2060 PyOperationRef operation; 2061 }; 2062 2063 /// A list of operation results. Internally, these are stored as consecutive 2064 /// elements, random access is cheap. The result list is associated with the 2065 /// operation whose results these are, and extends the lifetime of this 2066 /// operation. 2067 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 2068 public: 2069 static constexpr const char *pyClassName = "OpResultList"; 2070 2071 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 2072 intptr_t length = -1, intptr_t step = 1) 2073 : Sliceable(startIndex, 2074 length == -1 ? mlirOperationGetNumResults(operation->get()) 2075 : length, 2076 step), 2077 operation(operation) {} 2078 2079 intptr_t getNumElements() { 2080 operation->checkValid(); 2081 return mlirOperationGetNumResults(operation->get()); 2082 } 2083 2084 PyOpResult getElement(intptr_t index) { 2085 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 2086 return PyOpResult(value); 2087 } 2088 2089 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2090 return PyOpResultList(operation, startIndex, length, step); 2091 } 2092 2093 static void bindDerived(ClassTy &c) { 2094 c.def_property_readonly("types", [](PyOpResultList &self) { 2095 return getValueTypes(self, self.operation->getContext()); 2096 }); 2097 } 2098 2099 private: 2100 PyOperationRef operation; 2101 }; 2102 2103 /// A list of operation attributes. Can be indexed by name, producing 2104 /// attributes, or by index, producing named attributes. 2105 class PyOpAttributeMap { 2106 public: 2107 PyOpAttributeMap(PyOperationRef operation) 2108 : operation(std::move(operation)) {} 2109 2110 PyAttribute dunderGetItemNamed(const std::string &name) { 2111 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2112 toMlirStringRef(name)); 2113 if (mlirAttributeIsNull(attr)) { 2114 throw SetPyError(PyExc_KeyError, 2115 "attempt to access a non-existent attribute"); 2116 } 2117 return PyAttribute(operation->getContext(), attr); 2118 } 2119 2120 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2121 if (index < 0 || index >= dunderLen()) { 2122 throw SetPyError(PyExc_IndexError, 2123 "attempt to access out of bounds attribute"); 2124 } 2125 MlirNamedAttribute namedAttr = 2126 mlirOperationGetAttribute(operation->get(), index); 2127 return PyNamedAttribute( 2128 namedAttr.attribute, 2129 std::string(mlirIdentifierStr(namedAttr.name).data, 2130 mlirIdentifierStr(namedAttr.name).length)); 2131 } 2132 2133 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2134 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2135 attr); 2136 } 2137 2138 void dunderDelItem(const std::string &name) { 2139 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2140 toMlirStringRef(name)); 2141 if (!removed) 2142 throw SetPyError(PyExc_KeyError, 2143 "attempt to delete a non-existent attribute"); 2144 } 2145 2146 intptr_t dunderLen() { 2147 return mlirOperationGetNumAttributes(operation->get()); 2148 } 2149 2150 bool dunderContains(const std::string &name) { 2151 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2152 operation->get(), toMlirStringRef(name))); 2153 } 2154 2155 static void bind(py::module &m) { 2156 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2157 .def("__contains__", &PyOpAttributeMap::dunderContains) 2158 .def("__len__", &PyOpAttributeMap::dunderLen) 2159 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2160 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2161 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2162 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2163 } 2164 2165 private: 2166 PyOperationRef operation; 2167 }; 2168 2169 } // namespace 2170 2171 //------------------------------------------------------------------------------ 2172 // Populates the core exports of the 'ir' submodule. 2173 //------------------------------------------------------------------------------ 2174 2175 void mlir::python::populateIRCore(py::module &m) { 2176 //---------------------------------------------------------------------------- 2177 // Enums. 2178 //---------------------------------------------------------------------------- 2179 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local()) 2180 .value("ERROR", MlirDiagnosticError) 2181 .value("WARNING", MlirDiagnosticWarning) 2182 .value("NOTE", MlirDiagnosticNote) 2183 .value("REMARK", MlirDiagnosticRemark); 2184 2185 //---------------------------------------------------------------------------- 2186 // Mapping of Diagnostics. 2187 //---------------------------------------------------------------------------- 2188 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local()) 2189 .def_property_readonly("severity", &PyDiagnostic::getSeverity) 2190 .def_property_readonly("location", &PyDiagnostic::getLocation) 2191 .def_property_readonly("message", &PyDiagnostic::getMessage) 2192 .def_property_readonly("notes", &PyDiagnostic::getNotes) 2193 .def("__str__", [](PyDiagnostic &self) -> py::str { 2194 if (!self.isValid()) 2195 return "<Invalid Diagnostic>"; 2196 return self.getMessage(); 2197 }); 2198 2199 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local()) 2200 .def("detach", &PyDiagnosticHandler::detach) 2201 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) 2202 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) 2203 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2204 .def("__exit__", &PyDiagnosticHandler::contextExit); 2205 2206 //---------------------------------------------------------------------------- 2207 // Mapping of MlirContext. 2208 //---------------------------------------------------------------------------- 2209 py::class_<PyMlirContext>(m, "Context", py::module_local()) 2210 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2211 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2212 .def("_get_context_again", 2213 [](PyMlirContext &self) { 2214 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2215 return ref.releaseObject(); 2216 }) 2217 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2218 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) 2219 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2220 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2221 &PyMlirContext::getCapsule) 2222 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2223 .def("__enter__", &PyMlirContext::contextEnter) 2224 .def("__exit__", &PyMlirContext::contextExit) 2225 .def_property_readonly_static( 2226 "current", 2227 [](py::object & /*class*/) { 2228 auto *context = PyThreadContextEntry::getDefaultContext(); 2229 if (!context) 2230 throw SetPyError(PyExc_ValueError, "No current Context"); 2231 return context; 2232 }, 2233 "Gets the Context bound to the current thread or raises ValueError") 2234 .def_property_readonly( 2235 "dialects", 2236 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2237 "Gets a container for accessing dialects by name") 2238 .def_property_readonly( 2239 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2240 "Alias for 'dialect'") 2241 .def( 2242 "get_dialect_descriptor", 2243 [=](PyMlirContext &self, std::string &name) { 2244 MlirDialect dialect = mlirContextGetOrLoadDialect( 2245 self.get(), {name.data(), name.size()}); 2246 if (mlirDialectIsNull(dialect)) { 2247 throw SetPyError(PyExc_ValueError, 2248 Twine("Dialect '") + name + "' not found"); 2249 } 2250 return PyDialectDescriptor(self.getRef(), dialect); 2251 }, 2252 py::arg("dialect_name"), 2253 "Gets or loads a dialect by name, returning its descriptor object") 2254 .def_property( 2255 "allow_unregistered_dialects", 2256 [](PyMlirContext &self) -> bool { 2257 return mlirContextGetAllowUnregisteredDialects(self.get()); 2258 }, 2259 [](PyMlirContext &self, bool value) { 2260 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2261 }) 2262 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2263 py::arg("callback"), 2264 "Attaches a diagnostic handler that will receive callbacks") 2265 .def( 2266 "enable_multithreading", 2267 [](PyMlirContext &self, bool enable) { 2268 mlirContextEnableMultithreading(self.get(), enable); 2269 }, 2270 py::arg("enable")) 2271 .def( 2272 "is_registered_operation", 2273 [](PyMlirContext &self, std::string &name) { 2274 return mlirContextIsRegisteredOperation( 2275 self.get(), MlirStringRef{name.data(), name.size()}); 2276 }, 2277 py::arg("operation_name")); 2278 2279 //---------------------------------------------------------------------------- 2280 // Mapping of PyDialectDescriptor 2281 //---------------------------------------------------------------------------- 2282 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2283 .def_property_readonly("namespace", 2284 [](PyDialectDescriptor &self) { 2285 MlirStringRef ns = 2286 mlirDialectGetNamespace(self.get()); 2287 return py::str(ns.data, ns.length); 2288 }) 2289 .def("__repr__", [](PyDialectDescriptor &self) { 2290 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2291 std::string repr("<DialectDescriptor "); 2292 repr.append(ns.data, ns.length); 2293 repr.append(">"); 2294 return repr; 2295 }); 2296 2297 //---------------------------------------------------------------------------- 2298 // Mapping of PyDialects 2299 //---------------------------------------------------------------------------- 2300 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2301 .def("__getitem__", 2302 [=](PyDialects &self, std::string keyName) { 2303 MlirDialect dialect = 2304 self.getDialectForKey(keyName, /*attrError=*/false); 2305 py::object descriptor = 2306 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2307 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2308 }) 2309 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2310 MlirDialect dialect = 2311 self.getDialectForKey(attrName, /*attrError=*/true); 2312 py::object descriptor = 2313 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2314 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2315 }); 2316 2317 //---------------------------------------------------------------------------- 2318 // Mapping of PyDialect 2319 //---------------------------------------------------------------------------- 2320 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2321 .def(py::init<py::object>(), py::arg("descriptor")) 2322 .def_property_readonly( 2323 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2324 .def("__repr__", [](py::object self) { 2325 auto clazz = self.attr("__class__"); 2326 return py::str("<Dialect ") + 2327 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2328 clazz.attr("__module__") + py::str(".") + 2329 clazz.attr("__name__") + py::str(")>"); 2330 }); 2331 2332 //---------------------------------------------------------------------------- 2333 // Mapping of Location 2334 //---------------------------------------------------------------------------- 2335 py::class_<PyLocation>(m, "Location", py::module_local()) 2336 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2337 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2338 .def("__enter__", &PyLocation::contextEnter) 2339 .def("__exit__", &PyLocation::contextExit) 2340 .def("__eq__", 2341 [](PyLocation &self, PyLocation &other) -> bool { 2342 return mlirLocationEqual(self, other); 2343 }) 2344 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2345 .def_property_readonly_static( 2346 "current", 2347 [](py::object & /*class*/) { 2348 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2349 if (!loc) 2350 throw SetPyError(PyExc_ValueError, "No current Location"); 2351 return loc; 2352 }, 2353 "Gets the Location bound to the current thread or raises ValueError") 2354 .def_static( 2355 "unknown", 2356 [](DefaultingPyMlirContext context) { 2357 return PyLocation(context->getRef(), 2358 mlirLocationUnknownGet(context->get())); 2359 }, 2360 py::arg("context") = py::none(), 2361 "Gets a Location representing an unknown location") 2362 .def_static( 2363 "callsite", 2364 [](PyLocation callee, const std::vector<PyLocation> &frames, 2365 DefaultingPyMlirContext context) { 2366 if (frames.empty()) 2367 throw py::value_error("No caller frames provided"); 2368 MlirLocation caller = frames.back().get(); 2369 for (const PyLocation &frame : 2370 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2371 caller = mlirLocationCallSiteGet(frame.get(), caller); 2372 return PyLocation(context->getRef(), 2373 mlirLocationCallSiteGet(callee.get(), caller)); 2374 }, 2375 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2376 kContextGetCallSiteLocationDocstring) 2377 .def_static( 2378 "file", 2379 [](std::string filename, int line, int col, 2380 DefaultingPyMlirContext context) { 2381 return PyLocation( 2382 context->getRef(), 2383 mlirLocationFileLineColGet( 2384 context->get(), toMlirStringRef(filename), line, col)); 2385 }, 2386 py::arg("filename"), py::arg("line"), py::arg("col"), 2387 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2388 .def_static( 2389 "fused", 2390 [](const std::vector<PyLocation> &pyLocations, 2391 llvm::Optional<PyAttribute> metadata, 2392 DefaultingPyMlirContext context) { 2393 llvm::SmallVector<MlirLocation, 4> locations; 2394 locations.reserve(pyLocations.size()); 2395 for (auto &pyLocation : pyLocations) 2396 locations.push_back(pyLocation.get()); 2397 MlirLocation location = mlirLocationFusedGet( 2398 context->get(), locations.size(), locations.data(), 2399 metadata ? metadata->get() : MlirAttribute{0}); 2400 return PyLocation(context->getRef(), location); 2401 }, 2402 py::arg("locations"), py::arg("metadata") = py::none(), 2403 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2404 .def_static( 2405 "name", 2406 [](std::string name, llvm::Optional<PyLocation> childLoc, 2407 DefaultingPyMlirContext context) { 2408 return PyLocation( 2409 context->getRef(), 2410 mlirLocationNameGet( 2411 context->get(), toMlirStringRef(name), 2412 childLoc ? childLoc->get() 2413 : mlirLocationUnknownGet(context->get()))); 2414 }, 2415 py::arg("name"), py::arg("childLoc") = py::none(), 2416 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2417 .def_property_readonly( 2418 "context", 2419 [](PyLocation &self) { return self.getContext().getObject(); }, 2420 "Context that owns the Location") 2421 .def( 2422 "emit_error", 2423 [](PyLocation &self, std::string message) { 2424 mlirEmitError(self, message.c_str()); 2425 }, 2426 py::arg("message"), "Emits an error at this location") 2427 .def("__repr__", [](PyLocation &self) { 2428 PyPrintAccumulator printAccum; 2429 mlirLocationPrint(self, printAccum.getCallback(), 2430 printAccum.getUserData()); 2431 return printAccum.join(); 2432 }); 2433 2434 //---------------------------------------------------------------------------- 2435 // Mapping of Module 2436 //---------------------------------------------------------------------------- 2437 py::class_<PyModule>(m, "Module", py::module_local()) 2438 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2439 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2440 .def_static( 2441 "parse", 2442 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2443 MlirModule module = mlirModuleCreateParse( 2444 context->get(), toMlirStringRef(moduleAsm)); 2445 // TODO: Rework error reporting once diagnostic engine is exposed 2446 // in C API. 2447 if (mlirModuleIsNull(module)) { 2448 throw SetPyError( 2449 PyExc_ValueError, 2450 "Unable to parse module assembly (see diagnostics)"); 2451 } 2452 return PyModule::forModule(module).releaseObject(); 2453 }, 2454 py::arg("asm"), py::arg("context") = py::none(), 2455 kModuleParseDocstring) 2456 .def_static( 2457 "create", 2458 [](DefaultingPyLocation loc) { 2459 MlirModule module = mlirModuleCreateEmpty(loc); 2460 return PyModule::forModule(module).releaseObject(); 2461 }, 2462 py::arg("loc") = py::none(), "Creates an empty module") 2463 .def_property_readonly( 2464 "context", 2465 [](PyModule &self) { return self.getContext().getObject(); }, 2466 "Context that created the Module") 2467 .def_property_readonly( 2468 "operation", 2469 [](PyModule &self) { 2470 return PyOperation::forOperation(self.getContext(), 2471 mlirModuleGetOperation(self.get()), 2472 self.getRef().releaseObject()) 2473 .releaseObject(); 2474 }, 2475 "Accesses the module as an operation") 2476 .def_property_readonly( 2477 "body", 2478 [](PyModule &self) { 2479 PyOperationRef moduleOp = PyOperation::forOperation( 2480 self.getContext(), mlirModuleGetOperation(self.get()), 2481 self.getRef().releaseObject()); 2482 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 2483 return returnBlock; 2484 }, 2485 "Return the block for this module") 2486 .def( 2487 "dump", 2488 [](PyModule &self) { 2489 mlirOperationDump(mlirModuleGetOperation(self.get())); 2490 }, 2491 kDumpDocstring) 2492 .def( 2493 "__str__", 2494 [](py::object self) { 2495 // Defer to the operation's __str__. 2496 return self.attr("operation").attr("__str__")(); 2497 }, 2498 kOperationStrDunderDocstring); 2499 2500 //---------------------------------------------------------------------------- 2501 // Mapping of Operation. 2502 //---------------------------------------------------------------------------- 2503 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2504 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2505 [](PyOperationBase &self) { 2506 return self.getOperation().getCapsule(); 2507 }) 2508 .def("__eq__", 2509 [](PyOperationBase &self, PyOperationBase &other) { 2510 return &self.getOperation() == &other.getOperation(); 2511 }) 2512 .def("__eq__", 2513 [](PyOperationBase &self, py::object other) { return false; }) 2514 .def("__hash__", 2515 [](PyOperationBase &self) { 2516 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2517 }) 2518 .def_property_readonly("attributes", 2519 [](PyOperationBase &self) { 2520 return PyOpAttributeMap( 2521 self.getOperation().getRef()); 2522 }) 2523 .def_property_readonly("operands", 2524 [](PyOperationBase &self) { 2525 return PyOpOperandList( 2526 self.getOperation().getRef()); 2527 }) 2528 .def_property_readonly("regions", 2529 [](PyOperationBase &self) { 2530 return PyRegionList( 2531 self.getOperation().getRef()); 2532 }) 2533 .def_property_readonly( 2534 "results", 2535 [](PyOperationBase &self) { 2536 return PyOpResultList(self.getOperation().getRef()); 2537 }, 2538 "Returns the list of Operation results.") 2539 .def_property_readonly( 2540 "result", 2541 [](PyOperationBase &self) { 2542 auto &operation = self.getOperation(); 2543 auto numResults = mlirOperationGetNumResults(operation); 2544 if (numResults != 1) { 2545 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2546 throw SetPyError( 2547 PyExc_ValueError, 2548 Twine("Cannot call .result on operation ") + 2549 StringRef(name.data, name.length) + " which has " + 2550 Twine(numResults) + 2551 " results (it is only valid for operations with a " 2552 "single result)"); 2553 } 2554 return PyOpResult(operation.getRef(), 2555 mlirOperationGetResult(operation, 0)); 2556 }, 2557 "Shortcut to get an op result if it has only one (throws an error " 2558 "otherwise).") 2559 .def_property_readonly( 2560 "location", 2561 [](PyOperationBase &self) { 2562 PyOperation &operation = self.getOperation(); 2563 return PyLocation(operation.getContext(), 2564 mlirOperationGetLocation(operation.get())); 2565 }, 2566 "Returns the source location the operation was defined or derived " 2567 "from.") 2568 .def( 2569 "__str__", 2570 [](PyOperationBase &self) { 2571 return self.getAsm(/*binary=*/false, 2572 /*largeElementsLimit=*/llvm::None, 2573 /*enableDebugInfo=*/false, 2574 /*prettyDebugInfo=*/false, 2575 /*printGenericOpForm=*/false, 2576 /*useLocalScope=*/false, 2577 /*assumeVerified=*/false); 2578 }, 2579 "Returns the assembly form of the operation.") 2580 .def("print", &PyOperationBase::print, 2581 // Careful: Lots of arguments must match up with print method. 2582 py::arg("file") = py::none(), py::arg("binary") = false, 2583 py::arg("large_elements_limit") = py::none(), 2584 py::arg("enable_debug_info") = false, 2585 py::arg("pretty_debug_info") = false, 2586 py::arg("print_generic_op_form") = false, 2587 py::arg("use_local_scope") = false, 2588 py::arg("assume_verified") = false, kOperationPrintDocstring) 2589 .def("get_asm", &PyOperationBase::getAsm, 2590 // Careful: Lots of arguments must match up with get_asm method. 2591 py::arg("binary") = false, 2592 py::arg("large_elements_limit") = py::none(), 2593 py::arg("enable_debug_info") = false, 2594 py::arg("pretty_debug_info") = false, 2595 py::arg("print_generic_op_form") = false, 2596 py::arg("use_local_scope") = false, 2597 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2598 .def( 2599 "verify", 2600 [](PyOperationBase &self) { 2601 return mlirOperationVerify(self.getOperation()); 2602 }, 2603 "Verify the operation and return true if it passes, false if it " 2604 "fails.") 2605 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2606 "Puts self immediately after the other operation in its parent " 2607 "block.") 2608 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2609 "Puts self immediately before the other operation in its parent " 2610 "block.") 2611 .def( 2612 "detach_from_parent", 2613 [](PyOperationBase &self) { 2614 PyOperation &operation = self.getOperation(); 2615 operation.checkValid(); 2616 if (!operation.isAttached()) 2617 throw py::value_error("Detached operation has no parent."); 2618 2619 operation.detachFromParent(); 2620 return operation.createOpView(); 2621 }, 2622 "Detaches the operation from its parent block."); 2623 2624 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2625 .def_static("create", &PyOperation::create, py::arg("name"), 2626 py::arg("results") = py::none(), 2627 py::arg("operands") = py::none(), 2628 py::arg("attributes") = py::none(), 2629 py::arg("successors") = py::none(), py::arg("regions") = 0, 2630 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2631 kOperationCreateDocstring) 2632 .def_property_readonly("parent", 2633 [](PyOperation &self) -> py::object { 2634 auto parent = self.getParentOperation(); 2635 if (parent) 2636 return parent->getObject(); 2637 return py::none(); 2638 }) 2639 .def("erase", &PyOperation::erase) 2640 .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) 2641 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2642 &PyOperation::getCapsule) 2643 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2644 .def_property_readonly("name", 2645 [](PyOperation &self) { 2646 self.checkValid(); 2647 MlirOperation operation = self.get(); 2648 MlirStringRef name = mlirIdentifierStr( 2649 mlirOperationGetName(operation)); 2650 return py::str(name.data, name.length); 2651 }) 2652 .def_property_readonly( 2653 "context", 2654 [](PyOperation &self) { 2655 self.checkValid(); 2656 return self.getContext().getObject(); 2657 }, 2658 "Context that owns the Operation") 2659 .def_property_readonly("opview", &PyOperation::createOpView); 2660 2661 auto opViewClass = 2662 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2663 .def(py::init<py::object>(), py::arg("operation")) 2664 .def_property_readonly("operation", &PyOpView::getOperationObject) 2665 .def_property_readonly( 2666 "context", 2667 [](PyOpView &self) { 2668 return self.getOperation().getContext().getObject(); 2669 }, 2670 "Context that owns the Operation") 2671 .def("__str__", [](PyOpView &self) { 2672 return py::str(self.getOperationObject()); 2673 }); 2674 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2675 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2676 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2677 opViewClass.attr("build_generic") = classmethod( 2678 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2679 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2680 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2681 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2682 "Builds a specific, generated OpView based on class level attributes."); 2683 2684 //---------------------------------------------------------------------------- 2685 // Mapping of PyRegion. 2686 //---------------------------------------------------------------------------- 2687 py::class_<PyRegion>(m, "Region", py::module_local()) 2688 .def_property_readonly( 2689 "blocks", 2690 [](PyRegion &self) { 2691 return PyBlockList(self.getParentOperation(), self.get()); 2692 }, 2693 "Returns a forward-optimized sequence of blocks.") 2694 .def_property_readonly( 2695 "owner", 2696 [](PyRegion &self) { 2697 return self.getParentOperation()->createOpView(); 2698 }, 2699 "Returns the operation owning this region.") 2700 .def( 2701 "__iter__", 2702 [](PyRegion &self) { 2703 self.checkValid(); 2704 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2705 return PyBlockIterator(self.getParentOperation(), firstBlock); 2706 }, 2707 "Iterates over blocks in the region.") 2708 .def("__eq__", 2709 [](PyRegion &self, PyRegion &other) { 2710 return self.get().ptr == other.get().ptr; 2711 }) 2712 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2713 2714 //---------------------------------------------------------------------------- 2715 // Mapping of PyBlock. 2716 //---------------------------------------------------------------------------- 2717 py::class_<PyBlock>(m, "Block", py::module_local()) 2718 .def_property_readonly( 2719 "owner", 2720 [](PyBlock &self) { 2721 return self.getParentOperation()->createOpView(); 2722 }, 2723 "Returns the owning operation of this block.") 2724 .def_property_readonly( 2725 "region", 2726 [](PyBlock &self) { 2727 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2728 return PyRegion(self.getParentOperation(), region); 2729 }, 2730 "Returns the owning region of this block.") 2731 .def_property_readonly( 2732 "arguments", 2733 [](PyBlock &self) { 2734 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2735 }, 2736 "Returns a list of block arguments.") 2737 .def_property_readonly( 2738 "operations", 2739 [](PyBlock &self) { 2740 return PyOperationList(self.getParentOperation(), self.get()); 2741 }, 2742 "Returns a forward-optimized sequence of operations.") 2743 .def_static( 2744 "create_at_start", 2745 [](PyRegion &parent, py::list pyArgTypes) { 2746 parent.checkValid(); 2747 llvm::SmallVector<MlirType, 4> argTypes; 2748 llvm::SmallVector<MlirLocation, 4> argLocs; 2749 argTypes.reserve(pyArgTypes.size()); 2750 argLocs.reserve(pyArgTypes.size()); 2751 for (auto &pyArg : pyArgTypes) { 2752 argTypes.push_back(pyArg.cast<PyType &>()); 2753 // TODO: Pass in a proper location here. 2754 argLocs.push_back( 2755 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2756 } 2757 2758 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2759 argLocs.data()); 2760 mlirRegionInsertOwnedBlock(parent, 0, block); 2761 return PyBlock(parent.getParentOperation(), block); 2762 }, 2763 py::arg("parent"), py::arg("arg_types") = py::list(), 2764 "Creates and returns a new Block at the beginning of the given " 2765 "region (with given argument types).") 2766 .def( 2767 "append_to", 2768 [](PyBlock &self, PyRegion ®ion) { 2769 MlirBlock b = self.get(); 2770 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) 2771 mlirBlockDetach(b); 2772 mlirRegionAppendOwnedBlock(region.get(), b); 2773 }, 2774 "Append this block to a region, transferring ownership if necessary") 2775 .def( 2776 "create_before", 2777 [](PyBlock &self, py::args pyArgTypes) { 2778 self.checkValid(); 2779 llvm::SmallVector<MlirType, 4> argTypes; 2780 llvm::SmallVector<MlirLocation, 4> argLocs; 2781 argTypes.reserve(pyArgTypes.size()); 2782 argLocs.reserve(pyArgTypes.size()); 2783 for (auto &pyArg : pyArgTypes) { 2784 argTypes.push_back(pyArg.cast<PyType &>()); 2785 // TODO: Pass in a proper location here. 2786 argLocs.push_back( 2787 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2788 } 2789 2790 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2791 argLocs.data()); 2792 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2793 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2794 return PyBlock(self.getParentOperation(), block); 2795 }, 2796 "Creates and returns a new Block before this block " 2797 "(with given argument types).") 2798 .def( 2799 "create_after", 2800 [](PyBlock &self, py::args pyArgTypes) { 2801 self.checkValid(); 2802 llvm::SmallVector<MlirType, 4> argTypes; 2803 llvm::SmallVector<MlirLocation, 4> argLocs; 2804 argTypes.reserve(pyArgTypes.size()); 2805 argLocs.reserve(pyArgTypes.size()); 2806 for (auto &pyArg : pyArgTypes) { 2807 argTypes.push_back(pyArg.cast<PyType &>()); 2808 2809 // TODO: Pass in a proper location here. 2810 argLocs.push_back( 2811 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2812 } 2813 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2814 argLocs.data()); 2815 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2816 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2817 return PyBlock(self.getParentOperation(), block); 2818 }, 2819 "Creates and returns a new Block after this block " 2820 "(with given argument types).") 2821 .def( 2822 "__iter__", 2823 [](PyBlock &self) { 2824 self.checkValid(); 2825 MlirOperation firstOperation = 2826 mlirBlockGetFirstOperation(self.get()); 2827 return PyOperationIterator(self.getParentOperation(), 2828 firstOperation); 2829 }, 2830 "Iterates over operations in the block.") 2831 .def("__eq__", 2832 [](PyBlock &self, PyBlock &other) { 2833 return self.get().ptr == other.get().ptr; 2834 }) 2835 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2836 .def( 2837 "__str__", 2838 [](PyBlock &self) { 2839 self.checkValid(); 2840 PyPrintAccumulator printAccum; 2841 mlirBlockPrint(self.get(), printAccum.getCallback(), 2842 printAccum.getUserData()); 2843 return printAccum.join(); 2844 }, 2845 "Returns the assembly form of the block.") 2846 .def( 2847 "append", 2848 [](PyBlock &self, PyOperationBase &operation) { 2849 if (operation.getOperation().isAttached()) 2850 operation.getOperation().detachFromParent(); 2851 2852 MlirOperation mlirOperation = operation.getOperation().get(); 2853 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2854 operation.getOperation().setAttached( 2855 self.getParentOperation().getObject()); 2856 }, 2857 py::arg("operation"), 2858 "Appends an operation to this block. If the operation is currently " 2859 "in another block, it will be moved."); 2860 2861 //---------------------------------------------------------------------------- 2862 // Mapping of PyInsertionPoint. 2863 //---------------------------------------------------------------------------- 2864 2865 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2866 .def(py::init<PyBlock &>(), py::arg("block"), 2867 "Inserts after the last operation but still inside the block.") 2868 .def("__enter__", &PyInsertionPoint::contextEnter) 2869 .def("__exit__", &PyInsertionPoint::contextExit) 2870 .def_property_readonly_static( 2871 "current", 2872 [](py::object & /*class*/) { 2873 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2874 if (!ip) 2875 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2876 return ip; 2877 }, 2878 "Gets the InsertionPoint bound to the current thread or raises " 2879 "ValueError if none has been set") 2880 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2881 "Inserts before a referenced operation.") 2882 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2883 py::arg("block"), "Inserts at the beginning of the block.") 2884 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2885 py::arg("block"), "Inserts before the block terminator.") 2886 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2887 "Inserts an operation.") 2888 .def_property_readonly( 2889 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2890 "Returns the block that this InsertionPoint points to."); 2891 2892 //---------------------------------------------------------------------------- 2893 // Mapping of PyAttribute. 2894 //---------------------------------------------------------------------------- 2895 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2896 // Delegate to the PyAttribute copy constructor, which will also lifetime 2897 // extend the backing context which owns the MlirAttribute. 2898 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2899 "Casts the passed attribute to the generic Attribute") 2900 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2901 &PyAttribute::getCapsule) 2902 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2903 .def_static( 2904 "parse", 2905 [](std::string attrSpec, DefaultingPyMlirContext context) { 2906 MlirAttribute type = mlirAttributeParseGet( 2907 context->get(), toMlirStringRef(attrSpec)); 2908 // TODO: Rework error reporting once diagnostic engine is exposed 2909 // in C API. 2910 if (mlirAttributeIsNull(type)) { 2911 throw SetPyError(PyExc_ValueError, 2912 Twine("Unable to parse attribute: '") + 2913 attrSpec + "'"); 2914 } 2915 return PyAttribute(context->getRef(), type); 2916 }, 2917 py::arg("asm"), py::arg("context") = py::none(), 2918 "Parses an attribute from an assembly form") 2919 .def_property_readonly( 2920 "context", 2921 [](PyAttribute &self) { return self.getContext().getObject(); }, 2922 "Context that owns the Attribute") 2923 .def_property_readonly("type", 2924 [](PyAttribute &self) { 2925 return PyType(self.getContext()->getRef(), 2926 mlirAttributeGetType(self)); 2927 }) 2928 .def( 2929 "get_named", 2930 [](PyAttribute &self, std::string name) { 2931 return PyNamedAttribute(self, std::move(name)); 2932 }, 2933 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2934 .def("__eq__", 2935 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2936 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2937 .def("__hash__", 2938 [](PyAttribute &self) { 2939 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2940 }) 2941 .def( 2942 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2943 kDumpDocstring) 2944 .def( 2945 "__str__", 2946 [](PyAttribute &self) { 2947 PyPrintAccumulator printAccum; 2948 mlirAttributePrint(self, printAccum.getCallback(), 2949 printAccum.getUserData()); 2950 return printAccum.join(); 2951 }, 2952 "Returns the assembly form of the Attribute.") 2953 .def("__repr__", [](PyAttribute &self) { 2954 // Generally, assembly formats are not printed for __repr__ because 2955 // this can cause exceptionally long debug output and exceptions. 2956 // However, attribute values are generally considered useful and are 2957 // printed. This may need to be re-evaluated if debug dumps end up 2958 // being excessive. 2959 PyPrintAccumulator printAccum; 2960 printAccum.parts.append("Attribute("); 2961 mlirAttributePrint(self, printAccum.getCallback(), 2962 printAccum.getUserData()); 2963 printAccum.parts.append(")"); 2964 return printAccum.join(); 2965 }); 2966 2967 //---------------------------------------------------------------------------- 2968 // Mapping of PyNamedAttribute 2969 //---------------------------------------------------------------------------- 2970 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2971 .def("__repr__", 2972 [](PyNamedAttribute &self) { 2973 PyPrintAccumulator printAccum; 2974 printAccum.parts.append("NamedAttribute("); 2975 printAccum.parts.append( 2976 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2977 mlirIdentifierStr(self.namedAttr.name).length)); 2978 printAccum.parts.append("="); 2979 mlirAttributePrint(self.namedAttr.attribute, 2980 printAccum.getCallback(), 2981 printAccum.getUserData()); 2982 printAccum.parts.append(")"); 2983 return printAccum.join(); 2984 }) 2985 .def_property_readonly( 2986 "name", 2987 [](PyNamedAttribute &self) { 2988 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2989 mlirIdentifierStr(self.namedAttr.name).length); 2990 }, 2991 "The name of the NamedAttribute binding") 2992 .def_property_readonly( 2993 "attr", 2994 [](PyNamedAttribute &self) { 2995 // TODO: When named attribute is removed/refactored, also remove 2996 // this constructor (it does an inefficient table lookup). 2997 auto contextRef = PyMlirContext::forContext( 2998 mlirAttributeGetContext(self.namedAttr.attribute)); 2999 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 3000 }, 3001 py::keep_alive<0, 1>(), 3002 "The underlying generic attribute of the NamedAttribute binding"); 3003 3004 //---------------------------------------------------------------------------- 3005 // Mapping of PyType. 3006 //---------------------------------------------------------------------------- 3007 py::class_<PyType>(m, "Type", py::module_local()) 3008 // Delegate to the PyType copy constructor, which will also lifetime 3009 // extend the backing context which owns the MlirType. 3010 .def(py::init<PyType &>(), py::arg("cast_from_type"), 3011 "Casts the passed type to the generic Type") 3012 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 3013 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 3014 .def_static( 3015 "parse", 3016 [](std::string typeSpec, DefaultingPyMlirContext context) { 3017 MlirType type = 3018 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 3019 // TODO: Rework error reporting once diagnostic engine is exposed 3020 // in C API. 3021 if (mlirTypeIsNull(type)) { 3022 throw SetPyError(PyExc_ValueError, 3023 Twine("Unable to parse type: '") + typeSpec + 3024 "'"); 3025 } 3026 return PyType(context->getRef(), type); 3027 }, 3028 py::arg("asm"), py::arg("context") = py::none(), 3029 kContextParseTypeDocstring) 3030 .def_property_readonly( 3031 "context", [](PyType &self) { return self.getContext().getObject(); }, 3032 "Context that owns the Type") 3033 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3034 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 3035 .def("__hash__", 3036 [](PyType &self) { 3037 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3038 }) 3039 .def( 3040 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3041 .def( 3042 "__str__", 3043 [](PyType &self) { 3044 PyPrintAccumulator printAccum; 3045 mlirTypePrint(self, printAccum.getCallback(), 3046 printAccum.getUserData()); 3047 return printAccum.join(); 3048 }, 3049 "Returns the assembly form of the type.") 3050 .def("__repr__", [](PyType &self) { 3051 // Generally, assembly formats are not printed for __repr__ because 3052 // this can cause exceptionally long debug output and exceptions. 3053 // However, types are an exception as they typically have compact 3054 // assembly forms and printing them is useful. 3055 PyPrintAccumulator printAccum; 3056 printAccum.parts.append("Type("); 3057 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 3058 printAccum.parts.append(")"); 3059 return printAccum.join(); 3060 }); 3061 3062 //---------------------------------------------------------------------------- 3063 // Mapping of Value. 3064 //---------------------------------------------------------------------------- 3065 py::class_<PyValue>(m, "Value", py::module_local()) 3066 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3067 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3068 .def_property_readonly( 3069 "context", 3070 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3071 "Context in which the value lives.") 3072 .def( 3073 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3074 kDumpDocstring) 3075 .def_property_readonly( 3076 "owner", 3077 [](PyValue &self) { 3078 assert(mlirOperationEqual(self.getParentOperation()->get(), 3079 mlirOpResultGetOwner(self.get())) && 3080 "expected the owner of the value in Python to match that in " 3081 "the IR"); 3082 return self.getParentOperation().getObject(); 3083 }) 3084 .def("__eq__", 3085 [](PyValue &self, PyValue &other) { 3086 return self.get().ptr == other.get().ptr; 3087 }) 3088 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 3089 .def("__hash__", 3090 [](PyValue &self) { 3091 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3092 }) 3093 .def( 3094 "__str__", 3095 [](PyValue &self) { 3096 PyPrintAccumulator printAccum; 3097 printAccum.parts.append("Value("); 3098 mlirValuePrint(self.get(), printAccum.getCallback(), 3099 printAccum.getUserData()); 3100 printAccum.parts.append(")"); 3101 return printAccum.join(); 3102 }, 3103 kValueDunderStrDocstring) 3104 .def_property_readonly("type", [](PyValue &self) { 3105 return PyType(self.getParentOperation()->getContext(), 3106 mlirValueGetType(self.get())); 3107 }); 3108 PyBlockArgument::bind(m); 3109 PyOpResult::bind(m); 3110 3111 //---------------------------------------------------------------------------- 3112 // Mapping of SymbolTable. 3113 //---------------------------------------------------------------------------- 3114 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 3115 .def(py::init<PyOperationBase &>()) 3116 .def("__getitem__", &PySymbolTable::dunderGetItem) 3117 .def("insert", &PySymbolTable::insert, py::arg("operation")) 3118 .def("erase", &PySymbolTable::erase, py::arg("operation")) 3119 .def("__delitem__", &PySymbolTable::dunderDel) 3120 .def("__contains__", 3121 [](PySymbolTable &table, const std::string &name) { 3122 return !mlirOperationIsNull(mlirSymbolTableLookup( 3123 table, mlirStringRefCreate(name.data(), name.length()))); 3124 }) 3125 // Static helpers. 3126 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3127 py::arg("symbol"), py::arg("name")) 3128 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3129 py::arg("symbol")) 3130 .def_static("get_visibility", &PySymbolTable::getVisibility, 3131 py::arg("symbol")) 3132 .def_static("set_visibility", &PySymbolTable::setVisibility, 3133 py::arg("symbol"), py::arg("visibility")) 3134 .def_static("replace_all_symbol_uses", 3135 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 3136 py::arg("new_symbol"), py::arg("from_op")) 3137 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3138 py::arg("from_op"), py::arg("all_sym_uses_visible"), 3139 py::arg("callback")); 3140 3141 // Container bindings. 3142 PyBlockArgumentList::bind(m); 3143 PyBlockIterator::bind(m); 3144 PyBlockList::bind(m); 3145 PyOperationIterator::bind(m); 3146 PyOperationList::bind(m); 3147 PyOpAttributeMap::bind(m); 3148 PyOpOperandList::bind(m); 3149 PyOpResultList::bind(m); 3150 PyRegionIterator::bind(m); 3151 PyRegionList::bind(m); 3152 3153 // Debug bindings. 3154 PyGlobalDebugFlag::bind(m); 3155 } 3156