1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 #include <utility> 25 26 namespace py = pybind11; 27 using namespace mlir; 28 using namespace mlir::python; 29 30 using llvm::SmallVector; 31 using llvm::StringRef; 32 using llvm::Twine; 33 34 //------------------------------------------------------------------------------ 35 // Docstrings (trivial, non-duplicated docstrings are included inline). 36 //------------------------------------------------------------------------------ 37 38 static const char kContextParseTypeDocstring[] = 39 R"(Parses the assembly form of a type. 40 41 Returns a Type object or raises a ValueError if the type cannot be parsed. 42 43 See also: https://mlir.llvm.org/docs/LangRef/#type-system 44 )"; 45 46 static const char kContextGetCallSiteLocationDocstring[] = 47 R"(Gets a Location representing a caller and callsite)"; 48 49 static const char kContextGetFileLocationDocstring[] = 50 R"(Gets a Location representing a file, line and column)"; 51 52 static const char kContextGetFusedLocationDocstring[] = 53 R"(Gets a Location representing a fused location with optional metadata)"; 54 55 static const char kContextGetNameLocationDocString[] = 56 R"(Gets a Location representing a named location with optional child location)"; 57 58 static const char kModuleParseDocstring[] = 59 R"(Parses a module's assembly format from a string. 60 61 Returns a new MlirModule or raises a ValueError if the parsing fails. 62 63 See also: https://mlir.llvm.org/docs/LangRef/ 64 )"; 65 66 static const char kOperationCreateDocstring[] = 67 R"(Creates a new operation. 68 69 Args: 70 name: Operation name (e.g. "dialect.operation"). 71 results: Sequence of Type representing op result types. 72 attributes: Dict of str:Attribute. 73 successors: List of Block for the operation's successors. 74 regions: Number of regions to create. 75 location: A Location object (defaults to resolve from context manager). 76 ip: An InsertionPoint (defaults to resolve from context manager or set to 77 False to disable insertion, even with an insertion point set in the 78 context manager). 79 Returns: 80 A new "detached" Operation object. Detached operations can be added 81 to blocks, which causes them to become "attached." 82 )"; 83 84 static const char kOperationPrintDocstring[] = 85 R"(Prints the assembly form of the operation to a file like object. 86 87 Args: 88 file: The file like object to write to. Defaults to sys.stdout. 89 binary: Whether to write bytes (True) or str (False). Defaults to False. 90 large_elements_limit: Whether to elide elements attributes above this 91 number of elements. Defaults to None (no limit). 92 enable_debug_info: Whether to print debug/location information. Defaults 93 to False. 94 pretty_debug_info: Whether to format debug information for easier reading 95 by a human (warning: the result is unparseable). 96 print_generic_op_form: Whether to print the generic assembly forms of all 97 ops. Defaults to False. 98 use_local_Scope: Whether to print in a way that is more optimized for 99 multi-threaded access but may not be consistent with how the overall 100 module prints. 101 assume_verified: By default, if not printing generic form, the verifier 102 will be run and if it fails, generic form will be printed with a comment 103 about failed verification. While a reasonable default for interactive use, 104 for systematic use, it is often better for the caller to verify explicitly 105 and report failures in a more robust fashion. Set this to True if doing this 106 in order to avoid running a redundant verification. If the IR is actually 107 invalid, behavior is undefined. 108 )"; 109 110 static const char kOperationGetAsmDocstring[] = 111 R"(Gets the assembly form of the operation with all options available. 112 113 Args: 114 binary: Whether to return a bytes (True) or str (False) object. Defaults to 115 False. 116 ... others ...: See the print() method for common keyword arguments for 117 configuring the printout. 118 Returns: 119 Either a bytes or str object, depending on the setting of the 'binary' 120 argument. 121 )"; 122 123 static const char kOperationStrDunderDocstring[] = 124 R"(Gets the assembly form of the operation with default options. 125 126 If more advanced control over the assembly formatting or I/O options is needed, 127 use the dedicated print or get_asm method, which supports keyword arguments to 128 customize behavior. 129 )"; 130 131 static const char kDumpDocstring[] = 132 R"(Dumps a debug representation of the object to stderr.)"; 133 134 static const char kAppendBlockDocstring[] = 135 R"(Appends a new block, with argument types as positional args. 136 137 Returns: 138 The created block. 139 )"; 140 141 static const char kValueDunderStrDocstring[] = 142 R"(Returns the string form of the value. 143 144 If the value is a block argument, this is the assembly form of its type and the 145 position in the argument list. If the value is an operation result, this is 146 equivalent to printing the operation that produced it. 147 )"; 148 149 //------------------------------------------------------------------------------ 150 // Utilities. 151 //------------------------------------------------------------------------------ 152 153 /// Helper for creating an @classmethod. 154 template <class Func, typename... Args> 155 py::object classmethod(Func f, Args... args) { 156 py::object cf = py::cpp_function(f, args...); 157 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 158 } 159 160 static py::object 161 createCustomDialectWrapper(const std::string &dialectNamespace, 162 py::object dialectDescriptor) { 163 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 164 if (!dialectClass) { 165 // Use the base class. 166 return py::cast(PyDialect(std::move(dialectDescriptor))); 167 } 168 169 // Create the custom implementation. 170 return (*dialectClass)(std::move(dialectDescriptor)); 171 } 172 173 static MlirStringRef toMlirStringRef(const std::string &s) { 174 return mlirStringRefCreate(s.data(), s.size()); 175 } 176 177 /// Wrapper for the global LLVM debugging flag. 178 struct PyGlobalDebugFlag { 179 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 180 181 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } 182 183 static void bind(py::module &m) { 184 // Debug flags. 185 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 186 .def_property_static("flag", &PyGlobalDebugFlag::get, 187 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 188 } 189 }; 190 191 //------------------------------------------------------------------------------ 192 // Collections. 193 //------------------------------------------------------------------------------ 194 195 namespace { 196 197 class PyRegionIterator { 198 public: 199 PyRegionIterator(PyOperationRef operation) 200 : operation(std::move(operation)) {} 201 202 PyRegionIterator &dunderIter() { return *this; } 203 204 PyRegion dunderNext() { 205 operation->checkValid(); 206 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 207 throw py::stop_iteration(); 208 } 209 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 210 return PyRegion(operation, region); 211 } 212 213 static void bind(py::module &m) { 214 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 215 .def("__iter__", &PyRegionIterator::dunderIter) 216 .def("__next__", &PyRegionIterator::dunderNext); 217 } 218 219 private: 220 PyOperationRef operation; 221 int nextIndex = 0; 222 }; 223 224 /// Regions of an op are fixed length and indexed numerically so are represented 225 /// with a sequence-like container. 226 class PyRegionList { 227 public: 228 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 229 230 intptr_t dunderLen() { 231 operation->checkValid(); 232 return mlirOperationGetNumRegions(operation->get()); 233 } 234 235 PyRegion dunderGetItem(intptr_t index) { 236 // dunderLen checks validity. 237 if (index < 0 || index >= dunderLen()) { 238 throw SetPyError(PyExc_IndexError, 239 "attempt to access out of bounds region"); 240 } 241 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 242 return PyRegion(operation, region); 243 } 244 245 static void bind(py::module &m) { 246 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 247 .def("__len__", &PyRegionList::dunderLen) 248 .def("__getitem__", &PyRegionList::dunderGetItem); 249 } 250 251 private: 252 PyOperationRef operation; 253 }; 254 255 class PyBlockIterator { 256 public: 257 PyBlockIterator(PyOperationRef operation, MlirBlock next) 258 : operation(std::move(operation)), next(next) {} 259 260 PyBlockIterator &dunderIter() { return *this; } 261 262 PyBlock dunderNext() { 263 operation->checkValid(); 264 if (mlirBlockIsNull(next)) { 265 throw py::stop_iteration(); 266 } 267 268 PyBlock returnBlock(operation, next); 269 next = mlirBlockGetNextInRegion(next); 270 return returnBlock; 271 } 272 273 static void bind(py::module &m) { 274 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 275 .def("__iter__", &PyBlockIterator::dunderIter) 276 .def("__next__", &PyBlockIterator::dunderNext); 277 } 278 279 private: 280 PyOperationRef operation; 281 MlirBlock next; 282 }; 283 284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 285 /// we present them as a more full-featured list-like container but optimize 286 /// it for forward iteration. Blocks are always owned by a region. 287 class PyBlockList { 288 public: 289 PyBlockList(PyOperationRef operation, MlirRegion region) 290 : operation(std::move(operation)), region(region) {} 291 292 PyBlockIterator dunderIter() { 293 operation->checkValid(); 294 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 295 } 296 297 intptr_t dunderLen() { 298 operation->checkValid(); 299 intptr_t count = 0; 300 MlirBlock block = mlirRegionGetFirstBlock(region); 301 while (!mlirBlockIsNull(block)) { 302 count += 1; 303 block = mlirBlockGetNextInRegion(block); 304 } 305 return count; 306 } 307 308 PyBlock dunderGetItem(intptr_t index) { 309 operation->checkValid(); 310 if (index < 0) { 311 throw SetPyError(PyExc_IndexError, 312 "attempt to access out of bounds block"); 313 } 314 MlirBlock block = mlirRegionGetFirstBlock(region); 315 while (!mlirBlockIsNull(block)) { 316 if (index == 0) { 317 return PyBlock(operation, block); 318 } 319 block = mlirBlockGetNextInRegion(block); 320 index -= 1; 321 } 322 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 323 } 324 325 PyBlock appendBlock(const py::args &pyArgTypes) { 326 operation->checkValid(); 327 llvm::SmallVector<MlirType, 4> argTypes; 328 llvm::SmallVector<MlirLocation, 4> argLocs; 329 argTypes.reserve(pyArgTypes.size()); 330 argLocs.reserve(pyArgTypes.size()); 331 for (auto &pyArg : pyArgTypes) { 332 argTypes.push_back(pyArg.cast<PyType &>()); 333 // TODO: Pass in a proper location here. 334 argLocs.push_back( 335 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 336 } 337 338 MlirBlock block = 339 mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 340 mlirRegionAppendOwnedBlock(region, block); 341 return PyBlock(operation, block); 342 } 343 344 static void bind(py::module &m) { 345 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 346 .def("__getitem__", &PyBlockList::dunderGetItem) 347 .def("__iter__", &PyBlockList::dunderIter) 348 .def("__len__", &PyBlockList::dunderLen) 349 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 350 } 351 352 private: 353 PyOperationRef operation; 354 MlirRegion region; 355 }; 356 357 class PyOperationIterator { 358 public: 359 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 360 : parentOperation(std::move(parentOperation)), next(next) {} 361 362 PyOperationIterator &dunderIter() { return *this; } 363 364 py::object dunderNext() { 365 parentOperation->checkValid(); 366 if (mlirOperationIsNull(next)) { 367 throw py::stop_iteration(); 368 } 369 370 PyOperationRef returnOperation = 371 PyOperation::forOperation(parentOperation->getContext(), next); 372 next = mlirOperationGetNextInBlock(next); 373 return returnOperation->createOpView(); 374 } 375 376 static void bind(py::module &m) { 377 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 378 .def("__iter__", &PyOperationIterator::dunderIter) 379 .def("__next__", &PyOperationIterator::dunderNext); 380 } 381 382 private: 383 PyOperationRef parentOperation; 384 MlirOperation next; 385 }; 386 387 /// Operations are exposed by the C-API as a forward-only linked list. In 388 /// Python, we present them as a more full-featured list-like container but 389 /// optimize it for forward iteration. Iterable operations are always owned 390 /// by a block. 391 class PyOperationList { 392 public: 393 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 394 : parentOperation(std::move(parentOperation)), block(block) {} 395 396 PyOperationIterator dunderIter() { 397 parentOperation->checkValid(); 398 return PyOperationIterator(parentOperation, 399 mlirBlockGetFirstOperation(block)); 400 } 401 402 intptr_t dunderLen() { 403 parentOperation->checkValid(); 404 intptr_t count = 0; 405 MlirOperation childOp = mlirBlockGetFirstOperation(block); 406 while (!mlirOperationIsNull(childOp)) { 407 count += 1; 408 childOp = mlirOperationGetNextInBlock(childOp); 409 } 410 return count; 411 } 412 413 py::object dunderGetItem(intptr_t index) { 414 parentOperation->checkValid(); 415 if (index < 0) { 416 throw SetPyError(PyExc_IndexError, 417 "attempt to access out of bounds operation"); 418 } 419 MlirOperation childOp = mlirBlockGetFirstOperation(block); 420 while (!mlirOperationIsNull(childOp)) { 421 if (index == 0) { 422 return PyOperation::forOperation(parentOperation->getContext(), childOp) 423 ->createOpView(); 424 } 425 childOp = mlirOperationGetNextInBlock(childOp); 426 index -= 1; 427 } 428 throw SetPyError(PyExc_IndexError, 429 "attempt to access out of bounds operation"); 430 } 431 432 static void bind(py::module &m) { 433 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 434 .def("__getitem__", &PyOperationList::dunderGetItem) 435 .def("__iter__", &PyOperationList::dunderIter) 436 .def("__len__", &PyOperationList::dunderLen); 437 } 438 439 private: 440 PyOperationRef parentOperation; 441 MlirBlock block; 442 }; 443 444 } // namespace 445 446 //------------------------------------------------------------------------------ 447 // PyMlirContext 448 //------------------------------------------------------------------------------ 449 450 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 451 py::gil_scoped_acquire acquire; 452 auto &liveContexts = getLiveContexts(); 453 liveContexts[context.ptr] = this; 454 } 455 456 PyMlirContext::~PyMlirContext() { 457 // Note that the only public way to construct an instance is via the 458 // forContext method, which always puts the associated handle into 459 // liveContexts. 460 py::gil_scoped_acquire acquire; 461 getLiveContexts().erase(context.ptr); 462 mlirContextDestroy(context); 463 } 464 465 py::object PyMlirContext::getCapsule() { 466 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 467 } 468 469 py::object PyMlirContext::createFromCapsule(py::object capsule) { 470 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 471 if (mlirContextIsNull(rawContext)) 472 throw py::error_already_set(); 473 return forContext(rawContext).releaseObject(); 474 } 475 476 PyMlirContext *PyMlirContext::createNewContextForInit() { 477 MlirContext context = mlirContextCreate(); 478 mlirRegisterAllDialects(context); 479 return new PyMlirContext(context); 480 } 481 482 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 483 py::gil_scoped_acquire acquire; 484 auto &liveContexts = getLiveContexts(); 485 auto it = liveContexts.find(context.ptr); 486 if (it == liveContexts.end()) { 487 // Create. 488 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 489 py::object pyRef = py::cast(unownedContextWrapper); 490 assert(pyRef && "cast to py::object failed"); 491 liveContexts[context.ptr] = unownedContextWrapper; 492 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 493 } 494 // Use existing. 495 py::object pyRef = py::cast(it->second); 496 return PyMlirContextRef(it->second, std::move(pyRef)); 497 } 498 499 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 500 static LiveContextMap liveContexts; 501 return liveContexts; 502 } 503 504 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 505 506 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 507 508 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 509 510 pybind11::object PyMlirContext::contextEnter() { 511 return PyThreadContextEntry::pushContext(*this); 512 } 513 514 void PyMlirContext::contextExit(const pybind11::object &excType, 515 const pybind11::object &excVal, 516 const pybind11::object &excTb) { 517 PyThreadContextEntry::popContext(*this); 518 } 519 520 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { 521 // Note that ownership is transferred to the delete callback below by way of 522 // an explicit inc_ref (borrow). 523 PyDiagnosticHandler *pyHandler = 524 new PyDiagnosticHandler(get(), std::move(callback)); 525 py::object pyHandlerObject = 526 py::cast(pyHandler, py::return_value_policy::take_ownership); 527 pyHandlerObject.inc_ref(); 528 529 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 530 // guaranteed to be known to pybind. 531 auto handlerCallback = 532 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 533 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 534 py::object pyDiagnosticObject = 535 py::cast(pyDiagnostic, py::return_value_policy::take_ownership); 536 537 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 538 bool result = false; 539 { 540 // Since this can be called from arbitrary C++ contexts, always get the 541 // gil. 542 py::gil_scoped_acquire gil; 543 try { 544 result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); 545 } catch (std::exception &e) { 546 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 547 e.what()); 548 pyHandler->hadError = true; 549 } 550 } 551 552 pyDiagnostic->invalidate(); 553 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 554 }; 555 auto deleteCallback = +[](void *userData) { 556 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 557 assert(pyHandler->registeredID && "handler is not registered"); 558 pyHandler->registeredID.reset(); 559 560 // Decrement reference, balancing the inc_ref() above. 561 py::object pyHandlerObject = 562 py::cast(pyHandler, py::return_value_policy::reference); 563 pyHandlerObject.dec_ref(); 564 }; 565 566 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 567 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 568 return pyHandlerObject; 569 } 570 571 PyMlirContext &DefaultingPyMlirContext::resolve() { 572 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 573 if (!context) { 574 throw SetPyError( 575 PyExc_RuntimeError, 576 "An MLIR function requires a Context but none was provided in the call " 577 "or from the surrounding environment. Either pass to the function with " 578 "a 'context=' argument or establish a default using 'with Context():'"); 579 } 580 return *context; 581 } 582 583 //------------------------------------------------------------------------------ 584 // PyThreadContextEntry management 585 //------------------------------------------------------------------------------ 586 587 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 588 static thread_local std::vector<PyThreadContextEntry> stack; 589 return stack; 590 } 591 592 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 593 auto &stack = getStack(); 594 if (stack.empty()) 595 return nullptr; 596 return &stack.back(); 597 } 598 599 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 600 py::object insertionPoint, 601 py::object location) { 602 auto &stack = getStack(); 603 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 604 std::move(location)); 605 // If the new stack has more than one entry and the context of the new top 606 // entry matches the previous, copy the insertionPoint and location from the 607 // previous entry if missing from the new top entry. 608 if (stack.size() > 1) { 609 auto &prev = *(stack.rbegin() + 1); 610 auto ¤t = stack.back(); 611 if (current.context.is(prev.context)) { 612 // Default non-context objects from the previous entry. 613 if (!current.insertionPoint) 614 current.insertionPoint = prev.insertionPoint; 615 if (!current.location) 616 current.location = prev.location; 617 } 618 } 619 } 620 621 PyMlirContext *PyThreadContextEntry::getContext() { 622 if (!context) 623 return nullptr; 624 return py::cast<PyMlirContext *>(context); 625 } 626 627 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 628 if (!insertionPoint) 629 return nullptr; 630 return py::cast<PyInsertionPoint *>(insertionPoint); 631 } 632 633 PyLocation *PyThreadContextEntry::getLocation() { 634 if (!location) 635 return nullptr; 636 return py::cast<PyLocation *>(location); 637 } 638 639 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 640 auto *tos = getTopOfStack(); 641 return tos ? tos->getContext() : nullptr; 642 } 643 644 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 645 auto *tos = getTopOfStack(); 646 return tos ? tos->getInsertionPoint() : nullptr; 647 } 648 649 PyLocation *PyThreadContextEntry::getDefaultLocation() { 650 auto *tos = getTopOfStack(); 651 return tos ? tos->getLocation() : nullptr; 652 } 653 654 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 655 py::object contextObj = py::cast(context); 656 push(FrameKind::Context, /*context=*/contextObj, 657 /*insertionPoint=*/py::object(), 658 /*location=*/py::object()); 659 return contextObj; 660 } 661 662 void PyThreadContextEntry::popContext(PyMlirContext &context) { 663 auto &stack = getStack(); 664 if (stack.empty()) 665 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 666 auto &tos = stack.back(); 667 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 668 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 669 stack.pop_back(); 670 } 671 672 py::object 673 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 674 py::object contextObj = 675 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 676 py::object insertionPointObj = py::cast(insertionPoint); 677 push(FrameKind::InsertionPoint, 678 /*context=*/contextObj, 679 /*insertionPoint=*/insertionPointObj, 680 /*location=*/py::object()); 681 return insertionPointObj; 682 } 683 684 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 685 auto &stack = getStack(); 686 if (stack.empty()) 687 throw SetPyError(PyExc_RuntimeError, 688 "Unbalanced InsertionPoint enter/exit"); 689 auto &tos = stack.back(); 690 if (tos.frameKind != FrameKind::InsertionPoint && 691 tos.getInsertionPoint() != &insertionPoint) 692 throw SetPyError(PyExc_RuntimeError, 693 "Unbalanced InsertionPoint enter/exit"); 694 stack.pop_back(); 695 } 696 697 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 698 py::object contextObj = location.getContext().getObject(); 699 py::object locationObj = py::cast(location); 700 push(FrameKind::Location, /*context=*/contextObj, 701 /*insertionPoint=*/py::object(), 702 /*location=*/locationObj); 703 return locationObj; 704 } 705 706 void PyThreadContextEntry::popLocation(PyLocation &location) { 707 auto &stack = getStack(); 708 if (stack.empty()) 709 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 710 auto &tos = stack.back(); 711 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 712 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 713 stack.pop_back(); 714 } 715 716 //------------------------------------------------------------------------------ 717 // PyDiagnostic* 718 //------------------------------------------------------------------------------ 719 720 void PyDiagnostic::invalidate() { 721 valid = false; 722 if (materializedNotes) { 723 for (auto ¬eObject : *materializedNotes) { 724 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); 725 note->invalidate(); 726 } 727 } 728 } 729 730 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 731 py::object callback) 732 : context(context), callback(std::move(callback)) {} 733 734 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 735 736 void PyDiagnosticHandler::detach() { 737 if (!registeredID) 738 return; 739 MlirDiagnosticHandlerID localID = *registeredID; 740 mlirContextDetachDiagnosticHandler(context, localID); 741 assert(!registeredID && "should have unregistered"); 742 // Not strictly necessary but keeps stale pointers from being around to cause 743 // issues. 744 context = {nullptr}; 745 } 746 747 void PyDiagnostic::checkValid() { 748 if (!valid) { 749 throw std::invalid_argument( 750 "Diagnostic is invalid (used outside of callback)"); 751 } 752 } 753 754 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 755 checkValid(); 756 return mlirDiagnosticGetSeverity(diagnostic); 757 } 758 759 PyLocation PyDiagnostic::getLocation() { 760 checkValid(); 761 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 762 MlirContext context = mlirLocationGetContext(loc); 763 return PyLocation(PyMlirContext::forContext(context), loc); 764 } 765 766 py::str PyDiagnostic::getMessage() { 767 checkValid(); 768 py::object fileObject = py::module::import("io").attr("StringIO")(); 769 PyFileAccumulator accum(fileObject, /*binary=*/false); 770 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 771 return fileObject.attr("getvalue")(); 772 } 773 774 py::tuple PyDiagnostic::getNotes() { 775 checkValid(); 776 if (materializedNotes) 777 return *materializedNotes; 778 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 779 materializedNotes = py::tuple(numNotes); 780 for (intptr_t i = 0; i < numNotes; ++i) { 781 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 782 py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); 783 PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); 784 } 785 return *materializedNotes; 786 } 787 788 //------------------------------------------------------------------------------ 789 // PyDialect, PyDialectDescriptor, PyDialects 790 //------------------------------------------------------------------------------ 791 792 MlirDialect PyDialects::getDialectForKey(const std::string &key, 793 bool attrError) { 794 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 795 {key.data(), key.size()}); 796 if (mlirDialectIsNull(dialect)) { 797 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 798 Twine("Dialect '") + key + "' not found"); 799 } 800 return dialect; 801 } 802 803 //------------------------------------------------------------------------------ 804 // PyLocation 805 //------------------------------------------------------------------------------ 806 807 py::object PyLocation::getCapsule() { 808 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 809 } 810 811 PyLocation PyLocation::createFromCapsule(py::object capsule) { 812 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 813 if (mlirLocationIsNull(rawLoc)) 814 throw py::error_already_set(); 815 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 816 rawLoc); 817 } 818 819 py::object PyLocation::contextEnter() { 820 return PyThreadContextEntry::pushLocation(*this); 821 } 822 823 void PyLocation::contextExit(const pybind11::object &excType, 824 const pybind11::object &excVal, 825 const pybind11::object &excTb) { 826 PyThreadContextEntry::popLocation(*this); 827 } 828 829 PyLocation &DefaultingPyLocation::resolve() { 830 auto *location = PyThreadContextEntry::getDefaultLocation(); 831 if (!location) { 832 throw SetPyError( 833 PyExc_RuntimeError, 834 "An MLIR function requires a Location but none was provided in the " 835 "call or from the surrounding environment. Either pass to the function " 836 "with a 'loc=' argument or establish a default using 'with loc:'"); 837 } 838 return *location; 839 } 840 841 //------------------------------------------------------------------------------ 842 // PyModule 843 //------------------------------------------------------------------------------ 844 845 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 846 : BaseContextObject(std::move(contextRef)), module(module) {} 847 848 PyModule::~PyModule() { 849 py::gil_scoped_acquire acquire; 850 auto &liveModules = getContext()->liveModules; 851 assert(liveModules.count(module.ptr) == 1 && 852 "destroying module not in live map"); 853 liveModules.erase(module.ptr); 854 mlirModuleDestroy(module); 855 } 856 857 PyModuleRef PyModule::forModule(MlirModule module) { 858 MlirContext context = mlirModuleGetContext(module); 859 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 860 861 py::gil_scoped_acquire acquire; 862 auto &liveModules = contextRef->liveModules; 863 auto it = liveModules.find(module.ptr); 864 if (it == liveModules.end()) { 865 // Create. 866 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 867 // Note that the default return value policy on cast is automatic_reference, 868 // which does not take ownership (delete will not be called). 869 // Just be explicit. 870 py::object pyRef = 871 py::cast(unownedModule, py::return_value_policy::take_ownership); 872 unownedModule->handle = pyRef; 873 liveModules[module.ptr] = 874 std::make_pair(unownedModule->handle, unownedModule); 875 return PyModuleRef(unownedModule, std::move(pyRef)); 876 } 877 // Use existing. 878 PyModule *existing = it->second.second; 879 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 880 return PyModuleRef(existing, std::move(pyRef)); 881 } 882 883 py::object PyModule::createFromCapsule(py::object capsule) { 884 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 885 if (mlirModuleIsNull(rawModule)) 886 throw py::error_already_set(); 887 return forModule(rawModule).releaseObject(); 888 } 889 890 py::object PyModule::getCapsule() { 891 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 892 } 893 894 //------------------------------------------------------------------------------ 895 // PyOperation 896 //------------------------------------------------------------------------------ 897 898 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 899 : BaseContextObject(std::move(contextRef)), operation(operation) {} 900 901 PyOperation::~PyOperation() { 902 // If the operation has already been invalidated there is nothing to do. 903 if (!valid) 904 return; 905 auto &liveOperations = getContext()->liveOperations; 906 assert(liveOperations.count(operation.ptr) == 1 && 907 "destroying operation not in live map"); 908 liveOperations.erase(operation.ptr); 909 if (!isAttached()) { 910 mlirOperationDestroy(operation); 911 } 912 } 913 914 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 915 MlirOperation operation, 916 py::object parentKeepAlive) { 917 auto &liveOperations = contextRef->liveOperations; 918 // Create. 919 PyOperation *unownedOperation = 920 new PyOperation(std::move(contextRef), operation); 921 // Note that the default return value policy on cast is automatic_reference, 922 // which does not take ownership (delete will not be called). 923 // Just be explicit. 924 py::object pyRef = 925 py::cast(unownedOperation, py::return_value_policy::take_ownership); 926 unownedOperation->handle = pyRef; 927 if (parentKeepAlive) { 928 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 929 } 930 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 931 return PyOperationRef(unownedOperation, std::move(pyRef)); 932 } 933 934 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 935 MlirOperation operation, 936 py::object parentKeepAlive) { 937 auto &liveOperations = contextRef->liveOperations; 938 auto it = liveOperations.find(operation.ptr); 939 if (it == liveOperations.end()) { 940 // Create. 941 return createInstance(std::move(contextRef), operation, 942 std::move(parentKeepAlive)); 943 } 944 // Use existing. 945 PyOperation *existing = it->second.second; 946 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 947 return PyOperationRef(existing, std::move(pyRef)); 948 } 949 950 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 951 MlirOperation operation, 952 py::object parentKeepAlive) { 953 auto &liveOperations = contextRef->liveOperations; 954 assert(liveOperations.count(operation.ptr) == 0 && 955 "cannot create detached operation that already exists"); 956 (void)liveOperations; 957 958 PyOperationRef created = createInstance(std::move(contextRef), operation, 959 std::move(parentKeepAlive)); 960 created->attached = false; 961 return created; 962 } 963 964 void PyOperation::checkValid() const { 965 if (!valid) { 966 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 967 } 968 } 969 970 void PyOperationBase::print(py::object fileObject, bool binary, 971 llvm::Optional<int64_t> largeElementsLimit, 972 bool enableDebugInfo, bool prettyDebugInfo, 973 bool printGenericOpForm, bool useLocalScope, 974 bool assumeVerified) { 975 PyOperation &operation = getOperation(); 976 operation.checkValid(); 977 if (fileObject.is_none()) 978 fileObject = py::module::import("sys").attr("stdout"); 979 980 if (!assumeVerified && !printGenericOpForm && 981 !mlirOperationVerify(operation)) { 982 std::string message("// Verification failed, printing generic form\n"); 983 if (binary) { 984 fileObject.attr("write")(py::bytes(message)); 985 } else { 986 fileObject.attr("write")(py::str(message)); 987 } 988 printGenericOpForm = true; 989 } 990 991 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 992 if (largeElementsLimit) 993 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 994 if (enableDebugInfo) 995 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 996 if (printGenericOpForm) 997 mlirOpPrintingFlagsPrintGenericOpForm(flags); 998 999 PyFileAccumulator accum(fileObject, binary); 1000 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1001 accum.getUserData()); 1002 mlirOpPrintingFlagsDestroy(flags); 1003 } 1004 1005 py::object PyOperationBase::getAsm(bool binary, 1006 llvm::Optional<int64_t> largeElementsLimit, 1007 bool enableDebugInfo, bool prettyDebugInfo, 1008 bool printGenericOpForm, bool useLocalScope, 1009 bool assumeVerified) { 1010 py::object fileObject; 1011 if (binary) { 1012 fileObject = py::module::import("io").attr("BytesIO")(); 1013 } else { 1014 fileObject = py::module::import("io").attr("StringIO")(); 1015 } 1016 print(fileObject, /*binary=*/binary, 1017 /*largeElementsLimit=*/largeElementsLimit, 1018 /*enableDebugInfo=*/enableDebugInfo, 1019 /*prettyDebugInfo=*/prettyDebugInfo, 1020 /*printGenericOpForm=*/printGenericOpForm, 1021 /*useLocalScope=*/useLocalScope, 1022 /*assumeVerified=*/assumeVerified); 1023 1024 return fileObject.attr("getvalue")(); 1025 } 1026 1027 void PyOperationBase::moveAfter(PyOperationBase &other) { 1028 PyOperation &operation = getOperation(); 1029 PyOperation &otherOp = other.getOperation(); 1030 operation.checkValid(); 1031 otherOp.checkValid(); 1032 mlirOperationMoveAfter(operation, otherOp); 1033 operation.parentKeepAlive = otherOp.parentKeepAlive; 1034 } 1035 1036 void PyOperationBase::moveBefore(PyOperationBase &other) { 1037 PyOperation &operation = getOperation(); 1038 PyOperation &otherOp = other.getOperation(); 1039 operation.checkValid(); 1040 otherOp.checkValid(); 1041 mlirOperationMoveBefore(operation, otherOp); 1042 operation.parentKeepAlive = otherOp.parentKeepAlive; 1043 } 1044 1045 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 1046 checkValid(); 1047 if (!isAttached()) 1048 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 1049 MlirOperation operation = mlirOperationGetParentOperation(get()); 1050 if (mlirOperationIsNull(operation)) 1051 return {}; 1052 return PyOperation::forOperation(getContext(), operation); 1053 } 1054 1055 PyBlock PyOperation::getBlock() { 1056 checkValid(); 1057 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 1058 MlirBlock block = mlirOperationGetBlock(get()); 1059 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1060 assert(parentOperation && "Operation has no parent"); 1061 return PyBlock{std::move(*parentOperation), block}; 1062 } 1063 1064 py::object PyOperation::getCapsule() { 1065 checkValid(); 1066 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 1067 } 1068 1069 py::object PyOperation::createFromCapsule(py::object capsule) { 1070 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1071 if (mlirOperationIsNull(rawOperation)) 1072 throw py::error_already_set(); 1073 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1074 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1075 .releaseObject(); 1076 } 1077 1078 py::object PyOperation::create( 1079 const std::string &name, llvm::Optional<std::vector<PyType *>> results, 1080 llvm::Optional<std::vector<PyValue *>> operands, 1081 llvm::Optional<py::dict> attributes, 1082 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 1083 DefaultingPyLocation location, const py::object &maybeIp) { 1084 llvm::SmallVector<MlirValue, 4> mlirOperands; 1085 llvm::SmallVector<MlirType, 4> mlirResults; 1086 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1087 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1088 1089 // General parameter validation. 1090 if (regions < 0) 1091 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 1092 1093 // Unpack/validate operands. 1094 if (operands) { 1095 mlirOperands.reserve(operands->size()); 1096 for (PyValue *operand : *operands) { 1097 if (!operand) 1098 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 1099 mlirOperands.push_back(operand->get()); 1100 } 1101 } 1102 1103 // Unpack/validate results. 1104 if (results) { 1105 mlirResults.reserve(results->size()); 1106 for (PyType *result : *results) { 1107 // TODO: Verify result type originate from the same context. 1108 if (!result) 1109 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 1110 mlirResults.push_back(*result); 1111 } 1112 } 1113 // Unpack/validate attributes. 1114 if (attributes) { 1115 mlirAttributes.reserve(attributes->size()); 1116 for (auto &it : *attributes) { 1117 std::string key; 1118 try { 1119 key = it.first.cast<std::string>(); 1120 } catch (py::cast_error &err) { 1121 std::string msg = "Invalid attribute key (not a string) when " 1122 "attempting to create the operation \"" + 1123 name + "\" (" + err.what() + ")"; 1124 throw py::cast_error(msg); 1125 } 1126 try { 1127 auto &attribute = it.second.cast<PyAttribute &>(); 1128 // TODO: Verify attribute originates from the same context. 1129 mlirAttributes.emplace_back(std::move(key), attribute); 1130 } catch (py::reference_cast_error &) { 1131 // This exception seems thrown when the value is "None". 1132 std::string msg = 1133 "Found an invalid (`None`?) attribute value for the key \"" + key + 1134 "\" when attempting to create the operation \"" + name + "\""; 1135 throw py::cast_error(msg); 1136 } catch (py::cast_error &err) { 1137 std::string msg = "Invalid attribute value for the key \"" + key + 1138 "\" when attempting to create the operation \"" + 1139 name + "\" (" + err.what() + ")"; 1140 throw py::cast_error(msg); 1141 } 1142 } 1143 } 1144 // Unpack/validate successors. 1145 if (successors) { 1146 mlirSuccessors.reserve(successors->size()); 1147 for (auto *successor : *successors) { 1148 // TODO: Verify successor originate from the same context. 1149 if (!successor) 1150 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1151 mlirSuccessors.push_back(successor->get()); 1152 } 1153 } 1154 1155 // Apply unpacked/validated to the operation state. Beyond this 1156 // point, exceptions cannot be thrown or else the state will leak. 1157 MlirOperationState state = 1158 mlirOperationStateGet(toMlirStringRef(name), location); 1159 if (!mlirOperands.empty()) 1160 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1161 mlirOperands.data()); 1162 if (!mlirResults.empty()) 1163 mlirOperationStateAddResults(&state, mlirResults.size(), 1164 mlirResults.data()); 1165 if (!mlirAttributes.empty()) { 1166 // Note that the attribute names directly reference bytes in 1167 // mlirAttributes, so that vector must not be changed from here 1168 // on. 1169 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1170 mlirNamedAttributes.reserve(mlirAttributes.size()); 1171 for (auto &it : mlirAttributes) 1172 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1173 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1174 toMlirStringRef(it.first)), 1175 it.second)); 1176 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1177 mlirNamedAttributes.data()); 1178 } 1179 if (!mlirSuccessors.empty()) 1180 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1181 mlirSuccessors.data()); 1182 if (regions) { 1183 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1184 mlirRegions.resize(regions); 1185 for (int i = 0; i < regions; ++i) 1186 mlirRegions[i] = mlirRegionCreate(); 1187 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1188 mlirRegions.data()); 1189 } 1190 1191 // Construct the operation. 1192 MlirOperation operation = mlirOperationCreate(&state); 1193 PyOperationRef created = 1194 PyOperation::createDetached(location->getContext(), operation); 1195 1196 // InsertPoint active? 1197 if (!maybeIp.is(py::cast(false))) { 1198 PyInsertionPoint *ip; 1199 if (maybeIp.is_none()) { 1200 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1201 } else { 1202 ip = py::cast<PyInsertionPoint *>(maybeIp); 1203 } 1204 if (ip) 1205 ip->insert(*created.get()); 1206 } 1207 1208 return created->createOpView(); 1209 } 1210 1211 py::object PyOperation::createOpView() { 1212 checkValid(); 1213 MlirIdentifier ident = mlirOperationGetName(get()); 1214 MlirStringRef identStr = mlirIdentifierStr(ident); 1215 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1216 StringRef(identStr.data, identStr.length)); 1217 if (opViewClass) 1218 return (*opViewClass)(getRef().getObject()); 1219 return py::cast(PyOpView(getRef().getObject())); 1220 } 1221 1222 void PyOperation::erase() { 1223 checkValid(); 1224 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1225 // Python reference to a child operation is live. All children should also 1226 // have their `valid` bit set to false. 1227 auto &liveOperations = getContext()->liveOperations; 1228 if (liveOperations.count(operation.ptr)) 1229 liveOperations.erase(operation.ptr); 1230 mlirOperationDestroy(operation); 1231 valid = false; 1232 } 1233 1234 //------------------------------------------------------------------------------ 1235 // PyOpView 1236 //------------------------------------------------------------------------------ 1237 1238 py::object PyOpView::buildGeneric( 1239 const py::object &cls, py::list resultTypeList, py::list operandList, 1240 llvm::Optional<py::dict> attributes, 1241 llvm::Optional<std::vector<PyBlock *>> successors, 1242 llvm::Optional<int> regions, DefaultingPyLocation location, 1243 const py::object &maybeIp) { 1244 PyMlirContextRef context = location->getContext(); 1245 // Class level operation construction metadata. 1246 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1247 // Operand and result segment specs are either none, which does no 1248 // variadic unpacking, or a list of ints with segment sizes, where each 1249 // element is either a positive number (typically 1 for a scalar) or -1 to 1250 // indicate that it is derived from the length of the same-indexed operand 1251 // or result (implying that it is a list at that position). 1252 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1253 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1254 1255 std::vector<uint32_t> operandSegmentLengths; 1256 std::vector<uint32_t> resultSegmentLengths; 1257 1258 // Validate/determine region count. 1259 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1260 int opMinRegionCount = std::get<0>(opRegionSpec); 1261 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1262 if (!regions) { 1263 regions = opMinRegionCount; 1264 } 1265 if (*regions < opMinRegionCount) { 1266 throw py::value_error( 1267 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1268 llvm::Twine(opMinRegionCount) + 1269 " regions but was built with regions=" + llvm::Twine(*regions)) 1270 .str()); 1271 } 1272 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1273 throw py::value_error( 1274 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1275 llvm::Twine(opMinRegionCount) + 1276 " regions but was built with regions=" + llvm::Twine(*regions)) 1277 .str()); 1278 } 1279 1280 // Unpack results. 1281 std::vector<PyType *> resultTypes; 1282 resultTypes.reserve(resultTypeList.size()); 1283 if (resultSegmentSpecObj.is_none()) { 1284 // Non-variadic result unpacking. 1285 for (const auto &it : llvm::enumerate(resultTypeList)) { 1286 try { 1287 resultTypes.push_back(py::cast<PyType *>(it.value())); 1288 if (!resultTypes.back()) 1289 throw py::cast_error(); 1290 } catch (py::cast_error &err) { 1291 throw py::value_error((llvm::Twine("Result ") + 1292 llvm::Twine(it.index()) + " of operation \"" + 1293 name + "\" must be a Type (" + err.what() + ")") 1294 .str()); 1295 } 1296 } 1297 } else { 1298 // Sized result unpacking. 1299 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1300 if (resultSegmentSpec.size() != resultTypeList.size()) { 1301 throw py::value_error((llvm::Twine("Operation \"") + name + 1302 "\" requires " + 1303 llvm::Twine(resultSegmentSpec.size()) + 1304 " result segments but was provided " + 1305 llvm::Twine(resultTypeList.size())) 1306 .str()); 1307 } 1308 resultSegmentLengths.reserve(resultTypeList.size()); 1309 for (const auto &it : 1310 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1311 int segmentSpec = std::get<1>(it.value()); 1312 if (segmentSpec == 1 || segmentSpec == 0) { 1313 // Unpack unary element. 1314 try { 1315 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1316 if (resultType) { 1317 resultTypes.push_back(resultType); 1318 resultSegmentLengths.push_back(1); 1319 } else if (segmentSpec == 0) { 1320 // Allowed to be optional. 1321 resultSegmentLengths.push_back(0); 1322 } else { 1323 throw py::cast_error("was None and result is not optional"); 1324 } 1325 } catch (py::cast_error &err) { 1326 throw py::value_error((llvm::Twine("Result ") + 1327 llvm::Twine(it.index()) + " of operation \"" + 1328 name + "\" must be a Type (" + err.what() + 1329 ")") 1330 .str()); 1331 } 1332 } else if (segmentSpec == -1) { 1333 // Unpack sequence by appending. 1334 try { 1335 if (std::get<0>(it.value()).is_none()) { 1336 // Treat it as an empty list. 1337 resultSegmentLengths.push_back(0); 1338 } else { 1339 // Unpack the list. 1340 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1341 for (py::object segmentItem : segment) { 1342 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1343 if (!resultTypes.back()) { 1344 throw py::cast_error("contained a None item"); 1345 } 1346 } 1347 resultSegmentLengths.push_back(segment.size()); 1348 } 1349 } catch (std::exception &err) { 1350 // NOTE: Sloppy to be using a catch-all here, but there are at least 1351 // three different unrelated exceptions that can be thrown in the 1352 // above "casts". Just keep the scope above small and catch them all. 1353 throw py::value_error((llvm::Twine("Result ") + 1354 llvm::Twine(it.index()) + " of operation \"" + 1355 name + "\" must be a Sequence of Types (" + 1356 err.what() + ")") 1357 .str()); 1358 } 1359 } else { 1360 throw py::value_error("Unexpected segment spec"); 1361 } 1362 } 1363 } 1364 1365 // Unpack operands. 1366 std::vector<PyValue *> operands; 1367 operands.reserve(operands.size()); 1368 if (operandSegmentSpecObj.is_none()) { 1369 // Non-sized operand unpacking. 1370 for (const auto &it : llvm::enumerate(operandList)) { 1371 try { 1372 operands.push_back(py::cast<PyValue *>(it.value())); 1373 if (!operands.back()) 1374 throw py::cast_error(); 1375 } catch (py::cast_error &err) { 1376 throw py::value_error((llvm::Twine("Operand ") + 1377 llvm::Twine(it.index()) + " of operation \"" + 1378 name + "\" must be a Value (" + err.what() + ")") 1379 .str()); 1380 } 1381 } 1382 } else { 1383 // Sized operand unpacking. 1384 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1385 if (operandSegmentSpec.size() != operandList.size()) { 1386 throw py::value_error((llvm::Twine("Operation \"") + name + 1387 "\" requires " + 1388 llvm::Twine(operandSegmentSpec.size()) + 1389 "operand segments but was provided " + 1390 llvm::Twine(operandList.size())) 1391 .str()); 1392 } 1393 operandSegmentLengths.reserve(operandList.size()); 1394 for (const auto &it : 1395 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1396 int segmentSpec = std::get<1>(it.value()); 1397 if (segmentSpec == 1 || segmentSpec == 0) { 1398 // Unpack unary element. 1399 try { 1400 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1401 if (operandValue) { 1402 operands.push_back(operandValue); 1403 operandSegmentLengths.push_back(1); 1404 } else if (segmentSpec == 0) { 1405 // Allowed to be optional. 1406 operandSegmentLengths.push_back(0); 1407 } else { 1408 throw py::cast_error("was None and operand is not optional"); 1409 } 1410 } catch (py::cast_error &err) { 1411 throw py::value_error((llvm::Twine("Operand ") + 1412 llvm::Twine(it.index()) + " of operation \"" + 1413 name + "\" must be a Value (" + err.what() + 1414 ")") 1415 .str()); 1416 } 1417 } else if (segmentSpec == -1) { 1418 // Unpack sequence by appending. 1419 try { 1420 if (std::get<0>(it.value()).is_none()) { 1421 // Treat it as an empty list. 1422 operandSegmentLengths.push_back(0); 1423 } else { 1424 // Unpack the list. 1425 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1426 for (py::object segmentItem : segment) { 1427 operands.push_back(py::cast<PyValue *>(segmentItem)); 1428 if (!operands.back()) { 1429 throw py::cast_error("contained a None item"); 1430 } 1431 } 1432 operandSegmentLengths.push_back(segment.size()); 1433 } 1434 } catch (std::exception &err) { 1435 // NOTE: Sloppy to be using a catch-all here, but there are at least 1436 // three different unrelated exceptions that can be thrown in the 1437 // above "casts". Just keep the scope above small and catch them all. 1438 throw py::value_error((llvm::Twine("Operand ") + 1439 llvm::Twine(it.index()) + " of operation \"" + 1440 name + "\" must be a Sequence of Values (" + 1441 err.what() + ")") 1442 .str()); 1443 } 1444 } else { 1445 throw py::value_error("Unexpected segment spec"); 1446 } 1447 } 1448 } 1449 1450 // Merge operand/result segment lengths into attributes if needed. 1451 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1452 // Dup. 1453 if (attributes) { 1454 attributes = py::dict(*attributes); 1455 } else { 1456 attributes = py::dict(); 1457 } 1458 if (attributes->contains("result_segment_sizes") || 1459 attributes->contains("operand_segment_sizes")) { 1460 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1461 "'operand_segment_sizes' attribute is unsupported. " 1462 "Use Operation.create for such low-level access."); 1463 } 1464 1465 // Add result_segment_sizes attribute. 1466 if (!resultSegmentLengths.empty()) { 1467 int64_t size = resultSegmentLengths.size(); 1468 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1469 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1470 resultSegmentLengths.size(), resultSegmentLengths.data()); 1471 (*attributes)["result_segment_sizes"] = 1472 PyAttribute(context, segmentLengthAttr); 1473 } 1474 1475 // Add operand_segment_sizes attribute. 1476 if (!operandSegmentLengths.empty()) { 1477 int64_t size = operandSegmentLengths.size(); 1478 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1479 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1480 operandSegmentLengths.size(), operandSegmentLengths.data()); 1481 (*attributes)["operand_segment_sizes"] = 1482 PyAttribute(context, segmentLengthAttr); 1483 } 1484 } 1485 1486 // Delegate to create. 1487 return PyOperation::create(name, 1488 /*results=*/std::move(resultTypes), 1489 /*operands=*/std::move(operands), 1490 /*attributes=*/std::move(attributes), 1491 /*successors=*/std::move(successors), 1492 /*regions=*/*regions, location, maybeIp); 1493 } 1494 1495 PyOpView::PyOpView(const py::object &operationObject) 1496 // Casting through the PyOperationBase base-class and then back to the 1497 // Operation lets us accept any PyOperationBase subclass. 1498 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1499 operationObject(operation.getRef().getObject()) {} 1500 1501 py::object PyOpView::createRawSubclass(const py::object &userClass) { 1502 // This is... a little gross. The typical pattern is to have a pure python 1503 // class that extends OpView like: 1504 // class AddFOp(_cext.ir.OpView): 1505 // def __init__(self, loc, lhs, rhs): 1506 // operation = loc.context.create_operation( 1507 // "addf", lhs, rhs, results=[lhs.type]) 1508 // super().__init__(operation) 1509 // 1510 // I.e. The goal of the user facing type is to provide a nice constructor 1511 // that has complete freedom for the op under construction. This is at odds 1512 // with our other desire to sometimes create this object by just passing an 1513 // operation (to initialize the base class). We could do *arg and **kwargs 1514 // munging to try to make it work, but instead, we synthesize a new class 1515 // on the fly which extends this user class (AddFOp in this example) and 1516 // *give it* the base class's __init__ method, thus bypassing the 1517 // intermediate subclass's __init__ method entirely. While slightly, 1518 // underhanded, this is safe/legal because the type hierarchy has not changed 1519 // (we just added a new leaf) and we aren't mucking around with __new__. 1520 // Typically, this new class will be stored on the original as "_Raw" and will 1521 // be used for casts and other things that need a variant of the class that 1522 // is initialized purely from an operation. 1523 py::object parentMetaclass = 1524 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1525 py::dict attributes; 1526 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1527 // now. 1528 // auto opViewType = py::type::of<PyOpView>(); 1529 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1530 attributes["__init__"] = opViewType.attr("__init__"); 1531 py::str origName = userClass.attr("__name__"); 1532 py::str newName = py::str("_") + origName; 1533 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1534 } 1535 1536 //------------------------------------------------------------------------------ 1537 // PyInsertionPoint. 1538 //------------------------------------------------------------------------------ 1539 1540 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1541 1542 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1543 : refOperation(beforeOperationBase.getOperation().getRef()), 1544 block((*refOperation)->getBlock()) {} 1545 1546 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1547 PyOperation &operation = operationBase.getOperation(); 1548 if (operation.isAttached()) 1549 throw SetPyError(PyExc_ValueError, 1550 "Attempt to insert operation that is already attached"); 1551 block.getParentOperation()->checkValid(); 1552 MlirOperation beforeOp = {nullptr}; 1553 if (refOperation) { 1554 // Insert before operation. 1555 (*refOperation)->checkValid(); 1556 beforeOp = (*refOperation)->get(); 1557 } else { 1558 // Insert at end (before null) is only valid if the block does not 1559 // already end in a known terminator (violating this will cause assertion 1560 // failures later). 1561 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1562 throw py::index_error("Cannot insert operation at the end of a block " 1563 "that already has a terminator. Did you mean to " 1564 "use 'InsertionPoint.at_block_terminator(block)' " 1565 "versus 'InsertionPoint(block)'?"); 1566 } 1567 } 1568 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1569 operation.setAttached(); 1570 } 1571 1572 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1573 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1574 if (mlirOperationIsNull(firstOp)) { 1575 // Just insert at end. 1576 return PyInsertionPoint(block); 1577 } 1578 1579 // Insert before first op. 1580 PyOperationRef firstOpRef = PyOperation::forOperation( 1581 block.getParentOperation()->getContext(), firstOp); 1582 return PyInsertionPoint{block, std::move(firstOpRef)}; 1583 } 1584 1585 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1586 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1587 if (mlirOperationIsNull(terminator)) 1588 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1589 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1590 block.getParentOperation()->getContext(), terminator); 1591 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1592 } 1593 1594 py::object PyInsertionPoint::contextEnter() { 1595 return PyThreadContextEntry::pushInsertionPoint(*this); 1596 } 1597 1598 void PyInsertionPoint::contextExit(const pybind11::object &excType, 1599 const pybind11::object &excVal, 1600 const pybind11::object &excTb) { 1601 PyThreadContextEntry::popInsertionPoint(*this); 1602 } 1603 1604 //------------------------------------------------------------------------------ 1605 // PyAttribute. 1606 //------------------------------------------------------------------------------ 1607 1608 bool PyAttribute::operator==(const PyAttribute &other) { 1609 return mlirAttributeEqual(attr, other.attr); 1610 } 1611 1612 py::object PyAttribute::getCapsule() { 1613 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1614 } 1615 1616 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1617 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1618 if (mlirAttributeIsNull(rawAttr)) 1619 throw py::error_already_set(); 1620 return PyAttribute( 1621 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1622 } 1623 1624 //------------------------------------------------------------------------------ 1625 // PyNamedAttribute. 1626 //------------------------------------------------------------------------------ 1627 1628 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1629 : ownedName(new std::string(std::move(ownedName))) { 1630 namedAttr = mlirNamedAttributeGet( 1631 mlirIdentifierGet(mlirAttributeGetContext(attr), 1632 toMlirStringRef(*this->ownedName)), 1633 attr); 1634 } 1635 1636 //------------------------------------------------------------------------------ 1637 // PyType. 1638 //------------------------------------------------------------------------------ 1639 1640 bool PyType::operator==(const PyType &other) { 1641 return mlirTypeEqual(type, other.type); 1642 } 1643 1644 py::object PyType::getCapsule() { 1645 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1646 } 1647 1648 PyType PyType::createFromCapsule(py::object capsule) { 1649 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1650 if (mlirTypeIsNull(rawType)) 1651 throw py::error_already_set(); 1652 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1653 rawType); 1654 } 1655 1656 //------------------------------------------------------------------------------ 1657 // PyValue and subclases. 1658 //------------------------------------------------------------------------------ 1659 1660 pybind11::object PyValue::getCapsule() { 1661 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1662 } 1663 1664 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1665 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1666 if (mlirValueIsNull(value)) 1667 throw py::error_already_set(); 1668 MlirOperation owner; 1669 if (mlirValueIsAOpResult(value)) 1670 owner = mlirOpResultGetOwner(value); 1671 if (mlirValueIsABlockArgument(value)) 1672 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1673 if (mlirOperationIsNull(owner)) 1674 throw py::error_already_set(); 1675 MlirContext ctx = mlirOperationGetContext(owner); 1676 PyOperationRef ownerRef = 1677 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1678 return PyValue(ownerRef, value); 1679 } 1680 1681 //------------------------------------------------------------------------------ 1682 // PySymbolTable. 1683 //------------------------------------------------------------------------------ 1684 1685 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1686 : operation(operation.getOperation().getRef()) { 1687 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1688 if (mlirSymbolTableIsNull(symbolTable)) { 1689 throw py::cast_error("Operation is not a Symbol Table."); 1690 } 1691 } 1692 1693 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1694 operation->checkValid(); 1695 MlirOperation symbol = mlirSymbolTableLookup( 1696 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1697 if (mlirOperationIsNull(symbol)) 1698 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1699 1700 return PyOperation::forOperation(operation->getContext(), symbol, 1701 operation.getObject()) 1702 ->createOpView(); 1703 } 1704 1705 void PySymbolTable::erase(PyOperationBase &symbol) { 1706 operation->checkValid(); 1707 symbol.getOperation().checkValid(); 1708 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1709 // The operation is also erased, so we must invalidate it. There may be Python 1710 // references to this operation so we don't want to delete it from the list of 1711 // live operations here. 1712 symbol.getOperation().valid = false; 1713 } 1714 1715 void PySymbolTable::dunderDel(const std::string &name) { 1716 py::object operation = dunderGetItem(name); 1717 erase(py::cast<PyOperationBase &>(operation)); 1718 } 1719 1720 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1721 operation->checkValid(); 1722 symbol.getOperation().checkValid(); 1723 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1724 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1725 if (mlirAttributeIsNull(symbolAttr)) 1726 throw py::value_error("Expected operation to have a symbol name."); 1727 return PyAttribute( 1728 symbol.getOperation().getContext(), 1729 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1730 } 1731 1732 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1733 // Op must already be a symbol. 1734 PyOperation &operation = symbol.getOperation(); 1735 operation.checkValid(); 1736 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1737 MlirAttribute existingNameAttr = 1738 mlirOperationGetAttributeByName(operation.get(), attrName); 1739 if (mlirAttributeIsNull(existingNameAttr)) 1740 throw py::value_error("Expected operation to have a symbol name."); 1741 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1742 } 1743 1744 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1745 const std::string &name) { 1746 // Op must already be a symbol. 1747 PyOperation &operation = symbol.getOperation(); 1748 operation.checkValid(); 1749 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1750 MlirAttribute existingNameAttr = 1751 mlirOperationGetAttributeByName(operation.get(), attrName); 1752 if (mlirAttributeIsNull(existingNameAttr)) 1753 throw py::value_error("Expected operation to have a symbol name."); 1754 MlirAttribute newNameAttr = 1755 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1756 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1757 } 1758 1759 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1760 PyOperation &operation = symbol.getOperation(); 1761 operation.checkValid(); 1762 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1763 MlirAttribute existingVisAttr = 1764 mlirOperationGetAttributeByName(operation.get(), attrName); 1765 if (mlirAttributeIsNull(existingVisAttr)) 1766 throw py::value_error("Expected operation to have a symbol visibility."); 1767 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1768 } 1769 1770 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1771 const std::string &visibility) { 1772 if (visibility != "public" && visibility != "private" && 1773 visibility != "nested") 1774 throw py::value_error( 1775 "Expected visibility to be 'public', 'private' or 'nested'"); 1776 PyOperation &operation = symbol.getOperation(); 1777 operation.checkValid(); 1778 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1779 MlirAttribute existingVisAttr = 1780 mlirOperationGetAttributeByName(operation.get(), attrName); 1781 if (mlirAttributeIsNull(existingVisAttr)) 1782 throw py::value_error("Expected operation to have a symbol visibility."); 1783 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1784 toMlirStringRef(visibility)); 1785 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1786 } 1787 1788 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1789 const std::string &newSymbol, 1790 PyOperationBase &from) { 1791 PyOperation &fromOperation = from.getOperation(); 1792 fromOperation.checkValid(); 1793 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1794 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1795 from.getOperation()))) 1796 1797 throw py::value_error("Symbol rename failed"); 1798 } 1799 1800 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1801 bool allSymUsesVisible, 1802 py::object callback) { 1803 PyOperation &fromOperation = from.getOperation(); 1804 fromOperation.checkValid(); 1805 struct UserData { 1806 PyMlirContextRef context; 1807 py::object callback; 1808 bool gotException; 1809 std::string exceptionWhat; 1810 py::object exceptionType; 1811 }; 1812 UserData userData{ 1813 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1814 mlirSymbolTableWalkSymbolTables( 1815 fromOperation.get(), allSymUsesVisible, 1816 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1817 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1818 auto pyFoundOp = 1819 PyOperation::forOperation(calleeUserData->context, foundOp); 1820 if (calleeUserData->gotException) 1821 return; 1822 try { 1823 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1824 } catch (py::error_already_set &e) { 1825 calleeUserData->gotException = true; 1826 calleeUserData->exceptionWhat = e.what(); 1827 calleeUserData->exceptionType = e.type(); 1828 } 1829 }, 1830 static_cast<void *>(&userData)); 1831 if (userData.gotException) { 1832 std::string message("Exception raised in callback: "); 1833 message.append(userData.exceptionWhat); 1834 throw std::runtime_error(message); 1835 } 1836 } 1837 1838 namespace { 1839 /// CRTP base class for Python MLIR values that subclass Value and should be 1840 /// castable from it. The value hierarchy is one level deep and is not supposed 1841 /// to accommodate other levels unless core MLIR changes. 1842 template <typename DerivedTy> 1843 class PyConcreteValue : public PyValue { 1844 public: 1845 // Derived classes must define statics for: 1846 // IsAFunctionTy isaFunction 1847 // const char *pyClassName 1848 // and redefine bindDerived. 1849 using ClassTy = py::class_<DerivedTy, PyValue>; 1850 using IsAFunctionTy = bool (*)(MlirValue); 1851 1852 PyConcreteValue() = default; 1853 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1854 : PyValue(operationRef, value) {} 1855 PyConcreteValue(PyValue &orig) 1856 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1857 1858 /// Attempts to cast the original value to the derived type and throws on 1859 /// type mismatches. 1860 static MlirValue castFrom(PyValue &orig) { 1861 if (!DerivedTy::isaFunction(orig.get())) { 1862 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1863 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1864 DerivedTy::pyClassName + 1865 " (from " + origRepr + ")"); 1866 } 1867 return orig.get(); 1868 } 1869 1870 /// Binds the Python module objects to functions of this class. 1871 static void bind(py::module &m) { 1872 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1873 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1874 cls.def_static( 1875 "isinstance", 1876 [](PyValue &otherValue) -> bool { 1877 return DerivedTy::isaFunction(otherValue); 1878 }, 1879 py::arg("other_value")); 1880 DerivedTy::bindDerived(cls); 1881 } 1882 1883 /// Implemented by derived classes to add methods to the Python subclass. 1884 static void bindDerived(ClassTy &m) {} 1885 }; 1886 1887 /// Python wrapper for MlirBlockArgument. 1888 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1889 public: 1890 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1891 static constexpr const char *pyClassName = "BlockArgument"; 1892 using PyConcreteValue::PyConcreteValue; 1893 1894 static void bindDerived(ClassTy &c) { 1895 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1896 return PyBlock(self.getParentOperation(), 1897 mlirBlockArgumentGetOwner(self.get())); 1898 }); 1899 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1900 return mlirBlockArgumentGetArgNumber(self.get()); 1901 }); 1902 c.def( 1903 "set_type", 1904 [](PyBlockArgument &self, PyType type) { 1905 return mlirBlockArgumentSetType(self.get(), type); 1906 }, 1907 py::arg("type")); 1908 } 1909 }; 1910 1911 /// Python wrapper for MlirOpResult. 1912 class PyOpResult : public PyConcreteValue<PyOpResult> { 1913 public: 1914 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1915 static constexpr const char *pyClassName = "OpResult"; 1916 using PyConcreteValue::PyConcreteValue; 1917 1918 static void bindDerived(ClassTy &c) { 1919 c.def_property_readonly("owner", [](PyOpResult &self) { 1920 assert( 1921 mlirOperationEqual(self.getParentOperation()->get(), 1922 mlirOpResultGetOwner(self.get())) && 1923 "expected the owner of the value in Python to match that in the IR"); 1924 return self.getParentOperation().getObject(); 1925 }); 1926 c.def_property_readonly("result_number", [](PyOpResult &self) { 1927 return mlirOpResultGetResultNumber(self.get()); 1928 }); 1929 } 1930 }; 1931 1932 /// Returns the list of types of the values held by container. 1933 template <typename Container> 1934 static std::vector<PyType> getValueTypes(Container &container, 1935 PyMlirContextRef &context) { 1936 std::vector<PyType> result; 1937 result.reserve(container.getNumElements()); 1938 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1939 result.push_back( 1940 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1941 } 1942 return result; 1943 } 1944 1945 /// A list of block arguments. Internally, these are stored as consecutive 1946 /// elements, random access is cheap. The argument list is associated with the 1947 /// operation that contains the block (detached blocks are not allowed in 1948 /// Python bindings) and extends its lifetime. 1949 class PyBlockArgumentList 1950 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1951 public: 1952 static constexpr const char *pyClassName = "BlockArgumentList"; 1953 1954 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1955 intptr_t startIndex = 0, intptr_t length = -1, 1956 intptr_t step = 1) 1957 : Sliceable(startIndex, 1958 length == -1 ? mlirBlockGetNumArguments(block) : length, 1959 step), 1960 operation(std::move(operation)), block(block) {} 1961 1962 /// Returns the number of arguments in the list. 1963 intptr_t getNumElements() { 1964 operation->checkValid(); 1965 return mlirBlockGetNumArguments(block); 1966 } 1967 1968 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1969 PyBlockArgument getElement(intptr_t pos) { 1970 MlirValue argument = mlirBlockGetArgument(block, pos); 1971 return PyBlockArgument(operation, argument); 1972 } 1973 1974 /// Returns a sublist of this list. 1975 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1976 intptr_t step) { 1977 return PyBlockArgumentList(operation, block, startIndex, length, step); 1978 } 1979 1980 static void bindDerived(ClassTy &c) { 1981 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1982 return getValueTypes(self, self.operation->getContext()); 1983 }); 1984 } 1985 1986 private: 1987 PyOperationRef operation; 1988 MlirBlock block; 1989 }; 1990 1991 /// A list of operation operands. Internally, these are stored as consecutive 1992 /// elements, random access is cheap. The result list is associated with the 1993 /// operation whose results these are, and extends the lifetime of this 1994 /// operation. 1995 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1996 public: 1997 static constexpr const char *pyClassName = "OpOperandList"; 1998 1999 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2000 intptr_t length = -1, intptr_t step = 1) 2001 : Sliceable(startIndex, 2002 length == -1 ? mlirOperationGetNumOperands(operation->get()) 2003 : length, 2004 step), 2005 operation(operation) {} 2006 2007 intptr_t getNumElements() { 2008 operation->checkValid(); 2009 return mlirOperationGetNumOperands(operation->get()); 2010 } 2011 2012 PyValue getElement(intptr_t pos) { 2013 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2014 MlirOperation owner; 2015 if (mlirValueIsAOpResult(operand)) 2016 owner = mlirOpResultGetOwner(operand); 2017 else if (mlirValueIsABlockArgument(operand)) 2018 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2019 else 2020 assert(false && "Value must be an block arg or op result."); 2021 PyOperationRef pyOwner = 2022 PyOperation::forOperation(operation->getContext(), owner); 2023 return PyValue(pyOwner, operand); 2024 } 2025 2026 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2027 return PyOpOperandList(operation, startIndex, length, step); 2028 } 2029 2030 void dunderSetItem(intptr_t index, PyValue value) { 2031 index = wrapIndex(index); 2032 mlirOperationSetOperand(operation->get(), index, value.get()); 2033 } 2034 2035 static void bindDerived(ClassTy &c) { 2036 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2037 } 2038 2039 private: 2040 PyOperationRef operation; 2041 }; 2042 2043 /// A list of operation results. Internally, these are stored as consecutive 2044 /// elements, random access is cheap. The result list is associated with the 2045 /// operation whose results these are, and extends the lifetime of this 2046 /// operation. 2047 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 2048 public: 2049 static constexpr const char *pyClassName = "OpResultList"; 2050 2051 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 2052 intptr_t length = -1, intptr_t step = 1) 2053 : Sliceable(startIndex, 2054 length == -1 ? mlirOperationGetNumResults(operation->get()) 2055 : length, 2056 step), 2057 operation(operation) {} 2058 2059 intptr_t getNumElements() { 2060 operation->checkValid(); 2061 return mlirOperationGetNumResults(operation->get()); 2062 } 2063 2064 PyOpResult getElement(intptr_t index) { 2065 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 2066 return PyOpResult(value); 2067 } 2068 2069 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2070 return PyOpResultList(operation, startIndex, length, step); 2071 } 2072 2073 static void bindDerived(ClassTy &c) { 2074 c.def_property_readonly("types", [](PyOpResultList &self) { 2075 return getValueTypes(self, self.operation->getContext()); 2076 }); 2077 } 2078 2079 private: 2080 PyOperationRef operation; 2081 }; 2082 2083 /// A list of operation attributes. Can be indexed by name, producing 2084 /// attributes, or by index, producing named attributes. 2085 class PyOpAttributeMap { 2086 public: 2087 PyOpAttributeMap(PyOperationRef operation) 2088 : operation(std::move(operation)) {} 2089 2090 PyAttribute dunderGetItemNamed(const std::string &name) { 2091 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2092 toMlirStringRef(name)); 2093 if (mlirAttributeIsNull(attr)) { 2094 throw SetPyError(PyExc_KeyError, 2095 "attempt to access a non-existent attribute"); 2096 } 2097 return PyAttribute(operation->getContext(), attr); 2098 } 2099 2100 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2101 if (index < 0 || index >= dunderLen()) { 2102 throw SetPyError(PyExc_IndexError, 2103 "attempt to access out of bounds attribute"); 2104 } 2105 MlirNamedAttribute namedAttr = 2106 mlirOperationGetAttribute(operation->get(), index); 2107 return PyNamedAttribute( 2108 namedAttr.attribute, 2109 std::string(mlirIdentifierStr(namedAttr.name).data, 2110 mlirIdentifierStr(namedAttr.name).length)); 2111 } 2112 2113 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2114 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2115 attr); 2116 } 2117 2118 void dunderDelItem(const std::string &name) { 2119 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2120 toMlirStringRef(name)); 2121 if (!removed) 2122 throw SetPyError(PyExc_KeyError, 2123 "attempt to delete a non-existent attribute"); 2124 } 2125 2126 intptr_t dunderLen() { 2127 return mlirOperationGetNumAttributes(operation->get()); 2128 } 2129 2130 bool dunderContains(const std::string &name) { 2131 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2132 operation->get(), toMlirStringRef(name))); 2133 } 2134 2135 static void bind(py::module &m) { 2136 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2137 .def("__contains__", &PyOpAttributeMap::dunderContains) 2138 .def("__len__", &PyOpAttributeMap::dunderLen) 2139 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2140 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2141 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2142 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2143 } 2144 2145 private: 2146 PyOperationRef operation; 2147 }; 2148 2149 } // namespace 2150 2151 //------------------------------------------------------------------------------ 2152 // Populates the core exports of the 'ir' submodule. 2153 //------------------------------------------------------------------------------ 2154 2155 void mlir::python::populateIRCore(py::module &m) { 2156 //---------------------------------------------------------------------------- 2157 // Enums. 2158 //---------------------------------------------------------------------------- 2159 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local()) 2160 .value("ERROR", MlirDiagnosticError) 2161 .value("WARNING", MlirDiagnosticWarning) 2162 .value("NOTE", MlirDiagnosticNote) 2163 .value("REMARK", MlirDiagnosticRemark); 2164 2165 //---------------------------------------------------------------------------- 2166 // Mapping of Diagnostics. 2167 //---------------------------------------------------------------------------- 2168 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local()) 2169 .def_property_readonly("severity", &PyDiagnostic::getSeverity) 2170 .def_property_readonly("location", &PyDiagnostic::getLocation) 2171 .def_property_readonly("message", &PyDiagnostic::getMessage) 2172 .def_property_readonly("notes", &PyDiagnostic::getNotes) 2173 .def("__str__", [](PyDiagnostic &self) -> py::str { 2174 if (!self.isValid()) 2175 return "<Invalid Diagnostic>"; 2176 return self.getMessage(); 2177 }); 2178 2179 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local()) 2180 .def("detach", &PyDiagnosticHandler::detach) 2181 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) 2182 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) 2183 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2184 .def("__exit__", &PyDiagnosticHandler::contextExit); 2185 2186 //---------------------------------------------------------------------------- 2187 // Mapping of MlirContext. 2188 //---------------------------------------------------------------------------- 2189 py::class_<PyMlirContext>(m, "Context", py::module_local()) 2190 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2191 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2192 .def("_get_context_again", 2193 [](PyMlirContext &self) { 2194 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2195 return ref.releaseObject(); 2196 }) 2197 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2198 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2199 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2200 &PyMlirContext::getCapsule) 2201 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2202 .def("__enter__", &PyMlirContext::contextEnter) 2203 .def("__exit__", &PyMlirContext::contextExit) 2204 .def_property_readonly_static( 2205 "current", 2206 [](py::object & /*class*/) { 2207 auto *context = PyThreadContextEntry::getDefaultContext(); 2208 if (!context) 2209 throw SetPyError(PyExc_ValueError, "No current Context"); 2210 return context; 2211 }, 2212 "Gets the Context bound to the current thread or raises ValueError") 2213 .def_property_readonly( 2214 "dialects", 2215 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2216 "Gets a container for accessing dialects by name") 2217 .def_property_readonly( 2218 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2219 "Alias for 'dialect'") 2220 .def( 2221 "get_dialect_descriptor", 2222 [=](PyMlirContext &self, std::string &name) { 2223 MlirDialect dialect = mlirContextGetOrLoadDialect( 2224 self.get(), {name.data(), name.size()}); 2225 if (mlirDialectIsNull(dialect)) { 2226 throw SetPyError(PyExc_ValueError, 2227 Twine("Dialect '") + name + "' not found"); 2228 } 2229 return PyDialectDescriptor(self.getRef(), dialect); 2230 }, 2231 py::arg("dialect_name"), 2232 "Gets or loads a dialect by name, returning its descriptor object") 2233 .def_property( 2234 "allow_unregistered_dialects", 2235 [](PyMlirContext &self) -> bool { 2236 return mlirContextGetAllowUnregisteredDialects(self.get()); 2237 }, 2238 [](PyMlirContext &self, bool value) { 2239 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2240 }) 2241 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2242 py::arg("callback"), 2243 "Attaches a diagnostic handler that will receive callbacks") 2244 .def( 2245 "enable_multithreading", 2246 [](PyMlirContext &self, bool enable) { 2247 mlirContextEnableMultithreading(self.get(), enable); 2248 }, 2249 py::arg("enable")) 2250 .def( 2251 "is_registered_operation", 2252 [](PyMlirContext &self, std::string &name) { 2253 return mlirContextIsRegisteredOperation( 2254 self.get(), MlirStringRef{name.data(), name.size()}); 2255 }, 2256 py::arg("operation_name")); 2257 2258 //---------------------------------------------------------------------------- 2259 // Mapping of PyDialectDescriptor 2260 //---------------------------------------------------------------------------- 2261 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2262 .def_property_readonly("namespace", 2263 [](PyDialectDescriptor &self) { 2264 MlirStringRef ns = 2265 mlirDialectGetNamespace(self.get()); 2266 return py::str(ns.data, ns.length); 2267 }) 2268 .def("__repr__", [](PyDialectDescriptor &self) { 2269 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2270 std::string repr("<DialectDescriptor "); 2271 repr.append(ns.data, ns.length); 2272 repr.append(">"); 2273 return repr; 2274 }); 2275 2276 //---------------------------------------------------------------------------- 2277 // Mapping of PyDialects 2278 //---------------------------------------------------------------------------- 2279 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2280 .def("__getitem__", 2281 [=](PyDialects &self, std::string keyName) { 2282 MlirDialect dialect = 2283 self.getDialectForKey(keyName, /*attrError=*/false); 2284 py::object descriptor = 2285 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2286 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2287 }) 2288 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2289 MlirDialect dialect = 2290 self.getDialectForKey(attrName, /*attrError=*/true); 2291 py::object descriptor = 2292 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2293 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2294 }); 2295 2296 //---------------------------------------------------------------------------- 2297 // Mapping of PyDialect 2298 //---------------------------------------------------------------------------- 2299 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2300 .def(py::init<py::object>(), py::arg("descriptor")) 2301 .def_property_readonly( 2302 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2303 .def("__repr__", [](py::object self) { 2304 auto clazz = self.attr("__class__"); 2305 return py::str("<Dialect ") + 2306 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2307 clazz.attr("__module__") + py::str(".") + 2308 clazz.attr("__name__") + py::str(")>"); 2309 }); 2310 2311 //---------------------------------------------------------------------------- 2312 // Mapping of Location 2313 //---------------------------------------------------------------------------- 2314 py::class_<PyLocation>(m, "Location", py::module_local()) 2315 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2316 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2317 .def("__enter__", &PyLocation::contextEnter) 2318 .def("__exit__", &PyLocation::contextExit) 2319 .def("__eq__", 2320 [](PyLocation &self, PyLocation &other) -> bool { 2321 return mlirLocationEqual(self, other); 2322 }) 2323 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2324 .def_property_readonly_static( 2325 "current", 2326 [](py::object & /*class*/) { 2327 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2328 if (!loc) 2329 throw SetPyError(PyExc_ValueError, "No current Location"); 2330 return loc; 2331 }, 2332 "Gets the Location bound to the current thread or raises ValueError") 2333 .def_static( 2334 "unknown", 2335 [](DefaultingPyMlirContext context) { 2336 return PyLocation(context->getRef(), 2337 mlirLocationUnknownGet(context->get())); 2338 }, 2339 py::arg("context") = py::none(), 2340 "Gets a Location representing an unknown location") 2341 .def_static( 2342 "callsite", 2343 [](PyLocation callee, const std::vector<PyLocation> &frames, 2344 DefaultingPyMlirContext context) { 2345 if (frames.empty()) 2346 throw py::value_error("No caller frames provided"); 2347 MlirLocation caller = frames.back().get(); 2348 for (const PyLocation &frame : 2349 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2350 caller = mlirLocationCallSiteGet(frame.get(), caller); 2351 return PyLocation(context->getRef(), 2352 mlirLocationCallSiteGet(callee.get(), caller)); 2353 }, 2354 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2355 kContextGetCallSiteLocationDocstring) 2356 .def_static( 2357 "file", 2358 [](std::string filename, int line, int col, 2359 DefaultingPyMlirContext context) { 2360 return PyLocation( 2361 context->getRef(), 2362 mlirLocationFileLineColGet( 2363 context->get(), toMlirStringRef(filename), line, col)); 2364 }, 2365 py::arg("filename"), py::arg("line"), py::arg("col"), 2366 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2367 .def_static( 2368 "fused", 2369 [](const std::vector<PyLocation> &pyLocations, 2370 llvm::Optional<PyAttribute> metadata, 2371 DefaultingPyMlirContext context) { 2372 llvm::SmallVector<MlirLocation, 4> locations; 2373 locations.reserve(pyLocations.size()); 2374 for (auto &pyLocation : pyLocations) 2375 locations.push_back(pyLocation.get()); 2376 MlirLocation location = mlirLocationFusedGet( 2377 context->get(), locations.size(), locations.data(), 2378 metadata ? metadata->get() : MlirAttribute{0}); 2379 return PyLocation(context->getRef(), location); 2380 }, 2381 py::arg("locations"), py::arg("metadata") = py::none(), 2382 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2383 .def_static( 2384 "name", 2385 [](std::string name, llvm::Optional<PyLocation> childLoc, 2386 DefaultingPyMlirContext context) { 2387 return PyLocation( 2388 context->getRef(), 2389 mlirLocationNameGet( 2390 context->get(), toMlirStringRef(name), 2391 childLoc ? childLoc->get() 2392 : mlirLocationUnknownGet(context->get()))); 2393 }, 2394 py::arg("name"), py::arg("childLoc") = py::none(), 2395 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2396 .def_property_readonly( 2397 "context", 2398 [](PyLocation &self) { return self.getContext().getObject(); }, 2399 "Context that owns the Location") 2400 .def( 2401 "emit_error", 2402 [](PyLocation &self, std::string message) { 2403 mlirEmitError(self, message.c_str()); 2404 }, 2405 py::arg("message"), "Emits an error at this location") 2406 .def("__repr__", [](PyLocation &self) { 2407 PyPrintAccumulator printAccum; 2408 mlirLocationPrint(self, printAccum.getCallback(), 2409 printAccum.getUserData()); 2410 return printAccum.join(); 2411 }); 2412 2413 //---------------------------------------------------------------------------- 2414 // Mapping of Module 2415 //---------------------------------------------------------------------------- 2416 py::class_<PyModule>(m, "Module", py::module_local()) 2417 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2418 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2419 .def_static( 2420 "parse", 2421 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2422 MlirModule module = mlirModuleCreateParse( 2423 context->get(), toMlirStringRef(moduleAsm)); 2424 // TODO: Rework error reporting once diagnostic engine is exposed 2425 // in C API. 2426 if (mlirModuleIsNull(module)) { 2427 throw SetPyError( 2428 PyExc_ValueError, 2429 "Unable to parse module assembly (see diagnostics)"); 2430 } 2431 return PyModule::forModule(module).releaseObject(); 2432 }, 2433 py::arg("asm"), py::arg("context") = py::none(), 2434 kModuleParseDocstring) 2435 .def_static( 2436 "create", 2437 [](DefaultingPyLocation loc) { 2438 MlirModule module = mlirModuleCreateEmpty(loc); 2439 return PyModule::forModule(module).releaseObject(); 2440 }, 2441 py::arg("loc") = py::none(), "Creates an empty module") 2442 .def_property_readonly( 2443 "context", 2444 [](PyModule &self) { return self.getContext().getObject(); }, 2445 "Context that created the Module") 2446 .def_property_readonly( 2447 "operation", 2448 [](PyModule &self) { 2449 return PyOperation::forOperation(self.getContext(), 2450 mlirModuleGetOperation(self.get()), 2451 self.getRef().releaseObject()) 2452 .releaseObject(); 2453 }, 2454 "Accesses the module as an operation") 2455 .def_property_readonly( 2456 "body", 2457 [](PyModule &self) { 2458 PyOperationRef moduleOp = PyOperation::forOperation( 2459 self.getContext(), mlirModuleGetOperation(self.get()), 2460 self.getRef().releaseObject()); 2461 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 2462 return returnBlock; 2463 }, 2464 "Return the block for this module") 2465 .def( 2466 "dump", 2467 [](PyModule &self) { 2468 mlirOperationDump(mlirModuleGetOperation(self.get())); 2469 }, 2470 kDumpDocstring) 2471 .def( 2472 "__str__", 2473 [](py::object self) { 2474 // Defer to the operation's __str__. 2475 return self.attr("operation").attr("__str__")(); 2476 }, 2477 kOperationStrDunderDocstring); 2478 2479 //---------------------------------------------------------------------------- 2480 // Mapping of Operation. 2481 //---------------------------------------------------------------------------- 2482 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2483 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2484 [](PyOperationBase &self) { 2485 return self.getOperation().getCapsule(); 2486 }) 2487 .def("__eq__", 2488 [](PyOperationBase &self, PyOperationBase &other) { 2489 return &self.getOperation() == &other.getOperation(); 2490 }) 2491 .def("__eq__", 2492 [](PyOperationBase &self, py::object other) { return false; }) 2493 .def("__hash__", 2494 [](PyOperationBase &self) { 2495 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2496 }) 2497 .def_property_readonly("attributes", 2498 [](PyOperationBase &self) { 2499 return PyOpAttributeMap( 2500 self.getOperation().getRef()); 2501 }) 2502 .def_property_readonly("operands", 2503 [](PyOperationBase &self) { 2504 return PyOpOperandList( 2505 self.getOperation().getRef()); 2506 }) 2507 .def_property_readonly("regions", 2508 [](PyOperationBase &self) { 2509 return PyRegionList( 2510 self.getOperation().getRef()); 2511 }) 2512 .def_property_readonly( 2513 "results", 2514 [](PyOperationBase &self) { 2515 return PyOpResultList(self.getOperation().getRef()); 2516 }, 2517 "Returns the list of Operation results.") 2518 .def_property_readonly( 2519 "result", 2520 [](PyOperationBase &self) { 2521 auto &operation = self.getOperation(); 2522 auto numResults = mlirOperationGetNumResults(operation); 2523 if (numResults != 1) { 2524 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2525 throw SetPyError( 2526 PyExc_ValueError, 2527 Twine("Cannot call .result on operation ") + 2528 StringRef(name.data, name.length) + " which has " + 2529 Twine(numResults) + 2530 " results (it is only valid for operations with a " 2531 "single result)"); 2532 } 2533 return PyOpResult(operation.getRef(), 2534 mlirOperationGetResult(operation, 0)); 2535 }, 2536 "Shortcut to get an op result if it has only one (throws an error " 2537 "otherwise).") 2538 .def_property_readonly( 2539 "location", 2540 [](PyOperationBase &self) { 2541 PyOperation &operation = self.getOperation(); 2542 return PyLocation(operation.getContext(), 2543 mlirOperationGetLocation(operation.get())); 2544 }, 2545 "Returns the source location the operation was defined or derived " 2546 "from.") 2547 .def( 2548 "__str__", 2549 [](PyOperationBase &self) { 2550 return self.getAsm(/*binary=*/false, 2551 /*largeElementsLimit=*/llvm::None, 2552 /*enableDebugInfo=*/false, 2553 /*prettyDebugInfo=*/false, 2554 /*printGenericOpForm=*/false, 2555 /*useLocalScope=*/false, 2556 /*assumeVerified=*/false); 2557 }, 2558 "Returns the assembly form of the operation.") 2559 .def("print", &PyOperationBase::print, 2560 // Careful: Lots of arguments must match up with print method. 2561 py::arg("file") = py::none(), py::arg("binary") = false, 2562 py::arg("large_elements_limit") = py::none(), 2563 py::arg("enable_debug_info") = false, 2564 py::arg("pretty_debug_info") = false, 2565 py::arg("print_generic_op_form") = false, 2566 py::arg("use_local_scope") = false, 2567 py::arg("assume_verified") = false, kOperationPrintDocstring) 2568 .def("get_asm", &PyOperationBase::getAsm, 2569 // Careful: Lots of arguments must match up with get_asm method. 2570 py::arg("binary") = false, 2571 py::arg("large_elements_limit") = py::none(), 2572 py::arg("enable_debug_info") = false, 2573 py::arg("pretty_debug_info") = false, 2574 py::arg("print_generic_op_form") = false, 2575 py::arg("use_local_scope") = false, 2576 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2577 .def( 2578 "verify", 2579 [](PyOperationBase &self) { 2580 return mlirOperationVerify(self.getOperation()); 2581 }, 2582 "Verify the operation and return true if it passes, false if it " 2583 "fails.") 2584 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2585 "Puts self immediately after the other operation in its parent " 2586 "block.") 2587 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2588 "Puts self immediately before the other operation in its parent " 2589 "block.") 2590 .def( 2591 "detach_from_parent", 2592 [](PyOperationBase &self) { 2593 PyOperation &operation = self.getOperation(); 2594 operation.checkValid(); 2595 if (!operation.isAttached()) 2596 throw py::value_error("Detached operation has no parent."); 2597 2598 operation.detachFromParent(); 2599 return operation.createOpView(); 2600 }, 2601 "Detaches the operation from its parent block."); 2602 2603 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2604 .def_static("create", &PyOperation::create, py::arg("name"), 2605 py::arg("results") = py::none(), 2606 py::arg("operands") = py::none(), 2607 py::arg("attributes") = py::none(), 2608 py::arg("successors") = py::none(), py::arg("regions") = 0, 2609 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2610 kOperationCreateDocstring) 2611 .def_property_readonly("parent", 2612 [](PyOperation &self) -> py::object { 2613 auto parent = self.getParentOperation(); 2614 if (parent) 2615 return parent->getObject(); 2616 return py::none(); 2617 }) 2618 .def("erase", &PyOperation::erase) 2619 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2620 &PyOperation::getCapsule) 2621 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2622 .def_property_readonly("name", 2623 [](PyOperation &self) { 2624 self.checkValid(); 2625 MlirOperation operation = self.get(); 2626 MlirStringRef name = mlirIdentifierStr( 2627 mlirOperationGetName(operation)); 2628 return py::str(name.data, name.length); 2629 }) 2630 .def_property_readonly( 2631 "context", 2632 [](PyOperation &self) { 2633 self.checkValid(); 2634 return self.getContext().getObject(); 2635 }, 2636 "Context that owns the Operation") 2637 .def_property_readonly("opview", &PyOperation::createOpView); 2638 2639 auto opViewClass = 2640 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2641 .def(py::init<py::object>(), py::arg("operation")) 2642 .def_property_readonly("operation", &PyOpView::getOperationObject) 2643 .def_property_readonly( 2644 "context", 2645 [](PyOpView &self) { 2646 return self.getOperation().getContext().getObject(); 2647 }, 2648 "Context that owns the Operation") 2649 .def("__str__", [](PyOpView &self) { 2650 return py::str(self.getOperationObject()); 2651 }); 2652 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2653 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2654 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2655 opViewClass.attr("build_generic") = classmethod( 2656 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2657 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2658 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2659 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2660 "Builds a specific, generated OpView based on class level attributes."); 2661 2662 //---------------------------------------------------------------------------- 2663 // Mapping of PyRegion. 2664 //---------------------------------------------------------------------------- 2665 py::class_<PyRegion>(m, "Region", py::module_local()) 2666 .def_property_readonly( 2667 "blocks", 2668 [](PyRegion &self) { 2669 return PyBlockList(self.getParentOperation(), self.get()); 2670 }, 2671 "Returns a forward-optimized sequence of blocks.") 2672 .def_property_readonly( 2673 "owner", 2674 [](PyRegion &self) { 2675 return self.getParentOperation()->createOpView(); 2676 }, 2677 "Returns the operation owning this region.") 2678 .def( 2679 "__iter__", 2680 [](PyRegion &self) { 2681 self.checkValid(); 2682 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2683 return PyBlockIterator(self.getParentOperation(), firstBlock); 2684 }, 2685 "Iterates over blocks in the region.") 2686 .def("__eq__", 2687 [](PyRegion &self, PyRegion &other) { 2688 return self.get().ptr == other.get().ptr; 2689 }) 2690 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2691 2692 //---------------------------------------------------------------------------- 2693 // Mapping of PyBlock. 2694 //---------------------------------------------------------------------------- 2695 py::class_<PyBlock>(m, "Block", py::module_local()) 2696 .def_property_readonly( 2697 "owner", 2698 [](PyBlock &self) { 2699 return self.getParentOperation()->createOpView(); 2700 }, 2701 "Returns the owning operation of this block.") 2702 .def_property_readonly( 2703 "region", 2704 [](PyBlock &self) { 2705 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2706 return PyRegion(self.getParentOperation(), region); 2707 }, 2708 "Returns the owning region of this block.") 2709 .def_property_readonly( 2710 "arguments", 2711 [](PyBlock &self) { 2712 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2713 }, 2714 "Returns a list of block arguments.") 2715 .def_property_readonly( 2716 "operations", 2717 [](PyBlock &self) { 2718 return PyOperationList(self.getParentOperation(), self.get()); 2719 }, 2720 "Returns a forward-optimized sequence of operations.") 2721 .def_static( 2722 "create_at_start", 2723 [](PyRegion &parent, py::list pyArgTypes) { 2724 parent.checkValid(); 2725 llvm::SmallVector<MlirType, 4> argTypes; 2726 llvm::SmallVector<MlirLocation, 4> argLocs; 2727 argTypes.reserve(pyArgTypes.size()); 2728 argLocs.reserve(pyArgTypes.size()); 2729 for (auto &pyArg : pyArgTypes) { 2730 argTypes.push_back(pyArg.cast<PyType &>()); 2731 // TODO: Pass in a proper location here. 2732 argLocs.push_back( 2733 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2734 } 2735 2736 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2737 argLocs.data()); 2738 mlirRegionInsertOwnedBlock(parent, 0, block); 2739 return PyBlock(parent.getParentOperation(), block); 2740 }, 2741 py::arg("parent"), py::arg("arg_types") = py::list(), 2742 "Creates and returns a new Block at the beginning of the given " 2743 "region (with given argument types).") 2744 .def( 2745 "create_before", 2746 [](PyBlock &self, py::args pyArgTypes) { 2747 self.checkValid(); 2748 llvm::SmallVector<MlirType, 4> argTypes; 2749 llvm::SmallVector<MlirLocation, 4> argLocs; 2750 argTypes.reserve(pyArgTypes.size()); 2751 argLocs.reserve(pyArgTypes.size()); 2752 for (auto &pyArg : pyArgTypes) { 2753 argTypes.push_back(pyArg.cast<PyType &>()); 2754 // TODO: Pass in a proper location here. 2755 argLocs.push_back( 2756 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2757 } 2758 2759 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2760 argLocs.data()); 2761 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2762 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2763 return PyBlock(self.getParentOperation(), block); 2764 }, 2765 "Creates and returns a new Block before this block " 2766 "(with given argument types).") 2767 .def( 2768 "create_after", 2769 [](PyBlock &self, py::args pyArgTypes) { 2770 self.checkValid(); 2771 llvm::SmallVector<MlirType, 4> argTypes; 2772 llvm::SmallVector<MlirLocation, 4> argLocs; 2773 argTypes.reserve(pyArgTypes.size()); 2774 argLocs.reserve(pyArgTypes.size()); 2775 for (auto &pyArg : pyArgTypes) { 2776 argTypes.push_back(pyArg.cast<PyType &>()); 2777 2778 // TODO: Pass in a proper location here. 2779 argLocs.push_back( 2780 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2781 } 2782 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2783 argLocs.data()); 2784 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2785 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2786 return PyBlock(self.getParentOperation(), block); 2787 }, 2788 "Creates and returns a new Block after this block " 2789 "(with given argument types).") 2790 .def( 2791 "__iter__", 2792 [](PyBlock &self) { 2793 self.checkValid(); 2794 MlirOperation firstOperation = 2795 mlirBlockGetFirstOperation(self.get()); 2796 return PyOperationIterator(self.getParentOperation(), 2797 firstOperation); 2798 }, 2799 "Iterates over operations in the block.") 2800 .def("__eq__", 2801 [](PyBlock &self, PyBlock &other) { 2802 return self.get().ptr == other.get().ptr; 2803 }) 2804 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2805 .def( 2806 "__str__", 2807 [](PyBlock &self) { 2808 self.checkValid(); 2809 PyPrintAccumulator printAccum; 2810 mlirBlockPrint(self.get(), printAccum.getCallback(), 2811 printAccum.getUserData()); 2812 return printAccum.join(); 2813 }, 2814 "Returns the assembly form of the block.") 2815 .def( 2816 "append", 2817 [](PyBlock &self, PyOperationBase &operation) { 2818 if (operation.getOperation().isAttached()) 2819 operation.getOperation().detachFromParent(); 2820 2821 MlirOperation mlirOperation = operation.getOperation().get(); 2822 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2823 operation.getOperation().setAttached( 2824 self.getParentOperation().getObject()); 2825 }, 2826 py::arg("operation"), 2827 "Appends an operation to this block. If the operation is currently " 2828 "in another block, it will be moved."); 2829 2830 //---------------------------------------------------------------------------- 2831 // Mapping of PyInsertionPoint. 2832 //---------------------------------------------------------------------------- 2833 2834 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2835 .def(py::init<PyBlock &>(), py::arg("block"), 2836 "Inserts after the last operation but still inside the block.") 2837 .def("__enter__", &PyInsertionPoint::contextEnter) 2838 .def("__exit__", &PyInsertionPoint::contextExit) 2839 .def_property_readonly_static( 2840 "current", 2841 [](py::object & /*class*/) { 2842 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2843 if (!ip) 2844 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2845 return ip; 2846 }, 2847 "Gets the InsertionPoint bound to the current thread or raises " 2848 "ValueError if none has been set") 2849 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2850 "Inserts before a referenced operation.") 2851 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2852 py::arg("block"), "Inserts at the beginning of the block.") 2853 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2854 py::arg("block"), "Inserts before the block terminator.") 2855 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2856 "Inserts an operation.") 2857 .def_property_readonly( 2858 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2859 "Returns the block that this InsertionPoint points to."); 2860 2861 //---------------------------------------------------------------------------- 2862 // Mapping of PyAttribute. 2863 //---------------------------------------------------------------------------- 2864 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2865 // Delegate to the PyAttribute copy constructor, which will also lifetime 2866 // extend the backing context which owns the MlirAttribute. 2867 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2868 "Casts the passed attribute to the generic Attribute") 2869 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2870 &PyAttribute::getCapsule) 2871 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2872 .def_static( 2873 "parse", 2874 [](std::string attrSpec, DefaultingPyMlirContext context) { 2875 MlirAttribute type = mlirAttributeParseGet( 2876 context->get(), toMlirStringRef(attrSpec)); 2877 // TODO: Rework error reporting once diagnostic engine is exposed 2878 // in C API. 2879 if (mlirAttributeIsNull(type)) { 2880 throw SetPyError(PyExc_ValueError, 2881 Twine("Unable to parse attribute: '") + 2882 attrSpec + "'"); 2883 } 2884 return PyAttribute(context->getRef(), type); 2885 }, 2886 py::arg("asm"), py::arg("context") = py::none(), 2887 "Parses an attribute from an assembly form") 2888 .def_property_readonly( 2889 "context", 2890 [](PyAttribute &self) { return self.getContext().getObject(); }, 2891 "Context that owns the Attribute") 2892 .def_property_readonly("type", 2893 [](PyAttribute &self) { 2894 return PyType(self.getContext()->getRef(), 2895 mlirAttributeGetType(self)); 2896 }) 2897 .def( 2898 "get_named", 2899 [](PyAttribute &self, std::string name) { 2900 return PyNamedAttribute(self, std::move(name)); 2901 }, 2902 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2903 .def("__eq__", 2904 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2905 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2906 .def("__hash__", 2907 [](PyAttribute &self) { 2908 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2909 }) 2910 .def( 2911 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2912 kDumpDocstring) 2913 .def( 2914 "__str__", 2915 [](PyAttribute &self) { 2916 PyPrintAccumulator printAccum; 2917 mlirAttributePrint(self, printAccum.getCallback(), 2918 printAccum.getUserData()); 2919 return printAccum.join(); 2920 }, 2921 "Returns the assembly form of the Attribute.") 2922 .def("__repr__", [](PyAttribute &self) { 2923 // Generally, assembly formats are not printed for __repr__ because 2924 // this can cause exceptionally long debug output and exceptions. 2925 // However, attribute values are generally considered useful and are 2926 // printed. This may need to be re-evaluated if debug dumps end up 2927 // being excessive. 2928 PyPrintAccumulator printAccum; 2929 printAccum.parts.append("Attribute("); 2930 mlirAttributePrint(self, printAccum.getCallback(), 2931 printAccum.getUserData()); 2932 printAccum.parts.append(")"); 2933 return printAccum.join(); 2934 }); 2935 2936 //---------------------------------------------------------------------------- 2937 // Mapping of PyNamedAttribute 2938 //---------------------------------------------------------------------------- 2939 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2940 .def("__repr__", 2941 [](PyNamedAttribute &self) { 2942 PyPrintAccumulator printAccum; 2943 printAccum.parts.append("NamedAttribute("); 2944 printAccum.parts.append( 2945 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2946 mlirIdentifierStr(self.namedAttr.name).length)); 2947 printAccum.parts.append("="); 2948 mlirAttributePrint(self.namedAttr.attribute, 2949 printAccum.getCallback(), 2950 printAccum.getUserData()); 2951 printAccum.parts.append(")"); 2952 return printAccum.join(); 2953 }) 2954 .def_property_readonly( 2955 "name", 2956 [](PyNamedAttribute &self) { 2957 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2958 mlirIdentifierStr(self.namedAttr.name).length); 2959 }, 2960 "The name of the NamedAttribute binding") 2961 .def_property_readonly( 2962 "attr", 2963 [](PyNamedAttribute &self) { 2964 // TODO: When named attribute is removed/refactored, also remove 2965 // this constructor (it does an inefficient table lookup). 2966 auto contextRef = PyMlirContext::forContext( 2967 mlirAttributeGetContext(self.namedAttr.attribute)); 2968 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2969 }, 2970 py::keep_alive<0, 1>(), 2971 "The underlying generic attribute of the NamedAttribute binding"); 2972 2973 //---------------------------------------------------------------------------- 2974 // Mapping of PyType. 2975 //---------------------------------------------------------------------------- 2976 py::class_<PyType>(m, "Type", py::module_local()) 2977 // Delegate to the PyType copy constructor, which will also lifetime 2978 // extend the backing context which owns the MlirType. 2979 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2980 "Casts the passed type to the generic Type") 2981 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2982 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2983 .def_static( 2984 "parse", 2985 [](std::string typeSpec, DefaultingPyMlirContext context) { 2986 MlirType type = 2987 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2988 // TODO: Rework error reporting once diagnostic engine is exposed 2989 // in C API. 2990 if (mlirTypeIsNull(type)) { 2991 throw SetPyError(PyExc_ValueError, 2992 Twine("Unable to parse type: '") + typeSpec + 2993 "'"); 2994 } 2995 return PyType(context->getRef(), type); 2996 }, 2997 py::arg("asm"), py::arg("context") = py::none(), 2998 kContextParseTypeDocstring) 2999 .def_property_readonly( 3000 "context", [](PyType &self) { return self.getContext().getObject(); }, 3001 "Context that owns the Type") 3002 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3003 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 3004 .def("__hash__", 3005 [](PyType &self) { 3006 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3007 }) 3008 .def( 3009 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3010 .def( 3011 "__str__", 3012 [](PyType &self) { 3013 PyPrintAccumulator printAccum; 3014 mlirTypePrint(self, printAccum.getCallback(), 3015 printAccum.getUserData()); 3016 return printAccum.join(); 3017 }, 3018 "Returns the assembly form of the type.") 3019 .def("__repr__", [](PyType &self) { 3020 // Generally, assembly formats are not printed for __repr__ because 3021 // this can cause exceptionally long debug output and exceptions. 3022 // However, types are an exception as they typically have compact 3023 // assembly forms and printing them is useful. 3024 PyPrintAccumulator printAccum; 3025 printAccum.parts.append("Type("); 3026 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 3027 printAccum.parts.append(")"); 3028 return printAccum.join(); 3029 }); 3030 3031 //---------------------------------------------------------------------------- 3032 // Mapping of Value. 3033 //---------------------------------------------------------------------------- 3034 py::class_<PyValue>(m, "Value", py::module_local()) 3035 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3036 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3037 .def_property_readonly( 3038 "context", 3039 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3040 "Context in which the value lives.") 3041 .def( 3042 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3043 kDumpDocstring) 3044 .def_property_readonly( 3045 "owner", 3046 [](PyValue &self) { 3047 assert(mlirOperationEqual(self.getParentOperation()->get(), 3048 mlirOpResultGetOwner(self.get())) && 3049 "expected the owner of the value in Python to match that in " 3050 "the IR"); 3051 return self.getParentOperation().getObject(); 3052 }) 3053 .def("__eq__", 3054 [](PyValue &self, PyValue &other) { 3055 return self.get().ptr == other.get().ptr; 3056 }) 3057 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 3058 .def("__hash__", 3059 [](PyValue &self) { 3060 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3061 }) 3062 .def( 3063 "__str__", 3064 [](PyValue &self) { 3065 PyPrintAccumulator printAccum; 3066 printAccum.parts.append("Value("); 3067 mlirValuePrint(self.get(), printAccum.getCallback(), 3068 printAccum.getUserData()); 3069 printAccum.parts.append(")"); 3070 return printAccum.join(); 3071 }, 3072 kValueDunderStrDocstring) 3073 .def_property_readonly("type", [](PyValue &self) { 3074 return PyType(self.getParentOperation()->getContext(), 3075 mlirValueGetType(self.get())); 3076 }); 3077 PyBlockArgument::bind(m); 3078 PyOpResult::bind(m); 3079 3080 //---------------------------------------------------------------------------- 3081 // Mapping of SymbolTable. 3082 //---------------------------------------------------------------------------- 3083 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 3084 .def(py::init<PyOperationBase &>()) 3085 .def("__getitem__", &PySymbolTable::dunderGetItem) 3086 .def("insert", &PySymbolTable::insert, py::arg("operation")) 3087 .def("erase", &PySymbolTable::erase, py::arg("operation")) 3088 .def("__delitem__", &PySymbolTable::dunderDel) 3089 .def("__contains__", 3090 [](PySymbolTable &table, const std::string &name) { 3091 return !mlirOperationIsNull(mlirSymbolTableLookup( 3092 table, mlirStringRefCreate(name.data(), name.length()))); 3093 }) 3094 // Static helpers. 3095 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3096 py::arg("symbol"), py::arg("name")) 3097 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3098 py::arg("symbol")) 3099 .def_static("get_visibility", &PySymbolTable::getVisibility, 3100 py::arg("symbol")) 3101 .def_static("set_visibility", &PySymbolTable::setVisibility, 3102 py::arg("symbol"), py::arg("visibility")) 3103 .def_static("replace_all_symbol_uses", 3104 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 3105 py::arg("new_symbol"), py::arg("from_op")) 3106 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3107 py::arg("from_op"), py::arg("all_sym_uses_visible"), 3108 py::arg("callback")); 3109 3110 // Container bindings. 3111 PyBlockArgumentList::bind(m); 3112 PyBlockIterator::bind(m); 3113 PyBlockList::bind(m); 3114 PyOperationIterator::bind(m); 3115 PyOperationList::bind(m); 3116 PyOpAttributeMap::bind(m); 3117 PyOpOperandList::bind(m); 3118 PyOpResultList::bind(m); 3119 PyRegionIterator::bind(m); 3120 PyRegionList::bind(m); 3121 3122 // Debug bindings. 3123 PyGlobalDebugFlag::bind(m); 3124 } 3125