1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 //#include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 23 #include <utility> 24 25 namespace py = pybind11; 26 using namespace mlir; 27 using namespace mlir::python; 28 29 using llvm::SmallVector; 30 using llvm::StringRef; 31 using llvm::Twine; 32 33 //------------------------------------------------------------------------------ 34 // Docstrings (trivial, non-duplicated docstrings are included inline). 35 //------------------------------------------------------------------------------ 36 37 static const char kContextParseTypeDocstring[] = 38 R"(Parses the assembly form of a type. 39 40 Returns a Type object or raises a ValueError if the type cannot be parsed. 41 42 See also: https://mlir.llvm.org/docs/LangRef/#type-system 43 )"; 44 45 static const char kContextGetCallSiteLocationDocstring[] = 46 R"(Gets a Location representing a caller and callsite)"; 47 48 static const char kContextGetFileLocationDocstring[] = 49 R"(Gets a Location representing a file, line and column)"; 50 51 static const char kContextGetFusedLocationDocstring[] = 52 R"(Gets a Location representing a fused location with optional metadata)"; 53 54 static const char kContextGetNameLocationDocString[] = 55 R"(Gets a Location representing a named location with optional child location)"; 56 57 static const char kModuleParseDocstring[] = 58 R"(Parses a module's assembly format from a string. 59 60 Returns a new MlirModule or raises a ValueError if the parsing fails. 61 62 See also: https://mlir.llvm.org/docs/LangRef/ 63 )"; 64 65 static const char kOperationCreateDocstring[] = 66 R"(Creates a new operation. 67 68 Args: 69 name: Operation name (e.g. "dialect.operation"). 70 results: Sequence of Type representing op result types. 71 attributes: Dict of str:Attribute. 72 successors: List of Block for the operation's successors. 73 regions: Number of regions to create. 74 location: A Location object (defaults to resolve from context manager). 75 ip: An InsertionPoint (defaults to resolve from context manager or set to 76 False to disable insertion, even with an insertion point set in the 77 context manager). 78 Returns: 79 A new "detached" Operation object. Detached operations can be added 80 to blocks, which causes them to become "attached." 81 )"; 82 83 static const char kOperationPrintDocstring[] = 84 R"(Prints the assembly form of the operation to a file like object. 85 86 Args: 87 file: The file like object to write to. Defaults to sys.stdout. 88 binary: Whether to write bytes (True) or str (False). Defaults to False. 89 large_elements_limit: Whether to elide elements attributes above this 90 number of elements. Defaults to None (no limit). 91 enable_debug_info: Whether to print debug/location information. Defaults 92 to False. 93 pretty_debug_info: Whether to format debug information for easier reading 94 by a human (warning: the result is unparseable). 95 print_generic_op_form: Whether to print the generic assembly forms of all 96 ops. Defaults to False. 97 use_local_Scope: Whether to print in a way that is more optimized for 98 multi-threaded access but may not be consistent with how the overall 99 module prints. 100 assume_verified: By default, if not printing generic form, the verifier 101 will be run and if it fails, generic form will be printed with a comment 102 about failed verification. While a reasonable default for interactive use, 103 for systematic use, it is often better for the caller to verify explicitly 104 and report failures in a more robust fashion. Set this to True if doing this 105 in order to avoid running a redundant verification. If the IR is actually 106 invalid, behavior is undefined. 107 )"; 108 109 static const char kOperationGetAsmDocstring[] = 110 R"(Gets the assembly form of the operation with all options available. 111 112 Args: 113 binary: Whether to return a bytes (True) or str (False) object. Defaults to 114 False. 115 ... others ...: See the print() method for common keyword arguments for 116 configuring the printout. 117 Returns: 118 Either a bytes or str object, depending on the setting of the 'binary' 119 argument. 120 )"; 121 122 static const char kOperationStrDunderDocstring[] = 123 R"(Gets the assembly form of the operation with default options. 124 125 If more advanced control over the assembly formatting or I/O options is needed, 126 use the dedicated print or get_asm method, which supports keyword arguments to 127 customize behavior. 128 )"; 129 130 static const char kDumpDocstring[] = 131 R"(Dumps a debug representation of the object to stderr.)"; 132 133 static const char kAppendBlockDocstring[] = 134 R"(Appends a new block, with argument types as positional args. 135 136 Returns: 137 The created block. 138 )"; 139 140 static const char kValueDunderStrDocstring[] = 141 R"(Returns the string form of the value. 142 143 If the value is a block argument, this is the assembly form of its type and the 144 position in the argument list. If the value is an operation result, this is 145 equivalent to printing the operation that produced it. 146 )"; 147 148 //------------------------------------------------------------------------------ 149 // Utilities. 150 //------------------------------------------------------------------------------ 151 152 /// Helper for creating an @classmethod. 153 template <class Func, typename... Args> 154 py::object classmethod(Func f, Args... args) { 155 py::object cf = py::cpp_function(f, args...); 156 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 157 } 158 159 static py::object 160 createCustomDialectWrapper(const std::string &dialectNamespace, 161 py::object dialectDescriptor) { 162 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 163 if (!dialectClass) { 164 // Use the base class. 165 return py::cast(PyDialect(std::move(dialectDescriptor))); 166 } 167 168 // Create the custom implementation. 169 return (*dialectClass)(std::move(dialectDescriptor)); 170 } 171 172 static MlirStringRef toMlirStringRef(const std::string &s) { 173 return mlirStringRefCreate(s.data(), s.size()); 174 } 175 176 /// Wrapper for the global LLVM debugging flag. 177 struct PyGlobalDebugFlag { 178 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 179 180 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } 181 182 static void bind(py::module &m) { 183 // Debug flags. 184 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 185 .def_property_static("flag", &PyGlobalDebugFlag::get, 186 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 187 } 188 }; 189 190 //------------------------------------------------------------------------------ 191 // Collections. 192 //------------------------------------------------------------------------------ 193 194 namespace { 195 196 class PyRegionIterator { 197 public: 198 PyRegionIterator(PyOperationRef operation) 199 : operation(std::move(operation)) {} 200 201 PyRegionIterator &dunderIter() { return *this; } 202 203 PyRegion dunderNext() { 204 operation->checkValid(); 205 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 206 throw py::stop_iteration(); 207 } 208 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 209 return PyRegion(operation, region); 210 } 211 212 static void bind(py::module &m) { 213 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 214 .def("__iter__", &PyRegionIterator::dunderIter) 215 .def("__next__", &PyRegionIterator::dunderNext); 216 } 217 218 private: 219 PyOperationRef operation; 220 int nextIndex = 0; 221 }; 222 223 /// Regions of an op are fixed length and indexed numerically so are represented 224 /// with a sequence-like container. 225 class PyRegionList { 226 public: 227 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 228 229 intptr_t dunderLen() { 230 operation->checkValid(); 231 return mlirOperationGetNumRegions(operation->get()); 232 } 233 234 PyRegion dunderGetItem(intptr_t index) { 235 // dunderLen checks validity. 236 if (index < 0 || index >= dunderLen()) { 237 throw SetPyError(PyExc_IndexError, 238 "attempt to access out of bounds region"); 239 } 240 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 241 return PyRegion(operation, region); 242 } 243 244 static void bind(py::module &m) { 245 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 246 .def("__len__", &PyRegionList::dunderLen) 247 .def("__getitem__", &PyRegionList::dunderGetItem); 248 } 249 250 private: 251 PyOperationRef operation; 252 }; 253 254 class PyBlockIterator { 255 public: 256 PyBlockIterator(PyOperationRef operation, MlirBlock next) 257 : operation(std::move(operation)), next(next) {} 258 259 PyBlockIterator &dunderIter() { return *this; } 260 261 PyBlock dunderNext() { 262 operation->checkValid(); 263 if (mlirBlockIsNull(next)) { 264 throw py::stop_iteration(); 265 } 266 267 PyBlock returnBlock(operation, next); 268 next = mlirBlockGetNextInRegion(next); 269 return returnBlock; 270 } 271 272 static void bind(py::module &m) { 273 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 274 .def("__iter__", &PyBlockIterator::dunderIter) 275 .def("__next__", &PyBlockIterator::dunderNext); 276 } 277 278 private: 279 PyOperationRef operation; 280 MlirBlock next; 281 }; 282 283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 284 /// we present them as a more full-featured list-like container but optimize 285 /// it for forward iteration. Blocks are always owned by a region. 286 class PyBlockList { 287 public: 288 PyBlockList(PyOperationRef operation, MlirRegion region) 289 : operation(std::move(operation)), region(region) {} 290 291 PyBlockIterator dunderIter() { 292 operation->checkValid(); 293 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 294 } 295 296 intptr_t dunderLen() { 297 operation->checkValid(); 298 intptr_t count = 0; 299 MlirBlock block = mlirRegionGetFirstBlock(region); 300 while (!mlirBlockIsNull(block)) { 301 count += 1; 302 block = mlirBlockGetNextInRegion(block); 303 } 304 return count; 305 } 306 307 PyBlock dunderGetItem(intptr_t index) { 308 operation->checkValid(); 309 if (index < 0) { 310 throw SetPyError(PyExc_IndexError, 311 "attempt to access out of bounds block"); 312 } 313 MlirBlock block = mlirRegionGetFirstBlock(region); 314 while (!mlirBlockIsNull(block)) { 315 if (index == 0) { 316 return PyBlock(operation, block); 317 } 318 block = mlirBlockGetNextInRegion(block); 319 index -= 1; 320 } 321 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 322 } 323 324 PyBlock appendBlock(const py::args &pyArgTypes) { 325 operation->checkValid(); 326 llvm::SmallVector<MlirType, 4> argTypes; 327 llvm::SmallVector<MlirLocation, 4> argLocs; 328 argTypes.reserve(pyArgTypes.size()); 329 argLocs.reserve(pyArgTypes.size()); 330 for (auto &pyArg : pyArgTypes) { 331 argTypes.push_back(pyArg.cast<PyType &>()); 332 // TODO: Pass in a proper location here. 333 argLocs.push_back( 334 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 335 } 336 337 MlirBlock block = 338 mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 339 mlirRegionAppendOwnedBlock(region, block); 340 return PyBlock(operation, block); 341 } 342 343 static void bind(py::module &m) { 344 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 345 .def("__getitem__", &PyBlockList::dunderGetItem) 346 .def("__iter__", &PyBlockList::dunderIter) 347 .def("__len__", &PyBlockList::dunderLen) 348 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 349 } 350 351 private: 352 PyOperationRef operation; 353 MlirRegion region; 354 }; 355 356 class PyOperationIterator { 357 public: 358 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 359 : parentOperation(std::move(parentOperation)), next(next) {} 360 361 PyOperationIterator &dunderIter() { return *this; } 362 363 py::object dunderNext() { 364 parentOperation->checkValid(); 365 if (mlirOperationIsNull(next)) { 366 throw py::stop_iteration(); 367 } 368 369 PyOperationRef returnOperation = 370 PyOperation::forOperation(parentOperation->getContext(), next); 371 next = mlirOperationGetNextInBlock(next); 372 return returnOperation->createOpView(); 373 } 374 375 static void bind(py::module &m) { 376 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 377 .def("__iter__", &PyOperationIterator::dunderIter) 378 .def("__next__", &PyOperationIterator::dunderNext); 379 } 380 381 private: 382 PyOperationRef parentOperation; 383 MlirOperation next; 384 }; 385 386 /// Operations are exposed by the C-API as a forward-only linked list. In 387 /// Python, we present them as a more full-featured list-like container but 388 /// optimize it for forward iteration. Iterable operations are always owned 389 /// by a block. 390 class PyOperationList { 391 public: 392 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 393 : parentOperation(std::move(parentOperation)), block(block) {} 394 395 PyOperationIterator dunderIter() { 396 parentOperation->checkValid(); 397 return PyOperationIterator(parentOperation, 398 mlirBlockGetFirstOperation(block)); 399 } 400 401 intptr_t dunderLen() { 402 parentOperation->checkValid(); 403 intptr_t count = 0; 404 MlirOperation childOp = mlirBlockGetFirstOperation(block); 405 while (!mlirOperationIsNull(childOp)) { 406 count += 1; 407 childOp = mlirOperationGetNextInBlock(childOp); 408 } 409 return count; 410 } 411 412 py::object dunderGetItem(intptr_t index) { 413 parentOperation->checkValid(); 414 if (index < 0) { 415 throw SetPyError(PyExc_IndexError, 416 "attempt to access out of bounds operation"); 417 } 418 MlirOperation childOp = mlirBlockGetFirstOperation(block); 419 while (!mlirOperationIsNull(childOp)) { 420 if (index == 0) { 421 return PyOperation::forOperation(parentOperation->getContext(), childOp) 422 ->createOpView(); 423 } 424 childOp = mlirOperationGetNextInBlock(childOp); 425 index -= 1; 426 } 427 throw SetPyError(PyExc_IndexError, 428 "attempt to access out of bounds operation"); 429 } 430 431 static void bind(py::module &m) { 432 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 433 .def("__getitem__", &PyOperationList::dunderGetItem) 434 .def("__iter__", &PyOperationList::dunderIter) 435 .def("__len__", &PyOperationList::dunderLen); 436 } 437 438 private: 439 PyOperationRef parentOperation; 440 MlirBlock block; 441 }; 442 443 } // namespace 444 445 //------------------------------------------------------------------------------ 446 // PyMlirContext 447 //------------------------------------------------------------------------------ 448 449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 450 py::gil_scoped_acquire acquire; 451 auto &liveContexts = getLiveContexts(); 452 liveContexts[context.ptr] = this; 453 } 454 455 PyMlirContext::~PyMlirContext() { 456 // Note that the only public way to construct an instance is via the 457 // forContext method, which always puts the associated handle into 458 // liveContexts. 459 py::gil_scoped_acquire acquire; 460 getLiveContexts().erase(context.ptr); 461 mlirContextDestroy(context); 462 } 463 464 py::object PyMlirContext::getCapsule() { 465 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 466 } 467 468 py::object PyMlirContext::createFromCapsule(py::object capsule) { 469 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 470 if (mlirContextIsNull(rawContext)) 471 throw py::error_already_set(); 472 return forContext(rawContext).releaseObject(); 473 } 474 475 PyMlirContext *PyMlirContext::createNewContextForInit() { 476 MlirContext context = mlirContextCreate(); 477 return new PyMlirContext(context); 478 } 479 480 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 481 py::gil_scoped_acquire acquire; 482 auto &liveContexts = getLiveContexts(); 483 auto it = liveContexts.find(context.ptr); 484 if (it == liveContexts.end()) { 485 // Create. 486 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 487 py::object pyRef = py::cast(unownedContextWrapper); 488 assert(pyRef && "cast to py::object failed"); 489 liveContexts[context.ptr] = unownedContextWrapper; 490 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 491 } 492 // Use existing. 493 py::object pyRef = py::cast(it->second); 494 return PyMlirContextRef(it->second, std::move(pyRef)); 495 } 496 497 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 498 static LiveContextMap liveContexts; 499 return liveContexts; 500 } 501 502 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 503 504 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 505 506 size_t PyMlirContext::clearLiveOperations() { 507 for (auto &op : liveOperations) 508 op.second.second->setInvalid(); 509 size_t numInvalidated = liveOperations.size(); 510 liveOperations.clear(); 511 return numInvalidated; 512 } 513 514 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 515 516 pybind11::object PyMlirContext::contextEnter() { 517 return PyThreadContextEntry::pushContext(*this); 518 } 519 520 void PyMlirContext::contextExit(const pybind11::object &excType, 521 const pybind11::object &excVal, 522 const pybind11::object &excTb) { 523 PyThreadContextEntry::popContext(*this); 524 } 525 526 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { 527 // Note that ownership is transferred to the delete callback below by way of 528 // an explicit inc_ref (borrow). 529 PyDiagnosticHandler *pyHandler = 530 new PyDiagnosticHandler(get(), std::move(callback)); 531 py::object pyHandlerObject = 532 py::cast(pyHandler, py::return_value_policy::take_ownership); 533 pyHandlerObject.inc_ref(); 534 535 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 536 // guaranteed to be known to pybind. 537 auto handlerCallback = 538 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 539 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 540 py::object pyDiagnosticObject = 541 py::cast(pyDiagnostic, py::return_value_policy::take_ownership); 542 543 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 544 bool result = false; 545 { 546 // Since this can be called from arbitrary C++ contexts, always get the 547 // gil. 548 py::gil_scoped_acquire gil; 549 try { 550 result = py::cast<bool>(pyHandler->callback(pyDiagnostic)); 551 } catch (std::exception &e) { 552 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 553 e.what()); 554 pyHandler->hadError = true; 555 } 556 } 557 558 pyDiagnostic->invalidate(); 559 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 560 }; 561 auto deleteCallback = +[](void *userData) { 562 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 563 assert(pyHandler->registeredID && "handler is not registered"); 564 pyHandler->registeredID.reset(); 565 566 // Decrement reference, balancing the inc_ref() above. 567 py::object pyHandlerObject = 568 py::cast(pyHandler, py::return_value_policy::reference); 569 pyHandlerObject.dec_ref(); 570 }; 571 572 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 573 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 574 return pyHandlerObject; 575 } 576 577 PyMlirContext &DefaultingPyMlirContext::resolve() { 578 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 579 if (!context) { 580 throw SetPyError( 581 PyExc_RuntimeError, 582 "An MLIR function requires a Context but none was provided in the call " 583 "or from the surrounding environment. Either pass to the function with " 584 "a 'context=' argument or establish a default using 'with Context():'"); 585 } 586 return *context; 587 } 588 589 //------------------------------------------------------------------------------ 590 // PyThreadContextEntry management 591 //------------------------------------------------------------------------------ 592 593 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 594 static thread_local std::vector<PyThreadContextEntry> stack; 595 return stack; 596 } 597 598 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 599 auto &stack = getStack(); 600 if (stack.empty()) 601 return nullptr; 602 return &stack.back(); 603 } 604 605 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 606 py::object insertionPoint, 607 py::object location) { 608 auto &stack = getStack(); 609 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 610 std::move(location)); 611 // If the new stack has more than one entry and the context of the new top 612 // entry matches the previous, copy the insertionPoint and location from the 613 // previous entry if missing from the new top entry. 614 if (stack.size() > 1) { 615 auto &prev = *(stack.rbegin() + 1); 616 auto ¤t = stack.back(); 617 if (current.context.is(prev.context)) { 618 // Default non-context objects from the previous entry. 619 if (!current.insertionPoint) 620 current.insertionPoint = prev.insertionPoint; 621 if (!current.location) 622 current.location = prev.location; 623 } 624 } 625 } 626 627 PyMlirContext *PyThreadContextEntry::getContext() { 628 if (!context) 629 return nullptr; 630 return py::cast<PyMlirContext *>(context); 631 } 632 633 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 634 if (!insertionPoint) 635 return nullptr; 636 return py::cast<PyInsertionPoint *>(insertionPoint); 637 } 638 639 PyLocation *PyThreadContextEntry::getLocation() { 640 if (!location) 641 return nullptr; 642 return py::cast<PyLocation *>(location); 643 } 644 645 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 646 auto *tos = getTopOfStack(); 647 return tos ? tos->getContext() : nullptr; 648 } 649 650 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 651 auto *tos = getTopOfStack(); 652 return tos ? tos->getInsertionPoint() : nullptr; 653 } 654 655 PyLocation *PyThreadContextEntry::getDefaultLocation() { 656 auto *tos = getTopOfStack(); 657 return tos ? tos->getLocation() : nullptr; 658 } 659 660 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 661 py::object contextObj = py::cast(context); 662 push(FrameKind::Context, /*context=*/contextObj, 663 /*insertionPoint=*/py::object(), 664 /*location=*/py::object()); 665 return contextObj; 666 } 667 668 void PyThreadContextEntry::popContext(PyMlirContext &context) { 669 auto &stack = getStack(); 670 if (stack.empty()) 671 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 672 auto &tos = stack.back(); 673 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 674 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 675 stack.pop_back(); 676 } 677 678 py::object 679 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 680 py::object contextObj = 681 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 682 py::object insertionPointObj = py::cast(insertionPoint); 683 push(FrameKind::InsertionPoint, 684 /*context=*/contextObj, 685 /*insertionPoint=*/insertionPointObj, 686 /*location=*/py::object()); 687 return insertionPointObj; 688 } 689 690 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 691 auto &stack = getStack(); 692 if (stack.empty()) 693 throw SetPyError(PyExc_RuntimeError, 694 "Unbalanced InsertionPoint enter/exit"); 695 auto &tos = stack.back(); 696 if (tos.frameKind != FrameKind::InsertionPoint && 697 tos.getInsertionPoint() != &insertionPoint) 698 throw SetPyError(PyExc_RuntimeError, 699 "Unbalanced InsertionPoint enter/exit"); 700 stack.pop_back(); 701 } 702 703 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 704 py::object contextObj = location.getContext().getObject(); 705 py::object locationObj = py::cast(location); 706 push(FrameKind::Location, /*context=*/contextObj, 707 /*insertionPoint=*/py::object(), 708 /*location=*/locationObj); 709 return locationObj; 710 } 711 712 void PyThreadContextEntry::popLocation(PyLocation &location) { 713 auto &stack = getStack(); 714 if (stack.empty()) 715 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 716 auto &tos = stack.back(); 717 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 718 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 719 stack.pop_back(); 720 } 721 722 //------------------------------------------------------------------------------ 723 // PyDiagnostic* 724 //------------------------------------------------------------------------------ 725 726 void PyDiagnostic::invalidate() { 727 valid = false; 728 if (materializedNotes) { 729 for (auto ¬eObject : *materializedNotes) { 730 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject); 731 note->invalidate(); 732 } 733 } 734 } 735 736 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 737 py::object callback) 738 : context(context), callback(std::move(callback)) {} 739 740 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 741 742 void PyDiagnosticHandler::detach() { 743 if (!registeredID) 744 return; 745 MlirDiagnosticHandlerID localID = *registeredID; 746 mlirContextDetachDiagnosticHandler(context, localID); 747 assert(!registeredID && "should have unregistered"); 748 // Not strictly necessary but keeps stale pointers from being around to cause 749 // issues. 750 context = {nullptr}; 751 } 752 753 void PyDiagnostic::checkValid() { 754 if (!valid) { 755 throw std::invalid_argument( 756 "Diagnostic is invalid (used outside of callback)"); 757 } 758 } 759 760 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 761 checkValid(); 762 return mlirDiagnosticGetSeverity(diagnostic); 763 } 764 765 PyLocation PyDiagnostic::getLocation() { 766 checkValid(); 767 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 768 MlirContext context = mlirLocationGetContext(loc); 769 return PyLocation(PyMlirContext::forContext(context), loc); 770 } 771 772 py::str PyDiagnostic::getMessage() { 773 checkValid(); 774 py::object fileObject = py::module::import("io").attr("StringIO")(); 775 PyFileAccumulator accum(fileObject, /*binary=*/false); 776 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 777 return fileObject.attr("getvalue")(); 778 } 779 780 py::tuple PyDiagnostic::getNotes() { 781 checkValid(); 782 if (materializedNotes) 783 return *materializedNotes; 784 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 785 materializedNotes = py::tuple(numNotes); 786 for (intptr_t i = 0; i < numNotes; ++i) { 787 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 788 py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); 789 PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); 790 } 791 return *materializedNotes; 792 } 793 794 //------------------------------------------------------------------------------ 795 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry 796 //------------------------------------------------------------------------------ 797 798 MlirDialect PyDialects::getDialectForKey(const std::string &key, 799 bool attrError) { 800 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 801 {key.data(), key.size()}); 802 if (mlirDialectIsNull(dialect)) { 803 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 804 Twine("Dialect '") + key + "' not found"); 805 } 806 return dialect; 807 } 808 809 py::object PyDialectRegistry::getCapsule() { 810 return py::reinterpret_steal<py::object>( 811 mlirPythonDialectRegistryToCapsule(*this)); 812 } 813 814 PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { 815 MlirDialectRegistry rawRegistry = 816 mlirPythonCapsuleToDialectRegistry(capsule.ptr()); 817 if (mlirDialectRegistryIsNull(rawRegistry)) 818 throw py::error_already_set(); 819 return PyDialectRegistry(rawRegistry); 820 } 821 822 //------------------------------------------------------------------------------ 823 // PyLocation 824 //------------------------------------------------------------------------------ 825 826 py::object PyLocation::getCapsule() { 827 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 828 } 829 830 PyLocation PyLocation::createFromCapsule(py::object capsule) { 831 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 832 if (mlirLocationIsNull(rawLoc)) 833 throw py::error_already_set(); 834 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 835 rawLoc); 836 } 837 838 py::object PyLocation::contextEnter() { 839 return PyThreadContextEntry::pushLocation(*this); 840 } 841 842 void PyLocation::contextExit(const pybind11::object &excType, 843 const pybind11::object &excVal, 844 const pybind11::object &excTb) { 845 PyThreadContextEntry::popLocation(*this); 846 } 847 848 PyLocation &DefaultingPyLocation::resolve() { 849 auto *location = PyThreadContextEntry::getDefaultLocation(); 850 if (!location) { 851 throw SetPyError( 852 PyExc_RuntimeError, 853 "An MLIR function requires a Location but none was provided in the " 854 "call or from the surrounding environment. Either pass to the function " 855 "with a 'loc=' argument or establish a default using 'with loc:'"); 856 } 857 return *location; 858 } 859 860 //------------------------------------------------------------------------------ 861 // PyModule 862 //------------------------------------------------------------------------------ 863 864 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 865 : BaseContextObject(std::move(contextRef)), module(module) {} 866 867 PyModule::~PyModule() { 868 py::gil_scoped_acquire acquire; 869 auto &liveModules = getContext()->liveModules; 870 assert(liveModules.count(module.ptr) == 1 && 871 "destroying module not in live map"); 872 liveModules.erase(module.ptr); 873 mlirModuleDestroy(module); 874 } 875 876 PyModuleRef PyModule::forModule(MlirModule module) { 877 MlirContext context = mlirModuleGetContext(module); 878 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 879 880 py::gil_scoped_acquire acquire; 881 auto &liveModules = contextRef->liveModules; 882 auto it = liveModules.find(module.ptr); 883 if (it == liveModules.end()) { 884 // Create. 885 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 886 // Note that the default return value policy on cast is automatic_reference, 887 // which does not take ownership (delete will not be called). 888 // Just be explicit. 889 py::object pyRef = 890 py::cast(unownedModule, py::return_value_policy::take_ownership); 891 unownedModule->handle = pyRef; 892 liveModules[module.ptr] = 893 std::make_pair(unownedModule->handle, unownedModule); 894 return PyModuleRef(unownedModule, std::move(pyRef)); 895 } 896 // Use existing. 897 PyModule *existing = it->second.second; 898 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 899 return PyModuleRef(existing, std::move(pyRef)); 900 } 901 902 py::object PyModule::createFromCapsule(py::object capsule) { 903 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 904 if (mlirModuleIsNull(rawModule)) 905 throw py::error_already_set(); 906 return forModule(rawModule).releaseObject(); 907 } 908 909 py::object PyModule::getCapsule() { 910 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 911 } 912 913 //------------------------------------------------------------------------------ 914 // PyOperation 915 //------------------------------------------------------------------------------ 916 917 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 918 : BaseContextObject(std::move(contextRef)), operation(operation) {} 919 920 PyOperation::~PyOperation() { 921 // If the operation has already been invalidated there is nothing to do. 922 if (!valid) 923 return; 924 auto &liveOperations = getContext()->liveOperations; 925 assert(liveOperations.count(operation.ptr) == 1 && 926 "destroying operation not in live map"); 927 liveOperations.erase(operation.ptr); 928 if (!isAttached()) { 929 mlirOperationDestroy(operation); 930 } 931 } 932 933 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 934 MlirOperation operation, 935 py::object parentKeepAlive) { 936 auto &liveOperations = contextRef->liveOperations; 937 // Create. 938 PyOperation *unownedOperation = 939 new PyOperation(std::move(contextRef), operation); 940 // Note that the default return value policy on cast is automatic_reference, 941 // which does not take ownership (delete will not be called). 942 // Just be explicit. 943 py::object pyRef = 944 py::cast(unownedOperation, py::return_value_policy::take_ownership); 945 unownedOperation->handle = pyRef; 946 if (parentKeepAlive) { 947 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 948 } 949 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 950 return PyOperationRef(unownedOperation, std::move(pyRef)); 951 } 952 953 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 954 MlirOperation operation, 955 py::object parentKeepAlive) { 956 auto &liveOperations = contextRef->liveOperations; 957 auto it = liveOperations.find(operation.ptr); 958 if (it == liveOperations.end()) { 959 // Create. 960 return createInstance(std::move(contextRef), operation, 961 std::move(parentKeepAlive)); 962 } 963 // Use existing. 964 PyOperation *existing = it->second.second; 965 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 966 return PyOperationRef(existing, std::move(pyRef)); 967 } 968 969 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 970 MlirOperation operation, 971 py::object parentKeepAlive) { 972 auto &liveOperations = contextRef->liveOperations; 973 assert(liveOperations.count(operation.ptr) == 0 && 974 "cannot create detached operation that already exists"); 975 (void)liveOperations; 976 977 PyOperationRef created = createInstance(std::move(contextRef), operation, 978 std::move(parentKeepAlive)); 979 created->attached = false; 980 return created; 981 } 982 983 void PyOperation::checkValid() const { 984 if (!valid) { 985 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 986 } 987 } 988 989 void PyOperationBase::print(py::object fileObject, bool binary, 990 llvm::Optional<int64_t> largeElementsLimit, 991 bool enableDebugInfo, bool prettyDebugInfo, 992 bool printGenericOpForm, bool useLocalScope, 993 bool assumeVerified) { 994 PyOperation &operation = getOperation(); 995 operation.checkValid(); 996 if (fileObject.is_none()) 997 fileObject = py::module::import("sys").attr("stdout"); 998 999 if (!assumeVerified && !printGenericOpForm && 1000 !mlirOperationVerify(operation)) { 1001 std::string message("// Verification failed, printing generic form\n"); 1002 if (binary) { 1003 fileObject.attr("write")(py::bytes(message)); 1004 } else { 1005 fileObject.attr("write")(py::str(message)); 1006 } 1007 printGenericOpForm = true; 1008 } 1009 1010 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 1011 if (largeElementsLimit) 1012 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 1013 if (enableDebugInfo) 1014 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 1015 if (printGenericOpForm) 1016 mlirOpPrintingFlagsPrintGenericOpForm(flags); 1017 if (useLocalScope) 1018 mlirOpPrintingFlagsUseLocalScope(flags); 1019 1020 PyFileAccumulator accum(fileObject, binary); 1021 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1022 accum.getUserData()); 1023 mlirOpPrintingFlagsDestroy(flags); 1024 } 1025 1026 py::object PyOperationBase::getAsm(bool binary, 1027 llvm::Optional<int64_t> largeElementsLimit, 1028 bool enableDebugInfo, bool prettyDebugInfo, 1029 bool printGenericOpForm, bool useLocalScope, 1030 bool assumeVerified) { 1031 py::object fileObject; 1032 if (binary) { 1033 fileObject = py::module::import("io").attr("BytesIO")(); 1034 } else { 1035 fileObject = py::module::import("io").attr("StringIO")(); 1036 } 1037 print(fileObject, /*binary=*/binary, 1038 /*largeElementsLimit=*/largeElementsLimit, 1039 /*enableDebugInfo=*/enableDebugInfo, 1040 /*prettyDebugInfo=*/prettyDebugInfo, 1041 /*printGenericOpForm=*/printGenericOpForm, 1042 /*useLocalScope=*/useLocalScope, 1043 /*assumeVerified=*/assumeVerified); 1044 1045 return fileObject.attr("getvalue")(); 1046 } 1047 1048 void PyOperationBase::moveAfter(PyOperationBase &other) { 1049 PyOperation &operation = getOperation(); 1050 PyOperation &otherOp = other.getOperation(); 1051 operation.checkValid(); 1052 otherOp.checkValid(); 1053 mlirOperationMoveAfter(operation, otherOp); 1054 operation.parentKeepAlive = otherOp.parentKeepAlive; 1055 } 1056 1057 void PyOperationBase::moveBefore(PyOperationBase &other) { 1058 PyOperation &operation = getOperation(); 1059 PyOperation &otherOp = other.getOperation(); 1060 operation.checkValid(); 1061 otherOp.checkValid(); 1062 mlirOperationMoveBefore(operation, otherOp); 1063 operation.parentKeepAlive = otherOp.parentKeepAlive; 1064 } 1065 1066 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 1067 checkValid(); 1068 if (!isAttached()) 1069 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 1070 MlirOperation operation = mlirOperationGetParentOperation(get()); 1071 if (mlirOperationIsNull(operation)) 1072 return {}; 1073 return PyOperation::forOperation(getContext(), operation); 1074 } 1075 1076 PyBlock PyOperation::getBlock() { 1077 checkValid(); 1078 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 1079 MlirBlock block = mlirOperationGetBlock(get()); 1080 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1081 assert(parentOperation && "Operation has no parent"); 1082 return PyBlock{std::move(*parentOperation), block}; 1083 } 1084 1085 py::object PyOperation::getCapsule() { 1086 checkValid(); 1087 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 1088 } 1089 1090 py::object PyOperation::createFromCapsule(py::object capsule) { 1091 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1092 if (mlirOperationIsNull(rawOperation)) 1093 throw py::error_already_set(); 1094 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1095 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1096 .releaseObject(); 1097 } 1098 1099 static void maybeInsertOperation(PyOperationRef &op, 1100 const py::object &maybeIp) { 1101 // InsertPoint active? 1102 if (!maybeIp.is(py::cast(false))) { 1103 PyInsertionPoint *ip; 1104 if (maybeIp.is_none()) { 1105 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1106 } else { 1107 ip = py::cast<PyInsertionPoint *>(maybeIp); 1108 } 1109 if (ip) 1110 ip->insert(*op.get()); 1111 } 1112 } 1113 1114 py::object PyOperation::create( 1115 const std::string &name, llvm::Optional<std::vector<PyType *>> results, 1116 llvm::Optional<std::vector<PyValue *>> operands, 1117 llvm::Optional<py::dict> attributes, 1118 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 1119 DefaultingPyLocation location, const py::object &maybeIp) { 1120 llvm::SmallVector<MlirValue, 4> mlirOperands; 1121 llvm::SmallVector<MlirType, 4> mlirResults; 1122 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1123 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1124 1125 // General parameter validation. 1126 if (regions < 0) 1127 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 1128 1129 // Unpack/validate operands. 1130 if (operands) { 1131 mlirOperands.reserve(operands->size()); 1132 for (PyValue *operand : *operands) { 1133 if (!operand) 1134 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 1135 mlirOperands.push_back(operand->get()); 1136 } 1137 } 1138 1139 // Unpack/validate results. 1140 if (results) { 1141 mlirResults.reserve(results->size()); 1142 for (PyType *result : *results) { 1143 // TODO: Verify result type originate from the same context. 1144 if (!result) 1145 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 1146 mlirResults.push_back(*result); 1147 } 1148 } 1149 // Unpack/validate attributes. 1150 if (attributes) { 1151 mlirAttributes.reserve(attributes->size()); 1152 for (auto &it : *attributes) { 1153 std::string key; 1154 try { 1155 key = it.first.cast<std::string>(); 1156 } catch (py::cast_error &err) { 1157 std::string msg = "Invalid attribute key (not a string) when " 1158 "attempting to create the operation \"" + 1159 name + "\" (" + err.what() + ")"; 1160 throw py::cast_error(msg); 1161 } 1162 try { 1163 auto &attribute = it.second.cast<PyAttribute &>(); 1164 // TODO: Verify attribute originates from the same context. 1165 mlirAttributes.emplace_back(std::move(key), attribute); 1166 } catch (py::reference_cast_error &) { 1167 // This exception seems thrown when the value is "None". 1168 std::string msg = 1169 "Found an invalid (`None`?) attribute value for the key \"" + key + 1170 "\" when attempting to create the operation \"" + name + "\""; 1171 throw py::cast_error(msg); 1172 } catch (py::cast_error &err) { 1173 std::string msg = "Invalid attribute value for the key \"" + key + 1174 "\" when attempting to create the operation \"" + 1175 name + "\" (" + err.what() + ")"; 1176 throw py::cast_error(msg); 1177 } 1178 } 1179 } 1180 // Unpack/validate successors. 1181 if (successors) { 1182 mlirSuccessors.reserve(successors->size()); 1183 for (auto *successor : *successors) { 1184 // TODO: Verify successor originate from the same context. 1185 if (!successor) 1186 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1187 mlirSuccessors.push_back(successor->get()); 1188 } 1189 } 1190 1191 // Apply unpacked/validated to the operation state. Beyond this 1192 // point, exceptions cannot be thrown or else the state will leak. 1193 MlirOperationState state = 1194 mlirOperationStateGet(toMlirStringRef(name), location); 1195 if (!mlirOperands.empty()) 1196 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1197 mlirOperands.data()); 1198 if (!mlirResults.empty()) 1199 mlirOperationStateAddResults(&state, mlirResults.size(), 1200 mlirResults.data()); 1201 if (!mlirAttributes.empty()) { 1202 // Note that the attribute names directly reference bytes in 1203 // mlirAttributes, so that vector must not be changed from here 1204 // on. 1205 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1206 mlirNamedAttributes.reserve(mlirAttributes.size()); 1207 for (auto &it : mlirAttributes) 1208 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1209 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1210 toMlirStringRef(it.first)), 1211 it.second)); 1212 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1213 mlirNamedAttributes.data()); 1214 } 1215 if (!mlirSuccessors.empty()) 1216 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1217 mlirSuccessors.data()); 1218 if (regions) { 1219 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1220 mlirRegions.resize(regions); 1221 for (int i = 0; i < regions; ++i) 1222 mlirRegions[i] = mlirRegionCreate(); 1223 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1224 mlirRegions.data()); 1225 } 1226 1227 // Construct the operation. 1228 MlirOperation operation = mlirOperationCreate(&state); 1229 PyOperationRef created = 1230 PyOperation::createDetached(location->getContext(), operation); 1231 maybeInsertOperation(created, maybeIp); 1232 1233 return created->createOpView(); 1234 } 1235 1236 py::object PyOperation::clone(const py::object &maybeIp) { 1237 MlirOperation clonedOperation = mlirOperationClone(operation); 1238 PyOperationRef cloned = 1239 PyOperation::createDetached(getContext(), clonedOperation); 1240 maybeInsertOperation(cloned, maybeIp); 1241 1242 return cloned->createOpView(); 1243 } 1244 1245 py::object PyOperation::createOpView() { 1246 checkValid(); 1247 MlirIdentifier ident = mlirOperationGetName(get()); 1248 MlirStringRef identStr = mlirIdentifierStr(ident); 1249 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1250 StringRef(identStr.data, identStr.length)); 1251 if (opViewClass) 1252 return (*opViewClass)(getRef().getObject()); 1253 return py::cast(PyOpView(getRef().getObject())); 1254 } 1255 1256 void PyOperation::erase() { 1257 checkValid(); 1258 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1259 // Python reference to a child operation is live. All children should also 1260 // have their `valid` bit set to false. 1261 auto &liveOperations = getContext()->liveOperations; 1262 if (liveOperations.count(operation.ptr)) 1263 liveOperations.erase(operation.ptr); 1264 mlirOperationDestroy(operation); 1265 valid = false; 1266 } 1267 1268 //------------------------------------------------------------------------------ 1269 // PyOpView 1270 //------------------------------------------------------------------------------ 1271 1272 py::object PyOpView::buildGeneric( 1273 const py::object &cls, py::list resultTypeList, py::list operandList, 1274 llvm::Optional<py::dict> attributes, 1275 llvm::Optional<std::vector<PyBlock *>> successors, 1276 llvm::Optional<int> regions, DefaultingPyLocation location, 1277 const py::object &maybeIp) { 1278 PyMlirContextRef context = location->getContext(); 1279 // Class level operation construction metadata. 1280 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1281 // Operand and result segment specs are either none, which does no 1282 // variadic unpacking, or a list of ints with segment sizes, where each 1283 // element is either a positive number (typically 1 for a scalar) or -1 to 1284 // indicate that it is derived from the length of the same-indexed operand 1285 // or result (implying that it is a list at that position). 1286 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1287 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1288 1289 std::vector<uint32_t> operandSegmentLengths; 1290 std::vector<uint32_t> resultSegmentLengths; 1291 1292 // Validate/determine region count. 1293 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1294 int opMinRegionCount = std::get<0>(opRegionSpec); 1295 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1296 if (!regions) { 1297 regions = opMinRegionCount; 1298 } 1299 if (*regions < opMinRegionCount) { 1300 throw py::value_error( 1301 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1302 llvm::Twine(opMinRegionCount) + 1303 " regions but was built with regions=" + llvm::Twine(*regions)) 1304 .str()); 1305 } 1306 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1307 throw py::value_error( 1308 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1309 llvm::Twine(opMinRegionCount) + 1310 " regions but was built with regions=" + llvm::Twine(*regions)) 1311 .str()); 1312 } 1313 1314 // Unpack results. 1315 std::vector<PyType *> resultTypes; 1316 resultTypes.reserve(resultTypeList.size()); 1317 if (resultSegmentSpecObj.is_none()) { 1318 // Non-variadic result unpacking. 1319 for (const auto &it : llvm::enumerate(resultTypeList)) { 1320 try { 1321 resultTypes.push_back(py::cast<PyType *>(it.value())); 1322 if (!resultTypes.back()) 1323 throw py::cast_error(); 1324 } catch (py::cast_error &err) { 1325 throw py::value_error((llvm::Twine("Result ") + 1326 llvm::Twine(it.index()) + " of operation \"" + 1327 name + "\" must be a Type (" + err.what() + ")") 1328 .str()); 1329 } 1330 } 1331 } else { 1332 // Sized result unpacking. 1333 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1334 if (resultSegmentSpec.size() != resultTypeList.size()) { 1335 throw py::value_error((llvm::Twine("Operation \"") + name + 1336 "\" requires " + 1337 llvm::Twine(resultSegmentSpec.size()) + 1338 " result segments but was provided " + 1339 llvm::Twine(resultTypeList.size())) 1340 .str()); 1341 } 1342 resultSegmentLengths.reserve(resultTypeList.size()); 1343 for (const auto &it : 1344 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1345 int segmentSpec = std::get<1>(it.value()); 1346 if (segmentSpec == 1 || segmentSpec == 0) { 1347 // Unpack unary element. 1348 try { 1349 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1350 if (resultType) { 1351 resultTypes.push_back(resultType); 1352 resultSegmentLengths.push_back(1); 1353 } else if (segmentSpec == 0) { 1354 // Allowed to be optional. 1355 resultSegmentLengths.push_back(0); 1356 } else { 1357 throw py::cast_error("was None and result is not optional"); 1358 } 1359 } catch (py::cast_error &err) { 1360 throw py::value_error((llvm::Twine("Result ") + 1361 llvm::Twine(it.index()) + " of operation \"" + 1362 name + "\" must be a Type (" + err.what() + 1363 ")") 1364 .str()); 1365 } 1366 } else if (segmentSpec == -1) { 1367 // Unpack sequence by appending. 1368 try { 1369 if (std::get<0>(it.value()).is_none()) { 1370 // Treat it as an empty list. 1371 resultSegmentLengths.push_back(0); 1372 } else { 1373 // Unpack the list. 1374 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1375 for (py::object segmentItem : segment) { 1376 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1377 if (!resultTypes.back()) { 1378 throw py::cast_error("contained a None item"); 1379 } 1380 } 1381 resultSegmentLengths.push_back(segment.size()); 1382 } 1383 } catch (std::exception &err) { 1384 // NOTE: Sloppy to be using a catch-all here, but there are at least 1385 // three different unrelated exceptions that can be thrown in the 1386 // above "casts". Just keep the scope above small and catch them all. 1387 throw py::value_error((llvm::Twine("Result ") + 1388 llvm::Twine(it.index()) + " of operation \"" + 1389 name + "\" must be a Sequence of Types (" + 1390 err.what() + ")") 1391 .str()); 1392 } 1393 } else { 1394 throw py::value_error("Unexpected segment spec"); 1395 } 1396 } 1397 } 1398 1399 // Unpack operands. 1400 std::vector<PyValue *> operands; 1401 operands.reserve(operands.size()); 1402 if (operandSegmentSpecObj.is_none()) { 1403 // Non-sized operand unpacking. 1404 for (const auto &it : llvm::enumerate(operandList)) { 1405 try { 1406 operands.push_back(py::cast<PyValue *>(it.value())); 1407 if (!operands.back()) 1408 throw py::cast_error(); 1409 } catch (py::cast_error &err) { 1410 throw py::value_error((llvm::Twine("Operand ") + 1411 llvm::Twine(it.index()) + " of operation \"" + 1412 name + "\" must be a Value (" + err.what() + ")") 1413 .str()); 1414 } 1415 } 1416 } else { 1417 // Sized operand unpacking. 1418 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1419 if (operandSegmentSpec.size() != operandList.size()) { 1420 throw py::value_error((llvm::Twine("Operation \"") + name + 1421 "\" requires " + 1422 llvm::Twine(operandSegmentSpec.size()) + 1423 "operand segments but was provided " + 1424 llvm::Twine(operandList.size())) 1425 .str()); 1426 } 1427 operandSegmentLengths.reserve(operandList.size()); 1428 for (const auto &it : 1429 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1430 int segmentSpec = std::get<1>(it.value()); 1431 if (segmentSpec == 1 || segmentSpec == 0) { 1432 // Unpack unary element. 1433 try { 1434 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1435 if (operandValue) { 1436 operands.push_back(operandValue); 1437 operandSegmentLengths.push_back(1); 1438 } else if (segmentSpec == 0) { 1439 // Allowed to be optional. 1440 operandSegmentLengths.push_back(0); 1441 } else { 1442 throw py::cast_error("was None and operand is not optional"); 1443 } 1444 } catch (py::cast_error &err) { 1445 throw py::value_error((llvm::Twine("Operand ") + 1446 llvm::Twine(it.index()) + " of operation \"" + 1447 name + "\" must be a Value (" + err.what() + 1448 ")") 1449 .str()); 1450 } 1451 } else if (segmentSpec == -1) { 1452 // Unpack sequence by appending. 1453 try { 1454 if (std::get<0>(it.value()).is_none()) { 1455 // Treat it as an empty list. 1456 operandSegmentLengths.push_back(0); 1457 } else { 1458 // Unpack the list. 1459 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1460 for (py::object segmentItem : segment) { 1461 operands.push_back(py::cast<PyValue *>(segmentItem)); 1462 if (!operands.back()) { 1463 throw py::cast_error("contained a None item"); 1464 } 1465 } 1466 operandSegmentLengths.push_back(segment.size()); 1467 } 1468 } catch (std::exception &err) { 1469 // NOTE: Sloppy to be using a catch-all here, but there are at least 1470 // three different unrelated exceptions that can be thrown in the 1471 // above "casts". Just keep the scope above small and catch them all. 1472 throw py::value_error((llvm::Twine("Operand ") + 1473 llvm::Twine(it.index()) + " of operation \"" + 1474 name + "\" must be a Sequence of Values (" + 1475 err.what() + ")") 1476 .str()); 1477 } 1478 } else { 1479 throw py::value_error("Unexpected segment spec"); 1480 } 1481 } 1482 } 1483 1484 // Merge operand/result segment lengths into attributes if needed. 1485 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1486 // Dup. 1487 if (attributes) { 1488 attributes = py::dict(*attributes); 1489 } else { 1490 attributes = py::dict(); 1491 } 1492 if (attributes->contains("result_segment_sizes") || 1493 attributes->contains("operand_segment_sizes")) { 1494 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1495 "'operand_segment_sizes' attribute is unsupported. " 1496 "Use Operation.create for such low-level access."); 1497 } 1498 1499 // Add result_segment_sizes attribute. 1500 if (!resultSegmentLengths.empty()) { 1501 int64_t size = resultSegmentLengths.size(); 1502 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1503 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1504 resultSegmentLengths.size(), resultSegmentLengths.data()); 1505 (*attributes)["result_segment_sizes"] = 1506 PyAttribute(context, segmentLengthAttr); 1507 } 1508 1509 // Add operand_segment_sizes attribute. 1510 if (!operandSegmentLengths.empty()) { 1511 int64_t size = operandSegmentLengths.size(); 1512 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1513 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1514 operandSegmentLengths.size(), operandSegmentLengths.data()); 1515 (*attributes)["operand_segment_sizes"] = 1516 PyAttribute(context, segmentLengthAttr); 1517 } 1518 } 1519 1520 // Delegate to create. 1521 return PyOperation::create(name, 1522 /*results=*/std::move(resultTypes), 1523 /*operands=*/std::move(operands), 1524 /*attributes=*/std::move(attributes), 1525 /*successors=*/std::move(successors), 1526 /*regions=*/*regions, location, maybeIp); 1527 } 1528 1529 PyOpView::PyOpView(const py::object &operationObject) 1530 // Casting through the PyOperationBase base-class and then back to the 1531 // Operation lets us accept any PyOperationBase subclass. 1532 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1533 operationObject(operation.getRef().getObject()) {} 1534 1535 py::object PyOpView::createRawSubclass(const py::object &userClass) { 1536 // This is... a little gross. The typical pattern is to have a pure python 1537 // class that extends OpView like: 1538 // class AddFOp(_cext.ir.OpView): 1539 // def __init__(self, loc, lhs, rhs): 1540 // operation = loc.context.create_operation( 1541 // "addf", lhs, rhs, results=[lhs.type]) 1542 // super().__init__(operation) 1543 // 1544 // I.e. The goal of the user facing type is to provide a nice constructor 1545 // that has complete freedom for the op under construction. This is at odds 1546 // with our other desire to sometimes create this object by just passing an 1547 // operation (to initialize the base class). We could do *arg and **kwargs 1548 // munging to try to make it work, but instead, we synthesize a new class 1549 // on the fly which extends this user class (AddFOp in this example) and 1550 // *give it* the base class's __init__ method, thus bypassing the 1551 // intermediate subclass's __init__ method entirely. While slightly, 1552 // underhanded, this is safe/legal because the type hierarchy has not changed 1553 // (we just added a new leaf) and we aren't mucking around with __new__. 1554 // Typically, this new class will be stored on the original as "_Raw" and will 1555 // be used for casts and other things that need a variant of the class that 1556 // is initialized purely from an operation. 1557 py::object parentMetaclass = 1558 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1559 py::dict attributes; 1560 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1561 // now. 1562 // auto opViewType = py::type::of<PyOpView>(); 1563 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1564 attributes["__init__"] = opViewType.attr("__init__"); 1565 py::str origName = userClass.attr("__name__"); 1566 py::str newName = py::str("_") + origName; 1567 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1568 } 1569 1570 //------------------------------------------------------------------------------ 1571 // PyInsertionPoint. 1572 //------------------------------------------------------------------------------ 1573 1574 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1575 1576 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1577 : refOperation(beforeOperationBase.getOperation().getRef()), 1578 block((*refOperation)->getBlock()) {} 1579 1580 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1581 PyOperation &operation = operationBase.getOperation(); 1582 if (operation.isAttached()) 1583 throw SetPyError(PyExc_ValueError, 1584 "Attempt to insert operation that is already attached"); 1585 block.getParentOperation()->checkValid(); 1586 MlirOperation beforeOp = {nullptr}; 1587 if (refOperation) { 1588 // Insert before operation. 1589 (*refOperation)->checkValid(); 1590 beforeOp = (*refOperation)->get(); 1591 } else { 1592 // Insert at end (before null) is only valid if the block does not 1593 // already end in a known terminator (violating this will cause assertion 1594 // failures later). 1595 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1596 throw py::index_error("Cannot insert operation at the end of a block " 1597 "that already has a terminator. Did you mean to " 1598 "use 'InsertionPoint.at_block_terminator(block)' " 1599 "versus 'InsertionPoint(block)'?"); 1600 } 1601 } 1602 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1603 operation.setAttached(); 1604 } 1605 1606 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1607 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1608 if (mlirOperationIsNull(firstOp)) { 1609 // Just insert at end. 1610 return PyInsertionPoint(block); 1611 } 1612 1613 // Insert before first op. 1614 PyOperationRef firstOpRef = PyOperation::forOperation( 1615 block.getParentOperation()->getContext(), firstOp); 1616 return PyInsertionPoint{block, std::move(firstOpRef)}; 1617 } 1618 1619 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1620 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1621 if (mlirOperationIsNull(terminator)) 1622 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1623 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1624 block.getParentOperation()->getContext(), terminator); 1625 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1626 } 1627 1628 py::object PyInsertionPoint::contextEnter() { 1629 return PyThreadContextEntry::pushInsertionPoint(*this); 1630 } 1631 1632 void PyInsertionPoint::contextExit(const pybind11::object &excType, 1633 const pybind11::object &excVal, 1634 const pybind11::object &excTb) { 1635 PyThreadContextEntry::popInsertionPoint(*this); 1636 } 1637 1638 //------------------------------------------------------------------------------ 1639 // PyAttribute. 1640 //------------------------------------------------------------------------------ 1641 1642 bool PyAttribute::operator==(const PyAttribute &other) { 1643 return mlirAttributeEqual(attr, other.attr); 1644 } 1645 1646 py::object PyAttribute::getCapsule() { 1647 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1648 } 1649 1650 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1651 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1652 if (mlirAttributeIsNull(rawAttr)) 1653 throw py::error_already_set(); 1654 return PyAttribute( 1655 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1656 } 1657 1658 //------------------------------------------------------------------------------ 1659 // PyNamedAttribute. 1660 //------------------------------------------------------------------------------ 1661 1662 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1663 : ownedName(new std::string(std::move(ownedName))) { 1664 namedAttr = mlirNamedAttributeGet( 1665 mlirIdentifierGet(mlirAttributeGetContext(attr), 1666 toMlirStringRef(*this->ownedName)), 1667 attr); 1668 } 1669 1670 //------------------------------------------------------------------------------ 1671 // PyType. 1672 //------------------------------------------------------------------------------ 1673 1674 bool PyType::operator==(const PyType &other) { 1675 return mlirTypeEqual(type, other.type); 1676 } 1677 1678 py::object PyType::getCapsule() { 1679 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1680 } 1681 1682 PyType PyType::createFromCapsule(py::object capsule) { 1683 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1684 if (mlirTypeIsNull(rawType)) 1685 throw py::error_already_set(); 1686 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1687 rawType); 1688 } 1689 1690 //------------------------------------------------------------------------------ 1691 // PyValue and subclases. 1692 //------------------------------------------------------------------------------ 1693 1694 pybind11::object PyValue::getCapsule() { 1695 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1696 } 1697 1698 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1699 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1700 if (mlirValueIsNull(value)) 1701 throw py::error_already_set(); 1702 MlirOperation owner; 1703 if (mlirValueIsAOpResult(value)) 1704 owner = mlirOpResultGetOwner(value); 1705 if (mlirValueIsABlockArgument(value)) 1706 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1707 if (mlirOperationIsNull(owner)) 1708 throw py::error_already_set(); 1709 MlirContext ctx = mlirOperationGetContext(owner); 1710 PyOperationRef ownerRef = 1711 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1712 return PyValue(ownerRef, value); 1713 } 1714 1715 //------------------------------------------------------------------------------ 1716 // PySymbolTable. 1717 //------------------------------------------------------------------------------ 1718 1719 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1720 : operation(operation.getOperation().getRef()) { 1721 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1722 if (mlirSymbolTableIsNull(symbolTable)) { 1723 throw py::cast_error("Operation is not a Symbol Table."); 1724 } 1725 } 1726 1727 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1728 operation->checkValid(); 1729 MlirOperation symbol = mlirSymbolTableLookup( 1730 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1731 if (mlirOperationIsNull(symbol)) 1732 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1733 1734 return PyOperation::forOperation(operation->getContext(), symbol, 1735 operation.getObject()) 1736 ->createOpView(); 1737 } 1738 1739 void PySymbolTable::erase(PyOperationBase &symbol) { 1740 operation->checkValid(); 1741 symbol.getOperation().checkValid(); 1742 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1743 // The operation is also erased, so we must invalidate it. There may be Python 1744 // references to this operation so we don't want to delete it from the list of 1745 // live operations here. 1746 symbol.getOperation().valid = false; 1747 } 1748 1749 void PySymbolTable::dunderDel(const std::string &name) { 1750 py::object operation = dunderGetItem(name); 1751 erase(py::cast<PyOperationBase &>(operation)); 1752 } 1753 1754 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1755 operation->checkValid(); 1756 symbol.getOperation().checkValid(); 1757 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1758 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1759 if (mlirAttributeIsNull(symbolAttr)) 1760 throw py::value_error("Expected operation to have a symbol name."); 1761 return PyAttribute( 1762 symbol.getOperation().getContext(), 1763 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1764 } 1765 1766 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1767 // Op must already be a symbol. 1768 PyOperation &operation = symbol.getOperation(); 1769 operation.checkValid(); 1770 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1771 MlirAttribute existingNameAttr = 1772 mlirOperationGetAttributeByName(operation.get(), attrName); 1773 if (mlirAttributeIsNull(existingNameAttr)) 1774 throw py::value_error("Expected operation to have a symbol name."); 1775 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1776 } 1777 1778 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1779 const std::string &name) { 1780 // Op must already be a symbol. 1781 PyOperation &operation = symbol.getOperation(); 1782 operation.checkValid(); 1783 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1784 MlirAttribute existingNameAttr = 1785 mlirOperationGetAttributeByName(operation.get(), attrName); 1786 if (mlirAttributeIsNull(existingNameAttr)) 1787 throw py::value_error("Expected operation to have a symbol name."); 1788 MlirAttribute newNameAttr = 1789 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1790 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1791 } 1792 1793 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1794 PyOperation &operation = symbol.getOperation(); 1795 operation.checkValid(); 1796 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1797 MlirAttribute existingVisAttr = 1798 mlirOperationGetAttributeByName(operation.get(), attrName); 1799 if (mlirAttributeIsNull(existingVisAttr)) 1800 throw py::value_error("Expected operation to have a symbol visibility."); 1801 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1802 } 1803 1804 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1805 const std::string &visibility) { 1806 if (visibility != "public" && visibility != "private" && 1807 visibility != "nested") 1808 throw py::value_error( 1809 "Expected visibility to be 'public', 'private' or 'nested'"); 1810 PyOperation &operation = symbol.getOperation(); 1811 operation.checkValid(); 1812 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1813 MlirAttribute existingVisAttr = 1814 mlirOperationGetAttributeByName(operation.get(), attrName); 1815 if (mlirAttributeIsNull(existingVisAttr)) 1816 throw py::value_error("Expected operation to have a symbol visibility."); 1817 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1818 toMlirStringRef(visibility)); 1819 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1820 } 1821 1822 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1823 const std::string &newSymbol, 1824 PyOperationBase &from) { 1825 PyOperation &fromOperation = from.getOperation(); 1826 fromOperation.checkValid(); 1827 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1828 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1829 from.getOperation()))) 1830 1831 throw py::value_error("Symbol rename failed"); 1832 } 1833 1834 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1835 bool allSymUsesVisible, 1836 py::object callback) { 1837 PyOperation &fromOperation = from.getOperation(); 1838 fromOperation.checkValid(); 1839 struct UserData { 1840 PyMlirContextRef context; 1841 py::object callback; 1842 bool gotException; 1843 std::string exceptionWhat; 1844 py::object exceptionType; 1845 }; 1846 UserData userData{ 1847 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1848 mlirSymbolTableWalkSymbolTables( 1849 fromOperation.get(), allSymUsesVisible, 1850 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1851 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1852 auto pyFoundOp = 1853 PyOperation::forOperation(calleeUserData->context, foundOp); 1854 if (calleeUserData->gotException) 1855 return; 1856 try { 1857 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1858 } catch (py::error_already_set &e) { 1859 calleeUserData->gotException = true; 1860 calleeUserData->exceptionWhat = e.what(); 1861 calleeUserData->exceptionType = e.type(); 1862 } 1863 }, 1864 static_cast<void *>(&userData)); 1865 if (userData.gotException) { 1866 std::string message("Exception raised in callback: "); 1867 message.append(userData.exceptionWhat); 1868 throw std::runtime_error(message); 1869 } 1870 } 1871 1872 namespace { 1873 /// CRTP base class for Python MLIR values that subclass Value and should be 1874 /// castable from it. The value hierarchy is one level deep and is not supposed 1875 /// to accommodate other levels unless core MLIR changes. 1876 template <typename DerivedTy> 1877 class PyConcreteValue : public PyValue { 1878 public: 1879 // Derived classes must define statics for: 1880 // IsAFunctionTy isaFunction 1881 // const char *pyClassName 1882 // and redefine bindDerived. 1883 using ClassTy = py::class_<DerivedTy, PyValue>; 1884 using IsAFunctionTy = bool (*)(MlirValue); 1885 1886 PyConcreteValue() = default; 1887 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1888 : PyValue(operationRef, value) {} 1889 PyConcreteValue(PyValue &orig) 1890 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1891 1892 /// Attempts to cast the original value to the derived type and throws on 1893 /// type mismatches. 1894 static MlirValue castFrom(PyValue &orig) { 1895 if (!DerivedTy::isaFunction(orig.get())) { 1896 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1897 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1898 DerivedTy::pyClassName + 1899 " (from " + origRepr + ")"); 1900 } 1901 return orig.get(); 1902 } 1903 1904 /// Binds the Python module objects to functions of this class. 1905 static void bind(py::module &m) { 1906 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1907 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1908 cls.def_static( 1909 "isinstance", 1910 [](PyValue &otherValue) -> bool { 1911 return DerivedTy::isaFunction(otherValue); 1912 }, 1913 py::arg("other_value")); 1914 DerivedTy::bindDerived(cls); 1915 } 1916 1917 /// Implemented by derived classes to add methods to the Python subclass. 1918 static void bindDerived(ClassTy &m) {} 1919 }; 1920 1921 /// Python wrapper for MlirBlockArgument. 1922 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1923 public: 1924 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1925 static constexpr const char *pyClassName = "BlockArgument"; 1926 using PyConcreteValue::PyConcreteValue; 1927 1928 static void bindDerived(ClassTy &c) { 1929 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1930 return PyBlock(self.getParentOperation(), 1931 mlirBlockArgumentGetOwner(self.get())); 1932 }); 1933 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1934 return mlirBlockArgumentGetArgNumber(self.get()); 1935 }); 1936 c.def( 1937 "set_type", 1938 [](PyBlockArgument &self, PyType type) { 1939 return mlirBlockArgumentSetType(self.get(), type); 1940 }, 1941 py::arg("type")); 1942 } 1943 }; 1944 1945 /// Python wrapper for MlirOpResult. 1946 class PyOpResult : public PyConcreteValue<PyOpResult> { 1947 public: 1948 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1949 static constexpr const char *pyClassName = "OpResult"; 1950 using PyConcreteValue::PyConcreteValue; 1951 1952 static void bindDerived(ClassTy &c) { 1953 c.def_property_readonly("owner", [](PyOpResult &self) { 1954 assert( 1955 mlirOperationEqual(self.getParentOperation()->get(), 1956 mlirOpResultGetOwner(self.get())) && 1957 "expected the owner of the value in Python to match that in the IR"); 1958 return self.getParentOperation().getObject(); 1959 }); 1960 c.def_property_readonly("result_number", [](PyOpResult &self) { 1961 return mlirOpResultGetResultNumber(self.get()); 1962 }); 1963 } 1964 }; 1965 1966 /// Returns the list of types of the values held by container. 1967 template <typename Container> 1968 static std::vector<PyType> getValueTypes(Container &container, 1969 PyMlirContextRef &context) { 1970 std::vector<PyType> result; 1971 result.reserve(container.getNumElements()); 1972 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1973 result.push_back( 1974 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1975 } 1976 return result; 1977 } 1978 1979 /// A list of block arguments. Internally, these are stored as consecutive 1980 /// elements, random access is cheap. The argument list is associated with the 1981 /// operation that contains the block (detached blocks are not allowed in 1982 /// Python bindings) and extends its lifetime. 1983 class PyBlockArgumentList 1984 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1985 public: 1986 static constexpr const char *pyClassName = "BlockArgumentList"; 1987 1988 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1989 intptr_t startIndex = 0, intptr_t length = -1, 1990 intptr_t step = 1) 1991 : Sliceable(startIndex, 1992 length == -1 ? mlirBlockGetNumArguments(block) : length, 1993 step), 1994 operation(std::move(operation)), block(block) {} 1995 1996 /// Returns the number of arguments in the list. 1997 intptr_t getNumElements() { 1998 operation->checkValid(); 1999 return mlirBlockGetNumArguments(block); 2000 } 2001 2002 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 2003 PyBlockArgument getElement(intptr_t pos) { 2004 MlirValue argument = mlirBlockGetArgument(block, pos); 2005 return PyBlockArgument(operation, argument); 2006 } 2007 2008 /// Returns a sublist of this list. 2009 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 2010 intptr_t step) { 2011 return PyBlockArgumentList(operation, block, startIndex, length, step); 2012 } 2013 2014 static void bindDerived(ClassTy &c) { 2015 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 2016 return getValueTypes(self, self.operation->getContext()); 2017 }); 2018 } 2019 2020 private: 2021 PyOperationRef operation; 2022 MlirBlock block; 2023 }; 2024 2025 /// A list of operation operands. Internally, these are stored as consecutive 2026 /// elements, random access is cheap. The result list is associated with the 2027 /// operation whose results these are, and extends the lifetime of this 2028 /// operation. 2029 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 2030 public: 2031 static constexpr const char *pyClassName = "OpOperandList"; 2032 2033 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2034 intptr_t length = -1, intptr_t step = 1) 2035 : Sliceable(startIndex, 2036 length == -1 ? mlirOperationGetNumOperands(operation->get()) 2037 : length, 2038 step), 2039 operation(operation) {} 2040 2041 intptr_t getNumElements() { 2042 operation->checkValid(); 2043 return mlirOperationGetNumOperands(operation->get()); 2044 } 2045 2046 PyValue getElement(intptr_t pos) { 2047 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2048 MlirOperation owner; 2049 if (mlirValueIsAOpResult(operand)) 2050 owner = mlirOpResultGetOwner(operand); 2051 else if (mlirValueIsABlockArgument(operand)) 2052 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2053 else 2054 assert(false && "Value must be an block arg or op result."); 2055 PyOperationRef pyOwner = 2056 PyOperation::forOperation(operation->getContext(), owner); 2057 return PyValue(pyOwner, operand); 2058 } 2059 2060 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2061 return PyOpOperandList(operation, startIndex, length, step); 2062 } 2063 2064 void dunderSetItem(intptr_t index, PyValue value) { 2065 index = wrapIndex(index); 2066 mlirOperationSetOperand(operation->get(), index, value.get()); 2067 } 2068 2069 static void bindDerived(ClassTy &c) { 2070 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2071 } 2072 2073 private: 2074 PyOperationRef operation; 2075 }; 2076 2077 /// A list of operation results. Internally, these are stored as consecutive 2078 /// elements, random access is cheap. The result list is associated with the 2079 /// operation whose results these are, and extends the lifetime of this 2080 /// operation. 2081 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 2082 public: 2083 static constexpr const char *pyClassName = "OpResultList"; 2084 2085 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 2086 intptr_t length = -1, intptr_t step = 1) 2087 : Sliceable(startIndex, 2088 length == -1 ? mlirOperationGetNumResults(operation->get()) 2089 : length, 2090 step), 2091 operation(operation) {} 2092 2093 intptr_t getNumElements() { 2094 operation->checkValid(); 2095 return mlirOperationGetNumResults(operation->get()); 2096 } 2097 2098 PyOpResult getElement(intptr_t index) { 2099 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 2100 return PyOpResult(value); 2101 } 2102 2103 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2104 return PyOpResultList(operation, startIndex, length, step); 2105 } 2106 2107 static void bindDerived(ClassTy &c) { 2108 c.def_property_readonly("types", [](PyOpResultList &self) { 2109 return getValueTypes(self, self.operation->getContext()); 2110 }); 2111 } 2112 2113 private: 2114 PyOperationRef operation; 2115 }; 2116 2117 /// A list of operation attributes. Can be indexed by name, producing 2118 /// attributes, or by index, producing named attributes. 2119 class PyOpAttributeMap { 2120 public: 2121 PyOpAttributeMap(PyOperationRef operation) 2122 : operation(std::move(operation)) {} 2123 2124 PyAttribute dunderGetItemNamed(const std::string &name) { 2125 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2126 toMlirStringRef(name)); 2127 if (mlirAttributeIsNull(attr)) { 2128 throw SetPyError(PyExc_KeyError, 2129 "attempt to access a non-existent attribute"); 2130 } 2131 return PyAttribute(operation->getContext(), attr); 2132 } 2133 2134 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2135 if (index < 0 || index >= dunderLen()) { 2136 throw SetPyError(PyExc_IndexError, 2137 "attempt to access out of bounds attribute"); 2138 } 2139 MlirNamedAttribute namedAttr = 2140 mlirOperationGetAttribute(operation->get(), index); 2141 return PyNamedAttribute( 2142 namedAttr.attribute, 2143 std::string(mlirIdentifierStr(namedAttr.name).data, 2144 mlirIdentifierStr(namedAttr.name).length)); 2145 } 2146 2147 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2148 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2149 attr); 2150 } 2151 2152 void dunderDelItem(const std::string &name) { 2153 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2154 toMlirStringRef(name)); 2155 if (!removed) 2156 throw SetPyError(PyExc_KeyError, 2157 "attempt to delete a non-existent attribute"); 2158 } 2159 2160 intptr_t dunderLen() { 2161 return mlirOperationGetNumAttributes(operation->get()); 2162 } 2163 2164 bool dunderContains(const std::string &name) { 2165 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2166 operation->get(), toMlirStringRef(name))); 2167 } 2168 2169 static void bind(py::module &m) { 2170 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2171 .def("__contains__", &PyOpAttributeMap::dunderContains) 2172 .def("__len__", &PyOpAttributeMap::dunderLen) 2173 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2174 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2175 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2176 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2177 } 2178 2179 private: 2180 PyOperationRef operation; 2181 }; 2182 2183 } // namespace 2184 2185 //------------------------------------------------------------------------------ 2186 // Populates the core exports of the 'ir' submodule. 2187 //------------------------------------------------------------------------------ 2188 2189 void mlir::python::populateIRCore(py::module &m) { 2190 //---------------------------------------------------------------------------- 2191 // Enums. 2192 //---------------------------------------------------------------------------- 2193 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local()) 2194 .value("ERROR", MlirDiagnosticError) 2195 .value("WARNING", MlirDiagnosticWarning) 2196 .value("NOTE", MlirDiagnosticNote) 2197 .value("REMARK", MlirDiagnosticRemark); 2198 2199 //---------------------------------------------------------------------------- 2200 // Mapping of Diagnostics. 2201 //---------------------------------------------------------------------------- 2202 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local()) 2203 .def_property_readonly("severity", &PyDiagnostic::getSeverity) 2204 .def_property_readonly("location", &PyDiagnostic::getLocation) 2205 .def_property_readonly("message", &PyDiagnostic::getMessage) 2206 .def_property_readonly("notes", &PyDiagnostic::getNotes) 2207 .def("__str__", [](PyDiagnostic &self) -> py::str { 2208 if (!self.isValid()) 2209 return "<Invalid Diagnostic>"; 2210 return self.getMessage(); 2211 }); 2212 2213 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local()) 2214 .def("detach", &PyDiagnosticHandler::detach) 2215 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) 2216 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) 2217 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2218 .def("__exit__", &PyDiagnosticHandler::contextExit); 2219 2220 //---------------------------------------------------------------------------- 2221 // Mapping of MlirContext. 2222 // Note that this is exported as _BaseContext. The containing, Python level 2223 // __init__.py will subclass it with site-specific functionality and set a 2224 // "Context" attribute on this module. 2225 //---------------------------------------------------------------------------- 2226 py::class_<PyMlirContext>(m, "_BaseContext", py::module_local()) 2227 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2228 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2229 .def("_get_context_again", 2230 [](PyMlirContext &self) { 2231 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2232 return ref.releaseObject(); 2233 }) 2234 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2235 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) 2236 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2237 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2238 &PyMlirContext::getCapsule) 2239 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2240 .def("__enter__", &PyMlirContext::contextEnter) 2241 .def("__exit__", &PyMlirContext::contextExit) 2242 .def_property_readonly_static( 2243 "current", 2244 [](py::object & /*class*/) { 2245 auto *context = PyThreadContextEntry::getDefaultContext(); 2246 if (!context) 2247 throw SetPyError(PyExc_ValueError, "No current Context"); 2248 return context; 2249 }, 2250 "Gets the Context bound to the current thread or raises ValueError") 2251 .def_property_readonly( 2252 "dialects", 2253 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2254 "Gets a container for accessing dialects by name") 2255 .def_property_readonly( 2256 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2257 "Alias for 'dialect'") 2258 .def( 2259 "get_dialect_descriptor", 2260 [=](PyMlirContext &self, std::string &name) { 2261 MlirDialect dialect = mlirContextGetOrLoadDialect( 2262 self.get(), {name.data(), name.size()}); 2263 if (mlirDialectIsNull(dialect)) { 2264 throw SetPyError(PyExc_ValueError, 2265 Twine("Dialect '") + name + "' not found"); 2266 } 2267 return PyDialectDescriptor(self.getRef(), dialect); 2268 }, 2269 py::arg("dialect_name"), 2270 "Gets or loads a dialect by name, returning its descriptor object") 2271 .def_property( 2272 "allow_unregistered_dialects", 2273 [](PyMlirContext &self) -> bool { 2274 return mlirContextGetAllowUnregisteredDialects(self.get()); 2275 }, 2276 [](PyMlirContext &self, bool value) { 2277 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2278 }) 2279 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2280 py::arg("callback"), 2281 "Attaches a diagnostic handler that will receive callbacks") 2282 .def( 2283 "enable_multithreading", 2284 [](PyMlirContext &self, bool enable) { 2285 mlirContextEnableMultithreading(self.get(), enable); 2286 }, 2287 py::arg("enable")) 2288 .def( 2289 "is_registered_operation", 2290 [](PyMlirContext &self, std::string &name) { 2291 return mlirContextIsRegisteredOperation( 2292 self.get(), MlirStringRef{name.data(), name.size()}); 2293 }, 2294 py::arg("operation_name")) 2295 .def( 2296 "append_dialect_registry", 2297 [](PyMlirContext &self, PyDialectRegistry ®istry) { 2298 mlirContextAppendDialectRegistry(self.get(), registry); 2299 }, 2300 py::arg("registry")) 2301 .def("load_all_available_dialects", [](PyMlirContext &self) { 2302 mlirContextLoadAllAvailableDialects(self.get()); 2303 }); 2304 2305 //---------------------------------------------------------------------------- 2306 // Mapping of PyDialectDescriptor 2307 //---------------------------------------------------------------------------- 2308 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2309 .def_property_readonly("namespace", 2310 [](PyDialectDescriptor &self) { 2311 MlirStringRef ns = 2312 mlirDialectGetNamespace(self.get()); 2313 return py::str(ns.data, ns.length); 2314 }) 2315 .def("__repr__", [](PyDialectDescriptor &self) { 2316 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2317 std::string repr("<DialectDescriptor "); 2318 repr.append(ns.data, ns.length); 2319 repr.append(">"); 2320 return repr; 2321 }); 2322 2323 //---------------------------------------------------------------------------- 2324 // Mapping of PyDialects 2325 //---------------------------------------------------------------------------- 2326 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2327 .def("__getitem__", 2328 [=](PyDialects &self, std::string keyName) { 2329 MlirDialect dialect = 2330 self.getDialectForKey(keyName, /*attrError=*/false); 2331 py::object descriptor = 2332 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2333 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2334 }) 2335 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2336 MlirDialect dialect = 2337 self.getDialectForKey(attrName, /*attrError=*/true); 2338 py::object descriptor = 2339 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2340 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2341 }); 2342 2343 //---------------------------------------------------------------------------- 2344 // Mapping of PyDialect 2345 //---------------------------------------------------------------------------- 2346 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2347 .def(py::init<py::object>(), py::arg("descriptor")) 2348 .def_property_readonly( 2349 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2350 .def("__repr__", [](py::object self) { 2351 auto clazz = self.attr("__class__"); 2352 return py::str("<Dialect ") + 2353 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2354 clazz.attr("__module__") + py::str(".") + 2355 clazz.attr("__name__") + py::str(")>"); 2356 }); 2357 2358 //---------------------------------------------------------------------------- 2359 // Mapping of PyDialectRegistry 2360 //---------------------------------------------------------------------------- 2361 py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local()) 2362 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2363 &PyDialectRegistry::getCapsule) 2364 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) 2365 .def(py::init<>()); 2366 2367 //---------------------------------------------------------------------------- 2368 // Mapping of Location 2369 //---------------------------------------------------------------------------- 2370 py::class_<PyLocation>(m, "Location", py::module_local()) 2371 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2372 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2373 .def("__enter__", &PyLocation::contextEnter) 2374 .def("__exit__", &PyLocation::contextExit) 2375 .def("__eq__", 2376 [](PyLocation &self, PyLocation &other) -> bool { 2377 return mlirLocationEqual(self, other); 2378 }) 2379 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2380 .def_property_readonly_static( 2381 "current", 2382 [](py::object & /*class*/) { 2383 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2384 if (!loc) 2385 throw SetPyError(PyExc_ValueError, "No current Location"); 2386 return loc; 2387 }, 2388 "Gets the Location bound to the current thread or raises ValueError") 2389 .def_static( 2390 "unknown", 2391 [](DefaultingPyMlirContext context) { 2392 return PyLocation(context->getRef(), 2393 mlirLocationUnknownGet(context->get())); 2394 }, 2395 py::arg("context") = py::none(), 2396 "Gets a Location representing an unknown location") 2397 .def_static( 2398 "callsite", 2399 [](PyLocation callee, const std::vector<PyLocation> &frames, 2400 DefaultingPyMlirContext context) { 2401 if (frames.empty()) 2402 throw py::value_error("No caller frames provided"); 2403 MlirLocation caller = frames.back().get(); 2404 for (const PyLocation &frame : 2405 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2406 caller = mlirLocationCallSiteGet(frame.get(), caller); 2407 return PyLocation(context->getRef(), 2408 mlirLocationCallSiteGet(callee.get(), caller)); 2409 }, 2410 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2411 kContextGetCallSiteLocationDocstring) 2412 .def_static( 2413 "file", 2414 [](std::string filename, int line, int col, 2415 DefaultingPyMlirContext context) { 2416 return PyLocation( 2417 context->getRef(), 2418 mlirLocationFileLineColGet( 2419 context->get(), toMlirStringRef(filename), line, col)); 2420 }, 2421 py::arg("filename"), py::arg("line"), py::arg("col"), 2422 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2423 .def_static( 2424 "fused", 2425 [](const std::vector<PyLocation> &pyLocations, 2426 llvm::Optional<PyAttribute> metadata, 2427 DefaultingPyMlirContext context) { 2428 llvm::SmallVector<MlirLocation, 4> locations; 2429 locations.reserve(pyLocations.size()); 2430 for (auto &pyLocation : pyLocations) 2431 locations.push_back(pyLocation.get()); 2432 MlirLocation location = mlirLocationFusedGet( 2433 context->get(), locations.size(), locations.data(), 2434 metadata ? metadata->get() : MlirAttribute{0}); 2435 return PyLocation(context->getRef(), location); 2436 }, 2437 py::arg("locations"), py::arg("metadata") = py::none(), 2438 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2439 .def_static( 2440 "name", 2441 [](std::string name, llvm::Optional<PyLocation> childLoc, 2442 DefaultingPyMlirContext context) { 2443 return PyLocation( 2444 context->getRef(), 2445 mlirLocationNameGet( 2446 context->get(), toMlirStringRef(name), 2447 childLoc ? childLoc->get() 2448 : mlirLocationUnknownGet(context->get()))); 2449 }, 2450 py::arg("name"), py::arg("childLoc") = py::none(), 2451 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2452 .def_property_readonly( 2453 "context", 2454 [](PyLocation &self) { return self.getContext().getObject(); }, 2455 "Context that owns the Location") 2456 .def( 2457 "emit_error", 2458 [](PyLocation &self, std::string message) { 2459 mlirEmitError(self, message.c_str()); 2460 }, 2461 py::arg("message"), "Emits an error at this location") 2462 .def("__repr__", [](PyLocation &self) { 2463 PyPrintAccumulator printAccum; 2464 mlirLocationPrint(self, printAccum.getCallback(), 2465 printAccum.getUserData()); 2466 return printAccum.join(); 2467 }); 2468 2469 //---------------------------------------------------------------------------- 2470 // Mapping of Module 2471 //---------------------------------------------------------------------------- 2472 py::class_<PyModule>(m, "Module", py::module_local()) 2473 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2474 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2475 .def_static( 2476 "parse", 2477 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2478 MlirModule module = mlirModuleCreateParse( 2479 context->get(), toMlirStringRef(moduleAsm)); 2480 // TODO: Rework error reporting once diagnostic engine is exposed 2481 // in C API. 2482 if (mlirModuleIsNull(module)) { 2483 throw SetPyError( 2484 PyExc_ValueError, 2485 "Unable to parse module assembly (see diagnostics)"); 2486 } 2487 return PyModule::forModule(module).releaseObject(); 2488 }, 2489 py::arg("asm"), py::arg("context") = py::none(), 2490 kModuleParseDocstring) 2491 .def_static( 2492 "create", 2493 [](DefaultingPyLocation loc) { 2494 MlirModule module = mlirModuleCreateEmpty(loc); 2495 return PyModule::forModule(module).releaseObject(); 2496 }, 2497 py::arg("loc") = py::none(), "Creates an empty module") 2498 .def_property_readonly( 2499 "context", 2500 [](PyModule &self) { return self.getContext().getObject(); }, 2501 "Context that created the Module") 2502 .def_property_readonly( 2503 "operation", 2504 [](PyModule &self) { 2505 return PyOperation::forOperation(self.getContext(), 2506 mlirModuleGetOperation(self.get()), 2507 self.getRef().releaseObject()) 2508 .releaseObject(); 2509 }, 2510 "Accesses the module as an operation") 2511 .def_property_readonly( 2512 "body", 2513 [](PyModule &self) { 2514 PyOperationRef moduleOp = PyOperation::forOperation( 2515 self.getContext(), mlirModuleGetOperation(self.get()), 2516 self.getRef().releaseObject()); 2517 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 2518 return returnBlock; 2519 }, 2520 "Return the block for this module") 2521 .def( 2522 "dump", 2523 [](PyModule &self) { 2524 mlirOperationDump(mlirModuleGetOperation(self.get())); 2525 }, 2526 kDumpDocstring) 2527 .def( 2528 "__str__", 2529 [](py::object self) { 2530 // Defer to the operation's __str__. 2531 return self.attr("operation").attr("__str__")(); 2532 }, 2533 kOperationStrDunderDocstring); 2534 2535 //---------------------------------------------------------------------------- 2536 // Mapping of Operation. 2537 //---------------------------------------------------------------------------- 2538 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2539 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2540 [](PyOperationBase &self) { 2541 return self.getOperation().getCapsule(); 2542 }) 2543 .def("__eq__", 2544 [](PyOperationBase &self, PyOperationBase &other) { 2545 return &self.getOperation() == &other.getOperation(); 2546 }) 2547 .def("__eq__", 2548 [](PyOperationBase &self, py::object other) { return false; }) 2549 .def("__hash__", 2550 [](PyOperationBase &self) { 2551 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2552 }) 2553 .def_property_readonly("attributes", 2554 [](PyOperationBase &self) { 2555 return PyOpAttributeMap( 2556 self.getOperation().getRef()); 2557 }) 2558 .def_property_readonly("operands", 2559 [](PyOperationBase &self) { 2560 return PyOpOperandList( 2561 self.getOperation().getRef()); 2562 }) 2563 .def_property_readonly("regions", 2564 [](PyOperationBase &self) { 2565 return PyRegionList( 2566 self.getOperation().getRef()); 2567 }) 2568 .def_property_readonly( 2569 "results", 2570 [](PyOperationBase &self) { 2571 return PyOpResultList(self.getOperation().getRef()); 2572 }, 2573 "Returns the list of Operation results.") 2574 .def_property_readonly( 2575 "result", 2576 [](PyOperationBase &self) { 2577 auto &operation = self.getOperation(); 2578 auto numResults = mlirOperationGetNumResults(operation); 2579 if (numResults != 1) { 2580 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2581 throw SetPyError( 2582 PyExc_ValueError, 2583 Twine("Cannot call .result on operation ") + 2584 StringRef(name.data, name.length) + " which has " + 2585 Twine(numResults) + 2586 " results (it is only valid for operations with a " 2587 "single result)"); 2588 } 2589 return PyOpResult(operation.getRef(), 2590 mlirOperationGetResult(operation, 0)); 2591 }, 2592 "Shortcut to get an op result if it has only one (throws an error " 2593 "otherwise).") 2594 .def_property_readonly( 2595 "location", 2596 [](PyOperationBase &self) { 2597 PyOperation &operation = self.getOperation(); 2598 return PyLocation(operation.getContext(), 2599 mlirOperationGetLocation(operation.get())); 2600 }, 2601 "Returns the source location the operation was defined or derived " 2602 "from.") 2603 .def( 2604 "__str__", 2605 [](PyOperationBase &self) { 2606 return self.getAsm(/*binary=*/false, 2607 /*largeElementsLimit=*/llvm::None, 2608 /*enableDebugInfo=*/false, 2609 /*prettyDebugInfo=*/false, 2610 /*printGenericOpForm=*/false, 2611 /*useLocalScope=*/false, 2612 /*assumeVerified=*/false); 2613 }, 2614 "Returns the assembly form of the operation.") 2615 .def("print", &PyOperationBase::print, 2616 // Careful: Lots of arguments must match up with print method. 2617 py::arg("file") = py::none(), py::arg("binary") = false, 2618 py::arg("large_elements_limit") = py::none(), 2619 py::arg("enable_debug_info") = false, 2620 py::arg("pretty_debug_info") = false, 2621 py::arg("print_generic_op_form") = false, 2622 py::arg("use_local_scope") = false, 2623 py::arg("assume_verified") = false, kOperationPrintDocstring) 2624 .def("get_asm", &PyOperationBase::getAsm, 2625 // Careful: Lots of arguments must match up with get_asm method. 2626 py::arg("binary") = false, 2627 py::arg("large_elements_limit") = py::none(), 2628 py::arg("enable_debug_info") = false, 2629 py::arg("pretty_debug_info") = false, 2630 py::arg("print_generic_op_form") = false, 2631 py::arg("use_local_scope") = false, 2632 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2633 .def( 2634 "verify", 2635 [](PyOperationBase &self) { 2636 return mlirOperationVerify(self.getOperation()); 2637 }, 2638 "Verify the operation and return true if it passes, false if it " 2639 "fails.") 2640 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2641 "Puts self immediately after the other operation in its parent " 2642 "block.") 2643 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2644 "Puts self immediately before the other operation in its parent " 2645 "block.") 2646 .def( 2647 "detach_from_parent", 2648 [](PyOperationBase &self) { 2649 PyOperation &operation = self.getOperation(); 2650 operation.checkValid(); 2651 if (!operation.isAttached()) 2652 throw py::value_error("Detached operation has no parent."); 2653 2654 operation.detachFromParent(); 2655 return operation.createOpView(); 2656 }, 2657 "Detaches the operation from its parent block."); 2658 2659 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2660 .def_static("create", &PyOperation::create, py::arg("name"), 2661 py::arg("results") = py::none(), 2662 py::arg("operands") = py::none(), 2663 py::arg("attributes") = py::none(), 2664 py::arg("successors") = py::none(), py::arg("regions") = 0, 2665 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2666 kOperationCreateDocstring) 2667 .def_property_readonly("parent", 2668 [](PyOperation &self) -> py::object { 2669 auto parent = self.getParentOperation(); 2670 if (parent) 2671 return parent->getObject(); 2672 return py::none(); 2673 }) 2674 .def("erase", &PyOperation::erase) 2675 .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) 2676 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2677 &PyOperation::getCapsule) 2678 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2679 .def_property_readonly("name", 2680 [](PyOperation &self) { 2681 self.checkValid(); 2682 MlirOperation operation = self.get(); 2683 MlirStringRef name = mlirIdentifierStr( 2684 mlirOperationGetName(operation)); 2685 return py::str(name.data, name.length); 2686 }) 2687 .def_property_readonly( 2688 "context", 2689 [](PyOperation &self) { 2690 self.checkValid(); 2691 return self.getContext().getObject(); 2692 }, 2693 "Context that owns the Operation") 2694 .def_property_readonly("opview", &PyOperation::createOpView); 2695 2696 auto opViewClass = 2697 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2698 .def(py::init<py::object>(), py::arg("operation")) 2699 .def_property_readonly("operation", &PyOpView::getOperationObject) 2700 .def_property_readonly( 2701 "context", 2702 [](PyOpView &self) { 2703 return self.getOperation().getContext().getObject(); 2704 }, 2705 "Context that owns the Operation") 2706 .def("__str__", [](PyOpView &self) { 2707 return py::str(self.getOperationObject()); 2708 }); 2709 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2710 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2711 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2712 opViewClass.attr("build_generic") = classmethod( 2713 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2714 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2715 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2716 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2717 "Builds a specific, generated OpView based on class level attributes."); 2718 2719 //---------------------------------------------------------------------------- 2720 // Mapping of PyRegion. 2721 //---------------------------------------------------------------------------- 2722 py::class_<PyRegion>(m, "Region", py::module_local()) 2723 .def_property_readonly( 2724 "blocks", 2725 [](PyRegion &self) { 2726 return PyBlockList(self.getParentOperation(), self.get()); 2727 }, 2728 "Returns a forward-optimized sequence of blocks.") 2729 .def_property_readonly( 2730 "owner", 2731 [](PyRegion &self) { 2732 return self.getParentOperation()->createOpView(); 2733 }, 2734 "Returns the operation owning this region.") 2735 .def( 2736 "__iter__", 2737 [](PyRegion &self) { 2738 self.checkValid(); 2739 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2740 return PyBlockIterator(self.getParentOperation(), firstBlock); 2741 }, 2742 "Iterates over blocks in the region.") 2743 .def("__eq__", 2744 [](PyRegion &self, PyRegion &other) { 2745 return self.get().ptr == other.get().ptr; 2746 }) 2747 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2748 2749 //---------------------------------------------------------------------------- 2750 // Mapping of PyBlock. 2751 //---------------------------------------------------------------------------- 2752 py::class_<PyBlock>(m, "Block", py::module_local()) 2753 .def_property_readonly( 2754 "owner", 2755 [](PyBlock &self) { 2756 return self.getParentOperation()->createOpView(); 2757 }, 2758 "Returns the owning operation of this block.") 2759 .def_property_readonly( 2760 "region", 2761 [](PyBlock &self) { 2762 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2763 return PyRegion(self.getParentOperation(), region); 2764 }, 2765 "Returns the owning region of this block.") 2766 .def_property_readonly( 2767 "arguments", 2768 [](PyBlock &self) { 2769 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2770 }, 2771 "Returns a list of block arguments.") 2772 .def_property_readonly( 2773 "operations", 2774 [](PyBlock &self) { 2775 return PyOperationList(self.getParentOperation(), self.get()); 2776 }, 2777 "Returns a forward-optimized sequence of operations.") 2778 .def_static( 2779 "create_at_start", 2780 [](PyRegion &parent, py::list pyArgTypes) { 2781 parent.checkValid(); 2782 llvm::SmallVector<MlirType, 4> argTypes; 2783 llvm::SmallVector<MlirLocation, 4> argLocs; 2784 argTypes.reserve(pyArgTypes.size()); 2785 argLocs.reserve(pyArgTypes.size()); 2786 for (auto &pyArg : pyArgTypes) { 2787 argTypes.push_back(pyArg.cast<PyType &>()); 2788 // TODO: Pass in a proper location here. 2789 argLocs.push_back( 2790 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2791 } 2792 2793 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2794 argLocs.data()); 2795 mlirRegionInsertOwnedBlock(parent, 0, block); 2796 return PyBlock(parent.getParentOperation(), block); 2797 }, 2798 py::arg("parent"), py::arg("arg_types") = py::list(), 2799 "Creates and returns a new Block at the beginning of the given " 2800 "region (with given argument types).") 2801 .def( 2802 "append_to", 2803 [](PyBlock &self, PyRegion ®ion) { 2804 MlirBlock b = self.get(); 2805 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) 2806 mlirBlockDetach(b); 2807 mlirRegionAppendOwnedBlock(region.get(), b); 2808 }, 2809 "Append this block to a region, transferring ownership if necessary") 2810 .def( 2811 "create_before", 2812 [](PyBlock &self, py::args pyArgTypes) { 2813 self.checkValid(); 2814 llvm::SmallVector<MlirType, 4> argTypes; 2815 llvm::SmallVector<MlirLocation, 4> argLocs; 2816 argTypes.reserve(pyArgTypes.size()); 2817 argLocs.reserve(pyArgTypes.size()); 2818 for (auto &pyArg : pyArgTypes) { 2819 argTypes.push_back(pyArg.cast<PyType &>()); 2820 // TODO: Pass in a proper location here. 2821 argLocs.push_back( 2822 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2823 } 2824 2825 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2826 argLocs.data()); 2827 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2828 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2829 return PyBlock(self.getParentOperation(), block); 2830 }, 2831 "Creates and returns a new Block before this block " 2832 "(with given argument types).") 2833 .def( 2834 "create_after", 2835 [](PyBlock &self, py::args pyArgTypes) { 2836 self.checkValid(); 2837 llvm::SmallVector<MlirType, 4> argTypes; 2838 llvm::SmallVector<MlirLocation, 4> argLocs; 2839 argTypes.reserve(pyArgTypes.size()); 2840 argLocs.reserve(pyArgTypes.size()); 2841 for (auto &pyArg : pyArgTypes) { 2842 argTypes.push_back(pyArg.cast<PyType &>()); 2843 2844 // TODO: Pass in a proper location here. 2845 argLocs.push_back( 2846 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); 2847 } 2848 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), 2849 argLocs.data()); 2850 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2851 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2852 return PyBlock(self.getParentOperation(), block); 2853 }, 2854 "Creates and returns a new Block after this block " 2855 "(with given argument types).") 2856 .def( 2857 "__iter__", 2858 [](PyBlock &self) { 2859 self.checkValid(); 2860 MlirOperation firstOperation = 2861 mlirBlockGetFirstOperation(self.get()); 2862 return PyOperationIterator(self.getParentOperation(), 2863 firstOperation); 2864 }, 2865 "Iterates over operations in the block.") 2866 .def("__eq__", 2867 [](PyBlock &self, PyBlock &other) { 2868 return self.get().ptr == other.get().ptr; 2869 }) 2870 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2871 .def( 2872 "__str__", 2873 [](PyBlock &self) { 2874 self.checkValid(); 2875 PyPrintAccumulator printAccum; 2876 mlirBlockPrint(self.get(), printAccum.getCallback(), 2877 printAccum.getUserData()); 2878 return printAccum.join(); 2879 }, 2880 "Returns the assembly form of the block.") 2881 .def( 2882 "append", 2883 [](PyBlock &self, PyOperationBase &operation) { 2884 if (operation.getOperation().isAttached()) 2885 operation.getOperation().detachFromParent(); 2886 2887 MlirOperation mlirOperation = operation.getOperation().get(); 2888 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2889 operation.getOperation().setAttached( 2890 self.getParentOperation().getObject()); 2891 }, 2892 py::arg("operation"), 2893 "Appends an operation to this block. If the operation is currently " 2894 "in another block, it will be moved."); 2895 2896 //---------------------------------------------------------------------------- 2897 // Mapping of PyInsertionPoint. 2898 //---------------------------------------------------------------------------- 2899 2900 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2901 .def(py::init<PyBlock &>(), py::arg("block"), 2902 "Inserts after the last operation but still inside the block.") 2903 .def("__enter__", &PyInsertionPoint::contextEnter) 2904 .def("__exit__", &PyInsertionPoint::contextExit) 2905 .def_property_readonly_static( 2906 "current", 2907 [](py::object & /*class*/) { 2908 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2909 if (!ip) 2910 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2911 return ip; 2912 }, 2913 "Gets the InsertionPoint bound to the current thread or raises " 2914 "ValueError if none has been set") 2915 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2916 "Inserts before a referenced operation.") 2917 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2918 py::arg("block"), "Inserts at the beginning of the block.") 2919 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2920 py::arg("block"), "Inserts before the block terminator.") 2921 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2922 "Inserts an operation.") 2923 .def_property_readonly( 2924 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2925 "Returns the block that this InsertionPoint points to."); 2926 2927 //---------------------------------------------------------------------------- 2928 // Mapping of PyAttribute. 2929 //---------------------------------------------------------------------------- 2930 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2931 // Delegate to the PyAttribute copy constructor, which will also lifetime 2932 // extend the backing context which owns the MlirAttribute. 2933 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2934 "Casts the passed attribute to the generic Attribute") 2935 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2936 &PyAttribute::getCapsule) 2937 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2938 .def_static( 2939 "parse", 2940 [](std::string attrSpec, DefaultingPyMlirContext context) { 2941 MlirAttribute type = mlirAttributeParseGet( 2942 context->get(), toMlirStringRef(attrSpec)); 2943 // TODO: Rework error reporting once diagnostic engine is exposed 2944 // in C API. 2945 if (mlirAttributeIsNull(type)) { 2946 throw SetPyError(PyExc_ValueError, 2947 Twine("Unable to parse attribute: '") + 2948 attrSpec + "'"); 2949 } 2950 return PyAttribute(context->getRef(), type); 2951 }, 2952 py::arg("asm"), py::arg("context") = py::none(), 2953 "Parses an attribute from an assembly form") 2954 .def_property_readonly( 2955 "context", 2956 [](PyAttribute &self) { return self.getContext().getObject(); }, 2957 "Context that owns the Attribute") 2958 .def_property_readonly("type", 2959 [](PyAttribute &self) { 2960 return PyType(self.getContext()->getRef(), 2961 mlirAttributeGetType(self)); 2962 }) 2963 .def( 2964 "get_named", 2965 [](PyAttribute &self, std::string name) { 2966 return PyNamedAttribute(self, std::move(name)); 2967 }, 2968 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2969 .def("__eq__", 2970 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2971 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2972 .def("__hash__", 2973 [](PyAttribute &self) { 2974 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2975 }) 2976 .def( 2977 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2978 kDumpDocstring) 2979 .def( 2980 "__str__", 2981 [](PyAttribute &self) { 2982 PyPrintAccumulator printAccum; 2983 mlirAttributePrint(self, printAccum.getCallback(), 2984 printAccum.getUserData()); 2985 return printAccum.join(); 2986 }, 2987 "Returns the assembly form of the Attribute.") 2988 .def("__repr__", [](PyAttribute &self) { 2989 // Generally, assembly formats are not printed for __repr__ because 2990 // this can cause exceptionally long debug output and exceptions. 2991 // However, attribute values are generally considered useful and are 2992 // printed. This may need to be re-evaluated if debug dumps end up 2993 // being excessive. 2994 PyPrintAccumulator printAccum; 2995 printAccum.parts.append("Attribute("); 2996 mlirAttributePrint(self, printAccum.getCallback(), 2997 printAccum.getUserData()); 2998 printAccum.parts.append(")"); 2999 return printAccum.join(); 3000 }); 3001 3002 //---------------------------------------------------------------------------- 3003 // Mapping of PyNamedAttribute 3004 //---------------------------------------------------------------------------- 3005 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 3006 .def("__repr__", 3007 [](PyNamedAttribute &self) { 3008 PyPrintAccumulator printAccum; 3009 printAccum.parts.append("NamedAttribute("); 3010 printAccum.parts.append( 3011 py::str(mlirIdentifierStr(self.namedAttr.name).data, 3012 mlirIdentifierStr(self.namedAttr.name).length)); 3013 printAccum.parts.append("="); 3014 mlirAttributePrint(self.namedAttr.attribute, 3015 printAccum.getCallback(), 3016 printAccum.getUserData()); 3017 printAccum.parts.append(")"); 3018 return printAccum.join(); 3019 }) 3020 .def_property_readonly( 3021 "name", 3022 [](PyNamedAttribute &self) { 3023 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 3024 mlirIdentifierStr(self.namedAttr.name).length); 3025 }, 3026 "The name of the NamedAttribute binding") 3027 .def_property_readonly( 3028 "attr", 3029 [](PyNamedAttribute &self) { 3030 // TODO: When named attribute is removed/refactored, also remove 3031 // this constructor (it does an inefficient table lookup). 3032 auto contextRef = PyMlirContext::forContext( 3033 mlirAttributeGetContext(self.namedAttr.attribute)); 3034 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 3035 }, 3036 py::keep_alive<0, 1>(), 3037 "The underlying generic attribute of the NamedAttribute binding"); 3038 3039 //---------------------------------------------------------------------------- 3040 // Mapping of PyType. 3041 //---------------------------------------------------------------------------- 3042 py::class_<PyType>(m, "Type", py::module_local()) 3043 // Delegate to the PyType copy constructor, which will also lifetime 3044 // extend the backing context which owns the MlirType. 3045 .def(py::init<PyType &>(), py::arg("cast_from_type"), 3046 "Casts the passed type to the generic Type") 3047 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 3048 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 3049 .def_static( 3050 "parse", 3051 [](std::string typeSpec, DefaultingPyMlirContext context) { 3052 MlirType type = 3053 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 3054 // TODO: Rework error reporting once diagnostic engine is exposed 3055 // in C API. 3056 if (mlirTypeIsNull(type)) { 3057 throw SetPyError(PyExc_ValueError, 3058 Twine("Unable to parse type: '") + typeSpec + 3059 "'"); 3060 } 3061 return PyType(context->getRef(), type); 3062 }, 3063 py::arg("asm"), py::arg("context") = py::none(), 3064 kContextParseTypeDocstring) 3065 .def_property_readonly( 3066 "context", [](PyType &self) { return self.getContext().getObject(); }, 3067 "Context that owns the Type") 3068 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3069 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 3070 .def("__hash__", 3071 [](PyType &self) { 3072 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3073 }) 3074 .def( 3075 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3076 .def( 3077 "__str__", 3078 [](PyType &self) { 3079 PyPrintAccumulator printAccum; 3080 mlirTypePrint(self, printAccum.getCallback(), 3081 printAccum.getUserData()); 3082 return printAccum.join(); 3083 }, 3084 "Returns the assembly form of the type.") 3085 .def("__repr__", [](PyType &self) { 3086 // Generally, assembly formats are not printed for __repr__ because 3087 // this can cause exceptionally long debug output and exceptions. 3088 // However, types are an exception as they typically have compact 3089 // assembly forms and printing them is useful. 3090 PyPrintAccumulator printAccum; 3091 printAccum.parts.append("Type("); 3092 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 3093 printAccum.parts.append(")"); 3094 return printAccum.join(); 3095 }); 3096 3097 //---------------------------------------------------------------------------- 3098 // Mapping of Value. 3099 //---------------------------------------------------------------------------- 3100 py::class_<PyValue>(m, "Value", py::module_local()) 3101 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3102 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3103 .def_property_readonly( 3104 "context", 3105 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3106 "Context in which the value lives.") 3107 .def( 3108 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3109 kDumpDocstring) 3110 .def_property_readonly( 3111 "owner", 3112 [](PyValue &self) { 3113 assert(mlirOperationEqual(self.getParentOperation()->get(), 3114 mlirOpResultGetOwner(self.get())) && 3115 "expected the owner of the value in Python to match that in " 3116 "the IR"); 3117 return self.getParentOperation().getObject(); 3118 }) 3119 .def("__eq__", 3120 [](PyValue &self, PyValue &other) { 3121 return self.get().ptr == other.get().ptr; 3122 }) 3123 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 3124 .def("__hash__", 3125 [](PyValue &self) { 3126 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3127 }) 3128 .def( 3129 "__str__", 3130 [](PyValue &self) { 3131 PyPrintAccumulator printAccum; 3132 printAccum.parts.append("Value("); 3133 mlirValuePrint(self.get(), printAccum.getCallback(), 3134 printAccum.getUserData()); 3135 printAccum.parts.append(")"); 3136 return printAccum.join(); 3137 }, 3138 kValueDunderStrDocstring) 3139 .def_property_readonly("type", [](PyValue &self) { 3140 return PyType(self.getParentOperation()->getContext(), 3141 mlirValueGetType(self.get())); 3142 }); 3143 PyBlockArgument::bind(m); 3144 PyOpResult::bind(m); 3145 3146 //---------------------------------------------------------------------------- 3147 // Mapping of SymbolTable. 3148 //---------------------------------------------------------------------------- 3149 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 3150 .def(py::init<PyOperationBase &>()) 3151 .def("__getitem__", &PySymbolTable::dunderGetItem) 3152 .def("insert", &PySymbolTable::insert, py::arg("operation")) 3153 .def("erase", &PySymbolTable::erase, py::arg("operation")) 3154 .def("__delitem__", &PySymbolTable::dunderDel) 3155 .def("__contains__", 3156 [](PySymbolTable &table, const std::string &name) { 3157 return !mlirOperationIsNull(mlirSymbolTableLookup( 3158 table, mlirStringRefCreate(name.data(), name.length()))); 3159 }) 3160 // Static helpers. 3161 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3162 py::arg("symbol"), py::arg("name")) 3163 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3164 py::arg("symbol")) 3165 .def_static("get_visibility", &PySymbolTable::getVisibility, 3166 py::arg("symbol")) 3167 .def_static("set_visibility", &PySymbolTable::setVisibility, 3168 py::arg("symbol"), py::arg("visibility")) 3169 .def_static("replace_all_symbol_uses", 3170 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 3171 py::arg("new_symbol"), py::arg("from_op")) 3172 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3173 py::arg("from_op"), py::arg("all_sym_uses_visible"), 3174 py::arg("callback")); 3175 3176 // Container bindings. 3177 PyBlockArgumentList::bind(m); 3178 PyBlockIterator::bind(m); 3179 PyBlockList::bind(m); 3180 PyOperationIterator::bind(m); 3181 PyOperationList::bind(m); 3182 PyOpAttributeMap::bind(m); 3183 PyOpOperandList::bind(m); 3184 PyOpResultList::bind(m); 3185 PyRegionIterator::bind(m); 3186 PyRegionList::bind(m); 3187 3188 // Debug bindings. 3189 PyGlobalDebugFlag::bind(m); 3190 } 3191