1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "Globals.h" 12 #include "PybindUtils.h" 13 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/BuiltinTypes.h" 17 #include "mlir-c/Debug.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Registration.h" 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include <pybind11/stl.h> 23 24 namespace py = pybind11; 25 using namespace mlir; 26 using namespace mlir::python; 27 28 using llvm::SmallVector; 29 using llvm::StringRef; 30 using llvm::Twine; 31 32 //------------------------------------------------------------------------------ 33 // Docstrings (trivial, non-duplicated docstrings are included inline). 34 //------------------------------------------------------------------------------ 35 36 static const char kContextParseTypeDocstring[] = 37 R"(Parses the assembly form of a type. 38 39 Returns a Type object or raises a ValueError if the type cannot be parsed. 40 41 See also: https://mlir.llvm.org/docs/LangRef/#type-system 42 )"; 43 44 static const char kContextGetCallSiteLocationDocstring[] = 45 R"(Gets a Location representing a caller and callsite)"; 46 47 static const char kContextGetFileLocationDocstring[] = 48 R"(Gets a Location representing a file, line and column)"; 49 50 static const char kContextGetFusedLocationDocstring[] = 51 R"(Gets a Location representing a fused location with optional metadata)"; 52 53 static const char kContextGetNameLocationDocString[] = 54 R"(Gets a Location representing a named location with optional child location)"; 55 56 static const char kModuleParseDocstring[] = 57 R"(Parses a module's assembly format from a string. 58 59 Returns a new MlirModule or raises a ValueError if the parsing fails. 60 61 See also: https://mlir.llvm.org/docs/LangRef/ 62 )"; 63 64 static const char kOperationCreateDocstring[] = 65 R"(Creates a new operation. 66 67 Args: 68 name: Operation name (e.g. "dialect.operation"). 69 results: Sequence of Type representing op result types. 70 attributes: Dict of str:Attribute. 71 successors: List of Block for the operation's successors. 72 regions: Number of regions to create. 73 location: A Location object (defaults to resolve from context manager). 74 ip: An InsertionPoint (defaults to resolve from context manager or set to 75 False to disable insertion, even with an insertion point set in the 76 context manager). 77 Returns: 78 A new "detached" Operation object. Detached operations can be added 79 to blocks, which causes them to become "attached." 80 )"; 81 82 static const char kOperationPrintDocstring[] = 83 R"(Prints the assembly form of the operation to a file like object. 84 85 Args: 86 file: The file like object to write to. Defaults to sys.stdout. 87 binary: Whether to write bytes (True) or str (False). Defaults to False. 88 large_elements_limit: Whether to elide elements attributes above this 89 number of elements. Defaults to None (no limit). 90 enable_debug_info: Whether to print debug/location information. Defaults 91 to False. 92 pretty_debug_info: Whether to format debug information for easier reading 93 by a human (warning: the result is unparseable). 94 print_generic_op_form: Whether to print the generic assembly forms of all 95 ops. Defaults to False. 96 use_local_Scope: Whether to print in a way that is more optimized for 97 multi-threaded access but may not be consistent with how the overall 98 module prints. 99 assume_verified: By default, if not printing generic form, the verifier 100 will be run and if it fails, generic form will be printed with a comment 101 about failed verification. While a reasonable default for interactive use, 102 for systematic use, it is often better for the caller to verify explicitly 103 and report failures in a more robust fashion. Set this to True if doing this 104 in order to avoid running a redundant verification. If the IR is actually 105 invalid, behavior is undefined. 106 )"; 107 108 static const char kOperationGetAsmDocstring[] = 109 R"(Gets the assembly form of the operation with all options available. 110 111 Args: 112 binary: Whether to return a bytes (True) or str (False) object. Defaults to 113 False. 114 ... others ...: See the print() method for common keyword arguments for 115 configuring the printout. 116 Returns: 117 Either a bytes or str object, depending on the setting of the 'binary' 118 argument. 119 )"; 120 121 static const char kOperationStrDunderDocstring[] = 122 R"(Gets the assembly form of the operation with default options. 123 124 If more advanced control over the assembly formatting or I/O options is needed, 125 use the dedicated print or get_asm method, which supports keyword arguments to 126 customize behavior. 127 )"; 128 129 static const char kDumpDocstring[] = 130 R"(Dumps a debug representation of the object to stderr.)"; 131 132 static const char kAppendBlockDocstring[] = 133 R"(Appends a new block, with argument types as positional args. 134 135 Returns: 136 The created block. 137 )"; 138 139 static const char kValueDunderStrDocstring[] = 140 R"(Returns the string form of the value. 141 142 If the value is a block argument, this is the assembly form of its type and the 143 position in the argument list. If the value is an operation result, this is 144 equivalent to printing the operation that produced it. 145 )"; 146 147 //------------------------------------------------------------------------------ 148 // Utilities. 149 //------------------------------------------------------------------------------ 150 151 /// Helper for creating an @classmethod. 152 template <class Func, typename... Args> 153 py::object classmethod(Func f, Args... args) { 154 py::object cf = py::cpp_function(f, args...); 155 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr()))); 156 } 157 158 static py::object 159 createCustomDialectWrapper(const std::string &dialectNamespace, 160 py::object dialectDescriptor) { 161 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 162 if (!dialectClass) { 163 // Use the base class. 164 return py::cast(PyDialect(std::move(dialectDescriptor))); 165 } 166 167 // Create the custom implementation. 168 return (*dialectClass)(std::move(dialectDescriptor)); 169 } 170 171 static MlirStringRef toMlirStringRef(const std::string &s) { 172 return mlirStringRefCreate(s.data(), s.size()); 173 } 174 175 /// Wrapper for the global LLVM debugging flag. 176 struct PyGlobalDebugFlag { 177 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } 178 179 static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } 180 181 static void bind(py::module &m) { 182 // Debug flags. 183 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local()) 184 .def_property_static("flag", &PyGlobalDebugFlag::get, 185 &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); 186 } 187 }; 188 189 //------------------------------------------------------------------------------ 190 // Collections. 191 //------------------------------------------------------------------------------ 192 193 namespace { 194 195 class PyRegionIterator { 196 public: 197 PyRegionIterator(PyOperationRef operation) 198 : operation(std::move(operation)) {} 199 200 PyRegionIterator &dunderIter() { return *this; } 201 202 PyRegion dunderNext() { 203 operation->checkValid(); 204 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 205 throw py::stop_iteration(); 206 } 207 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 208 return PyRegion(operation, region); 209 } 210 211 static void bind(py::module &m) { 212 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local()) 213 .def("__iter__", &PyRegionIterator::dunderIter) 214 .def("__next__", &PyRegionIterator::dunderNext); 215 } 216 217 private: 218 PyOperationRef operation; 219 int nextIndex = 0; 220 }; 221 222 /// Regions of an op are fixed length and indexed numerically so are represented 223 /// with a sequence-like container. 224 class PyRegionList { 225 public: 226 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 227 228 intptr_t dunderLen() { 229 operation->checkValid(); 230 return mlirOperationGetNumRegions(operation->get()); 231 } 232 233 PyRegion dunderGetItem(intptr_t index) { 234 // dunderLen checks validity. 235 if (index < 0 || index >= dunderLen()) { 236 throw SetPyError(PyExc_IndexError, 237 "attempt to access out of bounds region"); 238 } 239 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 240 return PyRegion(operation, region); 241 } 242 243 static void bind(py::module &m) { 244 py::class_<PyRegionList>(m, "RegionSequence", py::module_local()) 245 .def("__len__", &PyRegionList::dunderLen) 246 .def("__getitem__", &PyRegionList::dunderGetItem); 247 } 248 249 private: 250 PyOperationRef operation; 251 }; 252 253 class PyBlockIterator { 254 public: 255 PyBlockIterator(PyOperationRef operation, MlirBlock next) 256 : operation(std::move(operation)), next(next) {} 257 258 PyBlockIterator &dunderIter() { return *this; } 259 260 PyBlock dunderNext() { 261 operation->checkValid(); 262 if (mlirBlockIsNull(next)) { 263 throw py::stop_iteration(); 264 } 265 266 PyBlock returnBlock(operation, next); 267 next = mlirBlockGetNextInRegion(next); 268 return returnBlock; 269 } 270 271 static void bind(py::module &m) { 272 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local()) 273 .def("__iter__", &PyBlockIterator::dunderIter) 274 .def("__next__", &PyBlockIterator::dunderNext); 275 } 276 277 private: 278 PyOperationRef operation; 279 MlirBlock next; 280 }; 281 282 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 283 /// we present them as a more full-featured list-like container but optimize 284 /// it for forward iteration. Blocks are always owned by a region. 285 class PyBlockList { 286 public: 287 PyBlockList(PyOperationRef operation, MlirRegion region) 288 : operation(std::move(operation)), region(region) {} 289 290 PyBlockIterator dunderIter() { 291 operation->checkValid(); 292 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 293 } 294 295 intptr_t dunderLen() { 296 operation->checkValid(); 297 intptr_t count = 0; 298 MlirBlock block = mlirRegionGetFirstBlock(region); 299 while (!mlirBlockIsNull(block)) { 300 count += 1; 301 block = mlirBlockGetNextInRegion(block); 302 } 303 return count; 304 } 305 306 PyBlock dunderGetItem(intptr_t index) { 307 operation->checkValid(); 308 if (index < 0) { 309 throw SetPyError(PyExc_IndexError, 310 "attempt to access out of bounds block"); 311 } 312 MlirBlock block = mlirRegionGetFirstBlock(region); 313 while (!mlirBlockIsNull(block)) { 314 if (index == 0) { 315 return PyBlock(operation, block); 316 } 317 block = mlirBlockGetNextInRegion(block); 318 index -= 1; 319 } 320 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); 321 } 322 323 PyBlock appendBlock(py::args pyArgTypes) { 324 operation->checkValid(); 325 llvm::SmallVector<MlirType, 4> argTypes; 326 argTypes.reserve(pyArgTypes.size()); 327 for (auto &pyArg : pyArgTypes) { 328 argTypes.push_back(pyArg.cast<PyType &>()); 329 } 330 331 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 332 mlirRegionAppendOwnedBlock(region, block); 333 return PyBlock(operation, block); 334 } 335 336 static void bind(py::module &m) { 337 py::class_<PyBlockList>(m, "BlockList", py::module_local()) 338 .def("__getitem__", &PyBlockList::dunderGetItem) 339 .def("__iter__", &PyBlockList::dunderIter) 340 .def("__len__", &PyBlockList::dunderLen) 341 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); 342 } 343 344 private: 345 PyOperationRef operation; 346 MlirRegion region; 347 }; 348 349 class PyOperationIterator { 350 public: 351 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 352 : parentOperation(std::move(parentOperation)), next(next) {} 353 354 PyOperationIterator &dunderIter() { return *this; } 355 356 py::object dunderNext() { 357 parentOperation->checkValid(); 358 if (mlirOperationIsNull(next)) { 359 throw py::stop_iteration(); 360 } 361 362 PyOperationRef returnOperation = 363 PyOperation::forOperation(parentOperation->getContext(), next); 364 next = mlirOperationGetNextInBlock(next); 365 return returnOperation->createOpView(); 366 } 367 368 static void bind(py::module &m) { 369 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local()) 370 .def("__iter__", &PyOperationIterator::dunderIter) 371 .def("__next__", &PyOperationIterator::dunderNext); 372 } 373 374 private: 375 PyOperationRef parentOperation; 376 MlirOperation next; 377 }; 378 379 /// Operations are exposed by the C-API as a forward-only linked list. In 380 /// Python, we present them as a more full-featured list-like container but 381 /// optimize it for forward iteration. Iterable operations are always owned 382 /// by a block. 383 class PyOperationList { 384 public: 385 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 386 : parentOperation(std::move(parentOperation)), block(block) {} 387 388 PyOperationIterator dunderIter() { 389 parentOperation->checkValid(); 390 return PyOperationIterator(parentOperation, 391 mlirBlockGetFirstOperation(block)); 392 } 393 394 intptr_t dunderLen() { 395 parentOperation->checkValid(); 396 intptr_t count = 0; 397 MlirOperation childOp = mlirBlockGetFirstOperation(block); 398 while (!mlirOperationIsNull(childOp)) { 399 count += 1; 400 childOp = mlirOperationGetNextInBlock(childOp); 401 } 402 return count; 403 } 404 405 py::object dunderGetItem(intptr_t index) { 406 parentOperation->checkValid(); 407 if (index < 0) { 408 throw SetPyError(PyExc_IndexError, 409 "attempt to access out of bounds operation"); 410 } 411 MlirOperation childOp = mlirBlockGetFirstOperation(block); 412 while (!mlirOperationIsNull(childOp)) { 413 if (index == 0) { 414 return PyOperation::forOperation(parentOperation->getContext(), childOp) 415 ->createOpView(); 416 } 417 childOp = mlirOperationGetNextInBlock(childOp); 418 index -= 1; 419 } 420 throw SetPyError(PyExc_IndexError, 421 "attempt to access out of bounds operation"); 422 } 423 424 static void bind(py::module &m) { 425 py::class_<PyOperationList>(m, "OperationList", py::module_local()) 426 .def("__getitem__", &PyOperationList::dunderGetItem) 427 .def("__iter__", &PyOperationList::dunderIter) 428 .def("__len__", &PyOperationList::dunderLen); 429 } 430 431 private: 432 PyOperationRef parentOperation; 433 MlirBlock block; 434 }; 435 436 } // namespace 437 438 //------------------------------------------------------------------------------ 439 // PyMlirContext 440 //------------------------------------------------------------------------------ 441 442 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 443 py::gil_scoped_acquire acquire; 444 auto &liveContexts = getLiveContexts(); 445 liveContexts[context.ptr] = this; 446 } 447 448 PyMlirContext::~PyMlirContext() { 449 // Note that the only public way to construct an instance is via the 450 // forContext method, which always puts the associated handle into 451 // liveContexts. 452 py::gil_scoped_acquire acquire; 453 getLiveContexts().erase(context.ptr); 454 mlirContextDestroy(context); 455 } 456 457 py::object PyMlirContext::getCapsule() { 458 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); 459 } 460 461 py::object PyMlirContext::createFromCapsule(py::object capsule) { 462 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 463 if (mlirContextIsNull(rawContext)) 464 throw py::error_already_set(); 465 return forContext(rawContext).releaseObject(); 466 } 467 468 PyMlirContext *PyMlirContext::createNewContextForInit() { 469 MlirContext context = mlirContextCreate(); 470 mlirRegisterAllDialects(context); 471 return new PyMlirContext(context); 472 } 473 474 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 475 py::gil_scoped_acquire acquire; 476 auto &liveContexts = getLiveContexts(); 477 auto it = liveContexts.find(context.ptr); 478 if (it == liveContexts.end()) { 479 // Create. 480 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 481 py::object pyRef = py::cast(unownedContextWrapper); 482 assert(pyRef && "cast to py::object failed"); 483 liveContexts[context.ptr] = unownedContextWrapper; 484 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 485 } 486 // Use existing. 487 py::object pyRef = py::cast(it->second); 488 return PyMlirContextRef(it->second, std::move(pyRef)); 489 } 490 491 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 492 static LiveContextMap liveContexts; 493 return liveContexts; 494 } 495 496 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } 497 498 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } 499 500 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 501 502 pybind11::object PyMlirContext::contextEnter() { 503 return PyThreadContextEntry::pushContext(*this); 504 } 505 506 void PyMlirContext::contextExit(pybind11::object excType, 507 pybind11::object excVal, 508 pybind11::object excTb) { 509 PyThreadContextEntry::popContext(*this); 510 } 511 512 PyMlirContext &DefaultingPyMlirContext::resolve() { 513 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 514 if (!context) { 515 throw SetPyError( 516 PyExc_RuntimeError, 517 "An MLIR function requires a Context but none was provided in the call " 518 "or from the surrounding environment. Either pass to the function with " 519 "a 'context=' argument or establish a default using 'with Context():'"); 520 } 521 return *context; 522 } 523 524 //------------------------------------------------------------------------------ 525 // PyThreadContextEntry management 526 //------------------------------------------------------------------------------ 527 528 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 529 static thread_local std::vector<PyThreadContextEntry> stack; 530 return stack; 531 } 532 533 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 534 auto &stack = getStack(); 535 if (stack.empty()) 536 return nullptr; 537 return &stack.back(); 538 } 539 540 void PyThreadContextEntry::push(FrameKind frameKind, py::object context, 541 py::object insertionPoint, 542 py::object location) { 543 auto &stack = getStack(); 544 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 545 std::move(location)); 546 // If the new stack has more than one entry and the context of the new top 547 // entry matches the previous, copy the insertionPoint and location from the 548 // previous entry if missing from the new top entry. 549 if (stack.size() > 1) { 550 auto &prev = *(stack.rbegin() + 1); 551 auto ¤t = stack.back(); 552 if (current.context.is(prev.context)) { 553 // Default non-context objects from the previous entry. 554 if (!current.insertionPoint) 555 current.insertionPoint = prev.insertionPoint; 556 if (!current.location) 557 current.location = prev.location; 558 } 559 } 560 } 561 562 PyMlirContext *PyThreadContextEntry::getContext() { 563 if (!context) 564 return nullptr; 565 return py::cast<PyMlirContext *>(context); 566 } 567 568 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 569 if (!insertionPoint) 570 return nullptr; 571 return py::cast<PyInsertionPoint *>(insertionPoint); 572 } 573 574 PyLocation *PyThreadContextEntry::getLocation() { 575 if (!location) 576 return nullptr; 577 return py::cast<PyLocation *>(location); 578 } 579 580 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 581 auto *tos = getTopOfStack(); 582 return tos ? tos->getContext() : nullptr; 583 } 584 585 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 586 auto *tos = getTopOfStack(); 587 return tos ? tos->getInsertionPoint() : nullptr; 588 } 589 590 PyLocation *PyThreadContextEntry::getDefaultLocation() { 591 auto *tos = getTopOfStack(); 592 return tos ? tos->getLocation() : nullptr; 593 } 594 595 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { 596 py::object contextObj = py::cast(context); 597 push(FrameKind::Context, /*context=*/contextObj, 598 /*insertionPoint=*/py::object(), 599 /*location=*/py::object()); 600 return contextObj; 601 } 602 603 void PyThreadContextEntry::popContext(PyMlirContext &context) { 604 auto &stack = getStack(); 605 if (stack.empty()) 606 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 607 auto &tos = stack.back(); 608 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 609 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); 610 stack.pop_back(); 611 } 612 613 py::object 614 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { 615 py::object contextObj = 616 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 617 py::object insertionPointObj = py::cast(insertionPoint); 618 push(FrameKind::InsertionPoint, 619 /*context=*/contextObj, 620 /*insertionPoint=*/insertionPointObj, 621 /*location=*/py::object()); 622 return insertionPointObj; 623 } 624 625 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 626 auto &stack = getStack(); 627 if (stack.empty()) 628 throw SetPyError(PyExc_RuntimeError, 629 "Unbalanced InsertionPoint enter/exit"); 630 auto &tos = stack.back(); 631 if (tos.frameKind != FrameKind::InsertionPoint && 632 tos.getInsertionPoint() != &insertionPoint) 633 throw SetPyError(PyExc_RuntimeError, 634 "Unbalanced InsertionPoint enter/exit"); 635 stack.pop_back(); 636 } 637 638 py::object PyThreadContextEntry::pushLocation(PyLocation &location) { 639 py::object contextObj = location.getContext().getObject(); 640 py::object locationObj = py::cast(location); 641 push(FrameKind::Location, /*context=*/contextObj, 642 /*insertionPoint=*/py::object(), 643 /*location=*/locationObj); 644 return locationObj; 645 } 646 647 void PyThreadContextEntry::popLocation(PyLocation &location) { 648 auto &stack = getStack(); 649 if (stack.empty()) 650 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 651 auto &tos = stack.back(); 652 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 653 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); 654 stack.pop_back(); 655 } 656 657 //------------------------------------------------------------------------------ 658 // PyDialect, PyDialectDescriptor, PyDialects 659 //------------------------------------------------------------------------------ 660 661 MlirDialect PyDialects::getDialectForKey(const std::string &key, 662 bool attrError) { 663 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 664 {key.data(), key.size()}); 665 if (mlirDialectIsNull(dialect)) { 666 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, 667 Twine("Dialect '") + key + "' not found"); 668 } 669 return dialect; 670 } 671 672 //------------------------------------------------------------------------------ 673 // PyLocation 674 //------------------------------------------------------------------------------ 675 676 py::object PyLocation::getCapsule() { 677 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); 678 } 679 680 PyLocation PyLocation::createFromCapsule(py::object capsule) { 681 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 682 if (mlirLocationIsNull(rawLoc)) 683 throw py::error_already_set(); 684 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 685 rawLoc); 686 } 687 688 py::object PyLocation::contextEnter() { 689 return PyThreadContextEntry::pushLocation(*this); 690 } 691 692 void PyLocation::contextExit(py::object excType, py::object excVal, 693 py::object excTb) { 694 PyThreadContextEntry::popLocation(*this); 695 } 696 697 PyLocation &DefaultingPyLocation::resolve() { 698 auto *location = PyThreadContextEntry::getDefaultLocation(); 699 if (!location) { 700 throw SetPyError( 701 PyExc_RuntimeError, 702 "An MLIR function requires a Location but none was provided in the " 703 "call or from the surrounding environment. Either pass to the function " 704 "with a 'loc=' argument or establish a default using 'with loc:'"); 705 } 706 return *location; 707 } 708 709 //------------------------------------------------------------------------------ 710 // PyModule 711 //------------------------------------------------------------------------------ 712 713 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 714 : BaseContextObject(std::move(contextRef)), module(module) {} 715 716 PyModule::~PyModule() { 717 py::gil_scoped_acquire acquire; 718 auto &liveModules = getContext()->liveModules; 719 assert(liveModules.count(module.ptr) == 1 && 720 "destroying module not in live map"); 721 liveModules.erase(module.ptr); 722 mlirModuleDestroy(module); 723 } 724 725 PyModuleRef PyModule::forModule(MlirModule module) { 726 MlirContext context = mlirModuleGetContext(module); 727 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 728 729 py::gil_scoped_acquire acquire; 730 auto &liveModules = contextRef->liveModules; 731 auto it = liveModules.find(module.ptr); 732 if (it == liveModules.end()) { 733 // Create. 734 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 735 // Note that the default return value policy on cast is automatic_reference, 736 // which does not take ownership (delete will not be called). 737 // Just be explicit. 738 py::object pyRef = 739 py::cast(unownedModule, py::return_value_policy::take_ownership); 740 unownedModule->handle = pyRef; 741 liveModules[module.ptr] = 742 std::make_pair(unownedModule->handle, unownedModule); 743 return PyModuleRef(unownedModule, std::move(pyRef)); 744 } 745 // Use existing. 746 PyModule *existing = it->second.second; 747 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 748 return PyModuleRef(existing, std::move(pyRef)); 749 } 750 751 py::object PyModule::createFromCapsule(py::object capsule) { 752 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 753 if (mlirModuleIsNull(rawModule)) 754 throw py::error_already_set(); 755 return forModule(rawModule).releaseObject(); 756 } 757 758 py::object PyModule::getCapsule() { 759 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); 760 } 761 762 //------------------------------------------------------------------------------ 763 // PyOperation 764 //------------------------------------------------------------------------------ 765 766 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 767 : BaseContextObject(std::move(contextRef)), operation(operation) {} 768 769 PyOperation::~PyOperation() { 770 // If the operation has already been invalidated there is nothing to do. 771 if (!valid) 772 return; 773 auto &liveOperations = getContext()->liveOperations; 774 assert(liveOperations.count(operation.ptr) == 1 && 775 "destroying operation not in live map"); 776 liveOperations.erase(operation.ptr); 777 if (!isAttached()) { 778 mlirOperationDestroy(operation); 779 } 780 } 781 782 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 783 MlirOperation operation, 784 py::object parentKeepAlive) { 785 auto &liveOperations = contextRef->liveOperations; 786 // Create. 787 PyOperation *unownedOperation = 788 new PyOperation(std::move(contextRef), operation); 789 // Note that the default return value policy on cast is automatic_reference, 790 // which does not take ownership (delete will not be called). 791 // Just be explicit. 792 py::object pyRef = 793 py::cast(unownedOperation, py::return_value_policy::take_ownership); 794 unownedOperation->handle = pyRef; 795 if (parentKeepAlive) { 796 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 797 } 798 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); 799 return PyOperationRef(unownedOperation, std::move(pyRef)); 800 } 801 802 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 803 MlirOperation operation, 804 py::object parentKeepAlive) { 805 auto &liveOperations = contextRef->liveOperations; 806 auto it = liveOperations.find(operation.ptr); 807 if (it == liveOperations.end()) { 808 // Create. 809 return createInstance(std::move(contextRef), operation, 810 std::move(parentKeepAlive)); 811 } 812 // Use existing. 813 PyOperation *existing = it->second.second; 814 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); 815 return PyOperationRef(existing, std::move(pyRef)); 816 } 817 818 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 819 MlirOperation operation, 820 py::object parentKeepAlive) { 821 auto &liveOperations = contextRef->liveOperations; 822 assert(liveOperations.count(operation.ptr) == 0 && 823 "cannot create detached operation that already exists"); 824 (void)liveOperations; 825 826 PyOperationRef created = createInstance(std::move(contextRef), operation, 827 std::move(parentKeepAlive)); 828 created->attached = false; 829 return created; 830 } 831 832 void PyOperation::checkValid() const { 833 if (!valid) { 834 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); 835 } 836 } 837 838 void PyOperationBase::print(py::object fileObject, bool binary, 839 llvm::Optional<int64_t> largeElementsLimit, 840 bool enableDebugInfo, bool prettyDebugInfo, 841 bool printGenericOpForm, bool useLocalScope, 842 bool assumeVerified) { 843 PyOperation &operation = getOperation(); 844 operation.checkValid(); 845 if (fileObject.is_none()) 846 fileObject = py::module::import("sys").attr("stdout"); 847 848 if (!assumeVerified && !printGenericOpForm && 849 !mlirOperationVerify(operation)) { 850 std::string message("// Verification failed, printing generic form\n"); 851 if (binary) { 852 fileObject.attr("write")(py::bytes(message)); 853 } else { 854 fileObject.attr("write")(py::str(message)); 855 } 856 printGenericOpForm = true; 857 } 858 859 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 860 if (largeElementsLimit) 861 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 862 if (enableDebugInfo) 863 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); 864 if (printGenericOpForm) 865 mlirOpPrintingFlagsPrintGenericOpForm(flags); 866 867 PyFileAccumulator accum(fileObject, binary); 868 py::gil_scoped_release(); 869 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 870 accum.getUserData()); 871 mlirOpPrintingFlagsDestroy(flags); 872 } 873 874 py::object PyOperationBase::getAsm(bool binary, 875 llvm::Optional<int64_t> largeElementsLimit, 876 bool enableDebugInfo, bool prettyDebugInfo, 877 bool printGenericOpForm, bool useLocalScope, 878 bool assumeVerified) { 879 py::object fileObject; 880 if (binary) { 881 fileObject = py::module::import("io").attr("BytesIO")(); 882 } else { 883 fileObject = py::module::import("io").attr("StringIO")(); 884 } 885 print(fileObject, /*binary=*/binary, 886 /*largeElementsLimit=*/largeElementsLimit, 887 /*enableDebugInfo=*/enableDebugInfo, 888 /*prettyDebugInfo=*/prettyDebugInfo, 889 /*printGenericOpForm=*/printGenericOpForm, 890 /*useLocalScope=*/useLocalScope, 891 /*assumeVerified=*/assumeVerified); 892 893 return fileObject.attr("getvalue")(); 894 } 895 896 void PyOperationBase::moveAfter(PyOperationBase &other) { 897 PyOperation &operation = getOperation(); 898 PyOperation &otherOp = other.getOperation(); 899 operation.checkValid(); 900 otherOp.checkValid(); 901 mlirOperationMoveAfter(operation, otherOp); 902 operation.parentKeepAlive = otherOp.parentKeepAlive; 903 } 904 905 void PyOperationBase::moveBefore(PyOperationBase &other) { 906 PyOperation &operation = getOperation(); 907 PyOperation &otherOp = other.getOperation(); 908 operation.checkValid(); 909 otherOp.checkValid(); 910 mlirOperationMoveBefore(operation, otherOp); 911 operation.parentKeepAlive = otherOp.parentKeepAlive; 912 } 913 914 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() { 915 checkValid(); 916 if (!isAttached()) 917 throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); 918 MlirOperation operation = mlirOperationGetParentOperation(get()); 919 if (mlirOperationIsNull(operation)) 920 return {}; 921 return PyOperation::forOperation(getContext(), operation); 922 } 923 924 PyBlock PyOperation::getBlock() { 925 checkValid(); 926 llvm::Optional<PyOperationRef> parentOperation = getParentOperation(); 927 MlirBlock block = mlirOperationGetBlock(get()); 928 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 929 assert(parentOperation && "Operation has no parent"); 930 return PyBlock{std::move(*parentOperation), block}; 931 } 932 933 py::object PyOperation::getCapsule() { 934 checkValid(); 935 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get())); 936 } 937 938 py::object PyOperation::createFromCapsule(py::object capsule) { 939 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 940 if (mlirOperationIsNull(rawOperation)) 941 throw py::error_already_set(); 942 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 943 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 944 .releaseObject(); 945 } 946 947 py::object PyOperation::create( 948 std::string name, llvm::Optional<std::vector<PyType *>> results, 949 llvm::Optional<std::vector<PyValue *>> operands, 950 llvm::Optional<py::dict> attributes, 951 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 952 DefaultingPyLocation location, py::object maybeIp) { 953 llvm::SmallVector<MlirValue, 4> mlirOperands; 954 llvm::SmallVector<MlirType, 4> mlirResults; 955 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 956 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 957 958 // General parameter validation. 959 if (regions < 0) 960 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); 961 962 // Unpack/validate operands. 963 if (operands) { 964 mlirOperands.reserve(operands->size()); 965 for (PyValue *operand : *operands) { 966 if (!operand) 967 throw SetPyError(PyExc_ValueError, "operand value cannot be None"); 968 mlirOperands.push_back(operand->get()); 969 } 970 } 971 972 // Unpack/validate results. 973 if (results) { 974 mlirResults.reserve(results->size()); 975 for (PyType *result : *results) { 976 // TODO: Verify result type originate from the same context. 977 if (!result) 978 throw SetPyError(PyExc_ValueError, "result type cannot be None"); 979 mlirResults.push_back(*result); 980 } 981 } 982 // Unpack/validate attributes. 983 if (attributes) { 984 mlirAttributes.reserve(attributes->size()); 985 for (auto &it : *attributes) { 986 std::string key; 987 try { 988 key = it.first.cast<std::string>(); 989 } catch (py::cast_error &err) { 990 std::string msg = "Invalid attribute key (not a string) when " 991 "attempting to create the operation \"" + 992 name + "\" (" + err.what() + ")"; 993 throw py::cast_error(msg); 994 } 995 try { 996 auto &attribute = it.second.cast<PyAttribute &>(); 997 // TODO: Verify attribute originates from the same context. 998 mlirAttributes.emplace_back(std::move(key), attribute); 999 } catch (py::reference_cast_error &) { 1000 // This exception seems thrown when the value is "None". 1001 std::string msg = 1002 "Found an invalid (`None`?) attribute value for the key \"" + key + 1003 "\" when attempting to create the operation \"" + name + "\""; 1004 throw py::cast_error(msg); 1005 } catch (py::cast_error &err) { 1006 std::string msg = "Invalid attribute value for the key \"" + key + 1007 "\" when attempting to create the operation \"" + 1008 name + "\" (" + err.what() + ")"; 1009 throw py::cast_error(msg); 1010 } 1011 } 1012 } 1013 // Unpack/validate successors. 1014 if (successors) { 1015 mlirSuccessors.reserve(successors->size()); 1016 for (auto *successor : *successors) { 1017 // TODO: Verify successor originate from the same context. 1018 if (!successor) 1019 throw SetPyError(PyExc_ValueError, "successor block cannot be None"); 1020 mlirSuccessors.push_back(successor->get()); 1021 } 1022 } 1023 1024 // Apply unpacked/validated to the operation state. Beyond this 1025 // point, exceptions cannot be thrown or else the state will leak. 1026 MlirOperationState state = 1027 mlirOperationStateGet(toMlirStringRef(name), location); 1028 if (!mlirOperands.empty()) 1029 mlirOperationStateAddOperands(&state, mlirOperands.size(), 1030 mlirOperands.data()); 1031 if (!mlirResults.empty()) 1032 mlirOperationStateAddResults(&state, mlirResults.size(), 1033 mlirResults.data()); 1034 if (!mlirAttributes.empty()) { 1035 // Note that the attribute names directly reference bytes in 1036 // mlirAttributes, so that vector must not be changed from here 1037 // on. 1038 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1039 mlirNamedAttributes.reserve(mlirAttributes.size()); 1040 for (auto &it : mlirAttributes) 1041 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1042 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1043 toMlirStringRef(it.first)), 1044 it.second)); 1045 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1046 mlirNamedAttributes.data()); 1047 } 1048 if (!mlirSuccessors.empty()) 1049 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1050 mlirSuccessors.data()); 1051 if (regions) { 1052 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1053 mlirRegions.resize(regions); 1054 for (int i = 0; i < regions; ++i) 1055 mlirRegions[i] = mlirRegionCreate(); 1056 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1057 mlirRegions.data()); 1058 } 1059 1060 // Construct the operation. 1061 MlirOperation operation = mlirOperationCreate(&state); 1062 PyOperationRef created = 1063 PyOperation::createDetached(location->getContext(), operation); 1064 1065 // InsertPoint active? 1066 if (!maybeIp.is(py::cast(false))) { 1067 PyInsertionPoint *ip; 1068 if (maybeIp.is_none()) { 1069 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1070 } else { 1071 ip = py::cast<PyInsertionPoint *>(maybeIp); 1072 } 1073 if (ip) 1074 ip->insert(*created.get()); 1075 } 1076 1077 return created->createOpView(); 1078 } 1079 1080 py::object PyOperation::createOpView() { 1081 checkValid(); 1082 MlirIdentifier ident = mlirOperationGetName(get()); 1083 MlirStringRef identStr = mlirIdentifierStr(ident); 1084 auto opViewClass = PyGlobals::get().lookupRawOpViewClass( 1085 StringRef(identStr.data, identStr.length)); 1086 if (opViewClass) 1087 return (*opViewClass)(getRef().getObject()); 1088 return py::cast(PyOpView(getRef().getObject())); 1089 } 1090 1091 void PyOperation::erase() { 1092 checkValid(); 1093 // TODO: Fix memory hazards when erasing a tree of operations for which a deep 1094 // Python reference to a child operation is live. All children should also 1095 // have their `valid` bit set to false. 1096 auto &liveOperations = getContext()->liveOperations; 1097 if (liveOperations.count(operation.ptr)) 1098 liveOperations.erase(operation.ptr); 1099 mlirOperationDestroy(operation); 1100 valid = false; 1101 } 1102 1103 //------------------------------------------------------------------------------ 1104 // PyOpView 1105 //------------------------------------------------------------------------------ 1106 1107 py::object 1108 PyOpView::buildGeneric(py::object cls, py::list resultTypeList, 1109 py::list operandList, 1110 llvm::Optional<py::dict> attributes, 1111 llvm::Optional<std::vector<PyBlock *>> successors, 1112 llvm::Optional<int> regions, 1113 DefaultingPyLocation location, py::object maybeIp) { 1114 PyMlirContextRef context = location->getContext(); 1115 // Class level operation construction metadata. 1116 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME")); 1117 // Operand and result segment specs are either none, which does no 1118 // variadic unpacking, or a list of ints with segment sizes, where each 1119 // element is either a positive number (typically 1 for a scalar) or -1 to 1120 // indicate that it is derived from the length of the same-indexed operand 1121 // or result (implying that it is a list at that position). 1122 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); 1123 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); 1124 1125 std::vector<uint32_t> operandSegmentLengths; 1126 std::vector<uint32_t> resultSegmentLengths; 1127 1128 // Validate/determine region count. 1129 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 1130 int opMinRegionCount = std::get<0>(opRegionSpec); 1131 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1132 if (!regions) { 1133 regions = opMinRegionCount; 1134 } 1135 if (*regions < opMinRegionCount) { 1136 throw py::value_error( 1137 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1138 llvm::Twine(opMinRegionCount) + 1139 " regions but was built with regions=" + llvm::Twine(*regions)) 1140 .str()); 1141 } 1142 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1143 throw py::value_error( 1144 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1145 llvm::Twine(opMinRegionCount) + 1146 " regions but was built with regions=" + llvm::Twine(*regions)) 1147 .str()); 1148 } 1149 1150 // Unpack results. 1151 std::vector<PyType *> resultTypes; 1152 resultTypes.reserve(resultTypeList.size()); 1153 if (resultSegmentSpecObj.is_none()) { 1154 // Non-variadic result unpacking. 1155 for (auto it : llvm::enumerate(resultTypeList)) { 1156 try { 1157 resultTypes.push_back(py::cast<PyType *>(it.value())); 1158 if (!resultTypes.back()) 1159 throw py::cast_error(); 1160 } catch (py::cast_error &err) { 1161 throw py::value_error((llvm::Twine("Result ") + 1162 llvm::Twine(it.index()) + " of operation \"" + 1163 name + "\" must be a Type (" + err.what() + ")") 1164 .str()); 1165 } 1166 } 1167 } else { 1168 // Sized result unpacking. 1169 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj); 1170 if (resultSegmentSpec.size() != resultTypeList.size()) { 1171 throw py::value_error((llvm::Twine("Operation \"") + name + 1172 "\" requires " + 1173 llvm::Twine(resultSegmentSpec.size()) + 1174 " result segments but was provided " + 1175 llvm::Twine(resultTypeList.size())) 1176 .str()); 1177 } 1178 resultSegmentLengths.reserve(resultTypeList.size()); 1179 for (auto it : 1180 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1181 int segmentSpec = std::get<1>(it.value()); 1182 if (segmentSpec == 1 || segmentSpec == 0) { 1183 // Unpack unary element. 1184 try { 1185 auto *resultType = py::cast<PyType *>(std::get<0>(it.value())); 1186 if (resultType) { 1187 resultTypes.push_back(resultType); 1188 resultSegmentLengths.push_back(1); 1189 } else if (segmentSpec == 0) { 1190 // Allowed to be optional. 1191 resultSegmentLengths.push_back(0); 1192 } else { 1193 throw py::cast_error("was None and result is not optional"); 1194 } 1195 } catch (py::cast_error &err) { 1196 throw py::value_error((llvm::Twine("Result ") + 1197 llvm::Twine(it.index()) + " of operation \"" + 1198 name + "\" must be a Type (" + err.what() + 1199 ")") 1200 .str()); 1201 } 1202 } else if (segmentSpec == -1) { 1203 // Unpack sequence by appending. 1204 try { 1205 if (std::get<0>(it.value()).is_none()) { 1206 // Treat it as an empty list. 1207 resultSegmentLengths.push_back(0); 1208 } else { 1209 // Unpack the list. 1210 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1211 for (py::object segmentItem : segment) { 1212 resultTypes.push_back(py::cast<PyType *>(segmentItem)); 1213 if (!resultTypes.back()) { 1214 throw py::cast_error("contained a None item"); 1215 } 1216 } 1217 resultSegmentLengths.push_back(segment.size()); 1218 } 1219 } catch (std::exception &err) { 1220 // NOTE: Sloppy to be using a catch-all here, but there are at least 1221 // three different unrelated exceptions that can be thrown in the 1222 // above "casts". Just keep the scope above small and catch them all. 1223 throw py::value_error((llvm::Twine("Result ") + 1224 llvm::Twine(it.index()) + " of operation \"" + 1225 name + "\" must be a Sequence of Types (" + 1226 err.what() + ")") 1227 .str()); 1228 } 1229 } else { 1230 throw py::value_error("Unexpected segment spec"); 1231 } 1232 } 1233 } 1234 1235 // Unpack operands. 1236 std::vector<PyValue *> operands; 1237 operands.reserve(operands.size()); 1238 if (operandSegmentSpecObj.is_none()) { 1239 // Non-sized operand unpacking. 1240 for (auto it : llvm::enumerate(operandList)) { 1241 try { 1242 operands.push_back(py::cast<PyValue *>(it.value())); 1243 if (!operands.back()) 1244 throw py::cast_error(); 1245 } catch (py::cast_error &err) { 1246 throw py::value_error((llvm::Twine("Operand ") + 1247 llvm::Twine(it.index()) + " of operation \"" + 1248 name + "\" must be a Value (" + err.what() + ")") 1249 .str()); 1250 } 1251 } 1252 } else { 1253 // Sized operand unpacking. 1254 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj); 1255 if (operandSegmentSpec.size() != operandList.size()) { 1256 throw py::value_error((llvm::Twine("Operation \"") + name + 1257 "\" requires " + 1258 llvm::Twine(operandSegmentSpec.size()) + 1259 "operand segments but was provided " + 1260 llvm::Twine(operandList.size())) 1261 .str()); 1262 } 1263 operandSegmentLengths.reserve(operandList.size()); 1264 for (auto it : 1265 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1266 int segmentSpec = std::get<1>(it.value()); 1267 if (segmentSpec == 1 || segmentSpec == 0) { 1268 // Unpack unary element. 1269 try { 1270 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); 1271 if (operandValue) { 1272 operands.push_back(operandValue); 1273 operandSegmentLengths.push_back(1); 1274 } else if (segmentSpec == 0) { 1275 // Allowed to be optional. 1276 operandSegmentLengths.push_back(0); 1277 } else { 1278 throw py::cast_error("was None and operand is not optional"); 1279 } 1280 } catch (py::cast_error &err) { 1281 throw py::value_error((llvm::Twine("Operand ") + 1282 llvm::Twine(it.index()) + " of operation \"" + 1283 name + "\" must be a Value (" + err.what() + 1284 ")") 1285 .str()); 1286 } 1287 } else if (segmentSpec == -1) { 1288 // Unpack sequence by appending. 1289 try { 1290 if (std::get<0>(it.value()).is_none()) { 1291 // Treat it as an empty list. 1292 operandSegmentLengths.push_back(0); 1293 } else { 1294 // Unpack the list. 1295 auto segment = py::cast<py::sequence>(std::get<0>(it.value())); 1296 for (py::object segmentItem : segment) { 1297 operands.push_back(py::cast<PyValue *>(segmentItem)); 1298 if (!operands.back()) { 1299 throw py::cast_error("contained a None item"); 1300 } 1301 } 1302 operandSegmentLengths.push_back(segment.size()); 1303 } 1304 } catch (std::exception &err) { 1305 // NOTE: Sloppy to be using a catch-all here, but there are at least 1306 // three different unrelated exceptions that can be thrown in the 1307 // above "casts". Just keep the scope above small and catch them all. 1308 throw py::value_error((llvm::Twine("Operand ") + 1309 llvm::Twine(it.index()) + " of operation \"" + 1310 name + "\" must be a Sequence of Values (" + 1311 err.what() + ")") 1312 .str()); 1313 } 1314 } else { 1315 throw py::value_error("Unexpected segment spec"); 1316 } 1317 } 1318 } 1319 1320 // Merge operand/result segment lengths into attributes if needed. 1321 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 1322 // Dup. 1323 if (attributes) { 1324 attributes = py::dict(*attributes); 1325 } else { 1326 attributes = py::dict(); 1327 } 1328 if (attributes->contains("result_segment_sizes") || 1329 attributes->contains("operand_segment_sizes")) { 1330 throw py::value_error("Manually setting a 'result_segment_sizes' or " 1331 "'operand_segment_sizes' attribute is unsupported. " 1332 "Use Operation.create for such low-level access."); 1333 } 1334 1335 // Add result_segment_sizes attribute. 1336 if (!resultSegmentLengths.empty()) { 1337 int64_t size = resultSegmentLengths.size(); 1338 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1339 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1340 resultSegmentLengths.size(), resultSegmentLengths.data()); 1341 (*attributes)["result_segment_sizes"] = 1342 PyAttribute(context, segmentLengthAttr); 1343 } 1344 1345 // Add operand_segment_sizes attribute. 1346 if (!operandSegmentLengths.empty()) { 1347 int64_t size = operandSegmentLengths.size(); 1348 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( 1349 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), 1350 operandSegmentLengths.size(), operandSegmentLengths.data()); 1351 (*attributes)["operand_segment_sizes"] = 1352 PyAttribute(context, segmentLengthAttr); 1353 } 1354 } 1355 1356 // Delegate to create. 1357 return PyOperation::create(std::move(name), 1358 /*results=*/std::move(resultTypes), 1359 /*operands=*/std::move(operands), 1360 /*attributes=*/std::move(attributes), 1361 /*successors=*/std::move(successors), 1362 /*regions=*/*regions, location, maybeIp); 1363 } 1364 1365 PyOpView::PyOpView(py::object operationObject) 1366 // Casting through the PyOperationBase base-class and then back to the 1367 // Operation lets us accept any PyOperationBase subclass. 1368 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), 1369 operationObject(operation.getRef().getObject()) {} 1370 1371 py::object PyOpView::createRawSubclass(py::object userClass) { 1372 // This is... a little gross. The typical pattern is to have a pure python 1373 // class that extends OpView like: 1374 // class AddFOp(_cext.ir.OpView): 1375 // def __init__(self, loc, lhs, rhs): 1376 // operation = loc.context.create_operation( 1377 // "addf", lhs, rhs, results=[lhs.type]) 1378 // super().__init__(operation) 1379 // 1380 // I.e. The goal of the user facing type is to provide a nice constructor 1381 // that has complete freedom for the op under construction. This is at odds 1382 // with our other desire to sometimes create this object by just passing an 1383 // operation (to initialize the base class). We could do *arg and **kwargs 1384 // munging to try to make it work, but instead, we synthesize a new class 1385 // on the fly which extends this user class (AddFOp in this example) and 1386 // *give it* the base class's __init__ method, thus bypassing the 1387 // intermediate subclass's __init__ method entirely. While slightly, 1388 // underhanded, this is safe/legal because the type hierarchy has not changed 1389 // (we just added a new leaf) and we aren't mucking around with __new__. 1390 // Typically, this new class will be stored on the original as "_Raw" and will 1391 // be used for casts and other things that need a variant of the class that 1392 // is initialized purely from an operation. 1393 py::object parentMetaclass = 1394 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); 1395 py::dict attributes; 1396 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from 1397 // now. 1398 // auto opViewType = py::type::of<PyOpView>(); 1399 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); 1400 attributes["__init__"] = opViewType.attr("__init__"); 1401 py::str origName = userClass.attr("__name__"); 1402 py::str newName = py::str("_") + origName; 1403 return parentMetaclass(newName, py::make_tuple(userClass), attributes); 1404 } 1405 1406 //------------------------------------------------------------------------------ 1407 // PyInsertionPoint. 1408 //------------------------------------------------------------------------------ 1409 1410 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 1411 1412 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 1413 : refOperation(beforeOperationBase.getOperation().getRef()), 1414 block((*refOperation)->getBlock()) {} 1415 1416 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 1417 PyOperation &operation = operationBase.getOperation(); 1418 if (operation.isAttached()) 1419 throw SetPyError(PyExc_ValueError, 1420 "Attempt to insert operation that is already attached"); 1421 block.getParentOperation()->checkValid(); 1422 MlirOperation beforeOp = {nullptr}; 1423 if (refOperation) { 1424 // Insert before operation. 1425 (*refOperation)->checkValid(); 1426 beforeOp = (*refOperation)->get(); 1427 } else { 1428 // Insert at end (before null) is only valid if the block does not 1429 // already end in a known terminator (violating this will cause assertion 1430 // failures later). 1431 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 1432 throw py::index_error("Cannot insert operation at the end of a block " 1433 "that already has a terminator. Did you mean to " 1434 "use 'InsertionPoint.at_block_terminator(block)' " 1435 "versus 'InsertionPoint(block)'?"); 1436 } 1437 } 1438 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 1439 operation.setAttached(); 1440 } 1441 1442 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 1443 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 1444 if (mlirOperationIsNull(firstOp)) { 1445 // Just insert at end. 1446 return PyInsertionPoint(block); 1447 } 1448 1449 // Insert before first op. 1450 PyOperationRef firstOpRef = PyOperation::forOperation( 1451 block.getParentOperation()->getContext(), firstOp); 1452 return PyInsertionPoint{block, std::move(firstOpRef)}; 1453 } 1454 1455 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 1456 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 1457 if (mlirOperationIsNull(terminator)) 1458 throw SetPyError(PyExc_ValueError, "Block has no terminator"); 1459 PyOperationRef terminatorOpRef = PyOperation::forOperation( 1460 block.getParentOperation()->getContext(), terminator); 1461 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 1462 } 1463 1464 py::object PyInsertionPoint::contextEnter() { 1465 return PyThreadContextEntry::pushInsertionPoint(*this); 1466 } 1467 1468 void PyInsertionPoint::contextExit(pybind11::object excType, 1469 pybind11::object excVal, 1470 pybind11::object excTb) { 1471 PyThreadContextEntry::popInsertionPoint(*this); 1472 } 1473 1474 //------------------------------------------------------------------------------ 1475 // PyAttribute. 1476 //------------------------------------------------------------------------------ 1477 1478 bool PyAttribute::operator==(const PyAttribute &other) { 1479 return mlirAttributeEqual(attr, other.attr); 1480 } 1481 1482 py::object PyAttribute::getCapsule() { 1483 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); 1484 } 1485 1486 PyAttribute PyAttribute::createFromCapsule(py::object capsule) { 1487 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 1488 if (mlirAttributeIsNull(rawAttr)) 1489 throw py::error_already_set(); 1490 return PyAttribute( 1491 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 1492 } 1493 1494 //------------------------------------------------------------------------------ 1495 // PyNamedAttribute. 1496 //------------------------------------------------------------------------------ 1497 1498 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 1499 : ownedName(new std::string(std::move(ownedName))) { 1500 namedAttr = mlirNamedAttributeGet( 1501 mlirIdentifierGet(mlirAttributeGetContext(attr), 1502 toMlirStringRef(*this->ownedName)), 1503 attr); 1504 } 1505 1506 //------------------------------------------------------------------------------ 1507 // PyType. 1508 //------------------------------------------------------------------------------ 1509 1510 bool PyType::operator==(const PyType &other) { 1511 return mlirTypeEqual(type, other.type); 1512 } 1513 1514 py::object PyType::getCapsule() { 1515 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); 1516 } 1517 1518 PyType PyType::createFromCapsule(py::object capsule) { 1519 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 1520 if (mlirTypeIsNull(rawType)) 1521 throw py::error_already_set(); 1522 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 1523 rawType); 1524 } 1525 1526 //------------------------------------------------------------------------------ 1527 // PyValue and subclases. 1528 //------------------------------------------------------------------------------ 1529 1530 pybind11::object PyValue::getCapsule() { 1531 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get())); 1532 } 1533 1534 PyValue PyValue::createFromCapsule(pybind11::object capsule) { 1535 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 1536 if (mlirValueIsNull(value)) 1537 throw py::error_already_set(); 1538 MlirOperation owner; 1539 if (mlirValueIsAOpResult(value)) 1540 owner = mlirOpResultGetOwner(value); 1541 if (mlirValueIsABlockArgument(value)) 1542 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 1543 if (mlirOperationIsNull(owner)) 1544 throw py::error_already_set(); 1545 MlirContext ctx = mlirOperationGetContext(owner); 1546 PyOperationRef ownerRef = 1547 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 1548 return PyValue(ownerRef, value); 1549 } 1550 1551 //------------------------------------------------------------------------------ 1552 // PySymbolTable. 1553 //------------------------------------------------------------------------------ 1554 1555 PySymbolTable::PySymbolTable(PyOperationBase &operation) 1556 : operation(operation.getOperation().getRef()) { 1557 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 1558 if (mlirSymbolTableIsNull(symbolTable)) { 1559 throw py::cast_error("Operation is not a Symbol Table."); 1560 } 1561 } 1562 1563 py::object PySymbolTable::dunderGetItem(const std::string &name) { 1564 operation->checkValid(); 1565 MlirOperation symbol = mlirSymbolTableLookup( 1566 symbolTable, mlirStringRefCreate(name.data(), name.length())); 1567 if (mlirOperationIsNull(symbol)) 1568 throw py::key_error("Symbol '" + name + "' not in the symbol table."); 1569 1570 return PyOperation::forOperation(operation->getContext(), symbol, 1571 operation.getObject()) 1572 ->createOpView(); 1573 } 1574 1575 void PySymbolTable::erase(PyOperationBase &symbol) { 1576 operation->checkValid(); 1577 symbol.getOperation().checkValid(); 1578 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 1579 // The operation is also erased, so we must invalidate it. There may be Python 1580 // references to this operation so we don't want to delete it from the list of 1581 // live operations here. 1582 symbol.getOperation().valid = false; 1583 } 1584 1585 void PySymbolTable::dunderDel(const std::string &name) { 1586 py::object operation = dunderGetItem(name); 1587 erase(py::cast<PyOperationBase &>(operation)); 1588 } 1589 1590 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { 1591 operation->checkValid(); 1592 symbol.getOperation().checkValid(); 1593 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 1594 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 1595 if (mlirAttributeIsNull(symbolAttr)) 1596 throw py::value_error("Expected operation to have a symbol name."); 1597 return PyAttribute( 1598 symbol.getOperation().getContext(), 1599 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); 1600 } 1601 1602 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 1603 // Op must already be a symbol. 1604 PyOperation &operation = symbol.getOperation(); 1605 operation.checkValid(); 1606 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1607 MlirAttribute existingNameAttr = 1608 mlirOperationGetAttributeByName(operation.get(), attrName); 1609 if (mlirAttributeIsNull(existingNameAttr)) 1610 throw py::value_error("Expected operation to have a symbol name."); 1611 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); 1612 } 1613 1614 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 1615 const std::string &name) { 1616 // Op must already be a symbol. 1617 PyOperation &operation = symbol.getOperation(); 1618 operation.checkValid(); 1619 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 1620 MlirAttribute existingNameAttr = 1621 mlirOperationGetAttributeByName(operation.get(), attrName); 1622 if (mlirAttributeIsNull(existingNameAttr)) 1623 throw py::value_error("Expected operation to have a symbol name."); 1624 MlirAttribute newNameAttr = 1625 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 1626 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 1627 } 1628 1629 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 1630 PyOperation &operation = symbol.getOperation(); 1631 operation.checkValid(); 1632 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1633 MlirAttribute existingVisAttr = 1634 mlirOperationGetAttributeByName(operation.get(), attrName); 1635 if (mlirAttributeIsNull(existingVisAttr)) 1636 throw py::value_error("Expected operation to have a symbol visibility."); 1637 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); 1638 } 1639 1640 void PySymbolTable::setVisibility(PyOperationBase &symbol, 1641 const std::string &visibility) { 1642 if (visibility != "public" && visibility != "private" && 1643 visibility != "nested") 1644 throw py::value_error( 1645 "Expected visibility to be 'public', 'private' or 'nested'"); 1646 PyOperation &operation = symbol.getOperation(); 1647 operation.checkValid(); 1648 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 1649 MlirAttribute existingVisAttr = 1650 mlirOperationGetAttributeByName(operation.get(), attrName); 1651 if (mlirAttributeIsNull(existingVisAttr)) 1652 throw py::value_error("Expected operation to have a symbol visibility."); 1653 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 1654 toMlirStringRef(visibility)); 1655 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 1656 } 1657 1658 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 1659 const std::string &newSymbol, 1660 PyOperationBase &from) { 1661 PyOperation &fromOperation = from.getOperation(); 1662 fromOperation.checkValid(); 1663 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 1664 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 1665 from.getOperation()))) 1666 1667 throw py::value_error("Symbol rename failed"); 1668 } 1669 1670 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 1671 bool allSymUsesVisible, 1672 py::object callback) { 1673 PyOperation &fromOperation = from.getOperation(); 1674 fromOperation.checkValid(); 1675 struct UserData { 1676 PyMlirContextRef context; 1677 py::object callback; 1678 bool gotException; 1679 std::string exceptionWhat; 1680 py::object exceptionType; 1681 }; 1682 UserData userData{ 1683 fromOperation.getContext(), std::move(callback), false, {}, {}}; 1684 mlirSymbolTableWalkSymbolTables( 1685 fromOperation.get(), allSymUsesVisible, 1686 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 1687 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 1688 auto pyFoundOp = 1689 PyOperation::forOperation(calleeUserData->context, foundOp); 1690 if (calleeUserData->gotException) 1691 return; 1692 try { 1693 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 1694 } catch (py::error_already_set &e) { 1695 calleeUserData->gotException = true; 1696 calleeUserData->exceptionWhat = e.what(); 1697 calleeUserData->exceptionType = e.type(); 1698 } 1699 }, 1700 static_cast<void *>(&userData)); 1701 if (userData.gotException) { 1702 std::string message("Exception raised in callback: "); 1703 message.append(userData.exceptionWhat); 1704 throw std::runtime_error(std::move(message)); 1705 } 1706 } 1707 1708 namespace { 1709 /// CRTP base class for Python MLIR values that subclass Value and should be 1710 /// castable from it. The value hierarchy is one level deep and is not supposed 1711 /// to accommodate other levels unless core MLIR changes. 1712 template <typename DerivedTy> 1713 class PyConcreteValue : public PyValue { 1714 public: 1715 // Derived classes must define statics for: 1716 // IsAFunctionTy isaFunction 1717 // const char *pyClassName 1718 // and redefine bindDerived. 1719 using ClassTy = py::class_<DerivedTy, PyValue>; 1720 using IsAFunctionTy = bool (*)(MlirValue); 1721 1722 PyConcreteValue() = default; 1723 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1724 : PyValue(operationRef, value) {} 1725 PyConcreteValue(PyValue &orig) 1726 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1727 1728 /// Attempts to cast the original value to the derived type and throws on 1729 /// type mismatches. 1730 static MlirValue castFrom(PyValue &orig) { 1731 if (!DerivedTy::isaFunction(orig.get())) { 1732 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 1733 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + 1734 DerivedTy::pyClassName + 1735 " (from " + origRepr + ")"); 1736 } 1737 return orig.get(); 1738 } 1739 1740 /// Binds the Python module objects to functions of this class. 1741 static void bind(py::module &m) { 1742 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 1743 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value")); 1744 cls.def_static( 1745 "isinstance", 1746 [](PyValue &otherValue) -> bool { 1747 return DerivedTy::isaFunction(otherValue); 1748 }, 1749 py::arg("other_value")); 1750 DerivedTy::bindDerived(cls); 1751 } 1752 1753 /// Implemented by derived classes to add methods to the Python subclass. 1754 static void bindDerived(ClassTy &m) {} 1755 }; 1756 1757 /// Python wrapper for MlirBlockArgument. 1758 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 1759 public: 1760 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 1761 static constexpr const char *pyClassName = "BlockArgument"; 1762 using PyConcreteValue::PyConcreteValue; 1763 1764 static void bindDerived(ClassTy &c) { 1765 c.def_property_readonly("owner", [](PyBlockArgument &self) { 1766 return PyBlock(self.getParentOperation(), 1767 mlirBlockArgumentGetOwner(self.get())); 1768 }); 1769 c.def_property_readonly("arg_number", [](PyBlockArgument &self) { 1770 return mlirBlockArgumentGetArgNumber(self.get()); 1771 }); 1772 c.def( 1773 "set_type", 1774 [](PyBlockArgument &self, PyType type) { 1775 return mlirBlockArgumentSetType(self.get(), type); 1776 }, 1777 py::arg("type")); 1778 } 1779 }; 1780 1781 /// Python wrapper for MlirOpResult. 1782 class PyOpResult : public PyConcreteValue<PyOpResult> { 1783 public: 1784 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1785 static constexpr const char *pyClassName = "OpResult"; 1786 using PyConcreteValue::PyConcreteValue; 1787 1788 static void bindDerived(ClassTy &c) { 1789 c.def_property_readonly("owner", [](PyOpResult &self) { 1790 assert( 1791 mlirOperationEqual(self.getParentOperation()->get(), 1792 mlirOpResultGetOwner(self.get())) && 1793 "expected the owner of the value in Python to match that in the IR"); 1794 return self.getParentOperation().getObject(); 1795 }); 1796 c.def_property_readonly("result_number", [](PyOpResult &self) { 1797 return mlirOpResultGetResultNumber(self.get()); 1798 }); 1799 } 1800 }; 1801 1802 /// Returns the list of types of the values held by container. 1803 template <typename Container> 1804 static std::vector<PyType> getValueTypes(Container &container, 1805 PyMlirContextRef &context) { 1806 std::vector<PyType> result; 1807 result.reserve(container.getNumElements()); 1808 for (int i = 0, e = container.getNumElements(); i < e; ++i) { 1809 result.push_back( 1810 PyType(context, mlirValueGetType(container.getElement(i).get()))); 1811 } 1812 return result; 1813 } 1814 1815 /// A list of block arguments. Internally, these are stored as consecutive 1816 /// elements, random access is cheap. The argument list is associated with the 1817 /// operation that contains the block (detached blocks are not allowed in 1818 /// Python bindings) and extends its lifetime. 1819 class PyBlockArgumentList 1820 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 1821 public: 1822 static constexpr const char *pyClassName = "BlockArgumentList"; 1823 1824 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 1825 intptr_t startIndex = 0, intptr_t length = -1, 1826 intptr_t step = 1) 1827 : Sliceable(startIndex, 1828 length == -1 ? mlirBlockGetNumArguments(block) : length, 1829 step), 1830 operation(std::move(operation)), block(block) {} 1831 1832 /// Returns the number of arguments in the list. 1833 intptr_t getNumElements() { 1834 operation->checkValid(); 1835 return mlirBlockGetNumArguments(block); 1836 } 1837 1838 /// Returns `pos`-the element in the list. Asserts on out-of-bounds. 1839 PyBlockArgument getElement(intptr_t pos) { 1840 MlirValue argument = mlirBlockGetArgument(block, pos); 1841 return PyBlockArgument(operation, argument); 1842 } 1843 1844 /// Returns a sublist of this list. 1845 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 1846 intptr_t step) { 1847 return PyBlockArgumentList(operation, block, startIndex, length, step); 1848 } 1849 1850 static void bindDerived(ClassTy &c) { 1851 c.def_property_readonly("types", [](PyBlockArgumentList &self) { 1852 return getValueTypes(self, self.operation->getContext()); 1853 }); 1854 } 1855 1856 private: 1857 PyOperationRef operation; 1858 MlirBlock block; 1859 }; 1860 1861 /// A list of operation operands. Internally, these are stored as consecutive 1862 /// elements, random access is cheap. The result list is associated with the 1863 /// operation whose results these are, and extends the lifetime of this 1864 /// operation. 1865 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 1866 public: 1867 static constexpr const char *pyClassName = "OpOperandList"; 1868 1869 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 1870 intptr_t length = -1, intptr_t step = 1) 1871 : Sliceable(startIndex, 1872 length == -1 ? mlirOperationGetNumOperands(operation->get()) 1873 : length, 1874 step), 1875 operation(operation) {} 1876 1877 intptr_t getNumElements() { 1878 operation->checkValid(); 1879 return mlirOperationGetNumOperands(operation->get()); 1880 } 1881 1882 PyValue getElement(intptr_t pos) { 1883 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 1884 MlirOperation owner; 1885 if (mlirValueIsAOpResult(operand)) 1886 owner = mlirOpResultGetOwner(operand); 1887 else if (mlirValueIsABlockArgument(operand)) 1888 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 1889 else 1890 assert(false && "Value must be an block arg or op result."); 1891 PyOperationRef pyOwner = 1892 PyOperation::forOperation(operation->getContext(), owner); 1893 return PyValue(pyOwner, operand); 1894 } 1895 1896 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1897 return PyOpOperandList(operation, startIndex, length, step); 1898 } 1899 1900 void dunderSetItem(intptr_t index, PyValue value) { 1901 index = wrapIndex(index); 1902 mlirOperationSetOperand(operation->get(), index, value.get()); 1903 } 1904 1905 static void bindDerived(ClassTy &c) { 1906 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 1907 } 1908 1909 private: 1910 PyOperationRef operation; 1911 }; 1912 1913 /// A list of operation results. Internally, these are stored as consecutive 1914 /// elements, random access is cheap. The result list is associated with the 1915 /// operation whose results these are, and extends the lifetime of this 1916 /// operation. 1917 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1918 public: 1919 static constexpr const char *pyClassName = "OpResultList"; 1920 1921 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1922 intptr_t length = -1, intptr_t step = 1) 1923 : Sliceable(startIndex, 1924 length == -1 ? mlirOperationGetNumResults(operation->get()) 1925 : length, 1926 step), 1927 operation(operation) {} 1928 1929 intptr_t getNumElements() { 1930 operation->checkValid(); 1931 return mlirOperationGetNumResults(operation->get()); 1932 } 1933 1934 PyOpResult getElement(intptr_t index) { 1935 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1936 return PyOpResult(value); 1937 } 1938 1939 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1940 return PyOpResultList(operation, startIndex, length, step); 1941 } 1942 1943 static void bindDerived(ClassTy &c) { 1944 c.def_property_readonly("types", [](PyOpResultList &self) { 1945 return getValueTypes(self, self.operation->getContext()); 1946 }); 1947 } 1948 1949 private: 1950 PyOperationRef operation; 1951 }; 1952 1953 /// A list of operation attributes. Can be indexed by name, producing 1954 /// attributes, or by index, producing named attributes. 1955 class PyOpAttributeMap { 1956 public: 1957 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} 1958 1959 PyAttribute dunderGetItemNamed(const std::string &name) { 1960 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 1961 toMlirStringRef(name)); 1962 if (mlirAttributeIsNull(attr)) { 1963 throw SetPyError(PyExc_KeyError, 1964 "attempt to access a non-existent attribute"); 1965 } 1966 return PyAttribute(operation->getContext(), attr); 1967 } 1968 1969 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 1970 if (index < 0 || index >= dunderLen()) { 1971 throw SetPyError(PyExc_IndexError, 1972 "attempt to access out of bounds attribute"); 1973 } 1974 MlirNamedAttribute namedAttr = 1975 mlirOperationGetAttribute(operation->get(), index); 1976 return PyNamedAttribute( 1977 namedAttr.attribute, 1978 std::string(mlirIdentifierStr(namedAttr.name).data, 1979 mlirIdentifierStr(namedAttr.name).length)); 1980 } 1981 1982 void dunderSetItem(const std::string &name, PyAttribute attr) { 1983 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 1984 attr); 1985 } 1986 1987 void dunderDelItem(const std::string &name) { 1988 int removed = mlirOperationRemoveAttributeByName(operation->get(), 1989 toMlirStringRef(name)); 1990 if (!removed) 1991 throw SetPyError(PyExc_KeyError, 1992 "attempt to delete a non-existent attribute"); 1993 } 1994 1995 intptr_t dunderLen() { 1996 return mlirOperationGetNumAttributes(operation->get()); 1997 } 1998 1999 bool dunderContains(const std::string &name) { 2000 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2001 operation->get(), toMlirStringRef(name))); 2002 } 2003 2004 static void bind(py::module &m) { 2005 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local()) 2006 .def("__contains__", &PyOpAttributeMap::dunderContains) 2007 .def("__len__", &PyOpAttributeMap::dunderLen) 2008 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2009 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2010 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2011 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2012 } 2013 2014 private: 2015 PyOperationRef operation; 2016 }; 2017 2018 } // namespace 2019 2020 //------------------------------------------------------------------------------ 2021 // Populates the core exports of the 'ir' submodule. 2022 //------------------------------------------------------------------------------ 2023 2024 void mlir::python::populateIRCore(py::module &m) { 2025 //---------------------------------------------------------------------------- 2026 // Mapping of MlirContext. 2027 //---------------------------------------------------------------------------- 2028 py::class_<PyMlirContext>(m, "Context", py::module_local()) 2029 .def(py::init<>(&PyMlirContext::createNewContextForInit)) 2030 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2031 .def("_get_context_again", 2032 [](PyMlirContext &self) { 2033 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2034 return ref.releaseObject(); 2035 }) 2036 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2037 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2038 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2039 &PyMlirContext::getCapsule) 2040 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2041 .def("__enter__", &PyMlirContext::contextEnter) 2042 .def("__exit__", &PyMlirContext::contextExit) 2043 .def_property_readonly_static( 2044 "current", 2045 [](py::object & /*class*/) { 2046 auto *context = PyThreadContextEntry::getDefaultContext(); 2047 if (!context) 2048 throw SetPyError(PyExc_ValueError, "No current Context"); 2049 return context; 2050 }, 2051 "Gets the Context bound to the current thread or raises ValueError") 2052 .def_property_readonly( 2053 "dialects", 2054 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2055 "Gets a container for accessing dialects by name") 2056 .def_property_readonly( 2057 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2058 "Alias for 'dialect'") 2059 .def( 2060 "get_dialect_descriptor", 2061 [=](PyMlirContext &self, std::string &name) { 2062 MlirDialect dialect = mlirContextGetOrLoadDialect( 2063 self.get(), {name.data(), name.size()}); 2064 if (mlirDialectIsNull(dialect)) { 2065 throw SetPyError(PyExc_ValueError, 2066 Twine("Dialect '") + name + "' not found"); 2067 } 2068 return PyDialectDescriptor(self.getRef(), dialect); 2069 }, 2070 py::arg("dialect_name"), 2071 "Gets or loads a dialect by name, returning its descriptor object") 2072 .def_property( 2073 "allow_unregistered_dialects", 2074 [](PyMlirContext &self) -> bool { 2075 return mlirContextGetAllowUnregisteredDialects(self.get()); 2076 }, 2077 [](PyMlirContext &self, bool value) { 2078 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2079 }) 2080 .def( 2081 "enable_multithreading", 2082 [](PyMlirContext &self, bool enable) { 2083 mlirContextEnableMultithreading(self.get(), enable); 2084 }, 2085 py::arg("enable")) 2086 .def( 2087 "is_registered_operation", 2088 [](PyMlirContext &self, std::string &name) { 2089 return mlirContextIsRegisteredOperation( 2090 self.get(), MlirStringRef{name.data(), name.size()}); 2091 }, 2092 py::arg("operation_name")); 2093 2094 //---------------------------------------------------------------------------- 2095 // Mapping of PyDialectDescriptor 2096 //---------------------------------------------------------------------------- 2097 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local()) 2098 .def_property_readonly("namespace", 2099 [](PyDialectDescriptor &self) { 2100 MlirStringRef ns = 2101 mlirDialectGetNamespace(self.get()); 2102 return py::str(ns.data, ns.length); 2103 }) 2104 .def("__repr__", [](PyDialectDescriptor &self) { 2105 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2106 std::string repr("<DialectDescriptor "); 2107 repr.append(ns.data, ns.length); 2108 repr.append(">"); 2109 return repr; 2110 }); 2111 2112 //---------------------------------------------------------------------------- 2113 // Mapping of PyDialects 2114 //---------------------------------------------------------------------------- 2115 py::class_<PyDialects>(m, "Dialects", py::module_local()) 2116 .def("__getitem__", 2117 [=](PyDialects &self, std::string keyName) { 2118 MlirDialect dialect = 2119 self.getDialectForKey(keyName, /*attrError=*/false); 2120 py::object descriptor = 2121 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2122 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2123 }) 2124 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2125 MlirDialect dialect = 2126 self.getDialectForKey(attrName, /*attrError=*/true); 2127 py::object descriptor = 2128 py::cast(PyDialectDescriptor{self.getContext(), dialect}); 2129 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2130 }); 2131 2132 //---------------------------------------------------------------------------- 2133 // Mapping of PyDialect 2134 //---------------------------------------------------------------------------- 2135 py::class_<PyDialect>(m, "Dialect", py::module_local()) 2136 .def(py::init<py::object>(), py::arg("descriptor")) 2137 .def_property_readonly( 2138 "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) 2139 .def("__repr__", [](py::object self) { 2140 auto clazz = self.attr("__class__"); 2141 return py::str("<Dialect ") + 2142 self.attr("descriptor").attr("namespace") + py::str(" (class ") + 2143 clazz.attr("__module__") + py::str(".") + 2144 clazz.attr("__name__") + py::str(")>"); 2145 }); 2146 2147 //---------------------------------------------------------------------------- 2148 // Mapping of Location 2149 //---------------------------------------------------------------------------- 2150 py::class_<PyLocation>(m, "Location", py::module_local()) 2151 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2152 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2153 .def("__enter__", &PyLocation::contextEnter) 2154 .def("__exit__", &PyLocation::contextExit) 2155 .def("__eq__", 2156 [](PyLocation &self, PyLocation &other) -> bool { 2157 return mlirLocationEqual(self, other); 2158 }) 2159 .def("__eq__", [](PyLocation &self, py::object other) { return false; }) 2160 .def_property_readonly_static( 2161 "current", 2162 [](py::object & /*class*/) { 2163 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2164 if (!loc) 2165 throw SetPyError(PyExc_ValueError, "No current Location"); 2166 return loc; 2167 }, 2168 "Gets the Location bound to the current thread or raises ValueError") 2169 .def_static( 2170 "unknown", 2171 [](DefaultingPyMlirContext context) { 2172 return PyLocation(context->getRef(), 2173 mlirLocationUnknownGet(context->get())); 2174 }, 2175 py::arg("context") = py::none(), 2176 "Gets a Location representing an unknown location") 2177 .def_static( 2178 "callsite", 2179 [](PyLocation callee, const std::vector<PyLocation> &frames, 2180 DefaultingPyMlirContext context) { 2181 if (frames.empty()) 2182 throw py::value_error("No caller frames provided"); 2183 MlirLocation caller = frames.back().get(); 2184 for (const PyLocation &frame : 2185 llvm::reverse(llvm::makeArrayRef(frames).drop_back())) 2186 caller = mlirLocationCallSiteGet(frame.get(), caller); 2187 return PyLocation(context->getRef(), 2188 mlirLocationCallSiteGet(callee.get(), caller)); 2189 }, 2190 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), 2191 kContextGetCallSiteLocationDocstring) 2192 .def_static( 2193 "file", 2194 [](std::string filename, int line, int col, 2195 DefaultingPyMlirContext context) { 2196 return PyLocation( 2197 context->getRef(), 2198 mlirLocationFileLineColGet( 2199 context->get(), toMlirStringRef(filename), line, col)); 2200 }, 2201 py::arg("filename"), py::arg("line"), py::arg("col"), 2202 py::arg("context") = py::none(), kContextGetFileLocationDocstring) 2203 .def_static( 2204 "fused", 2205 [](const std::vector<PyLocation> &pyLocations, llvm::Optional<PyAttribute> metadata, 2206 DefaultingPyMlirContext context) { 2207 if (pyLocations.empty()) 2208 throw py::value_error("No locations provided"); 2209 llvm::SmallVector<MlirLocation, 4> locations; 2210 locations.reserve(pyLocations.size()); 2211 for (auto &pyLocation : pyLocations) 2212 locations.push_back(pyLocation.get()); 2213 MlirLocation location = mlirLocationFusedGet( 2214 context->get(), locations.size(), locations.data(), 2215 metadata ? metadata->get() : MlirAttribute{0}); 2216 return PyLocation(context->getRef(), location); 2217 }, 2218 py::arg("locations"), py::arg("metadata") = py::none(), 2219 py::arg("context") = py::none(), kContextGetFusedLocationDocstring) 2220 .def_static( 2221 "name", 2222 [](std::string name, llvm::Optional<PyLocation> childLoc, 2223 DefaultingPyMlirContext context) { 2224 return PyLocation( 2225 context->getRef(), 2226 mlirLocationNameGet( 2227 context->get(), toMlirStringRef(name), 2228 childLoc ? childLoc->get() 2229 : mlirLocationUnknownGet(context->get()))); 2230 }, 2231 py::arg("name"), py::arg("childLoc") = py::none(), 2232 py::arg("context") = py::none(), kContextGetNameLocationDocString) 2233 .def_property_readonly( 2234 "context", 2235 [](PyLocation &self) { return self.getContext().getObject(); }, 2236 "Context that owns the Location") 2237 .def("__repr__", [](PyLocation &self) { 2238 PyPrintAccumulator printAccum; 2239 mlirLocationPrint(self, printAccum.getCallback(), 2240 printAccum.getUserData()); 2241 return printAccum.join(); 2242 }); 2243 2244 //---------------------------------------------------------------------------- 2245 // Mapping of Module 2246 //---------------------------------------------------------------------------- 2247 py::class_<PyModule>(m, "Module", py::module_local()) 2248 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 2249 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 2250 .def_static( 2251 "parse", 2252 [](const std::string moduleAsm, DefaultingPyMlirContext context) { 2253 MlirModule module = mlirModuleCreateParse( 2254 context->get(), toMlirStringRef(moduleAsm)); 2255 // TODO: Rework error reporting once diagnostic engine is exposed 2256 // in C API. 2257 if (mlirModuleIsNull(module)) { 2258 throw SetPyError( 2259 PyExc_ValueError, 2260 "Unable to parse module assembly (see diagnostics)"); 2261 } 2262 return PyModule::forModule(module).releaseObject(); 2263 }, 2264 py::arg("asm"), py::arg("context") = py::none(), 2265 kModuleParseDocstring) 2266 .def_static( 2267 "create", 2268 [](DefaultingPyLocation loc) { 2269 MlirModule module = mlirModuleCreateEmpty(loc); 2270 return PyModule::forModule(module).releaseObject(); 2271 }, 2272 py::arg("loc") = py::none(), "Creates an empty module") 2273 .def_property_readonly( 2274 "context", 2275 [](PyModule &self) { return self.getContext().getObject(); }, 2276 "Context that created the Module") 2277 .def_property_readonly( 2278 "operation", 2279 [](PyModule &self) { 2280 return PyOperation::forOperation(self.getContext(), 2281 mlirModuleGetOperation(self.get()), 2282 self.getRef().releaseObject()) 2283 .releaseObject(); 2284 }, 2285 "Accesses the module as an operation") 2286 .def_property_readonly( 2287 "body", 2288 [](PyModule &self) { 2289 PyOperationRef module_op = PyOperation::forOperation( 2290 self.getContext(), mlirModuleGetOperation(self.get()), 2291 self.getRef().releaseObject()); 2292 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); 2293 return returnBlock; 2294 }, 2295 "Return the block for this module") 2296 .def( 2297 "dump", 2298 [](PyModule &self) { 2299 mlirOperationDump(mlirModuleGetOperation(self.get())); 2300 }, 2301 kDumpDocstring) 2302 .def( 2303 "__str__", 2304 [](py::object self) { 2305 // Defer to the operation's __str__. 2306 return self.attr("operation").attr("__str__")(); 2307 }, 2308 kOperationStrDunderDocstring); 2309 2310 //---------------------------------------------------------------------------- 2311 // Mapping of Operation. 2312 //---------------------------------------------------------------------------- 2313 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local()) 2314 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2315 [](PyOperationBase &self) { 2316 return self.getOperation().getCapsule(); 2317 }) 2318 .def("__eq__", 2319 [](PyOperationBase &self, PyOperationBase &other) { 2320 return &self.getOperation() == &other.getOperation(); 2321 }) 2322 .def("__eq__", 2323 [](PyOperationBase &self, py::object other) { return false; }) 2324 .def("__hash__", 2325 [](PyOperationBase &self) { 2326 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 2327 }) 2328 .def_property_readonly("attributes", 2329 [](PyOperationBase &self) { 2330 return PyOpAttributeMap( 2331 self.getOperation().getRef()); 2332 }) 2333 .def_property_readonly("operands", 2334 [](PyOperationBase &self) { 2335 return PyOpOperandList( 2336 self.getOperation().getRef()); 2337 }) 2338 .def_property_readonly("regions", 2339 [](PyOperationBase &self) { 2340 return PyRegionList( 2341 self.getOperation().getRef()); 2342 }) 2343 .def_property_readonly( 2344 "results", 2345 [](PyOperationBase &self) { 2346 return PyOpResultList(self.getOperation().getRef()); 2347 }, 2348 "Returns the list of Operation results.") 2349 .def_property_readonly( 2350 "result", 2351 [](PyOperationBase &self) { 2352 auto &operation = self.getOperation(); 2353 auto numResults = mlirOperationGetNumResults(operation); 2354 if (numResults != 1) { 2355 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 2356 throw SetPyError( 2357 PyExc_ValueError, 2358 Twine("Cannot call .result on operation ") + 2359 StringRef(name.data, name.length) + " which has " + 2360 Twine(numResults) + 2361 " results (it is only valid for operations with a " 2362 "single result)"); 2363 } 2364 return PyOpResult(operation.getRef(), 2365 mlirOperationGetResult(operation, 0)); 2366 }, 2367 "Shortcut to get an op result if it has only one (throws an error " 2368 "otherwise).") 2369 .def_property_readonly( 2370 "location", 2371 [](PyOperationBase &self) { 2372 PyOperation &operation = self.getOperation(); 2373 return PyLocation(operation.getContext(), 2374 mlirOperationGetLocation(operation.get())); 2375 }, 2376 "Returns the source location the operation was defined or derived " 2377 "from.") 2378 .def( 2379 "__str__", 2380 [](PyOperationBase &self) { 2381 return self.getAsm(/*binary=*/false, 2382 /*largeElementsLimit=*/llvm::None, 2383 /*enableDebugInfo=*/false, 2384 /*prettyDebugInfo=*/false, 2385 /*printGenericOpForm=*/false, 2386 /*useLocalScope=*/false, 2387 /*assumeVerified=*/false); 2388 }, 2389 "Returns the assembly form of the operation.") 2390 .def("print", &PyOperationBase::print, 2391 // Careful: Lots of arguments must match up with print method. 2392 py::arg("file") = py::none(), py::arg("binary") = false, 2393 py::arg("large_elements_limit") = py::none(), 2394 py::arg("enable_debug_info") = false, 2395 py::arg("pretty_debug_info") = false, 2396 py::arg("print_generic_op_form") = false, 2397 py::arg("use_local_scope") = false, 2398 py::arg("assume_verified") = false, kOperationPrintDocstring) 2399 .def("get_asm", &PyOperationBase::getAsm, 2400 // Careful: Lots of arguments must match up with get_asm method. 2401 py::arg("binary") = false, 2402 py::arg("large_elements_limit") = py::none(), 2403 py::arg("enable_debug_info") = false, 2404 py::arg("pretty_debug_info") = false, 2405 py::arg("print_generic_op_form") = false, 2406 py::arg("use_local_scope") = false, 2407 py::arg("assume_verified") = false, kOperationGetAsmDocstring) 2408 .def( 2409 "verify", 2410 [](PyOperationBase &self) { 2411 return mlirOperationVerify(self.getOperation()); 2412 }, 2413 "Verify the operation and return true if it passes, false if it " 2414 "fails.") 2415 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), 2416 "Puts self immediately after the other operation in its parent " 2417 "block.") 2418 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), 2419 "Puts self immediately before the other operation in its parent " 2420 "block.") 2421 .def( 2422 "detach_from_parent", 2423 [](PyOperationBase &self) { 2424 PyOperation &operation = self.getOperation(); 2425 operation.checkValid(); 2426 if (!operation.isAttached()) 2427 throw py::value_error("Detached operation has no parent."); 2428 2429 operation.detachFromParent(); 2430 return operation.createOpView(); 2431 }, 2432 "Detaches the operation from its parent block."); 2433 2434 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local()) 2435 .def_static("create", &PyOperation::create, py::arg("name"), 2436 py::arg("results") = py::none(), 2437 py::arg("operands") = py::none(), 2438 py::arg("attributes") = py::none(), 2439 py::arg("successors") = py::none(), py::arg("regions") = 0, 2440 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2441 kOperationCreateDocstring) 2442 .def_property_readonly("parent", 2443 [](PyOperation &self) -> py::object { 2444 auto parent = self.getParentOperation(); 2445 if (parent) 2446 return parent->getObject(); 2447 return py::none(); 2448 }) 2449 .def("erase", &PyOperation::erase) 2450 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2451 &PyOperation::getCapsule) 2452 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 2453 .def_property_readonly("name", 2454 [](PyOperation &self) { 2455 self.checkValid(); 2456 MlirOperation operation = self.get(); 2457 MlirStringRef name = mlirIdentifierStr( 2458 mlirOperationGetName(operation)); 2459 return py::str(name.data, name.length); 2460 }) 2461 .def_property_readonly( 2462 "context", 2463 [](PyOperation &self) { 2464 self.checkValid(); 2465 return self.getContext().getObject(); 2466 }, 2467 "Context that owns the Operation") 2468 .def_property_readonly("opview", &PyOperation::createOpView); 2469 2470 auto opViewClass = 2471 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local()) 2472 .def(py::init<py::object>(), py::arg("operation")) 2473 .def_property_readonly("operation", &PyOpView::getOperationObject) 2474 .def_property_readonly( 2475 "context", 2476 [](PyOpView &self) { 2477 return self.getOperation().getContext().getObject(); 2478 }, 2479 "Context that owns the Operation") 2480 .def("__str__", [](PyOpView &self) { 2481 return py::str(self.getOperationObject()); 2482 }); 2483 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); 2484 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); 2485 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); 2486 opViewClass.attr("build_generic") = classmethod( 2487 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), 2488 py::arg("operands") = py::none(), py::arg("attributes") = py::none(), 2489 py::arg("successors") = py::none(), py::arg("regions") = py::none(), 2490 py::arg("loc") = py::none(), py::arg("ip") = py::none(), 2491 "Builds a specific, generated OpView based on class level attributes."); 2492 2493 //---------------------------------------------------------------------------- 2494 // Mapping of PyRegion. 2495 //---------------------------------------------------------------------------- 2496 py::class_<PyRegion>(m, "Region", py::module_local()) 2497 .def_property_readonly( 2498 "blocks", 2499 [](PyRegion &self) { 2500 return PyBlockList(self.getParentOperation(), self.get()); 2501 }, 2502 "Returns a forward-optimized sequence of blocks.") 2503 .def_property_readonly( 2504 "owner", 2505 [](PyRegion &self) { 2506 return self.getParentOperation()->createOpView(); 2507 }, 2508 "Returns the operation owning this region.") 2509 .def( 2510 "__iter__", 2511 [](PyRegion &self) { 2512 self.checkValid(); 2513 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 2514 return PyBlockIterator(self.getParentOperation(), firstBlock); 2515 }, 2516 "Iterates over blocks in the region.") 2517 .def("__eq__", 2518 [](PyRegion &self, PyRegion &other) { 2519 return self.get().ptr == other.get().ptr; 2520 }) 2521 .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); 2522 2523 //---------------------------------------------------------------------------- 2524 // Mapping of PyBlock. 2525 //---------------------------------------------------------------------------- 2526 py::class_<PyBlock>(m, "Block", py::module_local()) 2527 .def_property_readonly( 2528 "owner", 2529 [](PyBlock &self) { 2530 return self.getParentOperation()->createOpView(); 2531 }, 2532 "Returns the owning operation of this block.") 2533 .def_property_readonly( 2534 "region", 2535 [](PyBlock &self) { 2536 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2537 return PyRegion(self.getParentOperation(), region); 2538 }, 2539 "Returns the owning region of this block.") 2540 .def_property_readonly( 2541 "arguments", 2542 [](PyBlock &self) { 2543 return PyBlockArgumentList(self.getParentOperation(), self.get()); 2544 }, 2545 "Returns a list of block arguments.") 2546 .def_property_readonly( 2547 "operations", 2548 [](PyBlock &self) { 2549 return PyOperationList(self.getParentOperation(), self.get()); 2550 }, 2551 "Returns a forward-optimized sequence of operations.") 2552 .def_static( 2553 "create_at_start", 2554 [](PyRegion &parent, py::list pyArgTypes) { 2555 parent.checkValid(); 2556 llvm::SmallVector<MlirType, 4> argTypes; 2557 argTypes.reserve(pyArgTypes.size()); 2558 for (auto &pyArg : pyArgTypes) { 2559 argTypes.push_back(pyArg.cast<PyType &>()); 2560 } 2561 2562 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2563 mlirRegionInsertOwnedBlock(parent, 0, block); 2564 return PyBlock(parent.getParentOperation(), block); 2565 }, 2566 py::arg("parent"), py::arg("arg_types") = py::list(), 2567 "Creates and returns a new Block at the beginning of the given " 2568 "region (with given argument types).") 2569 .def( 2570 "create_before", 2571 [](PyBlock &self, py::args pyArgTypes) { 2572 self.checkValid(); 2573 llvm::SmallVector<MlirType, 4> argTypes; 2574 argTypes.reserve(pyArgTypes.size()); 2575 for (auto &pyArg : pyArgTypes) { 2576 argTypes.push_back(pyArg.cast<PyType &>()); 2577 } 2578 2579 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2580 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2581 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 2582 return PyBlock(self.getParentOperation(), block); 2583 }, 2584 "Creates and returns a new Block before this block " 2585 "(with given argument types).") 2586 .def( 2587 "create_after", 2588 [](PyBlock &self, py::args pyArgTypes) { 2589 self.checkValid(); 2590 llvm::SmallVector<MlirType, 4> argTypes; 2591 argTypes.reserve(pyArgTypes.size()); 2592 for (auto &pyArg : pyArgTypes) { 2593 argTypes.push_back(pyArg.cast<PyType &>()); 2594 } 2595 2596 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); 2597 MlirRegion region = mlirBlockGetParentRegion(self.get()); 2598 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 2599 return PyBlock(self.getParentOperation(), block); 2600 }, 2601 "Creates and returns a new Block after this block " 2602 "(with given argument types).") 2603 .def( 2604 "__iter__", 2605 [](PyBlock &self) { 2606 self.checkValid(); 2607 MlirOperation firstOperation = 2608 mlirBlockGetFirstOperation(self.get()); 2609 return PyOperationIterator(self.getParentOperation(), 2610 firstOperation); 2611 }, 2612 "Iterates over operations in the block.") 2613 .def("__eq__", 2614 [](PyBlock &self, PyBlock &other) { 2615 return self.get().ptr == other.get().ptr; 2616 }) 2617 .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) 2618 .def( 2619 "__str__", 2620 [](PyBlock &self) { 2621 self.checkValid(); 2622 PyPrintAccumulator printAccum; 2623 mlirBlockPrint(self.get(), printAccum.getCallback(), 2624 printAccum.getUserData()); 2625 return printAccum.join(); 2626 }, 2627 "Returns the assembly form of the block.") 2628 .def( 2629 "append", 2630 [](PyBlock &self, PyOperationBase &operation) { 2631 if (operation.getOperation().isAttached()) 2632 operation.getOperation().detachFromParent(); 2633 2634 MlirOperation mlirOperation = operation.getOperation().get(); 2635 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 2636 operation.getOperation().setAttached( 2637 self.getParentOperation().getObject()); 2638 }, 2639 py::arg("operation"), 2640 "Appends an operation to this block. If the operation is currently " 2641 "in another block, it will be moved."); 2642 2643 //---------------------------------------------------------------------------- 2644 // Mapping of PyInsertionPoint. 2645 //---------------------------------------------------------------------------- 2646 2647 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local()) 2648 .def(py::init<PyBlock &>(), py::arg("block"), 2649 "Inserts after the last operation but still inside the block.") 2650 .def("__enter__", &PyInsertionPoint::contextEnter) 2651 .def("__exit__", &PyInsertionPoint::contextExit) 2652 .def_property_readonly_static( 2653 "current", 2654 [](py::object & /*class*/) { 2655 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 2656 if (!ip) 2657 throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); 2658 return ip; 2659 }, 2660 "Gets the InsertionPoint bound to the current thread or raises " 2661 "ValueError if none has been set") 2662 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), 2663 "Inserts before a referenced operation.") 2664 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 2665 py::arg("block"), "Inserts at the beginning of the block.") 2666 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 2667 py::arg("block"), "Inserts before the block terminator.") 2668 .def("insert", &PyInsertionPoint::insert, py::arg("operation"), 2669 "Inserts an operation.") 2670 .def_property_readonly( 2671 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 2672 "Returns the block that this InsertionPoint points to."); 2673 2674 //---------------------------------------------------------------------------- 2675 // Mapping of PyAttribute. 2676 //---------------------------------------------------------------------------- 2677 py::class_<PyAttribute>(m, "Attribute", py::module_local()) 2678 // Delegate to the PyAttribute copy constructor, which will also lifetime 2679 // extend the backing context which owns the MlirAttribute. 2680 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"), 2681 "Casts the passed attribute to the generic Attribute") 2682 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 2683 &PyAttribute::getCapsule) 2684 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 2685 .def_static( 2686 "parse", 2687 [](std::string attrSpec, DefaultingPyMlirContext context) { 2688 MlirAttribute type = mlirAttributeParseGet( 2689 context->get(), toMlirStringRef(attrSpec)); 2690 // TODO: Rework error reporting once diagnostic engine is exposed 2691 // in C API. 2692 if (mlirAttributeIsNull(type)) { 2693 throw SetPyError(PyExc_ValueError, 2694 Twine("Unable to parse attribute: '") + 2695 attrSpec + "'"); 2696 } 2697 return PyAttribute(context->getRef(), type); 2698 }, 2699 py::arg("asm"), py::arg("context") = py::none(), 2700 "Parses an attribute from an assembly form") 2701 .def_property_readonly( 2702 "context", 2703 [](PyAttribute &self) { return self.getContext().getObject(); }, 2704 "Context that owns the Attribute") 2705 .def_property_readonly("type", 2706 [](PyAttribute &self) { 2707 return PyType(self.getContext()->getRef(), 2708 mlirAttributeGetType(self)); 2709 }) 2710 .def( 2711 "get_named", 2712 [](PyAttribute &self, std::string name) { 2713 return PyNamedAttribute(self, std::move(name)); 2714 }, 2715 py::keep_alive<0, 1>(), "Binds a name to the attribute") 2716 .def("__eq__", 2717 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 2718 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) 2719 .def("__hash__", 2720 [](PyAttribute &self) { 2721 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2722 }) 2723 .def( 2724 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 2725 kDumpDocstring) 2726 .def( 2727 "__str__", 2728 [](PyAttribute &self) { 2729 PyPrintAccumulator printAccum; 2730 mlirAttributePrint(self, printAccum.getCallback(), 2731 printAccum.getUserData()); 2732 return printAccum.join(); 2733 }, 2734 "Returns the assembly form of the Attribute.") 2735 .def("__repr__", [](PyAttribute &self) { 2736 // Generally, assembly formats are not printed for __repr__ because 2737 // this can cause exceptionally long debug output and exceptions. 2738 // However, attribute values are generally considered useful and are 2739 // printed. This may need to be re-evaluated if debug dumps end up 2740 // being excessive. 2741 PyPrintAccumulator printAccum; 2742 printAccum.parts.append("Attribute("); 2743 mlirAttributePrint(self, printAccum.getCallback(), 2744 printAccum.getUserData()); 2745 printAccum.parts.append(")"); 2746 return printAccum.join(); 2747 }); 2748 2749 //---------------------------------------------------------------------------- 2750 // Mapping of PyNamedAttribute 2751 //---------------------------------------------------------------------------- 2752 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local()) 2753 .def("__repr__", 2754 [](PyNamedAttribute &self) { 2755 PyPrintAccumulator printAccum; 2756 printAccum.parts.append("NamedAttribute("); 2757 printAccum.parts.append( 2758 py::str(mlirIdentifierStr(self.namedAttr.name).data, 2759 mlirIdentifierStr(self.namedAttr.name).length)); 2760 printAccum.parts.append("="); 2761 mlirAttributePrint(self.namedAttr.attribute, 2762 printAccum.getCallback(), 2763 printAccum.getUserData()); 2764 printAccum.parts.append(")"); 2765 return printAccum.join(); 2766 }) 2767 .def_property_readonly( 2768 "name", 2769 [](PyNamedAttribute &self) { 2770 return py::str(mlirIdentifierStr(self.namedAttr.name).data, 2771 mlirIdentifierStr(self.namedAttr.name).length); 2772 }, 2773 "The name of the NamedAttribute binding") 2774 .def_property_readonly( 2775 "attr", 2776 [](PyNamedAttribute &self) { 2777 // TODO: When named attribute is removed/refactored, also remove 2778 // this constructor (it does an inefficient table lookup). 2779 auto contextRef = PyMlirContext::forContext( 2780 mlirAttributeGetContext(self.namedAttr.attribute)); 2781 return PyAttribute(std::move(contextRef), self.namedAttr.attribute); 2782 }, 2783 py::keep_alive<0, 1>(), 2784 "The underlying generic attribute of the NamedAttribute binding"); 2785 2786 //---------------------------------------------------------------------------- 2787 // Mapping of PyType. 2788 //---------------------------------------------------------------------------- 2789 py::class_<PyType>(m, "Type", py::module_local()) 2790 // Delegate to the PyType copy constructor, which will also lifetime 2791 // extend the backing context which owns the MlirType. 2792 .def(py::init<PyType &>(), py::arg("cast_from_type"), 2793 "Casts the passed type to the generic Type") 2794 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 2795 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 2796 .def_static( 2797 "parse", 2798 [](std::string typeSpec, DefaultingPyMlirContext context) { 2799 MlirType type = 2800 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 2801 // TODO: Rework error reporting once diagnostic engine is exposed 2802 // in C API. 2803 if (mlirTypeIsNull(type)) { 2804 throw SetPyError(PyExc_ValueError, 2805 Twine("Unable to parse type: '") + typeSpec + 2806 "'"); 2807 } 2808 return PyType(context->getRef(), type); 2809 }, 2810 py::arg("asm"), py::arg("context") = py::none(), 2811 kContextParseTypeDocstring) 2812 .def_property_readonly( 2813 "context", [](PyType &self) { return self.getContext().getObject(); }, 2814 "Context that owns the Type") 2815 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 2816 .def("__eq__", [](PyType &self, py::object &other) { return false; }) 2817 .def("__hash__", 2818 [](PyType &self) { 2819 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2820 }) 2821 .def( 2822 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 2823 .def( 2824 "__str__", 2825 [](PyType &self) { 2826 PyPrintAccumulator printAccum; 2827 mlirTypePrint(self, printAccum.getCallback(), 2828 printAccum.getUserData()); 2829 return printAccum.join(); 2830 }, 2831 "Returns the assembly form of the type.") 2832 .def("__repr__", [](PyType &self) { 2833 // Generally, assembly formats are not printed for __repr__ because 2834 // this can cause exceptionally long debug output and exceptions. 2835 // However, types are an exception as they typically have compact 2836 // assembly forms and printing them is useful. 2837 PyPrintAccumulator printAccum; 2838 printAccum.parts.append("Type("); 2839 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 2840 printAccum.parts.append(")"); 2841 return printAccum.join(); 2842 }); 2843 2844 //---------------------------------------------------------------------------- 2845 // Mapping of Value. 2846 //---------------------------------------------------------------------------- 2847 py::class_<PyValue>(m, "Value", py::module_local()) 2848 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 2849 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 2850 .def_property_readonly( 2851 "context", 2852 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 2853 "Context in which the value lives.") 2854 .def( 2855 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 2856 kDumpDocstring) 2857 .def_property_readonly( 2858 "owner", 2859 [](PyValue &self) { 2860 assert(mlirOperationEqual(self.getParentOperation()->get(), 2861 mlirOpResultGetOwner(self.get())) && 2862 "expected the owner of the value in Python to match that in " 2863 "the IR"); 2864 return self.getParentOperation().getObject(); 2865 }) 2866 .def("__eq__", 2867 [](PyValue &self, PyValue &other) { 2868 return self.get().ptr == other.get().ptr; 2869 }) 2870 .def("__eq__", [](PyValue &self, py::object other) { return false; }) 2871 .def("__hash__", 2872 [](PyValue &self) { 2873 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 2874 }) 2875 .def( 2876 "__str__", 2877 [](PyValue &self) { 2878 PyPrintAccumulator printAccum; 2879 printAccum.parts.append("Value("); 2880 mlirValuePrint(self.get(), printAccum.getCallback(), 2881 printAccum.getUserData()); 2882 printAccum.parts.append(")"); 2883 return printAccum.join(); 2884 }, 2885 kValueDunderStrDocstring) 2886 .def_property_readonly("type", [](PyValue &self) { 2887 return PyType(self.getParentOperation()->getContext(), 2888 mlirValueGetType(self.get())); 2889 }); 2890 PyBlockArgument::bind(m); 2891 PyOpResult::bind(m); 2892 2893 //---------------------------------------------------------------------------- 2894 // Mapping of SymbolTable. 2895 //---------------------------------------------------------------------------- 2896 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local()) 2897 .def(py::init<PyOperationBase &>()) 2898 .def("__getitem__", &PySymbolTable::dunderGetItem) 2899 .def("insert", &PySymbolTable::insert, py::arg("operation")) 2900 .def("erase", &PySymbolTable::erase, py::arg("operation")) 2901 .def("__delitem__", &PySymbolTable::dunderDel) 2902 .def("__contains__", 2903 [](PySymbolTable &table, const std::string &name) { 2904 return !mlirOperationIsNull(mlirSymbolTableLookup( 2905 table, mlirStringRefCreate(name.data(), name.length()))); 2906 }) 2907 // Static helpers. 2908 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 2909 py::arg("symbol"), py::arg("name")) 2910 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 2911 py::arg("symbol")) 2912 .def_static("get_visibility", &PySymbolTable::getVisibility, 2913 py::arg("symbol")) 2914 .def_static("set_visibility", &PySymbolTable::setVisibility, 2915 py::arg("symbol"), py::arg("visibility")) 2916 .def_static("replace_all_symbol_uses", 2917 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), 2918 py::arg("new_symbol"), py::arg("from_op")) 2919 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 2920 py::arg("from_op"), py::arg("all_sym_uses_visible"), 2921 py::arg("callback")); 2922 2923 // Container bindings. 2924 PyBlockArgumentList::bind(m); 2925 PyBlockIterator::bind(m); 2926 PyBlockList::bind(m); 2927 PyOperationIterator::bind(m); 2928 PyOperationList::bind(m); 2929 PyOpAttributeMap::bind(m); 2930 PyOpOperandList::bind(m); 2931 PyOpResultList::bind(m); 2932 PyRegionIterator::bind(m); 2933 PyRegionList::bind(m); 2934 2935 // Debug bindings. 2936 PyGlobalDebugFlag::bind(m); 2937 } 2938