1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 #include <utility> 25 26 namespace py = pybind11; 27 using namespace mlir; 28 using namespace mlir::python; 29 30 using llvm::SmallVector; 31 using llvm::StringRef; 32 using llvm::Twine; 33 34 //------------------------------------------------------------------------------ 35 // Docstrings (trivial, non-duplicated docstrings are included inline). 36 //------------------------------------------------------------------------------ 37 38 static const char kContextParseTypeDocstring[] = 39 R"(Parses the assembly form of a type. 40 41 Returns a Type object or raises a ValueError if the type cannot be parsed. 42 43 See also: https://mlir.llvm.org/docs/LangRef/#type-system 44 )"; 45 46 static const char kContextGetCallSiteLocationDocstring[] = 47 R"(Gets a Location representing a caller and callsite)"; 48 49 static const char kContextGetFileLocationDocstring[] = 50 R"(Gets a Location representing a file, line and column)"; 51 52 static const char kContextGetFusedLocationDocstring[] = 53 R"(Gets a Location representing a fused location with optional metadata)"; 54 55 static const char kContextGetNameLocationDocString[] = 56 R"(Gets a Location representing a named location with optional child location)"; 57 58 static const char kModuleParseDocstring[] = 59 R"(Parses a module's assembly format from a string. 60 61 Returns a new MlirModule or raises a ValueError if the parsing fails. 62 63 See also: https://mlir.llvm.org/docs/LangRef/ 64 )"; 65 66 static const char kOperationCreateDocstring[] = 67 R"(Creates a new operation. 68 69 Args: 70 name: Operation name (e.g. "dialect.operation"). 71 results: Sequence of Type representing op result types. 72 attributes: Dict of str:Attribute. 73 successors: List of Block for the operation's successors. 74 regions: Number of regions to create. 75 location: A Location object (defaults to resolve from context manager). 76 ip: An InsertionPoint (defaults to resolve from context manager or set to 77 False to disable insertion, even with an insertion point set in the 78 context manager). 79 Returns: 80 A new "detached" Operation object. Detached operations can be added 81 to blocks, which causes them to become "attached." 82 )"; 83 84 static const char kOperationPrintDocstring[] = 85 R"(Prints the assembly form of the operation to a file like object. 86 87 Args: 88 file: The file like object to write to. Defaults to sys.stdout. 89 binary: Whether to write bytes (True) or str (False). Defaults to False. 90 large_elements_limit: Whether to elide elements attributes above this 91 number of elements. Defaults to None (no limit). 92 enable_debug_info: Whether to print debug/location information. Defaults 93 to False. 94 pretty_debug_info: Whether to format debug information for easier reading 95 by a human (warning: the result is unparseable). 96 print_generic_op_form: Whether to print the generic assembly forms of all 97 ops. Defaults to False. 98 use_local_Scope: Whether to print in a way that is more optimized for 99 multi-threaded access but may not be consistent with how the overall 100 module prints. 101 assume_verified: By default, if not printing generic form, the verifier 102 will be run and if it fails, generic form will be printed with a comment 103 about failed verification. While a reasonable default for interactive use, 104 for systematic use, it is often better for the caller to verify explicitly 105 and report failures in a more robust fashion. Set this to True if doing this 106 in order to avoid running a redundant verification. If the IR is actually 107 invalid, behavior is undefined. 108 )"; 109 110 static const char kOperationGetAsmDocstring[] = 111 R"(Gets the assembly form of the operation with all options available. 112 113 Args: 114 binary: Whether to return a bytes (True) or str (False) object. Defaults to 115 False. 116 ... others ...: See the print() method for common keyword arguments for 117 configuring the printout. 118 Returns: 119 Either a bytes or str object, depending on the setting of the 'binary' 120 argument. 121 )"; 122 123 static const char kOperationStrDunderDocstring[] = 124 R"(Gets the assembly form of the operation with default options. 125 126 If more advanced control over the assembly formatting or I/O options is needed, 127 use the dedicated print or get_asm method, which supports keyword arguments to 128 customize behavior. 129 )"; 130 131 static const char kDumpDocstring[] = 132 R"(Dumps a debug representation of the object to stderr.)"; 133 134 static const char kAppendBlockDocstring[] = 135 R"(Appends a new block, with argument types as positional args. 136 137 Returns: 138 The created block. 139 )"; 140 141 static const char kValueDunderStrDocstring[] = 142 R"(Returns the string form of the value. 143 144 If the value is a block argument, this is the assembly form of its type and the 145 position in the argument list. If the value is an operation result, this is 146 equivalent to printing the operation that produced it. 147 )"; 148 149 //------------------------------------------------------------------------------ 150 // Utilities. 151 //------------------------------------------------------------------------------ 152 153 /// Helper for creating an @classmethod. 154 template <class Func, typename... Args> 155 py::object classmethod(Func f, Args... args) { 156 py::object cf = py::cpp_function(f, args...); 157 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 158 } 159 160 static py::object 161 createCustomDialectWrapper(const std::string &dialectNamespace, 162 py::object dialectDescriptor) { 163 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 164 if (!dialectClass) { 165 // Use the base class. 166 return py::cast(PyDialect(std::move(dialectDescriptor))); 167 } 168 169 // Create the custom implementation. 170 return (*dialectClass)(std::move(dialectDescriptor)); 171 } 172 173 static MlirStringRef toMlirStringRef(const std::string &s) { 174 return mlirStringRefCreate(s.data(), s.size()); 175 } 176 177 /// Wrapper for the global LLVM debugging flag. 178 struct PyGlobalDebugFlag { 179 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 180 181 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } 182 183 static void bind(py::module &m) { 184 // Debug flags. 185 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 186 .def_property_static("flag", &PyGlobalDebugFlag::get, 187 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 188 } 189 }; 190 191 //------------------------------------------------------------------------------ 192 // Collections. 193 //------------------------------------------------------------------------------ 194 195 namespace { 196 197 class PyRegionIterator { 198 public: 199 PyRegionIterator(PyOperationRef operation) 200 : operation(std::move(operation)) {} 201 202 PyRegionIterator &dunderIter() { return *this; } 203 204 PyRegion dunderNext() { 205 operation->checkValid(); 206 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 207 throw py::stop_iteration(); 208 } 209 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 210 return PyRegion(operation, region); 211 } 212 213 static void bind(py::module &m) { 214 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 215 .def("__iter__", &PyRegionIterator::dunderIter) 216 .def("__next__", &PyRegionIterator::dunderNext); 217 } 218 219 private: 220 PyOperationRef operation; 221 int nextIndex = 0; 222 }; 223 224 /// Regions of an op are fixed length and indexed numerically so are represented 225 /// with a sequence-like container. 226 class PyRegionList { 227 public: 228 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 229 230 intptr_t dunderLen() { 231 operation->checkValid(); 232 return mlirOperationGetNumRegions(operation->get()); 233 } 234 235 PyRegion dunderGetItem(intptr_t index) { 236 // dunderLen checks validity. 237 if (index < 0 || index >= dunderLen()) { 238 throw SetPyError(PyExc_IndexError, 239 "attempt to access out of bounds region"); 240 } 241 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 242 return PyRegion(operation, region); 243 } 244 245 static void bind(py::module &m) { 246 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 247 .def("__len__", &PyRegionList::dunderLen) 248 .def("__getitem__", &PyRegionList::dunderGetItem); 249 } 250 251 private: 252 PyOperationRef operation; 253 }; 254 255 class PyBlockIterator { 256 public: 257 PyBlockIterator(PyOperationRef operation, MlirBlock next) 258 : operation(std::move(operation)), next(next) {} 259 260 PyBlockIterator &dunderIter() { return *this; } 261 262 PyBlock dunderNext() { 263 operation->checkValid(); 264 if (mlirBlockIsNull(next)) { 265 throw py::stop_iteration(); 266 } 267 268 PyBlock returnBlock(operation, next); 269 next = mlirBlockGetNextInRegion(next); 270 return returnBlock; 271 } 272 273 static void bind(py::module &m) { 274 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 275 .def("__iter__", &PyBlockIterator::dunderIter) 276 .def("__next__", &PyBlockIterator::dunderNext); 277 } 278 279 private: 280 PyOperationRef operation; 281 MlirBlock next; 282 }; 283 284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 285 /// we present them as a more full-featured list-like container but optimize 286 /// it for forward iteration. Blocks are always owned by a region. 287 class PyBlockList { 288 public: 289 PyBlockList(PyOperationRef operation, MlirRegion region) 290 : operation(std::move(operation)), region(region) {} 291 292 PyBlockIterator dunderIter() { 293 operation->checkValid(); 294 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 295 } 296 297 intptr_t dunderLen() { 298 operation->checkValid(); 299 intptr_t count = 0; 300 MlirBlock block = mlirRegionGetFirstBlock(region); 301 while (!mlirBlockIsNull(block)) { 302 count += 1; 303 block = mlirBlockGetNextInRegion(block); 304 } 305 return count; 306 } 307 308 PyBlock dunderGetItem(intptr_t index) { 309 operation->checkValid(); 310 if (index < 0) { 311 throw SetPyError(PyExc_IndexError, 312 "attempt to access out of bounds block"); 313 } 314 MlirBlock block = mlirRegionGetFirstBlock(region); 315 while (!mlirBlockIsNull(block)) { 316 if (index == 0) { 317 return PyBlock(operation, block); 318 } 319 block = mlirBlockGetNextInRegion(block); 320 index -= 1; 321 } 322 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 323 } 324 325 PyBlock appendBlock(const py::args &pyArgTypes) { 326 operation->checkValid(); 327 llvm::SmallVector<MlirType, 4> argTypes; 328 llvm::SmallVector<MlirLocation, 4> argLocs; 329 argTypes.reserve(pyArgTypes.size()); 330 argLocs.reserve(pyArgTypes.size()); 331 for (auto &pyArg : pyArgTypes) { 332 argTypes.push_back(pyArg.cast<PyType &>()); 333 // TODO: Pass in a proper location here. 334 argLocs.push_back( 335 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 336 } 337 338 MlirBlock block = 339 mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 340 mlirRegionAppendOwnedBlock(region, block); 341 return PyBlock(operation, block); 342 } 343 344 static void bind(py::module &m) { 345 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 346 .def("__getitem__", &PyBlockList::dunderGetItem) 347 .def("__iter__", &PyBlockList::dunderIter) 348 .def("__len__", &PyBlockList::dunderLen) 349 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 350 } 351 352 private: 353 PyOperationRef operation; 354 MlirRegion region; 355 }; 356 357 class PyOperationIterator { 358 public: 359 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 360 : parentOperation(std::move(parentOperation)), next(next) {} 361 362 PyOperationIterator &dunderIter() { return *this; } 363 364 py::object dunderNext() { 365 parentOperation->checkValid(); 366 if (mlirOperationIsNull(next)) { 367 throw py::stop_iteration(); 368 } 369 370 PyOperationRef returnOperation = 371 PyOperation::forOperation(parentOperation->getContext(), next); 372 next = mlirOperationGetNextInBlock(next); 373 return returnOperation->createOpView(); 374 } 375 376 static void bind(py::module &m) { 377 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 378 .def("__iter__", &PyOperationIterator::dunderIter) 379 .def("__next__", &PyOperationIterator::dunderNext); 380 } 381 382 private: 383 PyOperationRef parentOperation; 384 MlirOperation next; 385 }; 386 387 /// Operations are exposed by the C-API as a forward-only linked list. In 388 /// Python, we present them as a more full-featured list-like container but 389 /// optimize it for forward iteration. Iterable operations are always owned 390 /// by a block. 391 class PyOperationList { 392 public: 393 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 394 : parentOperation(std::move(parentOperation)), block(block) {} 395 396 PyOperationIterator dunderIter() { 397 parentOperation->checkValid(); 398 return PyOperationIterator(parentOperation, 399 mlirBlockGetFirstOperation(block)); 400 } 401 402 intptr_t dunderLen() { 403 parentOperation->checkValid(); 404 intptr_t count = 0; 405 MlirOperation childOp = mlirBlockGetFirstOperation(block); 406 while (!mlirOperationIsNull(childOp)) { 407 count += 1; 408 childOp = mlirOperationGetNextInBlock(childOp); 409 } 410 return count; 411 } 412 413 py::object dunderGetItem(intptr_t index) { 414 parentOperation->checkValid(); 415 if (index < 0) { 416 throw SetPyError(PyExc_IndexError, 417 "attempt to access out of bounds operation"); 418 } 419 MlirOperation childOp = mlirBlockGetFirstOperation(block); 420 while (!mlirOperationIsNull(childOp)) { 421 if (index == 0) { 422 return PyOperation::forOperation(parentOperation->getContext(), childOp) 423 ->createOpView(); 424 } 425 childOp = mlirOperationGetNextInBlock(childOp); 426 index -= 1; 427 } 428 throw SetPyError(PyExc_IndexError, 429 "attempt to access out of bounds operation"); 430 } 431 432 static void bind(py::module &m) { 433 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 434 .def("__getitem__", &PyOperationList::dunderGetItem) 435 .def("__iter__", &PyOperationList::dunderIter) 436 .def("__len__", &PyOperationList::dunderLen); 437 } 438 439 private: 440 PyOperationRef parentOperation; 441 MlirBlock block; 442 }; 443 444 } // namespace 445 446 //------------------------------------------------------------------------------ 447 // PyMlirContext 448 //------------------------------------------------------------------------------ 449 450 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 451 py::gil_scoped_acquire acquire; 452 auto &liveContexts = getLiveContexts(); 453 liveContexts[context.ptr] = this; 454 } 455 456 PyMlirContext::~PyMlirContext() { 457 // Note that the only public way to construct an instance is via the 458 // forContext method, which always puts the associated handle into 459 // liveContexts. 460 py::gil_scoped_acquire acquire; 461 getLiveContexts().erase(context.ptr); 462 mlirContextDestroy(context); 463 } 464 465 py::object PyMlirContext::getCapsule() { 466 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 467 } 468 469 py::object PyMlirContext::createFromCapsule(py::object capsule) { 470 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 471 if (mlirContextIsNull(rawContext)) 472 throw py::error_already_set(); 473 return forContext(rawContext).releaseObject(); 474 } 475 476 PyMlirContext *PyMlirContext::createNewContextForInit() { 477 MlirContext context = mlirContextCreate(); 478 mlirRegisterAllDialects(context); 479 return new PyMlirContext(context); 480 } 481 482 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 483 py::gil_scoped_acquire acquire; 484 auto &liveContexts = getLiveContexts(); 485 auto it = liveContexts.find(context.ptr); 486 if (it == liveContexts.end()) { 487 // Create. 488 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 489 py::object pyRef = py::cast(unownedContextWrapper); 490 assert(pyRef && "cast to py::object failed"); 491 liveContexts[context.ptr] = unownedContextWrapper; 492 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 493 } 494 // Use existing. 495 py::object pyRef = py::cast(it->second); 496 return PyMlirContextRef(it->second, std::move(pyRef)); 497 } 498 499 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 500 static LiveContextMap liveContexts; 501 return liveContexts; 502 } 503 504 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 505 506 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 507 508 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 509 510 pybind11::object PyMlirContext::contextEnter() { 511 return PyThreadContextEntry::pushContext(*this); 512 } 513 514 void PyMlirContext::contextExit(const pybind11::object &excType, 515 const pybind11::object &excVal, 516 const pybind11::object &excTb) { 517 PyThreadContextEntry::popContext(*this); 518 } 519 520 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { 521 // Note that ownership is transferred to the delete callback below by way of 522 // an explicit inc_ref (borrow). 523 PyDiagnosticHandler *pyHandler = 524 new PyDiagnosticHandler(get(), std::move(callback)); 525 py::object pyHandlerObject = 526 py::cast(pyHandler, py::return_value_policy::take_ownership); 527 pyHandlerObject.inc_ref(); 528 529 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 530 // guaranteed to be known to pybind. 531 auto handlerCallback = 532 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 533 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 534 py::object pyDiagnosticObject = 535 py::cast(pyDiagnostic, py::return_value_policy::take_ownership); 536 537 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 538 bool result = false; 539 { 540 // Since this can be called from arbitrary C++ contexts, always get the 541 // gil. 542 py::gil_scoped_acquire gil; 543 try { 544 result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); 545 } catch (std::exception &e) { 546 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 547 e.what()); 548 pyHandler->hadError = true; 549 } 550 } 551 552 pyDiagnostic->invalidate(); 553 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 554 }; 555 auto deleteCallback = +[](void *userData) { 556 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 557 assert(pyHandler->registeredID && "handler is not registered"); 558 pyHandler->registeredID.reset(); 559 560 // Decrement reference, balancing the inc_ref() above. 561 py::object pyHandlerObject = 562 py::cast(pyHandler, py::return_value_policy::reference); 563 pyHandlerObject.dec_ref(); 564 }; 565 566 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 567 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 568 return pyHandlerObject; 569 } 570 571 PyMlirContext &DefaultingPyMlirContext::resolve() { 572 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 573 if (!context) { 574 throw SetPyError( 575 PyExc_RuntimeError, 576 "An MLIR function requires a Context but none was provided in the call " 577 "or from the surrounding environment. Either pass to the function with " 578 "a 'context=' argument or establish a default using 'with Context():'"); 579 } 580 return *context; 581 } 582 583 //------------------------------------------------------------------------------ 584 // PyThreadContextEntry management 585 //------------------------------------------------------------------------------ 586 587 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 588 static thread_local std::vector<PyThreadContextEntry> stack; 589 return stack; 590 } 591 592 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 593 auto &stack = getStack(); 594 if (stack.empty()) 595 return nullptr; 596 return &stack.back(); 597 } 598 599 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 600 py::object insertionPoint, 601 py::object location) { 602 auto &stack = getStack(); 603 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 604 std::move(location)); 605 // If the new stack has more than one entry and the context of the new top 606 // entry matches the previous, copy the insertionPoint and location from the 607 // previous entry if missing from the new top entry. 608 if (stack.size() > 1) { 609 auto &prev = *(stack.rbegin() + 1); 610 auto ¤t = stack.back(); 611 if (current.context.is(prev.context)) { 612 // Default non-context objects from the previous entry. 613 if (!current.insertionPoint) 614 current.insertionPoint = prev.insertionPoint; 615 if (!current.location) 616 current.location = prev.location; 617 } 618 } 619 } 620 621 PyMlirContext *PyThreadContextEntry::getContext() { 622 if (!context) 623 return nullptr; 624 return py::cast<PyMlirContext *>(context); 625 } 626 627 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 628 if (!insertionPoint) 629 return nullptr; 630 return py::cast<PyInsertionPoint *>(insertionPoint); 631 } 632 633 PyLocation *PyThreadContextEntry::getLocation() { 634 if (!location) 635 return nullptr; 636 return py::cast<PyLocation *>(location); 637 } 638 639 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 640 auto *tos = getTopOfStack(); 641 return tos ? tos->getContext() : nullptr; 642 } 643 644 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 645 auto *tos = getTopOfStack(); 646 return tos ? tos->getInsertionPoint() : nullptr; 647 } 648 649 PyLocation *PyThreadContextEntry::getDefaultLocation() { 650 auto *tos = getTopOfStack(); 651 return tos ? tos->getLocation() : nullptr; 652 } 653 654 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 655 py::object contextObj = py::cast(context); 656 push(FrameKind::Context, /*context=*/contextObj, 657 /*insertionPoint=*/py::object(), 658 /*location=*/py::object()); 659 return contextObj; 660 } 661 662 void PyThreadContextEntry::popContext(PyMlirContext &context) { 663 auto &stack = getStack(); 664 if (stack.empty()) 665 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 666 auto &tos = stack.back(); 667 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 668 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 669 stack.pop_back(); 670 } 671 672 py::object 673 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 674 py::object contextObj = 675 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 676 py::object insertionPointObj = py::cast(insertionPoint); 677 push(FrameKind::InsertionPoint, 678 /*context=*/contextObj, 679 /*insertionPoint=*/insertionPointObj, 680 /*location=*/py::object()); 681 return insertionPointObj; 682 } 683 684 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 685 auto &stack = getStack(); 686 if (stack.empty()) 687 throw SetPyError(PyExc_RuntimeError, 688 "Unbalanced InsertionPoint enter/exit"); 689 auto &tos = stack.back(); 690 if (tos.frameKind != FrameKind::InsertionPoint && 691 tos.getInsertionPoint() != &insertionPoint) 692 throw SetPyError(PyExc_RuntimeError, 693 "Unbalanced InsertionPoint enter/exit"); 694 stack.pop_back(); 695 } 696 697 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 698 py::object contextObj = location.getContext().getObject(); 699 py::object locationObj = py::cast(location); 700 push(FrameKind::Location, /*context=*/contextObj, 701 /*insertionPoint=*/py::object(), 702 /*location=*/locationObj); 703 return locationObj; 704 } 705 706 void PyThreadContextEntry::popLocation(PyLocation &location) { 707 auto &stack = getStack(); 708 if (stack.empty()) 709 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 710 auto &tos = stack.back(); 711 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 712 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 713 stack.pop_back(); 714 } 715 716 //------------------------------------------------------------------------------ 717 // PyDiagnostic* 718 //------------------------------------------------------------------------------ 719 720 void PyDiagnostic::invalidate() { 721 valid = false; 722 if (materializedNotes) { 723 for (auto ¬eObject : *materializedNotes) { 724 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); 725 note->invalidate(); 726 } 727 } 728 } 729 730 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 731 py::object callback) 732 : context(context), callback(std::move(callback)) {} 733 734 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 735 736 void PyDiagnosticHandler::detach() { 737 if (!registeredID) 738 return; 739 MlirDiagnosticHandlerID localID = *registeredID; 740 mlirContextDetachDiagnosticHandler(context, localID); 741 assert(!registeredID && "should have unregistered"); 742 // Not strictly necessary but keeps stale pointers from being around to cause 743 // issues. 744 context = {nullptr}; 745 } 746 747 void PyDiagnostic::checkValid() { 748 if (!valid) { 749 throw std::invalid_argument( 750 "Diagnostic is invalid (used outside of callback)"); 751 } 752 } 753 754 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 755 checkValid(); 756 return mlirDiagnosticGetSeverity(diagnostic); 757 } 758 759 PyLocation PyDiagnostic::getLocation() { 760 checkValid(); 761 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 762 MlirContext context = mlirLocationGetContext(loc); 763 return PyLocation(PyMlirContext::forContext(context), loc); 764 } 765 766 py::str PyDiagnostic::getMessage() { 767 checkValid(); 768 py::object fileObject = py::module::import("io").attr("StringIO")(); 769 PyFileAccumulator accum(fileObject, /*binary=*/false); 770 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 771 return fileObject.attr("getvalue")(); 772 } 773 774 py::tuple PyDiagnostic::getNotes() { 775 checkValid(); 776 if (materializedNotes) 777 return *materializedNotes; 778 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 779 materializedNotes = py::tuple(numNotes); 780 for (intptr_t i = 0; i < numNotes; ++i) { 781 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 782 py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); 783 PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); 784 } 785 return *materializedNotes; 786 } 787 788 //------------------------------------------------------------------------------ 789 // PyDialect, PyDialectDescriptor, PyDialects 790 //------------------------------------------------------------------------------ 791 792 MlirDialect PyDialects::getDialectForKey(const std::string &key, 793 bool attrError) { 794 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 795 {key.data(), key.size()}); 796 if (mlirDialectIsNull(dialect)) { 797 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 798 Twine("Dialect '") + key + "' not found"); 799 } 800 return dialect; 801 } 802 803 //------------------------------------------------------------------------------ 804 // PyLocation 805 //------------------------------------------------------------------------------ 806 807 py::object PyLocation::getCapsule() { 808 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 809 } 810 811 PyLocation PyLocation::createFromCapsule(py::object capsule) { 812 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 813 if (mlirLocationIsNull(rawLoc)) 814 throw py::error_already_set(); 815 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 816 rawLoc); 817 } 818 819 py::object PyLocation::contextEnter() { 820 return PyThreadContextEntry::pushLocation(*this); 821 } 822 823 void PyLocation::contextExit(const pybind11::object &excType, 824 const pybind11::object &excVal, 825 const pybind11::object &excTb) { 826 PyThreadContextEntry::popLocation(*this); 827 } 828 829 PyLocation &DefaultingPyLocation::resolve() { 830 auto *location = PyThreadContextEntry::getDefaultLocation(); 831 if (!location) { 832 throw SetPyError( 833 PyExc_RuntimeError, 834 "An MLIR function requires a Location but none was provided in the " 835 "call or from the surrounding environment. Either pass to the function " 836 "with a 'loc=' argument or establish a default using 'with loc:'"); 837 } 838 return *location; 839 } 840 841 //------------------------------------------------------------------------------ 842 // PyModule 843 //------------------------------------------------------------------------------ 844 845 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 846 : BaseContextObject(std::move(contextRef)), module(module) {} 847 848 PyModule::~PyModule() { 849 py::gil_scoped_acquire acquire; 850 auto &liveModules = getContext()->liveModules; 851 assert(liveModules.count(module.ptr) == 1 && 852 "destroying module not in live map"); 853 liveModules.erase(module.ptr); 854 mlirModuleDestroy(module); 855 } 856 857 PyModuleRef PyModule::forModule(MlirModule module) { 858 MlirContext context = mlirModuleGetContext(module); 859 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 860 861 py::gil_scoped_acquire acquire; 862 auto &liveModules = contextRef->liveModules; 863 auto it = liveModules.find(module.ptr); 864 if (it == liveModules.end()) { 865 // Create. 866 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 867 // Note that the default return value policy on cast is automatic_reference, 868 // which does not take ownership (delete will not be called). 869 // Just be explicit. 870 py::object pyRef = 871 py::cast(unownedModule, py::return_value_policy::take_ownership); 872 unownedModule->handle = pyRef; 873 liveModules[module.ptr] = 874 std::make_pair(unownedModule->handle, unownedModule); 875 return PyModuleRef(unownedModule, std::move(pyRef)); 876 } 877 // Use existing. 878 PyModule *existing = it->second.second; 879 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 880 return PyModuleRef(existing, std::move(pyRef)); 881 } 882 883 py::object PyModule::createFromCapsule(py::object capsule) { 884 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 885 if (mlirModuleIsNull(rawModule)) 886 throw py::error_already_set(); 887 return forModule(rawModule).releaseObject(); 888 } 889 890 py::object PyModule::getCapsule() { 891 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 892 } 893 894 //------------------------------------------------------------------------------ 895 // PyOperation 896 //------------------------------------------------------------------------------ 897 898 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 899 : BaseContextObject(std::move(contextRef)), operation(operation) {} 900 901 PyOperation::~PyOperation() { 902 // If the operation has already been invalidated there is nothing to do. 903 if (!valid) 904 return; 905 auto &liveOperations = getContext()->liveOperations; 906 assert(liveOperations.count(operation.ptr) == 1 && 907 "destroying operation not in live map"); 908 liveOperations.erase(operation.ptr); 909 if (!isAttached()) { 910 mlirOperationDestroy(operation); 911 } 912 } 913 914 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 915 MlirOperation operation, 916 py::object parentKeepAlive) { 917 auto &liveOperations = contextRef->liveOperations; 918 // Create. 919 PyOperation *unownedOperation = 920 new PyOperation(std::move(contextRef), operation); 921 // Note that the default return value policy on cast is automatic_reference, 922 // which does not take ownership (delete will not be called). 923 // Just be explicit. 924 py::object pyRef = 925 py::cast(unownedOperation, py::return_value_policy::take_ownership); 926 unownedOperation->handle = pyRef; 927 if (parentKeepAlive) { 928 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 929 } 930 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 931 return PyOperationRef(unownedOperation, std::move(pyRef)); 932 } 933 934 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 935 MlirOperation operation, 936 py::object parentKeepAlive) { 937 auto &liveOperations = contextRef->liveOperations; 938 auto it = liveOperations.find(operation.ptr); 939 if (it == liveOperations.end()) { 940 // Create. 941 return createInstance(std::move(contextRef), operation, 942 std::move(parentKeepAlive)); 943 } 944 // Use existing. 945 PyOperation *existing = it->second.second; 946 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 947 return PyOperationRef(existing, std::move(pyRef)); 948 } 949 950 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 951 MlirOperation operation, 952 py::object parentKeepAlive) { 953 auto &liveOperations = contextRef->liveOperations; 954 assert(liveOperations.count(operation.ptr) == 0 && 955 "cannot create detached operation that already exists"); 956 (void)liveOperations; 957 958 PyOperationRef created = createInstance(std::move(contextRef), operation, 959 std::move(parentKeepAlive)); 960 created->attached = false; 961 return created; 962 } 963 964 void PyOperation::checkValid() const { 965 if (!valid) { 966 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 967 } 968 } 969 970 void PyOperationBase::print(py::object fileObject, bool binary, 971 llvm::Optional<int64_t> largeElementsLimit, 972 bool enableDebugInfo, bool prettyDebugInfo, 973 bool printGenericOpForm, bool useLocalScope, 974 bool assumeVerified) { 975 PyOperation &operation = getOperation(); 976 operation.checkValid(); 977 if (fileObject.is_none()) 978 fileObject = py::module::import("sys").attr("stdout"); 979 980 if (!assumeVerified && !printGenericOpForm && 981 !mlirOperationVerify(operation)) { 982 std::string message("// Verification failed, printing generic form\n"); 983 if (binary) { 984 fileObject.attr("write")(py::bytes(message)); 985 } else { 986 fileObject.attr("write")(py::str(message)); 987 } 988 printGenericOpForm = true; 989 } 990 991 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 992 if (largeElementsLimit) 993 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 994 if (enableDebugInfo) 995 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 996 if (printGenericOpForm) 997 mlirOpPrintingFlagsPrintGenericOpForm(flags); 998 999 PyFileAccumulator accum(fileObject, binary); 1000 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1001 accum.getUserData()); 1002 mlirOpPrintingFlagsDestroy(flags); 1003 } 1004 1005 py::object PyOperationBase::getAsm(bool binary, 1006 llvm::Optional<int64_t> largeElementsLimit, 1007 bool enableDebugInfo, bool prettyDebugInfo, 1008 bool printGenericOpForm, bool useLocalScope, 1009 bool assumeVerified) { 1010 py::object fileObject; 1011 if (binary) { 1012 fileObject = py::module::import("io").attr("BytesIO")(); 1013 } else { 1014 fileObject = py::module::import("io").attr("StringIO")(); 1015 } 1016 print(fileObject, /*binary=*/binary, 1017 /*largeElementsLimit=*/largeElementsLimit, 1018 /*enableDebugInfo=*/enableDebugInfo, 1019 /*prettyDebugInfo=*/prettyDebugInfo, 1020 /*printGenericOpForm=*/printGenericOpForm, 1021 /*useLocalScope=*/useLocalScope, 1022 /*assumeVerified=*/assumeVerified); 1023 1024 return fileObject.attr("getvalue")(); 1025 } 1026 1027 void PyOperationBase::moveAfter(PyOperationBase &other) { 1028 PyOperation &operation = getOperation(); 1029 PyOperation &otherOp = other.getOperation(); 1030 operation.checkValid(); 1031 otherOp.checkValid(); 1032 mlirOperationMoveAfter(operation, otherOp); 1033 operation.parentKeepAlive = otherOp.parentKeepAlive; 1034 } 1035 1036 void PyOperationBase::moveBefore(PyOperationBase &other) { 1037 PyOperation &operation = getOperation(); 1038 PyOperation &otherOp = other.getOperation(); 1039 operation.checkValid(); 1040 otherOp.checkValid(); 1041 mlirOperationMoveBefore(operation, otherOp); 1042 operation.parentKeepAlive = otherOp.parentKeepAlive; 1043 } 1044 1045 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 1046 checkValid(); 1047 if (!isAttached()) 1048 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 1049 MlirOperation operation = mlirOperationGetParentOperation(get()); 1050 if (mlirOperationIsNull(operation)) 1051 return {}; 1052 return PyOperation::forOperation(getContext(), operation); 1053 } 1054 1055 PyBlock PyOperation::getBlock() { 1056 checkValid(); 1057 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 1058 MlirBlock block = mlirOperationGetBlock(get()); 1059 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1060 assert(parentOperation && "Operation has no parent"); 1061 return PyBlock{std::move(*parentOperation), block}; 1062 } 1063 1064 py::object PyOperation::getCapsule() { 1065 checkValid(); 1066 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 1067 } 1068 1069 py::object PyOperation::createFromCapsule(py::object capsule) { 1070 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1071 if (mlirOperationIsNull(rawOperation)) 1072 throw py::error_already_set(); 1073 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1074 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1075 .releaseObject(); 1076 } 1077 1078 static void maybeInsertOperation(PyOperationRef &op, 1079 const py::object &maybeIp) { 1080 // InsertPoint active? 1081 if (!maybeIp.is(py::cast(false))) { 1082 PyInsertionPoint *ip; 1083 if (maybeIp.is_none()) { 1084 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1085 } else { 1086 ip = py::cast<PyInsertionPoint *>(maybeIp); 1087 } 1088 if (ip) 1089 ip->insert(*op.get()); 1090 } 1091 } 1092 1093 py::object PyOperation::create( 1094 const std::string &name, llvm::Optional<std::vector<PyType *>> results, 1095 llvm::Optional<std::vector<PyValue *>> operands, 1096 llvm::Optional<py::dict> attributes, 1097 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 1098 DefaultingPyLocation location, const py::object &maybeIp) { 1099 llvm::SmallVector<MlirValue, 4> mlirOperands; 1100 llvm::SmallVector<MlirType, 4> mlirResults; 1101 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1102 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1103 1104 // General parameter validation. 1105 if (regions < 0) 1106 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 1107 1108 // Unpack/validate operands. 1109 if (operands) { 1110 mlirOperands.reserve(operands->size()); 1111 for (PyValue *operand : *operands) { 1112 if (!operand) 1113 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 1114 mlirOperands.push_back(operand->get()); 1115 } 1116 } 1117 1118 // Unpack/validate results. 1119 if (results) { 1120 mlirResults.reserve(results->size()); 1121 for (PyType *result : *results) { 1122 // TODO: Verify result type originate from the same context. 1123 if (!result) 1124 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 1125 mlirResults.push_back(*result); 1126 } 1127 } 1128 // Unpack/validate attributes. 1129 if (attributes) { 1130 mlirAttributes.reserve(attributes->size()); 1131 for (auto &it : *attributes) { 1132 std::string key; 1133 try { 1134 key = it.first.cast<std::string>(); 1135 } catch (py::cast_error &err) { 1136 std::string msg = "Invalid attribute key (not a string) when " 1137 "attempting to create the operation \"" + 1138 name + "\" (" + err.what() + ")"; 1139 throw py::cast_error(msg); 1140 } 1141 try { 1142 auto &attribute = it.second.cast<PyAttribute &>(); 1143 // TODO: Verify attribute originates from the same context. 1144 mlirAttributes.emplace_back(std::move(key), attribute); 1145 } catch (py::reference_cast_error &) { 1146 // This exception seems thrown when the value is "None". 1147 std::string msg = 1148 "Found an invalid (`None`?) attribute value for the key \"" + key + 1149 "\" when attempting to create the operation \"" + name + "\""; 1150 throw py::cast_error(msg); 1151 } catch (py::cast_error &err) { 1152 std::string msg = "Invalid attribute value for the key \"" + key + 1153 "\" when attempting to create the operation \"" + 1154 name + "\" (" + err.what() + ")"; 1155 throw py::cast_error(msg); 1156 } 1157 } 1158 } 1159 // Unpack/validate successors. 1160 if (successors) { 1161 mlirSuccessors.reserve(successors->size()); 1162 for (auto *successor : *successors) { 1163 // TODO: Verify successor originate from the same context. 1164 if (!successor) 1165 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1166 mlirSuccessors.push_back(successor->get()); 1167 } 1168 } 1169 1170 // Apply unpacked/validated to the operation state. Beyond this 1171 // point, exceptions cannot be thrown or else the state will leak. 1172 MlirOperationState state = 1173 mlirOperationStateGet(toMlirStringRef(name), location); 1174 if (!mlirOperands.empty()) 1175 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1176 mlirOperands.data()); 1177 if (!mlirResults.empty()) 1178 mlirOperationStateAddResults(&state, mlirResults.size(), 1179 mlirResults.data()); 1180 if (!mlirAttributes.empty()) { 1181 // Note that the attribute names directly reference bytes in 1182 // mlirAttributes, so that vector must not be changed from here 1183 // on. 1184 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1185 mlirNamedAttributes.reserve(mlirAttributes.size()); 1186 for (auto &it : mlirAttributes) 1187 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1188 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1189 toMlirStringRef(it.first)), 1190 it.second)); 1191 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1192 mlirNamedAttributes.data()); 1193 } 1194 if (!mlirSuccessors.empty()) 1195 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1196 mlirSuccessors.data()); 1197 if (regions) { 1198 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1199 mlirRegions.resize(regions); 1200 for (int i = 0; i < regions; ++i) 1201 mlirRegions[i] = mlirRegionCreate(); 1202 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1203 mlirRegions.data()); 1204 } 1205 1206 // Construct the operation. 1207 MlirOperation operation = mlirOperationCreate(&state); 1208 PyOperationRef created = 1209 PyOperation::createDetached(location->getContext(), operation); 1210 maybeInsertOperation(created, maybeIp); 1211 1212 return created->createOpView(); 1213 } 1214 1215 py::object PyOperation::clone(const py::object &maybeIp) { 1216 MlirOperation clonedOperation = mlirOperationClone(operation); 1217 PyOperationRef cloned = 1218 PyOperation::createDetached(getContext(), clonedOperation); 1219 maybeInsertOperation(cloned, maybeIp); 1220 1221 return cloned->createOpView(); 1222 } 1223 1224 py::object PyOperation::createOpView() { 1225 checkValid(); 1226 MlirIdentifier ident = mlirOperationGetName(get()); 1227 MlirStringRef identStr = mlirIdentifierStr(ident); 1228 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1229 StringRef(identStr.data, identStr.length)); 1230 if (opViewClass) 1231 return (*opViewClass)(getRef().getObject()); 1232 return py::cast(PyOpView(getRef().getObject())); 1233 } 1234 1235 void PyOperation::erase() { 1236 checkValid(); 1237 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1238 // Python reference to a child operation is live. All children should also 1239 // have their `valid` bit set to false. 1240 auto &liveOperations = getContext()->liveOperations; 1241 if (liveOperations.count(operation.ptr)) 1242 liveOperations.erase(operation.ptr); 1243 mlirOperationDestroy(operation); 1244 valid = false; 1245 } 1246 1247 //------------------------------------------------------------------------------ 1248 // PyOpView 1249 //------------------------------------------------------------------------------ 1250 1251 py::object PyOpView::buildGeneric( 1252 const py::object &cls, py::list resultTypeList, py::list operandList, 1253 llvm::Optional<py::dict> attributes, 1254 llvm::Optional<std::vector<PyBlock *>> successors, 1255 llvm::Optional<int> regions, DefaultingPyLocation location, 1256 const py::object &maybeIp) { 1257 PyMlirContextRef context = location->getContext(); 1258 // Class level operation construction metadata. 1259 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1260 // Operand and result segment specs are either none, which does no 1261 // variadic unpacking, or a list of ints with segment sizes, where each 1262 // element is either a positive number (typically 1 for a scalar) or -1 to 1263 // indicate that it is derived from the length of the same-indexed operand 1264 // or result (implying that it is a list at that position). 1265 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1266 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1267 1268 std::vector<uint32_t> operandSegmentLengths; 1269 std::vector<uint32_t> resultSegmentLengths; 1270 1271 // Validate/determine region count. 1272 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1273 int opMinRegionCount = std::get<0>(opRegionSpec); 1274 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1275 if (!regions) { 1276 regions = opMinRegionCount; 1277 } 1278 if (*regions < opMinRegionCount) { 1279 throw py::value_error( 1280 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1281 llvm::Twine(opMinRegionCount) + 1282 " regions but was built with regions=" + llvm::Twine(*regions)) 1283 .str()); 1284 } 1285 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1286 throw py::value_error( 1287 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1288 llvm::Twine(opMinRegionCount) + 1289 " regions but was built with regions=" + llvm::Twine(*regions)) 1290 .str()); 1291 } 1292 1293 // Unpack results. 1294 std::vector<PyType *> resultTypes; 1295 resultTypes.reserve(resultTypeList.size()); 1296 if (resultSegmentSpecObj.is_none()) { 1297 // Non-variadic result unpacking. 1298 for (const auto &it : llvm::enumerate(resultTypeList)) { 1299 try { 1300 resultTypes.push_back(py::cast<PyType *>(it.value())); 1301 if (!resultTypes.back()) 1302 throw py::cast_error(); 1303 } catch (py::cast_error &err) { 1304 throw py::value_error((llvm::Twine("Result ") + 1305 llvm::Twine(it.index()) + " of operation \"" + 1306 name + "\" must be a Type (" + err.what() + ")") 1307 .str()); 1308 } 1309 } 1310 } else { 1311 // Sized result unpacking. 1312 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1313 if (resultSegmentSpec.size() != resultTypeList.size()) { 1314 throw py::value_error((llvm::Twine("Operation \"") + name + 1315 "\" requires " + 1316 llvm::Twine(resultSegmentSpec.size()) + 1317 " result segments but was provided " + 1318 llvm::Twine(resultTypeList.size())) 1319 .str()); 1320 } 1321 resultSegmentLengths.reserve(resultTypeList.size()); 1322 for (const auto &it : 1323 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1324 int segmentSpec = std::get<1>(it.value()); 1325 if (segmentSpec == 1 || segmentSpec == 0) { 1326 // Unpack unary element. 1327 try { 1328 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1329 if (resultType) { 1330 resultTypes.push_back(resultType); 1331 resultSegmentLengths.push_back(1); 1332 } else if (segmentSpec == 0) { 1333 // Allowed to be optional. 1334 resultSegmentLengths.push_back(0); 1335 } else { 1336 throw py::cast_error("was None and result is not optional"); 1337 } 1338 } catch (py::cast_error &err) { 1339 throw py::value_error((llvm::Twine("Result ") + 1340 llvm::Twine(it.index()) + " of operation \"" + 1341 name + "\" must be a Type (" + err.what() + 1342 ")") 1343 .str()); 1344 } 1345 } else if (segmentSpec == -1) { 1346 // Unpack sequence by appending. 1347 try { 1348 if (std::get<0>(it.value()).is_none()) { 1349 // Treat it as an empty list. 1350 resultSegmentLengths.push_back(0); 1351 } else { 1352 // Unpack the list. 1353 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1354 for (py::object segmentItem : segment) { 1355 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1356 if (!resultTypes.back()) { 1357 throw py::cast_error("contained a None item"); 1358 } 1359 } 1360 resultSegmentLengths.push_back(segment.size()); 1361 } 1362 } catch (std::exception &err) { 1363 // NOTE: Sloppy to be using a catch-all here, but there are at least 1364 // three different unrelated exceptions that can be thrown in the 1365 // above "casts". Just keep the scope above small and catch them all. 1366 throw py::value_error((llvm::Twine("Result ") + 1367 llvm::Twine(it.index()) + " of operation \"" + 1368 name + "\" must be a Sequence of Types (" + 1369 err.what() + ")") 1370 .str()); 1371 } 1372 } else { 1373 throw py::value_error("Unexpected segment spec"); 1374 } 1375 } 1376 } 1377 1378 // Unpack operands. 1379 std::vector<PyValue *> operands; 1380 operands.reserve(operands.size()); 1381 if (operandSegmentSpecObj.is_none()) { 1382 // Non-sized operand unpacking. 1383 for (const auto &it : llvm::enumerate(operandList)) { 1384 try { 1385 operands.push_back(py::cast<PyValue *>(it.value())); 1386 if (!operands.back()) 1387 throw py::cast_error(); 1388 } catch (py::cast_error &err) { 1389 throw py::value_error((llvm::Twine("Operand ") + 1390 llvm::Twine(it.index()) + " of operation \"" + 1391 name + "\" must be a Value (" + err.what() + ")") 1392 .str()); 1393 } 1394 } 1395 } else { 1396 // Sized operand unpacking. 1397 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1398 if (operandSegmentSpec.size() != operandList.size()) { 1399 throw py::value_error((llvm::Twine("Operation \"") + name + 1400 "\" requires " + 1401 llvm::Twine(operandSegmentSpec.size()) + 1402 "operand segments but was provided " + 1403 llvm::Twine(operandList.size())) 1404 .str()); 1405 } 1406 operandSegmentLengths.reserve(operandList.size()); 1407 for (const auto &it : 1408 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1409 int segmentSpec = std::get<1>(it.value()); 1410 if (segmentSpec == 1 || segmentSpec == 0) { 1411 // Unpack unary element. 1412 try { 1413 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1414 if (operandValue) { 1415 operands.push_back(operandValue); 1416 operandSegmentLengths.push_back(1); 1417 } else if (segmentSpec == 0) { 1418 // Allowed to be optional. 1419 operandSegmentLengths.push_back(0); 1420 } else { 1421 throw py::cast_error("was None and operand is not optional"); 1422 } 1423 } catch (py::cast_error &err) { 1424 throw py::value_error((llvm::Twine("Operand ") + 1425 llvm::Twine(it.index()) + " of operation \"" + 1426 name + "\" must be a Value (" + err.what() + 1427 ")") 1428 .str()); 1429 } 1430 } else if (segmentSpec == -1) { 1431 // Unpack sequence by appending. 1432 try { 1433 if (std::get<0>(it.value()).is_none()) { 1434 // Treat it as an empty list. 1435 operandSegmentLengths.push_back(0); 1436 } else { 1437 // Unpack the list. 1438 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1439 for (py::object segmentItem : segment) { 1440 operands.push_back(py::cast<PyValue *>(segmentItem)); 1441 if (!operands.back()) { 1442 throw py::cast_error("contained a None item"); 1443 } 1444 } 1445 operandSegmentLengths.push_back(segment.size()); 1446 } 1447 } catch (std::exception &err) { 1448 // NOTE: Sloppy to be using a catch-all here, but there are at least 1449 // three different unrelated exceptions that can be thrown in the 1450 // above "casts". Just keep the scope above small and catch them all. 1451 throw py::value_error((llvm::Twine("Operand ") + 1452 llvm::Twine(it.index()) + " of operation \"" + 1453 name + "\" must be a Sequence of Values (" + 1454 err.what() + ")") 1455 .str()); 1456 } 1457 } else { 1458 throw py::value_error("Unexpected segment spec"); 1459 } 1460 } 1461 } 1462 1463 // Merge operand/result segment lengths into attributes if needed. 1464 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1465 // Dup. 1466 if (attributes) { 1467 attributes = py::dict(*attributes); 1468 } else { 1469 attributes = py::dict(); 1470 } 1471 if (attributes->contains("result_segment_sizes") || 1472 attributes->contains("operand_segment_sizes")) { 1473 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1474 "'operand_segment_sizes' attribute is unsupported. " 1475 "Use Operation.create for such low-level access."); 1476 } 1477 1478 // Add result_segment_sizes attribute. 1479 if (!resultSegmentLengths.empty()) { 1480 int64_t size = resultSegmentLengths.size(); 1481 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1482 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1483 resultSegmentLengths.size(), resultSegmentLengths.data()); 1484 (*attributes)["result_segment_sizes"] = 1485 PyAttribute(context, segmentLengthAttr); 1486 } 1487 1488 // Add operand_segment_sizes attribute. 1489 if (!operandSegmentLengths.empty()) { 1490 int64_t size = operandSegmentLengths.size(); 1491 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1492 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1493 operandSegmentLengths.size(), operandSegmentLengths.data()); 1494 (*attributes)["operand_segment_sizes"] = 1495 PyAttribute(context, segmentLengthAttr); 1496 } 1497 } 1498 1499 // Delegate to create. 1500 return PyOperation::create(name, 1501 /*results=*/std::move(resultTypes), 1502 /*operands=*/std::move(operands), 1503 /*attributes=*/std::move(attributes), 1504 /*successors=*/std::move(successors), 1505 /*regions=*/*regions, location, maybeIp); 1506 } 1507 1508 PyOpView::PyOpView(const py::object &operationObject) 1509 // Casting through the PyOperationBase base-class and then back to the 1510 // Operation lets us accept any PyOperationBase subclass. 1511 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1512 operationObject(operation.getRef().getObject()) {} 1513 1514 py::object PyOpView::createRawSubclass(const py::object &userClass) { 1515 // This is... a little gross. The typical pattern is to have a pure python 1516 // class that extends OpView like: 1517 // class AddFOp(_cext.ir.OpView): 1518 // def __init__(self, loc, lhs, rhs): 1519 // operation = loc.context.create_operation( 1520 // "addf", lhs, rhs, results=[lhs.type]) 1521 // super().__init__(operation) 1522 // 1523 // I.e. The goal of the user facing type is to provide a nice constructor 1524 // that has complete freedom for the op under construction. This is at odds 1525 // with our other desire to sometimes create this object by just passing an 1526 // operation (to initialize the base class). We could do *arg and **kwargs 1527 // munging to try to make it work, but instead, we synthesize a new class 1528 // on the fly which extends this user class (AddFOp in this example) and 1529 // *give it* the base class's __init__ method, thus bypassing the 1530 // intermediate subclass's __init__ method entirely. While slightly, 1531 // underhanded, this is safe/legal because the type hierarchy has not changed 1532 // (we just added a new leaf) and we aren't mucking around with __new__. 1533 // Typically, this new class will be stored on the original as "_Raw" and will 1534 // be used for casts and other things that need a variant of the class that 1535 // is initialized purely from an operation. 1536 py::object parentMetaclass = 1537 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1538 py::dict attributes; 1539 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1540 // now. 1541 // auto opViewType = py::type::of<PyOpView>(); 1542 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1543 attributes["__init__"] = opViewType.attr("__init__"); 1544 py::str origName = userClass.attr("__name__"); 1545 py::str newName = py::str("_") + origName; 1546 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1547 } 1548 1549 //------------------------------------------------------------------------------ 1550 // PyInsertionPoint. 1551 //------------------------------------------------------------------------------ 1552 1553 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1554 1555 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1556 : refOperation(beforeOperationBase.getOperation().getRef()), 1557 block((*refOperation)->getBlock()) {} 1558 1559 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1560 PyOperation &operation = operationBase.getOperation(); 1561 if (operation.isAttached()) 1562 throw SetPyError(PyExc_ValueError, 1563 "Attempt to insert operation that is already attached"); 1564 block.getParentOperation()->checkValid(); 1565 MlirOperation beforeOp = {nullptr}; 1566 if (refOperation) { 1567 // Insert before operation. 1568 (*refOperation)->checkValid(); 1569 beforeOp = (*refOperation)->get(); 1570 } else { 1571 // Insert at end (before null) is only valid if the block does not 1572 // already end in a known terminator (violating this will cause assertion 1573 // failures later). 1574 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1575 throw py::index_error("Cannot insert operation at the end of a block " 1576 "that already has a terminator. Did you mean to " 1577 "use 'InsertionPoint.at_block_terminator(block)' " 1578 "versus 'InsertionPoint(block)'?"); 1579 } 1580 } 1581 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1582 operation.setAttached(); 1583 } 1584 1585 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1586 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1587 if (mlirOperationIsNull(firstOp)) { 1588 // Just insert at end. 1589 return PyInsertionPoint(block); 1590 } 1591 1592 // Insert before first op. 1593 PyOperationRef firstOpRef = PyOperation::forOperation( 1594 block.getParentOperation()->getContext(), firstOp); 1595 return PyInsertionPoint{block, std::move(firstOpRef)}; 1596 } 1597 1598 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1599 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1600 if (mlirOperationIsNull(terminator)) 1601 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1602 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1603 block.getParentOperation()->getContext(), terminator); 1604 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1605 } 1606 1607 py::object PyInsertionPoint::contextEnter() { 1608 return PyThreadContextEntry::pushInsertionPoint(*this); 1609 } 1610 1611 void PyInsertionPoint::contextExit(const pybind11::object &excType, 1612 const pybind11::object &excVal, 1613 const pybind11::object &excTb) { 1614 PyThreadContextEntry::popInsertionPoint(*this); 1615 } 1616 1617 //------------------------------------------------------------------------------ 1618 // PyAttribute. 1619 //------------------------------------------------------------------------------ 1620 1621 bool PyAttribute::operator==(const PyAttribute &other) { 1622 return mlirAttributeEqual(attr, other.attr); 1623 } 1624 1625 py::object PyAttribute::getCapsule() { 1626 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1627 } 1628 1629 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1630 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1631 if (mlirAttributeIsNull(rawAttr)) 1632 throw py::error_already_set(); 1633 return PyAttribute( 1634 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1635 } 1636 1637 //------------------------------------------------------------------------------ 1638 // PyNamedAttribute. 1639 //------------------------------------------------------------------------------ 1640 1641 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1642 : ownedName(new std::string(std::move(ownedName))) { 1643 namedAttr = mlirNamedAttributeGet( 1644 mlirIdentifierGet(mlirAttributeGetContext(attr), 1645 toMlirStringRef(*this->ownedName)), 1646 attr); 1647 } 1648 1649 //------------------------------------------------------------------------------ 1650 // PyType. 1651 //------------------------------------------------------------------------------ 1652 1653 bool PyType::operator==(const PyType &other) { 1654 return mlirTypeEqual(type, other.type); 1655 } 1656 1657 py::object PyType::getCapsule() { 1658 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1659 } 1660 1661 PyType PyType::createFromCapsule(py::object capsule) { 1662 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1663 if (mlirTypeIsNull(rawType)) 1664 throw py::error_already_set(); 1665 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1666 rawType); 1667 } 1668 1669 //------------------------------------------------------------------------------ 1670 // PyValue and subclases. 1671 //------------------------------------------------------------------------------ 1672 1673 pybind11::object PyValue::getCapsule() { 1674 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1675 } 1676 1677 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1678 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1679 if (mlirValueIsNull(value)) 1680 throw py::error_already_set(); 1681 MlirOperation owner; 1682 if (mlirValueIsAOpResult(value)) 1683 owner = mlirOpResultGetOwner(value); 1684 if (mlirValueIsABlockArgument(value)) 1685 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1686 if (mlirOperationIsNull(owner)) 1687 throw py::error_already_set(); 1688 MlirContext ctx = mlirOperationGetContext(owner); 1689 PyOperationRef ownerRef = 1690 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1691 return PyValue(ownerRef, value); 1692 } 1693 1694 //------------------------------------------------------------------------------ 1695 // PySymbolTable. 1696 //------------------------------------------------------------------------------ 1697 1698 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1699 : operation(operation.getOperation().getRef()) { 1700 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1701 if (mlirSymbolTableIsNull(symbolTable)) { 1702 throw py::cast_error("Operation is not a Symbol Table."); 1703 } 1704 } 1705 1706 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1707 operation->checkValid(); 1708 MlirOperation symbol = mlirSymbolTableLookup( 1709 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1710 if (mlirOperationIsNull(symbol)) 1711 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1712 1713 return PyOperation::forOperation(operation->getContext(), symbol, 1714 operation.getObject()) 1715 ->createOpView(); 1716 } 1717 1718 void PySymbolTable::erase(PyOperationBase &symbol) { 1719 operation->checkValid(); 1720 symbol.getOperation().checkValid(); 1721 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1722 // The operation is also erased, so we must invalidate it. There may be Python 1723 // references to this operation so we don't want to delete it from the list of 1724 // live operations here. 1725 symbol.getOperation().valid = false; 1726 } 1727 1728 void PySymbolTable::dunderDel(const std::string &name) { 1729 py::object operation = dunderGetItem(name); 1730 erase(py::cast<PyOperationBase &>(operation)); 1731 } 1732 1733 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1734 operation->checkValid(); 1735 symbol.getOperation().checkValid(); 1736 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1737 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1738 if (mlirAttributeIsNull(symbolAttr)) 1739 throw py::value_error("Expected operation to have a symbol name."); 1740 return PyAttribute( 1741 symbol.getOperation().getContext(), 1742 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1743 } 1744 1745 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1746 // Op must already be a symbol. 1747 PyOperation &operation = symbol.getOperation(); 1748 operation.checkValid(); 1749 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1750 MlirAttribute existingNameAttr = 1751 mlirOperationGetAttributeByName(operation.get(), attrName); 1752 if (mlirAttributeIsNull(existingNameAttr)) 1753 throw py::value_error("Expected operation to have a symbol name."); 1754 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1755 } 1756 1757 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1758 const std::string &name) { 1759 // Op must already be a symbol. 1760 PyOperation &operation = symbol.getOperation(); 1761 operation.checkValid(); 1762 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1763 MlirAttribute existingNameAttr = 1764 mlirOperationGetAttributeByName(operation.get(), attrName); 1765 if (mlirAttributeIsNull(existingNameAttr)) 1766 throw py::value_error("Expected operation to have a symbol name."); 1767 MlirAttribute newNameAttr = 1768 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1769 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1770 } 1771 1772 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1773 PyOperation &operation = symbol.getOperation(); 1774 operation.checkValid(); 1775 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1776 MlirAttribute existingVisAttr = 1777 mlirOperationGetAttributeByName(operation.get(), attrName); 1778 if (mlirAttributeIsNull(existingVisAttr)) 1779 throw py::value_error("Expected operation to have a symbol visibility."); 1780 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1781 } 1782 1783 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1784 const std::string &visibility) { 1785 if (visibility != "public" && visibility != "private" && 1786 visibility != "nested") 1787 throw py::value_error( 1788 "Expected visibility to be 'public', 'private' or 'nested'"); 1789 PyOperation &operation = symbol.getOperation(); 1790 operation.checkValid(); 1791 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1792 MlirAttribute existingVisAttr = 1793 mlirOperationGetAttributeByName(operation.get(), attrName); 1794 if (mlirAttributeIsNull(existingVisAttr)) 1795 throw py::value_error("Expected operation to have a symbol visibility."); 1796 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1797 toMlirStringRef(visibility)); 1798 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1799 } 1800 1801 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1802 const std::string &newSymbol, 1803 PyOperationBase &from) { 1804 PyOperation &fromOperation = from.getOperation(); 1805 fromOperation.checkValid(); 1806 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1807 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1808 from.getOperation()))) 1809 1810 throw py::value_error("Symbol rename failed"); 1811 } 1812 1813 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1814 bool allSymUsesVisible, 1815 py::object callback) { 1816 PyOperation &fromOperation = from.getOperation(); 1817 fromOperation.checkValid(); 1818 struct UserData { 1819 PyMlirContextRef context; 1820 py::object callback; 1821 bool gotException; 1822 std::string exceptionWhat; 1823 py::object exceptionType; 1824 }; 1825 UserData userData{ 1826 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1827 mlirSymbolTableWalkSymbolTables( 1828 fromOperation.get(), allSymUsesVisible, 1829 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1830 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1831 auto pyFoundOp = 1832 PyOperation::forOperation(calleeUserData->context, foundOp); 1833 if (calleeUserData->gotException) 1834 return; 1835 try { 1836 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1837 } catch (py::error_already_set &e) { 1838 calleeUserData->gotException = true; 1839 calleeUserData->exceptionWhat = e.what(); 1840 calleeUserData->exceptionType = e.type(); 1841 } 1842 }, 1843 static_cast<void *>(&userData)); 1844 if (userData.gotException) { 1845 std::string message("Exception raised in callback: "); 1846 message.append(userData.exceptionWhat); 1847 throw std::runtime_error(message); 1848 } 1849 } 1850 1851 namespace { 1852 /// CRTP base class for Python MLIR values that subclass Value and should be 1853 /// castable from it. The value hierarchy is one level deep and is not supposed 1854 /// to accommodate other levels unless core MLIR changes. 1855 template <typename DerivedTy> 1856 class PyConcreteValue : public PyValue { 1857 public: 1858 // Derived classes must define statics for: 1859 // IsAFunctionTy isaFunction 1860 // const char *pyClassName 1861 // and redefine bindDerived. 1862 using ClassTy = py::class_<DerivedTy, PyValue>; 1863 using IsAFunctionTy = bool (*)(MlirValue); 1864 1865 PyConcreteValue() = default; 1866 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1867 : PyValue(operationRef, value) {} 1868 PyConcreteValue(PyValue &orig) 1869 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1870 1871 /// Attempts to cast the original value to the derived type and throws on 1872 /// type mismatches. 1873 static MlirValue castFrom(PyValue &orig) { 1874 if (!DerivedTy::isaFunction(orig.get())) { 1875 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1876 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1877 DerivedTy::pyClassName + 1878 " (from " + origRepr + ")"); 1879 } 1880 return orig.get(); 1881 } 1882 1883 /// Binds the Python module objects to functions of this class. 1884 static void bind(py::module &m) { 1885 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1886 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1887 cls.def_static( 1888 "isinstance", 1889 [](PyValue &otherValue) -> bool { 1890 return DerivedTy::isaFunction(otherValue); 1891 }, 1892 py::arg("other_value")); 1893 DerivedTy::bindDerived(cls); 1894 } 1895 1896 /// Implemented by derived classes to add methods to the Python subclass. 1897 static void bindDerived(ClassTy &m) {} 1898 }; 1899 1900 /// Python wrapper for MlirBlockArgument. 1901 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1902 public: 1903 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1904 static constexpr const char *pyClassName = "BlockArgument"; 1905 using PyConcreteValue::PyConcreteValue; 1906 1907 static void bindDerived(ClassTy &c) { 1908 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1909 return PyBlock(self.getParentOperation(), 1910 mlirBlockArgumentGetOwner(self.get())); 1911 }); 1912 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1913 return mlirBlockArgumentGetArgNumber(self.get()); 1914 }); 1915 c.def( 1916 "set_type", 1917 [](PyBlockArgument &self, PyType type) { 1918 return mlirBlockArgumentSetType(self.get(), type); 1919 }, 1920 py::arg("type")); 1921 } 1922 }; 1923 1924 /// Python wrapper for MlirOpResult. 1925 class PyOpResult : public PyConcreteValue<PyOpResult> { 1926 public: 1927 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1928 static constexpr const char *pyClassName = "OpResult"; 1929 using PyConcreteValue::PyConcreteValue; 1930 1931 static void bindDerived(ClassTy &c) { 1932 c.def_property_readonly("owner", [](PyOpResult &self) { 1933 assert( 1934 mlirOperationEqual(self.getParentOperation()->get(), 1935 mlirOpResultGetOwner(self.get())) && 1936 "expected the owner of the value in Python to match that in the IR"); 1937 return self.getParentOperation().getObject(); 1938 }); 1939 c.def_property_readonly("result_number", [](PyOpResult &self) { 1940 return mlirOpResultGetResultNumber(self.get()); 1941 }); 1942 } 1943 }; 1944 1945 /// Returns the list of types of the values held by container. 1946 template <typename Container> 1947 static std::vector<PyType> getValueTypes(Container &container, 1948 PyMlirContextRef &context) { 1949 std::vector<PyType> result; 1950 result.reserve(container.getNumElements()); 1951 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1952 result.push_back( 1953 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1954 } 1955 return result; 1956 } 1957 1958 /// A list of block arguments. Internally, these are stored as consecutive 1959 /// elements, random access is cheap. The argument list is associated with the 1960 /// operation that contains the block (detached blocks are not allowed in 1961 /// Python bindings) and extends its lifetime. 1962 class PyBlockArgumentList 1963 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1964 public: 1965 static constexpr const char *pyClassName = "BlockArgumentList"; 1966 1967 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1968 intptr_t startIndex = 0, intptr_t length = -1, 1969 intptr_t step = 1) 1970 : Sliceable(startIndex, 1971 length == -1 ? mlirBlockGetNumArguments(block) : length, 1972 step), 1973 operation(std::move(operation)), block(block) {} 1974 1975 /// Returns the number of arguments in the list. 1976 intptr_t getNumElements() { 1977 operation->checkValid(); 1978 return mlirBlockGetNumArguments(block); 1979 } 1980 1981 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1982 PyBlockArgument getElement(intptr_t pos) { 1983 MlirValue argument = mlirBlockGetArgument(block, pos); 1984 return PyBlockArgument(operation, argument); 1985 } 1986 1987 /// Returns a sublist of this list. 1988 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1989 intptr_t step) { 1990 return PyBlockArgumentList(operation, block, startIndex, length, step); 1991 } 1992 1993 static void bindDerived(ClassTy &c) { 1994 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1995 return getValueTypes(self, self.operation->getContext()); 1996 }); 1997 } 1998 1999 private: 2000 PyOperationRef operation; 2001 MlirBlock block; 2002 }; 2003 2004 /// A list of operation operands. Internally, these are stored as consecutive 2005 /// elements, random access is cheap. The result list is associated with the 2006 /// operation whose results these are, and extends the lifetime of this 2007 /// operation. 2008 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 2009 public: 2010 static constexpr const char *pyClassName = "OpOperandList"; 2011 2012 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2013 intptr_t length = -1, intptr_t step = 1) 2014 : Sliceable(startIndex, 2015 length == -1 ? mlirOperationGetNumOperands(operation->get()) 2016 : length, 2017 step), 2018 operation(operation) {} 2019 2020 intptr_t getNumElements() { 2021 operation->checkValid(); 2022 return mlirOperationGetNumOperands(operation->get()); 2023 } 2024 2025 PyValue getElement(intptr_t pos) { 2026 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2027 MlirOperation owner; 2028 if (mlirValueIsAOpResult(operand)) 2029 owner = mlirOpResultGetOwner(operand); 2030 else if (mlirValueIsABlockArgument(operand)) 2031 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2032 else 2033 assert(false && "Value must be an block arg or op result."); 2034 PyOperationRef pyOwner = 2035 PyOperation::forOperation(operation->getContext(), owner); 2036 return PyValue(pyOwner, operand); 2037 } 2038 2039 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2040 return PyOpOperandList(operation, startIndex, length, step); 2041 } 2042 2043 void dunderSetItem(intptr_t index, PyValue value) { 2044 index = wrapIndex(index); 2045 mlirOperationSetOperand(operation->get(), index, value.get()); 2046 } 2047 2048 static void bindDerived(ClassTy &c) { 2049 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2050 } 2051 2052 private: 2053 PyOperationRef operation; 2054 }; 2055 2056 /// A list of operation results. Internally, these are stored as consecutive 2057 /// elements, random access is cheap. The result list is associated with the 2058 /// operation whose results these are, and extends the lifetime of this 2059 /// operation. 2060 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 2061 public: 2062 static constexpr const char *pyClassName = "OpResultList"; 2063 2064 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 2065 intptr_t length = -1, intptr_t step = 1) 2066 : Sliceable(startIndex, 2067 length == -1 ? mlirOperationGetNumResults(operation->get()) 2068 : length, 2069 step), 2070 operation(operation) {} 2071 2072 intptr_t getNumElements() { 2073 operation->checkValid(); 2074 return mlirOperationGetNumResults(operation->get()); 2075 } 2076 2077 PyOpResult getElement(intptr_t index) { 2078 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 2079 return PyOpResult(value); 2080 } 2081 2082 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2083 return PyOpResultList(operation, startIndex, length, step); 2084 } 2085 2086 static void bindDerived(ClassTy &c) { 2087 c.def_property_readonly("types", [](PyOpResultList &self) { 2088 return getValueTypes(self, self.operation->getContext()); 2089 }); 2090 } 2091 2092 private: 2093 PyOperationRef operation; 2094 }; 2095 2096 /// A list of operation attributes. Can be indexed by name, producing 2097 /// attributes, or by index, producing named attributes. 2098 class PyOpAttributeMap { 2099 public: 2100 PyOpAttributeMap(PyOperationRef operation) 2101 : operation(std::move(operation)) {} 2102 2103 PyAttribute dunderGetItemNamed(const std::string &name) { 2104 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2105 toMlirStringRef(name)); 2106 if (mlirAttributeIsNull(attr)) { 2107 throw SetPyError(PyExc_KeyError, 2108 "attempt to access a non-existent attribute"); 2109 } 2110 return PyAttribute(operation->getContext(), attr); 2111 } 2112 2113 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2114 if (index < 0 || index >= dunderLen()) { 2115 throw SetPyError(PyExc_IndexError, 2116 "attempt to access out of bounds attribute"); 2117 } 2118 MlirNamedAttribute namedAttr = 2119 mlirOperationGetAttribute(operation->get(), index); 2120 return PyNamedAttribute( 2121 namedAttr.attribute, 2122 std::string(mlirIdentifierStr(namedAttr.name).data, 2123 mlirIdentifierStr(namedAttr.name).length)); 2124 } 2125 2126 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2127 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2128 attr); 2129 } 2130 2131 void dunderDelItem(const std::string &name) { 2132 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2133 toMlirStringRef(name)); 2134 if (!removed) 2135 throw SetPyError(PyExc_KeyError, 2136 "attempt to delete a non-existent attribute"); 2137 } 2138 2139 intptr_t dunderLen() { 2140 return mlirOperationGetNumAttributes(operation->get()); 2141 } 2142 2143 bool dunderContains(const std::string &name) { 2144 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2145 operation->get(), toMlirStringRef(name))); 2146 } 2147 2148 static void bind(py::module &m) { 2149 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2150 .def("__contains__", &PyOpAttributeMap::dunderContains) 2151 .def("__len__", &PyOpAttributeMap::dunderLen) 2152 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2153 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2154 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2155 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2156 } 2157 2158 private: 2159 PyOperationRef operation; 2160 }; 2161 2162 } // namespace 2163 2164 //------------------------------------------------------------------------------ 2165 // Populates the core exports of the 'ir' submodule. 2166 //------------------------------------------------------------------------------ 2167 2168 void mlir::python::populateIRCore(py::module &m) { 2169 //---------------------------------------------------------------------------- 2170 // Enums. 2171 //---------------------------------------------------------------------------- 2172 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local()) 2173 .value("ERROR", MlirDiagnosticError) 2174 .value("WARNING", MlirDiagnosticWarning) 2175 .value("NOTE", MlirDiagnosticNote) 2176 .value("REMARK", MlirDiagnosticRemark); 2177 2178 //---------------------------------------------------------------------------- 2179 // Mapping of Diagnostics. 2180 //---------------------------------------------------------------------------- 2181 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local()) 2182 .def_property_readonly("severity", &PyDiagnostic::getSeverity) 2183 .def_property_readonly("location", &PyDiagnostic::getLocation) 2184 .def_property_readonly("message", &PyDiagnostic::getMessage) 2185 .def_property_readonly("notes", &PyDiagnostic::getNotes) 2186 .def("__str__", [](PyDiagnostic &self) -> py::str { 2187 if (!self.isValid()) 2188 return "<Invalid Diagnostic>"; 2189 return self.getMessage(); 2190 }); 2191 2192 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local()) 2193 .def("detach", &PyDiagnosticHandler::detach) 2194 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) 2195 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) 2196 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2197 .def("__exit__", &PyDiagnosticHandler::contextExit); 2198 2199 //---------------------------------------------------------------------------- 2200 // Mapping of MlirContext. 2201 //---------------------------------------------------------------------------- 2202 py::class_<PyMlirContext>(m, "Context", py::module_local()) 2203 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2204 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2205 .def("_get_context_again", 2206 [](PyMlirContext &self) { 2207 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2208 return ref.releaseObject(); 2209 }) 2210 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2211 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2212 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2213 &PyMlirContext::getCapsule) 2214 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2215 .def("__enter__", &PyMlirContext::contextEnter) 2216 .def("__exit__", &PyMlirContext::contextExit) 2217 .def_property_readonly_static( 2218 "current", 2219 [](py::object & /*class*/) { 2220 auto *context = PyThreadContextEntry::getDefaultContext(); 2221 if (!context) 2222 throw SetPyError(PyExc_ValueError, "No current Context"); 2223 return context; 2224 }, 2225 "Gets the Context bound to the current thread or raises ValueError") 2226 .def_property_readonly( 2227 "dialects", 2228 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2229 "Gets a container for accessing dialects by name") 2230 .def_property_readonly( 2231 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2232 "Alias for 'dialect'") 2233 .def( 2234 "get_dialect_descriptor", 2235 [=](PyMlirContext &self, std::string &name) { 2236 MlirDialect dialect = mlirContextGetOrLoadDialect( 2237 self.get(), {name.data(), name.size()}); 2238 if (mlirDialectIsNull(dialect)) { 2239 throw SetPyError(PyExc_ValueError, 2240 Twine("Dialect '") + name + "' not found"); 2241 } 2242 return PyDialectDescriptor(self.getRef(), dialect); 2243 }, 2244 py::arg("dialect_name"), 2245 "Gets or loads a dialect by name, returning its descriptor object") 2246 .def_property( 2247 "allow_unregistered_dialects", 2248 [](PyMlirContext &self) -> bool { 2249 return mlirContextGetAllowUnregisteredDialects(self.get()); 2250 }, 2251 [](PyMlirContext &self, bool value) { 2252 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2253 }) 2254 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2255 py::arg("callback"), 2256 "Attaches a diagnostic handler that will receive callbacks") 2257 .def( 2258 "enable_multithreading", 2259 [](PyMlirContext &self, bool enable) { 2260 mlirContextEnableMultithreading(self.get(), enable); 2261 }, 2262 py::arg("enable")) 2263 .def( 2264 "is_registered_operation", 2265 [](PyMlirContext &self, std::string &name) { 2266 return mlirContextIsRegisteredOperation( 2267 self.get(), MlirStringRef{name.data(), name.size()}); 2268 }, 2269 py::arg("operation_name")); 2270 2271 //---------------------------------------------------------------------------- 2272 // Mapping of PyDialectDescriptor 2273 //---------------------------------------------------------------------------- 2274 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2275 .def_property_readonly("namespace", 2276 [](PyDialectDescriptor &self) { 2277 MlirStringRef ns = 2278 mlirDialectGetNamespace(self.get()); 2279 return py::str(ns.data, ns.length); 2280 }) 2281 .def("__repr__", [](PyDialectDescriptor &self) { 2282 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2283 std::string repr("<DialectDescriptor "); 2284 repr.append(ns.data, ns.length); 2285 repr.append(">"); 2286 return repr; 2287 }); 2288 2289 //---------------------------------------------------------------------------- 2290 // Mapping of PyDialects 2291 //---------------------------------------------------------------------------- 2292 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2293 .def("__getitem__", 2294 [=](PyDialects &self, std::string keyName) { 2295 MlirDialect dialect = 2296 self.getDialectForKey(keyName, /*attrError=*/false); 2297 py::object descriptor = 2298 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2299 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2300 }) 2301 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2302 MlirDialect dialect = 2303 self.getDialectForKey(attrName, /*attrError=*/true); 2304 py::object descriptor = 2305 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2306 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2307 }); 2308 2309 //---------------------------------------------------------------------------- 2310 // Mapping of PyDialect 2311 //---------------------------------------------------------------------------- 2312 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2313 .def(py::init<py::object>(), py::arg("descriptor")) 2314 .def_property_readonly( 2315 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2316 .def("__repr__", [](py::object self) { 2317 auto clazz = self.attr("__class__"); 2318 return py::str("<Dialect ") + 2319 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2320 clazz.attr("__module__") + py::str(".") + 2321 clazz.attr("__name__") + py::str(")>"); 2322 }); 2323 2324 //---------------------------------------------------------------------------- 2325 // Mapping of Location 2326 //---------------------------------------------------------------------------- 2327 py::class_<PyLocation>(m, "Location", py::module_local()) 2328 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2329 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2330 .def("__enter__", &PyLocation::contextEnter) 2331 .def("__exit__", &PyLocation::contextExit) 2332 .def("__eq__", 2333 [](PyLocation &self, PyLocation &other) -> bool { 2334 return mlirLocationEqual(self, other); 2335 }) 2336 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2337 .def_property_readonly_static( 2338 "current", 2339 [](py::object & /*class*/) { 2340 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2341 if (!loc) 2342 throw SetPyError(PyExc_ValueError, "No current Location"); 2343 return loc; 2344 }, 2345 "Gets the Location bound to the current thread or raises ValueError") 2346 .def_static( 2347 "unknown", 2348 [](DefaultingPyMlirContext context) { 2349 return PyLocation(context->getRef(), 2350 mlirLocationUnknownGet(context->get())); 2351 }, 2352 py::arg("context") = py::none(), 2353 "Gets a Location representing an unknown location") 2354 .def_static( 2355 "callsite", 2356 [](PyLocation callee, const std::vector<PyLocation> &frames, 2357 DefaultingPyMlirContext context) { 2358 if (frames.empty()) 2359 throw py::value_error("No caller frames provided"); 2360 MlirLocation caller = frames.back().get(); 2361 for (const PyLocation &frame : 2362 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2363 caller = mlirLocationCallSiteGet(frame.get(), caller); 2364 return PyLocation(context->getRef(), 2365 mlirLocationCallSiteGet(callee.get(), caller)); 2366 }, 2367 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2368 kContextGetCallSiteLocationDocstring) 2369 .def_static( 2370 "file", 2371 [](std::string filename, int line, int col, 2372 DefaultingPyMlirContext context) { 2373 return PyLocation( 2374 context->getRef(), 2375 mlirLocationFileLineColGet( 2376 context->get(), toMlirStringRef(filename), line, col)); 2377 }, 2378 py::arg("filename"), py::arg("line"), py::arg("col"), 2379 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2380 .def_static( 2381 "fused", 2382 [](const std::vector<PyLocation> &pyLocations, 2383 llvm::Optional<PyAttribute> metadata, 2384 DefaultingPyMlirContext context) { 2385 llvm::SmallVector<MlirLocation, 4> locations; 2386 locations.reserve(pyLocations.size()); 2387 for (auto &pyLocation : pyLocations) 2388 locations.push_back(pyLocation.get()); 2389 MlirLocation location = mlirLocationFusedGet( 2390 context->get(), locations.size(), locations.data(), 2391 metadata ? metadata->get() : MlirAttribute{0}); 2392 return PyLocation(context->getRef(), location); 2393 }, 2394 py::arg("locations"), py::arg("metadata") = py::none(), 2395 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2396 .def_static( 2397 "name", 2398 [](std::string name, llvm::Optional<PyLocation> childLoc, 2399 DefaultingPyMlirContext context) { 2400 return PyLocation( 2401 context->getRef(), 2402 mlirLocationNameGet( 2403 context->get(), toMlirStringRef(name), 2404 childLoc ? childLoc->get() 2405 : mlirLocationUnknownGet(context->get()))); 2406 }, 2407 py::arg("name"), py::arg("childLoc") = py::none(), 2408 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2409 .def_property_readonly( 2410 "context", 2411 [](PyLocation &self) { return self.getContext().getObject(); }, 2412 "Context that owns the Location") 2413 .def( 2414 "emit_error", 2415 [](PyLocation &self, std::string message) { 2416 mlirEmitError(self, message.c_str()); 2417 }, 2418 py::arg("message"), "Emits an error at this location") 2419 .def("__repr__", [](PyLocation &self) { 2420 PyPrintAccumulator printAccum; 2421 mlirLocationPrint(self, printAccum.getCallback(), 2422 printAccum.getUserData()); 2423 return printAccum.join(); 2424 }); 2425 2426 //---------------------------------------------------------------------------- 2427 // Mapping of Module 2428 //---------------------------------------------------------------------------- 2429 py::class_<PyModule>(m, "Module", py::module_local()) 2430 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2431 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2432 .def_static( 2433 "parse", 2434 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2435 MlirModule module = mlirModuleCreateParse( 2436 context->get(), toMlirStringRef(moduleAsm)); 2437 // TODO: Rework error reporting once diagnostic engine is exposed 2438 // in C API. 2439 if (mlirModuleIsNull(module)) { 2440 throw SetPyError( 2441 PyExc_ValueError, 2442 "Unable to parse module assembly (see diagnostics)"); 2443 } 2444 return PyModule::forModule(module).releaseObject(); 2445 }, 2446 py::arg("asm"), py::arg("context") = py::none(), 2447 kModuleParseDocstring) 2448 .def_static( 2449 "create", 2450 [](DefaultingPyLocation loc) { 2451 MlirModule module = mlirModuleCreateEmpty(loc); 2452 return PyModule::forModule(module).releaseObject(); 2453 }, 2454 py::arg("loc") = py::none(), "Creates an empty module") 2455 .def_property_readonly( 2456 "context", 2457 [](PyModule &self) { return self.getContext().getObject(); }, 2458 "Context that created the Module") 2459 .def_property_readonly( 2460 "operation", 2461 [](PyModule &self) { 2462 return PyOperation::forOperation(self.getContext(), 2463 mlirModuleGetOperation(self.get()), 2464 self.getRef().releaseObject()) 2465 .releaseObject(); 2466 }, 2467 "Accesses the module as an operation") 2468 .def_property_readonly( 2469 "body", 2470 [](PyModule &self) { 2471 PyOperationRef moduleOp = PyOperation::forOperation( 2472 self.getContext(), mlirModuleGetOperation(self.get()), 2473 self.getRef().releaseObject()); 2474 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 2475 return returnBlock; 2476 }, 2477 "Return the block for this module") 2478 .def( 2479 "dump", 2480 [](PyModule &self) { 2481 mlirOperationDump(mlirModuleGetOperation(self.get())); 2482 }, 2483 kDumpDocstring) 2484 .def( 2485 "__str__", 2486 [](py::object self) { 2487 // Defer to the operation's __str__. 2488 return self.attr("operation").attr("__str__")(); 2489 }, 2490 kOperationStrDunderDocstring); 2491 2492 //---------------------------------------------------------------------------- 2493 // Mapping of Operation. 2494 //---------------------------------------------------------------------------- 2495 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2496 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2497 [](PyOperationBase &self) { 2498 return self.getOperation().getCapsule(); 2499 }) 2500 .def("__eq__", 2501 [](PyOperationBase &self, PyOperationBase &other) { 2502 return &self.getOperation() == &other.getOperation(); 2503 }) 2504 .def("__eq__", 2505 [](PyOperationBase &self, py::object other) { return false; }) 2506 .def("__hash__", 2507 [](PyOperationBase &self) { 2508 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2509 }) 2510 .def_property_readonly("attributes", 2511 [](PyOperationBase &self) { 2512 return PyOpAttributeMap( 2513 self.getOperation().getRef()); 2514 }) 2515 .def_property_readonly("operands", 2516 [](PyOperationBase &self) { 2517 return PyOpOperandList( 2518 self.getOperation().getRef()); 2519 }) 2520 .def_property_readonly("regions", 2521 [](PyOperationBase &self) { 2522 return PyRegionList( 2523 self.getOperation().getRef()); 2524 }) 2525 .def_property_readonly( 2526 "results", 2527 [](PyOperationBase &self) { 2528 return PyOpResultList(self.getOperation().getRef()); 2529 }, 2530 "Returns the list of Operation results.") 2531 .def_property_readonly( 2532 "result", 2533 [](PyOperationBase &self) { 2534 auto &operation = self.getOperation(); 2535 auto numResults = mlirOperationGetNumResults(operation); 2536 if (numResults != 1) { 2537 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2538 throw SetPyError( 2539 PyExc_ValueError, 2540 Twine("Cannot call .result on operation ") + 2541 StringRef(name.data, name.length) + " which has " + 2542 Twine(numResults) + 2543 " results (it is only valid for operations with a " 2544 "single result)"); 2545 } 2546 return PyOpResult(operation.getRef(), 2547 mlirOperationGetResult(operation, 0)); 2548 }, 2549 "Shortcut to get an op result if it has only one (throws an error " 2550 "otherwise).") 2551 .def_property_readonly( 2552 "location", 2553 [](PyOperationBase &self) { 2554 PyOperation &operation = self.getOperation(); 2555 return PyLocation(operation.getContext(), 2556 mlirOperationGetLocation(operation.get())); 2557 }, 2558 "Returns the source location the operation was defined or derived " 2559 "from.") 2560 .def( 2561 "__str__", 2562 [](PyOperationBase &self) { 2563 return self.getAsm(/*binary=*/false, 2564 /*largeElementsLimit=*/llvm::None, 2565 /*enableDebugInfo=*/false, 2566 /*prettyDebugInfo=*/false, 2567 /*printGenericOpForm=*/false, 2568 /*useLocalScope=*/false, 2569 /*assumeVerified=*/false); 2570 }, 2571 "Returns the assembly form of the operation.") 2572 .def("print", &PyOperationBase::print, 2573 // Careful: Lots of arguments must match up with print method. 2574 py::arg("file") = py::none(), py::arg("binary") = false, 2575 py::arg("large_elements_limit") = py::none(), 2576 py::arg("enable_debug_info") = false, 2577 py::arg("pretty_debug_info") = false, 2578 py::arg("print_generic_op_form") = false, 2579 py::arg("use_local_scope") = false, 2580 py::arg("assume_verified") = false, kOperationPrintDocstring) 2581 .def("get_asm", &PyOperationBase::getAsm, 2582 // Careful: Lots of arguments must match up with get_asm method. 2583 py::arg("binary") = false, 2584 py::arg("large_elements_limit") = py::none(), 2585 py::arg("enable_debug_info") = false, 2586 py::arg("pretty_debug_info") = false, 2587 py::arg("print_generic_op_form") = false, 2588 py::arg("use_local_scope") = false, 2589 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2590 .def( 2591 "verify", 2592 [](PyOperationBase &self) { 2593 return mlirOperationVerify(self.getOperation()); 2594 }, 2595 "Verify the operation and return true if it passes, false if it " 2596 "fails.") 2597 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2598 "Puts self immediately after the other operation in its parent " 2599 "block.") 2600 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2601 "Puts self immediately before the other operation in its parent " 2602 "block.") 2603 .def( 2604 "detach_from_parent", 2605 [](PyOperationBase &self) { 2606 PyOperation &operation = self.getOperation(); 2607 operation.checkValid(); 2608 if (!operation.isAttached()) 2609 throw py::value_error("Detached operation has no parent."); 2610 2611 operation.detachFromParent(); 2612 return operation.createOpView(); 2613 }, 2614 "Detaches the operation from its parent block."); 2615 2616 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2617 .def_static("create", &PyOperation::create, py::arg("name"), 2618 py::arg("results") = py::none(), 2619 py::arg("operands") = py::none(), 2620 py::arg("attributes") = py::none(), 2621 py::arg("successors") = py::none(), py::arg("regions") = 0, 2622 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2623 kOperationCreateDocstring) 2624 .def_property_readonly("parent", 2625 [](PyOperation &self) -> py::object { 2626 auto parent = self.getParentOperation(); 2627 if (parent) 2628 return parent->getObject(); 2629 return py::none(); 2630 }) 2631 .def("erase", &PyOperation::erase) 2632 .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) 2633 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2634 &PyOperation::getCapsule) 2635 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2636 .def_property_readonly("name", 2637 [](PyOperation &self) { 2638 self.checkValid(); 2639 MlirOperation operation = self.get(); 2640 MlirStringRef name = mlirIdentifierStr( 2641 mlirOperationGetName(operation)); 2642 return py::str(name.data, name.length); 2643 }) 2644 .def_property_readonly( 2645 "context", 2646 [](PyOperation &self) { 2647 self.checkValid(); 2648 return self.getContext().getObject(); 2649 }, 2650 "Context that owns the Operation") 2651 .def_property_readonly("opview", &PyOperation::createOpView); 2652 2653 auto opViewClass = 2654 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2655 .def(py::init<py::object>(), py::arg("operation")) 2656 .def_property_readonly("operation", &PyOpView::getOperationObject) 2657 .def_property_readonly( 2658 "context", 2659 [](PyOpView &self) { 2660 return self.getOperation().getContext().getObject(); 2661 }, 2662 "Context that owns the Operation") 2663 .def("__str__", [](PyOpView &self) { 2664 return py::str(self.getOperationObject()); 2665 }); 2666 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2667 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2668 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2669 opViewClass.attr("build_generic") = classmethod( 2670 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2671 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2672 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2673 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2674 "Builds a specific, generated OpView based on class level attributes."); 2675 2676 //---------------------------------------------------------------------------- 2677 // Mapping of PyRegion. 2678 //---------------------------------------------------------------------------- 2679 py::class_<PyRegion>(m, "Region", py::module_local()) 2680 .def_property_readonly( 2681 "blocks", 2682 [](PyRegion &self) { 2683 return PyBlockList(self.getParentOperation(), self.get()); 2684 }, 2685 "Returns a forward-optimized sequence of blocks.") 2686 .def_property_readonly( 2687 "owner", 2688 [](PyRegion &self) { 2689 return self.getParentOperation()->createOpView(); 2690 }, 2691 "Returns the operation owning this region.") 2692 .def( 2693 "__iter__", 2694 [](PyRegion &self) { 2695 self.checkValid(); 2696 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2697 return PyBlockIterator(self.getParentOperation(), firstBlock); 2698 }, 2699 "Iterates over blocks in the region.") 2700 .def("__eq__", 2701 [](PyRegion &self, PyRegion &other) { 2702 return self.get().ptr == other.get().ptr; 2703 }) 2704 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2705 2706 //---------------------------------------------------------------------------- 2707 // Mapping of PyBlock. 2708 //---------------------------------------------------------------------------- 2709 py::class_<PyBlock>(m, "Block", py::module_local()) 2710 .def_property_readonly( 2711 "owner", 2712 [](PyBlock &self) { 2713 return self.getParentOperation()->createOpView(); 2714 }, 2715 "Returns the owning operation of this block.") 2716 .def_property_readonly( 2717 "region", 2718 [](PyBlock &self) { 2719 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2720 return PyRegion(self.getParentOperation(), region); 2721 }, 2722 "Returns the owning region of this block.") 2723 .def_property_readonly( 2724 "arguments", 2725 [](PyBlock &self) { 2726 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2727 }, 2728 "Returns a list of block arguments.") 2729 .def_property_readonly( 2730 "operations", 2731 [](PyBlock &self) { 2732 return PyOperationList(self.getParentOperation(), self.get()); 2733 }, 2734 "Returns a forward-optimized sequence of operations.") 2735 .def_static( 2736 "create_at_start", 2737 [](PyRegion &parent, py::list pyArgTypes) { 2738 parent.checkValid(); 2739 llvm::SmallVector<MlirType, 4> argTypes; 2740 llvm::SmallVector<MlirLocation, 4> argLocs; 2741 argTypes.reserve(pyArgTypes.size()); 2742 argLocs.reserve(pyArgTypes.size()); 2743 for (auto &pyArg : pyArgTypes) { 2744 argTypes.push_back(pyArg.cast<PyType &>()); 2745 // TODO: Pass in a proper location here. 2746 argLocs.push_back( 2747 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2748 } 2749 2750 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2751 argLocs.data()); 2752 mlirRegionInsertOwnedBlock(parent, 0, block); 2753 return PyBlock(parent.getParentOperation(), block); 2754 }, 2755 py::arg("parent"), py::arg("arg_types") = py::list(), 2756 "Creates and returns a new Block at the beginning of the given " 2757 "region (with given argument types).") 2758 .def( 2759 "append_to", 2760 [](PyBlock &self, PyRegion ®ion) { 2761 MlirBlock b = self.get(); 2762 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) 2763 mlirBlockDetach(b); 2764 mlirRegionAppendOwnedBlock(region.get(), b); 2765 }, 2766 "Append this block to a region, transferring ownership if necessary") 2767 .def( 2768 "create_before", 2769 [](PyBlock &self, py::args pyArgTypes) { 2770 self.checkValid(); 2771 llvm::SmallVector<MlirType, 4> argTypes; 2772 llvm::SmallVector<MlirLocation, 4> argLocs; 2773 argTypes.reserve(pyArgTypes.size()); 2774 argLocs.reserve(pyArgTypes.size()); 2775 for (auto &pyArg : pyArgTypes) { 2776 argTypes.push_back(pyArg.cast<PyType &>()); 2777 // TODO: Pass in a proper location here. 2778 argLocs.push_back( 2779 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2780 } 2781 2782 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2783 argLocs.data()); 2784 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2785 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2786 return PyBlock(self.getParentOperation(), block); 2787 }, 2788 "Creates and returns a new Block before this block " 2789 "(with given argument types).") 2790 .def( 2791 "create_after", 2792 [](PyBlock &self, py::args pyArgTypes) { 2793 self.checkValid(); 2794 llvm::SmallVector<MlirType, 4> argTypes; 2795 llvm::SmallVector<MlirLocation, 4> argLocs; 2796 argTypes.reserve(pyArgTypes.size()); 2797 argLocs.reserve(pyArgTypes.size()); 2798 for (auto &pyArg : pyArgTypes) { 2799 argTypes.push_back(pyArg.cast<PyType &>()); 2800 2801 // TODO: Pass in a proper location here. 2802 argLocs.push_back( 2803 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2804 } 2805 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2806 argLocs.data()); 2807 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2808 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2809 return PyBlock(self.getParentOperation(), block); 2810 }, 2811 "Creates and returns a new Block after this block " 2812 "(with given argument types).") 2813 .def( 2814 "__iter__", 2815 [](PyBlock &self) { 2816 self.checkValid(); 2817 MlirOperation firstOperation = 2818 mlirBlockGetFirstOperation(self.get()); 2819 return PyOperationIterator(self.getParentOperation(), 2820 firstOperation); 2821 }, 2822 "Iterates over operations in the block.") 2823 .def("__eq__", 2824 [](PyBlock &self, PyBlock &other) { 2825 return self.get().ptr == other.get().ptr; 2826 }) 2827 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2828 .def( 2829 "__str__", 2830 [](PyBlock &self) { 2831 self.checkValid(); 2832 PyPrintAccumulator printAccum; 2833 mlirBlockPrint(self.get(), printAccum.getCallback(), 2834 printAccum.getUserData()); 2835 return printAccum.join(); 2836 }, 2837 "Returns the assembly form of the block.") 2838 .def( 2839 "append", 2840 [](PyBlock &self, PyOperationBase &operation) { 2841 if (operation.getOperation().isAttached()) 2842 operation.getOperation().detachFromParent(); 2843 2844 MlirOperation mlirOperation = operation.getOperation().get(); 2845 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2846 operation.getOperation().setAttached( 2847 self.getParentOperation().getObject()); 2848 }, 2849 py::arg("operation"), 2850 "Appends an operation to this block. If the operation is currently " 2851 "in another block, it will be moved."); 2852 2853 //---------------------------------------------------------------------------- 2854 // Mapping of PyInsertionPoint. 2855 //---------------------------------------------------------------------------- 2856 2857 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2858 .def(py::init<PyBlock &>(), py::arg("block"), 2859 "Inserts after the last operation but still inside the block.") 2860 .def("__enter__", &PyInsertionPoint::contextEnter) 2861 .def("__exit__", &PyInsertionPoint::contextExit) 2862 .def_property_readonly_static( 2863 "current", 2864 [](py::object & /*class*/) { 2865 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2866 if (!ip) 2867 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2868 return ip; 2869 }, 2870 "Gets the InsertionPoint bound to the current thread or raises " 2871 "ValueError if none has been set") 2872 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2873 "Inserts before a referenced operation.") 2874 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2875 py::arg("block"), "Inserts at the beginning of the block.") 2876 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2877 py::arg("block"), "Inserts before the block terminator.") 2878 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2879 "Inserts an operation.") 2880 .def_property_readonly( 2881 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2882 "Returns the block that this InsertionPoint points to."); 2883 2884 //---------------------------------------------------------------------------- 2885 // Mapping of PyAttribute. 2886 //---------------------------------------------------------------------------- 2887 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2888 // Delegate to the PyAttribute copy constructor, which will also lifetime 2889 // extend the backing context which owns the MlirAttribute. 2890 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2891 "Casts the passed attribute to the generic Attribute") 2892 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2893 &PyAttribute::getCapsule) 2894 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2895 .def_static( 2896 "parse", 2897 [](std::string attrSpec, DefaultingPyMlirContext context) { 2898 MlirAttribute type = mlirAttributeParseGet( 2899 context->get(), toMlirStringRef(attrSpec)); 2900 // TODO: Rework error reporting once diagnostic engine is exposed 2901 // in C API. 2902 if (mlirAttributeIsNull(type)) { 2903 throw SetPyError(PyExc_ValueError, 2904 Twine("Unable to parse attribute: '") + 2905 attrSpec + "'"); 2906 } 2907 return PyAttribute(context->getRef(), type); 2908 }, 2909 py::arg("asm"), py::arg("context") = py::none(), 2910 "Parses an attribute from an assembly form") 2911 .def_property_readonly( 2912 "context", 2913 [](PyAttribute &self) { return self.getContext().getObject(); }, 2914 "Context that owns the Attribute") 2915 .def_property_readonly("type", 2916 [](PyAttribute &self) { 2917 return PyType(self.getContext()->getRef(), 2918 mlirAttributeGetType(self)); 2919 }) 2920 .def( 2921 "get_named", 2922 [](PyAttribute &self, std::string name) { 2923 return PyNamedAttribute(self, std::move(name)); 2924 }, 2925 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2926 .def("__eq__", 2927 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2928 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2929 .def("__hash__", 2930 [](PyAttribute &self) { 2931 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2932 }) 2933 .def( 2934 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2935 kDumpDocstring) 2936 .def( 2937 "__str__", 2938 [](PyAttribute &self) { 2939 PyPrintAccumulator printAccum; 2940 mlirAttributePrint(self, printAccum.getCallback(), 2941 printAccum.getUserData()); 2942 return printAccum.join(); 2943 }, 2944 "Returns the assembly form of the Attribute.") 2945 .def("__repr__", [](PyAttribute &self) { 2946 // Generally, assembly formats are not printed for __repr__ because 2947 // this can cause exceptionally long debug output and exceptions. 2948 // However, attribute values are generally considered useful and are 2949 // printed. This may need to be re-evaluated if debug dumps end up 2950 // being excessive. 2951 PyPrintAccumulator printAccum; 2952 printAccum.parts.append("Attribute("); 2953 mlirAttributePrint(self, printAccum.getCallback(), 2954 printAccum.getUserData()); 2955 printAccum.parts.append(")"); 2956 return printAccum.join(); 2957 }); 2958 2959 //---------------------------------------------------------------------------- 2960 // Mapping of PyNamedAttribute 2961 //---------------------------------------------------------------------------- 2962 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2963 .def("__repr__", 2964 [](PyNamedAttribute &self) { 2965 PyPrintAccumulator printAccum; 2966 printAccum.parts.append("NamedAttribute("); 2967 printAccum.parts.append( 2968 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2969 mlirIdentifierStr(self.namedAttr.name).length)); 2970 printAccum.parts.append("="); 2971 mlirAttributePrint(self.namedAttr.attribute, 2972 printAccum.getCallback(), 2973 printAccum.getUserData()); 2974 printAccum.parts.append(")"); 2975 return printAccum.join(); 2976 }) 2977 .def_property_readonly( 2978 "name", 2979 [](PyNamedAttribute &self) { 2980 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2981 mlirIdentifierStr(self.namedAttr.name).length); 2982 }, 2983 "The name of the NamedAttribute binding") 2984 .def_property_readonly( 2985 "attr", 2986 [](PyNamedAttribute &self) { 2987 // TODO: When named attribute is removed/refactored, also remove 2988 // this constructor (it does an inefficient table lookup). 2989 auto contextRef = PyMlirContext::forContext( 2990 mlirAttributeGetContext(self.namedAttr.attribute)); 2991 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2992 }, 2993 py::keep_alive<0, 1>(), 2994 "The underlying generic attribute of the NamedAttribute binding"); 2995 2996 //---------------------------------------------------------------------------- 2997 // Mapping of PyType. 2998 //---------------------------------------------------------------------------- 2999 py::class_<PyType>(m, "Type", py::module_local()) 3000 // Delegate to the PyType copy constructor, which will also lifetime 3001 // extend the backing context which owns the MlirType. 3002 .def(py::init<PyType &>(), py::arg("cast_from_type"), 3003 "Casts the passed type to the generic Type") 3004 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 3005 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 3006 .def_static( 3007 "parse", 3008 [](std::string typeSpec, DefaultingPyMlirContext context) { 3009 MlirType type = 3010 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 3011 // TODO: Rework error reporting once diagnostic engine is exposed 3012 // in C API. 3013 if (mlirTypeIsNull(type)) { 3014 throw SetPyError(PyExc_ValueError, 3015 Twine("Unable to parse type: '") + typeSpec + 3016 "'"); 3017 } 3018 return PyType(context->getRef(), type); 3019 }, 3020 py::arg("asm"), py::arg("context") = py::none(), 3021 kContextParseTypeDocstring) 3022 .def_property_readonly( 3023 "context", [](PyType &self) { return self.getContext().getObject(); }, 3024 "Context that owns the Type") 3025 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3026 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 3027 .def("__hash__", 3028 [](PyType &self) { 3029 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3030 }) 3031 .def( 3032 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3033 .def( 3034 "__str__", 3035 [](PyType &self) { 3036 PyPrintAccumulator printAccum; 3037 mlirTypePrint(self, printAccum.getCallback(), 3038 printAccum.getUserData()); 3039 return printAccum.join(); 3040 }, 3041 "Returns the assembly form of the type.") 3042 .def("__repr__", [](PyType &self) { 3043 // Generally, assembly formats are not printed for __repr__ because 3044 // this can cause exceptionally long debug output and exceptions. 3045 // However, types are an exception as they typically have compact 3046 // assembly forms and printing them is useful. 3047 PyPrintAccumulator printAccum; 3048 printAccum.parts.append("Type("); 3049 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 3050 printAccum.parts.append(")"); 3051 return printAccum.join(); 3052 }); 3053 3054 //---------------------------------------------------------------------------- 3055 // Mapping of Value. 3056 //---------------------------------------------------------------------------- 3057 py::class_<PyValue>(m, "Value", py::module_local()) 3058 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3059 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3060 .def_property_readonly( 3061 "context", 3062 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3063 "Context in which the value lives.") 3064 .def( 3065 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3066 kDumpDocstring) 3067 .def_property_readonly( 3068 "owner", 3069 [](PyValue &self) { 3070 assert(mlirOperationEqual(self.getParentOperation()->get(), 3071 mlirOpResultGetOwner(self.get())) && 3072 "expected the owner of the value in Python to match that in " 3073 "the IR"); 3074 return self.getParentOperation().getObject(); 3075 }) 3076 .def("__eq__", 3077 [](PyValue &self, PyValue &other) { 3078 return self.get().ptr == other.get().ptr; 3079 }) 3080 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 3081 .def("__hash__", 3082 [](PyValue &self) { 3083 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3084 }) 3085 .def( 3086 "__str__", 3087 [](PyValue &self) { 3088 PyPrintAccumulator printAccum; 3089 printAccum.parts.append("Value("); 3090 mlirValuePrint(self.get(), printAccum.getCallback(), 3091 printAccum.getUserData()); 3092 printAccum.parts.append(")"); 3093 return printAccum.join(); 3094 }, 3095 kValueDunderStrDocstring) 3096 .def_property_readonly("type", [](PyValue &self) { 3097 return PyType(self.getParentOperation()->getContext(), 3098 mlirValueGetType(self.get())); 3099 }); 3100 PyBlockArgument::bind(m); 3101 PyOpResult::bind(m); 3102 3103 //---------------------------------------------------------------------------- 3104 // Mapping of SymbolTable. 3105 //---------------------------------------------------------------------------- 3106 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 3107 .def(py::init<PyOperationBase &>()) 3108 .def("__getitem__", &PySymbolTable::dunderGetItem) 3109 .def("insert", &PySymbolTable::insert, py::arg("operation")) 3110 .def("erase", &PySymbolTable::erase, py::arg("operation")) 3111 .def("__delitem__", &PySymbolTable::dunderDel) 3112 .def("__contains__", 3113 [](PySymbolTable &table, const std::string &name) { 3114 return !mlirOperationIsNull(mlirSymbolTableLookup( 3115 table, mlirStringRefCreate(name.data(), name.length()))); 3116 }) 3117 // Static helpers. 3118 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3119 py::arg("symbol"), py::arg("name")) 3120 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3121 py::arg("symbol")) 3122 .def_static("get_visibility", &PySymbolTable::getVisibility, 3123 py::arg("symbol")) 3124 .def_static("set_visibility", &PySymbolTable::setVisibility, 3125 py::arg("symbol"), py::arg("visibility")) 3126 .def_static("replace_all_symbol_uses", 3127 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 3128 py::arg("new_symbol"), py::arg("from_op")) 3129 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3130 py::arg("from_op"), py::arg("all_sym_uses_visible"), 3131 py::arg("callback")); 3132 3133 // Container bindings. 3134 PyBlockArgumentList::bind(m); 3135 PyBlockIterator::bind(m); 3136 PyBlockList::bind(m); 3137 PyOperationIterator::bind(m); 3138 PyOperationList::bind(m); 3139 PyOpAttributeMap::bind(m); 3140 PyOpOperandList::bind(m); 3141 PyOpResultList::bind(m); 3142 PyRegionIterator::bind(m); 3143 PyRegionList::bind(m); 3144 3145 // Debug bindings. 3146 PyGlobalDebugFlag::bind(m); 3147 } 3148