1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/AffineMap.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/IntegerSet.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 using llvm::SmallVector;
22 using llvm::StringRef;
23 using llvm::Twine;
24 
25 static const char kDumpDocstring[] =
26     R"(Dumps a debug representation of the object to stderr.)";
27 
28 /// Attempts to populate `result` with the content of `list` casted to the
29 /// appropriate type (Python and C types are provided as template arguments).
30 /// Throws errors in case of failure, using "action" to describe what the caller
31 /// was attempting to do.
32 template <typename PyType, typename CType>
33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34                            StringRef action) {
35   result.reserve(py::len(list));
36   for (py::handle item : list) {
37     try {
38       result.push_back(item.cast<PyType>());
39     } catch (py::cast_error &err) {
40       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41                          " (" + err.what() + ")")
42                             .str();
43       throw py::cast_error(msg);
44     } catch (py::reference_cast_error &err) {
45       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46                          action + " (" + err.what() + ")")
47                             .str();
48       throw py::cast_error(msg);
49     }
50   }
51 }
52 
53 template <typename PermutationTy>
54 static bool isPermutation(std::vector<PermutationTy> permutation) {
55   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56   for (auto val : permutation) {
57     if (val < permutation.size()) {
58       if (seen[val])
59         return false;
60       seen[val] = true;
61       continue;
62     }
63     return false;
64   }
65   return true;
66 }
67 
68 namespace {
69 
70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71 /// and should be castable from it. Intermediate hierarchy classes can be
72 /// modeled by specifying BaseTy.
73 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74 class PyConcreteAffineExpr : public BaseTy {
75 public:
76   // Derived classes must define statics for:
77   //   IsAFunctionTy isaFunction
78   //   const char *pyClassName
79   // and redefine bindDerived.
80   using ClassTy = py::class_<DerivedTy, BaseTy>;
81   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82 
83   PyConcreteAffineExpr() = default;
84   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85       : BaseTy(std::move(contextRef), affineExpr) {}
86   PyConcreteAffineExpr(PyAffineExpr &orig)
87       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88 
89   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90     if (!DerivedTy::isaFunction(orig)) {
91       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92       throw SetPyError(PyExc_ValueError,
93                        Twine("Cannot cast affine expression to ") +
94                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95     }
96     return orig;
97   }
98 
99   static void bind(py::module &m) {
100     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
101     cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
102     cls.def_static(
103         "isinstance",
104         [](PyAffineExpr &otherAffineExpr) -> bool {
105           return DerivedTy::isaFunction(otherAffineExpr);
106         },
107         py::arg("other"));
108     DerivedTy::bindDerived(cls);
109   }
110 
111   /// Implemented by derived classes to add methods to the Python subclass.
112   static void bindDerived(ClassTy &m) {}
113 };
114 
115 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
116 public:
117   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
118   static constexpr const char *pyClassName = "AffineConstantExpr";
119   using PyConcreteAffineExpr::PyConcreteAffineExpr;
120 
121   static PyAffineConstantExpr get(intptr_t value,
122                                   DefaultingPyMlirContext context) {
123     MlirAffineExpr affineExpr =
124         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
125     return PyAffineConstantExpr(context->getRef(), affineExpr);
126   }
127 
128   static void bindDerived(ClassTy &c) {
129     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
130                  py::arg("context") = py::none());
131     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
132       return mlirAffineConstantExprGetValue(self);
133     });
134   }
135 };
136 
137 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
138 public:
139   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
140   static constexpr const char *pyClassName = "AffineDimExpr";
141   using PyConcreteAffineExpr::PyConcreteAffineExpr;
142 
143   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
144     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
145     return PyAffineDimExpr(context->getRef(), affineExpr);
146   }
147 
148   static void bindDerived(ClassTy &c) {
149     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
150                  py::arg("context") = py::none());
151     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
152       return mlirAffineDimExprGetPosition(self);
153     });
154   }
155 };
156 
157 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
158 public:
159   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
160   static constexpr const char *pyClassName = "AffineSymbolExpr";
161   using PyConcreteAffineExpr::PyConcreteAffineExpr;
162 
163   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
164     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
165     return PyAffineSymbolExpr(context->getRef(), affineExpr);
166   }
167 
168   static void bindDerived(ClassTy &c) {
169     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
170                  py::arg("context") = py::none());
171     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
172       return mlirAffineSymbolExprGetPosition(self);
173     });
174   }
175 };
176 
177 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
178 public:
179   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
180   static constexpr const char *pyClassName = "AffineBinaryExpr";
181   using PyConcreteAffineExpr::PyConcreteAffineExpr;
182 
183   PyAffineExpr lhs() {
184     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
185     return PyAffineExpr(getContext(), lhsExpr);
186   }
187 
188   PyAffineExpr rhs() {
189     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
190     return PyAffineExpr(getContext(), rhsExpr);
191   }
192 
193   static void bindDerived(ClassTy &c) {
194     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
195     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
196   }
197 };
198 
199 class PyAffineAddExpr
200     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
201 public:
202   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
203   static constexpr const char *pyClassName = "AffineAddExpr";
204   using PyConcreteAffineExpr::PyConcreteAffineExpr;
205 
206   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
207     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
208     return PyAffineAddExpr(lhs.getContext(), expr);
209   }
210 
211   static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
212     MlirAffineExpr expr = mlirAffineAddExprGet(
213         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
214     return PyAffineAddExpr(lhs.getContext(), expr);
215   }
216 
217   static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
218     MlirAffineExpr expr = mlirAffineAddExprGet(
219         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
220     return PyAffineAddExpr(rhs.getContext(), expr);
221   }
222 
223   static void bindDerived(ClassTy &c) {
224     c.def_static("get", &PyAffineAddExpr::get);
225   }
226 };
227 
228 class PyAffineMulExpr
229     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
230 public:
231   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
232   static constexpr const char *pyClassName = "AffineMulExpr";
233   using PyConcreteAffineExpr::PyConcreteAffineExpr;
234 
235   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
236     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
237     return PyAffineMulExpr(lhs.getContext(), expr);
238   }
239 
240   static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
241     MlirAffineExpr expr = mlirAffineMulExprGet(
242         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
243     return PyAffineMulExpr(lhs.getContext(), expr);
244   }
245 
246   static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
247     MlirAffineExpr expr = mlirAffineMulExprGet(
248         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
249     return PyAffineMulExpr(rhs.getContext(), expr);
250   }
251 
252   static void bindDerived(ClassTy &c) {
253     c.def_static("get", &PyAffineMulExpr::get);
254   }
255 };
256 
257 class PyAffineModExpr
258     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
259 public:
260   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
261   static constexpr const char *pyClassName = "AffineModExpr";
262   using PyConcreteAffineExpr::PyConcreteAffineExpr;
263 
264   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
265     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
266     return PyAffineModExpr(lhs.getContext(), expr);
267   }
268 
269   static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
270     MlirAffineExpr expr = mlirAffineModExprGet(
271         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
272     return PyAffineModExpr(lhs.getContext(), expr);
273   }
274 
275   static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
276     MlirAffineExpr expr = mlirAffineModExprGet(
277         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
278     return PyAffineModExpr(rhs.getContext(), expr);
279   }
280 
281   static void bindDerived(ClassTy &c) {
282     c.def_static("get", &PyAffineModExpr::get);
283   }
284 };
285 
286 class PyAffineFloorDivExpr
287     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
288 public:
289   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
290   static constexpr const char *pyClassName = "AffineFloorDivExpr";
291   using PyConcreteAffineExpr::PyConcreteAffineExpr;
292 
293   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
294     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
295     return PyAffineFloorDivExpr(lhs.getContext(), expr);
296   }
297 
298   static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
299     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
300         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
301     return PyAffineFloorDivExpr(lhs.getContext(), expr);
302   }
303 
304   static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
305     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
306         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
307     return PyAffineFloorDivExpr(rhs.getContext(), expr);
308   }
309 
310   static void bindDerived(ClassTy &c) {
311     c.def_static("get", &PyAffineFloorDivExpr::get);
312   }
313 };
314 
315 class PyAffineCeilDivExpr
316     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
317 public:
318   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
319   static constexpr const char *pyClassName = "AffineCeilDivExpr";
320   using PyConcreteAffineExpr::PyConcreteAffineExpr;
321 
322   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
323     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
324     return PyAffineCeilDivExpr(lhs.getContext(), expr);
325   }
326 
327   static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
328     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
329         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
330     return PyAffineCeilDivExpr(lhs.getContext(), expr);
331   }
332 
333   static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
334     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
335         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
336     return PyAffineCeilDivExpr(rhs.getContext(), expr);
337   }
338 
339   static void bindDerived(ClassTy &c) {
340     c.def_static("get", &PyAffineCeilDivExpr::get);
341   }
342 };
343 
344 } // namespace
345 
346 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
347   return mlirAffineExprEqual(affineExpr, other.affineExpr);
348 }
349 
350 py::object PyAffineExpr::getCapsule() {
351   return py::reinterpret_steal<py::object>(
352       mlirPythonAffineExprToCapsule(*this));
353 }
354 
355 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
356   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
357   if (mlirAffineExprIsNull(rawAffineExpr))
358     throw py::error_already_set();
359   return PyAffineExpr(
360       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
361       rawAffineExpr);
362 }
363 
364 //------------------------------------------------------------------------------
365 // PyAffineMap and utilities.
366 //------------------------------------------------------------------------------
367 namespace {
368 
369 /// A list of expressions contained in an affine map. Internally these are
370 /// stored as a consecutive array leading to inexpensive random access. Both
371 /// the map and the expression are owned by the context so we need not bother
372 /// with lifetime extension.
373 class PyAffineMapExprList
374     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
375 public:
376   static constexpr const char *pyClassName = "AffineExprList";
377 
378   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
379                       intptr_t length = -1, intptr_t step = 1)
380       : Sliceable(startIndex,
381                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
382                   step),
383         affineMap(map) {}
384 
385   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
386 
387   PyAffineExpr getElement(intptr_t pos) {
388     return PyAffineExpr(affineMap.getContext(),
389                         mlirAffineMapGetResult(affineMap, pos));
390   }
391 
392   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
393                             intptr_t step) {
394     return PyAffineMapExprList(affineMap, startIndex, length, step);
395   }
396 
397 private:
398   PyAffineMap affineMap;
399 };
400 } // namespace
401 
402 bool PyAffineMap::operator==(const PyAffineMap &other) {
403   return mlirAffineMapEqual(affineMap, other.affineMap);
404 }
405 
406 py::object PyAffineMap::getCapsule() {
407   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
408 }
409 
410 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
411   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
412   if (mlirAffineMapIsNull(rawAffineMap))
413     throw py::error_already_set();
414   return PyAffineMap(
415       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
416       rawAffineMap);
417 }
418 
419 //------------------------------------------------------------------------------
420 // PyIntegerSet and utilities.
421 //------------------------------------------------------------------------------
422 namespace {
423 
424 class PyIntegerSetConstraint {
425 public:
426   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
427 
428   PyAffineExpr getExpr() {
429     return PyAffineExpr(set.getContext(),
430                         mlirIntegerSetGetConstraint(set, pos));
431   }
432 
433   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
434 
435   static void bind(py::module &m) {
436     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
437                                        py::module_local())
438         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
439         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
440   }
441 
442 private:
443   PyIntegerSet set;
444   intptr_t pos;
445 };
446 
447 class PyIntegerSetConstraintList
448     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
449 public:
450   static constexpr const char *pyClassName = "IntegerSetConstraintList";
451 
452   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
453                              intptr_t length = -1, intptr_t step = 1)
454       : Sliceable(startIndex,
455                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
456                   step),
457         set(set) {}
458 
459   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
460 
461   PyIntegerSetConstraint getElement(intptr_t pos) {
462     return PyIntegerSetConstraint(set, pos);
463   }
464 
465   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
466                                    intptr_t step) {
467     return PyIntegerSetConstraintList(set, startIndex, length, step);
468   }
469 
470 private:
471   PyIntegerSet set;
472 };
473 } // namespace
474 
475 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
476   return mlirIntegerSetEqual(integerSet, other.integerSet);
477 }
478 
479 py::object PyIntegerSet::getCapsule() {
480   return py::reinterpret_steal<py::object>(
481       mlirPythonIntegerSetToCapsule(*this));
482 }
483 
484 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
485   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
486   if (mlirIntegerSetIsNull(rawIntegerSet))
487     throw py::error_already_set();
488   return PyIntegerSet(
489       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
490       rawIntegerSet);
491 }
492 
493 void mlir::python::populateIRAffine(py::module &m) {
494   //----------------------------------------------------------------------------
495   // Mapping of PyAffineExpr and derived classes.
496   //----------------------------------------------------------------------------
497   py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
498       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
499                              &PyAffineExpr::getCapsule)
500       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
501       .def("__add__", &PyAffineAddExpr::get)
502       .def("__add__", &PyAffineAddExpr::getRHSConstant)
503       .def("__radd__", &PyAffineAddExpr::getRHSConstant)
504       .def("__mul__", &PyAffineMulExpr::get)
505       .def("__mul__", &PyAffineMulExpr::getRHSConstant)
506       .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
507       .def("__mod__", &PyAffineModExpr::get)
508       .def("__mod__", &PyAffineModExpr::getRHSConstant)
509       .def("__rmod__",
510            [](PyAffineExpr &self, intptr_t other) {
511              return PyAffineModExpr::get(
512                  PyAffineConstantExpr::get(other, *self.getContext().get()),
513                  self);
514            })
515       .def("__sub__",
516            [](PyAffineExpr &self, PyAffineExpr &other) {
517              auto negOne =
518                  PyAffineConstantExpr::get(-1, *self.getContext().get());
519              return PyAffineAddExpr::get(self,
520                                          PyAffineMulExpr::get(negOne, other));
521            })
522       .def("__sub__",
523            [](PyAffineExpr &self, intptr_t other) {
524              return PyAffineAddExpr::get(
525                  self,
526                  PyAffineConstantExpr::get(-other, *self.getContext().get()));
527            })
528       .def("__rsub__",
529            [](PyAffineExpr &self, intptr_t other) {
530              return PyAffineAddExpr::getLHSConstant(
531                  other, PyAffineMulExpr::getLHSConstant(-1, self));
532            })
533       .def("__eq__", [](PyAffineExpr &self,
534                         PyAffineExpr &other) { return self == other; })
535       .def("__eq__",
536            [](PyAffineExpr &self, py::object &other) { return false; })
537       .def("__str__",
538            [](PyAffineExpr &self) {
539              PyPrintAccumulator printAccum;
540              mlirAffineExprPrint(self, printAccum.getCallback(),
541                                  printAccum.getUserData());
542              return printAccum.join();
543            })
544       .def("__repr__",
545            [](PyAffineExpr &self) {
546              PyPrintAccumulator printAccum;
547              printAccum.parts.append("AffineExpr(");
548              mlirAffineExprPrint(self, printAccum.getCallback(),
549                                  printAccum.getUserData());
550              printAccum.parts.append(")");
551              return printAccum.join();
552            })
553       .def("__hash__",
554            [](PyAffineExpr &self) {
555              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
556            })
557       .def_property_readonly(
558           "context",
559           [](PyAffineExpr &self) { return self.getContext().getObject(); })
560       .def("compose",
561            [](PyAffineExpr &self, PyAffineMap &other) {
562              return PyAffineExpr(self.getContext(),
563                                  mlirAffineExprCompose(self, other));
564            })
565       .def_static(
566           "get_add", &PyAffineAddExpr::get,
567           "Gets an affine expression containing a sum of two expressions.")
568       .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
569                   "Gets an affine expression containing a sum of a constant "
570                   "and another expression.")
571       .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
572                   "Gets an affine expression containing a sum of an expression "
573                   "and a constant.")
574       .def_static(
575           "get_mul", &PyAffineMulExpr::get,
576           "Gets an affine expression containing a product of two expressions.")
577       .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
578                   "Gets an affine expression containing a product of a "
579                   "constant and another expression.")
580       .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
581                   "Gets an affine expression containing a product of an "
582                   "expression and a constant.")
583       .def_static("get_mod", &PyAffineModExpr::get,
584                   "Gets an affine expression containing the modulo of dividing "
585                   "one expression by another.")
586       .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
587                   "Gets a semi-affine expression containing the modulo of "
588                   "dividing a constant by an expression.")
589       .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
590                   "Gets an affine expression containing the module of dividing"
591                   "an expression by a constant.")
592       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
593                   "Gets an affine expression containing the rounded-down "
594                   "result of dividing one expression by another.")
595       .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
596                   "Gets a semi-affine expression containing the rounded-down "
597                   "result of dividing a constant by an expression.")
598       .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
599                   "Gets an affine expression containing the rounded-down "
600                   "result of dividing an expression by a constant.")
601       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
602                   "Gets an affine expression containing the rounded-up result "
603                   "of dividing one expression by another.")
604       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
605                   "Gets a semi-affine expression containing the rounded-up "
606                   "result of dividing a constant by an expression.")
607       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
608                   "Gets an affine expression containing the rounded-up result "
609                   "of dividing an expression by a constant.")
610       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
611                   py::arg("context") = py::none(),
612                   "Gets a constant affine expression with the given value.")
613       .def_static(
614           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
615           py::arg("context") = py::none(),
616           "Gets an affine expression of a dimension at the given position.")
617       .def_static(
618           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
619           py::arg("context") = py::none(),
620           "Gets an affine expression of a symbol at the given position.")
621       .def(
622           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
623           kDumpDocstring);
624   PyAffineConstantExpr::bind(m);
625   PyAffineDimExpr::bind(m);
626   PyAffineSymbolExpr::bind(m);
627   PyAffineBinaryExpr::bind(m);
628   PyAffineAddExpr::bind(m);
629   PyAffineMulExpr::bind(m);
630   PyAffineModExpr::bind(m);
631   PyAffineFloorDivExpr::bind(m);
632   PyAffineCeilDivExpr::bind(m);
633 
634   //----------------------------------------------------------------------------
635   // Mapping of PyAffineMap.
636   //----------------------------------------------------------------------------
637   py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
638       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
639                              &PyAffineMap::getCapsule)
640       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
641       .def("__eq__",
642            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
643       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
644       .def("__str__",
645            [](PyAffineMap &self) {
646              PyPrintAccumulator printAccum;
647              mlirAffineMapPrint(self, printAccum.getCallback(),
648                                 printAccum.getUserData());
649              return printAccum.join();
650            })
651       .def("__repr__",
652            [](PyAffineMap &self) {
653              PyPrintAccumulator printAccum;
654              printAccum.parts.append("AffineMap(");
655              mlirAffineMapPrint(self, printAccum.getCallback(),
656                                 printAccum.getUserData());
657              printAccum.parts.append(")");
658              return printAccum.join();
659            })
660       .def("__hash__",
661            [](PyAffineMap &self) {
662              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
663            })
664       .def_static("compress_unused_symbols",
665                   [](py::list affineMaps, DefaultingPyMlirContext context) {
666                     SmallVector<MlirAffineMap> maps;
667                     pyListToVector<PyAffineMap, MlirAffineMap>(
668                         affineMaps, maps, "attempting to create an AffineMap");
669                     std::vector<MlirAffineMap> compressed(affineMaps.size());
670                     auto populate = [](void *result, intptr_t idx,
671                                        MlirAffineMap m) {
672                       static_cast<MlirAffineMap *>(result)[idx] = (m);
673                     };
674                     mlirAffineMapCompressUnusedSymbols(
675                         maps.data(), maps.size(), compressed.data(), populate);
676                     std::vector<PyAffineMap> res;
677                     res.reserve(compressed.size());
678                     for (auto m : compressed)
679                       res.push_back(PyAffineMap(context->getRef(), m));
680                     return res;
681                   })
682       .def_property_readonly(
683           "context",
684           [](PyAffineMap &self) { return self.getContext().getObject(); },
685           "Context that owns the Affine Map")
686       .def(
687           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
688           kDumpDocstring)
689       .def_static(
690           "get",
691           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
692              DefaultingPyMlirContext context) {
693             SmallVector<MlirAffineExpr> affineExprs;
694             pyListToVector<PyAffineExpr, MlirAffineExpr>(
695                 exprs, affineExprs, "attempting to create an AffineMap");
696             MlirAffineMap map =
697                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
698                                  affineExprs.size(), affineExprs.data());
699             return PyAffineMap(context->getRef(), map);
700           },
701           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
702           py::arg("context") = py::none(),
703           "Gets a map with the given expressions as results.")
704       .def_static(
705           "get_constant",
706           [](intptr_t value, DefaultingPyMlirContext context) {
707             MlirAffineMap affineMap =
708                 mlirAffineMapConstantGet(context->get(), value);
709             return PyAffineMap(context->getRef(), affineMap);
710           },
711           py::arg("value"), py::arg("context") = py::none(),
712           "Gets an affine map with a single constant result")
713       .def_static(
714           "get_empty",
715           [](DefaultingPyMlirContext context) {
716             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
717             return PyAffineMap(context->getRef(), affineMap);
718           },
719           py::arg("context") = py::none(), "Gets an empty affine map.")
720       .def_static(
721           "get_identity",
722           [](intptr_t nDims, DefaultingPyMlirContext context) {
723             MlirAffineMap affineMap =
724                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
725             return PyAffineMap(context->getRef(), affineMap);
726           },
727           py::arg("n_dims"), py::arg("context") = py::none(),
728           "Gets an identity map with the given number of dimensions.")
729       .def_static(
730           "get_minor_identity",
731           [](intptr_t nDims, intptr_t nResults,
732              DefaultingPyMlirContext context) {
733             MlirAffineMap affineMap =
734                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
735             return PyAffineMap(context->getRef(), affineMap);
736           },
737           py::arg("n_dims"), py::arg("n_results"),
738           py::arg("context") = py::none(),
739           "Gets a minor identity map with the given number of dimensions and "
740           "results.")
741       .def_static(
742           "get_permutation",
743           [](std::vector<unsigned> permutation,
744              DefaultingPyMlirContext context) {
745             if (!isPermutation(permutation))
746               throw py::cast_error("Invalid permutation when attempting to "
747                                    "create an AffineMap");
748             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
749                 context->get(), permutation.size(), permutation.data());
750             return PyAffineMap(context->getRef(), affineMap);
751           },
752           py::arg("permutation"), py::arg("context") = py::none(),
753           "Gets an affine map that permutes its inputs.")
754       .def(
755           "get_submap",
756           [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
757             intptr_t numResults = mlirAffineMapGetNumResults(self);
758             for (intptr_t pos : resultPos) {
759               if (pos < 0 || pos >= numResults)
760                 throw py::value_error("result position out of bounds");
761             }
762             MlirAffineMap affineMap = mlirAffineMapGetSubMap(
763                 self, resultPos.size(), resultPos.data());
764             return PyAffineMap(self.getContext(), affineMap);
765           },
766           py::arg("result_positions"))
767       .def(
768           "get_major_submap",
769           [](PyAffineMap &self, intptr_t nResults) {
770             if (nResults >= mlirAffineMapGetNumResults(self))
771               throw py::value_error("number of results out of bounds");
772             MlirAffineMap affineMap =
773                 mlirAffineMapGetMajorSubMap(self, nResults);
774             return PyAffineMap(self.getContext(), affineMap);
775           },
776           py::arg("n_results"))
777       .def(
778           "get_minor_submap",
779           [](PyAffineMap &self, intptr_t nResults) {
780             if (nResults >= mlirAffineMapGetNumResults(self))
781               throw py::value_error("number of results out of bounds");
782             MlirAffineMap affineMap =
783                 mlirAffineMapGetMinorSubMap(self, nResults);
784             return PyAffineMap(self.getContext(), affineMap);
785           },
786           py::arg("n_results"))
787       .def(
788           "replace",
789           [](PyAffineMap &self, PyAffineExpr &expression,
790              PyAffineExpr &replacement, intptr_t numResultDims,
791              intptr_t numResultSyms) {
792             MlirAffineMap affineMap = mlirAffineMapReplace(
793                 self, expression, replacement, numResultDims, numResultSyms);
794             return PyAffineMap(self.getContext(), affineMap);
795           },
796           py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
797           py::arg("n_result_syms"))
798       .def_property_readonly(
799           "is_permutation",
800           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
801       .def_property_readonly("is_projected_permutation",
802                              [](PyAffineMap &self) {
803                                return mlirAffineMapIsProjectedPermutation(self);
804                              })
805       .def_property_readonly(
806           "n_dims",
807           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
808       .def_property_readonly(
809           "n_inputs",
810           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
811       .def_property_readonly(
812           "n_symbols",
813           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
814       .def_property_readonly("results", [](PyAffineMap &self) {
815         return PyAffineMapExprList(self);
816       });
817   PyAffineMapExprList::bind(m);
818 
819   //----------------------------------------------------------------------------
820   // Mapping of PyIntegerSet.
821   //----------------------------------------------------------------------------
822   py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
823       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
824                              &PyIntegerSet::getCapsule)
825       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
826       .def("__eq__", [](PyIntegerSet &self,
827                         PyIntegerSet &other) { return self == other; })
828       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
829       .def("__str__",
830            [](PyIntegerSet &self) {
831              PyPrintAccumulator printAccum;
832              mlirIntegerSetPrint(self, printAccum.getCallback(),
833                                  printAccum.getUserData());
834              return printAccum.join();
835            })
836       .def("__repr__",
837            [](PyIntegerSet &self) {
838              PyPrintAccumulator printAccum;
839              printAccum.parts.append("IntegerSet(");
840              mlirIntegerSetPrint(self, printAccum.getCallback(),
841                                  printAccum.getUserData());
842              printAccum.parts.append(")");
843              return printAccum.join();
844            })
845       .def("__hash__",
846            [](PyIntegerSet &self) {
847              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
848            })
849       .def_property_readonly(
850           "context",
851           [](PyIntegerSet &self) { return self.getContext().getObject(); })
852       .def(
853           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
854           kDumpDocstring)
855       .def_static(
856           "get",
857           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
858              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
859             if (exprs.size() != eqFlags.size())
860               throw py::value_error(
861                   "Expected the number of constraints to match "
862                   "that of equality flags");
863             if (exprs.empty())
864               throw py::value_error("Expected non-empty list of constraints");
865 
866             // Copy over to a SmallVector because std::vector has a
867             // specialization for booleans that packs data and does not
868             // expose a `bool *`.
869             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
870 
871             SmallVector<MlirAffineExpr> affineExprs;
872             pyListToVector<PyAffineExpr>(exprs, affineExprs,
873                                          "attempting to create an IntegerSet");
874             MlirIntegerSet set = mlirIntegerSetGet(
875                 context->get(), numDims, numSymbols, exprs.size(),
876                 affineExprs.data(), flags.data());
877             return PyIntegerSet(context->getRef(), set);
878           },
879           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
880           py::arg("eq_flags"), py::arg("context") = py::none())
881       .def_static(
882           "get_empty",
883           [](intptr_t numDims, intptr_t numSymbols,
884              DefaultingPyMlirContext context) {
885             MlirIntegerSet set =
886                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
887             return PyIntegerSet(context->getRef(), set);
888           },
889           py::arg("num_dims"), py::arg("num_symbols"),
890           py::arg("context") = py::none())
891       .def(
892           "get_replaced",
893           [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
894              intptr_t numResultDims, intptr_t numResultSymbols) {
895             if (static_cast<intptr_t>(dimExprs.size()) !=
896                 mlirIntegerSetGetNumDims(self))
897               throw py::value_error(
898                   "Expected the number of dimension replacement expressions "
899                   "to match that of dimensions");
900             if (static_cast<intptr_t>(symbolExprs.size()) !=
901                 mlirIntegerSetGetNumSymbols(self))
902               throw py::value_error(
903                   "Expected the number of symbol replacement expressions "
904                   "to match that of symbols");
905 
906             SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
907             pyListToVector<PyAffineExpr>(
908                 dimExprs, dimAffineExprs,
909                 "attempting to create an IntegerSet by replacing dimensions");
910             pyListToVector<PyAffineExpr>(
911                 symbolExprs, symbolAffineExprs,
912                 "attempting to create an IntegerSet by replacing symbols");
913             MlirIntegerSet set = mlirIntegerSetReplaceGet(
914                 self, dimAffineExprs.data(), symbolAffineExprs.data(),
915                 numResultDims, numResultSymbols);
916             return PyIntegerSet(self.getContext(), set);
917           },
918           py::arg("dim_exprs"), py::arg("symbol_exprs"),
919           py::arg("num_result_dims"), py::arg("num_result_symbols"))
920       .def_property_readonly("is_canonical_empty",
921                              [](PyIntegerSet &self) {
922                                return mlirIntegerSetIsCanonicalEmpty(self);
923                              })
924       .def_property_readonly(
925           "n_dims",
926           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
927       .def_property_readonly(
928           "n_symbols",
929           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
930       .def_property_readonly(
931           "n_inputs",
932           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
933       .def_property_readonly("n_equalities",
934                              [](PyIntegerSet &self) {
935                                return mlirIntegerSetGetNumEqualities(self);
936                              })
937       .def_property_readonly("n_inequalities",
938                              [](PyIntegerSet &self) {
939                                return mlirIntegerSetGetNumInequalities(self);
940                              })
941       .def_property_readonly("constraints", [](PyIntegerSet &self) {
942         return PyIntegerSetConstraintList(self);
943       });
944   PyIntegerSetConstraint::bind(m);
945   PyIntegerSetConstraintList::bind(m);
946 }
947