1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 //#include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 23 #include <utility> 24 25 namespace py = pybind11; 26 using namespace mlir; 27 using namespace mlir::python; 28 29 using llvm::SmallVector; 30 using llvm::StringRef; 31 using llvm::Twine; 32 33 //------------------------------------------------------------------------------ 34 // Docstrings (trivial, non-duplicated docstrings are included inline). 35 //------------------------------------------------------------------------------ 36 37 static const char kContextParseTypeDocstring[] = 38 R"(Parses the assembly form of a type. 39 40 Returns a Type object or raises a ValueError if the type cannot be parsed. 41 42 See also: https://mlir.llvm.org/docs/LangRef/#type-system 43 )"; 44 45 static const char kContextGetCallSiteLocationDocstring[] = 46 R"(Gets a Location representing a caller and callsite)"; 47 48 static const char kContextGetFileLocationDocstring[] = 49 R"(Gets a Location representing a file, line and column)"; 50 51 static const char kContextGetFusedLocationDocstring[] = 52 R"(Gets a Location representing a fused location with optional metadata)"; 53 54 static const char kContextGetNameLocationDocString[] = 55 R"(Gets a Location representing a named location with optional child location)"; 56 57 static const char kModuleParseDocstring[] = 58 R"(Parses a module's assembly format from a string. 59 60 Returns a new MlirModule or raises a ValueError if the parsing fails. 61 62 See also: https://mlir.llvm.org/docs/LangRef/ 63 )"; 64 65 static const char kOperationCreateDocstring[] = 66 R"(Creates a new operation. 67 68 Args: 69 name: Operation name (e.g. "dialect.operation"). 70 results: Sequence of Type representing op result types. 71 attributes: Dict of str:Attribute. 72 successors: List of Block for the operation's successors. 73 regions: Number of regions to create. 74 location: A Location object (defaults to resolve from context manager). 75 ip: An InsertionPoint (defaults to resolve from context manager or set to 76 False to disable insertion, even with an insertion point set in the 77 context manager). 78 Returns: 79 A new "detached" Operation object. Detached operations can be added 80 to blocks, which causes them to become "attached." 81 )"; 82 83 static const char kOperationPrintDocstring[] = 84 R"(Prints the assembly form of the operation to a file like object. 85 86 Args: 87 file: The file like object to write to. Defaults to sys.stdout. 88 binary: Whether to write bytes (True) or str (False). Defaults to False. 89 large_elements_limit: Whether to elide elements attributes above this 90 number of elements. Defaults to None (no limit). 91 enable_debug_info: Whether to print debug/location information. Defaults 92 to False. 93 pretty_debug_info: Whether to format debug information for easier reading 94 by a human (warning: the result is unparseable). 95 print_generic_op_form: Whether to print the generic assembly forms of all 96 ops. Defaults to False. 97 use_local_Scope: Whether to print in a way that is more optimized for 98 multi-threaded access but may not be consistent with how the overall 99 module prints. 100 assume_verified: By default, if not printing generic form, the verifier 101 will be run and if it fails, generic form will be printed with a comment 102 about failed verification. While a reasonable default for interactive use, 103 for systematic use, it is often better for the caller to verify explicitly 104 and report failures in a more robust fashion. Set this to True if doing this 105 in order to avoid running a redundant verification. If the IR is actually 106 invalid, behavior is undefined. 107 )"; 108 109 static const char kOperationGetAsmDocstring[] = 110 R"(Gets the assembly form of the operation with all options available. 111 112 Args: 113 binary: Whether to return a bytes (True) or str (False) object. Defaults to 114 False. 115 ... others ...: See the print() method for common keyword arguments for 116 configuring the printout. 117 Returns: 118 Either a bytes or str object, depending on the setting of the 'binary' 119 argument. 120 )"; 121 122 static const char kOperationStrDunderDocstring[] = 123 R"(Gets the assembly form of the operation with default options. 124 125 If more advanced control over the assembly formatting or I/O options is needed, 126 use the dedicated print or get_asm method, which supports keyword arguments to 127 customize behavior. 128 )"; 129 130 static const char kDumpDocstring[] = 131 R"(Dumps a debug representation of the object to stderr.)"; 132 133 static const char kAppendBlockDocstring[] = 134 R"(Appends a new block, with argument types as positional args. 135 136 Returns: 137 The created block. 138 )"; 139 140 static const char kValueDunderStrDocstring[] = 141 R"(Returns the string form of the value. 142 143 If the value is a block argument, this is the assembly form of its type and the 144 position in the argument list. If the value is an operation result, this is 145 equivalent to printing the operation that produced it. 146 )"; 147 148 //------------------------------------------------------------------------------ 149 // Utilities. 150 //------------------------------------------------------------------------------ 151 152 /// Helper for creating an @classmethod. 153 template <class Func, typename... Args> 154 py::object classmethod(Func f, Args... args) { 155 py::object cf = py::cpp_function(f, args...); 156 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 157 } 158 159 static py::object 160 createCustomDialectWrapper(const std::string &dialectNamespace, 161 py::object dialectDescriptor) { 162 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 163 if (!dialectClass) { 164 // Use the base class. 165 return py::cast(PyDialect(std::move(dialectDescriptor))); 166 } 167 168 // Create the custom implementation. 169 return (*dialectClass)(std::move(dialectDescriptor)); 170 } 171 172 static MlirStringRef toMlirStringRef(const std::string &s) { 173 return mlirStringRefCreate(s.data(), s.size()); 174 } 175 176 /// Wrapper for the global LLVM debugging flag. 177 struct PyGlobalDebugFlag { 178 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 179 180 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } 181 182 static void bind(py::module &m) { 183 // Debug flags. 184 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 185 .def_property_static("flag", &PyGlobalDebugFlag::get, 186 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 187 } 188 }; 189 190 //------------------------------------------------------------------------------ 191 // Collections. 192 //------------------------------------------------------------------------------ 193 194 namespace { 195 196 class PyRegionIterator { 197 public: 198 PyRegionIterator(PyOperationRef operation) 199 : operation(std::move(operation)) {} 200 201 PyRegionIterator &dunderIter() { return *this; } 202 203 PyRegion dunderNext() { 204 operation->checkValid(); 205 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 206 throw py::stop_iteration(); 207 } 208 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 209 return PyRegion(operation, region); 210 } 211 212 static void bind(py::module &m) { 213 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 214 .def("__iter__", &PyRegionIterator::dunderIter) 215 .def("__next__", &PyRegionIterator::dunderNext); 216 } 217 218 private: 219 PyOperationRef operation; 220 int nextIndex = 0; 221 }; 222 223 /// Regions of an op are fixed length and indexed numerically so are represented 224 /// with a sequence-like container. 225 class PyRegionList { 226 public: 227 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 228 229 intptr_t dunderLen() { 230 operation->checkValid(); 231 return mlirOperationGetNumRegions(operation->get()); 232 } 233 234 PyRegion dunderGetItem(intptr_t index) { 235 // dunderLen checks validity. 236 if (index < 0 || index >= dunderLen()) { 237 throw SetPyError(PyExc_IndexError, 238 "attempt to access out of bounds region"); 239 } 240 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 241 return PyRegion(operation, region); 242 } 243 244 static void bind(py::module &m) { 245 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 246 .def("__len__", &PyRegionList::dunderLen) 247 .def("__getitem__", &PyRegionList::dunderGetItem); 248 } 249 250 private: 251 PyOperationRef operation; 252 }; 253 254 class PyBlockIterator { 255 public: 256 PyBlockIterator(PyOperationRef operation, MlirBlock next) 257 : operation(std::move(operation)), next(next) {} 258 259 PyBlockIterator &dunderIter() { return *this; } 260 261 PyBlock dunderNext() { 262 operation->checkValid(); 263 if (mlirBlockIsNull(next)) { 264 throw py::stop_iteration(); 265 } 266 267 PyBlock returnBlock(operation, next); 268 next = mlirBlockGetNextInRegion(next); 269 return returnBlock; 270 } 271 272 static void bind(py::module &m) { 273 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 274 .def("__iter__", &PyBlockIterator::dunderIter) 275 .def("__next__", &PyBlockIterator::dunderNext); 276 } 277 278 private: 279 PyOperationRef operation; 280 MlirBlock next; 281 }; 282 283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 284 /// we present them as a more full-featured list-like container but optimize 285 /// it for forward iteration. Blocks are always owned by a region. 286 class PyBlockList { 287 public: 288 PyBlockList(PyOperationRef operation, MlirRegion region) 289 : operation(std::move(operation)), region(region) {} 290 291 PyBlockIterator dunderIter() { 292 operation->checkValid(); 293 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 294 } 295 296 intptr_t dunderLen() { 297 operation->checkValid(); 298 intptr_t count = 0; 299 MlirBlock block = mlirRegionGetFirstBlock(region); 300 while (!mlirBlockIsNull(block)) { 301 count += 1; 302 block = mlirBlockGetNextInRegion(block); 303 } 304 return count; 305 } 306 307 PyBlock dunderGetItem(intptr_t index) { 308 operation->checkValid(); 309 if (index < 0) { 310 throw SetPyError(PyExc_IndexError, 311 "attempt to access out of bounds block"); 312 } 313 MlirBlock block = mlirRegionGetFirstBlock(region); 314 while (!mlirBlockIsNull(block)) { 315 if (index == 0) { 316 return PyBlock(operation, block); 317 } 318 block = mlirBlockGetNextInRegion(block); 319 index -= 1; 320 } 321 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 322 } 323 324 PyBlock appendBlock(const py::args &pyArgTypes) { 325 operation->checkValid(); 326 llvm::SmallVector<MlirType, 4> argTypes; 327 llvm::SmallVector<MlirLocation, 4> argLocs; 328 argTypes.reserve(pyArgTypes.size()); 329 argLocs.reserve(pyArgTypes.size()); 330 for (auto &pyArg : pyArgTypes) { 331 argTypes.push_back(pyArg.cast<PyType &>()); 332 // TODO: Pass in a proper location here. 333 argLocs.push_back( 334 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 335 } 336 337 MlirBlock block = 338 mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 339 mlirRegionAppendOwnedBlock(region, block); 340 return PyBlock(operation, block); 341 } 342 343 static void bind(py::module &m) { 344 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 345 .def("__getitem__", &PyBlockList::dunderGetItem) 346 .def("__iter__", &PyBlockList::dunderIter) 347 .def("__len__", &PyBlockList::dunderLen) 348 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 349 } 350 351 private: 352 PyOperationRef operation; 353 MlirRegion region; 354 }; 355 356 class PyOperationIterator { 357 public: 358 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 359 : parentOperation(std::move(parentOperation)), next(next) {} 360 361 PyOperationIterator &dunderIter() { return *this; } 362 363 py::object dunderNext() { 364 parentOperation->checkValid(); 365 if (mlirOperationIsNull(next)) { 366 throw py::stop_iteration(); 367 } 368 369 PyOperationRef returnOperation = 370 PyOperation::forOperation(parentOperation->getContext(), next); 371 next = mlirOperationGetNextInBlock(next); 372 return returnOperation->createOpView(); 373 } 374 375 static void bind(py::module &m) { 376 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 377 .def("__iter__", &PyOperationIterator::dunderIter) 378 .def("__next__", &PyOperationIterator::dunderNext); 379 } 380 381 private: 382 PyOperationRef parentOperation; 383 MlirOperation next; 384 }; 385 386 /// Operations are exposed by the C-API as a forward-only linked list. In 387 /// Python, we present them as a more full-featured list-like container but 388 /// optimize it for forward iteration. Iterable operations are always owned 389 /// by a block. 390 class PyOperationList { 391 public: 392 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 393 : parentOperation(std::move(parentOperation)), block(block) {} 394 395 PyOperationIterator dunderIter() { 396 parentOperation->checkValid(); 397 return PyOperationIterator(parentOperation, 398 mlirBlockGetFirstOperation(block)); 399 } 400 401 intptr_t dunderLen() { 402 parentOperation->checkValid(); 403 intptr_t count = 0; 404 MlirOperation childOp = mlirBlockGetFirstOperation(block); 405 while (!mlirOperationIsNull(childOp)) { 406 count += 1; 407 childOp = mlirOperationGetNextInBlock(childOp); 408 } 409 return count; 410 } 411 412 py::object dunderGetItem(intptr_t index) { 413 parentOperation->checkValid(); 414 if (index < 0) { 415 throw SetPyError(PyExc_IndexError, 416 "attempt to access out of bounds operation"); 417 } 418 MlirOperation childOp = mlirBlockGetFirstOperation(block); 419 while (!mlirOperationIsNull(childOp)) { 420 if (index == 0) { 421 return PyOperation::forOperation(parentOperation->getContext(), childOp) 422 ->createOpView(); 423 } 424 childOp = mlirOperationGetNextInBlock(childOp); 425 index -= 1; 426 } 427 throw SetPyError(PyExc_IndexError, 428 "attempt to access out of bounds operation"); 429 } 430 431 static void bind(py::module &m) { 432 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 433 .def("__getitem__", &PyOperationList::dunderGetItem) 434 .def("__iter__", &PyOperationList::dunderIter) 435 .def("__len__", &PyOperationList::dunderLen); 436 } 437 438 private: 439 PyOperationRef parentOperation; 440 MlirBlock block; 441 }; 442 443 } // namespace 444 445 //------------------------------------------------------------------------------ 446 // PyMlirContext 447 //------------------------------------------------------------------------------ 448 449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 450 py::gil_scoped_acquire acquire; 451 auto &liveContexts = getLiveContexts(); 452 liveContexts[context.ptr] = this; 453 } 454 455 PyMlirContext::~PyMlirContext() { 456 // Note that the only public way to construct an instance is via the 457 // forContext method, which always puts the associated handle into 458 // liveContexts. 459 py::gil_scoped_acquire acquire; 460 getLiveContexts().erase(context.ptr); 461 mlirContextDestroy(context); 462 } 463 464 py::object PyMlirContext::getCapsule() { 465 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 466 } 467 468 py::object PyMlirContext::createFromCapsule(py::object capsule) { 469 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 470 if (mlirContextIsNull(rawContext)) 471 throw py::error_already_set(); 472 return forContext(rawContext).releaseObject(); 473 } 474 475 PyMlirContext *PyMlirContext::createNewContextForInit() { 476 MlirContext context = mlirContextCreate(); 477 return new PyMlirContext(context); 478 } 479 480 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 481 py::gil_scoped_acquire acquire; 482 auto &liveContexts = getLiveContexts(); 483 auto it = liveContexts.find(context.ptr); 484 if (it == liveContexts.end()) { 485 // Create. 486 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 487 py::object pyRef = py::cast(unownedContextWrapper); 488 assert(pyRef && "cast to py::object failed"); 489 liveContexts[context.ptr] = unownedContextWrapper; 490 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 491 } 492 // Use existing. 493 py::object pyRef = py::cast(it->second); 494 return PyMlirContextRef(it->second, std::move(pyRef)); 495 } 496 497 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 498 static LiveContextMap liveContexts; 499 return liveContexts; 500 } 501 502 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 503 504 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 505 506 size_t PyMlirContext::clearLiveOperations() { 507 for (auto &op : liveOperations) 508 op.second.second->setInvalid(); 509 size_t numInvalidated = liveOperations.size(); 510 liveOperations.clear(); 511 return numInvalidated; 512 } 513 514 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 515 516 pybind11::object PyMlirContext::contextEnter() { 517 return PyThreadContextEntry::pushContext(*this); 518 } 519 520 void PyMlirContext::contextExit(const pybind11::object &excType, 521 const pybind11::object &excVal, 522 const pybind11::object &excTb) { 523 PyThreadContextEntry::popContext(*this); 524 } 525 526 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { 527 // Note that ownership is transferred to the delete callback below by way of 528 // an explicit inc_ref (borrow). 529 PyDiagnosticHandler *pyHandler = 530 new PyDiagnosticHandler(get(), std::move(callback)); 531 py::object pyHandlerObject = 532 py::cast(pyHandler, py::return_value_policy::take_ownership); 533 pyHandlerObject.inc_ref(); 534 535 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 536 // guaranteed to be known to pybind. 537 auto handlerCallback = 538 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 539 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 540 py::object pyDiagnosticObject = 541 py::cast(pyDiagnostic, py::return_value_policy::take_ownership); 542 543 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 544 bool result = false; 545 { 546 // Since this can be called from arbitrary C++ contexts, always get the 547 // gil. 548 py::gil_scoped_acquire gil; 549 try { 550 result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); 551 } catch (std::exception &e) { 552 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 553 e.what()); 554 pyHandler->hadError = true; 555 } 556 } 557 558 pyDiagnostic->invalidate(); 559 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 560 }; 561 auto deleteCallback = +[](void *userData) { 562 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 563 assert(pyHandler->registeredID && "handler is not registered"); 564 pyHandler->registeredID.reset(); 565 566 // Decrement reference, balancing the inc_ref() above. 567 py::object pyHandlerObject = 568 py::cast(pyHandler, py::return_value_policy::reference); 569 pyHandlerObject.dec_ref(); 570 }; 571 572 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 573 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 574 return pyHandlerObject; 575 } 576 577 PyMlirContext &DefaultingPyMlirContext::resolve() { 578 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 579 if (!context) { 580 throw SetPyError( 581 PyExc_RuntimeError, 582 "An MLIR function requires a Context but none was provided in the call " 583 "or from the surrounding environment. Either pass to the function with " 584 "a 'context=' argument or establish a default using 'with Context():'"); 585 } 586 return *context; 587 } 588 589 //------------------------------------------------------------------------------ 590 // PyThreadContextEntry management 591 //------------------------------------------------------------------------------ 592 593 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 594 static thread_local std::vector<PyThreadContextEntry> stack; 595 return stack; 596 } 597 598 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 599 auto &stack = getStack(); 600 if (stack.empty()) 601 return nullptr; 602 return &stack.back(); 603 } 604 605 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 606 py::object insertionPoint, 607 py::object location) { 608 auto &stack = getStack(); 609 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 610 std::move(location)); 611 // If the new stack has more than one entry and the context of the new top 612 // entry matches the previous, copy the insertionPoint and location from the 613 // previous entry if missing from the new top entry. 614 if (stack.size() > 1) { 615 auto &prev = *(stack.rbegin() + 1); 616 auto ¤t = stack.back(); 617 if (current.context.is(prev.context)) { 618 // Default non-context objects from the previous entry. 619 if (!current.insertionPoint) 620 current.insertionPoint = prev.insertionPoint; 621 if (!current.location) 622 current.location = prev.location; 623 } 624 } 625 } 626 627 PyMlirContext *PyThreadContextEntry::getContext() { 628 if (!context) 629 return nullptr; 630 return py::cast<PyMlirContext *>(context); 631 } 632 633 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 634 if (!insertionPoint) 635 return nullptr; 636 return py::cast<PyInsertionPoint *>(insertionPoint); 637 } 638 639 PyLocation *PyThreadContextEntry::getLocation() { 640 if (!location) 641 return nullptr; 642 return py::cast<PyLocation *>(location); 643 } 644 645 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 646 auto *tos = getTopOfStack(); 647 return tos ? tos->getContext() : nullptr; 648 } 649 650 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 651 auto *tos = getTopOfStack(); 652 return tos ? tos->getInsertionPoint() : nullptr; 653 } 654 655 PyLocation *PyThreadContextEntry::getDefaultLocation() { 656 auto *tos = getTopOfStack(); 657 return tos ? tos->getLocation() : nullptr; 658 } 659 660 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 661 py::object contextObj = py::cast(context); 662 push(FrameKind::Context, /*context=*/contextObj, 663 /*insertionPoint=*/py::object(), 664 /*location=*/py::object()); 665 return contextObj; 666 } 667 668 void PyThreadContextEntry::popContext(PyMlirContext &context) { 669 auto &stack = getStack(); 670 if (stack.empty()) 671 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 672 auto &tos = stack.back(); 673 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 674 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 675 stack.pop_back(); 676 } 677 678 py::object 679 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 680 py::object contextObj = 681 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 682 py::object insertionPointObj = py::cast(insertionPoint); 683 push(FrameKind::InsertionPoint, 684 /*context=*/contextObj, 685 /*insertionPoint=*/insertionPointObj, 686 /*location=*/py::object()); 687 return insertionPointObj; 688 } 689 690 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 691 auto &stack = getStack(); 692 if (stack.empty()) 693 throw SetPyError(PyExc_RuntimeError, 694 "Unbalanced InsertionPoint enter/exit"); 695 auto &tos = stack.back(); 696 if (tos.frameKind != FrameKind::InsertionPoint && 697 tos.getInsertionPoint() != &insertionPoint) 698 throw SetPyError(PyExc_RuntimeError, 699 "Unbalanced InsertionPoint enter/exit"); 700 stack.pop_back(); 701 } 702 703 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 704 py::object contextObj = location.getContext().getObject(); 705 py::object locationObj = py::cast(location); 706 push(FrameKind::Location, /*context=*/contextObj, 707 /*insertionPoint=*/py::object(), 708 /*location=*/locationObj); 709 return locationObj; 710 } 711 712 void PyThreadContextEntry::popLocation(PyLocation &location) { 713 auto &stack = getStack(); 714 if (stack.empty()) 715 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 716 auto &tos = stack.back(); 717 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 718 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 719 stack.pop_back(); 720 } 721 722 //------------------------------------------------------------------------------ 723 // PyDiagnostic* 724 //------------------------------------------------------------------------------ 725 726 void PyDiagnostic::invalidate() { 727 valid = false; 728 if (materializedNotes) { 729 for (auto ¬eObject : *materializedNotes) { 730 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); 731 note->invalidate(); 732 } 733 } 734 } 735 736 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 737 py::object callback) 738 : context(context), callback(std::move(callback)) {} 739 740 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 741 742 void PyDiagnosticHandler::detach() { 743 if (!registeredID) 744 return; 745 MlirDiagnosticHandlerID localID = *registeredID; 746 mlirContextDetachDiagnosticHandler(context, localID); 747 assert(!registeredID && "should have unregistered"); 748 // Not strictly necessary but keeps stale pointers from being around to cause 749 // issues. 750 context = {nullptr}; 751 } 752 753 void PyDiagnostic::checkValid() { 754 if (!valid) { 755 throw std::invalid_argument( 756 "Diagnostic is invalid (used outside of callback)"); 757 } 758 } 759 760 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 761 checkValid(); 762 return mlirDiagnosticGetSeverity(diagnostic); 763 } 764 765 PyLocation PyDiagnostic::getLocation() { 766 checkValid(); 767 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 768 MlirContext context = mlirLocationGetContext(loc); 769 return PyLocation(PyMlirContext::forContext(context), loc); 770 } 771 772 py::str PyDiagnostic::getMessage() { 773 checkValid(); 774 py::object fileObject = py::module::import("io").attr("StringIO")(); 775 PyFileAccumulator accum(fileObject, /*binary=*/false); 776 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 777 return fileObject.attr("getvalue")(); 778 } 779 780 py::tuple PyDiagnostic::getNotes() { 781 checkValid(); 782 if (materializedNotes) 783 return *materializedNotes; 784 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 785 materializedNotes = py::tuple(numNotes); 786 for (intptr_t i = 0; i < numNotes; ++i) { 787 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 788 py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); 789 PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); 790 } 791 return *materializedNotes; 792 } 793 794 //------------------------------------------------------------------------------ 795 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry 796 //------------------------------------------------------------------------------ 797 798 MlirDialect PyDialects::getDialectForKey(const std::string &key, 799 bool attrError) { 800 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 801 {key.data(), key.size()}); 802 if (mlirDialectIsNull(dialect)) { 803 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 804 Twine("Dialect '") + key + "' not found"); 805 } 806 return dialect; 807 } 808 809 py::object PyDialectRegistry::getCapsule() { 810 return py::reinterpret_steal<py::object>( 811 mlirPythonDialectRegistryToCapsule(*this)); 812 } 813 814 PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { 815 MlirDialectRegistry rawRegistry = 816 mlirPythonCapsuleToDialectRegistry(capsule.ptr()); 817 if (mlirDialectRegistryIsNull(rawRegistry)) 818 throw py::error_already_set(); 819 return PyDialectRegistry(rawRegistry); 820 } 821 822 //------------------------------------------------------------------------------ 823 // PyLocation 824 //------------------------------------------------------------------------------ 825 826 py::object PyLocation::getCapsule() { 827 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 828 } 829 830 PyLocation PyLocation::createFromCapsule(py::object capsule) { 831 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 832 if (mlirLocationIsNull(rawLoc)) 833 throw py::error_already_set(); 834 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 835 rawLoc); 836 } 837 838 py::object PyLocation::contextEnter() { 839 return PyThreadContextEntry::pushLocation(*this); 840 } 841 842 void PyLocation::contextExit(const pybind11::object &excType, 843 const pybind11::object &excVal, 844 const pybind11::object &excTb) { 845 PyThreadContextEntry::popLocation(*this); 846 } 847 848 PyLocation &DefaultingPyLocation::resolve() { 849 auto *location = PyThreadContextEntry::getDefaultLocation(); 850 if (!location) { 851 throw SetPyError( 852 PyExc_RuntimeError, 853 "An MLIR function requires a Location but none was provided in the " 854 "call or from the surrounding environment. Either pass to the function " 855 "with a 'loc=' argument or establish a default using 'with loc:'"); 856 } 857 return *location; 858 } 859 860 //------------------------------------------------------------------------------ 861 // PyModule 862 //------------------------------------------------------------------------------ 863 864 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 865 : BaseContextObject(std::move(contextRef)), module(module) {} 866 867 PyModule::~PyModule() { 868 py::gil_scoped_acquire acquire; 869 auto &liveModules = getContext()->liveModules; 870 assert(liveModules.count(module.ptr) == 1 && 871 "destroying module not in live map"); 872 liveModules.erase(module.ptr); 873 mlirModuleDestroy(module); 874 } 875 876 PyModuleRef PyModule::forModule(MlirModule module) { 877 MlirContext context = mlirModuleGetContext(module); 878 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 879 880 py::gil_scoped_acquire acquire; 881 auto &liveModules = contextRef->liveModules; 882 auto it = liveModules.find(module.ptr); 883 if (it == liveModules.end()) { 884 // Create. 885 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 886 // Note that the default return value policy on cast is automatic_reference, 887 // which does not take ownership (delete will not be called). 888 // Just be explicit. 889 py::object pyRef = 890 py::cast(unownedModule, py::return_value_policy::take_ownership); 891 unownedModule->handle = pyRef; 892 liveModules[module.ptr] = 893 std::make_pair(unownedModule->handle, unownedModule); 894 return PyModuleRef(unownedModule, std::move(pyRef)); 895 } 896 // Use existing. 897 PyModule *existing = it->second.second; 898 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 899 return PyModuleRef(existing, std::move(pyRef)); 900 } 901 902 py::object PyModule::createFromCapsule(py::object capsule) { 903 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 904 if (mlirModuleIsNull(rawModule)) 905 throw py::error_already_set(); 906 return forModule(rawModule).releaseObject(); 907 } 908 909 py::object PyModule::getCapsule() { 910 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 911 } 912 913 //------------------------------------------------------------------------------ 914 // PyOperation 915 //------------------------------------------------------------------------------ 916 917 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 918 : BaseContextObject(std::move(contextRef)), operation(operation) {} 919 920 PyOperation::~PyOperation() { 921 // If the operation has already been invalidated there is nothing to do. 922 if (!valid) 923 return; 924 auto &liveOperations = getContext()->liveOperations; 925 assert(liveOperations.count(operation.ptr) == 1 && 926 "destroying operation not in live map"); 927 liveOperations.erase(operation.ptr); 928 if (!isAttached()) { 929 mlirOperationDestroy(operation); 930 } 931 } 932 933 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 934 MlirOperation operation, 935 py::object parentKeepAlive) { 936 auto &liveOperations = contextRef->liveOperations; 937 // Create. 938 PyOperation *unownedOperation = 939 new PyOperation(std::move(contextRef), operation); 940 // Note that the default return value policy on cast is automatic_reference, 941 // which does not take ownership (delete will not be called). 942 // Just be explicit. 943 py::object pyRef = 944 py::cast(unownedOperation, py::return_value_policy::take_ownership); 945 unownedOperation->handle = pyRef; 946 if (parentKeepAlive) { 947 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 948 } 949 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 950 return PyOperationRef(unownedOperation, std::move(pyRef)); 951 } 952 953 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 954 MlirOperation operation, 955 py::object parentKeepAlive) { 956 auto &liveOperations = contextRef->liveOperations; 957 auto it = liveOperations.find(operation.ptr); 958 if (it == liveOperations.end()) { 959 // Create. 960 return createInstance(std::move(contextRef), operation, 961 std::move(parentKeepAlive)); 962 } 963 // Use existing. 964 PyOperation *existing = it->second.second; 965 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 966 return PyOperationRef(existing, std::move(pyRef)); 967 } 968 969 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 970 MlirOperation operation, 971 py::object parentKeepAlive) { 972 auto &liveOperations = contextRef->liveOperations; 973 assert(liveOperations.count(operation.ptr) == 0 && 974 "cannot create detached operation that already exists"); 975 (void)liveOperations; 976 977 PyOperationRef created = createInstance(std::move(contextRef), operation, 978 std::move(parentKeepAlive)); 979 created->attached = false; 980 return created; 981 } 982 983 void PyOperation::checkValid() const { 984 if (!valid) { 985 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 986 } 987 } 988 989 void PyOperationBase::print(py::object fileObject, bool binary, 990 llvm::Optional<int64_t> largeElementsLimit, 991 bool enableDebugInfo, bool prettyDebugInfo, 992 bool printGenericOpForm, bool useLocalScope, 993 bool assumeVerified) { 994 PyOperation &operation = getOperation(); 995 operation.checkValid(); 996 if (fileObject.is_none()) 997 fileObject = py::module::import("sys").attr("stdout"); 998 999 if (!assumeVerified && !printGenericOpForm && 1000 !mlirOperationVerify(operation)) { 1001 std::string message("// Verification failed, printing generic form\n"); 1002 if (binary) { 1003 fileObject.attr("write")(py::bytes(message)); 1004 } else { 1005 fileObject.attr("write")(py::str(message)); 1006 } 1007 printGenericOpForm = true; 1008 } 1009 1010 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 1011 if (largeElementsLimit) 1012 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 1013 if (enableDebugInfo) 1014 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 1015 if (printGenericOpForm) 1016 mlirOpPrintingFlagsPrintGenericOpForm(flags); 1017 if (useLocalScope) 1018 mlirOpPrintingFlagsUseLocalScope(flags); 1019 1020 PyFileAccumulator accum(fileObject, binary); 1021 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1022 accum.getUserData()); 1023 mlirOpPrintingFlagsDestroy(flags); 1024 } 1025 1026 py::object PyOperationBase::getAsm(bool binary, 1027 llvm::Optional<int64_t> largeElementsLimit, 1028 bool enableDebugInfo, bool prettyDebugInfo, 1029 bool printGenericOpForm, bool useLocalScope, 1030 bool assumeVerified) { 1031 py::object fileObject; 1032 if (binary) { 1033 fileObject = py::module::import("io").attr("BytesIO")(); 1034 } else { 1035 fileObject = py::module::import("io").attr("StringIO")(); 1036 } 1037 print(fileObject, /*binary=*/binary, 1038 /*largeElementsLimit=*/largeElementsLimit, 1039 /*enableDebugInfo=*/enableDebugInfo, 1040 /*prettyDebugInfo=*/prettyDebugInfo, 1041 /*printGenericOpForm=*/printGenericOpForm, 1042 /*useLocalScope=*/useLocalScope, 1043 /*assumeVerified=*/assumeVerified); 1044 1045 return fileObject.attr("getvalue")(); 1046 } 1047 1048 void PyOperationBase::moveAfter(PyOperationBase &other) { 1049 PyOperation &operation = getOperation(); 1050 PyOperation &otherOp = other.getOperation(); 1051 operation.checkValid(); 1052 otherOp.checkValid(); 1053 mlirOperationMoveAfter(operation, otherOp); 1054 operation.parentKeepAlive = otherOp.parentKeepAlive; 1055 } 1056 1057 void PyOperationBase::moveBefore(PyOperationBase &other) { 1058 PyOperation &operation = getOperation(); 1059 PyOperation &otherOp = other.getOperation(); 1060 operation.checkValid(); 1061 otherOp.checkValid(); 1062 mlirOperationMoveBefore(operation, otherOp); 1063 operation.parentKeepAlive = otherOp.parentKeepAlive; 1064 } 1065 1066 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 1067 checkValid(); 1068 if (!isAttached()) 1069 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 1070 MlirOperation operation = mlirOperationGetParentOperation(get()); 1071 if (mlirOperationIsNull(operation)) 1072 return {}; 1073 return PyOperation::forOperation(getContext(), operation); 1074 } 1075 1076 PyBlock PyOperation::getBlock() { 1077 checkValid(); 1078 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 1079 MlirBlock block = mlirOperationGetBlock(get()); 1080 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1081 assert(parentOperation && "Operation has no parent"); 1082 return PyBlock{std::move(*parentOperation), block}; 1083 } 1084 1085 py::object PyOperation::getCapsule() { 1086 checkValid(); 1087 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 1088 } 1089 1090 py::object PyOperation::createFromCapsule(py::object capsule) { 1091 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1092 if (mlirOperationIsNull(rawOperation)) 1093 throw py::error_already_set(); 1094 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1095 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1096 .releaseObject(); 1097 } 1098 1099 static void maybeInsertOperation(PyOperationRef &op, 1100 const py::object &maybeIp) { 1101 // InsertPoint active? 1102 if (!maybeIp.is(py::cast(false))) { 1103 PyInsertionPoint *ip; 1104 if (maybeIp.is_none()) { 1105 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1106 } else { 1107 ip = py::cast<PyInsertionPoint *>(maybeIp); 1108 } 1109 if (ip) 1110 ip->insert(*op.get()); 1111 } 1112 } 1113 1114 py::object PyOperation::create( 1115 const std::string &name, llvm::Optional<std::vector<PyType *>> results, 1116 llvm::Optional<std::vector<PyValue *>> operands, 1117 llvm::Optional<py::dict> attributes, 1118 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 1119 DefaultingPyLocation location, const py::object &maybeIp) { 1120 llvm::SmallVector<MlirValue, 4> mlirOperands; 1121 llvm::SmallVector<MlirType, 4> mlirResults; 1122 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1123 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1124 1125 // General parameter validation. 1126 if (regions < 0) 1127 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 1128 1129 // Unpack/validate operands. 1130 if (operands) { 1131 mlirOperands.reserve(operands->size()); 1132 for (PyValue *operand : *operands) { 1133 if (!operand) 1134 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 1135 mlirOperands.push_back(operand->get()); 1136 } 1137 } 1138 1139 // Unpack/validate results. 1140 if (results) { 1141 mlirResults.reserve(results->size()); 1142 for (PyType *result : *results) { 1143 // TODO: Verify result type originate from the same context. 1144 if (!result) 1145 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 1146 mlirResults.push_back(*result); 1147 } 1148 } 1149 // Unpack/validate attributes. 1150 if (attributes) { 1151 mlirAttributes.reserve(attributes->size()); 1152 for (auto &it : *attributes) { 1153 std::string key; 1154 try { 1155 key = it.first.cast<std::string>(); 1156 } catch (py::cast_error &err) { 1157 std::string msg = "Invalid attribute key (not a string) when " 1158 "attempting to create the operation \"" + 1159 name + "\" (" + err.what() + ")"; 1160 throw py::cast_error(msg); 1161 } 1162 try { 1163 auto &attribute = it.second.cast<PyAttribute &>(); 1164 // TODO: Verify attribute originates from the same context. 1165 mlirAttributes.emplace_back(std::move(key), attribute); 1166 } catch (py::reference_cast_error &) { 1167 // This exception seems thrown when the value is "None". 1168 std::string msg = 1169 "Found an invalid (`None`?) attribute value for the key \"" + key + 1170 "\" when attempting to create the operation \"" + name + "\""; 1171 throw py::cast_error(msg); 1172 } catch (py::cast_error &err) { 1173 std::string msg = "Invalid attribute value for the key \"" + key + 1174 "\" when attempting to create the operation \"" + 1175 name + "\" (" + err.what() + ")"; 1176 throw py::cast_error(msg); 1177 } 1178 } 1179 } 1180 // Unpack/validate successors. 1181 if (successors) { 1182 mlirSuccessors.reserve(successors->size()); 1183 for (auto *successor : *successors) { 1184 // TODO: Verify successor originate from the same context. 1185 if (!successor) 1186 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1187 mlirSuccessors.push_back(successor->get()); 1188 } 1189 } 1190 1191 // Apply unpacked/validated to the operation state. Beyond this 1192 // point, exceptions cannot be thrown or else the state will leak. 1193 MlirOperationState state = 1194 mlirOperationStateGet(toMlirStringRef(name), location); 1195 if (!mlirOperands.empty()) 1196 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1197 mlirOperands.data()); 1198 if (!mlirResults.empty()) 1199 mlirOperationStateAddResults(&state, mlirResults.size(), 1200 mlirResults.data()); 1201 if (!mlirAttributes.empty()) { 1202 // Note that the attribute names directly reference bytes in 1203 // mlirAttributes, so that vector must not be changed from here 1204 // on. 1205 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1206 mlirNamedAttributes.reserve(mlirAttributes.size()); 1207 for (auto &it : mlirAttributes) 1208 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1209 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1210 toMlirStringRef(it.first)), 1211 it.second)); 1212 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1213 mlirNamedAttributes.data()); 1214 } 1215 if (!mlirSuccessors.empty()) 1216 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1217 mlirSuccessors.data()); 1218 if (regions) { 1219 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1220 mlirRegions.resize(regions); 1221 for (int i = 0; i < regions; ++i) 1222 mlirRegions[i] = mlirRegionCreate(); 1223 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1224 mlirRegions.data()); 1225 } 1226 1227 // Construct the operation. 1228 MlirOperation operation = mlirOperationCreate(&state); 1229 PyOperationRef created = 1230 PyOperation::createDetached(location->getContext(), operation); 1231 maybeInsertOperation(created, maybeIp); 1232 1233 return created->createOpView(); 1234 } 1235 1236 py::object PyOperation::clone(const py::object &maybeIp) { 1237 MlirOperation clonedOperation = mlirOperationClone(operation); 1238 PyOperationRef cloned = 1239 PyOperation::createDetached(getContext(), clonedOperation); 1240 maybeInsertOperation(cloned, maybeIp); 1241 1242 return cloned->createOpView(); 1243 } 1244 1245 py::object PyOperation::createOpView() { 1246 checkValid(); 1247 MlirIdentifier ident = mlirOperationGetName(get()); 1248 MlirStringRef identStr = mlirIdentifierStr(ident); 1249 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1250 StringRef(identStr.data, identStr.length)); 1251 if (opViewClass) 1252 return (*opViewClass)(getRef().getObject()); 1253 return py::cast(PyOpView(getRef().getObject())); 1254 } 1255 1256 void PyOperation::erase() { 1257 checkValid(); 1258 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1259 // Python reference to a child operation is live. All children should also 1260 // have their `valid` bit set to false. 1261 auto &liveOperations = getContext()->liveOperations; 1262 if (liveOperations.count(operation.ptr)) 1263 liveOperations.erase(operation.ptr); 1264 mlirOperationDestroy(operation); 1265 valid = false; 1266 } 1267 1268 //------------------------------------------------------------------------------ 1269 // PyOpView 1270 //------------------------------------------------------------------------------ 1271 1272 py::object PyOpView::buildGeneric( 1273 const py::object &cls, py::list resultTypeList, py::list operandList, 1274 llvm::Optional<py::dict> attributes, 1275 llvm::Optional<std::vector<PyBlock *>> successors, 1276 llvm::Optional<int> regions, DefaultingPyLocation location, 1277 const py::object &maybeIp) { 1278 PyMlirContextRef context = location->getContext(); 1279 // Class level operation construction metadata. 1280 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1281 // Operand and result segment specs are either none, which does no 1282 // variadic unpacking, or a list of ints with segment sizes, where each 1283 // element is either a positive number (typically 1 for a scalar) or -1 to 1284 // indicate that it is derived from the length of the same-indexed operand 1285 // or result (implying that it is a list at that position). 1286 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1287 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1288 1289 std::vector<uint32_t> operandSegmentLengths; 1290 std::vector<uint32_t> resultSegmentLengths; 1291 1292 // Validate/determine region count. 1293 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1294 int opMinRegionCount = std::get<0>(opRegionSpec); 1295 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1296 if (!regions) { 1297 regions = opMinRegionCount; 1298 } 1299 if (*regions < opMinRegionCount) { 1300 throw py::value_error( 1301 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1302 llvm::Twine(opMinRegionCount) + 1303 " regions but was built with regions=" + llvm::Twine(*regions)) 1304 .str()); 1305 } 1306 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1307 throw py::value_error( 1308 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1309 llvm::Twine(opMinRegionCount) + 1310 " regions but was built with regions=" + llvm::Twine(*regions)) 1311 .str()); 1312 } 1313 1314 // Unpack results. 1315 std::vector<PyType *> resultTypes; 1316 resultTypes.reserve(resultTypeList.size()); 1317 if (resultSegmentSpecObj.is_none()) { 1318 // Non-variadic result unpacking. 1319 for (const auto &it : llvm::enumerate(resultTypeList)) { 1320 try { 1321 resultTypes.push_back(py::cast<PyType *>(it.value())); 1322 if (!resultTypes.back()) 1323 throw py::cast_error(); 1324 } catch (py::cast_error &err) { 1325 throw py::value_error((llvm::Twine("Result ") + 1326 llvm::Twine(it.index()) + " of operation \"" + 1327 name + "\" must be a Type (" + err.what() + ")") 1328 .str()); 1329 } 1330 } 1331 } else { 1332 // Sized result unpacking. 1333 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1334 if (resultSegmentSpec.size() != resultTypeList.size()) { 1335 throw py::value_error((llvm::Twine("Operation \"") + name + 1336 "\" requires " + 1337 llvm::Twine(resultSegmentSpec.size()) + 1338 " result segments but was provided " + 1339 llvm::Twine(resultTypeList.size())) 1340 .str()); 1341 } 1342 resultSegmentLengths.reserve(resultTypeList.size()); 1343 for (const auto &it : 1344 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1345 int segmentSpec = std::get<1>(it.value()); 1346 if (segmentSpec == 1 || segmentSpec == 0) { 1347 // Unpack unary element. 1348 try { 1349 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1350 if (resultType) { 1351 resultTypes.push_back(resultType); 1352 resultSegmentLengths.push_back(1); 1353 } else if (segmentSpec == 0) { 1354 // Allowed to be optional. 1355 resultSegmentLengths.push_back(0); 1356 } else { 1357 throw py::cast_error("was None and result is not optional"); 1358 } 1359 } catch (py::cast_error &err) { 1360 throw py::value_error((llvm::Twine("Result ") + 1361 llvm::Twine(it.index()) + " of operation \"" + 1362 name + "\" must be a Type (" + err.what() + 1363 ")") 1364 .str()); 1365 } 1366 } else if (segmentSpec == -1) { 1367 // Unpack sequence by appending. 1368 try { 1369 if (std::get<0>(it.value()).is_none()) { 1370 // Treat it as an empty list. 1371 resultSegmentLengths.push_back(0); 1372 } else { 1373 // Unpack the list. 1374 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1375 for (py::object segmentItem : segment) { 1376 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1377 if (!resultTypes.back()) { 1378 throw py::cast_error("contained a None item"); 1379 } 1380 } 1381 resultSegmentLengths.push_back(segment.size()); 1382 } 1383 } catch (std::exception &err) { 1384 // NOTE: Sloppy to be using a catch-all here, but there are at least 1385 // three different unrelated exceptions that can be thrown in the 1386 // above "casts". Just keep the scope above small and catch them all. 1387 throw py::value_error((llvm::Twine("Result ") + 1388 llvm::Twine(it.index()) + " of operation \"" + 1389 name + "\" must be a Sequence of Types (" + 1390 err.what() + ")") 1391 .str()); 1392 } 1393 } else { 1394 throw py::value_error("Unexpected segment spec"); 1395 } 1396 } 1397 } 1398 1399 // Unpack operands. 1400 std::vector<PyValue *> operands; 1401 operands.reserve(operands.size()); 1402 if (operandSegmentSpecObj.is_none()) { 1403 // Non-sized operand unpacking. 1404 for (const auto &it : llvm::enumerate(operandList)) { 1405 try { 1406 operands.push_back(py::cast<PyValue *>(it.value())); 1407 if (!operands.back()) 1408 throw py::cast_error(); 1409 } catch (py::cast_error &err) { 1410 throw py::value_error((llvm::Twine("Operand ") + 1411 llvm::Twine(it.index()) + " of operation \"" + 1412 name + "\" must be a Value (" + err.what() + ")") 1413 .str()); 1414 } 1415 } 1416 } else { 1417 // Sized operand unpacking. 1418 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1419 if (operandSegmentSpec.size() != operandList.size()) { 1420 throw py::value_error((llvm::Twine("Operation \"") + name + 1421 "\" requires " + 1422 llvm::Twine(operandSegmentSpec.size()) + 1423 "operand segments but was provided " + 1424 llvm::Twine(operandList.size())) 1425 .str()); 1426 } 1427 operandSegmentLengths.reserve(operandList.size()); 1428 for (const auto &it : 1429 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1430 int segmentSpec = std::get<1>(it.value()); 1431 if (segmentSpec == 1 || segmentSpec == 0) { 1432 // Unpack unary element. 1433 try { 1434 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1435 if (operandValue) { 1436 operands.push_back(operandValue); 1437 operandSegmentLengths.push_back(1); 1438 } else if (segmentSpec == 0) { 1439 // Allowed to be optional. 1440 operandSegmentLengths.push_back(0); 1441 } else { 1442 throw py::cast_error("was None and operand is not optional"); 1443 } 1444 } catch (py::cast_error &err) { 1445 throw py::value_error((llvm::Twine("Operand ") + 1446 llvm::Twine(it.index()) + " of operation \"" + 1447 name + "\" must be a Value (" + err.what() + 1448 ")") 1449 .str()); 1450 } 1451 } else if (segmentSpec == -1) { 1452 // Unpack sequence by appending. 1453 try { 1454 if (std::get<0>(it.value()).is_none()) { 1455 // Treat it as an empty list. 1456 operandSegmentLengths.push_back(0); 1457 } else { 1458 // Unpack the list. 1459 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1460 for (py::object segmentItem : segment) { 1461 operands.push_back(py::cast<PyValue *>(segmentItem)); 1462 if (!operands.back()) { 1463 throw py::cast_error("contained a None item"); 1464 } 1465 } 1466 operandSegmentLengths.push_back(segment.size()); 1467 } 1468 } catch (std::exception &err) { 1469 // NOTE: Sloppy to be using a catch-all here, but there are at least 1470 // three different unrelated exceptions that can be thrown in the 1471 // above "casts". Just keep the scope above small and catch them all. 1472 throw py::value_error((llvm::Twine("Operand ") + 1473 llvm::Twine(it.index()) + " of operation \"" + 1474 name + "\" must be a Sequence of Values (" + 1475 err.what() + ")") 1476 .str()); 1477 } 1478 } else { 1479 throw py::value_error("Unexpected segment spec"); 1480 } 1481 } 1482 } 1483 1484 // Merge operand/result segment lengths into attributes if needed. 1485 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1486 // Dup. 1487 if (attributes) { 1488 attributes = py::dict(*attributes); 1489 } else { 1490 attributes = py::dict(); 1491 } 1492 if (attributes->contains("result_segment_sizes") || 1493 attributes->contains("operand_segment_sizes")) { 1494 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1495 "'operand_segment_sizes' attribute is unsupported. " 1496 "Use Operation.create for such low-level access."); 1497 } 1498 1499 // Add result_segment_sizes attribute. 1500 if (!resultSegmentLengths.empty()) { 1501 int64_t size = resultSegmentLengths.size(); 1502 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1503 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1504 resultSegmentLengths.size(), resultSegmentLengths.data()); 1505 (*attributes)["result_segment_sizes"] = 1506 PyAttribute(context, segmentLengthAttr); 1507 } 1508 1509 // Add operand_segment_sizes attribute. 1510 if (!operandSegmentLengths.empty()) { 1511 int64_t size = operandSegmentLengths.size(); 1512 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1513 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1514 operandSegmentLengths.size(), operandSegmentLengths.data()); 1515 (*attributes)["operand_segment_sizes"] = 1516 PyAttribute(context, segmentLengthAttr); 1517 } 1518 } 1519 1520 // Delegate to create. 1521 return PyOperation::create(name, 1522 /*results=*/std::move(resultTypes), 1523 /*operands=*/std::move(operands), 1524 /*attributes=*/std::move(attributes), 1525 /*successors=*/std::move(successors), 1526 /*regions=*/*regions, location, maybeIp); 1527 } 1528 1529 PyOpView::PyOpView(const py::object &operationObject) 1530 // Casting through the PyOperationBase base-class and then back to the 1531 // Operation lets us accept any PyOperationBase subclass. 1532 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1533 operationObject(operation.getRef().getObject()) {} 1534 1535 py::object PyOpView::createRawSubclass(const py::object &userClass) { 1536 // This is... a little gross. The typical pattern is to have a pure python 1537 // class that extends OpView like: 1538 // class AddFOp(_cext.ir.OpView): 1539 // def __init__(self, loc, lhs, rhs): 1540 // operation = loc.context.create_operation( 1541 // "addf", lhs, rhs, results=[lhs.type]) 1542 // super().__init__(operation) 1543 // 1544 // I.e. The goal of the user facing type is to provide a nice constructor 1545 // that has complete freedom for the op under construction. This is at odds 1546 // with our other desire to sometimes create this object by just passing an 1547 // operation (to initialize the base class). We could do *arg and **kwargs 1548 // munging to try to make it work, but instead, we synthesize a new class 1549 // on the fly which extends this user class (AddFOp in this example) and 1550 // *give it* the base class's __init__ method, thus bypassing the 1551 // intermediate subclass's __init__ method entirely. While slightly, 1552 // underhanded, this is safe/legal because the type hierarchy has not changed 1553 // (we just added a new leaf) and we aren't mucking around with __new__. 1554 // Typically, this new class will be stored on the original as "_Raw" and will 1555 // be used for casts and other things that need a variant of the class that 1556 // is initialized purely from an operation. 1557 py::object parentMetaclass = 1558 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1559 py::dict attributes; 1560 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1561 // now. 1562 // auto opViewType = py::type::of<PyOpView>(); 1563 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1564 attributes["__init__"] = opViewType.attr("__init__"); 1565 py::str origName = userClass.attr("__name__"); 1566 py::str newName = py::str("_") + origName; 1567 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1568 } 1569 1570 //------------------------------------------------------------------------------ 1571 // PyInsertionPoint. 1572 //------------------------------------------------------------------------------ 1573 1574 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1575 1576 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1577 : refOperation(beforeOperationBase.getOperation().getRef()), 1578 block((*refOperation)->getBlock()) {} 1579 1580 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1581 PyOperation &operation = operationBase.getOperation(); 1582 if (operation.isAttached()) 1583 throw SetPyError(PyExc_ValueError, 1584 "Attempt to insert operation that is already attached"); 1585 block.getParentOperation()->checkValid(); 1586 MlirOperation beforeOp = {nullptr}; 1587 if (refOperation) { 1588 // Insert before operation. 1589 (*refOperation)->checkValid(); 1590 beforeOp = (*refOperation)->get(); 1591 } else { 1592 // Insert at end (before null) is only valid if the block does not 1593 // already end in a known terminator (violating this will cause assertion 1594 // failures later). 1595 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1596 throw py::index_error("Cannot insert operation at the end of a block " 1597 "that already has a terminator. Did you mean to " 1598 "use 'InsertionPoint.at_block_terminator(block)' " 1599 "versus 'InsertionPoint(block)'?"); 1600 } 1601 } 1602 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1603 operation.setAttached(); 1604 } 1605 1606 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1607 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1608 if (mlirOperationIsNull(firstOp)) { 1609 // Just insert at end. 1610 return PyInsertionPoint(block); 1611 } 1612 1613 // Insert before first op. 1614 PyOperationRef firstOpRef = PyOperation::forOperation( 1615 block.getParentOperation()->getContext(), firstOp); 1616 return PyInsertionPoint{block, std::move(firstOpRef)}; 1617 } 1618 1619 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1620 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1621 if (mlirOperationIsNull(terminator)) 1622 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1623 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1624 block.getParentOperation()->getContext(), terminator); 1625 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1626 } 1627 1628 py::object PyInsertionPoint::contextEnter() { 1629 return PyThreadContextEntry::pushInsertionPoint(*this); 1630 } 1631 1632 void PyInsertionPoint::contextExit(const pybind11::object &excType, 1633 const pybind11::object &excVal, 1634 const pybind11::object &excTb) { 1635 PyThreadContextEntry::popInsertionPoint(*this); 1636 } 1637 1638 //------------------------------------------------------------------------------ 1639 // PyAttribute. 1640 //------------------------------------------------------------------------------ 1641 1642 bool PyAttribute::operator==(const PyAttribute &other) { 1643 return mlirAttributeEqual(attr, other.attr); 1644 } 1645 1646 py::object PyAttribute::getCapsule() { 1647 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1648 } 1649 1650 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1651 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1652 if (mlirAttributeIsNull(rawAttr)) 1653 throw py::error_already_set(); 1654 return PyAttribute( 1655 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1656 } 1657 1658 //------------------------------------------------------------------------------ 1659 // PyNamedAttribute. 1660 //------------------------------------------------------------------------------ 1661 1662 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1663 : ownedName(new std::string(std::move(ownedName))) { 1664 namedAttr = mlirNamedAttributeGet( 1665 mlirIdentifierGet(mlirAttributeGetContext(attr), 1666 toMlirStringRef(*this->ownedName)), 1667 attr); 1668 } 1669 1670 //------------------------------------------------------------------------------ 1671 // PyType. 1672 //------------------------------------------------------------------------------ 1673 1674 bool PyType::operator==(const PyType &other) { 1675 return mlirTypeEqual(type, other.type); 1676 } 1677 1678 py::object PyType::getCapsule() { 1679 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1680 } 1681 1682 PyType PyType::createFromCapsule(py::object capsule) { 1683 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1684 if (mlirTypeIsNull(rawType)) 1685 throw py::error_already_set(); 1686 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1687 rawType); 1688 } 1689 1690 //------------------------------------------------------------------------------ 1691 // PyValue and subclases. 1692 //------------------------------------------------------------------------------ 1693 1694 pybind11::object PyValue::getCapsule() { 1695 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1696 } 1697 1698 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1699 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1700 if (mlirValueIsNull(value)) 1701 throw py::error_already_set(); 1702 MlirOperation owner; 1703 if (mlirValueIsAOpResult(value)) 1704 owner = mlirOpResultGetOwner(value); 1705 if (mlirValueIsABlockArgument(value)) 1706 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1707 if (mlirOperationIsNull(owner)) 1708 throw py::error_already_set(); 1709 MlirContext ctx = mlirOperationGetContext(owner); 1710 PyOperationRef ownerRef = 1711 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1712 return PyValue(ownerRef, value); 1713 } 1714 1715 //------------------------------------------------------------------------------ 1716 // PySymbolTable. 1717 //------------------------------------------------------------------------------ 1718 1719 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1720 : operation(operation.getOperation().getRef()) { 1721 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1722 if (mlirSymbolTableIsNull(symbolTable)) { 1723 throw py::cast_error("Operation is not a Symbol Table."); 1724 } 1725 } 1726 1727 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1728 operation->checkValid(); 1729 MlirOperation symbol = mlirSymbolTableLookup( 1730 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1731 if (mlirOperationIsNull(symbol)) 1732 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1733 1734 return PyOperation::forOperation(operation->getContext(), symbol, 1735 operation.getObject()) 1736 ->createOpView(); 1737 } 1738 1739 void PySymbolTable::erase(PyOperationBase &symbol) { 1740 operation->checkValid(); 1741 symbol.getOperation().checkValid(); 1742 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1743 // The operation is also erased, so we must invalidate it. There may be Python 1744 // references to this operation so we don't want to delete it from the list of 1745 // live operations here. 1746 symbol.getOperation().valid = false; 1747 } 1748 1749 void PySymbolTable::dunderDel(const std::string &name) { 1750 py::object operation = dunderGetItem(name); 1751 erase(py::cast<PyOperationBase &>(operation)); 1752 } 1753 1754 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1755 operation->checkValid(); 1756 symbol.getOperation().checkValid(); 1757 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1758 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1759 if (mlirAttributeIsNull(symbolAttr)) 1760 throw py::value_error("Expected operation to have a symbol name."); 1761 return PyAttribute( 1762 symbol.getOperation().getContext(), 1763 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1764 } 1765 1766 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1767 // Op must already be a symbol. 1768 PyOperation &operation = symbol.getOperation(); 1769 operation.checkValid(); 1770 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1771 MlirAttribute existingNameAttr = 1772 mlirOperationGetAttributeByName(operation.get(), attrName); 1773 if (mlirAttributeIsNull(existingNameAttr)) 1774 throw py::value_error("Expected operation to have a symbol name."); 1775 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1776 } 1777 1778 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1779 const std::string &name) { 1780 // Op must already be a symbol. 1781 PyOperation &operation = symbol.getOperation(); 1782 operation.checkValid(); 1783 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1784 MlirAttribute existingNameAttr = 1785 mlirOperationGetAttributeByName(operation.get(), attrName); 1786 if (mlirAttributeIsNull(existingNameAttr)) 1787 throw py::value_error("Expected operation to have a symbol name."); 1788 MlirAttribute newNameAttr = 1789 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1790 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1791 } 1792 1793 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1794 PyOperation &operation = symbol.getOperation(); 1795 operation.checkValid(); 1796 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1797 MlirAttribute existingVisAttr = 1798 mlirOperationGetAttributeByName(operation.get(), attrName); 1799 if (mlirAttributeIsNull(existingVisAttr)) 1800 throw py::value_error("Expected operation to have a symbol visibility."); 1801 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1802 } 1803 1804 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1805 const std::string &visibility) { 1806 if (visibility != "public" && visibility != "private" && 1807 visibility != "nested") 1808 throw py::value_error( 1809 "Expected visibility to be 'public', 'private' or 'nested'"); 1810 PyOperation &operation = symbol.getOperation(); 1811 operation.checkValid(); 1812 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1813 MlirAttribute existingVisAttr = 1814 mlirOperationGetAttributeByName(operation.get(), attrName); 1815 if (mlirAttributeIsNull(existingVisAttr)) 1816 throw py::value_error("Expected operation to have a symbol visibility."); 1817 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1818 toMlirStringRef(visibility)); 1819 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1820 } 1821 1822 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1823 const std::string &newSymbol, 1824 PyOperationBase &from) { 1825 PyOperation &fromOperation = from.getOperation(); 1826 fromOperation.checkValid(); 1827 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1828 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1829 from.getOperation()))) 1830 1831 throw py::value_error("Symbol rename failed"); 1832 } 1833 1834 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1835 bool allSymUsesVisible, 1836 py::object callback) { 1837 PyOperation &fromOperation = from.getOperation(); 1838 fromOperation.checkValid(); 1839 struct UserData { 1840 PyMlirContextRef context; 1841 py::object callback; 1842 bool gotException; 1843 std::string exceptionWhat; 1844 py::object exceptionType; 1845 }; 1846 UserData userData{ 1847 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1848 mlirSymbolTableWalkSymbolTables( 1849 fromOperation.get(), allSymUsesVisible, 1850 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1851 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1852 auto pyFoundOp = 1853 PyOperation::forOperation(calleeUserData->context, foundOp); 1854 if (calleeUserData->gotException) 1855 return; 1856 try { 1857 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1858 } catch (py::error_already_set &e) { 1859 calleeUserData->gotException = true; 1860 calleeUserData->exceptionWhat = e.what(); 1861 calleeUserData->exceptionType = e.type(); 1862 } 1863 }, 1864 static_cast<void *>(&userData)); 1865 if (userData.gotException) { 1866 std::string message("Exception raised in callback: "); 1867 message.append(userData.exceptionWhat); 1868 throw std::runtime_error(message); 1869 } 1870 } 1871 1872 namespace { 1873 /// CRTP base class for Python MLIR values that subclass Value and should be 1874 /// castable from it. The value hierarchy is one level deep and is not supposed 1875 /// to accommodate other levels unless core MLIR changes. 1876 template <typename DerivedTy> 1877 class PyConcreteValue : public PyValue { 1878 public: 1879 // Derived classes must define statics for: 1880 // IsAFunctionTy isaFunction 1881 // const char *pyClassName 1882 // and redefine bindDerived. 1883 using ClassTy = py::class_<DerivedTy, PyValue>; 1884 using IsAFunctionTy = bool (*)(MlirValue); 1885 1886 PyConcreteValue() = default; 1887 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1888 : PyValue(operationRef, value) {} 1889 PyConcreteValue(PyValue &orig) 1890 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1891 1892 /// Attempts to cast the original value to the derived type and throws on 1893 /// type mismatches. 1894 static MlirValue castFrom(PyValue &orig) { 1895 if (!DerivedTy::isaFunction(orig.get())) { 1896 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1897 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1898 DerivedTy::pyClassName + 1899 " (from " + origRepr + ")"); 1900 } 1901 return orig.get(); 1902 } 1903 1904 /// Binds the Python module objects to functions of this class. 1905 static void bind(py::module &m) { 1906 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1907 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1908 cls.def_static( 1909 "isinstance", 1910 [](PyValue &otherValue) -> bool { 1911 return DerivedTy::isaFunction(otherValue); 1912 }, 1913 py::arg("other_value")); 1914 DerivedTy::bindDerived(cls); 1915 } 1916 1917 /// Implemented by derived classes to add methods to the Python subclass. 1918 static void bindDerived(ClassTy &m) {} 1919 }; 1920 1921 /// Python wrapper for MlirBlockArgument. 1922 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1923 public: 1924 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1925 static constexpr const char *pyClassName = "BlockArgument"; 1926 using PyConcreteValue::PyConcreteValue; 1927 1928 static void bindDerived(ClassTy &c) { 1929 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1930 return PyBlock(self.getParentOperation(), 1931 mlirBlockArgumentGetOwner(self.get())); 1932 }); 1933 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1934 return mlirBlockArgumentGetArgNumber(self.get()); 1935 }); 1936 c.def( 1937 "set_type", 1938 [](PyBlockArgument &self, PyType type) { 1939 return mlirBlockArgumentSetType(self.get(), type); 1940 }, 1941 py::arg("type")); 1942 } 1943 }; 1944 1945 /// Python wrapper for MlirOpResult. 1946 class PyOpResult : public PyConcreteValue<PyOpResult> { 1947 public: 1948 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1949 static constexpr const char *pyClassName = "OpResult"; 1950 using PyConcreteValue::PyConcreteValue; 1951 1952 static void bindDerived(ClassTy &c) { 1953 c.def_property_readonly("owner", [](PyOpResult &self) { 1954 assert( 1955 mlirOperationEqual(self.getParentOperation()->get(), 1956 mlirOpResultGetOwner(self.get())) && 1957 "expected the owner of the value in Python to match that in the IR"); 1958 return self.getParentOperation().getObject(); 1959 }); 1960 c.def_property_readonly("result_number", [](PyOpResult &self) { 1961 return mlirOpResultGetResultNumber(self.get()); 1962 }); 1963 } 1964 }; 1965 1966 /// Returns the list of types of the values held by container. 1967 template <typename Container> 1968 static std::vector<PyType> getValueTypes(Container &container, 1969 PyMlirContextRef &context) { 1970 std::vector<PyType> result; 1971 result.reserve(container.size()); 1972 for (int i = 0, e = container.size(); i < e; ++i) { 1973 result.push_back( 1974 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1975 } 1976 return result; 1977 } 1978 1979 /// A list of block arguments. Internally, these are stored as consecutive 1980 /// elements, random access is cheap. The argument list is associated with the 1981 /// operation that contains the block (detached blocks are not allowed in 1982 /// Python bindings) and extends its lifetime. 1983 class PyBlockArgumentList 1984 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1985 public: 1986 static constexpr const char *pyClassName = "BlockArgumentList"; 1987 1988 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1989 intptr_t startIndex = 0, intptr_t length = -1, 1990 intptr_t step = 1) 1991 : Sliceable(startIndex, 1992 length == -1 ? mlirBlockGetNumArguments(block) : length, 1993 step), 1994 operation(std::move(operation)), block(block) {} 1995 1996 static void bindDerived(ClassTy &c) { 1997 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1998 return getValueTypes(self, self.operation->getContext()); 1999 }); 2000 } 2001 2002 private: 2003 /// Give the parent CRTP class access to hook implementations below. 2004 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>; 2005 2006 /// Returns the number of arguments in the list. 2007 intptr_t getRawNumElements() { 2008 operation->checkValid(); 2009 return mlirBlockGetNumArguments(block); 2010 } 2011 2012 /// Returns `pos`-the element in the list. 2013 PyBlockArgument getRawElement(intptr_t pos) { 2014 MlirValue argument = mlirBlockGetArgument(block, pos); 2015 return PyBlockArgument(operation, argument); 2016 } 2017 2018 /// Returns a sublist of this list. 2019 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 2020 intptr_t step) { 2021 return PyBlockArgumentList(operation, block, startIndex, length, step); 2022 } 2023 2024 PyOperationRef operation; 2025 MlirBlock block; 2026 }; 2027 2028 /// A list of operation operands. Internally, these are stored as consecutive 2029 /// elements, random access is cheap. The result list is associated with the 2030 /// operation whose results these are, and extends the lifetime of this 2031 /// operation. 2032 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 2033 public: 2034 static constexpr const char *pyClassName = "OpOperandList"; 2035 2036 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2037 intptr_t length = -1, intptr_t step = 1) 2038 : Sliceable(startIndex, 2039 length == -1 ? mlirOperationGetNumOperands(operation->get()) 2040 : length, 2041 step), 2042 operation(operation) {} 2043 2044 void dunderSetItem(intptr_t index, PyValue value) { 2045 index = wrapIndex(index); 2046 mlirOperationSetOperand(operation->get(), index, value.get()); 2047 } 2048 2049 static void bindDerived(ClassTy &c) { 2050 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2051 } 2052 2053 private: 2054 /// Give the parent CRTP class access to hook implementations below. 2055 friend class Sliceable<PyOpOperandList, PyValue>; 2056 2057 intptr_t getRawNumElements() { 2058 operation->checkValid(); 2059 return mlirOperationGetNumOperands(operation->get()); 2060 } 2061 2062 PyValue getRawElement(intptr_t pos) { 2063 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2064 MlirOperation owner; 2065 if (mlirValueIsAOpResult(operand)) 2066 owner = mlirOpResultGetOwner(operand); 2067 else if (mlirValueIsABlockArgument(operand)) 2068 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2069 else 2070 assert(false && "Value must be an block arg or op result."); 2071 PyOperationRef pyOwner = 2072 PyOperation::forOperation(operation->getContext(), owner); 2073 return PyValue(pyOwner, operand); 2074 } 2075 2076 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2077 return PyOpOperandList(operation, startIndex, length, step); 2078 } 2079 2080 PyOperationRef operation; 2081 }; 2082 2083 /// A list of operation results. Internally, these are stored as consecutive 2084 /// elements, random access is cheap. The result list is associated with the 2085 /// operation whose results these are, and extends the lifetime of this 2086 /// operation. 2087 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 2088 public: 2089 static constexpr const char *pyClassName = "OpResultList"; 2090 2091 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 2092 intptr_t length = -1, intptr_t step = 1) 2093 : Sliceable(startIndex, 2094 length == -1 ? mlirOperationGetNumResults(operation->get()) 2095 : length, 2096 step), 2097 operation(operation) {} 2098 2099 static void bindDerived(ClassTy &c) { 2100 c.def_property_readonly("types", [](PyOpResultList &self) { 2101 return getValueTypes(self, self.operation->getContext()); 2102 }); 2103 } 2104 2105 private: 2106 /// Give the parent CRTP class access to hook implementations below. 2107 friend class Sliceable<PyOpResultList, PyOpResult>; 2108 2109 intptr_t getRawNumElements() { 2110 operation->checkValid(); 2111 return mlirOperationGetNumResults(operation->get()); 2112 } 2113 2114 PyOpResult getRawElement(intptr_t index) { 2115 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 2116 return PyOpResult(value); 2117 } 2118 2119 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2120 return PyOpResultList(operation, startIndex, length, step); 2121 } 2122 2123 PyOperationRef operation; 2124 }; 2125 2126 /// A list of operation attributes. Can be indexed by name, producing 2127 /// attributes, or by index, producing named attributes. 2128 class PyOpAttributeMap { 2129 public: 2130 PyOpAttributeMap(PyOperationRef operation) 2131 : operation(std::move(operation)) {} 2132 2133 PyAttribute dunderGetItemNamed(const std::string &name) { 2134 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2135 toMlirStringRef(name)); 2136 if (mlirAttributeIsNull(attr)) { 2137 throw SetPyError(PyExc_KeyError, 2138 "attempt to access a non-existent attribute"); 2139 } 2140 return PyAttribute(operation->getContext(), attr); 2141 } 2142 2143 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2144 if (index < 0 || index >= dunderLen()) { 2145 throw SetPyError(PyExc_IndexError, 2146 "attempt to access out of bounds attribute"); 2147 } 2148 MlirNamedAttribute namedAttr = 2149 mlirOperationGetAttribute(operation->get(), index); 2150 return PyNamedAttribute( 2151 namedAttr.attribute, 2152 std::string(mlirIdentifierStr(namedAttr.name).data, 2153 mlirIdentifierStr(namedAttr.name).length)); 2154 } 2155 2156 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2157 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2158 attr); 2159 } 2160 2161 void dunderDelItem(const std::string &name) { 2162 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2163 toMlirStringRef(name)); 2164 if (!removed) 2165 throw SetPyError(PyExc_KeyError, 2166 "attempt to delete a non-existent attribute"); 2167 } 2168 2169 intptr_t dunderLen() { 2170 return mlirOperationGetNumAttributes(operation->get()); 2171 } 2172 2173 bool dunderContains(const std::string &name) { 2174 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2175 operation->get(), toMlirStringRef(name))); 2176 } 2177 2178 static void bind(py::module &m) { 2179 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2180 .def("__contains__", &PyOpAttributeMap::dunderContains) 2181 .def("__len__", &PyOpAttributeMap::dunderLen) 2182 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2183 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2184 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2185 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2186 } 2187 2188 private: 2189 PyOperationRef operation; 2190 }; 2191 2192 } // namespace 2193 2194 //------------------------------------------------------------------------------ 2195 // Populates the core exports of the 'ir' submodule. 2196 //------------------------------------------------------------------------------ 2197 2198 void mlir::python::populateIRCore(py::module &m) { 2199 //---------------------------------------------------------------------------- 2200 // Enums. 2201 //---------------------------------------------------------------------------- 2202 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local()) 2203 .value("ERROR", MlirDiagnosticError) 2204 .value("WARNING", MlirDiagnosticWarning) 2205 .value("NOTE", MlirDiagnosticNote) 2206 .value("REMARK", MlirDiagnosticRemark); 2207 2208 //---------------------------------------------------------------------------- 2209 // Mapping of Diagnostics. 2210 //---------------------------------------------------------------------------- 2211 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local()) 2212 .def_property_readonly("severity", &PyDiagnostic::getSeverity) 2213 .def_property_readonly("location", &PyDiagnostic::getLocation) 2214 .def_property_readonly("message", &PyDiagnostic::getMessage) 2215 .def_property_readonly("notes", &PyDiagnostic::getNotes) 2216 .def("__str__", [](PyDiagnostic &self) -> py::str { 2217 if (!self.isValid()) 2218 return "<Invalid Diagnostic>"; 2219 return self.getMessage(); 2220 }); 2221 2222 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local()) 2223 .def("detach", &PyDiagnosticHandler::detach) 2224 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) 2225 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) 2226 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2227 .def("__exit__", &PyDiagnosticHandler::contextExit); 2228 2229 //---------------------------------------------------------------------------- 2230 // Mapping of MlirContext. 2231 // Note that this is exported as _BaseContext. The containing, Python level 2232 // __init__.py will subclass it with site-specific functionality and set a 2233 // "Context" attribute on this module. 2234 //---------------------------------------------------------------------------- 2235 py::class_<PyMlirContext>(m, "_BaseContext", py::module_local()) 2236 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2237 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2238 .def("_get_context_again", 2239 [](PyMlirContext &self) { 2240 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2241 return ref.releaseObject(); 2242 }) 2243 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2244 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) 2245 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2246 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2247 &PyMlirContext::getCapsule) 2248 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2249 .def("__enter__", &PyMlirContext::contextEnter) 2250 .def("__exit__", &PyMlirContext::contextExit) 2251 .def_property_readonly_static( 2252 "current", 2253 [](py::object & /*class*/) { 2254 auto *context = PyThreadContextEntry::getDefaultContext(); 2255 if (!context) 2256 throw SetPyError(PyExc_ValueError, "No current Context"); 2257 return context; 2258 }, 2259 "Gets the Context bound to the current thread or raises ValueError") 2260 .def_property_readonly( 2261 "dialects", 2262 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2263 "Gets a container for accessing dialects by name") 2264 .def_property_readonly( 2265 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2266 "Alias for 'dialect'") 2267 .def( 2268 "get_dialect_descriptor", 2269 [=](PyMlirContext &self, std::string &name) { 2270 MlirDialect dialect = mlirContextGetOrLoadDialect( 2271 self.get(), {name.data(), name.size()}); 2272 if (mlirDialectIsNull(dialect)) { 2273 throw SetPyError(PyExc_ValueError, 2274 Twine("Dialect '") + name + "' not found"); 2275 } 2276 return PyDialectDescriptor(self.getRef(), dialect); 2277 }, 2278 py::arg("dialect_name"), 2279 "Gets or loads a dialect by name, returning its descriptor object") 2280 .def_property( 2281 "allow_unregistered_dialects", 2282 [](PyMlirContext &self) -> bool { 2283 return mlirContextGetAllowUnregisteredDialects(self.get()); 2284 }, 2285 [](PyMlirContext &self, bool value) { 2286 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2287 }) 2288 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2289 py::arg("callback"), 2290 "Attaches a diagnostic handler that will receive callbacks") 2291 .def( 2292 "enable_multithreading", 2293 [](PyMlirContext &self, bool enable) { 2294 mlirContextEnableMultithreading(self.get(), enable); 2295 }, 2296 py::arg("enable")) 2297 .def( 2298 "is_registered_operation", 2299 [](PyMlirContext &self, std::string &name) { 2300 return mlirContextIsRegisteredOperation( 2301 self.get(), MlirStringRef{name.data(), name.size()}); 2302 }, 2303 py::arg("operation_name")) 2304 .def( 2305 "append_dialect_registry", 2306 [](PyMlirContext &self, PyDialectRegistry ®istry) { 2307 mlirContextAppendDialectRegistry(self.get(), registry); 2308 }, 2309 py::arg("registry")) 2310 .def("load_all_available_dialects", [](PyMlirContext &self) { 2311 mlirContextLoadAllAvailableDialects(self.get()); 2312 }); 2313 2314 //---------------------------------------------------------------------------- 2315 // Mapping of PyDialectDescriptor 2316 //---------------------------------------------------------------------------- 2317 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2318 .def_property_readonly("namespace", 2319 [](PyDialectDescriptor &self) { 2320 MlirStringRef ns = 2321 mlirDialectGetNamespace(self.get()); 2322 return py::str(ns.data, ns.length); 2323 }) 2324 .def("__repr__", [](PyDialectDescriptor &self) { 2325 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2326 std::string repr("<DialectDescriptor "); 2327 repr.append(ns.data, ns.length); 2328 repr.append(">"); 2329 return repr; 2330 }); 2331 2332 //---------------------------------------------------------------------------- 2333 // Mapping of PyDialects 2334 //---------------------------------------------------------------------------- 2335 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2336 .def("__getitem__", 2337 [=](PyDialects &self, std::string keyName) { 2338 MlirDialect dialect = 2339 self.getDialectForKey(keyName, /*attrError=*/false); 2340 py::object descriptor = 2341 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2342 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2343 }) 2344 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2345 MlirDialect dialect = 2346 self.getDialectForKey(attrName, /*attrError=*/true); 2347 py::object descriptor = 2348 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2349 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2350 }); 2351 2352 //---------------------------------------------------------------------------- 2353 // Mapping of PyDialect 2354 //---------------------------------------------------------------------------- 2355 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2356 .def(py::init<py::object>(), py::arg("descriptor")) 2357 .def_property_readonly( 2358 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2359 .def("__repr__", [](py::object self) { 2360 auto clazz = self.attr("__class__"); 2361 return py::str("<Dialect ") + 2362 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2363 clazz.attr("__module__") + py::str(".") + 2364 clazz.attr("__name__") + py::str(")>"); 2365 }); 2366 2367 //---------------------------------------------------------------------------- 2368 // Mapping of PyDialectRegistry 2369 //---------------------------------------------------------------------------- 2370 py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local()) 2371 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2372 &PyDialectRegistry::getCapsule) 2373 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) 2374 .def(py::init<>()); 2375 2376 //---------------------------------------------------------------------------- 2377 // Mapping of Location 2378 //---------------------------------------------------------------------------- 2379 py::class_<PyLocation>(m, "Location", py::module_local()) 2380 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2381 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2382 .def("__enter__", &PyLocation::contextEnter) 2383 .def("__exit__", &PyLocation::contextExit) 2384 .def("__eq__", 2385 [](PyLocation &self, PyLocation &other) -> bool { 2386 return mlirLocationEqual(self, other); 2387 }) 2388 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2389 .def_property_readonly_static( 2390 "current", 2391 [](py::object & /*class*/) { 2392 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2393 if (!loc) 2394 throw SetPyError(PyExc_ValueError, "No current Location"); 2395 return loc; 2396 }, 2397 "Gets the Location bound to the current thread or raises ValueError") 2398 .def_static( 2399 "unknown", 2400 [](DefaultingPyMlirContext context) { 2401 return PyLocation(context->getRef(), 2402 mlirLocationUnknownGet(context->get())); 2403 }, 2404 py::arg("context") = py::none(), 2405 "Gets a Location representing an unknown location") 2406 .def_static( 2407 "callsite", 2408 [](PyLocation callee, const std::vector<PyLocation> &frames, 2409 DefaultingPyMlirContext context) { 2410 if (frames.empty()) 2411 throw py::value_error("No caller frames provided"); 2412 MlirLocation caller = frames.back().get(); 2413 for (const PyLocation &frame : 2414 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2415 caller = mlirLocationCallSiteGet(frame.get(), caller); 2416 return PyLocation(context->getRef(), 2417 mlirLocationCallSiteGet(callee.get(), caller)); 2418 }, 2419 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2420 kContextGetCallSiteLocationDocstring) 2421 .def_static( 2422 "file", 2423 [](std::string filename, int line, int col, 2424 DefaultingPyMlirContext context) { 2425 return PyLocation( 2426 context->getRef(), 2427 mlirLocationFileLineColGet( 2428 context->get(), toMlirStringRef(filename), line, col)); 2429 }, 2430 py::arg("filename"), py::arg("line"), py::arg("col"), 2431 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2432 .def_static( 2433 "fused", 2434 [](const std::vector<PyLocation> &pyLocations, 2435 llvm::Optional<PyAttribute> metadata, 2436 DefaultingPyMlirContext context) { 2437 llvm::SmallVector<MlirLocation, 4> locations; 2438 locations.reserve(pyLocations.size()); 2439 for (auto &pyLocation : pyLocations) 2440 locations.push_back(pyLocation.get()); 2441 MlirLocation location = mlirLocationFusedGet( 2442 context->get(), locations.size(), locations.data(), 2443 metadata ? metadata->get() : MlirAttribute{0}); 2444 return PyLocation(context->getRef(), location); 2445 }, 2446 py::arg("locations"), py::arg("metadata") = py::none(), 2447 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2448 .def_static( 2449 "name", 2450 [](std::string name, llvm::Optional<PyLocation> childLoc, 2451 DefaultingPyMlirContext context) { 2452 return PyLocation( 2453 context->getRef(), 2454 mlirLocationNameGet( 2455 context->get(), toMlirStringRef(name), 2456 childLoc ? childLoc->get() 2457 : mlirLocationUnknownGet(context->get()))); 2458 }, 2459 py::arg("name"), py::arg("childLoc") = py::none(), 2460 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2461 .def_property_readonly( 2462 "context", 2463 [](PyLocation &self) { return self.getContext().getObject(); }, 2464 "Context that owns the Location") 2465 .def( 2466 "emit_error", 2467 [](PyLocation &self, std::string message) { 2468 mlirEmitError(self, message.c_str()); 2469 }, 2470 py::arg("message"), "Emits an error at this location") 2471 .def("__repr__", [](PyLocation &self) { 2472 PyPrintAccumulator printAccum; 2473 mlirLocationPrint(self, printAccum.getCallback(), 2474 printAccum.getUserData()); 2475 return printAccum.join(); 2476 }); 2477 2478 //---------------------------------------------------------------------------- 2479 // Mapping of Module 2480 //---------------------------------------------------------------------------- 2481 py::class_<PyModule>(m, "Module", py::module_local()) 2482 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2483 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2484 .def_static( 2485 "parse", 2486 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2487 MlirModule module = mlirModuleCreateParse( 2488 context->get(), toMlirStringRef(moduleAsm)); 2489 // TODO: Rework error reporting once diagnostic engine is exposed 2490 // in C API. 2491 if (mlirModuleIsNull(module)) { 2492 throw SetPyError( 2493 PyExc_ValueError, 2494 "Unable to parse module assembly (see diagnostics)"); 2495 } 2496 return PyModule::forModule(module).releaseObject(); 2497 }, 2498 py::arg("asm"), py::arg("context") = py::none(), 2499 kModuleParseDocstring) 2500 .def_static( 2501 "create", 2502 [](DefaultingPyLocation loc) { 2503 MlirModule module = mlirModuleCreateEmpty(loc); 2504 return PyModule::forModule(module).releaseObject(); 2505 }, 2506 py::arg("loc") = py::none(), "Creates an empty module") 2507 .def_property_readonly( 2508 "context", 2509 [](PyModule &self) { return self.getContext().getObject(); }, 2510 "Context that created the Module") 2511 .def_property_readonly( 2512 "operation", 2513 [](PyModule &self) { 2514 return PyOperation::forOperation(self.getContext(), 2515 mlirModuleGetOperation(self.get()), 2516 self.getRef().releaseObject()) 2517 .releaseObject(); 2518 }, 2519 "Accesses the module as an operation") 2520 .def_property_readonly( 2521 "body", 2522 [](PyModule &self) { 2523 PyOperationRef moduleOp = PyOperation::forOperation( 2524 self.getContext(), mlirModuleGetOperation(self.get()), 2525 self.getRef().releaseObject()); 2526 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 2527 return returnBlock; 2528 }, 2529 "Return the block for this module") 2530 .def( 2531 "dump", 2532 [](PyModule &self) { 2533 mlirOperationDump(mlirModuleGetOperation(self.get())); 2534 }, 2535 kDumpDocstring) 2536 .def( 2537 "__str__", 2538 [](py::object self) { 2539 // Defer to the operation's __str__. 2540 return self.attr("operation").attr("__str__")(); 2541 }, 2542 kOperationStrDunderDocstring); 2543 2544 //---------------------------------------------------------------------------- 2545 // Mapping of Operation. 2546 //---------------------------------------------------------------------------- 2547 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2548 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2549 [](PyOperationBase &self) { 2550 return self.getOperation().getCapsule(); 2551 }) 2552 .def("__eq__", 2553 [](PyOperationBase &self, PyOperationBase &other) { 2554 return &self.getOperation() == &other.getOperation(); 2555 }) 2556 .def("__eq__", 2557 [](PyOperationBase &self, py::object other) { return false; }) 2558 .def("__hash__", 2559 [](PyOperationBase &self) { 2560 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2561 }) 2562 .def_property_readonly("attributes", 2563 [](PyOperationBase &self) { 2564 return PyOpAttributeMap( 2565 self.getOperation().getRef()); 2566 }) 2567 .def_property_readonly("operands", 2568 [](PyOperationBase &self) { 2569 return PyOpOperandList( 2570 self.getOperation().getRef()); 2571 }) 2572 .def_property_readonly("regions", 2573 [](PyOperationBase &self) { 2574 return PyRegionList( 2575 self.getOperation().getRef()); 2576 }) 2577 .def_property_readonly( 2578 "results", 2579 [](PyOperationBase &self) { 2580 return PyOpResultList(self.getOperation().getRef()); 2581 }, 2582 "Returns the list of Operation results.") 2583 .def_property_readonly( 2584 "result", 2585 [](PyOperationBase &self) { 2586 auto &operation = self.getOperation(); 2587 auto numResults = mlirOperationGetNumResults(operation); 2588 if (numResults != 1) { 2589 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2590 throw SetPyError( 2591 PyExc_ValueError, 2592 Twine("Cannot call .result on operation ") + 2593 StringRef(name.data, name.length) + " which has " + 2594 Twine(numResults) + 2595 " results (it is only valid for operations with a " 2596 "single result)"); 2597 } 2598 return PyOpResult(operation.getRef(), 2599 mlirOperationGetResult(operation, 0)); 2600 }, 2601 "Shortcut to get an op result if it has only one (throws an error " 2602 "otherwise).") 2603 .def_property_readonly( 2604 "location", 2605 [](PyOperationBase &self) { 2606 PyOperation &operation = self.getOperation(); 2607 return PyLocation(operation.getContext(), 2608 mlirOperationGetLocation(operation.get())); 2609 }, 2610 "Returns the source location the operation was defined or derived " 2611 "from.") 2612 .def( 2613 "__str__", 2614 [](PyOperationBase &self) { 2615 return self.getAsm(/*binary=*/false, 2616 /*largeElementsLimit=*/llvm::None, 2617 /*enableDebugInfo=*/false, 2618 /*prettyDebugInfo=*/false, 2619 /*printGenericOpForm=*/false, 2620 /*useLocalScope=*/false, 2621 /*assumeVerified=*/false); 2622 }, 2623 "Returns the assembly form of the operation.") 2624 .def("print", &PyOperationBase::print, 2625 // Careful: Lots of arguments must match up with print method. 2626 py::arg("file") = py::none(), py::arg("binary") = false, 2627 py::arg("large_elements_limit") = py::none(), 2628 py::arg("enable_debug_info") = false, 2629 py::arg("pretty_debug_info") = false, 2630 py::arg("print_generic_op_form") = false, 2631 py::arg("use_local_scope") = false, 2632 py::arg("assume_verified") = false, kOperationPrintDocstring) 2633 .def("get_asm", &PyOperationBase::getAsm, 2634 // Careful: Lots of arguments must match up with get_asm method. 2635 py::arg("binary") = false, 2636 py::arg("large_elements_limit") = py::none(), 2637 py::arg("enable_debug_info") = false, 2638 py::arg("pretty_debug_info") = false, 2639 py::arg("print_generic_op_form") = false, 2640 py::arg("use_local_scope") = false, 2641 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2642 .def( 2643 "verify", 2644 [](PyOperationBase &self) { 2645 return mlirOperationVerify(self.getOperation()); 2646 }, 2647 "Verify the operation and return true if it passes, false if it " 2648 "fails.") 2649 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2650 "Puts self immediately after the other operation in its parent " 2651 "block.") 2652 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2653 "Puts self immediately before the other operation in its parent " 2654 "block.") 2655 .def( 2656 "detach_from_parent", 2657 [](PyOperationBase &self) { 2658 PyOperation &operation = self.getOperation(); 2659 operation.checkValid(); 2660 if (!operation.isAttached()) 2661 throw py::value_error("Detached operation has no parent."); 2662 2663 operation.detachFromParent(); 2664 return operation.createOpView(); 2665 }, 2666 "Detaches the operation from its parent block."); 2667 2668 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2669 .def_static("create", &PyOperation::create, py::arg("name"), 2670 py::arg("results") = py::none(), 2671 py::arg("operands") = py::none(), 2672 py::arg("attributes") = py::none(), 2673 py::arg("successors") = py::none(), py::arg("regions") = 0, 2674 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2675 kOperationCreateDocstring) 2676 .def_property_readonly("parent", 2677 [](PyOperation &self) -> py::object { 2678 auto parent = self.getParentOperation(); 2679 if (parent) 2680 return parent->getObject(); 2681 return py::none(); 2682 }) 2683 .def("erase", &PyOperation::erase) 2684 .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) 2685 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2686 &PyOperation::getCapsule) 2687 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2688 .def_property_readonly("name", 2689 [](PyOperation &self) { 2690 self.checkValid(); 2691 MlirOperation operation = self.get(); 2692 MlirStringRef name = mlirIdentifierStr( 2693 mlirOperationGetName(operation)); 2694 return py::str(name.data, name.length); 2695 }) 2696 .def_property_readonly( 2697 "context", 2698 [](PyOperation &self) { 2699 self.checkValid(); 2700 return self.getContext().getObject(); 2701 }, 2702 "Context that owns the Operation") 2703 .def_property_readonly("opview", &PyOperation::createOpView); 2704 2705 auto opViewClass = 2706 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2707 .def(py::init<py::object>(), py::arg("operation")) 2708 .def_property_readonly("operation", &PyOpView::getOperationObject) 2709 .def_property_readonly( 2710 "context", 2711 [](PyOpView &self) { 2712 return self.getOperation().getContext().getObject(); 2713 }, 2714 "Context that owns the Operation") 2715 .def("__str__", [](PyOpView &self) { 2716 return py::str(self.getOperationObject()); 2717 }); 2718 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2719 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2720 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2721 opViewClass.attr("build_generic") = classmethod( 2722 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2723 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2724 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2725 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2726 "Builds a specific, generated OpView based on class level attributes."); 2727 2728 //---------------------------------------------------------------------------- 2729 // Mapping of PyRegion. 2730 //---------------------------------------------------------------------------- 2731 py::class_<PyRegion>(m, "Region", py::module_local()) 2732 .def_property_readonly( 2733 "blocks", 2734 [](PyRegion &self) { 2735 return PyBlockList(self.getParentOperation(), self.get()); 2736 }, 2737 "Returns a forward-optimized sequence of blocks.") 2738 .def_property_readonly( 2739 "owner", 2740 [](PyRegion &self) { 2741 return self.getParentOperation()->createOpView(); 2742 }, 2743 "Returns the operation owning this region.") 2744 .def( 2745 "__iter__", 2746 [](PyRegion &self) { 2747 self.checkValid(); 2748 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2749 return PyBlockIterator(self.getParentOperation(), firstBlock); 2750 }, 2751 "Iterates over blocks in the region.") 2752 .def("__eq__", 2753 [](PyRegion &self, PyRegion &other) { 2754 return self.get().ptr == other.get().ptr; 2755 }) 2756 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2757 2758 //---------------------------------------------------------------------------- 2759 // Mapping of PyBlock. 2760 //---------------------------------------------------------------------------- 2761 py::class_<PyBlock>(m, "Block", py::module_local()) 2762 .def_property_readonly( 2763 "owner", 2764 [](PyBlock &self) { 2765 return self.getParentOperation()->createOpView(); 2766 }, 2767 "Returns the owning operation of this block.") 2768 .def_property_readonly( 2769 "region", 2770 [](PyBlock &self) { 2771 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2772 return PyRegion(self.getParentOperation(), region); 2773 }, 2774 "Returns the owning region of this block.") 2775 .def_property_readonly( 2776 "arguments", 2777 [](PyBlock &self) { 2778 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2779 }, 2780 "Returns a list of block arguments.") 2781 .def_property_readonly( 2782 "operations", 2783 [](PyBlock &self) { 2784 return PyOperationList(self.getParentOperation(), self.get()); 2785 }, 2786 "Returns a forward-optimized sequence of operations.") 2787 .def_static( 2788 "create_at_start", 2789 [](PyRegion &parent, py::list pyArgTypes) { 2790 parent.checkValid(); 2791 llvm::SmallVector<MlirType, 4> argTypes; 2792 llvm::SmallVector<MlirLocation, 4> argLocs; 2793 argTypes.reserve(pyArgTypes.size()); 2794 argLocs.reserve(pyArgTypes.size()); 2795 for (auto &pyArg : pyArgTypes) { 2796 argTypes.push_back(pyArg.cast<PyType &>()); 2797 // TODO: Pass in a proper location here. 2798 argLocs.push_back( 2799 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2800 } 2801 2802 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2803 argLocs.data()); 2804 mlirRegionInsertOwnedBlock(parent, 0, block); 2805 return PyBlock(parent.getParentOperation(), block); 2806 }, 2807 py::arg("parent"), py::arg("arg_types") = py::list(), 2808 "Creates and returns a new Block at the beginning of the given " 2809 "region (with given argument types).") 2810 .def( 2811 "append_to", 2812 [](PyBlock &self, PyRegion ®ion) { 2813 MlirBlock b = self.get(); 2814 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) 2815 mlirBlockDetach(b); 2816 mlirRegionAppendOwnedBlock(region.get(), b); 2817 }, 2818 "Append this block to a region, transferring ownership if necessary") 2819 .def( 2820 "create_before", 2821 [](PyBlock &self, py::args pyArgTypes) { 2822 self.checkValid(); 2823 llvm::SmallVector<MlirType, 4> argTypes; 2824 llvm::SmallVector<MlirLocation, 4> argLocs; 2825 argTypes.reserve(pyArgTypes.size()); 2826 argLocs.reserve(pyArgTypes.size()); 2827 for (auto &pyArg : pyArgTypes) { 2828 argTypes.push_back(pyArg.cast<PyType &>()); 2829 // TODO: Pass in a proper location here. 2830 argLocs.push_back( 2831 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2832 } 2833 2834 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2835 argLocs.data()); 2836 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2837 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2838 return PyBlock(self.getParentOperation(), block); 2839 }, 2840 "Creates and returns a new Block before this block " 2841 "(with given argument types).") 2842 .def( 2843 "create_after", 2844 [](PyBlock &self, py::args pyArgTypes) { 2845 self.checkValid(); 2846 llvm::SmallVector<MlirType, 4> argTypes; 2847 llvm::SmallVector<MlirLocation, 4> argLocs; 2848 argTypes.reserve(pyArgTypes.size()); 2849 argLocs.reserve(pyArgTypes.size()); 2850 for (auto &pyArg : pyArgTypes) { 2851 argTypes.push_back(pyArg.cast<PyType &>()); 2852 2853 // TODO: Pass in a proper location here. 2854 argLocs.push_back( 2855 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2856 } 2857 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2858 argLocs.data()); 2859 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2860 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2861 return PyBlock(self.getParentOperation(), block); 2862 }, 2863 "Creates and returns a new Block after this block " 2864 "(with given argument types).") 2865 .def( 2866 "__iter__", 2867 [](PyBlock &self) { 2868 self.checkValid(); 2869 MlirOperation firstOperation = 2870 mlirBlockGetFirstOperation(self.get()); 2871 return PyOperationIterator(self.getParentOperation(), 2872 firstOperation); 2873 }, 2874 "Iterates over operations in the block.") 2875 .def("__eq__", 2876 [](PyBlock &self, PyBlock &other) { 2877 return self.get().ptr == other.get().ptr; 2878 }) 2879 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2880 .def( 2881 "__str__", 2882 [](PyBlock &self) { 2883 self.checkValid(); 2884 PyPrintAccumulator printAccum; 2885 mlirBlockPrint(self.get(), printAccum.getCallback(), 2886 printAccum.getUserData()); 2887 return printAccum.join(); 2888 }, 2889 "Returns the assembly form of the block.") 2890 .def( 2891 "append", 2892 [](PyBlock &self, PyOperationBase &operation) { 2893 if (operation.getOperation().isAttached()) 2894 operation.getOperation().detachFromParent(); 2895 2896 MlirOperation mlirOperation = operation.getOperation().get(); 2897 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2898 operation.getOperation().setAttached( 2899 self.getParentOperation().getObject()); 2900 }, 2901 py::arg("operation"), 2902 "Appends an operation to this block. If the operation is currently " 2903 "in another block, it will be moved."); 2904 2905 //---------------------------------------------------------------------------- 2906 // Mapping of PyInsertionPoint. 2907 //---------------------------------------------------------------------------- 2908 2909 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2910 .def(py::init<PyBlock &>(), py::arg("block"), 2911 "Inserts after the last operation but still inside the block.") 2912 .def("__enter__", &PyInsertionPoint::contextEnter) 2913 .def("__exit__", &PyInsertionPoint::contextExit) 2914 .def_property_readonly_static( 2915 "current", 2916 [](py::object & /*class*/) { 2917 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2918 if (!ip) 2919 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2920 return ip; 2921 }, 2922 "Gets the InsertionPoint bound to the current thread or raises " 2923 "ValueError if none has been set") 2924 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2925 "Inserts before a referenced operation.") 2926 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2927 py::arg("block"), "Inserts at the beginning of the block.") 2928 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2929 py::arg("block"), "Inserts before the block terminator.") 2930 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2931 "Inserts an operation.") 2932 .def_property_readonly( 2933 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2934 "Returns the block that this InsertionPoint points to."); 2935 2936 //---------------------------------------------------------------------------- 2937 // Mapping of PyAttribute. 2938 //---------------------------------------------------------------------------- 2939 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2940 // Delegate to the PyAttribute copy constructor, which will also lifetime 2941 // extend the backing context which owns the MlirAttribute. 2942 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2943 "Casts the passed attribute to the generic Attribute") 2944 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2945 &PyAttribute::getCapsule) 2946 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2947 .def_static( 2948 "parse", 2949 [](std::string attrSpec, DefaultingPyMlirContext context) { 2950 MlirAttribute type = mlirAttributeParseGet( 2951 context->get(), toMlirStringRef(attrSpec)); 2952 // TODO: Rework error reporting once diagnostic engine is exposed 2953 // in C API. 2954 if (mlirAttributeIsNull(type)) { 2955 throw SetPyError(PyExc_ValueError, 2956 Twine("Unable to parse attribute: '") + 2957 attrSpec + "'"); 2958 } 2959 return PyAttribute(context->getRef(), type); 2960 }, 2961 py::arg("asm"), py::arg("context") = py::none(), 2962 "Parses an attribute from an assembly form") 2963 .def_property_readonly( 2964 "context", 2965 [](PyAttribute &self) { return self.getContext().getObject(); }, 2966 "Context that owns the Attribute") 2967 .def_property_readonly("type", 2968 [](PyAttribute &self) { 2969 return PyType(self.getContext()->getRef(), 2970 mlirAttributeGetType(self)); 2971 }) 2972 .def( 2973 "get_named", 2974 [](PyAttribute &self, std::string name) { 2975 return PyNamedAttribute(self, std::move(name)); 2976 }, 2977 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2978 .def("__eq__", 2979 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2980 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2981 .def("__hash__", 2982 [](PyAttribute &self) { 2983 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2984 }) 2985 .def( 2986 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2987 kDumpDocstring) 2988 .def( 2989 "__str__", 2990 [](PyAttribute &self) { 2991 PyPrintAccumulator printAccum; 2992 mlirAttributePrint(self, printAccum.getCallback(), 2993 printAccum.getUserData()); 2994 return printAccum.join(); 2995 }, 2996 "Returns the assembly form of the Attribute.") 2997 .def("__repr__", [](PyAttribute &self) { 2998 // Generally, assembly formats are not printed for __repr__ because 2999 // this can cause exceptionally long debug output and exceptions. 3000 // However, attribute values are generally considered useful and are 3001 // printed. This may need to be re-evaluated if debug dumps end up 3002 // being excessive. 3003 PyPrintAccumulator printAccum; 3004 printAccum.parts.append("Attribute("); 3005 mlirAttributePrint(self, printAccum.getCallback(), 3006 printAccum.getUserData()); 3007 printAccum.parts.append(")"); 3008 return printAccum.join(); 3009 }); 3010 3011 //---------------------------------------------------------------------------- 3012 // Mapping of PyNamedAttribute 3013 //---------------------------------------------------------------------------- 3014 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 3015 .def("__repr__", 3016 [](PyNamedAttribute &self) { 3017 PyPrintAccumulator printAccum; 3018 printAccum.parts.append("NamedAttribute("); 3019 printAccum.parts.append( 3020 py::str(mlirIdentifierStr(self.namedAttr.name).data, 3021 mlirIdentifierStr(self.namedAttr.name).length)); 3022 printAccum.parts.append("="); 3023 mlirAttributePrint(self.namedAttr.attribute, 3024 printAccum.getCallback(), 3025 printAccum.getUserData()); 3026 printAccum.parts.append(")"); 3027 return printAccum.join(); 3028 }) 3029 .def_property_readonly( 3030 "name", 3031 [](PyNamedAttribute &self) { 3032 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 3033 mlirIdentifierStr(self.namedAttr.name).length); 3034 }, 3035 "The name of the NamedAttribute binding") 3036 .def_property_readonly( 3037 "attr", 3038 [](PyNamedAttribute &self) { 3039 // TODO: When named attribute is removed/refactored, also remove 3040 // this constructor (it does an inefficient table lookup). 3041 auto contextRef = PyMlirContext::forContext( 3042 mlirAttributeGetContext(self.namedAttr.attribute)); 3043 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 3044 }, 3045 py::keep_alive<0, 1>(), 3046 "The underlying generic attribute of the NamedAttribute binding"); 3047 3048 //---------------------------------------------------------------------------- 3049 // Mapping of PyType. 3050 //---------------------------------------------------------------------------- 3051 py::class_<PyType>(m, "Type", py::module_local()) 3052 // Delegate to the PyType copy constructor, which will also lifetime 3053 // extend the backing context which owns the MlirType. 3054 .def(py::init<PyType &>(), py::arg("cast_from_type"), 3055 "Casts the passed type to the generic Type") 3056 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 3057 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 3058 .def_static( 3059 "parse", 3060 [](std::string typeSpec, DefaultingPyMlirContext context) { 3061 MlirType type = 3062 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 3063 // TODO: Rework error reporting once diagnostic engine is exposed 3064 // in C API. 3065 if (mlirTypeIsNull(type)) { 3066 throw SetPyError(PyExc_ValueError, 3067 Twine("Unable to parse type: '") + typeSpec + 3068 "'"); 3069 } 3070 return PyType(context->getRef(), type); 3071 }, 3072 py::arg("asm"), py::arg("context") = py::none(), 3073 kContextParseTypeDocstring) 3074 .def_property_readonly( 3075 "context", [](PyType &self) { return self.getContext().getObject(); }, 3076 "Context that owns the Type") 3077 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3078 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 3079 .def("__hash__", 3080 [](PyType &self) { 3081 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3082 }) 3083 .def( 3084 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3085 .def( 3086 "__str__", 3087 [](PyType &self) { 3088 PyPrintAccumulator printAccum; 3089 mlirTypePrint(self, printAccum.getCallback(), 3090 printAccum.getUserData()); 3091 return printAccum.join(); 3092 }, 3093 "Returns the assembly form of the type.") 3094 .def("__repr__", [](PyType &self) { 3095 // Generally, assembly formats are not printed for __repr__ because 3096 // this can cause exceptionally long debug output and exceptions. 3097 // However, types are an exception as they typically have compact 3098 // assembly forms and printing them is useful. 3099 PyPrintAccumulator printAccum; 3100 printAccum.parts.append("Type("); 3101 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 3102 printAccum.parts.append(")"); 3103 return printAccum.join(); 3104 }); 3105 3106 //---------------------------------------------------------------------------- 3107 // Mapping of Value. 3108 //---------------------------------------------------------------------------- 3109 py::class_<PyValue>(m, "Value", py::module_local()) 3110 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3111 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3112 .def_property_readonly( 3113 "context", 3114 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3115 "Context in which the value lives.") 3116 .def( 3117 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3118 kDumpDocstring) 3119 .def_property_readonly( 3120 "owner", 3121 [](PyValue &self) { 3122 assert(mlirOperationEqual(self.getParentOperation()->get(), 3123 mlirOpResultGetOwner(self.get())) && 3124 "expected the owner of the value in Python to match that in " 3125 "the IR"); 3126 return self.getParentOperation().getObject(); 3127 }) 3128 .def("__eq__", 3129 [](PyValue &self, PyValue &other) { 3130 return self.get().ptr == other.get().ptr; 3131 }) 3132 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 3133 .def("__hash__", 3134 [](PyValue &self) { 3135 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3136 }) 3137 .def( 3138 "__str__", 3139 [](PyValue &self) { 3140 PyPrintAccumulator printAccum; 3141 printAccum.parts.append("Value("); 3142 mlirValuePrint(self.get(), printAccum.getCallback(), 3143 printAccum.getUserData()); 3144 printAccum.parts.append(")"); 3145 return printAccum.join(); 3146 }, 3147 kValueDunderStrDocstring) 3148 .def_property_readonly("type", [](PyValue &self) { 3149 return PyType(self.getParentOperation()->getContext(), 3150 mlirValueGetType(self.get())); 3151 }); 3152 PyBlockArgument::bind(m); 3153 PyOpResult::bind(m); 3154 3155 //---------------------------------------------------------------------------- 3156 // Mapping of SymbolTable. 3157 //---------------------------------------------------------------------------- 3158 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 3159 .def(py::init<PyOperationBase &>()) 3160 .def("__getitem__", &PySymbolTable::dunderGetItem) 3161 .def("insert", &PySymbolTable::insert, py::arg("operation")) 3162 .def("erase", &PySymbolTable::erase, py::arg("operation")) 3163 .def("__delitem__", &PySymbolTable::dunderDel) 3164 .def("__contains__", 3165 [](PySymbolTable &table, const std::string &name) { 3166 return !mlirOperationIsNull(mlirSymbolTableLookup( 3167 table, mlirStringRefCreate(name.data(), name.length()))); 3168 }) 3169 // Static helpers. 3170 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3171 py::arg("symbol"), py::arg("name")) 3172 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3173 py::arg("symbol")) 3174 .def_static("get_visibility", &PySymbolTable::getVisibility, 3175 py::arg("symbol")) 3176 .def_static("set_visibility", &PySymbolTable::setVisibility, 3177 py::arg("symbol"), py::arg("visibility")) 3178 .def_static("replace_all_symbol_uses", 3179 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 3180 py::arg("new_symbol"), py::arg("from_op")) 3181 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3182 py::arg("from_op"), py::arg("all_sym_uses_visible"), 3183 py::arg("callback")); 3184 3185 // Container bindings. 3186 PyBlockArgumentList::bind(m); 3187 PyBlockIterator::bind(m); 3188 PyBlockList::bind(m); 3189 PyOperationIterator::bind(m); 3190 PyOperationList::bind(m); 3191 PyOpAttributeMap::bind(m); 3192 PyOpOperandList::bind(m); 3193 PyOpResultList::bind(m); 3194 PyRegionIterator::bind(m); 3195 PyRegionList::bind(m); 3196 3197 // Debug bindings. 3198 PyGlobalDebugFlag::bind(m); 3199 } 3200