1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 23 #include <utility> 24 25 namespace py = pybind11; 26 using namespace mlir; 27 using namespace mlir::python; 28 29 using llvm::SmallVector; 30 using llvm::StringRef; 31 using llvm::Twine; 32 33 //------------------------------------------------------------------------------ 34 // Docstrings (trivial, non-duplicated docstrings are included inline). 35 //------------------------------------------------------------------------------ 36 37 static const char kContextParseTypeDocstring[] = 38 R"(Parses the assembly form of a type. 39 40 Returns a Type object or raises a ValueError if the type cannot be parsed. 41 42 See also: https://mlir.llvm.org/docs/LangRef/#type-system 43 )"; 44 45 static const char kContextGetCallSiteLocationDocstring[] = 46 R"(Gets a Location representing a caller and callsite)"; 47 48 static const char kContextGetFileLocationDocstring[] = 49 R"(Gets a Location representing a file, line and column)"; 50 51 static const char kContextGetFusedLocationDocstring[] = 52 R"(Gets a Location representing a fused location with optional metadata)"; 53 54 static const char kContextGetNameLocationDocString[] = 55 R"(Gets a Location representing a named location with optional child location)"; 56 57 static const char kModuleParseDocstring[] = 58 R"(Parses a module's assembly format from a string. 59 60 Returns a new MlirModule or raises a ValueError if the parsing fails. 61 62 See also: https://mlir.llvm.org/docs/LangRef/ 63 )"; 64 65 static const char kOperationCreateDocstring[] = 66 R"(Creates a new operation. 67 68 Args: 69 name: Operation name (e.g. "dialect.operation"). 70 results: Sequence of Type representing op result types. 71 attributes: Dict of str:Attribute. 72 successors: List of Block for the operation's successors. 73 regions: Number of regions to create. 74 location: A Location object (defaults to resolve from context manager). 75 ip: An InsertionPoint (defaults to resolve from context manager or set to 76 False to disable insertion, even with an insertion point set in the 77 context manager). 78 Returns: 79 A new "detached" Operation object. Detached operations can be added 80 to blocks, which causes them to become "attached." 81 )"; 82 83 static const char kOperationPrintDocstring[] = 84 R"(Prints the assembly form of the operation to a file like object. 85 86 Args: 87 file: The file like object to write to. Defaults to sys.stdout. 88 binary: Whether to write bytes (True) or str (False). Defaults to False. 89 large_elements_limit: Whether to elide elements attributes above this 90 number of elements. Defaults to None (no limit). 91 enable_debug_info: Whether to print debug/location information. Defaults 92 to False. 93 pretty_debug_info: Whether to format debug information for easier reading 94 by a human (warning: the result is unparseable). 95 print_generic_op_form: Whether to print the generic assembly forms of all 96 ops. Defaults to False. 97 use_local_Scope: Whether to print in a way that is more optimized for 98 multi-threaded access but may not be consistent with how the overall 99 module prints. 100 assume_verified: By default, if not printing generic form, the verifier 101 will be run and if it fails, generic form will be printed with a comment 102 about failed verification. While a reasonable default for interactive use, 103 for systematic use, it is often better for the caller to verify explicitly 104 and report failures in a more robust fashion. Set this to True if doing this 105 in order to avoid running a redundant verification. If the IR is actually 106 invalid, behavior is undefined. 107 )"; 108 109 static const char kOperationGetAsmDocstring[] = 110 R"(Gets the assembly form of the operation with all options available. 111 112 Args: 113 binary: Whether to return a bytes (True) or str (False) object. Defaults to 114 False. 115 ... others ...: See the print() method for common keyword arguments for 116 configuring the printout. 117 Returns: 118 Either a bytes or str object, depending on the setting of the 'binary' 119 argument. 120 )"; 121 122 static const char kOperationStrDunderDocstring[] = 123 R"(Gets the assembly form of the operation with default options. 124 125 If more advanced control over the assembly formatting or I/O options is needed, 126 use the dedicated print or get_asm method, which supports keyword arguments to 127 customize behavior. 128 )"; 129 130 static const char kDumpDocstring[] = 131 R"(Dumps a debug representation of the object to stderr.)"; 132 133 static const char kAppendBlockDocstring[] = 134 R"(Appends a new block, with argument types as positional args. 135 136 Returns: 137 The created block. 138 )"; 139 140 static const char kValueDunderStrDocstring[] = 141 R"(Returns the string form of the value. 142 143 If the value is a block argument, this is the assembly form of its type and the 144 position in the argument list. If the value is an operation result, this is 145 equivalent to printing the operation that produced it. 146 )"; 147 148 //------------------------------------------------------------------------------ 149 // Utilities. 150 //------------------------------------------------------------------------------ 151 152 /// Helper for creating an @classmethod. 153 template <class Func, typename... Args> 154 py::object classmethod(Func f, Args... args) { 155 py::object cf = py::cpp_function(f, args...); 156 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 157 } 158 159 static py::object 160 createCustomDialectWrapper(const std::string &dialectNamespace, 161 py::object dialectDescriptor) { 162 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 163 if (!dialectClass) { 164 // Use the base class. 165 return py::cast(PyDialect(std::move(dialectDescriptor))); 166 } 167 168 // Create the custom implementation. 169 return (*dialectClass)(std::move(dialectDescriptor)); 170 } 171 172 static MlirStringRef toMlirStringRef(const std::string &s) { 173 return mlirStringRefCreate(s.data(), s.size()); 174 } 175 176 /// Wrapper for the global LLVM debugging flag. 177 struct PyGlobalDebugFlag { 178 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 179 180 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } 181 182 static void bind(py::module &m) { 183 // Debug flags. 184 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 185 .def_property_static("flag", &PyGlobalDebugFlag::get, 186 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 187 } 188 }; 189 190 //------------------------------------------------------------------------------ 191 // Collections. 192 //------------------------------------------------------------------------------ 193 194 namespace { 195 196 class PyRegionIterator { 197 public: 198 PyRegionIterator(PyOperationRef operation) 199 : operation(std::move(operation)) {} 200 201 PyRegionIterator &dunderIter() { return *this; } 202 203 PyRegion dunderNext() { 204 operation->checkValid(); 205 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 206 throw py::stop_iteration(); 207 } 208 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 209 return PyRegion(operation, region); 210 } 211 212 static void bind(py::module &m) { 213 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 214 .def("__iter__", &PyRegionIterator::dunderIter) 215 .def("__next__", &PyRegionIterator::dunderNext); 216 } 217 218 private: 219 PyOperationRef operation; 220 int nextIndex = 0; 221 }; 222 223 /// Regions of an op are fixed length and indexed numerically so are represented 224 /// with a sequence-like container. 225 class PyRegionList { 226 public: 227 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 228 229 intptr_t dunderLen() { 230 operation->checkValid(); 231 return mlirOperationGetNumRegions(operation->get()); 232 } 233 234 PyRegion dunderGetItem(intptr_t index) { 235 // dunderLen checks validity. 236 if (index < 0 || index >= dunderLen()) { 237 throw SetPyError(PyExc_IndexError, 238 "attempt to access out of bounds region"); 239 } 240 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 241 return PyRegion(operation, region); 242 } 243 244 static void bind(py::module &m) { 245 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 246 .def("__len__", &PyRegionList::dunderLen) 247 .def("__getitem__", &PyRegionList::dunderGetItem); 248 } 249 250 private: 251 PyOperationRef operation; 252 }; 253 254 class PyBlockIterator { 255 public: 256 PyBlockIterator(PyOperationRef operation, MlirBlock next) 257 : operation(std::move(operation)), next(next) {} 258 259 PyBlockIterator &dunderIter() { return *this; } 260 261 PyBlock dunderNext() { 262 operation->checkValid(); 263 if (mlirBlockIsNull(next)) { 264 throw py::stop_iteration(); 265 } 266 267 PyBlock returnBlock(operation, next); 268 next = mlirBlockGetNextInRegion(next); 269 return returnBlock; 270 } 271 272 static void bind(py::module &m) { 273 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 274 .def("__iter__", &PyBlockIterator::dunderIter) 275 .def("__next__", &PyBlockIterator::dunderNext); 276 } 277 278 private: 279 PyOperationRef operation; 280 MlirBlock next; 281 }; 282 283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 284 /// we present them as a more full-featured list-like container but optimize 285 /// it for forward iteration. Blocks are always owned by a region. 286 class PyBlockList { 287 public: 288 PyBlockList(PyOperationRef operation, MlirRegion region) 289 : operation(std::move(operation)), region(region) {} 290 291 PyBlockIterator dunderIter() { 292 operation->checkValid(); 293 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 294 } 295 296 intptr_t dunderLen() { 297 operation->checkValid(); 298 intptr_t count = 0; 299 MlirBlock block = mlirRegionGetFirstBlock(region); 300 while (!mlirBlockIsNull(block)) { 301 count += 1; 302 block = mlirBlockGetNextInRegion(block); 303 } 304 return count; 305 } 306 307 PyBlock dunderGetItem(intptr_t index) { 308 operation->checkValid(); 309 if (index < 0) { 310 throw SetPyError(PyExc_IndexError, 311 "attempt to access out of bounds block"); 312 } 313 MlirBlock block = mlirRegionGetFirstBlock(region); 314 while (!mlirBlockIsNull(block)) { 315 if (index == 0) { 316 return PyBlock(operation, block); 317 } 318 block = mlirBlockGetNextInRegion(block); 319 index -= 1; 320 } 321 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 322 } 323 324 PyBlock appendBlock(const py::args &pyArgTypes) { 325 operation->checkValid(); 326 llvm::SmallVector<MlirType, 4> argTypes; 327 llvm::SmallVector<MlirLocation, 4> argLocs; 328 argTypes.reserve(pyArgTypes.size()); 329 argLocs.reserve(pyArgTypes.size()); 330 for (auto &pyArg : pyArgTypes) { 331 argTypes.push_back(pyArg.cast<PyType &>()); 332 // TODO: Pass in a proper location here. 333 argLocs.push_back( 334 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 335 } 336 337 MlirBlock block = 338 mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 339 mlirRegionAppendOwnedBlock(region, block); 340 return PyBlock(operation, block); 341 } 342 343 static void bind(py::module &m) { 344 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 345 .def("__getitem__", &PyBlockList::dunderGetItem) 346 .def("__iter__", &PyBlockList::dunderIter) 347 .def("__len__", &PyBlockList::dunderLen) 348 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 349 } 350 351 private: 352 PyOperationRef operation; 353 MlirRegion region; 354 }; 355 356 class PyOperationIterator { 357 public: 358 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 359 : parentOperation(std::move(parentOperation)), next(next) {} 360 361 PyOperationIterator &dunderIter() { return *this; } 362 363 py::object dunderNext() { 364 parentOperation->checkValid(); 365 if (mlirOperationIsNull(next)) { 366 throw py::stop_iteration(); 367 } 368 369 PyOperationRef returnOperation = 370 PyOperation::forOperation(parentOperation->getContext(), next); 371 next = mlirOperationGetNextInBlock(next); 372 return returnOperation->createOpView(); 373 } 374 375 static void bind(py::module &m) { 376 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 377 .def("__iter__", &PyOperationIterator::dunderIter) 378 .def("__next__", &PyOperationIterator::dunderNext); 379 } 380 381 private: 382 PyOperationRef parentOperation; 383 MlirOperation next; 384 }; 385 386 /// Operations are exposed by the C-API as a forward-only linked list. In 387 /// Python, we present them as a more full-featured list-like container but 388 /// optimize it for forward iteration. Iterable operations are always owned 389 /// by a block. 390 class PyOperationList { 391 public: 392 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 393 : parentOperation(std::move(parentOperation)), block(block) {} 394 395 PyOperationIterator dunderIter() { 396 parentOperation->checkValid(); 397 return PyOperationIterator(parentOperation, 398 mlirBlockGetFirstOperation(block)); 399 } 400 401 intptr_t dunderLen() { 402 parentOperation->checkValid(); 403 intptr_t count = 0; 404 MlirOperation childOp = mlirBlockGetFirstOperation(block); 405 while (!mlirOperationIsNull(childOp)) { 406 count += 1; 407 childOp = mlirOperationGetNextInBlock(childOp); 408 } 409 return count; 410 } 411 412 py::object dunderGetItem(intptr_t index) { 413 parentOperation->checkValid(); 414 if (index < 0) { 415 throw SetPyError(PyExc_IndexError, 416 "attempt to access out of bounds operation"); 417 } 418 MlirOperation childOp = mlirBlockGetFirstOperation(block); 419 while (!mlirOperationIsNull(childOp)) { 420 if (index == 0) { 421 return PyOperation::forOperation(parentOperation->getContext(), childOp) 422 ->createOpView(); 423 } 424 childOp = mlirOperationGetNextInBlock(childOp); 425 index -= 1; 426 } 427 throw SetPyError(PyExc_IndexError, 428 "attempt to access out of bounds operation"); 429 } 430 431 static void bind(py::module &m) { 432 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 433 .def("__getitem__", &PyOperationList::dunderGetItem) 434 .def("__iter__", &PyOperationList::dunderIter) 435 .def("__len__", &PyOperationList::dunderLen); 436 } 437 438 private: 439 PyOperationRef parentOperation; 440 MlirBlock block; 441 }; 442 443 } // namespace 444 445 //------------------------------------------------------------------------------ 446 // PyMlirContext 447 //------------------------------------------------------------------------------ 448 449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 450 py::gil_scoped_acquire acquire; 451 auto &liveContexts = getLiveContexts(); 452 liveContexts[context.ptr] = this; 453 } 454 455 PyMlirContext::~PyMlirContext() { 456 // Note that the only public way to construct an instance is via the 457 // forContext method, which always puts the associated handle into 458 // liveContexts. 459 py::gil_scoped_acquire acquire; 460 getLiveContexts().erase(context.ptr); 461 mlirContextDestroy(context); 462 } 463 464 py::object PyMlirContext::getCapsule() { 465 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 466 } 467 468 py::object PyMlirContext::createFromCapsule(py::object capsule) { 469 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 470 if (mlirContextIsNull(rawContext)) 471 throw py::error_already_set(); 472 return forContext(rawContext).releaseObject(); 473 } 474 475 PyMlirContext *PyMlirContext::createNewContextForInit() { 476 MlirContext context = mlirContextCreate(); 477 mlirRegisterAllDialects(context); 478 return new PyMlirContext(context); 479 } 480 481 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 482 py::gil_scoped_acquire acquire; 483 auto &liveContexts = getLiveContexts(); 484 auto it = liveContexts.find(context.ptr); 485 if (it == liveContexts.end()) { 486 // Create. 487 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 488 py::object pyRef = py::cast(unownedContextWrapper); 489 assert(pyRef && "cast to py::object failed"); 490 liveContexts[context.ptr] = unownedContextWrapper; 491 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 492 } 493 // Use existing. 494 py::object pyRef = py::cast(it->second); 495 return PyMlirContextRef(it->second, std::move(pyRef)); 496 } 497 498 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 499 static LiveContextMap liveContexts; 500 return liveContexts; 501 } 502 503 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 504 505 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 506 507 size_t PyMlirContext::clearLiveOperations() { 508 for (auto &op : liveOperations) 509 op.second.second->setInvalid(); 510 size_t numInvalidated = liveOperations.size(); 511 liveOperations.clear(); 512 return numInvalidated; 513 } 514 515 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 516 517 pybind11::object PyMlirContext::contextEnter() { 518 return PyThreadContextEntry::pushContext(*this); 519 } 520 521 void PyMlirContext::contextExit(const pybind11::object &excType, 522 const pybind11::object &excVal, 523 const pybind11::object &excTb) { 524 PyThreadContextEntry::popContext(*this); 525 } 526 527 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { 528 // Note that ownership is transferred to the delete callback below by way of 529 // an explicit inc_ref (borrow). 530 PyDiagnosticHandler *pyHandler = 531 new PyDiagnosticHandler(get(), std::move(callback)); 532 py::object pyHandlerObject = 533 py::cast(pyHandler, py::return_value_policy::take_ownership); 534 pyHandlerObject.inc_ref(); 535 536 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 537 // guaranteed to be known to pybind. 538 auto handlerCallback = 539 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 540 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 541 py::object pyDiagnosticObject = 542 py::cast(pyDiagnostic, py::return_value_policy::take_ownership); 543 544 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 545 bool result = false; 546 { 547 // Since this can be called from arbitrary C++ contexts, always get the 548 // gil. 549 py::gil_scoped_acquire gil; 550 try { 551 result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); 552 } catch (std::exception &e) { 553 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 554 e.what()); 555 pyHandler->hadError = true; 556 } 557 } 558 559 pyDiagnostic->invalidate(); 560 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 561 }; 562 auto deleteCallback = +[](void *userData) { 563 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 564 assert(pyHandler->registeredID && "handler is not registered"); 565 pyHandler->registeredID.reset(); 566 567 // Decrement reference, balancing the inc_ref() above. 568 py::object pyHandlerObject = 569 py::cast(pyHandler, py::return_value_policy::reference); 570 pyHandlerObject.dec_ref(); 571 }; 572 573 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 574 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 575 return pyHandlerObject; 576 } 577 578 PyMlirContext &DefaultingPyMlirContext::resolve() { 579 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 580 if (!context) { 581 throw SetPyError( 582 PyExc_RuntimeError, 583 "An MLIR function requires a Context but none was provided in the call " 584 "or from the surrounding environment. Either pass to the function with " 585 "a 'context=' argument or establish a default using 'with Context():'"); 586 } 587 return *context; 588 } 589 590 //------------------------------------------------------------------------------ 591 // PyThreadContextEntry management 592 //------------------------------------------------------------------------------ 593 594 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 595 static thread_local std::vector<PyThreadContextEntry> stack; 596 return stack; 597 } 598 599 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 600 auto &stack = getStack(); 601 if (stack.empty()) 602 return nullptr; 603 return &stack.back(); 604 } 605 606 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 607 py::object insertionPoint, 608 py::object location) { 609 auto &stack = getStack(); 610 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 611 std::move(location)); 612 // If the new stack has more than one entry and the context of the new top 613 // entry matches the previous, copy the insertionPoint and location from the 614 // previous entry if missing from the new top entry. 615 if (stack.size() > 1) { 616 auto &prev = *(stack.rbegin() + 1); 617 auto ¤t = stack.back(); 618 if (current.context.is(prev.context)) { 619 // Default non-context objects from the previous entry. 620 if (!current.insertionPoint) 621 current.insertionPoint = prev.insertionPoint; 622 if (!current.location) 623 current.location = prev.location; 624 } 625 } 626 } 627 628 PyMlirContext *PyThreadContextEntry::getContext() { 629 if (!context) 630 return nullptr; 631 return py::cast<PyMlirContext *>(context); 632 } 633 634 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 635 if (!insertionPoint) 636 return nullptr; 637 return py::cast<PyInsertionPoint *>(insertionPoint); 638 } 639 640 PyLocation *PyThreadContextEntry::getLocation() { 641 if (!location) 642 return nullptr; 643 return py::cast<PyLocation *>(location); 644 } 645 646 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 647 auto *tos = getTopOfStack(); 648 return tos ? tos->getContext() : nullptr; 649 } 650 651 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 652 auto *tos = getTopOfStack(); 653 return tos ? tos->getInsertionPoint() : nullptr; 654 } 655 656 PyLocation *PyThreadContextEntry::getDefaultLocation() { 657 auto *tos = getTopOfStack(); 658 return tos ? tos->getLocation() : nullptr; 659 } 660 661 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 662 py::object contextObj = py::cast(context); 663 push(FrameKind::Context, /*context=*/contextObj, 664 /*insertionPoint=*/py::object(), 665 /*location=*/py::object()); 666 return contextObj; 667 } 668 669 void PyThreadContextEntry::popContext(PyMlirContext &context) { 670 auto &stack = getStack(); 671 if (stack.empty()) 672 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 673 auto &tos = stack.back(); 674 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 675 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 676 stack.pop_back(); 677 } 678 679 py::object 680 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 681 py::object contextObj = 682 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 683 py::object insertionPointObj = py::cast(insertionPoint); 684 push(FrameKind::InsertionPoint, 685 /*context=*/contextObj, 686 /*insertionPoint=*/insertionPointObj, 687 /*location=*/py::object()); 688 return insertionPointObj; 689 } 690 691 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 692 auto &stack = getStack(); 693 if (stack.empty()) 694 throw SetPyError(PyExc_RuntimeError, 695 "Unbalanced InsertionPoint enter/exit"); 696 auto &tos = stack.back(); 697 if (tos.frameKind != FrameKind::InsertionPoint && 698 tos.getInsertionPoint() != &insertionPoint) 699 throw SetPyError(PyExc_RuntimeError, 700 "Unbalanced InsertionPoint enter/exit"); 701 stack.pop_back(); 702 } 703 704 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 705 py::object contextObj = location.getContext().getObject(); 706 py::object locationObj = py::cast(location); 707 push(FrameKind::Location, /*context=*/contextObj, 708 /*insertionPoint=*/py::object(), 709 /*location=*/locationObj); 710 return locationObj; 711 } 712 713 void PyThreadContextEntry::popLocation(PyLocation &location) { 714 auto &stack = getStack(); 715 if (stack.empty()) 716 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 717 auto &tos = stack.back(); 718 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 719 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 720 stack.pop_back(); 721 } 722 723 //------------------------------------------------------------------------------ 724 // PyDiagnostic* 725 //------------------------------------------------------------------------------ 726 727 void PyDiagnostic::invalidate() { 728 valid = false; 729 if (materializedNotes) { 730 for (auto ¬eObject : *materializedNotes) { 731 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); 732 note->invalidate(); 733 } 734 } 735 } 736 737 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 738 py::object callback) 739 : context(context), callback(std::move(callback)) {} 740 741 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 742 743 void PyDiagnosticHandler::detach() { 744 if (!registeredID) 745 return; 746 MlirDiagnosticHandlerID localID = *registeredID; 747 mlirContextDetachDiagnosticHandler(context, localID); 748 assert(!registeredID && "should have unregistered"); 749 // Not strictly necessary but keeps stale pointers from being around to cause 750 // issues. 751 context = {nullptr}; 752 } 753 754 void PyDiagnostic::checkValid() { 755 if (!valid) { 756 throw std::invalid_argument( 757 "Diagnostic is invalid (used outside of callback)"); 758 } 759 } 760 761 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 762 checkValid(); 763 return mlirDiagnosticGetSeverity(diagnostic); 764 } 765 766 PyLocation PyDiagnostic::getLocation() { 767 checkValid(); 768 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 769 MlirContext context = mlirLocationGetContext(loc); 770 return PyLocation(PyMlirContext::forContext(context), loc); 771 } 772 773 py::str PyDiagnostic::getMessage() { 774 checkValid(); 775 py::object fileObject = py::module::import("io").attr("StringIO")(); 776 PyFileAccumulator accum(fileObject, /*binary=*/false); 777 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 778 return fileObject.attr("getvalue")(); 779 } 780 781 py::tuple PyDiagnostic::getNotes() { 782 checkValid(); 783 if (materializedNotes) 784 return *materializedNotes; 785 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 786 materializedNotes = py::tuple(numNotes); 787 for (intptr_t i = 0; i < numNotes; ++i) { 788 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 789 py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); 790 PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); 791 } 792 return *materializedNotes; 793 } 794 795 //------------------------------------------------------------------------------ 796 // PyDialect, PyDialectDescriptor, PyDialects 797 //------------------------------------------------------------------------------ 798 799 MlirDialect PyDialects::getDialectForKey(const std::string &key, 800 bool attrError) { 801 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 802 {key.data(), key.size()}); 803 if (mlirDialectIsNull(dialect)) { 804 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 805 Twine("Dialect '") + key + "' not found"); 806 } 807 return dialect; 808 } 809 810 //------------------------------------------------------------------------------ 811 // PyLocation 812 //------------------------------------------------------------------------------ 813 814 py::object PyLocation::getCapsule() { 815 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 816 } 817 818 PyLocation PyLocation::createFromCapsule(py::object capsule) { 819 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 820 if (mlirLocationIsNull(rawLoc)) 821 throw py::error_already_set(); 822 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 823 rawLoc); 824 } 825 826 py::object PyLocation::contextEnter() { 827 return PyThreadContextEntry::pushLocation(*this); 828 } 829 830 void PyLocation::contextExit(const pybind11::object &excType, 831 const pybind11::object &excVal, 832 const pybind11::object &excTb) { 833 PyThreadContextEntry::popLocation(*this); 834 } 835 836 PyLocation &DefaultingPyLocation::resolve() { 837 auto *location = PyThreadContextEntry::getDefaultLocation(); 838 if (!location) { 839 throw SetPyError( 840 PyExc_RuntimeError, 841 "An MLIR function requires a Location but none was provided in the " 842 "call or from the surrounding environment. Either pass to the function " 843 "with a 'loc=' argument or establish a default using 'with loc:'"); 844 } 845 return *location; 846 } 847 848 //------------------------------------------------------------------------------ 849 // PyModule 850 //------------------------------------------------------------------------------ 851 852 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 853 : BaseContextObject(std::move(contextRef)), module(module) {} 854 855 PyModule::~PyModule() { 856 py::gil_scoped_acquire acquire; 857 auto &liveModules = getContext()->liveModules; 858 assert(liveModules.count(module.ptr) == 1 && 859 "destroying module not in live map"); 860 liveModules.erase(module.ptr); 861 mlirModuleDestroy(module); 862 } 863 864 PyModuleRef PyModule::forModule(MlirModule module) { 865 MlirContext context = mlirModuleGetContext(module); 866 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 867 868 py::gil_scoped_acquire acquire; 869 auto &liveModules = contextRef->liveModules; 870 auto it = liveModules.find(module.ptr); 871 if (it == liveModules.end()) { 872 // Create. 873 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 874 // Note that the default return value policy on cast is automatic_reference, 875 // which does not take ownership (delete will not be called). 876 // Just be explicit. 877 py::object pyRef = 878 py::cast(unownedModule, py::return_value_policy::take_ownership); 879 unownedModule->handle = pyRef; 880 liveModules[module.ptr] = 881 std::make_pair(unownedModule->handle, unownedModule); 882 return PyModuleRef(unownedModule, std::move(pyRef)); 883 } 884 // Use existing. 885 PyModule *existing = it->second.second; 886 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 887 return PyModuleRef(existing, std::move(pyRef)); 888 } 889 890 py::object PyModule::createFromCapsule(py::object capsule) { 891 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 892 if (mlirModuleIsNull(rawModule)) 893 throw py::error_already_set(); 894 return forModule(rawModule).releaseObject(); 895 } 896 897 py::object PyModule::getCapsule() { 898 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 899 } 900 901 //------------------------------------------------------------------------------ 902 // PyOperation 903 //------------------------------------------------------------------------------ 904 905 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 906 : BaseContextObject(std::move(contextRef)), operation(operation) {} 907 908 PyOperation::~PyOperation() { 909 // If the operation has already been invalidated there is nothing to do. 910 if (!valid) 911 return; 912 auto &liveOperations = getContext()->liveOperations; 913 assert(liveOperations.count(operation.ptr) == 1 && 914 "destroying operation not in live map"); 915 liveOperations.erase(operation.ptr); 916 if (!isAttached()) { 917 mlirOperationDestroy(operation); 918 } 919 } 920 921 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 922 MlirOperation operation, 923 py::object parentKeepAlive) { 924 auto &liveOperations = contextRef->liveOperations; 925 // Create. 926 PyOperation *unownedOperation = 927 new PyOperation(std::move(contextRef), operation); 928 // Note that the default return value policy on cast is automatic_reference, 929 // which does not take ownership (delete will not be called). 930 // Just be explicit. 931 py::object pyRef = 932 py::cast(unownedOperation, py::return_value_policy::take_ownership); 933 unownedOperation->handle = pyRef; 934 if (parentKeepAlive) { 935 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 936 } 937 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 938 return PyOperationRef(unownedOperation, std::move(pyRef)); 939 } 940 941 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 942 MlirOperation operation, 943 py::object parentKeepAlive) { 944 auto &liveOperations = contextRef->liveOperations; 945 auto it = liveOperations.find(operation.ptr); 946 if (it == liveOperations.end()) { 947 // Create. 948 return createInstance(std::move(contextRef), operation, 949 std::move(parentKeepAlive)); 950 } 951 // Use existing. 952 PyOperation *existing = it->second.second; 953 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 954 return PyOperationRef(existing, std::move(pyRef)); 955 } 956 957 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 958 MlirOperation operation, 959 py::object parentKeepAlive) { 960 auto &liveOperations = contextRef->liveOperations; 961 assert(liveOperations.count(operation.ptr) == 0 && 962 "cannot create detached operation that already exists"); 963 (void)liveOperations; 964 965 PyOperationRef created = createInstance(std::move(contextRef), operation, 966 std::move(parentKeepAlive)); 967 created->attached = false; 968 return created; 969 } 970 971 void PyOperation::checkValid() const { 972 if (!valid) { 973 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 974 } 975 } 976 977 void PyOperationBase::print(py::object fileObject, bool binary, 978 llvm::Optional<int64_t> largeElementsLimit, 979 bool enableDebugInfo, bool prettyDebugInfo, 980 bool printGenericOpForm, bool useLocalScope, 981 bool assumeVerified) { 982 PyOperation &operation = getOperation(); 983 operation.checkValid(); 984 if (fileObject.is_none()) 985 fileObject = py::module::import("sys").attr("stdout"); 986 987 if (!assumeVerified && !printGenericOpForm && 988 !mlirOperationVerify(operation)) { 989 std::string message("// Verification failed, printing generic form\n"); 990 if (binary) { 991 fileObject.attr("write")(py::bytes(message)); 992 } else { 993 fileObject.attr("write")(py::str(message)); 994 } 995 printGenericOpForm = true; 996 } 997 998 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 999 if (largeElementsLimit) 1000 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 1001 if (enableDebugInfo) 1002 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 1003 if (printGenericOpForm) 1004 mlirOpPrintingFlagsPrintGenericOpForm(flags); 1005 if (useLocalScope) 1006 mlirOpPrintingFlagsUseLocalScope(flags); 1007 1008 PyFileAccumulator accum(fileObject, binary); 1009 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1010 accum.getUserData()); 1011 mlirOpPrintingFlagsDestroy(flags); 1012 } 1013 1014 py::object PyOperationBase::getAsm(bool binary, 1015 llvm::Optional<int64_t> largeElementsLimit, 1016 bool enableDebugInfo, bool prettyDebugInfo, 1017 bool printGenericOpForm, bool useLocalScope, 1018 bool assumeVerified) { 1019 py::object fileObject; 1020 if (binary) { 1021 fileObject = py::module::import("io").attr("BytesIO")(); 1022 } else { 1023 fileObject = py::module::import("io").attr("StringIO")(); 1024 } 1025 print(fileObject, /*binary=*/binary, 1026 /*largeElementsLimit=*/largeElementsLimit, 1027 /*enableDebugInfo=*/enableDebugInfo, 1028 /*prettyDebugInfo=*/prettyDebugInfo, 1029 /*printGenericOpForm=*/printGenericOpForm, 1030 /*useLocalScope=*/useLocalScope, 1031 /*assumeVerified=*/assumeVerified); 1032 1033 return fileObject.attr("getvalue")(); 1034 } 1035 1036 void PyOperationBase::moveAfter(PyOperationBase &other) { 1037 PyOperation &operation = getOperation(); 1038 PyOperation &otherOp = other.getOperation(); 1039 operation.checkValid(); 1040 otherOp.checkValid(); 1041 mlirOperationMoveAfter(operation, otherOp); 1042 operation.parentKeepAlive = otherOp.parentKeepAlive; 1043 } 1044 1045 void PyOperationBase::moveBefore(PyOperationBase &other) { 1046 PyOperation &operation = getOperation(); 1047 PyOperation &otherOp = other.getOperation(); 1048 operation.checkValid(); 1049 otherOp.checkValid(); 1050 mlirOperationMoveBefore(operation, otherOp); 1051 operation.parentKeepAlive = otherOp.parentKeepAlive; 1052 } 1053 1054 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 1055 checkValid(); 1056 if (!isAttached()) 1057 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 1058 MlirOperation operation = mlirOperationGetParentOperation(get()); 1059 if (mlirOperationIsNull(operation)) 1060 return {}; 1061 return PyOperation::forOperation(getContext(), operation); 1062 } 1063 1064 PyBlock PyOperation::getBlock() { 1065 checkValid(); 1066 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 1067 MlirBlock block = mlirOperationGetBlock(get()); 1068 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1069 assert(parentOperation && "Operation has no parent"); 1070 return PyBlock{std::move(*parentOperation), block}; 1071 } 1072 1073 py::object PyOperation::getCapsule() { 1074 checkValid(); 1075 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 1076 } 1077 1078 py::object PyOperation::createFromCapsule(py::object capsule) { 1079 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1080 if (mlirOperationIsNull(rawOperation)) 1081 throw py::error_already_set(); 1082 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1083 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1084 .releaseObject(); 1085 } 1086 1087 static void maybeInsertOperation(PyOperationRef &op, 1088 const py::object &maybeIp) { 1089 // InsertPoint active? 1090 if (!maybeIp.is(py::cast(false))) { 1091 PyInsertionPoint *ip; 1092 if (maybeIp.is_none()) { 1093 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1094 } else { 1095 ip = py::cast<PyInsertionPoint *>(maybeIp); 1096 } 1097 if (ip) 1098 ip->insert(*op.get()); 1099 } 1100 } 1101 1102 py::object PyOperation::create( 1103 const std::string &name, llvm::Optional<std::vector<PyType *>> results, 1104 llvm::Optional<std::vector<PyValue *>> operands, 1105 llvm::Optional<py::dict> attributes, 1106 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 1107 DefaultingPyLocation location, const py::object &maybeIp) { 1108 llvm::SmallVector<MlirValue, 4> mlirOperands; 1109 llvm::SmallVector<MlirType, 4> mlirResults; 1110 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1111 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1112 1113 // General parameter validation. 1114 if (regions < 0) 1115 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 1116 1117 // Unpack/validate operands. 1118 if (operands) { 1119 mlirOperands.reserve(operands->size()); 1120 for (PyValue *operand : *operands) { 1121 if (!operand) 1122 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 1123 mlirOperands.push_back(operand->get()); 1124 } 1125 } 1126 1127 // Unpack/validate results. 1128 if (results) { 1129 mlirResults.reserve(results->size()); 1130 for (PyType *result : *results) { 1131 // TODO: Verify result type originate from the same context. 1132 if (!result) 1133 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 1134 mlirResults.push_back(*result); 1135 } 1136 } 1137 // Unpack/validate attributes. 1138 if (attributes) { 1139 mlirAttributes.reserve(attributes->size()); 1140 for (auto &it : *attributes) { 1141 std::string key; 1142 try { 1143 key = it.first.cast<std::string>(); 1144 } catch (py::cast_error &err) { 1145 std::string msg = "Invalid attribute key (not a string) when " 1146 "attempting to create the operation \"" + 1147 name + "\" (" + err.what() + ")"; 1148 throw py::cast_error(msg); 1149 } 1150 try { 1151 auto &attribute = it.second.cast<PyAttribute &>(); 1152 // TODO: Verify attribute originates from the same context. 1153 mlirAttributes.emplace_back(std::move(key), attribute); 1154 } catch (py::reference_cast_error &) { 1155 // This exception seems thrown when the value is "None". 1156 std::string msg = 1157 "Found an invalid (`None`?) attribute value for the key \"" + key + 1158 "\" when attempting to create the operation \"" + name + "\""; 1159 throw py::cast_error(msg); 1160 } catch (py::cast_error &err) { 1161 std::string msg = "Invalid attribute value for the key \"" + key + 1162 "\" when attempting to create the operation \"" + 1163 name + "\" (" + err.what() + ")"; 1164 throw py::cast_error(msg); 1165 } 1166 } 1167 } 1168 // Unpack/validate successors. 1169 if (successors) { 1170 mlirSuccessors.reserve(successors->size()); 1171 for (auto *successor : *successors) { 1172 // TODO: Verify successor originate from the same context. 1173 if (!successor) 1174 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1175 mlirSuccessors.push_back(successor->get()); 1176 } 1177 } 1178 1179 // Apply unpacked/validated to the operation state. Beyond this 1180 // point, exceptions cannot be thrown or else the state will leak. 1181 MlirOperationState state = 1182 mlirOperationStateGet(toMlirStringRef(name), location); 1183 if (!mlirOperands.empty()) 1184 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1185 mlirOperands.data()); 1186 if (!mlirResults.empty()) 1187 mlirOperationStateAddResults(&state, mlirResults.size(), 1188 mlirResults.data()); 1189 if (!mlirAttributes.empty()) { 1190 // Note that the attribute names directly reference bytes in 1191 // mlirAttributes, so that vector must not be changed from here 1192 // on. 1193 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1194 mlirNamedAttributes.reserve(mlirAttributes.size()); 1195 for (auto &it : mlirAttributes) 1196 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1197 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1198 toMlirStringRef(it.first)), 1199 it.second)); 1200 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1201 mlirNamedAttributes.data()); 1202 } 1203 if (!mlirSuccessors.empty()) 1204 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1205 mlirSuccessors.data()); 1206 if (regions) { 1207 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1208 mlirRegions.resize(regions); 1209 for (int i = 0; i < regions; ++i) 1210 mlirRegions[i] = mlirRegionCreate(); 1211 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1212 mlirRegions.data()); 1213 } 1214 1215 // Construct the operation. 1216 MlirOperation operation = mlirOperationCreate(&state); 1217 PyOperationRef created = 1218 PyOperation::createDetached(location->getContext(), operation); 1219 maybeInsertOperation(created, maybeIp); 1220 1221 return created->createOpView(); 1222 } 1223 1224 py::object PyOperation::clone(const py::object &maybeIp) { 1225 MlirOperation clonedOperation = mlirOperationClone(operation); 1226 PyOperationRef cloned = 1227 PyOperation::createDetached(getContext(), clonedOperation); 1228 maybeInsertOperation(cloned, maybeIp); 1229 1230 return cloned->createOpView(); 1231 } 1232 1233 py::object PyOperation::createOpView() { 1234 checkValid(); 1235 MlirIdentifier ident = mlirOperationGetName(get()); 1236 MlirStringRef identStr = mlirIdentifierStr(ident); 1237 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1238 StringRef(identStr.data, identStr.length)); 1239 if (opViewClass) 1240 return (*opViewClass)(getRef().getObject()); 1241 return py::cast(PyOpView(getRef().getObject())); 1242 } 1243 1244 void PyOperation::erase() { 1245 checkValid(); 1246 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1247 // Python reference to a child operation is live. All children should also 1248 // have their `valid` bit set to false. 1249 auto &liveOperations = getContext()->liveOperations; 1250 if (liveOperations.count(operation.ptr)) 1251 liveOperations.erase(operation.ptr); 1252 mlirOperationDestroy(operation); 1253 valid = false; 1254 } 1255 1256 //------------------------------------------------------------------------------ 1257 // PyOpView 1258 //------------------------------------------------------------------------------ 1259 1260 py::object PyOpView::buildGeneric( 1261 const py::object &cls, py::list resultTypeList, py::list operandList, 1262 llvm::Optional<py::dict> attributes, 1263 llvm::Optional<std::vector<PyBlock *>> successors, 1264 llvm::Optional<int> regions, DefaultingPyLocation location, 1265 const py::object &maybeIp) { 1266 PyMlirContextRef context = location->getContext(); 1267 // Class level operation construction metadata. 1268 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1269 // Operand and result segment specs are either none, which does no 1270 // variadic unpacking, or a list of ints with segment sizes, where each 1271 // element is either a positive number (typically 1 for a scalar) or -1 to 1272 // indicate that it is derived from the length of the same-indexed operand 1273 // or result (implying that it is a list at that position). 1274 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1275 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1276 1277 std::vector<uint32_t> operandSegmentLengths; 1278 std::vector<uint32_t> resultSegmentLengths; 1279 1280 // Validate/determine region count. 1281 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1282 int opMinRegionCount = std::get<0>(opRegionSpec); 1283 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1284 if (!regions) { 1285 regions = opMinRegionCount; 1286 } 1287 if (*regions < opMinRegionCount) { 1288 throw py::value_error( 1289 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1290 llvm::Twine(opMinRegionCount) + 1291 " regions but was built with regions=" + llvm::Twine(*regions)) 1292 .str()); 1293 } 1294 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1295 throw py::value_error( 1296 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1297 llvm::Twine(opMinRegionCount) + 1298 " regions but was built with regions=" + llvm::Twine(*regions)) 1299 .str()); 1300 } 1301 1302 // Unpack results. 1303 std::vector<PyType *> resultTypes; 1304 resultTypes.reserve(resultTypeList.size()); 1305 if (resultSegmentSpecObj.is_none()) { 1306 // Non-variadic result unpacking. 1307 for (const auto &it : llvm::enumerate(resultTypeList)) { 1308 try { 1309 resultTypes.push_back(py::cast<PyType *>(it.value())); 1310 if (!resultTypes.back()) 1311 throw py::cast_error(); 1312 } catch (py::cast_error &err) { 1313 throw py::value_error((llvm::Twine("Result ") + 1314 llvm::Twine(it.index()) + " of operation \"" + 1315 name + "\" must be a Type (" + err.what() + ")") 1316 .str()); 1317 } 1318 } 1319 } else { 1320 // Sized result unpacking. 1321 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1322 if (resultSegmentSpec.size() != resultTypeList.size()) { 1323 throw py::value_error((llvm::Twine("Operation \"") + name + 1324 "\" requires " + 1325 llvm::Twine(resultSegmentSpec.size()) + 1326 " result segments but was provided " + 1327 llvm::Twine(resultTypeList.size())) 1328 .str()); 1329 } 1330 resultSegmentLengths.reserve(resultTypeList.size()); 1331 for (const auto &it : 1332 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1333 int segmentSpec = std::get<1>(it.value()); 1334 if (segmentSpec == 1 || segmentSpec == 0) { 1335 // Unpack unary element. 1336 try { 1337 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1338 if (resultType) { 1339 resultTypes.push_back(resultType); 1340 resultSegmentLengths.push_back(1); 1341 } else if (segmentSpec == 0) { 1342 // Allowed to be optional. 1343 resultSegmentLengths.push_back(0); 1344 } else { 1345 throw py::cast_error("was None and result is not optional"); 1346 } 1347 } catch (py::cast_error &err) { 1348 throw py::value_error((llvm::Twine("Result ") + 1349 llvm::Twine(it.index()) + " of operation \"" + 1350 name + "\" must be a Type (" + err.what() + 1351 ")") 1352 .str()); 1353 } 1354 } else if (segmentSpec == -1) { 1355 // Unpack sequence by appending. 1356 try { 1357 if (std::get<0>(it.value()).is_none()) { 1358 // Treat it as an empty list. 1359 resultSegmentLengths.push_back(0); 1360 } else { 1361 // Unpack the list. 1362 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1363 for (py::object segmentItem : segment) { 1364 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1365 if (!resultTypes.back()) { 1366 throw py::cast_error("contained a None item"); 1367 } 1368 } 1369 resultSegmentLengths.push_back(segment.size()); 1370 } 1371 } catch (std::exception &err) { 1372 // NOTE: Sloppy to be using a catch-all here, but there are at least 1373 // three different unrelated exceptions that can be thrown in the 1374 // above "casts". Just keep the scope above small and catch them all. 1375 throw py::value_error((llvm::Twine("Result ") + 1376 llvm::Twine(it.index()) + " of operation \"" + 1377 name + "\" must be a Sequence of Types (" + 1378 err.what() + ")") 1379 .str()); 1380 } 1381 } else { 1382 throw py::value_error("Unexpected segment spec"); 1383 } 1384 } 1385 } 1386 1387 // Unpack operands. 1388 std::vector<PyValue *> operands; 1389 operands.reserve(operands.size()); 1390 if (operandSegmentSpecObj.is_none()) { 1391 // Non-sized operand unpacking. 1392 for (const auto &it : llvm::enumerate(operandList)) { 1393 try { 1394 operands.push_back(py::cast<PyValue *>(it.value())); 1395 if (!operands.back()) 1396 throw py::cast_error(); 1397 } catch (py::cast_error &err) { 1398 throw py::value_error((llvm::Twine("Operand ") + 1399 llvm::Twine(it.index()) + " of operation \"" + 1400 name + "\" must be a Value (" + err.what() + ")") 1401 .str()); 1402 } 1403 } 1404 } else { 1405 // Sized operand unpacking. 1406 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1407 if (operandSegmentSpec.size() != operandList.size()) { 1408 throw py::value_error((llvm::Twine("Operation \"") + name + 1409 "\" requires " + 1410 llvm::Twine(operandSegmentSpec.size()) + 1411 "operand segments but was provided " + 1412 llvm::Twine(operandList.size())) 1413 .str()); 1414 } 1415 operandSegmentLengths.reserve(operandList.size()); 1416 for (const auto &it : 1417 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1418 int segmentSpec = std::get<1>(it.value()); 1419 if (segmentSpec == 1 || segmentSpec == 0) { 1420 // Unpack unary element. 1421 try { 1422 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1423 if (operandValue) { 1424 operands.push_back(operandValue); 1425 operandSegmentLengths.push_back(1); 1426 } else if (segmentSpec == 0) { 1427 // Allowed to be optional. 1428 operandSegmentLengths.push_back(0); 1429 } else { 1430 throw py::cast_error("was None and operand is not optional"); 1431 } 1432 } catch (py::cast_error &err) { 1433 throw py::value_error((llvm::Twine("Operand ") + 1434 llvm::Twine(it.index()) + " of operation \"" + 1435 name + "\" must be a Value (" + err.what() + 1436 ")") 1437 .str()); 1438 } 1439 } else if (segmentSpec == -1) { 1440 // Unpack sequence by appending. 1441 try { 1442 if (std::get<0>(it.value()).is_none()) { 1443 // Treat it as an empty list. 1444 operandSegmentLengths.push_back(0); 1445 } else { 1446 // Unpack the list. 1447 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1448 for (py::object segmentItem : segment) { 1449 operands.push_back(py::cast<PyValue *>(segmentItem)); 1450 if (!operands.back()) { 1451 throw py::cast_error("contained a None item"); 1452 } 1453 } 1454 operandSegmentLengths.push_back(segment.size()); 1455 } 1456 } catch (std::exception &err) { 1457 // NOTE: Sloppy to be using a catch-all here, but there are at least 1458 // three different unrelated exceptions that can be thrown in the 1459 // above "casts". Just keep the scope above small and catch them all. 1460 throw py::value_error((llvm::Twine("Operand ") + 1461 llvm::Twine(it.index()) + " of operation \"" + 1462 name + "\" must be a Sequence of Values (" + 1463 err.what() + ")") 1464 .str()); 1465 } 1466 } else { 1467 throw py::value_error("Unexpected segment spec"); 1468 } 1469 } 1470 } 1471 1472 // Merge operand/result segment lengths into attributes if needed. 1473 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1474 // Dup. 1475 if (attributes) { 1476 attributes = py::dict(*attributes); 1477 } else { 1478 attributes = py::dict(); 1479 } 1480 if (attributes->contains("result_segment_sizes") || 1481 attributes->contains("operand_segment_sizes")) { 1482 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1483 "'operand_segment_sizes' attribute is unsupported. " 1484 "Use Operation.create for such low-level access."); 1485 } 1486 1487 // Add result_segment_sizes attribute. 1488 if (!resultSegmentLengths.empty()) { 1489 int64_t size = resultSegmentLengths.size(); 1490 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1491 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1492 resultSegmentLengths.size(), resultSegmentLengths.data()); 1493 (*attributes)["result_segment_sizes"] = 1494 PyAttribute(context, segmentLengthAttr); 1495 } 1496 1497 // Add operand_segment_sizes attribute. 1498 if (!operandSegmentLengths.empty()) { 1499 int64_t size = operandSegmentLengths.size(); 1500 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1501 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1502 operandSegmentLengths.size(), operandSegmentLengths.data()); 1503 (*attributes)["operand_segment_sizes"] = 1504 PyAttribute(context, segmentLengthAttr); 1505 } 1506 } 1507 1508 // Delegate to create. 1509 return PyOperation::create(name, 1510 /*results=*/std::move(resultTypes), 1511 /*operands=*/std::move(operands), 1512 /*attributes=*/std::move(attributes), 1513 /*successors=*/std::move(successors), 1514 /*regions=*/*regions, location, maybeIp); 1515 } 1516 1517 PyOpView::PyOpView(const py::object &operationObject) 1518 // Casting through the PyOperationBase base-class and then back to the 1519 // Operation lets us accept any PyOperationBase subclass. 1520 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1521 operationObject(operation.getRef().getObject()) {} 1522 1523 py::object PyOpView::createRawSubclass(const py::object &userClass) { 1524 // This is... a little gross. The typical pattern is to have a pure python 1525 // class that extends OpView like: 1526 // class AddFOp(_cext.ir.OpView): 1527 // def __init__(self, loc, lhs, rhs): 1528 // operation = loc.context.create_operation( 1529 // "addf", lhs, rhs, results=[lhs.type]) 1530 // super().__init__(operation) 1531 // 1532 // I.e. The goal of the user facing type is to provide a nice constructor 1533 // that has complete freedom for the op under construction. This is at odds 1534 // with our other desire to sometimes create this object by just passing an 1535 // operation (to initialize the base class). We could do *arg and **kwargs 1536 // munging to try to make it work, but instead, we synthesize a new class 1537 // on the fly which extends this user class (AddFOp in this example) and 1538 // *give it* the base class's __init__ method, thus bypassing the 1539 // intermediate subclass's __init__ method entirely. While slightly, 1540 // underhanded, this is safe/legal because the type hierarchy has not changed 1541 // (we just added a new leaf) and we aren't mucking around with __new__. 1542 // Typically, this new class will be stored on the original as "_Raw" and will 1543 // be used for casts and other things that need a variant of the class that 1544 // is initialized purely from an operation. 1545 py::object parentMetaclass = 1546 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1547 py::dict attributes; 1548 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1549 // now. 1550 // auto opViewType = py::type::of<PyOpView>(); 1551 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1552 attributes["__init__"] = opViewType.attr("__init__"); 1553 py::str origName = userClass.attr("__name__"); 1554 py::str newName = py::str("_") + origName; 1555 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1556 } 1557 1558 //------------------------------------------------------------------------------ 1559 // PyInsertionPoint. 1560 //------------------------------------------------------------------------------ 1561 1562 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1563 1564 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1565 : refOperation(beforeOperationBase.getOperation().getRef()), 1566 block((*refOperation)->getBlock()) {} 1567 1568 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1569 PyOperation &operation = operationBase.getOperation(); 1570 if (operation.isAttached()) 1571 throw SetPyError(PyExc_ValueError, 1572 "Attempt to insert operation that is already attached"); 1573 block.getParentOperation()->checkValid(); 1574 MlirOperation beforeOp = {nullptr}; 1575 if (refOperation) { 1576 // Insert before operation. 1577 (*refOperation)->checkValid(); 1578 beforeOp = (*refOperation)->get(); 1579 } else { 1580 // Insert at end (before null) is only valid if the block does not 1581 // already end in a known terminator (violating this will cause assertion 1582 // failures later). 1583 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1584 throw py::index_error("Cannot insert operation at the end of a block " 1585 "that already has a terminator. Did you mean to " 1586 "use 'InsertionPoint.at_block_terminator(block)' " 1587 "versus 'InsertionPoint(block)'?"); 1588 } 1589 } 1590 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1591 operation.setAttached(); 1592 } 1593 1594 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1595 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1596 if (mlirOperationIsNull(firstOp)) { 1597 // Just insert at end. 1598 return PyInsertionPoint(block); 1599 } 1600 1601 // Insert before first op. 1602 PyOperationRef firstOpRef = PyOperation::forOperation( 1603 block.getParentOperation()->getContext(), firstOp); 1604 return PyInsertionPoint{block, std::move(firstOpRef)}; 1605 } 1606 1607 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1608 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1609 if (mlirOperationIsNull(terminator)) 1610 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1611 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1612 block.getParentOperation()->getContext(), terminator); 1613 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1614 } 1615 1616 py::object PyInsertionPoint::contextEnter() { 1617 return PyThreadContextEntry::pushInsertionPoint(*this); 1618 } 1619 1620 void PyInsertionPoint::contextExit(const pybind11::object &excType, 1621 const pybind11::object &excVal, 1622 const pybind11::object &excTb) { 1623 PyThreadContextEntry::popInsertionPoint(*this); 1624 } 1625 1626 //------------------------------------------------------------------------------ 1627 // PyAttribute. 1628 //------------------------------------------------------------------------------ 1629 1630 bool PyAttribute::operator==(const PyAttribute &other) { 1631 return mlirAttributeEqual(attr, other.attr); 1632 } 1633 1634 py::object PyAttribute::getCapsule() { 1635 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1636 } 1637 1638 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1639 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1640 if (mlirAttributeIsNull(rawAttr)) 1641 throw py::error_already_set(); 1642 return PyAttribute( 1643 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1644 } 1645 1646 //------------------------------------------------------------------------------ 1647 // PyNamedAttribute. 1648 //------------------------------------------------------------------------------ 1649 1650 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1651 : ownedName(new std::string(std::move(ownedName))) { 1652 namedAttr = mlirNamedAttributeGet( 1653 mlirIdentifierGet(mlirAttributeGetContext(attr), 1654 toMlirStringRef(*this->ownedName)), 1655 attr); 1656 } 1657 1658 //------------------------------------------------------------------------------ 1659 // PyType. 1660 //------------------------------------------------------------------------------ 1661 1662 bool PyType::operator==(const PyType &other) { 1663 return mlirTypeEqual(type, other.type); 1664 } 1665 1666 py::object PyType::getCapsule() { 1667 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1668 } 1669 1670 PyType PyType::createFromCapsule(py::object capsule) { 1671 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1672 if (mlirTypeIsNull(rawType)) 1673 throw py::error_already_set(); 1674 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1675 rawType); 1676 } 1677 1678 //------------------------------------------------------------------------------ 1679 // PyValue and subclases. 1680 //------------------------------------------------------------------------------ 1681 1682 pybind11::object PyValue::getCapsule() { 1683 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1684 } 1685 1686 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1687 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1688 if (mlirValueIsNull(value)) 1689 throw py::error_already_set(); 1690 MlirOperation owner; 1691 if (mlirValueIsAOpResult(value)) 1692 owner = mlirOpResultGetOwner(value); 1693 if (mlirValueIsABlockArgument(value)) 1694 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1695 if (mlirOperationIsNull(owner)) 1696 throw py::error_already_set(); 1697 MlirContext ctx = mlirOperationGetContext(owner); 1698 PyOperationRef ownerRef = 1699 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1700 return PyValue(ownerRef, value); 1701 } 1702 1703 //------------------------------------------------------------------------------ 1704 // PySymbolTable. 1705 //------------------------------------------------------------------------------ 1706 1707 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1708 : operation(operation.getOperation().getRef()) { 1709 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1710 if (mlirSymbolTableIsNull(symbolTable)) { 1711 throw py::cast_error("Operation is not a Symbol Table."); 1712 } 1713 } 1714 1715 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1716 operation->checkValid(); 1717 MlirOperation symbol = mlirSymbolTableLookup( 1718 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1719 if (mlirOperationIsNull(symbol)) 1720 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1721 1722 return PyOperation::forOperation(operation->getContext(), symbol, 1723 operation.getObject()) 1724 ->createOpView(); 1725 } 1726 1727 void PySymbolTable::erase(PyOperationBase &symbol) { 1728 operation->checkValid(); 1729 symbol.getOperation().checkValid(); 1730 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1731 // The operation is also erased, so we must invalidate it. There may be Python 1732 // references to this operation so we don't want to delete it from the list of 1733 // live operations here. 1734 symbol.getOperation().valid = false; 1735 } 1736 1737 void PySymbolTable::dunderDel(const std::string &name) { 1738 py::object operation = dunderGetItem(name); 1739 erase(py::cast<PyOperationBase &>(operation)); 1740 } 1741 1742 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1743 operation->checkValid(); 1744 symbol.getOperation().checkValid(); 1745 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1746 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1747 if (mlirAttributeIsNull(symbolAttr)) 1748 throw py::value_error("Expected operation to have a symbol name."); 1749 return PyAttribute( 1750 symbol.getOperation().getContext(), 1751 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1752 } 1753 1754 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1755 // Op must already be a symbol. 1756 PyOperation &operation = symbol.getOperation(); 1757 operation.checkValid(); 1758 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1759 MlirAttribute existingNameAttr = 1760 mlirOperationGetAttributeByName(operation.get(), attrName); 1761 if (mlirAttributeIsNull(existingNameAttr)) 1762 throw py::value_error("Expected operation to have a symbol name."); 1763 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1764 } 1765 1766 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1767 const std::string &name) { 1768 // Op must already be a symbol. 1769 PyOperation &operation = symbol.getOperation(); 1770 operation.checkValid(); 1771 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1772 MlirAttribute existingNameAttr = 1773 mlirOperationGetAttributeByName(operation.get(), attrName); 1774 if (mlirAttributeIsNull(existingNameAttr)) 1775 throw py::value_error("Expected operation to have a symbol name."); 1776 MlirAttribute newNameAttr = 1777 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1778 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1779 } 1780 1781 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1782 PyOperation &operation = symbol.getOperation(); 1783 operation.checkValid(); 1784 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1785 MlirAttribute existingVisAttr = 1786 mlirOperationGetAttributeByName(operation.get(), attrName); 1787 if (mlirAttributeIsNull(existingVisAttr)) 1788 throw py::value_error("Expected operation to have a symbol visibility."); 1789 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1790 } 1791 1792 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1793 const std::string &visibility) { 1794 if (visibility != "public" && visibility != "private" && 1795 visibility != "nested") 1796 throw py::value_error( 1797 "Expected visibility to be 'public', 'private' or 'nested'"); 1798 PyOperation &operation = symbol.getOperation(); 1799 operation.checkValid(); 1800 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1801 MlirAttribute existingVisAttr = 1802 mlirOperationGetAttributeByName(operation.get(), attrName); 1803 if (mlirAttributeIsNull(existingVisAttr)) 1804 throw py::value_error("Expected operation to have a symbol visibility."); 1805 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1806 toMlirStringRef(visibility)); 1807 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1808 } 1809 1810 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1811 const std::string &newSymbol, 1812 PyOperationBase &from) { 1813 PyOperation &fromOperation = from.getOperation(); 1814 fromOperation.checkValid(); 1815 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1816 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1817 from.getOperation()))) 1818 1819 throw py::value_error("Symbol rename failed"); 1820 } 1821 1822 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1823 bool allSymUsesVisible, 1824 py::object callback) { 1825 PyOperation &fromOperation = from.getOperation(); 1826 fromOperation.checkValid(); 1827 struct UserData { 1828 PyMlirContextRef context; 1829 py::object callback; 1830 bool gotException; 1831 std::string exceptionWhat; 1832 py::object exceptionType; 1833 }; 1834 UserData userData{ 1835 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1836 mlirSymbolTableWalkSymbolTables( 1837 fromOperation.get(), allSymUsesVisible, 1838 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1839 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1840 auto pyFoundOp = 1841 PyOperation::forOperation(calleeUserData->context, foundOp); 1842 if (calleeUserData->gotException) 1843 return; 1844 try { 1845 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1846 } catch (py::error_already_set &e) { 1847 calleeUserData->gotException = true; 1848 calleeUserData->exceptionWhat = e.what(); 1849 calleeUserData->exceptionType = e.type(); 1850 } 1851 }, 1852 static_cast<void *>(&userData)); 1853 if (userData.gotException) { 1854 std::string message("Exception raised in callback: "); 1855 message.append(userData.exceptionWhat); 1856 throw std::runtime_error(message); 1857 } 1858 } 1859 1860 namespace { 1861 /// CRTP base class for Python MLIR values that subclass Value and should be 1862 /// castable from it. The value hierarchy is one level deep and is not supposed 1863 /// to accommodate other levels unless core MLIR changes. 1864 template <typename DerivedTy> 1865 class PyConcreteValue : public PyValue { 1866 public: 1867 // Derived classes must define statics for: 1868 // IsAFunctionTy isaFunction 1869 // const char *pyClassName 1870 // and redefine bindDerived. 1871 using ClassTy = py::class_<DerivedTy, PyValue>; 1872 using IsAFunctionTy = bool (*)(MlirValue); 1873 1874 PyConcreteValue() = default; 1875 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1876 : PyValue(operationRef, value) {} 1877 PyConcreteValue(PyValue &orig) 1878 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1879 1880 /// Attempts to cast the original value to the derived type and throws on 1881 /// type mismatches. 1882 static MlirValue castFrom(PyValue &orig) { 1883 if (!DerivedTy::isaFunction(orig.get())) { 1884 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1885 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1886 DerivedTy::pyClassName + 1887 " (from " + origRepr + ")"); 1888 } 1889 return orig.get(); 1890 } 1891 1892 /// Binds the Python module objects to functions of this class. 1893 static void bind(py::module &m) { 1894 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1895 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1896 cls.def_static( 1897 "isinstance", 1898 [](PyValue &otherValue) -> bool { 1899 return DerivedTy::isaFunction(otherValue); 1900 }, 1901 py::arg("other_value")); 1902 DerivedTy::bindDerived(cls); 1903 } 1904 1905 /// Implemented by derived classes to add methods to the Python subclass. 1906 static void bindDerived(ClassTy &m) {} 1907 }; 1908 1909 /// Python wrapper for MlirBlockArgument. 1910 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1911 public: 1912 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1913 static constexpr const char *pyClassName = "BlockArgument"; 1914 using PyConcreteValue::PyConcreteValue; 1915 1916 static void bindDerived(ClassTy &c) { 1917 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1918 return PyBlock(self.getParentOperation(), 1919 mlirBlockArgumentGetOwner(self.get())); 1920 }); 1921 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1922 return mlirBlockArgumentGetArgNumber(self.get()); 1923 }); 1924 c.def( 1925 "set_type", 1926 [](PyBlockArgument &self, PyType type) { 1927 return mlirBlockArgumentSetType(self.get(), type); 1928 }, 1929 py::arg("type")); 1930 } 1931 }; 1932 1933 /// Python wrapper for MlirOpResult. 1934 class PyOpResult : public PyConcreteValue<PyOpResult> { 1935 public: 1936 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1937 static constexpr const char *pyClassName = "OpResult"; 1938 using PyConcreteValue::PyConcreteValue; 1939 1940 static void bindDerived(ClassTy &c) { 1941 c.def_property_readonly("owner", [](PyOpResult &self) { 1942 assert( 1943 mlirOperationEqual(self.getParentOperation()->get(), 1944 mlirOpResultGetOwner(self.get())) && 1945 "expected the owner of the value in Python to match that in the IR"); 1946 return self.getParentOperation().getObject(); 1947 }); 1948 c.def_property_readonly("result_number", [](PyOpResult &self) { 1949 return mlirOpResultGetResultNumber(self.get()); 1950 }); 1951 } 1952 }; 1953 1954 /// Returns the list of types of the values held by container. 1955 template <typename Container> 1956 static std::vector<PyType> getValueTypes(Container &container, 1957 PyMlirContextRef &context) { 1958 std::vector<PyType> result; 1959 result.reserve(container.getNumElements()); 1960 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1961 result.push_back( 1962 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1963 } 1964 return result; 1965 } 1966 1967 /// A list of block arguments. Internally, these are stored as consecutive 1968 /// elements, random access is cheap. The argument list is associated with the 1969 /// operation that contains the block (detached blocks are not allowed in 1970 /// Python bindings) and extends its lifetime. 1971 class PyBlockArgumentList 1972 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1973 public: 1974 static constexpr const char *pyClassName = "BlockArgumentList"; 1975 1976 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1977 intptr_t startIndex = 0, intptr_t length = -1, 1978 intptr_t step = 1) 1979 : Sliceable(startIndex, 1980 length == -1 ? mlirBlockGetNumArguments(block) : length, 1981 step), 1982 operation(std::move(operation)), block(block) {} 1983 1984 /// Returns the number of arguments in the list. 1985 intptr_t getNumElements() { 1986 operation->checkValid(); 1987 return mlirBlockGetNumArguments(block); 1988 } 1989 1990 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1991 PyBlockArgument getElement(intptr_t pos) { 1992 MlirValue argument = mlirBlockGetArgument(block, pos); 1993 return PyBlockArgument(operation, argument); 1994 } 1995 1996 /// Returns a sublist of this list. 1997 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1998 intptr_t step) { 1999 return PyBlockArgumentList(operation, block, startIndex, length, step); 2000 } 2001 2002 static void bindDerived(ClassTy &c) { 2003 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 2004 return getValueTypes(self, self.operation->getContext()); 2005 }); 2006 } 2007 2008 private: 2009 PyOperationRef operation; 2010 MlirBlock block; 2011 }; 2012 2013 /// A list of operation operands. Internally, these are stored as consecutive 2014 /// elements, random access is cheap. The result list is associated with the 2015 /// operation whose results these are, and extends the lifetime of this 2016 /// operation. 2017 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 2018 public: 2019 static constexpr const char *pyClassName = "OpOperandList"; 2020 2021 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2022 intptr_t length = -1, intptr_t step = 1) 2023 : Sliceable(startIndex, 2024 length == -1 ? mlirOperationGetNumOperands(operation->get()) 2025 : length, 2026 step), 2027 operation(operation) {} 2028 2029 intptr_t getNumElements() { 2030 operation->checkValid(); 2031 return mlirOperationGetNumOperands(operation->get()); 2032 } 2033 2034 PyValue getElement(intptr_t pos) { 2035 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2036 MlirOperation owner; 2037 if (mlirValueIsAOpResult(operand)) 2038 owner = mlirOpResultGetOwner(operand); 2039 else if (mlirValueIsABlockArgument(operand)) 2040 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2041 else 2042 assert(false && "Value must be an block arg or op result."); 2043 PyOperationRef pyOwner = 2044 PyOperation::forOperation(operation->getContext(), owner); 2045 return PyValue(pyOwner, operand); 2046 } 2047 2048 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2049 return PyOpOperandList(operation, startIndex, length, step); 2050 } 2051 2052 void dunderSetItem(intptr_t index, PyValue value) { 2053 index = wrapIndex(index); 2054 mlirOperationSetOperand(operation->get(), index, value.get()); 2055 } 2056 2057 static void bindDerived(ClassTy &c) { 2058 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2059 } 2060 2061 private: 2062 PyOperationRef operation; 2063 }; 2064 2065 /// A list of operation results. Internally, these are stored as consecutive 2066 /// elements, random access is cheap. The result list is associated with the 2067 /// operation whose results these are, and extends the lifetime of this 2068 /// operation. 2069 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 2070 public: 2071 static constexpr const char *pyClassName = "OpResultList"; 2072 2073 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 2074 intptr_t length = -1, intptr_t step = 1) 2075 : Sliceable(startIndex, 2076 length == -1 ? mlirOperationGetNumResults(operation->get()) 2077 : length, 2078 step), 2079 operation(operation) {} 2080 2081 intptr_t getNumElements() { 2082 operation->checkValid(); 2083 return mlirOperationGetNumResults(operation->get()); 2084 } 2085 2086 PyOpResult getElement(intptr_t index) { 2087 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 2088 return PyOpResult(value); 2089 } 2090 2091 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2092 return PyOpResultList(operation, startIndex, length, step); 2093 } 2094 2095 static void bindDerived(ClassTy &c) { 2096 c.def_property_readonly("types", [](PyOpResultList &self) { 2097 return getValueTypes(self, self.operation->getContext()); 2098 }); 2099 } 2100 2101 private: 2102 PyOperationRef operation; 2103 }; 2104 2105 /// A list of operation attributes. Can be indexed by name, producing 2106 /// attributes, or by index, producing named attributes. 2107 class PyOpAttributeMap { 2108 public: 2109 PyOpAttributeMap(PyOperationRef operation) 2110 : operation(std::move(operation)) {} 2111 2112 PyAttribute dunderGetItemNamed(const std::string &name) { 2113 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2114 toMlirStringRef(name)); 2115 if (mlirAttributeIsNull(attr)) { 2116 throw SetPyError(PyExc_KeyError, 2117 "attempt to access a non-existent attribute"); 2118 } 2119 return PyAttribute(operation->getContext(), attr); 2120 } 2121 2122 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2123 if (index < 0 || index >= dunderLen()) { 2124 throw SetPyError(PyExc_IndexError, 2125 "attempt to access out of bounds attribute"); 2126 } 2127 MlirNamedAttribute namedAttr = 2128 mlirOperationGetAttribute(operation->get(), index); 2129 return PyNamedAttribute( 2130 namedAttr.attribute, 2131 std::string(mlirIdentifierStr(namedAttr.name).data, 2132 mlirIdentifierStr(namedAttr.name).length)); 2133 } 2134 2135 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2136 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2137 attr); 2138 } 2139 2140 void dunderDelItem(const std::string &name) { 2141 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2142 toMlirStringRef(name)); 2143 if (!removed) 2144 throw SetPyError(PyExc_KeyError, 2145 "attempt to delete a non-existent attribute"); 2146 } 2147 2148 intptr_t dunderLen() { 2149 return mlirOperationGetNumAttributes(operation->get()); 2150 } 2151 2152 bool dunderContains(const std::string &name) { 2153 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2154 operation->get(), toMlirStringRef(name))); 2155 } 2156 2157 static void bind(py::module &m) { 2158 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2159 .def("__contains__", &PyOpAttributeMap::dunderContains) 2160 .def("__len__", &PyOpAttributeMap::dunderLen) 2161 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2162 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2163 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2164 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2165 } 2166 2167 private: 2168 PyOperationRef operation; 2169 }; 2170 2171 } // namespace 2172 2173 //------------------------------------------------------------------------------ 2174 // Populates the core exports of the 'ir' submodule. 2175 //------------------------------------------------------------------------------ 2176 2177 void mlir::python::populateIRCore(py::module &m) { 2178 //---------------------------------------------------------------------------- 2179 // Enums. 2180 //---------------------------------------------------------------------------- 2181 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local()) 2182 .value("ERROR", MlirDiagnosticError) 2183 .value("WARNING", MlirDiagnosticWarning) 2184 .value("NOTE", MlirDiagnosticNote) 2185 .value("REMARK", MlirDiagnosticRemark); 2186 2187 //---------------------------------------------------------------------------- 2188 // Mapping of Diagnostics. 2189 //---------------------------------------------------------------------------- 2190 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local()) 2191 .def_property_readonly("severity", &PyDiagnostic::getSeverity) 2192 .def_property_readonly("location", &PyDiagnostic::getLocation) 2193 .def_property_readonly("message", &PyDiagnostic::getMessage) 2194 .def_property_readonly("notes", &PyDiagnostic::getNotes) 2195 .def("__str__", [](PyDiagnostic &self) -> py::str { 2196 if (!self.isValid()) 2197 return "<Invalid Diagnostic>"; 2198 return self.getMessage(); 2199 }); 2200 2201 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local()) 2202 .def("detach", &PyDiagnosticHandler::detach) 2203 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) 2204 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) 2205 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2206 .def("__exit__", &PyDiagnosticHandler::contextExit); 2207 2208 //---------------------------------------------------------------------------- 2209 // Mapping of MlirContext. 2210 //---------------------------------------------------------------------------- 2211 py::class_<PyMlirContext>(m, "Context", py::module_local()) 2212 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2213 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2214 .def("_get_context_again", 2215 [](PyMlirContext &self) { 2216 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2217 return ref.releaseObject(); 2218 }) 2219 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2220 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) 2221 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2222 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2223 &PyMlirContext::getCapsule) 2224 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2225 .def("__enter__", &PyMlirContext::contextEnter) 2226 .def("__exit__", &PyMlirContext::contextExit) 2227 .def_property_readonly_static( 2228 "current", 2229 [](py::object & /*class*/) { 2230 auto *context = PyThreadContextEntry::getDefaultContext(); 2231 if (!context) 2232 throw SetPyError(PyExc_ValueError, "No current Context"); 2233 return context; 2234 }, 2235 "Gets the Context bound to the current thread or raises ValueError") 2236 .def_property_readonly( 2237 "dialects", 2238 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2239 "Gets a container for accessing dialects by name") 2240 .def_property_readonly( 2241 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2242 "Alias for 'dialect'") 2243 .def( 2244 "get_dialect_descriptor", 2245 [=](PyMlirContext &self, std::string &name) { 2246 MlirDialect dialect = mlirContextGetOrLoadDialect( 2247 self.get(), {name.data(), name.size()}); 2248 if (mlirDialectIsNull(dialect)) { 2249 throw SetPyError(PyExc_ValueError, 2250 Twine("Dialect '") + name + "' not found"); 2251 } 2252 return PyDialectDescriptor(self.getRef(), dialect); 2253 }, 2254 py::arg("dialect_name"), 2255 "Gets or loads a dialect by name, returning its descriptor object") 2256 .def_property( 2257 "allow_unregistered_dialects", 2258 [](PyMlirContext &self) -> bool { 2259 return mlirContextGetAllowUnregisteredDialects(self.get()); 2260 }, 2261 [](PyMlirContext &self, bool value) { 2262 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2263 }) 2264 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2265 py::arg("callback"), 2266 "Attaches a diagnostic handler that will receive callbacks") 2267 .def( 2268 "enable_multithreading", 2269 [](PyMlirContext &self, bool enable) { 2270 mlirContextEnableMultithreading(self.get(), enable); 2271 }, 2272 py::arg("enable")) 2273 .def( 2274 "is_registered_operation", 2275 [](PyMlirContext &self, std::string &name) { 2276 return mlirContextIsRegisteredOperation( 2277 self.get(), MlirStringRef{name.data(), name.size()}); 2278 }, 2279 py::arg("operation_name")); 2280 2281 //---------------------------------------------------------------------------- 2282 // Mapping of PyDialectDescriptor 2283 //---------------------------------------------------------------------------- 2284 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2285 .def_property_readonly("namespace", 2286 [](PyDialectDescriptor &self) { 2287 MlirStringRef ns = 2288 mlirDialectGetNamespace(self.get()); 2289 return py::str(ns.data, ns.length); 2290 }) 2291 .def("__repr__", [](PyDialectDescriptor &self) { 2292 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2293 std::string repr("<DialectDescriptor "); 2294 repr.append(ns.data, ns.length); 2295 repr.append(">"); 2296 return repr; 2297 }); 2298 2299 //---------------------------------------------------------------------------- 2300 // Mapping of PyDialects 2301 //---------------------------------------------------------------------------- 2302 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2303 .def("__getitem__", 2304 [=](PyDialects &self, std::string keyName) { 2305 MlirDialect dialect = 2306 self.getDialectForKey(keyName, /*attrError=*/false); 2307 py::object descriptor = 2308 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2309 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2310 }) 2311 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2312 MlirDialect dialect = 2313 self.getDialectForKey(attrName, /*attrError=*/true); 2314 py::object descriptor = 2315 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2316 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2317 }); 2318 2319 //---------------------------------------------------------------------------- 2320 // Mapping of PyDialect 2321 //---------------------------------------------------------------------------- 2322 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2323 .def(py::init<py::object>(), py::arg("descriptor")) 2324 .def_property_readonly( 2325 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2326 .def("__repr__", [](py::object self) { 2327 auto clazz = self.attr("__class__"); 2328 return py::str("<Dialect ") + 2329 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2330 clazz.attr("__module__") + py::str(".") + 2331 clazz.attr("__name__") + py::str(")>"); 2332 }); 2333 2334 //---------------------------------------------------------------------------- 2335 // Mapping of Location 2336 //---------------------------------------------------------------------------- 2337 py::class_<PyLocation>(m, "Location", py::module_local()) 2338 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2339 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2340 .def("__enter__", &PyLocation::contextEnter) 2341 .def("__exit__", &PyLocation::contextExit) 2342 .def("__eq__", 2343 [](PyLocation &self, PyLocation &other) -> bool { 2344 return mlirLocationEqual(self, other); 2345 }) 2346 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2347 .def_property_readonly_static( 2348 "current", 2349 [](py::object & /*class*/) { 2350 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2351 if (!loc) 2352 throw SetPyError(PyExc_ValueError, "No current Location"); 2353 return loc; 2354 }, 2355 "Gets the Location bound to the current thread or raises ValueError") 2356 .def_static( 2357 "unknown", 2358 [](DefaultingPyMlirContext context) { 2359 return PyLocation(context->getRef(), 2360 mlirLocationUnknownGet(context->get())); 2361 }, 2362 py::arg("context") = py::none(), 2363 "Gets a Location representing an unknown location") 2364 .def_static( 2365 "callsite", 2366 [](PyLocation callee, const std::vector<PyLocation> &frames, 2367 DefaultingPyMlirContext context) { 2368 if (frames.empty()) 2369 throw py::value_error("No caller frames provided"); 2370 MlirLocation caller = frames.back().get(); 2371 for (const PyLocation &frame : 2372 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2373 caller = mlirLocationCallSiteGet(frame.get(), caller); 2374 return PyLocation(context->getRef(), 2375 mlirLocationCallSiteGet(callee.get(), caller)); 2376 }, 2377 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2378 kContextGetCallSiteLocationDocstring) 2379 .def_static( 2380 "file", 2381 [](std::string filename, int line, int col, 2382 DefaultingPyMlirContext context) { 2383 return PyLocation( 2384 context->getRef(), 2385 mlirLocationFileLineColGet( 2386 context->get(), toMlirStringRef(filename), line, col)); 2387 }, 2388 py::arg("filename"), py::arg("line"), py::arg("col"), 2389 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2390 .def_static( 2391 "fused", 2392 [](const std::vector<PyLocation> &pyLocations, 2393 llvm::Optional<PyAttribute> metadata, 2394 DefaultingPyMlirContext context) { 2395 llvm::SmallVector<MlirLocation, 4> locations; 2396 locations.reserve(pyLocations.size()); 2397 for (auto &pyLocation : pyLocations) 2398 locations.push_back(pyLocation.get()); 2399 MlirLocation location = mlirLocationFusedGet( 2400 context->get(), locations.size(), locations.data(), 2401 metadata ? metadata->get() : MlirAttribute{0}); 2402 return PyLocation(context->getRef(), location); 2403 }, 2404 py::arg("locations"), py::arg("metadata") = py::none(), 2405 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2406 .def_static( 2407 "name", 2408 [](std::string name, llvm::Optional<PyLocation> childLoc, 2409 DefaultingPyMlirContext context) { 2410 return PyLocation( 2411 context->getRef(), 2412 mlirLocationNameGet( 2413 context->get(), toMlirStringRef(name), 2414 childLoc ? childLoc->get() 2415 : mlirLocationUnknownGet(context->get()))); 2416 }, 2417 py::arg("name"), py::arg("childLoc") = py::none(), 2418 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2419 .def_property_readonly( 2420 "context", 2421 [](PyLocation &self) { return self.getContext().getObject(); }, 2422 "Context that owns the Location") 2423 .def( 2424 "emit_error", 2425 [](PyLocation &self, std::string message) { 2426 mlirEmitError(self, message.c_str()); 2427 }, 2428 py::arg("message"), "Emits an error at this location") 2429 .def("__repr__", [](PyLocation &self) { 2430 PyPrintAccumulator printAccum; 2431 mlirLocationPrint(self, printAccum.getCallback(), 2432 printAccum.getUserData()); 2433 return printAccum.join(); 2434 }); 2435 2436 //---------------------------------------------------------------------------- 2437 // Mapping of Module 2438 //---------------------------------------------------------------------------- 2439 py::class_<PyModule>(m, "Module", py::module_local()) 2440 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2441 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2442 .def_static( 2443 "parse", 2444 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2445 MlirModule module = mlirModuleCreateParse( 2446 context->get(), toMlirStringRef(moduleAsm)); 2447 // TODO: Rework error reporting once diagnostic engine is exposed 2448 // in C API. 2449 if (mlirModuleIsNull(module)) { 2450 throw SetPyError( 2451 PyExc_ValueError, 2452 "Unable to parse module assembly (see diagnostics)"); 2453 } 2454 return PyModule::forModule(module).releaseObject(); 2455 }, 2456 py::arg("asm"), py::arg("context") = py::none(), 2457 kModuleParseDocstring) 2458 .def_static( 2459 "create", 2460 [](DefaultingPyLocation loc) { 2461 MlirModule module = mlirModuleCreateEmpty(loc); 2462 return PyModule::forModule(module).releaseObject(); 2463 }, 2464 py::arg("loc") = py::none(), "Creates an empty module") 2465 .def_property_readonly( 2466 "context", 2467 [](PyModule &self) { return self.getContext().getObject(); }, 2468 "Context that created the Module") 2469 .def_property_readonly( 2470 "operation", 2471 [](PyModule &self) { 2472 return PyOperation::forOperation(self.getContext(), 2473 mlirModuleGetOperation(self.get()), 2474 self.getRef().releaseObject()) 2475 .releaseObject(); 2476 }, 2477 "Accesses the module as an operation") 2478 .def_property_readonly( 2479 "body", 2480 [](PyModule &self) { 2481 PyOperationRef moduleOp = PyOperation::forOperation( 2482 self.getContext(), mlirModuleGetOperation(self.get()), 2483 self.getRef().releaseObject()); 2484 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 2485 return returnBlock; 2486 }, 2487 "Return the block for this module") 2488 .def( 2489 "dump", 2490 [](PyModule &self) { 2491 mlirOperationDump(mlirModuleGetOperation(self.get())); 2492 }, 2493 kDumpDocstring) 2494 .def( 2495 "__str__", 2496 [](py::object self) { 2497 // Defer to the operation's __str__. 2498 return self.attr("operation").attr("__str__")(); 2499 }, 2500 kOperationStrDunderDocstring); 2501 2502 //---------------------------------------------------------------------------- 2503 // Mapping of Operation. 2504 //---------------------------------------------------------------------------- 2505 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2506 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2507 [](PyOperationBase &self) { 2508 return self.getOperation().getCapsule(); 2509 }) 2510 .def("__eq__", 2511 [](PyOperationBase &self, PyOperationBase &other) { 2512 return &self.getOperation() == &other.getOperation(); 2513 }) 2514 .def("__eq__", 2515 [](PyOperationBase &self, py::object other) { return false; }) 2516 .def("__hash__", 2517 [](PyOperationBase &self) { 2518 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2519 }) 2520 .def_property_readonly("attributes", 2521 [](PyOperationBase &self) { 2522 return PyOpAttributeMap( 2523 self.getOperation().getRef()); 2524 }) 2525 .def_property_readonly("operands", 2526 [](PyOperationBase &self) { 2527 return PyOpOperandList( 2528 self.getOperation().getRef()); 2529 }) 2530 .def_property_readonly("regions", 2531 [](PyOperationBase &self) { 2532 return PyRegionList( 2533 self.getOperation().getRef()); 2534 }) 2535 .def_property_readonly( 2536 "results", 2537 [](PyOperationBase &self) { 2538 return PyOpResultList(self.getOperation().getRef()); 2539 }, 2540 "Returns the list of Operation results.") 2541 .def_property_readonly( 2542 "result", 2543 [](PyOperationBase &self) { 2544 auto &operation = self.getOperation(); 2545 auto numResults = mlirOperationGetNumResults(operation); 2546 if (numResults != 1) { 2547 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2548 throw SetPyError( 2549 PyExc_ValueError, 2550 Twine("Cannot call .result on operation ") + 2551 StringRef(name.data, name.length) + " which has " + 2552 Twine(numResults) + 2553 " results (it is only valid for operations with a " 2554 "single result)"); 2555 } 2556 return PyOpResult(operation.getRef(), 2557 mlirOperationGetResult(operation, 0)); 2558 }, 2559 "Shortcut to get an op result if it has only one (throws an error " 2560 "otherwise).") 2561 .def_property_readonly( 2562 "location", 2563 [](PyOperationBase &self) { 2564 PyOperation &operation = self.getOperation(); 2565 return PyLocation(operation.getContext(), 2566 mlirOperationGetLocation(operation.get())); 2567 }, 2568 "Returns the source location the operation was defined or derived " 2569 "from.") 2570 .def( 2571 "__str__", 2572 [](PyOperationBase &self) { 2573 return self.getAsm(/*binary=*/false, 2574 /*largeElementsLimit=*/llvm::None, 2575 /*enableDebugInfo=*/false, 2576 /*prettyDebugInfo=*/false, 2577 /*printGenericOpForm=*/false, 2578 /*useLocalScope=*/false, 2579 /*assumeVerified=*/false); 2580 }, 2581 "Returns the assembly form of the operation.") 2582 .def("print", &PyOperationBase::print, 2583 // Careful: Lots of arguments must match up with print method. 2584 py::arg("file") = py::none(), py::arg("binary") = false, 2585 py::arg("large_elements_limit") = py::none(), 2586 py::arg("enable_debug_info") = false, 2587 py::arg("pretty_debug_info") = false, 2588 py::arg("print_generic_op_form") = false, 2589 py::arg("use_local_scope") = false, 2590 py::arg("assume_verified") = false, kOperationPrintDocstring) 2591 .def("get_asm", &PyOperationBase::getAsm, 2592 // Careful: Lots of arguments must match up with get_asm method. 2593 py::arg("binary") = false, 2594 py::arg("large_elements_limit") = py::none(), 2595 py::arg("enable_debug_info") = false, 2596 py::arg("pretty_debug_info") = false, 2597 py::arg("print_generic_op_form") = false, 2598 py::arg("use_local_scope") = false, 2599 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2600 .def( 2601 "verify", 2602 [](PyOperationBase &self) { 2603 return mlirOperationVerify(self.getOperation()); 2604 }, 2605 "Verify the operation and return true if it passes, false if it " 2606 "fails.") 2607 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2608 "Puts self immediately after the other operation in its parent " 2609 "block.") 2610 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2611 "Puts self immediately before the other operation in its parent " 2612 "block.") 2613 .def( 2614 "detach_from_parent", 2615 [](PyOperationBase &self) { 2616 PyOperation &operation = self.getOperation(); 2617 operation.checkValid(); 2618 if (!operation.isAttached()) 2619 throw py::value_error("Detached operation has no parent."); 2620 2621 operation.detachFromParent(); 2622 return operation.createOpView(); 2623 }, 2624 "Detaches the operation from its parent block."); 2625 2626 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2627 .def_static("create", &PyOperation::create, py::arg("name"), 2628 py::arg("results") = py::none(), 2629 py::arg("operands") = py::none(), 2630 py::arg("attributes") = py::none(), 2631 py::arg("successors") = py::none(), py::arg("regions") = 0, 2632 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2633 kOperationCreateDocstring) 2634 .def_property_readonly("parent", 2635 [](PyOperation &self) -> py::object { 2636 auto parent = self.getParentOperation(); 2637 if (parent) 2638 return parent->getObject(); 2639 return py::none(); 2640 }) 2641 .def("erase", &PyOperation::erase) 2642 .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) 2643 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2644 &PyOperation::getCapsule) 2645 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2646 .def_property_readonly("name", 2647 [](PyOperation &self) { 2648 self.checkValid(); 2649 MlirOperation operation = self.get(); 2650 MlirStringRef name = mlirIdentifierStr( 2651 mlirOperationGetName(operation)); 2652 return py::str(name.data, name.length); 2653 }) 2654 .def_property_readonly( 2655 "context", 2656 [](PyOperation &self) { 2657 self.checkValid(); 2658 return self.getContext().getObject(); 2659 }, 2660 "Context that owns the Operation") 2661 .def_property_readonly("opview", &PyOperation::createOpView); 2662 2663 auto opViewClass = 2664 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2665 .def(py::init<py::object>(), py::arg("operation")) 2666 .def_property_readonly("operation", &PyOpView::getOperationObject) 2667 .def_property_readonly( 2668 "context", 2669 [](PyOpView &self) { 2670 return self.getOperation().getContext().getObject(); 2671 }, 2672 "Context that owns the Operation") 2673 .def("__str__", [](PyOpView &self) { 2674 return py::str(self.getOperationObject()); 2675 }); 2676 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2677 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2678 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2679 opViewClass.attr("build_generic") = classmethod( 2680 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2681 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2682 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2683 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2684 "Builds a specific, generated OpView based on class level attributes."); 2685 2686 //---------------------------------------------------------------------------- 2687 // Mapping of PyRegion. 2688 //---------------------------------------------------------------------------- 2689 py::class_<PyRegion>(m, "Region", py::module_local()) 2690 .def_property_readonly( 2691 "blocks", 2692 [](PyRegion &self) { 2693 return PyBlockList(self.getParentOperation(), self.get()); 2694 }, 2695 "Returns a forward-optimized sequence of blocks.") 2696 .def_property_readonly( 2697 "owner", 2698 [](PyRegion &self) { 2699 return self.getParentOperation()->createOpView(); 2700 }, 2701 "Returns the operation owning this region.") 2702 .def( 2703 "__iter__", 2704 [](PyRegion &self) { 2705 self.checkValid(); 2706 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2707 return PyBlockIterator(self.getParentOperation(), firstBlock); 2708 }, 2709 "Iterates over blocks in the region.") 2710 .def("__eq__", 2711 [](PyRegion &self, PyRegion &other) { 2712 return self.get().ptr == other.get().ptr; 2713 }) 2714 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2715 2716 //---------------------------------------------------------------------------- 2717 // Mapping of PyBlock. 2718 //---------------------------------------------------------------------------- 2719 py::class_<PyBlock>(m, "Block", py::module_local()) 2720 .def_property_readonly( 2721 "owner", 2722 [](PyBlock &self) { 2723 return self.getParentOperation()->createOpView(); 2724 }, 2725 "Returns the owning operation of this block.") 2726 .def_property_readonly( 2727 "region", 2728 [](PyBlock &self) { 2729 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2730 return PyRegion(self.getParentOperation(), region); 2731 }, 2732 "Returns the owning region of this block.") 2733 .def_property_readonly( 2734 "arguments", 2735 [](PyBlock &self) { 2736 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2737 }, 2738 "Returns a list of block arguments.") 2739 .def_property_readonly( 2740 "operations", 2741 [](PyBlock &self) { 2742 return PyOperationList(self.getParentOperation(), self.get()); 2743 }, 2744 "Returns a forward-optimized sequence of operations.") 2745 .def_static( 2746 "create_at_start", 2747 [](PyRegion &parent, py::list pyArgTypes) { 2748 parent.checkValid(); 2749 llvm::SmallVector<MlirType, 4> argTypes; 2750 llvm::SmallVector<MlirLocation, 4> argLocs; 2751 argTypes.reserve(pyArgTypes.size()); 2752 argLocs.reserve(pyArgTypes.size()); 2753 for (auto &pyArg : pyArgTypes) { 2754 argTypes.push_back(pyArg.cast<PyType &>()); 2755 // TODO: Pass in a proper location here. 2756 argLocs.push_back( 2757 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2758 } 2759 2760 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2761 argLocs.data()); 2762 mlirRegionInsertOwnedBlock(parent, 0, block); 2763 return PyBlock(parent.getParentOperation(), block); 2764 }, 2765 py::arg("parent"), py::arg("arg_types") = py::list(), 2766 "Creates and returns a new Block at the beginning of the given " 2767 "region (with given argument types).") 2768 .def( 2769 "append_to", 2770 [](PyBlock &self, PyRegion ®ion) { 2771 MlirBlock b = self.get(); 2772 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) 2773 mlirBlockDetach(b); 2774 mlirRegionAppendOwnedBlock(region.get(), b); 2775 }, 2776 "Append this block to a region, transferring ownership if necessary") 2777 .def( 2778 "create_before", 2779 [](PyBlock &self, py::args pyArgTypes) { 2780 self.checkValid(); 2781 llvm::SmallVector<MlirType, 4> argTypes; 2782 llvm::SmallVector<MlirLocation, 4> argLocs; 2783 argTypes.reserve(pyArgTypes.size()); 2784 argLocs.reserve(pyArgTypes.size()); 2785 for (auto &pyArg : pyArgTypes) { 2786 argTypes.push_back(pyArg.cast<PyType &>()); 2787 // TODO: Pass in a proper location here. 2788 argLocs.push_back( 2789 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2790 } 2791 2792 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2793 argLocs.data()); 2794 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2795 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2796 return PyBlock(self.getParentOperation(), block); 2797 }, 2798 "Creates and returns a new Block before this block " 2799 "(with given argument types).") 2800 .def( 2801 "create_after", 2802 [](PyBlock &self, py::args pyArgTypes) { 2803 self.checkValid(); 2804 llvm::SmallVector<MlirType, 4> argTypes; 2805 llvm::SmallVector<MlirLocation, 4> argLocs; 2806 argTypes.reserve(pyArgTypes.size()); 2807 argLocs.reserve(pyArgTypes.size()); 2808 for (auto &pyArg : pyArgTypes) { 2809 argTypes.push_back(pyArg.cast<PyType &>()); 2810 2811 // TODO: Pass in a proper location here. 2812 argLocs.push_back( 2813 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2814 } 2815 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2816 argLocs.data()); 2817 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2818 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2819 return PyBlock(self.getParentOperation(), block); 2820 }, 2821 "Creates and returns a new Block after this block " 2822 "(with given argument types).") 2823 .def( 2824 "__iter__", 2825 [](PyBlock &self) { 2826 self.checkValid(); 2827 MlirOperation firstOperation = 2828 mlirBlockGetFirstOperation(self.get()); 2829 return PyOperationIterator(self.getParentOperation(), 2830 firstOperation); 2831 }, 2832 "Iterates over operations in the block.") 2833 .def("__eq__", 2834 [](PyBlock &self, PyBlock &other) { 2835 return self.get().ptr == other.get().ptr; 2836 }) 2837 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2838 .def( 2839 "__str__", 2840 [](PyBlock &self) { 2841 self.checkValid(); 2842 PyPrintAccumulator printAccum; 2843 mlirBlockPrint(self.get(), printAccum.getCallback(), 2844 printAccum.getUserData()); 2845 return printAccum.join(); 2846 }, 2847 "Returns the assembly form of the block.") 2848 .def( 2849 "append", 2850 [](PyBlock &self, PyOperationBase &operation) { 2851 if (operation.getOperation().isAttached()) 2852 operation.getOperation().detachFromParent(); 2853 2854 MlirOperation mlirOperation = operation.getOperation().get(); 2855 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2856 operation.getOperation().setAttached( 2857 self.getParentOperation().getObject()); 2858 }, 2859 py::arg("operation"), 2860 "Appends an operation to this block. If the operation is currently " 2861 "in another block, it will be moved."); 2862 2863 //---------------------------------------------------------------------------- 2864 // Mapping of PyInsertionPoint. 2865 //---------------------------------------------------------------------------- 2866 2867 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2868 .def(py::init<PyBlock &>(), py::arg("block"), 2869 "Inserts after the last operation but still inside the block.") 2870 .def("__enter__", &PyInsertionPoint::contextEnter) 2871 .def("__exit__", &PyInsertionPoint::contextExit) 2872 .def_property_readonly_static( 2873 "current", 2874 [](py::object & /*class*/) { 2875 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2876 if (!ip) 2877 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2878 return ip; 2879 }, 2880 "Gets the InsertionPoint bound to the current thread or raises " 2881 "ValueError if none has been set") 2882 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2883 "Inserts before a referenced operation.") 2884 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2885 py::arg("block"), "Inserts at the beginning of the block.") 2886 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2887 py::arg("block"), "Inserts before the block terminator.") 2888 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2889 "Inserts an operation.") 2890 .def_property_readonly( 2891 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2892 "Returns the block that this InsertionPoint points to."); 2893 2894 //---------------------------------------------------------------------------- 2895 // Mapping of PyAttribute. 2896 //---------------------------------------------------------------------------- 2897 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2898 // Delegate to the PyAttribute copy constructor, which will also lifetime 2899 // extend the backing context which owns the MlirAttribute. 2900 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2901 "Casts the passed attribute to the generic Attribute") 2902 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2903 &PyAttribute::getCapsule) 2904 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2905 .def_static( 2906 "parse", 2907 [](std::string attrSpec, DefaultingPyMlirContext context) { 2908 MlirAttribute type = mlirAttributeParseGet( 2909 context->get(), toMlirStringRef(attrSpec)); 2910 // TODO: Rework error reporting once diagnostic engine is exposed 2911 // in C API. 2912 if (mlirAttributeIsNull(type)) { 2913 throw SetPyError(PyExc_ValueError, 2914 Twine("Unable to parse attribute: '") + 2915 attrSpec + "'"); 2916 } 2917 return PyAttribute(context->getRef(), type); 2918 }, 2919 py::arg("asm"), py::arg("context") = py::none(), 2920 "Parses an attribute from an assembly form") 2921 .def_property_readonly( 2922 "context", 2923 [](PyAttribute &self) { return self.getContext().getObject(); }, 2924 "Context that owns the Attribute") 2925 .def_property_readonly("type", 2926 [](PyAttribute &self) { 2927 return PyType(self.getContext()->getRef(), 2928 mlirAttributeGetType(self)); 2929 }) 2930 .def( 2931 "get_named", 2932 [](PyAttribute &self, std::string name) { 2933 return PyNamedAttribute(self, std::move(name)); 2934 }, 2935 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2936 .def("__eq__", 2937 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2938 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2939 .def("__hash__", 2940 [](PyAttribute &self) { 2941 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2942 }) 2943 .def( 2944 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2945 kDumpDocstring) 2946 .def( 2947 "__str__", 2948 [](PyAttribute &self) { 2949 PyPrintAccumulator printAccum; 2950 mlirAttributePrint(self, printAccum.getCallback(), 2951 printAccum.getUserData()); 2952 return printAccum.join(); 2953 }, 2954 "Returns the assembly form of the Attribute.") 2955 .def("__repr__", [](PyAttribute &self) { 2956 // Generally, assembly formats are not printed for __repr__ because 2957 // this can cause exceptionally long debug output and exceptions. 2958 // However, attribute values are generally considered useful and are 2959 // printed. This may need to be re-evaluated if debug dumps end up 2960 // being excessive. 2961 PyPrintAccumulator printAccum; 2962 printAccum.parts.append("Attribute("); 2963 mlirAttributePrint(self, printAccum.getCallback(), 2964 printAccum.getUserData()); 2965 printAccum.parts.append(")"); 2966 return printAccum.join(); 2967 }); 2968 2969 //---------------------------------------------------------------------------- 2970 // Mapping of PyNamedAttribute 2971 //---------------------------------------------------------------------------- 2972 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2973 .def("__repr__", 2974 [](PyNamedAttribute &self) { 2975 PyPrintAccumulator printAccum; 2976 printAccum.parts.append("NamedAttribute("); 2977 printAccum.parts.append( 2978 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2979 mlirIdentifierStr(self.namedAttr.name).length)); 2980 printAccum.parts.append("="); 2981 mlirAttributePrint(self.namedAttr.attribute, 2982 printAccum.getCallback(), 2983 printAccum.getUserData()); 2984 printAccum.parts.append(")"); 2985 return printAccum.join(); 2986 }) 2987 .def_property_readonly( 2988 "name", 2989 [](PyNamedAttribute &self) { 2990 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2991 mlirIdentifierStr(self.namedAttr.name).length); 2992 }, 2993 "The name of the NamedAttribute binding") 2994 .def_property_readonly( 2995 "attr", 2996 [](PyNamedAttribute &self) { 2997 // TODO: When named attribute is removed/refactored, also remove 2998 // this constructor (it does an inefficient table lookup). 2999 auto contextRef = PyMlirContext::forContext( 3000 mlirAttributeGetContext(self.namedAttr.attribute)); 3001 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 3002 }, 3003 py::keep_alive<0, 1>(), 3004 "The underlying generic attribute of the NamedAttribute binding"); 3005 3006 //---------------------------------------------------------------------------- 3007 // Mapping of PyType. 3008 //---------------------------------------------------------------------------- 3009 py::class_<PyType>(m, "Type", py::module_local()) 3010 // Delegate to the PyType copy constructor, which will also lifetime 3011 // extend the backing context which owns the MlirType. 3012 .def(py::init<PyType &>(), py::arg("cast_from_type"), 3013 "Casts the passed type to the generic Type") 3014 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 3015 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 3016 .def_static( 3017 "parse", 3018 [](std::string typeSpec, DefaultingPyMlirContext context) { 3019 MlirType type = 3020 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 3021 // TODO: Rework error reporting once diagnostic engine is exposed 3022 // in C API. 3023 if (mlirTypeIsNull(type)) { 3024 throw SetPyError(PyExc_ValueError, 3025 Twine("Unable to parse type: '") + typeSpec + 3026 "'"); 3027 } 3028 return PyType(context->getRef(), type); 3029 }, 3030 py::arg("asm"), py::arg("context") = py::none(), 3031 kContextParseTypeDocstring) 3032 .def_property_readonly( 3033 "context", [](PyType &self) { return self.getContext().getObject(); }, 3034 "Context that owns the Type") 3035 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3036 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 3037 .def("__hash__", 3038 [](PyType &self) { 3039 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3040 }) 3041 .def( 3042 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3043 .def( 3044 "__str__", 3045 [](PyType &self) { 3046 PyPrintAccumulator printAccum; 3047 mlirTypePrint(self, printAccum.getCallback(), 3048 printAccum.getUserData()); 3049 return printAccum.join(); 3050 }, 3051 "Returns the assembly form of the type.") 3052 .def("__repr__", [](PyType &self) { 3053 // Generally, assembly formats are not printed for __repr__ because 3054 // this can cause exceptionally long debug output and exceptions. 3055 // However, types are an exception as they typically have compact 3056 // assembly forms and printing them is useful. 3057 PyPrintAccumulator printAccum; 3058 printAccum.parts.append("Type("); 3059 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 3060 printAccum.parts.append(")"); 3061 return printAccum.join(); 3062 }); 3063 3064 //---------------------------------------------------------------------------- 3065 // Mapping of Value. 3066 //---------------------------------------------------------------------------- 3067 py::class_<PyValue>(m, "Value", py::module_local()) 3068 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3069 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3070 .def_property_readonly( 3071 "context", 3072 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3073 "Context in which the value lives.") 3074 .def( 3075 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3076 kDumpDocstring) 3077 .def_property_readonly( 3078 "owner", 3079 [](PyValue &self) { 3080 assert(mlirOperationEqual(self.getParentOperation()->get(), 3081 mlirOpResultGetOwner(self.get())) && 3082 "expected the owner of the value in Python to match that in " 3083 "the IR"); 3084 return self.getParentOperation().getObject(); 3085 }) 3086 .def("__eq__", 3087 [](PyValue &self, PyValue &other) { 3088 return self.get().ptr == other.get().ptr; 3089 }) 3090 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 3091 .def("__hash__", 3092 [](PyValue &self) { 3093 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3094 }) 3095 .def( 3096 "__str__", 3097 [](PyValue &self) { 3098 PyPrintAccumulator printAccum; 3099 printAccum.parts.append("Value("); 3100 mlirValuePrint(self.get(), printAccum.getCallback(), 3101 printAccum.getUserData()); 3102 printAccum.parts.append(")"); 3103 return printAccum.join(); 3104 }, 3105 kValueDunderStrDocstring) 3106 .def_property_readonly("type", [](PyValue &self) { 3107 return PyType(self.getParentOperation()->getContext(), 3108 mlirValueGetType(self.get())); 3109 }); 3110 PyBlockArgument::bind(m); 3111 PyOpResult::bind(m); 3112 3113 //---------------------------------------------------------------------------- 3114 // Mapping of SymbolTable. 3115 //---------------------------------------------------------------------------- 3116 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 3117 .def(py::init<PyOperationBase &>()) 3118 .def("__getitem__", &PySymbolTable::dunderGetItem) 3119 .def("insert", &PySymbolTable::insert, py::arg("operation")) 3120 .def("erase", &PySymbolTable::erase, py::arg("operation")) 3121 .def("__delitem__", &PySymbolTable::dunderDel) 3122 .def("__contains__", 3123 [](PySymbolTable &table, const std::string &name) { 3124 return !mlirOperationIsNull(mlirSymbolTableLookup( 3125 table, mlirStringRefCreate(name.data(), name.length()))); 3126 }) 3127 // Static helpers. 3128 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3129 py::arg("symbol"), py::arg("name")) 3130 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3131 py::arg("symbol")) 3132 .def_static("get_visibility", &PySymbolTable::getVisibility, 3133 py::arg("symbol")) 3134 .def_static("set_visibility", &PySymbolTable::setVisibility, 3135 py::arg("symbol"), py::arg("visibility")) 3136 .def_static("replace_all_symbol_uses", 3137 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 3138 py::arg("new_symbol"), py::arg("from_op")) 3139 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3140 py::arg("from_op"), py::arg("all_sym_uses_visible"), 3141 py::arg("callback")); 3142 3143 // Container bindings. 3144 PyBlockArgumentList::bind(m); 3145 PyBlockIterator::bind(m); 3146 PyBlockList::bind(m); 3147 PyOperationIterator::bind(m); 3148 PyOperationList::bind(m); 3149 PyOpAttributeMap::bind(m); 3150 PyOpOperandList::bind(m); 3151 PyOpResultList::bind(m); 3152 PyRegionIterator::bind(m); 3153 PyRegionList::bind(m); 3154 3155 // Debug bindings. 3156 PyGlobalDebugFlag::bind(m); 3157 } 3158