1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "IRModule.h"
12 
13 #include "PybindUtils.h"
14 
15 #include "mlir-c/AffineMap.h"
16 #include "mlir-c/Bindings/Python/Interop.h"
17 #include "mlir-c/IntegerSet.h"
18 
19 namespace py = pybind11;
20 using namespace mlir;
21 using namespace mlir::python;
22 
23 using llvm::SmallVector;
24 using llvm::StringRef;
25 using llvm::Twine;
26 
27 static const char kDumpDocstring[] =
28     R"(Dumps a debug representation of the object to stderr.)";
29 
30 /// Attempts to populate `result` with the content of `list` casted to the
31 /// appropriate type (Python and C types are provided as template arguments).
32 /// Throws errors in case of failure, using "action" to describe what the caller
33 /// was attempting to do.
34 template <typename PyType, typename CType>
35 static void pyListToVector(const py::list &list,
36                            llvm::SmallVectorImpl<CType> &result,
37                            StringRef action) {
38   result.reserve(py::len(list));
39   for (py::handle item : list) {
40     try {
41       result.push_back(item.cast<PyType>());
42     } catch (py::cast_error &err) {
43       std::string msg = (llvm::Twine("Invalid expression when ") + action +
44                          " (" + err.what() + ")")
45                             .str();
46       throw py::cast_error(msg);
47     } catch (py::reference_cast_error &err) {
48       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
49                          action + " (" + err.what() + ")")
50                             .str();
51       throw py::cast_error(msg);
52     }
53   }
54 }
55 
56 template <typename PermutationTy>
57 static bool isPermutation(std::vector<PermutationTy> permutation) {
58   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
59   for (auto val : permutation) {
60     if (val < permutation.size()) {
61       if (seen[val])
62         return false;
63       seen[val] = true;
64       continue;
65     }
66     return false;
67   }
68   return true;
69 }
70 
71 namespace {
72 
73 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
74 /// and should be castable from it. Intermediate hierarchy classes can be
75 /// modeled by specifying BaseTy.
76 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
77 class PyConcreteAffineExpr : public BaseTy {
78 public:
79   // Derived classes must define statics for:
80   //   IsAFunctionTy isaFunction
81   //   const char *pyClassName
82   // and redefine bindDerived.
83   using ClassTy = py::class_<DerivedTy, BaseTy>;
84   using IsAFunctionTy = bool (*)(MlirAffineExpr);
85 
86   PyConcreteAffineExpr() = default;
87   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
88       : BaseTy(std::move(contextRef), affineExpr) {}
89   PyConcreteAffineExpr(PyAffineExpr &orig)
90       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
91 
92   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
93     if (!DerivedTy::isaFunction(orig)) {
94       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
95       throw SetPyError(PyExc_ValueError,
96                        Twine("Cannot cast affine expression to ") +
97                            DerivedTy::pyClassName + " (from " + origRepr + ")");
98     }
99     return orig;
100   }
101 
102   static void bind(py::module &m) {
103     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
104     cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
105     cls.def_static(
106         "isinstance",
107         [](PyAffineExpr &otherAffineExpr) -> bool {
108           return DerivedTy::isaFunction(otherAffineExpr);
109         },
110         py::arg("other"));
111     DerivedTy::bindDerived(cls);
112   }
113 
114   /// Implemented by derived classes to add methods to the Python subclass.
115   static void bindDerived(ClassTy &m) {}
116 };
117 
118 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
119 public:
120   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
121   static constexpr const char *pyClassName = "AffineConstantExpr";
122   using PyConcreteAffineExpr::PyConcreteAffineExpr;
123 
124   static PyAffineConstantExpr get(intptr_t value,
125                                   DefaultingPyMlirContext context) {
126     MlirAffineExpr affineExpr =
127         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
128     return PyAffineConstantExpr(context->getRef(), affineExpr);
129   }
130 
131   static void bindDerived(ClassTy &c) {
132     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
133                  py::arg("context") = py::none());
134     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
135       return mlirAffineConstantExprGetValue(self);
136     });
137   }
138 };
139 
140 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
141 public:
142   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
143   static constexpr const char *pyClassName = "AffineDimExpr";
144   using PyConcreteAffineExpr::PyConcreteAffineExpr;
145 
146   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
147     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
148     return PyAffineDimExpr(context->getRef(), affineExpr);
149   }
150 
151   static void bindDerived(ClassTy &c) {
152     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
153                  py::arg("context") = py::none());
154     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
155       return mlirAffineDimExprGetPosition(self);
156     });
157   }
158 };
159 
160 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
161 public:
162   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
163   static constexpr const char *pyClassName = "AffineSymbolExpr";
164   using PyConcreteAffineExpr::PyConcreteAffineExpr;
165 
166   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
167     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
168     return PyAffineSymbolExpr(context->getRef(), affineExpr);
169   }
170 
171   static void bindDerived(ClassTy &c) {
172     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
173                  py::arg("context") = py::none());
174     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
175       return mlirAffineSymbolExprGetPosition(self);
176     });
177   }
178 };
179 
180 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
181 public:
182   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
183   static constexpr const char *pyClassName = "AffineBinaryExpr";
184   using PyConcreteAffineExpr::PyConcreteAffineExpr;
185 
186   PyAffineExpr lhs() {
187     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
188     return PyAffineExpr(getContext(), lhsExpr);
189   }
190 
191   PyAffineExpr rhs() {
192     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
193     return PyAffineExpr(getContext(), rhsExpr);
194   }
195 
196   static void bindDerived(ClassTy &c) {
197     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
198     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
199   }
200 };
201 
202 class PyAffineAddExpr
203     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
204 public:
205   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
206   static constexpr const char *pyClassName = "AffineAddExpr";
207   using PyConcreteAffineExpr::PyConcreteAffineExpr;
208 
209   static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
210     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
211     return PyAffineAddExpr(lhs.getContext(), expr);
212   }
213 
214   static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
215     MlirAffineExpr expr = mlirAffineAddExprGet(
216         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
217     return PyAffineAddExpr(lhs.getContext(), expr);
218   }
219 
220   static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
221     MlirAffineExpr expr = mlirAffineAddExprGet(
222         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
223     return PyAffineAddExpr(rhs.getContext(), expr);
224   }
225 
226   static void bindDerived(ClassTy &c) {
227     c.def_static("get", &PyAffineAddExpr::get);
228   }
229 };
230 
231 class PyAffineMulExpr
232     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
233 public:
234   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
235   static constexpr const char *pyClassName = "AffineMulExpr";
236   using PyConcreteAffineExpr::PyConcreteAffineExpr;
237 
238   static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
239     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
240     return PyAffineMulExpr(lhs.getContext(), expr);
241   }
242 
243   static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
244     MlirAffineExpr expr = mlirAffineMulExprGet(
245         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
246     return PyAffineMulExpr(lhs.getContext(), expr);
247   }
248 
249   static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
250     MlirAffineExpr expr = mlirAffineMulExprGet(
251         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
252     return PyAffineMulExpr(rhs.getContext(), expr);
253   }
254 
255   static void bindDerived(ClassTy &c) {
256     c.def_static("get", &PyAffineMulExpr::get);
257   }
258 };
259 
260 class PyAffineModExpr
261     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
262 public:
263   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
264   static constexpr const char *pyClassName = "AffineModExpr";
265   using PyConcreteAffineExpr::PyConcreteAffineExpr;
266 
267   static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
268     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
269     return PyAffineModExpr(lhs.getContext(), expr);
270   }
271 
272   static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
273     MlirAffineExpr expr = mlirAffineModExprGet(
274         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
275     return PyAffineModExpr(lhs.getContext(), expr);
276   }
277 
278   static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
279     MlirAffineExpr expr = mlirAffineModExprGet(
280         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
281     return PyAffineModExpr(rhs.getContext(), expr);
282   }
283 
284   static void bindDerived(ClassTy &c) {
285     c.def_static("get", &PyAffineModExpr::get);
286   }
287 };
288 
289 class PyAffineFloorDivExpr
290     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
291 public:
292   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
293   static constexpr const char *pyClassName = "AffineFloorDivExpr";
294   using PyConcreteAffineExpr::PyConcreteAffineExpr;
295 
296   static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
297     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
298     return PyAffineFloorDivExpr(lhs.getContext(), expr);
299   }
300 
301   static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
302     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
303         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
304     return PyAffineFloorDivExpr(lhs.getContext(), expr);
305   }
306 
307   static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
308     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
309         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
310     return PyAffineFloorDivExpr(rhs.getContext(), expr);
311   }
312 
313   static void bindDerived(ClassTy &c) {
314     c.def_static("get", &PyAffineFloorDivExpr::get);
315   }
316 };
317 
318 class PyAffineCeilDivExpr
319     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
320 public:
321   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
322   static constexpr const char *pyClassName = "AffineCeilDivExpr";
323   using PyConcreteAffineExpr::PyConcreteAffineExpr;
324 
325   static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
326     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
327     return PyAffineCeilDivExpr(lhs.getContext(), expr);
328   }
329 
330   static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
331     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
332         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
333     return PyAffineCeilDivExpr(lhs.getContext(), expr);
334   }
335 
336   static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
337     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
338         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
339     return PyAffineCeilDivExpr(rhs.getContext(), expr);
340   }
341 
342   static void bindDerived(ClassTy &c) {
343     c.def_static("get", &PyAffineCeilDivExpr::get);
344   }
345 };
346 
347 } // namespace
348 
349 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
350   return mlirAffineExprEqual(affineExpr, other.affineExpr);
351 }
352 
353 py::object PyAffineExpr::getCapsule() {
354   return py::reinterpret_steal<py::object>(
355       mlirPythonAffineExprToCapsule(*this));
356 }
357 
358 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
359   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
360   if (mlirAffineExprIsNull(rawAffineExpr))
361     throw py::error_already_set();
362   return PyAffineExpr(
363       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
364       rawAffineExpr);
365 }
366 
367 //------------------------------------------------------------------------------
368 // PyAffineMap and utilities.
369 //------------------------------------------------------------------------------
370 namespace {
371 
372 /// A list of expressions contained in an affine map. Internally these are
373 /// stored as a consecutive array leading to inexpensive random access. Both
374 /// the map and the expression are owned by the context so we need not bother
375 /// with lifetime extension.
376 class PyAffineMapExprList
377     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
378 public:
379   static constexpr const char *pyClassName = "AffineExprList";
380 
381   PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0,
382                       intptr_t length = -1, intptr_t step = 1)
383       : Sliceable(startIndex,
384                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
385                   step),
386         affineMap(map) {}
387 
388 private:
389   /// Give the parent CRTP class access to hook implementations below.
390   friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
391 
392   intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
393 
394   PyAffineExpr getRawElement(intptr_t pos) {
395     return PyAffineExpr(affineMap.getContext(),
396                         mlirAffineMapGetResult(affineMap, pos));
397   }
398 
399   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
400                             intptr_t step) {
401     return PyAffineMapExprList(affineMap, startIndex, length, step);
402   }
403 
404   PyAffineMap affineMap;
405 };
406 } // namespace
407 
408 bool PyAffineMap::operator==(const PyAffineMap &other) {
409   return mlirAffineMapEqual(affineMap, other.affineMap);
410 }
411 
412 py::object PyAffineMap::getCapsule() {
413   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
414 }
415 
416 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
417   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
418   if (mlirAffineMapIsNull(rawAffineMap))
419     throw py::error_already_set();
420   return PyAffineMap(
421       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
422       rawAffineMap);
423 }
424 
425 //------------------------------------------------------------------------------
426 // PyIntegerSet and utilities.
427 //------------------------------------------------------------------------------
428 namespace {
429 
430 class PyIntegerSetConstraint {
431 public:
432   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos)
433       : set(std::move(set)), pos(pos) {}
434 
435   PyAffineExpr getExpr() {
436     return PyAffineExpr(set.getContext(),
437                         mlirIntegerSetGetConstraint(set, pos));
438   }
439 
440   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
441 
442   static void bind(py::module &m) {
443     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
444                                        py::module_local())
445         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
446         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
447   }
448 
449 private:
450   PyIntegerSet set;
451   intptr_t pos;
452 };
453 
454 class PyIntegerSetConstraintList
455     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
456 public:
457   static constexpr const char *pyClassName = "IntegerSetConstraintList";
458 
459   PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0,
460                              intptr_t length = -1, intptr_t step = 1)
461       : Sliceable(startIndex,
462                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
463                   step),
464         set(set) {}
465 
466 private:
467   /// Give the parent CRTP class access to hook implementations below.
468   friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
469 
470   intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
471 
472   PyIntegerSetConstraint getRawElement(intptr_t pos) {
473     return PyIntegerSetConstraint(set, pos);
474   }
475 
476   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
477                                    intptr_t step) {
478     return PyIntegerSetConstraintList(set, startIndex, length, step);
479   }
480 
481   PyIntegerSet set;
482 };
483 } // namespace
484 
485 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
486   return mlirIntegerSetEqual(integerSet, other.integerSet);
487 }
488 
489 py::object PyIntegerSet::getCapsule() {
490   return py::reinterpret_steal<py::object>(
491       mlirPythonIntegerSetToCapsule(*this));
492 }
493 
494 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
495   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
496   if (mlirIntegerSetIsNull(rawIntegerSet))
497     throw py::error_already_set();
498   return PyIntegerSet(
499       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
500       rawIntegerSet);
501 }
502 
503 void mlir::python::populateIRAffine(py::module &m) {
504   //----------------------------------------------------------------------------
505   // Mapping of PyAffineExpr and derived classes.
506   //----------------------------------------------------------------------------
507   py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
508       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
509                              &PyAffineExpr::getCapsule)
510       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
511       .def("__add__", &PyAffineAddExpr::get)
512       .def("__add__", &PyAffineAddExpr::getRHSConstant)
513       .def("__radd__", &PyAffineAddExpr::getRHSConstant)
514       .def("__mul__", &PyAffineMulExpr::get)
515       .def("__mul__", &PyAffineMulExpr::getRHSConstant)
516       .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
517       .def("__mod__", &PyAffineModExpr::get)
518       .def("__mod__", &PyAffineModExpr::getRHSConstant)
519       .def("__rmod__",
520            [](PyAffineExpr &self, intptr_t other) {
521              return PyAffineModExpr::get(
522                  PyAffineConstantExpr::get(other, *self.getContext().get()),
523                  self);
524            })
525       .def("__sub__",
526            [](PyAffineExpr &self, PyAffineExpr &other) {
527              auto negOne =
528                  PyAffineConstantExpr::get(-1, *self.getContext().get());
529              return PyAffineAddExpr::get(self,
530                                          PyAffineMulExpr::get(negOne, other));
531            })
532       .def("__sub__",
533            [](PyAffineExpr &self, intptr_t other) {
534              return PyAffineAddExpr::get(
535                  self,
536                  PyAffineConstantExpr::get(-other, *self.getContext().get()));
537            })
538       .def("__rsub__",
539            [](PyAffineExpr &self, intptr_t other) {
540              return PyAffineAddExpr::getLHSConstant(
541                  other, PyAffineMulExpr::getLHSConstant(-1, self));
542            })
543       .def("__eq__", [](PyAffineExpr &self,
544                         PyAffineExpr &other) { return self == other; })
545       .def("__eq__",
546            [](PyAffineExpr &self, py::object &other) { return false; })
547       .def("__str__",
548            [](PyAffineExpr &self) {
549              PyPrintAccumulator printAccum;
550              mlirAffineExprPrint(self, printAccum.getCallback(),
551                                  printAccum.getUserData());
552              return printAccum.join();
553            })
554       .def("__repr__",
555            [](PyAffineExpr &self) {
556              PyPrintAccumulator printAccum;
557              printAccum.parts.append("AffineExpr(");
558              mlirAffineExprPrint(self, printAccum.getCallback(),
559                                  printAccum.getUserData());
560              printAccum.parts.append(")");
561              return printAccum.join();
562            })
563       .def("__hash__",
564            [](PyAffineExpr &self) {
565              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
566            })
567       .def_property_readonly(
568           "context",
569           [](PyAffineExpr &self) { return self.getContext().getObject(); })
570       .def("compose",
571            [](PyAffineExpr &self, PyAffineMap &other) {
572              return PyAffineExpr(self.getContext(),
573                                  mlirAffineExprCompose(self, other));
574            })
575       .def_static(
576           "get_add", &PyAffineAddExpr::get,
577           "Gets an affine expression containing a sum of two expressions.")
578       .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
579                   "Gets an affine expression containing a sum of a constant "
580                   "and another expression.")
581       .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
582                   "Gets an affine expression containing a sum of an expression "
583                   "and a constant.")
584       .def_static(
585           "get_mul", &PyAffineMulExpr::get,
586           "Gets an affine expression containing a product of two expressions.")
587       .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
588                   "Gets an affine expression containing a product of a "
589                   "constant and another expression.")
590       .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
591                   "Gets an affine expression containing a product of an "
592                   "expression and a constant.")
593       .def_static("get_mod", &PyAffineModExpr::get,
594                   "Gets an affine expression containing the modulo of dividing "
595                   "one expression by another.")
596       .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
597                   "Gets a semi-affine expression containing the modulo of "
598                   "dividing a constant by an expression.")
599       .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
600                   "Gets an affine expression containing the module of dividing"
601                   "an expression by a constant.")
602       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
603                   "Gets an affine expression containing the rounded-down "
604                   "result of dividing one expression by another.")
605       .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
606                   "Gets a semi-affine expression containing the rounded-down "
607                   "result of dividing a constant by an expression.")
608       .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
609                   "Gets an affine expression containing the rounded-down "
610                   "result of dividing an expression by a constant.")
611       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
612                   "Gets an affine expression containing the rounded-up result "
613                   "of dividing one expression by another.")
614       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
615                   "Gets a semi-affine expression containing the rounded-up "
616                   "result of dividing a constant by an expression.")
617       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
618                   "Gets an affine expression containing the rounded-up result "
619                   "of dividing an expression by a constant.")
620       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
621                   py::arg("context") = py::none(),
622                   "Gets a constant affine expression with the given value.")
623       .def_static(
624           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
625           py::arg("context") = py::none(),
626           "Gets an affine expression of a dimension at the given position.")
627       .def_static(
628           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
629           py::arg("context") = py::none(),
630           "Gets an affine expression of a symbol at the given position.")
631       .def(
632           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
633           kDumpDocstring);
634   PyAffineConstantExpr::bind(m);
635   PyAffineDimExpr::bind(m);
636   PyAffineSymbolExpr::bind(m);
637   PyAffineBinaryExpr::bind(m);
638   PyAffineAddExpr::bind(m);
639   PyAffineMulExpr::bind(m);
640   PyAffineModExpr::bind(m);
641   PyAffineFloorDivExpr::bind(m);
642   PyAffineCeilDivExpr::bind(m);
643 
644   //----------------------------------------------------------------------------
645   // Mapping of PyAffineMap.
646   //----------------------------------------------------------------------------
647   py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
648       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
649                              &PyAffineMap::getCapsule)
650       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
651       .def("__eq__",
652            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
653       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
654       .def("__str__",
655            [](PyAffineMap &self) {
656              PyPrintAccumulator printAccum;
657              mlirAffineMapPrint(self, printAccum.getCallback(),
658                                 printAccum.getUserData());
659              return printAccum.join();
660            })
661       .def("__repr__",
662            [](PyAffineMap &self) {
663              PyPrintAccumulator printAccum;
664              printAccum.parts.append("AffineMap(");
665              mlirAffineMapPrint(self, printAccum.getCallback(),
666                                 printAccum.getUserData());
667              printAccum.parts.append(")");
668              return printAccum.join();
669            })
670       .def("__hash__",
671            [](PyAffineMap &self) {
672              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
673            })
674       .def_static("compress_unused_symbols",
675                   [](py::list affineMaps, DefaultingPyMlirContext context) {
676                     SmallVector<MlirAffineMap> maps;
677                     pyListToVector<PyAffineMap, MlirAffineMap>(
678                         affineMaps, maps, "attempting to create an AffineMap");
679                     std::vector<MlirAffineMap> compressed(affineMaps.size());
680                     auto populate = [](void *result, intptr_t idx,
681                                        MlirAffineMap m) {
682                       static_cast<MlirAffineMap *>(result)[idx] = (m);
683                     };
684                     mlirAffineMapCompressUnusedSymbols(
685                         maps.data(), maps.size(), compressed.data(), populate);
686                     std::vector<PyAffineMap> res;
687                     res.reserve(compressed.size());
688                     for (auto m : compressed)
689                       res.emplace_back(context->getRef(), m);
690                     return res;
691                   })
692       .def_property_readonly(
693           "context",
694           [](PyAffineMap &self) { return self.getContext().getObject(); },
695           "Context that owns the Affine Map")
696       .def(
697           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
698           kDumpDocstring)
699       .def_static(
700           "get",
701           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
702              DefaultingPyMlirContext context) {
703             SmallVector<MlirAffineExpr> affineExprs;
704             pyListToVector<PyAffineExpr, MlirAffineExpr>(
705                 exprs, affineExprs, "attempting to create an AffineMap");
706             MlirAffineMap map =
707                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
708                                  affineExprs.size(), affineExprs.data());
709             return PyAffineMap(context->getRef(), map);
710           },
711           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
712           py::arg("context") = py::none(),
713           "Gets a map with the given expressions as results.")
714       .def_static(
715           "get_constant",
716           [](intptr_t value, DefaultingPyMlirContext context) {
717             MlirAffineMap affineMap =
718                 mlirAffineMapConstantGet(context->get(), value);
719             return PyAffineMap(context->getRef(), affineMap);
720           },
721           py::arg("value"), py::arg("context") = py::none(),
722           "Gets an affine map with a single constant result")
723       .def_static(
724           "get_empty",
725           [](DefaultingPyMlirContext context) {
726             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
727             return PyAffineMap(context->getRef(), affineMap);
728           },
729           py::arg("context") = py::none(), "Gets an empty affine map.")
730       .def_static(
731           "get_identity",
732           [](intptr_t nDims, DefaultingPyMlirContext context) {
733             MlirAffineMap affineMap =
734                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
735             return PyAffineMap(context->getRef(), affineMap);
736           },
737           py::arg("n_dims"), py::arg("context") = py::none(),
738           "Gets an identity map with the given number of dimensions.")
739       .def_static(
740           "get_minor_identity",
741           [](intptr_t nDims, intptr_t nResults,
742              DefaultingPyMlirContext context) {
743             MlirAffineMap affineMap =
744                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
745             return PyAffineMap(context->getRef(), affineMap);
746           },
747           py::arg("n_dims"), py::arg("n_results"),
748           py::arg("context") = py::none(),
749           "Gets a minor identity map with the given number of dimensions and "
750           "results.")
751       .def_static(
752           "get_permutation",
753           [](std::vector<unsigned> permutation,
754              DefaultingPyMlirContext context) {
755             if (!isPermutation(permutation))
756               throw py::cast_error("Invalid permutation when attempting to "
757                                    "create an AffineMap");
758             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
759                 context->get(), permutation.size(), permutation.data());
760             return PyAffineMap(context->getRef(), affineMap);
761           },
762           py::arg("permutation"), py::arg("context") = py::none(),
763           "Gets an affine map that permutes its inputs.")
764       .def(
765           "get_submap",
766           [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
767             intptr_t numResults = mlirAffineMapGetNumResults(self);
768             for (intptr_t pos : resultPos) {
769               if (pos < 0 || pos >= numResults)
770                 throw py::value_error("result position out of bounds");
771             }
772             MlirAffineMap affineMap = mlirAffineMapGetSubMap(
773                 self, resultPos.size(), resultPos.data());
774             return PyAffineMap(self.getContext(), affineMap);
775           },
776           py::arg("result_positions"))
777       .def(
778           "get_major_submap",
779           [](PyAffineMap &self, intptr_t nResults) {
780             if (nResults >= mlirAffineMapGetNumResults(self))
781               throw py::value_error("number of results out of bounds");
782             MlirAffineMap affineMap =
783                 mlirAffineMapGetMajorSubMap(self, nResults);
784             return PyAffineMap(self.getContext(), affineMap);
785           },
786           py::arg("n_results"))
787       .def(
788           "get_minor_submap",
789           [](PyAffineMap &self, intptr_t nResults) {
790             if (nResults >= mlirAffineMapGetNumResults(self))
791               throw py::value_error("number of results out of bounds");
792             MlirAffineMap affineMap =
793                 mlirAffineMapGetMinorSubMap(self, nResults);
794             return PyAffineMap(self.getContext(), affineMap);
795           },
796           py::arg("n_results"))
797       .def(
798           "replace",
799           [](PyAffineMap &self, PyAffineExpr &expression,
800              PyAffineExpr &replacement, intptr_t numResultDims,
801              intptr_t numResultSyms) {
802             MlirAffineMap affineMap = mlirAffineMapReplace(
803                 self, expression, replacement, numResultDims, numResultSyms);
804             return PyAffineMap(self.getContext(), affineMap);
805           },
806           py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
807           py::arg("n_result_syms"))
808       .def_property_readonly(
809           "is_permutation",
810           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
811       .def_property_readonly("is_projected_permutation",
812                              [](PyAffineMap &self) {
813                                return mlirAffineMapIsProjectedPermutation(self);
814                              })
815       .def_property_readonly(
816           "n_dims",
817           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
818       .def_property_readonly(
819           "n_inputs",
820           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
821       .def_property_readonly(
822           "n_symbols",
823           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
824       .def_property_readonly("results", [](PyAffineMap &self) {
825         return PyAffineMapExprList(self);
826       });
827   PyAffineMapExprList::bind(m);
828 
829   //----------------------------------------------------------------------------
830   // Mapping of PyIntegerSet.
831   //----------------------------------------------------------------------------
832   py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
833       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
834                              &PyIntegerSet::getCapsule)
835       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
836       .def("__eq__", [](PyIntegerSet &self,
837                         PyIntegerSet &other) { return self == other; })
838       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
839       .def("__str__",
840            [](PyIntegerSet &self) {
841              PyPrintAccumulator printAccum;
842              mlirIntegerSetPrint(self, printAccum.getCallback(),
843                                  printAccum.getUserData());
844              return printAccum.join();
845            })
846       .def("__repr__",
847            [](PyIntegerSet &self) {
848              PyPrintAccumulator printAccum;
849              printAccum.parts.append("IntegerSet(");
850              mlirIntegerSetPrint(self, printAccum.getCallback(),
851                                  printAccum.getUserData());
852              printAccum.parts.append(")");
853              return printAccum.join();
854            })
855       .def("__hash__",
856            [](PyIntegerSet &self) {
857              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
858            })
859       .def_property_readonly(
860           "context",
861           [](PyIntegerSet &self) { return self.getContext().getObject(); })
862       .def(
863           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
864           kDumpDocstring)
865       .def_static(
866           "get",
867           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
868              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
869             if (exprs.size() != eqFlags.size())
870               throw py::value_error(
871                   "Expected the number of constraints to match "
872                   "that of equality flags");
873             if (exprs.empty())
874               throw py::value_error("Expected non-empty list of constraints");
875 
876             // Copy over to a SmallVector because std::vector has a
877             // specialization for booleans that packs data and does not
878             // expose a `bool *`.
879             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
880 
881             SmallVector<MlirAffineExpr> affineExprs;
882             pyListToVector<PyAffineExpr>(exprs, affineExprs,
883                                          "attempting to create an IntegerSet");
884             MlirIntegerSet set = mlirIntegerSetGet(
885                 context->get(), numDims, numSymbols, exprs.size(),
886                 affineExprs.data(), flags.data());
887             return PyIntegerSet(context->getRef(), set);
888           },
889           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
890           py::arg("eq_flags"), py::arg("context") = py::none())
891       .def_static(
892           "get_empty",
893           [](intptr_t numDims, intptr_t numSymbols,
894              DefaultingPyMlirContext context) {
895             MlirIntegerSet set =
896                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
897             return PyIntegerSet(context->getRef(), set);
898           },
899           py::arg("num_dims"), py::arg("num_symbols"),
900           py::arg("context") = py::none())
901       .def(
902           "get_replaced",
903           [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
904              intptr_t numResultDims, intptr_t numResultSymbols) {
905             if (static_cast<intptr_t>(dimExprs.size()) !=
906                 mlirIntegerSetGetNumDims(self))
907               throw py::value_error(
908                   "Expected the number of dimension replacement expressions "
909                   "to match that of dimensions");
910             if (static_cast<intptr_t>(symbolExprs.size()) !=
911                 mlirIntegerSetGetNumSymbols(self))
912               throw py::value_error(
913                   "Expected the number of symbol replacement expressions "
914                   "to match that of symbols");
915 
916             SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
917             pyListToVector<PyAffineExpr>(
918                 dimExprs, dimAffineExprs,
919                 "attempting to create an IntegerSet by replacing dimensions");
920             pyListToVector<PyAffineExpr>(
921                 symbolExprs, symbolAffineExprs,
922                 "attempting to create an IntegerSet by replacing symbols");
923             MlirIntegerSet set = mlirIntegerSetReplaceGet(
924                 self, dimAffineExprs.data(), symbolAffineExprs.data(),
925                 numResultDims, numResultSymbols);
926             return PyIntegerSet(self.getContext(), set);
927           },
928           py::arg("dim_exprs"), py::arg("symbol_exprs"),
929           py::arg("num_result_dims"), py::arg("num_result_symbols"))
930       .def_property_readonly("is_canonical_empty",
931                              [](PyIntegerSet &self) {
932                                return mlirIntegerSetIsCanonicalEmpty(self);
933                              })
934       .def_property_readonly(
935           "n_dims",
936           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
937       .def_property_readonly(
938           "n_symbols",
939           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
940       .def_property_readonly(
941           "n_inputs",
942           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
943       .def_property_readonly("n_equalities",
944                              [](PyIntegerSet &self) {
945                                return mlirIntegerSetGetNumEqualities(self);
946                              })
947       .def_property_readonly("n_inequalities",
948                              [](PyIntegerSet &self) {
949                                return mlirIntegerSetGetNumInequalities(self);
950                              })
951       .def_property_readonly("constraints", [](PyIntegerSet &self) {
952         return PyIntegerSetConstraintList(self);
953       });
954   PyIntegerSetConstraint::bind(m);
955   PyIntegerSetConstraintList::bind(m);
956 }
957