1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IRModule.h"
10
11 #include "Globals.h"
12 #include "PybindUtils.h"
13
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 //#include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22
23 #include <utility>
24
25 namespace py = pybind11;
26 using namespace mlir;
27 using namespace mlir::python;
28
29 using llvm::SmallVector;
30 using llvm::StringRef;
31 using llvm::Twine;
32
33 //------------------------------------------------------------------------------
34 // Docstrings (trivial, non-duplicated docstrings are included inline).
35 //------------------------------------------------------------------------------
36
37 static const char kContextParseTypeDocstring[] =
38 R"(Parses the assembly form of a type.
39
40 Returns a Type object or raises a ValueError if the type cannot be parsed.
41
42 See also: https://mlir.llvm.org/docs/LangRef/#type-system
43 )";
44
45 static const char kContextGetCallSiteLocationDocstring[] =
46 R"(Gets a Location representing a caller and callsite)";
47
48 static const char kContextGetFileLocationDocstring[] =
49 R"(Gets a Location representing a file, line and column)";
50
51 static const char kContextGetFusedLocationDocstring[] =
52 R"(Gets a Location representing a fused location with optional metadata)";
53
54 static const char kContextGetNameLocationDocString[] =
55 R"(Gets a Location representing a named location with optional child location)";
56
57 static const char kModuleParseDocstring[] =
58 R"(Parses a module's assembly format from a string.
59
60 Returns a new MlirModule or raises a ValueError if the parsing fails.
61
62 See also: https://mlir.llvm.org/docs/LangRef/
63 )";
64
65 static const char kOperationCreateDocstring[] =
66 R"(Creates a new operation.
67
68 Args:
69 name: Operation name (e.g. "dialect.operation").
70 results: Sequence of Type representing op result types.
71 attributes: Dict of str:Attribute.
72 successors: List of Block for the operation's successors.
73 regions: Number of regions to create.
74 location: A Location object (defaults to resolve from context manager).
75 ip: An InsertionPoint (defaults to resolve from context manager or set to
76 False to disable insertion, even with an insertion point set in the
77 context manager).
78 Returns:
79 A new "detached" Operation object. Detached operations can be added
80 to blocks, which causes them to become "attached."
81 )";
82
83 static const char kOperationPrintDocstring[] =
84 R"(Prints the assembly form of the operation to a file like object.
85
86 Args:
87 file: The file like object to write to. Defaults to sys.stdout.
88 binary: Whether to write bytes (True) or str (False). Defaults to False.
89 large_elements_limit: Whether to elide elements attributes above this
90 number of elements. Defaults to None (no limit).
91 enable_debug_info: Whether to print debug/location information. Defaults
92 to False.
93 pretty_debug_info: Whether to format debug information for easier reading
94 by a human (warning: the result is unparseable).
95 print_generic_op_form: Whether to print the generic assembly forms of all
96 ops. Defaults to False.
97 use_local_Scope: Whether to print in a way that is more optimized for
98 multi-threaded access but may not be consistent with how the overall
99 module prints.
100 assume_verified: By default, if not printing generic form, the verifier
101 will be run and if it fails, generic form will be printed with a comment
102 about failed verification. While a reasonable default for interactive use,
103 for systematic use, it is often better for the caller to verify explicitly
104 and report failures in a more robust fashion. Set this to True if doing this
105 in order to avoid running a redundant verification. If the IR is actually
106 invalid, behavior is undefined.
107 )";
108
109 static const char kOperationGetAsmDocstring[] =
110 R"(Gets the assembly form of the operation with all options available.
111
112 Args:
113 binary: Whether to return a bytes (True) or str (False) object. Defaults to
114 False.
115 ... others ...: See the print() method for common keyword arguments for
116 configuring the printout.
117 Returns:
118 Either a bytes or str object, depending on the setting of the 'binary'
119 argument.
120 )";
121
122 static const char kOperationStrDunderDocstring[] =
123 R"(Gets the assembly form of the operation with default options.
124
125 If more advanced control over the assembly formatting or I/O options is needed,
126 use the dedicated print or get_asm method, which supports keyword arguments to
127 customize behavior.
128 )";
129
130 static const char kDumpDocstring[] =
131 R"(Dumps a debug representation of the object to stderr.)";
132
133 static const char kAppendBlockDocstring[] =
134 R"(Appends a new block, with argument types as positional args.
135
136 Returns:
137 The created block.
138 )";
139
140 static const char kValueDunderStrDocstring[] =
141 R"(Returns the string form of the value.
142
143 If the value is a block argument, this is the assembly form of its type and the
144 position in the argument list. If the value is an operation result, this is
145 equivalent to printing the operation that produced it.
146 )";
147
148 //------------------------------------------------------------------------------
149 // Utilities.
150 //------------------------------------------------------------------------------
151
152 /// Helper for creating an @classmethod.
153 template <class Func, typename... Args>
classmethod(Func f,Args...args)154 py::object classmethod(Func f, Args... args) {
155 py::object cf = py::cpp_function(f, args...);
156 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
157 }
158
159 static py::object
createCustomDialectWrapper(const std::string & dialectNamespace,py::object dialectDescriptor)160 createCustomDialectWrapper(const std::string &dialectNamespace,
161 py::object dialectDescriptor) {
162 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
163 if (!dialectClass) {
164 // Use the base class.
165 return py::cast(PyDialect(std::move(dialectDescriptor)));
166 }
167
168 // Create the custom implementation.
169 return (*dialectClass)(std::move(dialectDescriptor));
170 }
171
toMlirStringRef(const std::string & s)172 static MlirStringRef toMlirStringRef(const std::string &s) {
173 return mlirStringRefCreate(s.data(), s.size());
174 }
175
176 /// Wrapper for the global LLVM debugging flag.
177 struct PyGlobalDebugFlag {
setPyGlobalDebugFlag178 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
179
getPyGlobalDebugFlag180 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
181
bindPyGlobalDebugFlag182 static void bind(py::module &m) {
183 // Debug flags.
184 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
185 .def_property_static("flag", &PyGlobalDebugFlag::get,
186 &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
187 }
188 };
189
190 //------------------------------------------------------------------------------
191 // Collections.
192 //------------------------------------------------------------------------------
193
194 namespace {
195
196 class PyRegionIterator {
197 public:
PyRegionIterator(PyOperationRef operation)198 PyRegionIterator(PyOperationRef operation)
199 : operation(std::move(operation)) {}
200
dunderIter()201 PyRegionIterator &dunderIter() { return *this; }
202
dunderNext()203 PyRegion dunderNext() {
204 operation->checkValid();
205 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
206 throw py::stop_iteration();
207 }
208 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
209 return PyRegion(operation, region);
210 }
211
bind(py::module & m)212 static void bind(py::module &m) {
213 py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
214 .def("__iter__", &PyRegionIterator::dunderIter)
215 .def("__next__", &PyRegionIterator::dunderNext);
216 }
217
218 private:
219 PyOperationRef operation;
220 int nextIndex = 0;
221 };
222
223 /// Regions of an op are fixed length and indexed numerically so are represented
224 /// with a sequence-like container.
225 class PyRegionList {
226 public:
PyRegionList(PyOperationRef operation)227 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
228
dunderLen()229 intptr_t dunderLen() {
230 operation->checkValid();
231 return mlirOperationGetNumRegions(operation->get());
232 }
233
dunderGetItem(intptr_t index)234 PyRegion dunderGetItem(intptr_t index) {
235 // dunderLen checks validity.
236 if (index < 0 || index >= dunderLen()) {
237 throw SetPyError(PyExc_IndexError,
238 "attempt to access out of bounds region");
239 }
240 MlirRegion region = mlirOperationGetRegion(operation->get(), index);
241 return PyRegion(operation, region);
242 }
243
bind(py::module & m)244 static void bind(py::module &m) {
245 py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
246 .def("__len__", &PyRegionList::dunderLen)
247 .def("__getitem__", &PyRegionList::dunderGetItem);
248 }
249
250 private:
251 PyOperationRef operation;
252 };
253
254 class PyBlockIterator {
255 public:
PyBlockIterator(PyOperationRef operation,MlirBlock next)256 PyBlockIterator(PyOperationRef operation, MlirBlock next)
257 : operation(std::move(operation)), next(next) {}
258
dunderIter()259 PyBlockIterator &dunderIter() { return *this; }
260
dunderNext()261 PyBlock dunderNext() {
262 operation->checkValid();
263 if (mlirBlockIsNull(next)) {
264 throw py::stop_iteration();
265 }
266
267 PyBlock returnBlock(operation, next);
268 next = mlirBlockGetNextInRegion(next);
269 return returnBlock;
270 }
271
bind(py::module & m)272 static void bind(py::module &m) {
273 py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
274 .def("__iter__", &PyBlockIterator::dunderIter)
275 .def("__next__", &PyBlockIterator::dunderNext);
276 }
277
278 private:
279 PyOperationRef operation;
280 MlirBlock next;
281 };
282
283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
284 /// we present them as a more full-featured list-like container but optimize
285 /// it for forward iteration. Blocks are always owned by a region.
286 class PyBlockList {
287 public:
PyBlockList(PyOperationRef operation,MlirRegion region)288 PyBlockList(PyOperationRef operation, MlirRegion region)
289 : operation(std::move(operation)), region(region) {}
290
dunderIter()291 PyBlockIterator dunderIter() {
292 operation->checkValid();
293 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
294 }
295
dunderLen()296 intptr_t dunderLen() {
297 operation->checkValid();
298 intptr_t count = 0;
299 MlirBlock block = mlirRegionGetFirstBlock(region);
300 while (!mlirBlockIsNull(block)) {
301 count += 1;
302 block = mlirBlockGetNextInRegion(block);
303 }
304 return count;
305 }
306
dunderGetItem(intptr_t index)307 PyBlock dunderGetItem(intptr_t index) {
308 operation->checkValid();
309 if (index < 0) {
310 throw SetPyError(PyExc_IndexError,
311 "attempt to access out of bounds block");
312 }
313 MlirBlock block = mlirRegionGetFirstBlock(region);
314 while (!mlirBlockIsNull(block)) {
315 if (index == 0) {
316 return PyBlock(operation, block);
317 }
318 block = mlirBlockGetNextInRegion(block);
319 index -= 1;
320 }
321 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
322 }
323
appendBlock(const py::args & pyArgTypes)324 PyBlock appendBlock(const py::args &pyArgTypes) {
325 operation->checkValid();
326 llvm::SmallVector<MlirType, 4> argTypes;
327 llvm::SmallVector<MlirLocation, 4> argLocs;
328 argTypes.reserve(pyArgTypes.size());
329 argLocs.reserve(pyArgTypes.size());
330 for (auto &pyArg : pyArgTypes) {
331 argTypes.push_back(pyArg.cast<PyType &>());
332 // TODO: Pass in a proper location here.
333 argLocs.push_back(
334 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
335 }
336
337 MlirBlock block =
338 mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
339 mlirRegionAppendOwnedBlock(region, block);
340 return PyBlock(operation, block);
341 }
342
bind(py::module & m)343 static void bind(py::module &m) {
344 py::class_<PyBlockList>(m, "BlockList", py::module_local())
345 .def("__getitem__", &PyBlockList::dunderGetItem)
346 .def("__iter__", &PyBlockList::dunderIter)
347 .def("__len__", &PyBlockList::dunderLen)
348 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
349 }
350
351 private:
352 PyOperationRef operation;
353 MlirRegion region;
354 };
355
356 class PyOperationIterator {
357 public:
PyOperationIterator(PyOperationRef parentOperation,MlirOperation next)358 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
359 : parentOperation(std::move(parentOperation)), next(next) {}
360
dunderIter()361 PyOperationIterator &dunderIter() { return *this; }
362
dunderNext()363 py::object dunderNext() {
364 parentOperation->checkValid();
365 if (mlirOperationIsNull(next)) {
366 throw py::stop_iteration();
367 }
368
369 PyOperationRef returnOperation =
370 PyOperation::forOperation(parentOperation->getContext(), next);
371 next = mlirOperationGetNextInBlock(next);
372 return returnOperation->createOpView();
373 }
374
bind(py::module & m)375 static void bind(py::module &m) {
376 py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
377 .def("__iter__", &PyOperationIterator::dunderIter)
378 .def("__next__", &PyOperationIterator::dunderNext);
379 }
380
381 private:
382 PyOperationRef parentOperation;
383 MlirOperation next;
384 };
385
386 /// Operations are exposed by the C-API as a forward-only linked list. In
387 /// Python, we present them as a more full-featured list-like container but
388 /// optimize it for forward iteration. Iterable operations are always owned
389 /// by a block.
390 class PyOperationList {
391 public:
PyOperationList(PyOperationRef parentOperation,MlirBlock block)392 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
393 : parentOperation(std::move(parentOperation)), block(block) {}
394
dunderIter()395 PyOperationIterator dunderIter() {
396 parentOperation->checkValid();
397 return PyOperationIterator(parentOperation,
398 mlirBlockGetFirstOperation(block));
399 }
400
dunderLen()401 intptr_t dunderLen() {
402 parentOperation->checkValid();
403 intptr_t count = 0;
404 MlirOperation childOp = mlirBlockGetFirstOperation(block);
405 while (!mlirOperationIsNull(childOp)) {
406 count += 1;
407 childOp = mlirOperationGetNextInBlock(childOp);
408 }
409 return count;
410 }
411
dunderGetItem(intptr_t index)412 py::object dunderGetItem(intptr_t index) {
413 parentOperation->checkValid();
414 if (index < 0) {
415 throw SetPyError(PyExc_IndexError,
416 "attempt to access out of bounds operation");
417 }
418 MlirOperation childOp = mlirBlockGetFirstOperation(block);
419 while (!mlirOperationIsNull(childOp)) {
420 if (index == 0) {
421 return PyOperation::forOperation(parentOperation->getContext(), childOp)
422 ->createOpView();
423 }
424 childOp = mlirOperationGetNextInBlock(childOp);
425 index -= 1;
426 }
427 throw SetPyError(PyExc_IndexError,
428 "attempt to access out of bounds operation");
429 }
430
bind(py::module & m)431 static void bind(py::module &m) {
432 py::class_<PyOperationList>(m, "OperationList", py::module_local())
433 .def("__getitem__", &PyOperationList::dunderGetItem)
434 .def("__iter__", &PyOperationList::dunderIter)
435 .def("__len__", &PyOperationList::dunderLen);
436 }
437
438 private:
439 PyOperationRef parentOperation;
440 MlirBlock block;
441 };
442
443 } // namespace
444
445 //------------------------------------------------------------------------------
446 // PyMlirContext
447 //------------------------------------------------------------------------------
448
PyMlirContext(MlirContext context)449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
450 py::gil_scoped_acquire acquire;
451 auto &liveContexts = getLiveContexts();
452 liveContexts[context.ptr] = this;
453 }
454
~PyMlirContext()455 PyMlirContext::~PyMlirContext() {
456 // Note that the only public way to construct an instance is via the
457 // forContext method, which always puts the associated handle into
458 // liveContexts.
459 py::gil_scoped_acquire acquire;
460 getLiveContexts().erase(context.ptr);
461 mlirContextDestroy(context);
462 }
463
getCapsule()464 py::object PyMlirContext::getCapsule() {
465 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
466 }
467
createFromCapsule(py::object capsule)468 py::object PyMlirContext::createFromCapsule(py::object capsule) {
469 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
470 if (mlirContextIsNull(rawContext))
471 throw py::error_already_set();
472 return forContext(rawContext).releaseObject();
473 }
474
createNewContextForInit()475 PyMlirContext *PyMlirContext::createNewContextForInit() {
476 MlirContext context = mlirContextCreate();
477 return new PyMlirContext(context);
478 }
479
forContext(MlirContext context)480 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
481 py::gil_scoped_acquire acquire;
482 auto &liveContexts = getLiveContexts();
483 auto it = liveContexts.find(context.ptr);
484 if (it == liveContexts.end()) {
485 // Create.
486 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
487 py::object pyRef = py::cast(unownedContextWrapper);
488 assert(pyRef && "cast to py::object failed");
489 liveContexts[context.ptr] = unownedContextWrapper;
490 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
491 }
492 // Use existing.
493 py::object pyRef = py::cast(it->second);
494 return PyMlirContextRef(it->second, std::move(pyRef));
495 }
496
getLiveContexts()497 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
498 static LiveContextMap liveContexts;
499 return liveContexts;
500 }
501
getLiveCount()502 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
503
getLiveOperationCount()504 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
505
clearLiveOperations()506 size_t PyMlirContext::clearLiveOperations() {
507 for (auto &op : liveOperations)
508 op.second.second->setInvalid();
509 size_t numInvalidated = liveOperations.size();
510 liveOperations.clear();
511 return numInvalidated;
512 }
513
getLiveModuleCount()514 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
515
contextEnter()516 pybind11::object PyMlirContext::contextEnter() {
517 return PyThreadContextEntry::pushContext(*this);
518 }
519
contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)520 void PyMlirContext::contextExit(const pybind11::object &excType,
521 const pybind11::object &excVal,
522 const pybind11::object &excTb) {
523 PyThreadContextEntry::popContext(*this);
524 }
525
attachDiagnosticHandler(py::object callback)526 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
527 // Note that ownership is transferred to the delete callback below by way of
528 // an explicit inc_ref (borrow).
529 PyDiagnosticHandler *pyHandler =
530 new PyDiagnosticHandler(get(), std::move(callback));
531 py::object pyHandlerObject =
532 py::cast(pyHandler, py::return_value_policy::take_ownership);
533 pyHandlerObject.inc_ref();
534
535 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
536 // guaranteed to be known to pybind.
537 auto handlerCallback =
538 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
539 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
540 py::object pyDiagnosticObject =
541 py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
542
543 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
544 bool result = false;
545 {
546 // Since this can be called from arbitrary C++ contexts, always get the
547 // gil.
548 py::gil_scoped_acquire gil;
549 try {
550 result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
551 } catch (std::exception &e) {
552 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
553 e.what());
554 pyHandler->hadError = true;
555 }
556 }
557
558 pyDiagnostic->invalidate();
559 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
560 };
561 auto deleteCallback = +[](void *userData) {
562 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
563 assert(pyHandler->registeredID && "handler is not registered");
564 pyHandler->registeredID.reset();
565
566 // Decrement reference, balancing the inc_ref() above.
567 py::object pyHandlerObject =
568 py::cast(pyHandler, py::return_value_policy::reference);
569 pyHandlerObject.dec_ref();
570 };
571
572 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
573 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
574 return pyHandlerObject;
575 }
576
resolve()577 PyMlirContext &DefaultingPyMlirContext::resolve() {
578 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
579 if (!context) {
580 throw SetPyError(
581 PyExc_RuntimeError,
582 "An MLIR function requires a Context but none was provided in the call "
583 "or from the surrounding environment. Either pass to the function with "
584 "a 'context=' argument or establish a default using 'with Context():'");
585 }
586 return *context;
587 }
588
589 //------------------------------------------------------------------------------
590 // PyThreadContextEntry management
591 //------------------------------------------------------------------------------
592
getStack()593 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
594 static thread_local std::vector<PyThreadContextEntry> stack;
595 return stack;
596 }
597
getTopOfStack()598 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
599 auto &stack = getStack();
600 if (stack.empty())
601 return nullptr;
602 return &stack.back();
603 }
604
push(FrameKind frameKind,py::object context,py::object insertionPoint,py::object location)605 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
606 py::object insertionPoint,
607 py::object location) {
608 auto &stack = getStack();
609 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
610 std::move(location));
611 // If the new stack has more than one entry and the context of the new top
612 // entry matches the previous, copy the insertionPoint and location from the
613 // previous entry if missing from the new top entry.
614 if (stack.size() > 1) {
615 auto &prev = *(stack.rbegin() + 1);
616 auto ¤t = stack.back();
617 if (current.context.is(prev.context)) {
618 // Default non-context objects from the previous entry.
619 if (!current.insertionPoint)
620 current.insertionPoint = prev.insertionPoint;
621 if (!current.location)
622 current.location = prev.location;
623 }
624 }
625 }
626
getContext()627 PyMlirContext *PyThreadContextEntry::getContext() {
628 if (!context)
629 return nullptr;
630 return py::cast<PyMlirContext *>(context);
631 }
632
getInsertionPoint()633 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
634 if (!insertionPoint)
635 return nullptr;
636 return py::cast<PyInsertionPoint *>(insertionPoint);
637 }
638
getLocation()639 PyLocation *PyThreadContextEntry::getLocation() {
640 if (!location)
641 return nullptr;
642 return py::cast<PyLocation *>(location);
643 }
644
getDefaultContext()645 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
646 auto *tos = getTopOfStack();
647 return tos ? tos->getContext() : nullptr;
648 }
649
getDefaultInsertionPoint()650 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
651 auto *tos = getTopOfStack();
652 return tos ? tos->getInsertionPoint() : nullptr;
653 }
654
getDefaultLocation()655 PyLocation *PyThreadContextEntry::getDefaultLocation() {
656 auto *tos = getTopOfStack();
657 return tos ? tos->getLocation() : nullptr;
658 }
659
pushContext(PyMlirContext & context)660 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
661 py::object contextObj = py::cast(context);
662 push(FrameKind::Context, /*context=*/contextObj,
663 /*insertionPoint=*/py::object(),
664 /*location=*/py::object());
665 return contextObj;
666 }
667
popContext(PyMlirContext & context)668 void PyThreadContextEntry::popContext(PyMlirContext &context) {
669 auto &stack = getStack();
670 if (stack.empty())
671 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
672 auto &tos = stack.back();
673 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
674 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
675 stack.pop_back();
676 }
677
678 py::object
pushInsertionPoint(PyInsertionPoint & insertionPoint)679 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
680 py::object contextObj =
681 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
682 py::object insertionPointObj = py::cast(insertionPoint);
683 push(FrameKind::InsertionPoint,
684 /*context=*/contextObj,
685 /*insertionPoint=*/insertionPointObj,
686 /*location=*/py::object());
687 return insertionPointObj;
688 }
689
popInsertionPoint(PyInsertionPoint & insertionPoint)690 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
691 auto &stack = getStack();
692 if (stack.empty())
693 throw SetPyError(PyExc_RuntimeError,
694 "Unbalanced InsertionPoint enter/exit");
695 auto &tos = stack.back();
696 if (tos.frameKind != FrameKind::InsertionPoint &&
697 tos.getInsertionPoint() != &insertionPoint)
698 throw SetPyError(PyExc_RuntimeError,
699 "Unbalanced InsertionPoint enter/exit");
700 stack.pop_back();
701 }
702
pushLocation(PyLocation & location)703 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
704 py::object contextObj = location.getContext().getObject();
705 py::object locationObj = py::cast(location);
706 push(FrameKind::Location, /*context=*/contextObj,
707 /*insertionPoint=*/py::object(),
708 /*location=*/locationObj);
709 return locationObj;
710 }
711
popLocation(PyLocation & location)712 void PyThreadContextEntry::popLocation(PyLocation &location) {
713 auto &stack = getStack();
714 if (stack.empty())
715 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
716 auto &tos = stack.back();
717 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
718 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
719 stack.pop_back();
720 }
721
722 //------------------------------------------------------------------------------
723 // PyDiagnostic*
724 //------------------------------------------------------------------------------
725
invalidate()726 void PyDiagnostic::invalidate() {
727 valid = false;
728 if (materializedNotes) {
729 for (auto ¬eObject : *materializedNotes) {
730 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
731 note->invalidate();
732 }
733 }
734 }
735
PyDiagnosticHandler(MlirContext context,py::object callback)736 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
737 py::object callback)
738 : context(context), callback(std::move(callback)) {}
739
740 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
741
detach()742 void PyDiagnosticHandler::detach() {
743 if (!registeredID)
744 return;
745 MlirDiagnosticHandlerID localID = *registeredID;
746 mlirContextDetachDiagnosticHandler(context, localID);
747 assert(!registeredID && "should have unregistered");
748 // Not strictly necessary but keeps stale pointers from being around to cause
749 // issues.
750 context = {nullptr};
751 }
752
checkValid()753 void PyDiagnostic::checkValid() {
754 if (!valid) {
755 throw std::invalid_argument(
756 "Diagnostic is invalid (used outside of callback)");
757 }
758 }
759
getSeverity()760 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
761 checkValid();
762 return mlirDiagnosticGetSeverity(diagnostic);
763 }
764
getLocation()765 PyLocation PyDiagnostic::getLocation() {
766 checkValid();
767 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
768 MlirContext context = mlirLocationGetContext(loc);
769 return PyLocation(PyMlirContext::forContext(context), loc);
770 }
771
getMessage()772 py::str PyDiagnostic::getMessage() {
773 checkValid();
774 py::object fileObject = py::module::import("io").attr("StringIO")();
775 PyFileAccumulator accum(fileObject, /*binary=*/false);
776 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
777 return fileObject.attr("getvalue")();
778 }
779
getNotes()780 py::tuple PyDiagnostic::getNotes() {
781 checkValid();
782 if (materializedNotes)
783 return *materializedNotes;
784 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
785 materializedNotes = py::tuple(numNotes);
786 for (intptr_t i = 0; i < numNotes; ++i) {
787 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
788 materializedNotes.value()[i] = PyDiagnostic(noteDiag);
789 }
790 return *materializedNotes;
791 }
792
793 //------------------------------------------------------------------------------
794 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
795 //------------------------------------------------------------------------------
796
getDialectForKey(const std::string & key,bool attrError)797 MlirDialect PyDialects::getDialectForKey(const std::string &key,
798 bool attrError) {
799 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
800 {key.data(), key.size()});
801 if (mlirDialectIsNull(dialect)) {
802 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
803 Twine("Dialect '") + key + "' not found");
804 }
805 return dialect;
806 }
807
getCapsule()808 py::object PyDialectRegistry::getCapsule() {
809 return py::reinterpret_steal<py::object>(
810 mlirPythonDialectRegistryToCapsule(*this));
811 }
812
createFromCapsule(py::object capsule)813 PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
814 MlirDialectRegistry rawRegistry =
815 mlirPythonCapsuleToDialectRegistry(capsule.ptr());
816 if (mlirDialectRegistryIsNull(rawRegistry))
817 throw py::error_already_set();
818 return PyDialectRegistry(rawRegistry);
819 }
820
821 //------------------------------------------------------------------------------
822 // PyLocation
823 //------------------------------------------------------------------------------
824
getCapsule()825 py::object PyLocation::getCapsule() {
826 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
827 }
828
createFromCapsule(py::object capsule)829 PyLocation PyLocation::createFromCapsule(py::object capsule) {
830 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
831 if (mlirLocationIsNull(rawLoc))
832 throw py::error_already_set();
833 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
834 rawLoc);
835 }
836
contextEnter()837 py::object PyLocation::contextEnter() {
838 return PyThreadContextEntry::pushLocation(*this);
839 }
840
contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)841 void PyLocation::contextExit(const pybind11::object &excType,
842 const pybind11::object &excVal,
843 const pybind11::object &excTb) {
844 PyThreadContextEntry::popLocation(*this);
845 }
846
resolve()847 PyLocation &DefaultingPyLocation::resolve() {
848 auto *location = PyThreadContextEntry::getDefaultLocation();
849 if (!location) {
850 throw SetPyError(
851 PyExc_RuntimeError,
852 "An MLIR function requires a Location but none was provided in the "
853 "call or from the surrounding environment. Either pass to the function "
854 "with a 'loc=' argument or establish a default using 'with loc:'");
855 }
856 return *location;
857 }
858
859 //------------------------------------------------------------------------------
860 // PyModule
861 //------------------------------------------------------------------------------
862
PyModule(PyMlirContextRef contextRef,MlirModule module)863 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
864 : BaseContextObject(std::move(contextRef)), module(module) {}
865
~PyModule()866 PyModule::~PyModule() {
867 py::gil_scoped_acquire acquire;
868 auto &liveModules = getContext()->liveModules;
869 assert(liveModules.count(module.ptr) == 1 &&
870 "destroying module not in live map");
871 liveModules.erase(module.ptr);
872 mlirModuleDestroy(module);
873 }
874
forModule(MlirModule module)875 PyModuleRef PyModule::forModule(MlirModule module) {
876 MlirContext context = mlirModuleGetContext(module);
877 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
878
879 py::gil_scoped_acquire acquire;
880 auto &liveModules = contextRef->liveModules;
881 auto it = liveModules.find(module.ptr);
882 if (it == liveModules.end()) {
883 // Create.
884 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
885 // Note that the default return value policy on cast is automatic_reference,
886 // which does not take ownership (delete will not be called).
887 // Just be explicit.
888 py::object pyRef =
889 py::cast(unownedModule, py::return_value_policy::take_ownership);
890 unownedModule->handle = pyRef;
891 liveModules[module.ptr] =
892 std::make_pair(unownedModule->handle, unownedModule);
893 return PyModuleRef(unownedModule, std::move(pyRef));
894 }
895 // Use existing.
896 PyModule *existing = it->second.second;
897 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
898 return PyModuleRef(existing, std::move(pyRef));
899 }
900
createFromCapsule(py::object capsule)901 py::object PyModule::createFromCapsule(py::object capsule) {
902 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
903 if (mlirModuleIsNull(rawModule))
904 throw py::error_already_set();
905 return forModule(rawModule).releaseObject();
906 }
907
getCapsule()908 py::object PyModule::getCapsule() {
909 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
910 }
911
912 //------------------------------------------------------------------------------
913 // PyOperation
914 //------------------------------------------------------------------------------
915
PyOperation(PyMlirContextRef contextRef,MlirOperation operation)916 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
917 : BaseContextObject(std::move(contextRef)), operation(operation) {}
918
~PyOperation()919 PyOperation::~PyOperation() {
920 // If the operation has already been invalidated there is nothing to do.
921 if (!valid)
922 return;
923 auto &liveOperations = getContext()->liveOperations;
924 assert(liveOperations.count(operation.ptr) == 1 &&
925 "destroying operation not in live map");
926 liveOperations.erase(operation.ptr);
927 if (!isAttached()) {
928 mlirOperationDestroy(operation);
929 }
930 }
931
createInstance(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)932 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
933 MlirOperation operation,
934 py::object parentKeepAlive) {
935 auto &liveOperations = contextRef->liveOperations;
936 // Create.
937 PyOperation *unownedOperation =
938 new PyOperation(std::move(contextRef), operation);
939 // Note that the default return value policy on cast is automatic_reference,
940 // which does not take ownership (delete will not be called).
941 // Just be explicit.
942 py::object pyRef =
943 py::cast(unownedOperation, py::return_value_policy::take_ownership);
944 unownedOperation->handle = pyRef;
945 if (parentKeepAlive) {
946 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
947 }
948 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
949 return PyOperationRef(unownedOperation, std::move(pyRef));
950 }
951
forOperation(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)952 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
953 MlirOperation operation,
954 py::object parentKeepAlive) {
955 auto &liveOperations = contextRef->liveOperations;
956 auto it = liveOperations.find(operation.ptr);
957 if (it == liveOperations.end()) {
958 // Create.
959 return createInstance(std::move(contextRef), operation,
960 std::move(parentKeepAlive));
961 }
962 // Use existing.
963 PyOperation *existing = it->second.second;
964 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
965 return PyOperationRef(existing, std::move(pyRef));
966 }
967
createDetached(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)968 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
969 MlirOperation operation,
970 py::object parentKeepAlive) {
971 auto &liveOperations = contextRef->liveOperations;
972 assert(liveOperations.count(operation.ptr) == 0 &&
973 "cannot create detached operation that already exists");
974 (void)liveOperations;
975
976 PyOperationRef created = createInstance(std::move(contextRef), operation,
977 std::move(parentKeepAlive));
978 created->attached = false;
979 return created;
980 }
981
checkValid() const982 void PyOperation::checkValid() const {
983 if (!valid) {
984 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
985 }
986 }
987
print(py::object fileObject,bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope,bool assumeVerified)988 void PyOperationBase::print(py::object fileObject, bool binary,
989 llvm::Optional<int64_t> largeElementsLimit,
990 bool enableDebugInfo, bool prettyDebugInfo,
991 bool printGenericOpForm, bool useLocalScope,
992 bool assumeVerified) {
993 PyOperation &operation = getOperation();
994 operation.checkValid();
995 if (fileObject.is_none())
996 fileObject = py::module::import("sys").attr("stdout");
997
998 if (!assumeVerified && !printGenericOpForm &&
999 !mlirOperationVerify(operation)) {
1000 std::string message("// Verification failed, printing generic form\n");
1001 if (binary) {
1002 fileObject.attr("write")(py::bytes(message));
1003 } else {
1004 fileObject.attr("write")(py::str(message));
1005 }
1006 printGenericOpForm = true;
1007 }
1008
1009 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1010 if (largeElementsLimit)
1011 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1012 if (enableDebugInfo)
1013 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
1014 if (printGenericOpForm)
1015 mlirOpPrintingFlagsPrintGenericOpForm(flags);
1016 if (useLocalScope)
1017 mlirOpPrintingFlagsUseLocalScope(flags);
1018
1019 PyFileAccumulator accum(fileObject, binary);
1020 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1021 accum.getUserData());
1022 mlirOpPrintingFlagsDestroy(flags);
1023 }
1024
getAsm(bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope,bool assumeVerified)1025 py::object PyOperationBase::getAsm(bool binary,
1026 llvm::Optional<int64_t> largeElementsLimit,
1027 bool enableDebugInfo, bool prettyDebugInfo,
1028 bool printGenericOpForm, bool useLocalScope,
1029 bool assumeVerified) {
1030 py::object fileObject;
1031 if (binary) {
1032 fileObject = py::module::import("io").attr("BytesIO")();
1033 } else {
1034 fileObject = py::module::import("io").attr("StringIO")();
1035 }
1036 print(fileObject, /*binary=*/binary,
1037 /*largeElementsLimit=*/largeElementsLimit,
1038 /*enableDebugInfo=*/enableDebugInfo,
1039 /*prettyDebugInfo=*/prettyDebugInfo,
1040 /*printGenericOpForm=*/printGenericOpForm,
1041 /*useLocalScope=*/useLocalScope,
1042 /*assumeVerified=*/assumeVerified);
1043
1044 return fileObject.attr("getvalue")();
1045 }
1046
moveAfter(PyOperationBase & other)1047 void PyOperationBase::moveAfter(PyOperationBase &other) {
1048 PyOperation &operation = getOperation();
1049 PyOperation &otherOp = other.getOperation();
1050 operation.checkValid();
1051 otherOp.checkValid();
1052 mlirOperationMoveAfter(operation, otherOp);
1053 operation.parentKeepAlive = otherOp.parentKeepAlive;
1054 }
1055
moveBefore(PyOperationBase & other)1056 void PyOperationBase::moveBefore(PyOperationBase &other) {
1057 PyOperation &operation = getOperation();
1058 PyOperation &otherOp = other.getOperation();
1059 operation.checkValid();
1060 otherOp.checkValid();
1061 mlirOperationMoveBefore(operation, otherOp);
1062 operation.parentKeepAlive = otherOp.parentKeepAlive;
1063 }
1064
getParentOperation()1065 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1066 checkValid();
1067 if (!isAttached())
1068 throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1069 MlirOperation operation = mlirOperationGetParentOperation(get());
1070 if (mlirOperationIsNull(operation))
1071 return {};
1072 return PyOperation::forOperation(getContext(), operation);
1073 }
1074
getBlock()1075 PyBlock PyOperation::getBlock() {
1076 checkValid();
1077 llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1078 MlirBlock block = mlirOperationGetBlock(get());
1079 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1080 assert(parentOperation && "Operation has no parent");
1081 return PyBlock{std::move(*parentOperation), block};
1082 }
1083
getCapsule()1084 py::object PyOperation::getCapsule() {
1085 checkValid();
1086 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1087 }
1088
createFromCapsule(py::object capsule)1089 py::object PyOperation::createFromCapsule(py::object capsule) {
1090 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1091 if (mlirOperationIsNull(rawOperation))
1092 throw py::error_already_set();
1093 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1094 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1095 .releaseObject();
1096 }
1097
maybeInsertOperation(PyOperationRef & op,const py::object & maybeIp)1098 static void maybeInsertOperation(PyOperationRef &op,
1099 const py::object &maybeIp) {
1100 // InsertPoint active?
1101 if (!maybeIp.is(py::cast(false))) {
1102 PyInsertionPoint *ip;
1103 if (maybeIp.is_none()) {
1104 ip = PyThreadContextEntry::getDefaultInsertionPoint();
1105 } else {
1106 ip = py::cast<PyInsertionPoint *>(maybeIp);
1107 }
1108 if (ip)
1109 ip->insert(*op.get());
1110 }
1111 }
1112
create(const std::string & name,llvm::Optional<std::vector<PyType * >> results,llvm::Optional<std::vector<PyValue * >> operands,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,int regions,DefaultingPyLocation location,const py::object & maybeIp)1113 py::object PyOperation::create(
1114 const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1115 llvm::Optional<std::vector<PyValue *>> operands,
1116 llvm::Optional<py::dict> attributes,
1117 llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1118 DefaultingPyLocation location, const py::object &maybeIp) {
1119 llvm::SmallVector<MlirValue, 4> mlirOperands;
1120 llvm::SmallVector<MlirType, 4> mlirResults;
1121 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1122 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1123
1124 // General parameter validation.
1125 if (regions < 0)
1126 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1127
1128 // Unpack/validate operands.
1129 if (operands) {
1130 mlirOperands.reserve(operands->size());
1131 for (PyValue *operand : *operands) {
1132 if (!operand)
1133 throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1134 mlirOperands.push_back(operand->get());
1135 }
1136 }
1137
1138 // Unpack/validate results.
1139 if (results) {
1140 mlirResults.reserve(results->size());
1141 for (PyType *result : *results) {
1142 // TODO: Verify result type originate from the same context.
1143 if (!result)
1144 throw SetPyError(PyExc_ValueError, "result type cannot be None");
1145 mlirResults.push_back(*result);
1146 }
1147 }
1148 // Unpack/validate attributes.
1149 if (attributes) {
1150 mlirAttributes.reserve(attributes->size());
1151 for (auto &it : *attributes) {
1152 std::string key;
1153 try {
1154 key = it.first.cast<std::string>();
1155 } catch (py::cast_error &err) {
1156 std::string msg = "Invalid attribute key (not a string) when "
1157 "attempting to create the operation \"" +
1158 name + "\" (" + err.what() + ")";
1159 throw py::cast_error(msg);
1160 }
1161 try {
1162 auto &attribute = it.second.cast<PyAttribute &>();
1163 // TODO: Verify attribute originates from the same context.
1164 mlirAttributes.emplace_back(std::move(key), attribute);
1165 } catch (py::reference_cast_error &) {
1166 // This exception seems thrown when the value is "None".
1167 std::string msg =
1168 "Found an invalid (`None`?) attribute value for the key \"" + key +
1169 "\" when attempting to create the operation \"" + name + "\"";
1170 throw py::cast_error(msg);
1171 } catch (py::cast_error &err) {
1172 std::string msg = "Invalid attribute value for the key \"" + key +
1173 "\" when attempting to create the operation \"" +
1174 name + "\" (" + err.what() + ")";
1175 throw py::cast_error(msg);
1176 }
1177 }
1178 }
1179 // Unpack/validate successors.
1180 if (successors) {
1181 mlirSuccessors.reserve(successors->size());
1182 for (auto *successor : *successors) {
1183 // TODO: Verify successor originate from the same context.
1184 if (!successor)
1185 throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1186 mlirSuccessors.push_back(successor->get());
1187 }
1188 }
1189
1190 // Apply unpacked/validated to the operation state. Beyond this
1191 // point, exceptions cannot be thrown or else the state will leak.
1192 MlirOperationState state =
1193 mlirOperationStateGet(toMlirStringRef(name), location);
1194 if (!mlirOperands.empty())
1195 mlirOperationStateAddOperands(&state, mlirOperands.size(),
1196 mlirOperands.data());
1197 if (!mlirResults.empty())
1198 mlirOperationStateAddResults(&state, mlirResults.size(),
1199 mlirResults.data());
1200 if (!mlirAttributes.empty()) {
1201 // Note that the attribute names directly reference bytes in
1202 // mlirAttributes, so that vector must not be changed from here
1203 // on.
1204 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1205 mlirNamedAttributes.reserve(mlirAttributes.size());
1206 for (auto &it : mlirAttributes)
1207 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1208 mlirIdentifierGet(mlirAttributeGetContext(it.second),
1209 toMlirStringRef(it.first)),
1210 it.second));
1211 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1212 mlirNamedAttributes.data());
1213 }
1214 if (!mlirSuccessors.empty())
1215 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1216 mlirSuccessors.data());
1217 if (regions) {
1218 llvm::SmallVector<MlirRegion, 4> mlirRegions;
1219 mlirRegions.resize(regions);
1220 for (int i = 0; i < regions; ++i)
1221 mlirRegions[i] = mlirRegionCreate();
1222 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1223 mlirRegions.data());
1224 }
1225
1226 // Construct the operation.
1227 MlirOperation operation = mlirOperationCreate(&state);
1228 PyOperationRef created =
1229 PyOperation::createDetached(location->getContext(), operation);
1230 maybeInsertOperation(created, maybeIp);
1231
1232 return created->createOpView();
1233 }
1234
clone(const py::object & maybeIp)1235 py::object PyOperation::clone(const py::object &maybeIp) {
1236 MlirOperation clonedOperation = mlirOperationClone(operation);
1237 PyOperationRef cloned =
1238 PyOperation::createDetached(getContext(), clonedOperation);
1239 maybeInsertOperation(cloned, maybeIp);
1240
1241 return cloned->createOpView();
1242 }
1243
createOpView()1244 py::object PyOperation::createOpView() {
1245 checkValid();
1246 MlirIdentifier ident = mlirOperationGetName(get());
1247 MlirStringRef identStr = mlirIdentifierStr(ident);
1248 auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1249 StringRef(identStr.data, identStr.length));
1250 if (opViewClass)
1251 return (*opViewClass)(getRef().getObject());
1252 return py::cast(PyOpView(getRef().getObject()));
1253 }
1254
erase()1255 void PyOperation::erase() {
1256 checkValid();
1257 // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1258 // Python reference to a child operation is live. All children should also
1259 // have their `valid` bit set to false.
1260 auto &liveOperations = getContext()->liveOperations;
1261 if (liveOperations.count(operation.ptr))
1262 liveOperations.erase(operation.ptr);
1263 mlirOperationDestroy(operation);
1264 valid = false;
1265 }
1266
1267 //------------------------------------------------------------------------------
1268 // PyOpView
1269 //------------------------------------------------------------------------------
1270
buildGeneric(const py::object & cls,py::list resultTypeList,py::list operandList,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,llvm::Optional<int> regions,DefaultingPyLocation location,const py::object & maybeIp)1271 py::object PyOpView::buildGeneric(
1272 const py::object &cls, py::list resultTypeList, py::list operandList,
1273 llvm::Optional<py::dict> attributes,
1274 llvm::Optional<std::vector<PyBlock *>> successors,
1275 llvm::Optional<int> regions, DefaultingPyLocation location,
1276 const py::object &maybeIp) {
1277 PyMlirContextRef context = location->getContext();
1278 // Class level operation construction metadata.
1279 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1280 // Operand and result segment specs are either none, which does no
1281 // variadic unpacking, or a list of ints with segment sizes, where each
1282 // element is either a positive number (typically 1 for a scalar) or -1 to
1283 // indicate that it is derived from the length of the same-indexed operand
1284 // or result (implying that it is a list at that position).
1285 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1286 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1287
1288 std::vector<uint32_t> operandSegmentLengths;
1289 std::vector<uint32_t> resultSegmentLengths;
1290
1291 // Validate/determine region count.
1292 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1293 int opMinRegionCount = std::get<0>(opRegionSpec);
1294 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1295 if (!regions) {
1296 regions = opMinRegionCount;
1297 }
1298 if (*regions < opMinRegionCount) {
1299 throw py::value_error(
1300 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1301 llvm::Twine(opMinRegionCount) +
1302 " regions but was built with regions=" + llvm::Twine(*regions))
1303 .str());
1304 }
1305 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1306 throw py::value_error(
1307 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1308 llvm::Twine(opMinRegionCount) +
1309 " regions but was built with regions=" + llvm::Twine(*regions))
1310 .str());
1311 }
1312
1313 // Unpack results.
1314 std::vector<PyType *> resultTypes;
1315 resultTypes.reserve(resultTypeList.size());
1316 if (resultSegmentSpecObj.is_none()) {
1317 // Non-variadic result unpacking.
1318 for (const auto &it : llvm::enumerate(resultTypeList)) {
1319 try {
1320 resultTypes.push_back(py::cast<PyType *>(it.value()));
1321 if (!resultTypes.back())
1322 throw py::cast_error();
1323 } catch (py::cast_error &err) {
1324 throw py::value_error((llvm::Twine("Result ") +
1325 llvm::Twine(it.index()) + " of operation \"" +
1326 name + "\" must be a Type (" + err.what() + ")")
1327 .str());
1328 }
1329 }
1330 } else {
1331 // Sized result unpacking.
1332 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1333 if (resultSegmentSpec.size() != resultTypeList.size()) {
1334 throw py::value_error((llvm::Twine("Operation \"") + name +
1335 "\" requires " +
1336 llvm::Twine(resultSegmentSpec.size()) +
1337 " result segments but was provided " +
1338 llvm::Twine(resultTypeList.size()))
1339 .str());
1340 }
1341 resultSegmentLengths.reserve(resultTypeList.size());
1342 for (const auto &it :
1343 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1344 int segmentSpec = std::get<1>(it.value());
1345 if (segmentSpec == 1 || segmentSpec == 0) {
1346 // Unpack unary element.
1347 try {
1348 auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1349 if (resultType) {
1350 resultTypes.push_back(resultType);
1351 resultSegmentLengths.push_back(1);
1352 } else if (segmentSpec == 0) {
1353 // Allowed to be optional.
1354 resultSegmentLengths.push_back(0);
1355 } else {
1356 throw py::cast_error("was None and result is not optional");
1357 }
1358 } catch (py::cast_error &err) {
1359 throw py::value_error((llvm::Twine("Result ") +
1360 llvm::Twine(it.index()) + " of operation \"" +
1361 name + "\" must be a Type (" + err.what() +
1362 ")")
1363 .str());
1364 }
1365 } else if (segmentSpec == -1) {
1366 // Unpack sequence by appending.
1367 try {
1368 if (std::get<0>(it.value()).is_none()) {
1369 // Treat it as an empty list.
1370 resultSegmentLengths.push_back(0);
1371 } else {
1372 // Unpack the list.
1373 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1374 for (py::object segmentItem : segment) {
1375 resultTypes.push_back(py::cast<PyType *>(segmentItem));
1376 if (!resultTypes.back()) {
1377 throw py::cast_error("contained a None item");
1378 }
1379 }
1380 resultSegmentLengths.push_back(segment.size());
1381 }
1382 } catch (std::exception &err) {
1383 // NOTE: Sloppy to be using a catch-all here, but there are at least
1384 // three different unrelated exceptions that can be thrown in the
1385 // above "casts". Just keep the scope above small and catch them all.
1386 throw py::value_error((llvm::Twine("Result ") +
1387 llvm::Twine(it.index()) + " of operation \"" +
1388 name + "\" must be a Sequence of Types (" +
1389 err.what() + ")")
1390 .str());
1391 }
1392 } else {
1393 throw py::value_error("Unexpected segment spec");
1394 }
1395 }
1396 }
1397
1398 // Unpack operands.
1399 std::vector<PyValue *> operands;
1400 operands.reserve(operands.size());
1401 if (operandSegmentSpecObj.is_none()) {
1402 // Non-sized operand unpacking.
1403 for (const auto &it : llvm::enumerate(operandList)) {
1404 try {
1405 operands.push_back(py::cast<PyValue *>(it.value()));
1406 if (!operands.back())
1407 throw py::cast_error();
1408 } catch (py::cast_error &err) {
1409 throw py::value_error((llvm::Twine("Operand ") +
1410 llvm::Twine(it.index()) + " of operation \"" +
1411 name + "\" must be a Value (" + err.what() + ")")
1412 .str());
1413 }
1414 }
1415 } else {
1416 // Sized operand unpacking.
1417 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1418 if (operandSegmentSpec.size() != operandList.size()) {
1419 throw py::value_error((llvm::Twine("Operation \"") + name +
1420 "\" requires " +
1421 llvm::Twine(operandSegmentSpec.size()) +
1422 "operand segments but was provided " +
1423 llvm::Twine(operandList.size()))
1424 .str());
1425 }
1426 operandSegmentLengths.reserve(operandList.size());
1427 for (const auto &it :
1428 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1429 int segmentSpec = std::get<1>(it.value());
1430 if (segmentSpec == 1 || segmentSpec == 0) {
1431 // Unpack unary element.
1432 try {
1433 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1434 if (operandValue) {
1435 operands.push_back(operandValue);
1436 operandSegmentLengths.push_back(1);
1437 } else if (segmentSpec == 0) {
1438 // Allowed to be optional.
1439 operandSegmentLengths.push_back(0);
1440 } else {
1441 throw py::cast_error("was None and operand is not optional");
1442 }
1443 } catch (py::cast_error &err) {
1444 throw py::value_error((llvm::Twine("Operand ") +
1445 llvm::Twine(it.index()) + " of operation \"" +
1446 name + "\" must be a Value (" + err.what() +
1447 ")")
1448 .str());
1449 }
1450 } else if (segmentSpec == -1) {
1451 // Unpack sequence by appending.
1452 try {
1453 if (std::get<0>(it.value()).is_none()) {
1454 // Treat it as an empty list.
1455 operandSegmentLengths.push_back(0);
1456 } else {
1457 // Unpack the list.
1458 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1459 for (py::object segmentItem : segment) {
1460 operands.push_back(py::cast<PyValue *>(segmentItem));
1461 if (!operands.back()) {
1462 throw py::cast_error("contained a None item");
1463 }
1464 }
1465 operandSegmentLengths.push_back(segment.size());
1466 }
1467 } catch (std::exception &err) {
1468 // NOTE: Sloppy to be using a catch-all here, but there are at least
1469 // three different unrelated exceptions that can be thrown in the
1470 // above "casts". Just keep the scope above small and catch them all.
1471 throw py::value_error((llvm::Twine("Operand ") +
1472 llvm::Twine(it.index()) + " of operation \"" +
1473 name + "\" must be a Sequence of Values (" +
1474 err.what() + ")")
1475 .str());
1476 }
1477 } else {
1478 throw py::value_error("Unexpected segment spec");
1479 }
1480 }
1481 }
1482
1483 // Merge operand/result segment lengths into attributes if needed.
1484 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1485 // Dup.
1486 if (attributes) {
1487 attributes = py::dict(*attributes);
1488 } else {
1489 attributes = py::dict();
1490 }
1491 if (attributes->contains("result_segment_sizes") ||
1492 attributes->contains("operand_segment_sizes")) {
1493 throw py::value_error("Manually setting a 'result_segment_sizes' or "
1494 "'operand_segment_sizes' attribute is unsupported. "
1495 "Use Operation.create for such low-level access.");
1496 }
1497
1498 // Add result_segment_sizes attribute.
1499 if (!resultSegmentLengths.empty()) {
1500 int64_t size = resultSegmentLengths.size();
1501 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1502 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1503 resultSegmentLengths.size(), resultSegmentLengths.data());
1504 (*attributes)["result_segment_sizes"] =
1505 PyAttribute(context, segmentLengthAttr);
1506 }
1507
1508 // Add operand_segment_sizes attribute.
1509 if (!operandSegmentLengths.empty()) {
1510 int64_t size = operandSegmentLengths.size();
1511 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1512 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1513 operandSegmentLengths.size(), operandSegmentLengths.data());
1514 (*attributes)["operand_segment_sizes"] =
1515 PyAttribute(context, segmentLengthAttr);
1516 }
1517 }
1518
1519 // Delegate to create.
1520 return PyOperation::create(name,
1521 /*results=*/std::move(resultTypes),
1522 /*operands=*/std::move(operands),
1523 /*attributes=*/std::move(attributes),
1524 /*successors=*/std::move(successors),
1525 /*regions=*/*regions, location, maybeIp);
1526 }
1527
PyOpView(const py::object & operationObject)1528 PyOpView::PyOpView(const py::object &operationObject)
1529 // Casting through the PyOperationBase base-class and then back to the
1530 // Operation lets us accept any PyOperationBase subclass.
1531 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1532 operationObject(operation.getRef().getObject()) {}
1533
createRawSubclass(const py::object & userClass)1534 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1535 // This is... a little gross. The typical pattern is to have a pure python
1536 // class that extends OpView like:
1537 // class AddFOp(_cext.ir.OpView):
1538 // def __init__(self, loc, lhs, rhs):
1539 // operation = loc.context.create_operation(
1540 // "addf", lhs, rhs, results=[lhs.type])
1541 // super().__init__(operation)
1542 //
1543 // I.e. The goal of the user facing type is to provide a nice constructor
1544 // that has complete freedom for the op under construction. This is at odds
1545 // with our other desire to sometimes create this object by just passing an
1546 // operation (to initialize the base class). We could do *arg and **kwargs
1547 // munging to try to make it work, but instead, we synthesize a new class
1548 // on the fly which extends this user class (AddFOp in this example) and
1549 // *give it* the base class's __init__ method, thus bypassing the
1550 // intermediate subclass's __init__ method entirely. While slightly,
1551 // underhanded, this is safe/legal because the type hierarchy has not changed
1552 // (we just added a new leaf) and we aren't mucking around with __new__.
1553 // Typically, this new class will be stored on the original as "_Raw" and will
1554 // be used for casts and other things that need a variant of the class that
1555 // is initialized purely from an operation.
1556 py::object parentMetaclass =
1557 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1558 py::dict attributes;
1559 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1560 // now.
1561 // auto opViewType = py::type::of<PyOpView>();
1562 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1563 attributes["__init__"] = opViewType.attr("__init__");
1564 py::str origName = userClass.attr("__name__");
1565 py::str newName = py::str("_") + origName;
1566 return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1567 }
1568
1569 //------------------------------------------------------------------------------
1570 // PyInsertionPoint.
1571 //------------------------------------------------------------------------------
1572
PyInsertionPoint(PyBlock & block)1573 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1574
PyInsertionPoint(PyOperationBase & beforeOperationBase)1575 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1576 : refOperation(beforeOperationBase.getOperation().getRef()),
1577 block((*refOperation)->getBlock()) {}
1578
insert(PyOperationBase & operationBase)1579 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1580 PyOperation &operation = operationBase.getOperation();
1581 if (operation.isAttached())
1582 throw SetPyError(PyExc_ValueError,
1583 "Attempt to insert operation that is already attached");
1584 block.getParentOperation()->checkValid();
1585 MlirOperation beforeOp = {nullptr};
1586 if (refOperation) {
1587 // Insert before operation.
1588 (*refOperation)->checkValid();
1589 beforeOp = (*refOperation)->get();
1590 } else {
1591 // Insert at end (before null) is only valid if the block does not
1592 // already end in a known terminator (violating this will cause assertion
1593 // failures later).
1594 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1595 throw py::index_error("Cannot insert operation at the end of a block "
1596 "that already has a terminator. Did you mean to "
1597 "use 'InsertionPoint.at_block_terminator(block)' "
1598 "versus 'InsertionPoint(block)'?");
1599 }
1600 }
1601 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1602 operation.setAttached();
1603 }
1604
atBlockBegin(PyBlock & block)1605 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1606 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1607 if (mlirOperationIsNull(firstOp)) {
1608 // Just insert at end.
1609 return PyInsertionPoint(block);
1610 }
1611
1612 // Insert before first op.
1613 PyOperationRef firstOpRef = PyOperation::forOperation(
1614 block.getParentOperation()->getContext(), firstOp);
1615 return PyInsertionPoint{block, std::move(firstOpRef)};
1616 }
1617
atBlockTerminator(PyBlock & block)1618 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1619 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1620 if (mlirOperationIsNull(terminator))
1621 throw SetPyError(PyExc_ValueError, "Block has no terminator");
1622 PyOperationRef terminatorOpRef = PyOperation::forOperation(
1623 block.getParentOperation()->getContext(), terminator);
1624 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1625 }
1626
contextEnter()1627 py::object PyInsertionPoint::contextEnter() {
1628 return PyThreadContextEntry::pushInsertionPoint(*this);
1629 }
1630
contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)1631 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1632 const pybind11::object &excVal,
1633 const pybind11::object &excTb) {
1634 PyThreadContextEntry::popInsertionPoint(*this);
1635 }
1636
1637 //------------------------------------------------------------------------------
1638 // PyAttribute.
1639 //------------------------------------------------------------------------------
1640
operator ==(const PyAttribute & other)1641 bool PyAttribute::operator==(const PyAttribute &other) {
1642 return mlirAttributeEqual(attr, other.attr);
1643 }
1644
getCapsule()1645 py::object PyAttribute::getCapsule() {
1646 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1647 }
1648
createFromCapsule(py::object capsule)1649 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1650 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1651 if (mlirAttributeIsNull(rawAttr))
1652 throw py::error_already_set();
1653 return PyAttribute(
1654 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1655 }
1656
1657 //------------------------------------------------------------------------------
1658 // PyNamedAttribute.
1659 //------------------------------------------------------------------------------
1660
PyNamedAttribute(MlirAttribute attr,std::string ownedName)1661 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1662 : ownedName(new std::string(std::move(ownedName))) {
1663 namedAttr = mlirNamedAttributeGet(
1664 mlirIdentifierGet(mlirAttributeGetContext(attr),
1665 toMlirStringRef(*this->ownedName)),
1666 attr);
1667 }
1668
1669 //------------------------------------------------------------------------------
1670 // PyType.
1671 //------------------------------------------------------------------------------
1672
operator ==(const PyType & other)1673 bool PyType::operator==(const PyType &other) {
1674 return mlirTypeEqual(type, other.type);
1675 }
1676
getCapsule()1677 py::object PyType::getCapsule() {
1678 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1679 }
1680
createFromCapsule(py::object capsule)1681 PyType PyType::createFromCapsule(py::object capsule) {
1682 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1683 if (mlirTypeIsNull(rawType))
1684 throw py::error_already_set();
1685 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1686 rawType);
1687 }
1688
1689 //------------------------------------------------------------------------------
1690 // PyValue and subclases.
1691 //------------------------------------------------------------------------------
1692
getCapsule()1693 pybind11::object PyValue::getCapsule() {
1694 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1695 }
1696
createFromCapsule(pybind11::object capsule)1697 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1698 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1699 if (mlirValueIsNull(value))
1700 throw py::error_already_set();
1701 MlirOperation owner;
1702 if (mlirValueIsAOpResult(value))
1703 owner = mlirOpResultGetOwner(value);
1704 if (mlirValueIsABlockArgument(value))
1705 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1706 if (mlirOperationIsNull(owner))
1707 throw py::error_already_set();
1708 MlirContext ctx = mlirOperationGetContext(owner);
1709 PyOperationRef ownerRef =
1710 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1711 return PyValue(ownerRef, value);
1712 }
1713
1714 //------------------------------------------------------------------------------
1715 // PySymbolTable.
1716 //------------------------------------------------------------------------------
1717
PySymbolTable(PyOperationBase & operation)1718 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1719 : operation(operation.getOperation().getRef()) {
1720 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1721 if (mlirSymbolTableIsNull(symbolTable)) {
1722 throw py::cast_error("Operation is not a Symbol Table.");
1723 }
1724 }
1725
dunderGetItem(const std::string & name)1726 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1727 operation->checkValid();
1728 MlirOperation symbol = mlirSymbolTableLookup(
1729 symbolTable, mlirStringRefCreate(name.data(), name.length()));
1730 if (mlirOperationIsNull(symbol))
1731 throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1732
1733 return PyOperation::forOperation(operation->getContext(), symbol,
1734 operation.getObject())
1735 ->createOpView();
1736 }
1737
erase(PyOperationBase & symbol)1738 void PySymbolTable::erase(PyOperationBase &symbol) {
1739 operation->checkValid();
1740 symbol.getOperation().checkValid();
1741 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1742 // The operation is also erased, so we must invalidate it. There may be Python
1743 // references to this operation so we don't want to delete it from the list of
1744 // live operations here.
1745 symbol.getOperation().valid = false;
1746 }
1747
dunderDel(const std::string & name)1748 void PySymbolTable::dunderDel(const std::string &name) {
1749 py::object operation = dunderGetItem(name);
1750 erase(py::cast<PyOperationBase &>(operation));
1751 }
1752
insert(PyOperationBase & symbol)1753 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1754 operation->checkValid();
1755 symbol.getOperation().checkValid();
1756 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1757 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1758 if (mlirAttributeIsNull(symbolAttr))
1759 throw py::value_error("Expected operation to have a symbol name.");
1760 return PyAttribute(
1761 symbol.getOperation().getContext(),
1762 mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1763 }
1764
getSymbolName(PyOperationBase & symbol)1765 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1766 // Op must already be a symbol.
1767 PyOperation &operation = symbol.getOperation();
1768 operation.checkValid();
1769 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1770 MlirAttribute existingNameAttr =
1771 mlirOperationGetAttributeByName(operation.get(), attrName);
1772 if (mlirAttributeIsNull(existingNameAttr))
1773 throw py::value_error("Expected operation to have a symbol name.");
1774 return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1775 }
1776
setSymbolName(PyOperationBase & symbol,const std::string & name)1777 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1778 const std::string &name) {
1779 // Op must already be a symbol.
1780 PyOperation &operation = symbol.getOperation();
1781 operation.checkValid();
1782 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1783 MlirAttribute existingNameAttr =
1784 mlirOperationGetAttributeByName(operation.get(), attrName);
1785 if (mlirAttributeIsNull(existingNameAttr))
1786 throw py::value_error("Expected operation to have a symbol name.");
1787 MlirAttribute newNameAttr =
1788 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1789 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1790 }
1791
getVisibility(PyOperationBase & symbol)1792 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1793 PyOperation &operation = symbol.getOperation();
1794 operation.checkValid();
1795 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1796 MlirAttribute existingVisAttr =
1797 mlirOperationGetAttributeByName(operation.get(), attrName);
1798 if (mlirAttributeIsNull(existingVisAttr))
1799 throw py::value_error("Expected operation to have a symbol visibility.");
1800 return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1801 }
1802
setVisibility(PyOperationBase & symbol,const std::string & visibility)1803 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1804 const std::string &visibility) {
1805 if (visibility != "public" && visibility != "private" &&
1806 visibility != "nested")
1807 throw py::value_error(
1808 "Expected visibility to be 'public', 'private' or 'nested'");
1809 PyOperation &operation = symbol.getOperation();
1810 operation.checkValid();
1811 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1812 MlirAttribute existingVisAttr =
1813 mlirOperationGetAttributeByName(operation.get(), attrName);
1814 if (mlirAttributeIsNull(existingVisAttr))
1815 throw py::value_error("Expected operation to have a symbol visibility.");
1816 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1817 toMlirStringRef(visibility));
1818 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1819 }
1820
replaceAllSymbolUses(const std::string & oldSymbol,const std::string & newSymbol,PyOperationBase & from)1821 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1822 const std::string &newSymbol,
1823 PyOperationBase &from) {
1824 PyOperation &fromOperation = from.getOperation();
1825 fromOperation.checkValid();
1826 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1827 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1828 from.getOperation())))
1829
1830 throw py::value_error("Symbol rename failed");
1831 }
1832
walkSymbolTables(PyOperationBase & from,bool allSymUsesVisible,py::object callback)1833 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1834 bool allSymUsesVisible,
1835 py::object callback) {
1836 PyOperation &fromOperation = from.getOperation();
1837 fromOperation.checkValid();
1838 struct UserData {
1839 PyMlirContextRef context;
1840 py::object callback;
1841 bool gotException;
1842 std::string exceptionWhat;
1843 py::object exceptionType;
1844 };
1845 UserData userData{
1846 fromOperation.getContext(), std::move(callback), false, {}, {}};
1847 mlirSymbolTableWalkSymbolTables(
1848 fromOperation.get(), allSymUsesVisible,
1849 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1850 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1851 auto pyFoundOp =
1852 PyOperation::forOperation(calleeUserData->context, foundOp);
1853 if (calleeUserData->gotException)
1854 return;
1855 try {
1856 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1857 } catch (py::error_already_set &e) {
1858 calleeUserData->gotException = true;
1859 calleeUserData->exceptionWhat = e.what();
1860 calleeUserData->exceptionType = e.type();
1861 }
1862 },
1863 static_cast<void *>(&userData));
1864 if (userData.gotException) {
1865 std::string message("Exception raised in callback: ");
1866 message.append(userData.exceptionWhat);
1867 throw std::runtime_error(message);
1868 }
1869 }
1870
1871 namespace {
1872 /// CRTP base class for Python MLIR values that subclass Value and should be
1873 /// castable from it. The value hierarchy is one level deep and is not supposed
1874 /// to accommodate other levels unless core MLIR changes.
1875 template <typename DerivedTy>
1876 class PyConcreteValue : public PyValue {
1877 public:
1878 // Derived classes must define statics for:
1879 // IsAFunctionTy isaFunction
1880 // const char *pyClassName
1881 // and redefine bindDerived.
1882 using ClassTy = py::class_<DerivedTy, PyValue>;
1883 using IsAFunctionTy = bool (*)(MlirValue);
1884
1885 PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef,MlirValue value)1886 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1887 : PyValue(operationRef, value) {}
PyConcreteValue(PyValue & orig)1888 PyConcreteValue(PyValue &orig)
1889 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1890
1891 /// Attempts to cast the original value to the derived type and throws on
1892 /// type mismatches.
castFrom(PyValue & orig)1893 static MlirValue castFrom(PyValue &orig) {
1894 if (!DerivedTy::isaFunction(orig.get())) {
1895 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1896 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1897 DerivedTy::pyClassName +
1898 " (from " + origRepr + ")");
1899 }
1900 return orig.get();
1901 }
1902
1903 /// Binds the Python module objects to functions of this class.
bind(py::module & m)1904 static void bind(py::module &m) {
1905 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1906 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1907 cls.def_static(
1908 "isinstance",
1909 [](PyValue &otherValue) -> bool {
1910 return DerivedTy::isaFunction(otherValue);
1911 },
1912 py::arg("other_value"));
1913 DerivedTy::bindDerived(cls);
1914 }
1915
1916 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1917 static void bindDerived(ClassTy &m) {}
1918 };
1919
1920 /// Python wrapper for MlirBlockArgument.
1921 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1922 public:
1923 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1924 static constexpr const char *pyClassName = "BlockArgument";
1925 using PyConcreteValue::PyConcreteValue;
1926
bindDerived(ClassTy & c)1927 static void bindDerived(ClassTy &c) {
1928 c.def_property_readonly("owner", [](PyBlockArgument &self) {
1929 return PyBlock(self.getParentOperation(),
1930 mlirBlockArgumentGetOwner(self.get()));
1931 });
1932 c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1933 return mlirBlockArgumentGetArgNumber(self.get());
1934 });
1935 c.def(
1936 "set_type",
1937 [](PyBlockArgument &self, PyType type) {
1938 return mlirBlockArgumentSetType(self.get(), type);
1939 },
1940 py::arg("type"));
1941 }
1942 };
1943
1944 /// Python wrapper for MlirOpResult.
1945 class PyOpResult : public PyConcreteValue<PyOpResult> {
1946 public:
1947 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1948 static constexpr const char *pyClassName = "OpResult";
1949 using PyConcreteValue::PyConcreteValue;
1950
bindDerived(ClassTy & c)1951 static void bindDerived(ClassTy &c) {
1952 c.def_property_readonly("owner", [](PyOpResult &self) {
1953 assert(
1954 mlirOperationEqual(self.getParentOperation()->get(),
1955 mlirOpResultGetOwner(self.get())) &&
1956 "expected the owner of the value in Python to match that in the IR");
1957 return self.getParentOperation().getObject();
1958 });
1959 c.def_property_readonly("result_number", [](PyOpResult &self) {
1960 return mlirOpResultGetResultNumber(self.get());
1961 });
1962 }
1963 };
1964
1965 /// Returns the list of types of the values held by container.
1966 template <typename Container>
getValueTypes(Container & container,PyMlirContextRef & context)1967 static std::vector<PyType> getValueTypes(Container &container,
1968 PyMlirContextRef &context) {
1969 std::vector<PyType> result;
1970 result.reserve(container.size());
1971 for (int i = 0, e = container.size(); i < e; ++i) {
1972 result.push_back(
1973 PyType(context, mlirValueGetType(container.getElement(i).get())));
1974 }
1975 return result;
1976 }
1977
1978 /// A list of block arguments. Internally, these are stored as consecutive
1979 /// elements, random access is cheap. The argument list is associated with the
1980 /// operation that contains the block (detached blocks are not allowed in
1981 /// Python bindings) and extends its lifetime.
1982 class PyBlockArgumentList
1983 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1984 public:
1985 static constexpr const char *pyClassName = "BlockArgumentList";
1986
PyBlockArgumentList(PyOperationRef operation,MlirBlock block,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1987 PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1988 intptr_t startIndex = 0, intptr_t length = -1,
1989 intptr_t step = 1)
1990 : Sliceable(startIndex,
1991 length == -1 ? mlirBlockGetNumArguments(block) : length,
1992 step),
1993 operation(std::move(operation)), block(block) {}
1994
bindDerived(ClassTy & c)1995 static void bindDerived(ClassTy &c) {
1996 c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1997 return getValueTypes(self, self.operation->getContext());
1998 });
1999 }
2000
2001 private:
2002 /// Give the parent CRTP class access to hook implementations below.
2003 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2004
2005 /// Returns the number of arguments in the list.
getRawNumElements()2006 intptr_t getRawNumElements() {
2007 operation->checkValid();
2008 return mlirBlockGetNumArguments(block);
2009 }
2010
2011 /// Returns `pos`-the element in the list.
getRawElement(intptr_t pos)2012 PyBlockArgument getRawElement(intptr_t pos) {
2013 MlirValue argument = mlirBlockGetArgument(block, pos);
2014 return PyBlockArgument(operation, argument);
2015 }
2016
2017 /// Returns a sublist of this list.
slice(intptr_t startIndex,intptr_t length,intptr_t step)2018 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2019 intptr_t step) {
2020 return PyBlockArgumentList(operation, block, startIndex, length, step);
2021 }
2022
2023 PyOperationRef operation;
2024 MlirBlock block;
2025 };
2026
2027 /// A list of operation operands. Internally, these are stored as consecutive
2028 /// elements, random access is cheap. The result list is associated with the
2029 /// operation whose results these are, and extends the lifetime of this
2030 /// operation.
2031 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2032 public:
2033 static constexpr const char *pyClassName = "OpOperandList";
2034
PyOpOperandList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)2035 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2036 intptr_t length = -1, intptr_t step = 1)
2037 : Sliceable(startIndex,
2038 length == -1 ? mlirOperationGetNumOperands(operation->get())
2039 : length,
2040 step),
2041 operation(operation) {}
2042
dunderSetItem(intptr_t index,PyValue value)2043 void dunderSetItem(intptr_t index, PyValue value) {
2044 index = wrapIndex(index);
2045 mlirOperationSetOperand(operation->get(), index, value.get());
2046 }
2047
bindDerived(ClassTy & c)2048 static void bindDerived(ClassTy &c) {
2049 c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2050 }
2051
2052 private:
2053 /// Give the parent CRTP class access to hook implementations below.
2054 friend class Sliceable<PyOpOperandList, PyValue>;
2055
getRawNumElements()2056 intptr_t getRawNumElements() {
2057 operation->checkValid();
2058 return mlirOperationGetNumOperands(operation->get());
2059 }
2060
getRawElement(intptr_t pos)2061 PyValue getRawElement(intptr_t pos) {
2062 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2063 MlirOperation owner;
2064 if (mlirValueIsAOpResult(operand))
2065 owner = mlirOpResultGetOwner(operand);
2066 else if (mlirValueIsABlockArgument(operand))
2067 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2068 else
2069 assert(false && "Value must be an block arg or op result.");
2070 PyOperationRef pyOwner =
2071 PyOperation::forOperation(operation->getContext(), owner);
2072 return PyValue(pyOwner, operand);
2073 }
2074
slice(intptr_t startIndex,intptr_t length,intptr_t step)2075 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2076 return PyOpOperandList(operation, startIndex, length, step);
2077 }
2078
2079 PyOperationRef operation;
2080 };
2081
2082 /// A list of operation results. Internally, these are stored as consecutive
2083 /// elements, random access is cheap. The result list is associated with the
2084 /// operation whose results these are, and extends the lifetime of this
2085 /// operation.
2086 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2087 public:
2088 static constexpr const char *pyClassName = "OpResultList";
2089
PyOpResultList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)2090 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2091 intptr_t length = -1, intptr_t step = 1)
2092 : Sliceable(startIndex,
2093 length == -1 ? mlirOperationGetNumResults(operation->get())
2094 : length,
2095 step),
2096 operation(operation) {}
2097
bindDerived(ClassTy & c)2098 static void bindDerived(ClassTy &c) {
2099 c.def_property_readonly("types", [](PyOpResultList &self) {
2100 return getValueTypes(self, self.operation->getContext());
2101 });
2102 }
2103
2104 private:
2105 /// Give the parent CRTP class access to hook implementations below.
2106 friend class Sliceable<PyOpResultList, PyOpResult>;
2107
getRawNumElements()2108 intptr_t getRawNumElements() {
2109 operation->checkValid();
2110 return mlirOperationGetNumResults(operation->get());
2111 }
2112
getRawElement(intptr_t index)2113 PyOpResult getRawElement(intptr_t index) {
2114 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2115 return PyOpResult(value);
2116 }
2117
slice(intptr_t startIndex,intptr_t length,intptr_t step)2118 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2119 return PyOpResultList(operation, startIndex, length, step);
2120 }
2121
2122 PyOperationRef operation;
2123 };
2124
2125 /// A list of operation attributes. Can be indexed by name, producing
2126 /// attributes, or by index, producing named attributes.
2127 class PyOpAttributeMap {
2128 public:
PyOpAttributeMap(PyOperationRef operation)2129 PyOpAttributeMap(PyOperationRef operation)
2130 : operation(std::move(operation)) {}
2131
dunderGetItemNamed(const std::string & name)2132 PyAttribute dunderGetItemNamed(const std::string &name) {
2133 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2134 toMlirStringRef(name));
2135 if (mlirAttributeIsNull(attr)) {
2136 throw SetPyError(PyExc_KeyError,
2137 "attempt to access a non-existent attribute");
2138 }
2139 return PyAttribute(operation->getContext(), attr);
2140 }
2141
dunderGetItemIndexed(intptr_t index)2142 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2143 if (index < 0 || index >= dunderLen()) {
2144 throw SetPyError(PyExc_IndexError,
2145 "attempt to access out of bounds attribute");
2146 }
2147 MlirNamedAttribute namedAttr =
2148 mlirOperationGetAttribute(operation->get(), index);
2149 return PyNamedAttribute(
2150 namedAttr.attribute,
2151 std::string(mlirIdentifierStr(namedAttr.name).data,
2152 mlirIdentifierStr(namedAttr.name).length));
2153 }
2154
dunderSetItem(const std::string & name,const PyAttribute & attr)2155 void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2156 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2157 attr);
2158 }
2159
dunderDelItem(const std::string & name)2160 void dunderDelItem(const std::string &name) {
2161 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2162 toMlirStringRef(name));
2163 if (!removed)
2164 throw SetPyError(PyExc_KeyError,
2165 "attempt to delete a non-existent attribute");
2166 }
2167
dunderLen()2168 intptr_t dunderLen() {
2169 return mlirOperationGetNumAttributes(operation->get());
2170 }
2171
dunderContains(const std::string & name)2172 bool dunderContains(const std::string &name) {
2173 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2174 operation->get(), toMlirStringRef(name)));
2175 }
2176
bind(py::module & m)2177 static void bind(py::module &m) {
2178 py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2179 .def("__contains__", &PyOpAttributeMap::dunderContains)
2180 .def("__len__", &PyOpAttributeMap::dunderLen)
2181 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2182 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2183 .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2184 .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2185 }
2186
2187 private:
2188 PyOperationRef operation;
2189 };
2190
2191 } // namespace
2192
2193 //------------------------------------------------------------------------------
2194 // Populates the core exports of the 'ir' submodule.
2195 //------------------------------------------------------------------------------
2196
populateIRCore(py::module & m)2197 void mlir::python::populateIRCore(py::module &m) {
2198 //----------------------------------------------------------------------------
2199 // Enums.
2200 //----------------------------------------------------------------------------
2201 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2202 .value("ERROR", MlirDiagnosticError)
2203 .value("WARNING", MlirDiagnosticWarning)
2204 .value("NOTE", MlirDiagnosticNote)
2205 .value("REMARK", MlirDiagnosticRemark);
2206
2207 //----------------------------------------------------------------------------
2208 // Mapping of Diagnostics.
2209 //----------------------------------------------------------------------------
2210 py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2211 .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2212 .def_property_readonly("location", &PyDiagnostic::getLocation)
2213 .def_property_readonly("message", &PyDiagnostic::getMessage)
2214 .def_property_readonly("notes", &PyDiagnostic::getNotes)
2215 .def("__str__", [](PyDiagnostic &self) -> py::str {
2216 if (!self.isValid())
2217 return "<Invalid Diagnostic>";
2218 return self.getMessage();
2219 });
2220
2221 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2222 .def("detach", &PyDiagnosticHandler::detach)
2223 .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2224 .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2225 .def("__enter__", &PyDiagnosticHandler::contextEnter)
2226 .def("__exit__", &PyDiagnosticHandler::contextExit);
2227
2228 //----------------------------------------------------------------------------
2229 // Mapping of MlirContext.
2230 // Note that this is exported as _BaseContext. The containing, Python level
2231 // __init__.py will subclass it with site-specific functionality and set a
2232 // "Context" attribute on this module.
2233 //----------------------------------------------------------------------------
2234 py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2235 .def(py::init<>(&PyMlirContext::createNewContextForInit))
2236 .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2237 .def("_get_context_again",
2238 [](PyMlirContext &self) {
2239 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2240 return ref.releaseObject();
2241 })
2242 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2243 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2244 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2245 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2246 &PyMlirContext::getCapsule)
2247 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2248 .def("__enter__", &PyMlirContext::contextEnter)
2249 .def("__exit__", &PyMlirContext::contextExit)
2250 .def_property_readonly_static(
2251 "current",
2252 [](py::object & /*class*/) {
2253 auto *context = PyThreadContextEntry::getDefaultContext();
2254 if (!context)
2255 throw SetPyError(PyExc_ValueError, "No current Context");
2256 return context;
2257 },
2258 "Gets the Context bound to the current thread or raises ValueError")
2259 .def_property_readonly(
2260 "dialects",
2261 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2262 "Gets a container for accessing dialects by name")
2263 .def_property_readonly(
2264 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2265 "Alias for 'dialect'")
2266 .def(
2267 "get_dialect_descriptor",
2268 [=](PyMlirContext &self, std::string &name) {
2269 MlirDialect dialect = mlirContextGetOrLoadDialect(
2270 self.get(), {name.data(), name.size()});
2271 if (mlirDialectIsNull(dialect)) {
2272 throw SetPyError(PyExc_ValueError,
2273 Twine("Dialect '") + name + "' not found");
2274 }
2275 return PyDialectDescriptor(self.getRef(), dialect);
2276 },
2277 py::arg("dialect_name"),
2278 "Gets or loads a dialect by name, returning its descriptor object")
2279 .def_property(
2280 "allow_unregistered_dialects",
2281 [](PyMlirContext &self) -> bool {
2282 return mlirContextGetAllowUnregisteredDialects(self.get());
2283 },
2284 [](PyMlirContext &self, bool value) {
2285 mlirContextSetAllowUnregisteredDialects(self.get(), value);
2286 })
2287 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2288 py::arg("callback"),
2289 "Attaches a diagnostic handler that will receive callbacks")
2290 .def(
2291 "enable_multithreading",
2292 [](PyMlirContext &self, bool enable) {
2293 mlirContextEnableMultithreading(self.get(), enable);
2294 },
2295 py::arg("enable"))
2296 .def(
2297 "is_registered_operation",
2298 [](PyMlirContext &self, std::string &name) {
2299 return mlirContextIsRegisteredOperation(
2300 self.get(), MlirStringRef{name.data(), name.size()});
2301 },
2302 py::arg("operation_name"))
2303 .def(
2304 "append_dialect_registry",
2305 [](PyMlirContext &self, PyDialectRegistry ®istry) {
2306 mlirContextAppendDialectRegistry(self.get(), registry);
2307 },
2308 py::arg("registry"))
2309 .def("load_all_available_dialects", [](PyMlirContext &self) {
2310 mlirContextLoadAllAvailableDialects(self.get());
2311 });
2312
2313 //----------------------------------------------------------------------------
2314 // Mapping of PyDialectDescriptor
2315 //----------------------------------------------------------------------------
2316 py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2317 .def_property_readonly("namespace",
2318 [](PyDialectDescriptor &self) {
2319 MlirStringRef ns =
2320 mlirDialectGetNamespace(self.get());
2321 return py::str(ns.data, ns.length);
2322 })
2323 .def("__repr__", [](PyDialectDescriptor &self) {
2324 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2325 std::string repr("<DialectDescriptor ");
2326 repr.append(ns.data, ns.length);
2327 repr.append(">");
2328 return repr;
2329 });
2330
2331 //----------------------------------------------------------------------------
2332 // Mapping of PyDialects
2333 //----------------------------------------------------------------------------
2334 py::class_<PyDialects>(m, "Dialects", py::module_local())
2335 .def("__getitem__",
2336 [=](PyDialects &self, std::string keyName) {
2337 MlirDialect dialect =
2338 self.getDialectForKey(keyName, /*attrError=*/false);
2339 py::object descriptor =
2340 py::cast(PyDialectDescriptor{self.getContext(), dialect});
2341 return createCustomDialectWrapper(keyName, std::move(descriptor));
2342 })
2343 .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2344 MlirDialect dialect =
2345 self.getDialectForKey(attrName, /*attrError=*/true);
2346 py::object descriptor =
2347 py::cast(PyDialectDescriptor{self.getContext(), dialect});
2348 return createCustomDialectWrapper(attrName, std::move(descriptor));
2349 });
2350
2351 //----------------------------------------------------------------------------
2352 // Mapping of PyDialect
2353 //----------------------------------------------------------------------------
2354 py::class_<PyDialect>(m, "Dialect", py::module_local())
2355 .def(py::init<py::object>(), py::arg("descriptor"))
2356 .def_property_readonly(
2357 "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2358 .def("__repr__", [](py::object self) {
2359 auto clazz = self.attr("__class__");
2360 return py::str("<Dialect ") +
2361 self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2362 clazz.attr("__module__") + py::str(".") +
2363 clazz.attr("__name__") + py::str(")>");
2364 });
2365
2366 //----------------------------------------------------------------------------
2367 // Mapping of PyDialectRegistry
2368 //----------------------------------------------------------------------------
2369 py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2370 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2371 &PyDialectRegistry::getCapsule)
2372 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2373 .def(py::init<>());
2374
2375 //----------------------------------------------------------------------------
2376 // Mapping of Location
2377 //----------------------------------------------------------------------------
2378 py::class_<PyLocation>(m, "Location", py::module_local())
2379 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2380 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2381 .def("__enter__", &PyLocation::contextEnter)
2382 .def("__exit__", &PyLocation::contextExit)
2383 .def("__eq__",
2384 [](PyLocation &self, PyLocation &other) -> bool {
2385 return mlirLocationEqual(self, other);
2386 })
2387 .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2388 .def_property_readonly_static(
2389 "current",
2390 [](py::object & /*class*/) {
2391 auto *loc = PyThreadContextEntry::getDefaultLocation();
2392 if (!loc)
2393 throw SetPyError(PyExc_ValueError, "No current Location");
2394 return loc;
2395 },
2396 "Gets the Location bound to the current thread or raises ValueError")
2397 .def_static(
2398 "unknown",
2399 [](DefaultingPyMlirContext context) {
2400 return PyLocation(context->getRef(),
2401 mlirLocationUnknownGet(context->get()));
2402 },
2403 py::arg("context") = py::none(),
2404 "Gets a Location representing an unknown location")
2405 .def_static(
2406 "callsite",
2407 [](PyLocation callee, const std::vector<PyLocation> &frames,
2408 DefaultingPyMlirContext context) {
2409 if (frames.empty())
2410 throw py::value_error("No caller frames provided");
2411 MlirLocation caller = frames.back().get();
2412 for (const PyLocation &frame :
2413 llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2414 caller = mlirLocationCallSiteGet(frame.get(), caller);
2415 return PyLocation(context->getRef(),
2416 mlirLocationCallSiteGet(callee.get(), caller));
2417 },
2418 py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2419 kContextGetCallSiteLocationDocstring)
2420 .def_static(
2421 "file",
2422 [](std::string filename, int line, int col,
2423 DefaultingPyMlirContext context) {
2424 return PyLocation(
2425 context->getRef(),
2426 mlirLocationFileLineColGet(
2427 context->get(), toMlirStringRef(filename), line, col));
2428 },
2429 py::arg("filename"), py::arg("line"), py::arg("col"),
2430 py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2431 .def_static(
2432 "fused",
2433 [](const std::vector<PyLocation> &pyLocations,
2434 llvm::Optional<PyAttribute> metadata,
2435 DefaultingPyMlirContext context) {
2436 llvm::SmallVector<MlirLocation, 4> locations;
2437 locations.reserve(pyLocations.size());
2438 for (auto &pyLocation : pyLocations)
2439 locations.push_back(pyLocation.get());
2440 MlirLocation location = mlirLocationFusedGet(
2441 context->get(), locations.size(), locations.data(),
2442 metadata ? metadata->get() : MlirAttribute{0});
2443 return PyLocation(context->getRef(), location);
2444 },
2445 py::arg("locations"), py::arg("metadata") = py::none(),
2446 py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2447 .def_static(
2448 "name",
2449 [](std::string name, llvm::Optional<PyLocation> childLoc,
2450 DefaultingPyMlirContext context) {
2451 return PyLocation(
2452 context->getRef(),
2453 mlirLocationNameGet(
2454 context->get(), toMlirStringRef(name),
2455 childLoc ? childLoc->get()
2456 : mlirLocationUnknownGet(context->get())));
2457 },
2458 py::arg("name"), py::arg("childLoc") = py::none(),
2459 py::arg("context") = py::none(), kContextGetNameLocationDocString)
2460 .def_property_readonly(
2461 "context",
2462 [](PyLocation &self) { return self.getContext().getObject(); },
2463 "Context that owns the Location")
2464 .def(
2465 "emit_error",
2466 [](PyLocation &self, std::string message) {
2467 mlirEmitError(self, message.c_str());
2468 },
2469 py::arg("message"), "Emits an error at this location")
2470 .def("__repr__", [](PyLocation &self) {
2471 PyPrintAccumulator printAccum;
2472 mlirLocationPrint(self, printAccum.getCallback(),
2473 printAccum.getUserData());
2474 return printAccum.join();
2475 });
2476
2477 //----------------------------------------------------------------------------
2478 // Mapping of Module
2479 //----------------------------------------------------------------------------
2480 py::class_<PyModule>(m, "Module", py::module_local())
2481 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2482 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2483 .def_static(
2484 "parse",
2485 [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2486 MlirModule module = mlirModuleCreateParse(
2487 context->get(), toMlirStringRef(moduleAsm));
2488 // TODO: Rework error reporting once diagnostic engine is exposed
2489 // in C API.
2490 if (mlirModuleIsNull(module)) {
2491 throw SetPyError(
2492 PyExc_ValueError,
2493 "Unable to parse module assembly (see diagnostics)");
2494 }
2495 return PyModule::forModule(module).releaseObject();
2496 },
2497 py::arg("asm"), py::arg("context") = py::none(),
2498 kModuleParseDocstring)
2499 .def_static(
2500 "create",
2501 [](DefaultingPyLocation loc) {
2502 MlirModule module = mlirModuleCreateEmpty(loc);
2503 return PyModule::forModule(module).releaseObject();
2504 },
2505 py::arg("loc") = py::none(), "Creates an empty module")
2506 .def_property_readonly(
2507 "context",
2508 [](PyModule &self) { return self.getContext().getObject(); },
2509 "Context that created the Module")
2510 .def_property_readonly(
2511 "operation",
2512 [](PyModule &self) {
2513 return PyOperation::forOperation(self.getContext(),
2514 mlirModuleGetOperation(self.get()),
2515 self.getRef().releaseObject())
2516 .releaseObject();
2517 },
2518 "Accesses the module as an operation")
2519 .def_property_readonly(
2520 "body",
2521 [](PyModule &self) {
2522 PyOperationRef moduleOp = PyOperation::forOperation(
2523 self.getContext(), mlirModuleGetOperation(self.get()),
2524 self.getRef().releaseObject());
2525 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2526 return returnBlock;
2527 },
2528 "Return the block for this module")
2529 .def(
2530 "dump",
2531 [](PyModule &self) {
2532 mlirOperationDump(mlirModuleGetOperation(self.get()));
2533 },
2534 kDumpDocstring)
2535 .def(
2536 "__str__",
2537 [](py::object self) {
2538 // Defer to the operation's __str__.
2539 return self.attr("operation").attr("__str__")();
2540 },
2541 kOperationStrDunderDocstring);
2542
2543 //----------------------------------------------------------------------------
2544 // Mapping of Operation.
2545 //----------------------------------------------------------------------------
2546 py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2547 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2548 [](PyOperationBase &self) {
2549 return self.getOperation().getCapsule();
2550 })
2551 .def("__eq__",
2552 [](PyOperationBase &self, PyOperationBase &other) {
2553 return &self.getOperation() == &other.getOperation();
2554 })
2555 .def("__eq__",
2556 [](PyOperationBase &self, py::object other) { return false; })
2557 .def("__hash__",
2558 [](PyOperationBase &self) {
2559 return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2560 })
2561 .def_property_readonly("attributes",
2562 [](PyOperationBase &self) {
2563 return PyOpAttributeMap(
2564 self.getOperation().getRef());
2565 })
2566 .def_property_readonly("operands",
2567 [](PyOperationBase &self) {
2568 return PyOpOperandList(
2569 self.getOperation().getRef());
2570 })
2571 .def_property_readonly("regions",
2572 [](PyOperationBase &self) {
2573 return PyRegionList(
2574 self.getOperation().getRef());
2575 })
2576 .def_property_readonly(
2577 "results",
2578 [](PyOperationBase &self) {
2579 return PyOpResultList(self.getOperation().getRef());
2580 },
2581 "Returns the list of Operation results.")
2582 .def_property_readonly(
2583 "result",
2584 [](PyOperationBase &self) {
2585 auto &operation = self.getOperation();
2586 auto numResults = mlirOperationGetNumResults(operation);
2587 if (numResults != 1) {
2588 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2589 throw SetPyError(
2590 PyExc_ValueError,
2591 Twine("Cannot call .result on operation ") +
2592 StringRef(name.data, name.length) + " which has " +
2593 Twine(numResults) +
2594 " results (it is only valid for operations with a "
2595 "single result)");
2596 }
2597 return PyOpResult(operation.getRef(),
2598 mlirOperationGetResult(operation, 0));
2599 },
2600 "Shortcut to get an op result if it has only one (throws an error "
2601 "otherwise).")
2602 .def_property_readonly(
2603 "location",
2604 [](PyOperationBase &self) {
2605 PyOperation &operation = self.getOperation();
2606 return PyLocation(operation.getContext(),
2607 mlirOperationGetLocation(operation.get()));
2608 },
2609 "Returns the source location the operation was defined or derived "
2610 "from.")
2611 .def(
2612 "__str__",
2613 [](PyOperationBase &self) {
2614 return self.getAsm(/*binary=*/false,
2615 /*largeElementsLimit=*/llvm::None,
2616 /*enableDebugInfo=*/false,
2617 /*prettyDebugInfo=*/false,
2618 /*printGenericOpForm=*/false,
2619 /*useLocalScope=*/false,
2620 /*assumeVerified=*/false);
2621 },
2622 "Returns the assembly form of the operation.")
2623 .def("print", &PyOperationBase::print,
2624 // Careful: Lots of arguments must match up with print method.
2625 py::arg("file") = py::none(), py::arg("binary") = false,
2626 py::arg("large_elements_limit") = py::none(),
2627 py::arg("enable_debug_info") = false,
2628 py::arg("pretty_debug_info") = false,
2629 py::arg("print_generic_op_form") = false,
2630 py::arg("use_local_scope") = false,
2631 py::arg("assume_verified") = false, kOperationPrintDocstring)
2632 .def("get_asm", &PyOperationBase::getAsm,
2633 // Careful: Lots of arguments must match up with get_asm method.
2634 py::arg("binary") = false,
2635 py::arg("large_elements_limit") = py::none(),
2636 py::arg("enable_debug_info") = false,
2637 py::arg("pretty_debug_info") = false,
2638 py::arg("print_generic_op_form") = false,
2639 py::arg("use_local_scope") = false,
2640 py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2641 .def(
2642 "verify",
2643 [](PyOperationBase &self) {
2644 return mlirOperationVerify(self.getOperation());
2645 },
2646 "Verify the operation and return true if it passes, false if it "
2647 "fails.")
2648 .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2649 "Puts self immediately after the other operation in its parent "
2650 "block.")
2651 .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2652 "Puts self immediately before the other operation in its parent "
2653 "block.")
2654 .def(
2655 "detach_from_parent",
2656 [](PyOperationBase &self) {
2657 PyOperation &operation = self.getOperation();
2658 operation.checkValid();
2659 if (!operation.isAttached())
2660 throw py::value_error("Detached operation has no parent.");
2661
2662 operation.detachFromParent();
2663 return operation.createOpView();
2664 },
2665 "Detaches the operation from its parent block.");
2666
2667 py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2668 .def_static("create", &PyOperation::create, py::arg("name"),
2669 py::arg("results") = py::none(),
2670 py::arg("operands") = py::none(),
2671 py::arg("attributes") = py::none(),
2672 py::arg("successors") = py::none(), py::arg("regions") = 0,
2673 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2674 kOperationCreateDocstring)
2675 .def_property_readonly("parent",
2676 [](PyOperation &self) -> py::object {
2677 auto parent = self.getParentOperation();
2678 if (parent)
2679 return parent->getObject();
2680 return py::none();
2681 })
2682 .def("erase", &PyOperation::erase)
2683 .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
2684 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2685 &PyOperation::getCapsule)
2686 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2687 .def_property_readonly("name",
2688 [](PyOperation &self) {
2689 self.checkValid();
2690 MlirOperation operation = self.get();
2691 MlirStringRef name = mlirIdentifierStr(
2692 mlirOperationGetName(operation));
2693 return py::str(name.data, name.length);
2694 })
2695 .def_property_readonly(
2696 "context",
2697 [](PyOperation &self) {
2698 self.checkValid();
2699 return self.getContext().getObject();
2700 },
2701 "Context that owns the Operation")
2702 .def_property_readonly("opview", &PyOperation::createOpView);
2703
2704 auto opViewClass =
2705 py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2706 .def(py::init<py::object>(), py::arg("operation"))
2707 .def_property_readonly("operation", &PyOpView::getOperationObject)
2708 .def_property_readonly(
2709 "context",
2710 [](PyOpView &self) {
2711 return self.getOperation().getContext().getObject();
2712 },
2713 "Context that owns the Operation")
2714 .def("__str__", [](PyOpView &self) {
2715 return py::str(self.getOperationObject());
2716 });
2717 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2718 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2719 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2720 opViewClass.attr("build_generic") = classmethod(
2721 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2722 py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2723 py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2724 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2725 "Builds a specific, generated OpView based on class level attributes.");
2726
2727 //----------------------------------------------------------------------------
2728 // Mapping of PyRegion.
2729 //----------------------------------------------------------------------------
2730 py::class_<PyRegion>(m, "Region", py::module_local())
2731 .def_property_readonly(
2732 "blocks",
2733 [](PyRegion &self) {
2734 return PyBlockList(self.getParentOperation(), self.get());
2735 },
2736 "Returns a forward-optimized sequence of blocks.")
2737 .def_property_readonly(
2738 "owner",
2739 [](PyRegion &self) {
2740 return self.getParentOperation()->createOpView();
2741 },
2742 "Returns the operation owning this region.")
2743 .def(
2744 "__iter__",
2745 [](PyRegion &self) {
2746 self.checkValid();
2747 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2748 return PyBlockIterator(self.getParentOperation(), firstBlock);
2749 },
2750 "Iterates over blocks in the region.")
2751 .def("__eq__",
2752 [](PyRegion &self, PyRegion &other) {
2753 return self.get().ptr == other.get().ptr;
2754 })
2755 .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2756
2757 //----------------------------------------------------------------------------
2758 // Mapping of PyBlock.
2759 //----------------------------------------------------------------------------
2760 py::class_<PyBlock>(m, "Block", py::module_local())
2761 .def_property_readonly(
2762 "owner",
2763 [](PyBlock &self) {
2764 return self.getParentOperation()->createOpView();
2765 },
2766 "Returns the owning operation of this block.")
2767 .def_property_readonly(
2768 "region",
2769 [](PyBlock &self) {
2770 MlirRegion region = mlirBlockGetParentRegion(self.get());
2771 return PyRegion(self.getParentOperation(), region);
2772 },
2773 "Returns the owning region of this block.")
2774 .def_property_readonly(
2775 "arguments",
2776 [](PyBlock &self) {
2777 return PyBlockArgumentList(self.getParentOperation(), self.get());
2778 },
2779 "Returns a list of block arguments.")
2780 .def_property_readonly(
2781 "operations",
2782 [](PyBlock &self) {
2783 return PyOperationList(self.getParentOperation(), self.get());
2784 },
2785 "Returns a forward-optimized sequence of operations.")
2786 .def_static(
2787 "create_at_start",
2788 [](PyRegion &parent, py::list pyArgTypes) {
2789 parent.checkValid();
2790 llvm::SmallVector<MlirType, 4> argTypes;
2791 llvm::SmallVector<MlirLocation, 4> argLocs;
2792 argTypes.reserve(pyArgTypes.size());
2793 argLocs.reserve(pyArgTypes.size());
2794 for (auto &pyArg : pyArgTypes) {
2795 argTypes.push_back(pyArg.cast<PyType &>());
2796 // TODO: Pass in a proper location here.
2797 argLocs.push_back(
2798 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2799 }
2800
2801 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2802 argLocs.data());
2803 mlirRegionInsertOwnedBlock(parent, 0, block);
2804 return PyBlock(parent.getParentOperation(), block);
2805 },
2806 py::arg("parent"), py::arg("arg_types") = py::list(),
2807 "Creates and returns a new Block at the beginning of the given "
2808 "region (with given argument types).")
2809 .def(
2810 "append_to",
2811 [](PyBlock &self, PyRegion ®ion) {
2812 MlirBlock b = self.get();
2813 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
2814 mlirBlockDetach(b);
2815 mlirRegionAppendOwnedBlock(region.get(), b);
2816 },
2817 "Append this block to a region, transferring ownership if necessary")
2818 .def(
2819 "create_before",
2820 [](PyBlock &self, py::args pyArgTypes) {
2821 self.checkValid();
2822 llvm::SmallVector<MlirType, 4> argTypes;
2823 llvm::SmallVector<MlirLocation, 4> argLocs;
2824 argTypes.reserve(pyArgTypes.size());
2825 argLocs.reserve(pyArgTypes.size());
2826 for (auto &pyArg : pyArgTypes) {
2827 argTypes.push_back(pyArg.cast<PyType &>());
2828 // TODO: Pass in a proper location here.
2829 argLocs.push_back(
2830 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2831 }
2832
2833 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2834 argLocs.data());
2835 MlirRegion region = mlirBlockGetParentRegion(self.get());
2836 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2837 return PyBlock(self.getParentOperation(), block);
2838 },
2839 "Creates and returns a new Block before this block "
2840 "(with given argument types).")
2841 .def(
2842 "create_after",
2843 [](PyBlock &self, py::args pyArgTypes) {
2844 self.checkValid();
2845 llvm::SmallVector<MlirType, 4> argTypes;
2846 llvm::SmallVector<MlirLocation, 4> argLocs;
2847 argTypes.reserve(pyArgTypes.size());
2848 argLocs.reserve(pyArgTypes.size());
2849 for (auto &pyArg : pyArgTypes) {
2850 argTypes.push_back(pyArg.cast<PyType &>());
2851
2852 // TODO: Pass in a proper location here.
2853 argLocs.push_back(
2854 mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2855 }
2856 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2857 argLocs.data());
2858 MlirRegion region = mlirBlockGetParentRegion(self.get());
2859 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2860 return PyBlock(self.getParentOperation(), block);
2861 },
2862 "Creates and returns a new Block after this block "
2863 "(with given argument types).")
2864 .def(
2865 "__iter__",
2866 [](PyBlock &self) {
2867 self.checkValid();
2868 MlirOperation firstOperation =
2869 mlirBlockGetFirstOperation(self.get());
2870 return PyOperationIterator(self.getParentOperation(),
2871 firstOperation);
2872 },
2873 "Iterates over operations in the block.")
2874 .def("__eq__",
2875 [](PyBlock &self, PyBlock &other) {
2876 return self.get().ptr == other.get().ptr;
2877 })
2878 .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2879 .def(
2880 "__str__",
2881 [](PyBlock &self) {
2882 self.checkValid();
2883 PyPrintAccumulator printAccum;
2884 mlirBlockPrint(self.get(), printAccum.getCallback(),
2885 printAccum.getUserData());
2886 return printAccum.join();
2887 },
2888 "Returns the assembly form of the block.")
2889 .def(
2890 "append",
2891 [](PyBlock &self, PyOperationBase &operation) {
2892 if (operation.getOperation().isAttached())
2893 operation.getOperation().detachFromParent();
2894
2895 MlirOperation mlirOperation = operation.getOperation().get();
2896 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2897 operation.getOperation().setAttached(
2898 self.getParentOperation().getObject());
2899 },
2900 py::arg("operation"),
2901 "Appends an operation to this block. If the operation is currently "
2902 "in another block, it will be moved.");
2903
2904 //----------------------------------------------------------------------------
2905 // Mapping of PyInsertionPoint.
2906 //----------------------------------------------------------------------------
2907
2908 py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2909 .def(py::init<PyBlock &>(), py::arg("block"),
2910 "Inserts after the last operation but still inside the block.")
2911 .def("__enter__", &PyInsertionPoint::contextEnter)
2912 .def("__exit__", &PyInsertionPoint::contextExit)
2913 .def_property_readonly_static(
2914 "current",
2915 [](py::object & /*class*/) {
2916 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2917 if (!ip)
2918 throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2919 return ip;
2920 },
2921 "Gets the InsertionPoint bound to the current thread or raises "
2922 "ValueError if none has been set")
2923 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2924 "Inserts before a referenced operation.")
2925 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2926 py::arg("block"), "Inserts at the beginning of the block.")
2927 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2928 py::arg("block"), "Inserts before the block terminator.")
2929 .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2930 "Inserts an operation.")
2931 .def_property_readonly(
2932 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2933 "Returns the block that this InsertionPoint points to.");
2934
2935 //----------------------------------------------------------------------------
2936 // Mapping of PyAttribute.
2937 //----------------------------------------------------------------------------
2938 py::class_<PyAttribute>(m, "Attribute", py::module_local())
2939 // Delegate to the PyAttribute copy constructor, which will also lifetime
2940 // extend the backing context which owns the MlirAttribute.
2941 .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2942 "Casts the passed attribute to the generic Attribute")
2943 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2944 &PyAttribute::getCapsule)
2945 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2946 .def_static(
2947 "parse",
2948 [](std::string attrSpec, DefaultingPyMlirContext context) {
2949 MlirAttribute type = mlirAttributeParseGet(
2950 context->get(), toMlirStringRef(attrSpec));
2951 // TODO: Rework error reporting once diagnostic engine is exposed
2952 // in C API.
2953 if (mlirAttributeIsNull(type)) {
2954 throw SetPyError(PyExc_ValueError,
2955 Twine("Unable to parse attribute: '") +
2956 attrSpec + "'");
2957 }
2958 return PyAttribute(context->getRef(), type);
2959 },
2960 py::arg("asm"), py::arg("context") = py::none(),
2961 "Parses an attribute from an assembly form")
2962 .def_property_readonly(
2963 "context",
2964 [](PyAttribute &self) { return self.getContext().getObject(); },
2965 "Context that owns the Attribute")
2966 .def_property_readonly("type",
2967 [](PyAttribute &self) {
2968 return PyType(self.getContext()->getRef(),
2969 mlirAttributeGetType(self));
2970 })
2971 .def(
2972 "get_named",
2973 [](PyAttribute &self, std::string name) {
2974 return PyNamedAttribute(self, std::move(name));
2975 },
2976 py::keep_alive<0, 1>(), "Binds a name to the attribute")
2977 .def("__eq__",
2978 [](PyAttribute &self, PyAttribute &other) { return self == other; })
2979 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2980 .def("__hash__",
2981 [](PyAttribute &self) {
2982 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2983 })
2984 .def(
2985 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2986 kDumpDocstring)
2987 .def(
2988 "__str__",
2989 [](PyAttribute &self) {
2990 PyPrintAccumulator printAccum;
2991 mlirAttributePrint(self, printAccum.getCallback(),
2992 printAccum.getUserData());
2993 return printAccum.join();
2994 },
2995 "Returns the assembly form of the Attribute.")
2996 .def("__repr__", [](PyAttribute &self) {
2997 // Generally, assembly formats are not printed for __repr__ because
2998 // this can cause exceptionally long debug output and exceptions.
2999 // However, attribute values are generally considered useful and are
3000 // printed. This may need to be re-evaluated if debug dumps end up
3001 // being excessive.
3002 PyPrintAccumulator printAccum;
3003 printAccum.parts.append("Attribute(");
3004 mlirAttributePrint(self, printAccum.getCallback(),
3005 printAccum.getUserData());
3006 printAccum.parts.append(")");
3007 return printAccum.join();
3008 });
3009
3010 //----------------------------------------------------------------------------
3011 // Mapping of PyNamedAttribute
3012 //----------------------------------------------------------------------------
3013 py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3014 .def("__repr__",
3015 [](PyNamedAttribute &self) {
3016 PyPrintAccumulator printAccum;
3017 printAccum.parts.append("NamedAttribute(");
3018 printAccum.parts.append(
3019 py::str(mlirIdentifierStr(self.namedAttr.name).data,
3020 mlirIdentifierStr(self.namedAttr.name).length));
3021 printAccum.parts.append("=");
3022 mlirAttributePrint(self.namedAttr.attribute,
3023 printAccum.getCallback(),
3024 printAccum.getUserData());
3025 printAccum.parts.append(")");
3026 return printAccum.join();
3027 })
3028 .def_property_readonly(
3029 "name",
3030 [](PyNamedAttribute &self) {
3031 return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3032 mlirIdentifierStr(self.namedAttr.name).length);
3033 },
3034 "The name of the NamedAttribute binding")
3035 .def_property_readonly(
3036 "attr",
3037 [](PyNamedAttribute &self) {
3038 // TODO: When named attribute is removed/refactored, also remove
3039 // this constructor (it does an inefficient table lookup).
3040 auto contextRef = PyMlirContext::forContext(
3041 mlirAttributeGetContext(self.namedAttr.attribute));
3042 return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3043 },
3044 py::keep_alive<0, 1>(),
3045 "The underlying generic attribute of the NamedAttribute binding");
3046
3047 //----------------------------------------------------------------------------
3048 // Mapping of PyType.
3049 //----------------------------------------------------------------------------
3050 py::class_<PyType>(m, "Type", py::module_local())
3051 // Delegate to the PyType copy constructor, which will also lifetime
3052 // extend the backing context which owns the MlirType.
3053 .def(py::init<PyType &>(), py::arg("cast_from_type"),
3054 "Casts the passed type to the generic Type")
3055 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3056 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3057 .def_static(
3058 "parse",
3059 [](std::string typeSpec, DefaultingPyMlirContext context) {
3060 MlirType type =
3061 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3062 // TODO: Rework error reporting once diagnostic engine is exposed
3063 // in C API.
3064 if (mlirTypeIsNull(type)) {
3065 throw SetPyError(PyExc_ValueError,
3066 Twine("Unable to parse type: '") + typeSpec +
3067 "'");
3068 }
3069 return PyType(context->getRef(), type);
3070 },
3071 py::arg("asm"), py::arg("context") = py::none(),
3072 kContextParseTypeDocstring)
3073 .def_property_readonly(
3074 "context", [](PyType &self) { return self.getContext().getObject(); },
3075 "Context that owns the Type")
3076 .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3077 .def("__eq__", [](PyType &self, py::object &other) { return false; })
3078 .def("__hash__",
3079 [](PyType &self) {
3080 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3081 })
3082 .def(
3083 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3084 .def(
3085 "__str__",
3086 [](PyType &self) {
3087 PyPrintAccumulator printAccum;
3088 mlirTypePrint(self, printAccum.getCallback(),
3089 printAccum.getUserData());
3090 return printAccum.join();
3091 },
3092 "Returns the assembly form of the type.")
3093 .def("__repr__", [](PyType &self) {
3094 // Generally, assembly formats are not printed for __repr__ because
3095 // this can cause exceptionally long debug output and exceptions.
3096 // However, types are an exception as they typically have compact
3097 // assembly forms and printing them is useful.
3098 PyPrintAccumulator printAccum;
3099 printAccum.parts.append("Type(");
3100 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3101 printAccum.parts.append(")");
3102 return printAccum.join();
3103 });
3104
3105 //----------------------------------------------------------------------------
3106 // Mapping of Value.
3107 //----------------------------------------------------------------------------
3108 py::class_<PyValue>(m, "Value", py::module_local())
3109 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3110 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3111 .def_property_readonly(
3112 "context",
3113 [](PyValue &self) { return self.getParentOperation()->getContext(); },
3114 "Context in which the value lives.")
3115 .def(
3116 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3117 kDumpDocstring)
3118 .def_property_readonly(
3119 "owner",
3120 [](PyValue &self) {
3121 assert(mlirOperationEqual(self.getParentOperation()->get(),
3122 mlirOpResultGetOwner(self.get())) &&
3123 "expected the owner of the value in Python to match that in "
3124 "the IR");
3125 return self.getParentOperation().getObject();
3126 })
3127 .def("__eq__",
3128 [](PyValue &self, PyValue &other) {
3129 return self.get().ptr == other.get().ptr;
3130 })
3131 .def("__eq__", [](PyValue &self, py::object other) { return false; })
3132 .def("__hash__",
3133 [](PyValue &self) {
3134 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3135 })
3136 .def(
3137 "__str__",
3138 [](PyValue &self) {
3139 PyPrintAccumulator printAccum;
3140 printAccum.parts.append("Value(");
3141 mlirValuePrint(self.get(), printAccum.getCallback(),
3142 printAccum.getUserData());
3143 printAccum.parts.append(")");
3144 return printAccum.join();
3145 },
3146 kValueDunderStrDocstring)
3147 .def_property_readonly("type", [](PyValue &self) {
3148 return PyType(self.getParentOperation()->getContext(),
3149 mlirValueGetType(self.get()));
3150 });
3151 PyBlockArgument::bind(m);
3152 PyOpResult::bind(m);
3153
3154 //----------------------------------------------------------------------------
3155 // Mapping of SymbolTable.
3156 //----------------------------------------------------------------------------
3157 py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3158 .def(py::init<PyOperationBase &>())
3159 .def("__getitem__", &PySymbolTable::dunderGetItem)
3160 .def("insert", &PySymbolTable::insert, py::arg("operation"))
3161 .def("erase", &PySymbolTable::erase, py::arg("operation"))
3162 .def("__delitem__", &PySymbolTable::dunderDel)
3163 .def("__contains__",
3164 [](PySymbolTable &table, const std::string &name) {
3165 return !mlirOperationIsNull(mlirSymbolTableLookup(
3166 table, mlirStringRefCreate(name.data(), name.length())));
3167 })
3168 // Static helpers.
3169 .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3170 py::arg("symbol"), py::arg("name"))
3171 .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3172 py::arg("symbol"))
3173 .def_static("get_visibility", &PySymbolTable::getVisibility,
3174 py::arg("symbol"))
3175 .def_static("set_visibility", &PySymbolTable::setVisibility,
3176 py::arg("symbol"), py::arg("visibility"))
3177 .def_static("replace_all_symbol_uses",
3178 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3179 py::arg("new_symbol"), py::arg("from_op"))
3180 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3181 py::arg("from_op"), py::arg("all_sym_uses_visible"),
3182 py::arg("callback"));
3183
3184 // Container bindings.
3185 PyBlockArgumentList::bind(m);
3186 PyBlockIterator::bind(m);
3187 PyBlockList::bind(m);
3188 PyOperationIterator::bind(m);
3189 PyOperationList::bind(m);
3190 PyOpAttributeMap::bind(m);
3191 PyOpOperandList::bind(m);
3192 PyOpResultList::bind(m);
3193 PyRegionIterator::bind(m);
3194 PyRegionList::bind(m);
3195
3196 // Debug bindings.
3197 PyGlobalDebugFlag::bind(m);
3198 }
3199