1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/AffineMap.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/IntegerSet.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 using llvm::SmallVector;
22 using llvm::StringRef;
23 using llvm::Twine;
24 
25 static const char kDumpDocstring[] =
26     R"(Dumps a debug representation of the object to stderr.)";
27 
28 /// Attempts to populate `result` with the content of `list` casted to the
29 /// appropriate type (Python and C types are provided as template arguments).
30 /// Throws errors in case of failure, using "action" to describe what the caller
31 /// was attempting to do.
32 template <typename PyType, typename CType>
33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34                            StringRef action) {
35   result.reserve(py::len(list));
36   for (py::handle item : list) {
37     try {
38       result.push_back(item.cast<PyType>());
39     } catch (py::cast_error &err) {
40       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41                          " (" + err.what() + ")")
42                             .str();
43       throw py::cast_error(msg);
44     } catch (py::reference_cast_error &err) {
45       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46                          action + " (" + err.what() + ")")
47                             .str();
48       throw py::cast_error(msg);
49     }
50   }
51 }
52 
53 template <typename PermutationTy>
54 static bool isPermutation(std::vector<PermutationTy> permutation) {
55   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56   for (auto val : permutation) {
57     if (val < permutation.size()) {
58       if (seen[val])
59         return false;
60       seen[val] = true;
61       continue;
62     }
63     return false;
64   }
65   return true;
66 }
67 
68 namespace {
69 
70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71 /// and should be castable from it. Intermediate hierarchy classes can be
72 /// modeled by specifying BaseTy.
73 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74 class PyConcreteAffineExpr : public BaseTy {
75 public:
76   // Derived classes must define statics for:
77   //   IsAFunctionTy isaFunction
78   //   const char *pyClassName
79   // and redefine bindDerived.
80   using ClassTy = py::class_<DerivedTy, BaseTy>;
81   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82 
83   PyConcreteAffineExpr() = default;
84   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85       : BaseTy(std::move(contextRef), affineExpr) {}
86   PyConcreteAffineExpr(PyAffineExpr &orig)
87       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88 
89   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90     if (!DerivedTy::isaFunction(orig)) {
91       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92       throw SetPyError(PyExc_ValueError,
93                        Twine("Cannot cast affine expression to ") +
94                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95     }
96     return orig;
97   }
98 
99   static void bind(py::module &m) {
100     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
101     cls.def(py::init<PyAffineExpr &>());
102     cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool {
103       return DerivedTy::isaFunction(otherAffineExpr);
104     });
105     DerivedTy::bindDerived(cls);
106   }
107 
108   /// Implemented by derived classes to add methods to the Python subclass.
109   static void bindDerived(ClassTy &m) {}
110 };
111 
112 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
113 public:
114   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
115   static constexpr const char *pyClassName = "AffineConstantExpr";
116   using PyConcreteAffineExpr::PyConcreteAffineExpr;
117 
118   static PyAffineConstantExpr get(intptr_t value,
119                                   DefaultingPyMlirContext context) {
120     MlirAffineExpr affineExpr =
121         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
122     return PyAffineConstantExpr(context->getRef(), affineExpr);
123   }
124 
125   static void bindDerived(ClassTy &c) {
126     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
127                  py::arg("context") = py::none());
128     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
129       return mlirAffineConstantExprGetValue(self);
130     });
131   }
132 };
133 
134 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
135 public:
136   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
137   static constexpr const char *pyClassName = "AffineDimExpr";
138   using PyConcreteAffineExpr::PyConcreteAffineExpr;
139 
140   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
141     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
142     return PyAffineDimExpr(context->getRef(), affineExpr);
143   }
144 
145   static void bindDerived(ClassTy &c) {
146     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
147                  py::arg("context") = py::none());
148     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
149       return mlirAffineDimExprGetPosition(self);
150     });
151   }
152 };
153 
154 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
155 public:
156   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
157   static constexpr const char *pyClassName = "AffineSymbolExpr";
158   using PyConcreteAffineExpr::PyConcreteAffineExpr;
159 
160   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
161     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
162     return PyAffineSymbolExpr(context->getRef(), affineExpr);
163   }
164 
165   static void bindDerived(ClassTy &c) {
166     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
167                  py::arg("context") = py::none());
168     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
169       return mlirAffineSymbolExprGetPosition(self);
170     });
171   }
172 };
173 
174 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
175 public:
176   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
177   static constexpr const char *pyClassName = "AffineBinaryExpr";
178   using PyConcreteAffineExpr::PyConcreteAffineExpr;
179 
180   PyAffineExpr lhs() {
181     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
182     return PyAffineExpr(getContext(), lhsExpr);
183   }
184 
185   PyAffineExpr rhs() {
186     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
187     return PyAffineExpr(getContext(), rhsExpr);
188   }
189 
190   static void bindDerived(ClassTy &c) {
191     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
192     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
193   }
194 };
195 
196 class PyAffineAddExpr
197     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
198 public:
199   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
200   static constexpr const char *pyClassName = "AffineAddExpr";
201   using PyConcreteAffineExpr::PyConcreteAffineExpr;
202 
203   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
204     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
205     return PyAffineAddExpr(lhs.getContext(), expr);
206   }
207 
208   static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
209     MlirAffineExpr expr = mlirAffineAddExprGet(
210         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
211     return PyAffineAddExpr(lhs.getContext(), expr);
212   }
213 
214   static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
215     MlirAffineExpr expr = mlirAffineAddExprGet(
216         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
217     return PyAffineAddExpr(rhs.getContext(), expr);
218   }
219 
220   static void bindDerived(ClassTy &c) {
221     c.def_static("get", &PyAffineAddExpr::get);
222   }
223 };
224 
225 class PyAffineMulExpr
226     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
227 public:
228   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
229   static constexpr const char *pyClassName = "AffineMulExpr";
230   using PyConcreteAffineExpr::PyConcreteAffineExpr;
231 
232   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
233     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
234     return PyAffineMulExpr(lhs.getContext(), expr);
235   }
236 
237   static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
238     MlirAffineExpr expr = mlirAffineMulExprGet(
239         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
240     return PyAffineMulExpr(lhs.getContext(), expr);
241   }
242 
243   static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
244     MlirAffineExpr expr = mlirAffineMulExprGet(
245         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
246     return PyAffineMulExpr(rhs.getContext(), expr);
247   }
248 
249   static void bindDerived(ClassTy &c) {
250     c.def_static("get", &PyAffineMulExpr::get);
251   }
252 };
253 
254 class PyAffineModExpr
255     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
256 public:
257   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
258   static constexpr const char *pyClassName = "AffineModExpr";
259   using PyConcreteAffineExpr::PyConcreteAffineExpr;
260 
261   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
262     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
263     return PyAffineModExpr(lhs.getContext(), expr);
264   }
265 
266   static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
267     MlirAffineExpr expr = mlirAffineModExprGet(
268         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
269     return PyAffineModExpr(lhs.getContext(), expr);
270   }
271 
272   static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
273     MlirAffineExpr expr = mlirAffineModExprGet(
274         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
275     return PyAffineModExpr(rhs.getContext(), expr);
276   }
277 
278   static void bindDerived(ClassTy &c) {
279     c.def_static("get", &PyAffineModExpr::get);
280   }
281 };
282 
283 class PyAffineFloorDivExpr
284     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
285 public:
286   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
287   static constexpr const char *pyClassName = "AffineFloorDivExpr";
288   using PyConcreteAffineExpr::PyConcreteAffineExpr;
289 
290   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
291     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
292     return PyAffineFloorDivExpr(lhs.getContext(), expr);
293   }
294 
295   static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
296     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
297         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
298     return PyAffineFloorDivExpr(lhs.getContext(), expr);
299   }
300 
301   static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
302     MlirAffineExpr expr = mlirAffineFloorDivExprGet(
303         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
304     return PyAffineFloorDivExpr(rhs.getContext(), expr);
305   }
306 
307   static void bindDerived(ClassTy &c) {
308     c.def_static("get", &PyAffineFloorDivExpr::get);
309   }
310 };
311 
312 class PyAffineCeilDivExpr
313     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
314 public:
315   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
316   static constexpr const char *pyClassName = "AffineCeilDivExpr";
317   using PyConcreteAffineExpr::PyConcreteAffineExpr;
318 
319   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
320     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
321     return PyAffineCeilDivExpr(lhs.getContext(), expr);
322   }
323 
324   static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
325     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
326         lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
327     return PyAffineCeilDivExpr(lhs.getContext(), expr);
328   }
329 
330   static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
331     MlirAffineExpr expr = mlirAffineCeilDivExprGet(
332         mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
333     return PyAffineCeilDivExpr(rhs.getContext(), expr);
334   }
335 
336   static void bindDerived(ClassTy &c) {
337     c.def_static("get", &PyAffineCeilDivExpr::get);
338   }
339 };
340 
341 } // namespace
342 
343 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
344   return mlirAffineExprEqual(affineExpr, other.affineExpr);
345 }
346 
347 py::object PyAffineExpr::getCapsule() {
348   return py::reinterpret_steal<py::object>(
349       mlirPythonAffineExprToCapsule(*this));
350 }
351 
352 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
353   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
354   if (mlirAffineExprIsNull(rawAffineExpr))
355     throw py::error_already_set();
356   return PyAffineExpr(
357       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
358       rawAffineExpr);
359 }
360 
361 //------------------------------------------------------------------------------
362 // PyAffineMap and utilities.
363 //------------------------------------------------------------------------------
364 namespace {
365 
366 /// A list of expressions contained in an affine map. Internally these are
367 /// stored as a consecutive array leading to inexpensive random access. Both
368 /// the map and the expression are owned by the context so we need not bother
369 /// with lifetime extension.
370 class PyAffineMapExprList
371     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
372 public:
373   static constexpr const char *pyClassName = "AffineExprList";
374 
375   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
376                       intptr_t length = -1, intptr_t step = 1)
377       : Sliceable(startIndex,
378                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
379                   step),
380         affineMap(map) {}
381 
382   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
383 
384   PyAffineExpr getElement(intptr_t pos) {
385     return PyAffineExpr(affineMap.getContext(),
386                         mlirAffineMapGetResult(affineMap, pos));
387   }
388 
389   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
390                             intptr_t step) {
391     return PyAffineMapExprList(affineMap, startIndex, length, step);
392   }
393 
394 private:
395   PyAffineMap affineMap;
396 };
397 } // end namespace
398 
399 bool PyAffineMap::operator==(const PyAffineMap &other) {
400   return mlirAffineMapEqual(affineMap, other.affineMap);
401 }
402 
403 py::object PyAffineMap::getCapsule() {
404   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
405 }
406 
407 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
408   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
409   if (mlirAffineMapIsNull(rawAffineMap))
410     throw py::error_already_set();
411   return PyAffineMap(
412       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
413       rawAffineMap);
414 }
415 
416 //------------------------------------------------------------------------------
417 // PyIntegerSet and utilities.
418 //------------------------------------------------------------------------------
419 namespace {
420 
421 class PyIntegerSetConstraint {
422 public:
423   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
424 
425   PyAffineExpr getExpr() {
426     return PyAffineExpr(set.getContext(),
427                         mlirIntegerSetGetConstraint(set, pos));
428   }
429 
430   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
431 
432   static void bind(py::module &m) {
433     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
434                                        py::module_local())
435         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
436         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
437   }
438 
439 private:
440   PyIntegerSet set;
441   intptr_t pos;
442 };
443 
444 class PyIntegerSetConstraintList
445     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
446 public:
447   static constexpr const char *pyClassName = "IntegerSetConstraintList";
448 
449   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
450                              intptr_t length = -1, intptr_t step = 1)
451       : Sliceable(startIndex,
452                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
453                   step),
454         set(set) {}
455 
456   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
457 
458   PyIntegerSetConstraint getElement(intptr_t pos) {
459     return PyIntegerSetConstraint(set, pos);
460   }
461 
462   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
463                                    intptr_t step) {
464     return PyIntegerSetConstraintList(set, startIndex, length, step);
465   }
466 
467 private:
468   PyIntegerSet set;
469 };
470 } // namespace
471 
472 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
473   return mlirIntegerSetEqual(integerSet, other.integerSet);
474 }
475 
476 py::object PyIntegerSet::getCapsule() {
477   return py::reinterpret_steal<py::object>(
478       mlirPythonIntegerSetToCapsule(*this));
479 }
480 
481 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
482   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
483   if (mlirIntegerSetIsNull(rawIntegerSet))
484     throw py::error_already_set();
485   return PyIntegerSet(
486       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
487       rawIntegerSet);
488 }
489 
490 void mlir::python::populateIRAffine(py::module &m) {
491   //----------------------------------------------------------------------------
492   // Mapping of PyAffineExpr and derived classes.
493   //----------------------------------------------------------------------------
494   py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
495       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
496                              &PyAffineExpr::getCapsule)
497       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
498       .def("__add__", &PyAffineAddExpr::get)
499       .def("__add__", &PyAffineAddExpr::getRHSConstant)
500       .def("__radd__", &PyAffineAddExpr::getRHSConstant)
501       .def("__mul__", &PyAffineMulExpr::get)
502       .def("__mul__", &PyAffineMulExpr::getRHSConstant)
503       .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
504       .def("__mod__", &PyAffineModExpr::get)
505       .def("__mod__", &PyAffineModExpr::getRHSConstant)
506       .def("__rmod__",
507            [](PyAffineExpr &self, intptr_t other) {
508              return PyAffineModExpr::get(
509                  PyAffineConstantExpr::get(other, *self.getContext().get()),
510                  self);
511            })
512       .def("__sub__",
513            [](PyAffineExpr &self, PyAffineExpr &other) {
514              auto negOne =
515                  PyAffineConstantExpr::get(-1, *self.getContext().get());
516              return PyAffineAddExpr::get(self,
517                                          PyAffineMulExpr::get(negOne, other));
518            })
519       .def("__sub__",
520            [](PyAffineExpr &self, intptr_t other) {
521              return PyAffineAddExpr::get(
522                  self,
523                  PyAffineConstantExpr::get(-other, *self.getContext().get()));
524            })
525       .def("__rsub__",
526            [](PyAffineExpr &self, intptr_t other) {
527              return PyAffineAddExpr::getLHSConstant(
528                  other, PyAffineMulExpr::getLHSConstant(-1, self));
529            })
530       .def("__eq__", [](PyAffineExpr &self,
531                         PyAffineExpr &other) { return self == other; })
532       .def("__eq__",
533            [](PyAffineExpr &self, py::object &other) { return false; })
534       .def("__str__",
535            [](PyAffineExpr &self) {
536              PyPrintAccumulator printAccum;
537              mlirAffineExprPrint(self, printAccum.getCallback(),
538                                  printAccum.getUserData());
539              return printAccum.join();
540            })
541       .def("__repr__",
542            [](PyAffineExpr &self) {
543              PyPrintAccumulator printAccum;
544              printAccum.parts.append("AffineExpr(");
545              mlirAffineExprPrint(self, printAccum.getCallback(),
546                                  printAccum.getUserData());
547              printAccum.parts.append(")");
548              return printAccum.join();
549            })
550       .def("__hash__",
551            [](PyAffineExpr &self) {
552              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
553            })
554       .def_property_readonly(
555           "context",
556           [](PyAffineExpr &self) { return self.getContext().getObject(); })
557       .def("compose",
558            [](PyAffineExpr &self, PyAffineMap &other) {
559              return PyAffineExpr(self.getContext(),
560                                  mlirAffineExprCompose(self, other));
561            })
562       .def_static(
563           "get_add", &PyAffineAddExpr::get,
564           "Gets an affine expression containing a sum of two expressions.")
565       .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
566                   "Gets an affine expression containing a sum of a constant "
567                   "and another expression.")
568       .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
569                   "Gets an affine expression containing a sum of an expression "
570                   "and a constant.")
571       .def_static(
572           "get_mul", &PyAffineMulExpr::get,
573           "Gets an affine expression containing a product of two expressions.")
574       .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
575                   "Gets an affine expression containing a product of a "
576                   "constant and another expression.")
577       .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
578                   "Gets an affine expression containing a product of an "
579                   "expression and a constant.")
580       .def_static("get_mod", &PyAffineModExpr::get,
581                   "Gets an affine expression containing the modulo of dividing "
582                   "one expression by another.")
583       .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
584                   "Gets a semi-affine expression containing the modulo of "
585                   "dividing a constant by an expression.")
586       .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
587                   "Gets an affine expression containing the module of dividing"
588                   "an expression by a constant.")
589       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
590                   "Gets an affine expression containing the rounded-down "
591                   "result of dividing one expression by another.")
592       .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
593                   "Gets a semi-affine expression containing the rounded-down "
594                   "result of dividing a constant by an expression.")
595       .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
596                   "Gets an affine expression containing the rounded-down "
597                   "result of dividing an expression by a constant.")
598       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
599                   "Gets an affine expression containing the rounded-up result "
600                   "of dividing one expression by another.")
601       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
602                   "Gets a semi-affine expression containing the rounded-up "
603                   "result of dividing a constant by an expression.")
604       .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
605                   "Gets an affine expression containing the rounded-up result "
606                   "of dividing an expression by a constant.")
607       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
608                   py::arg("context") = py::none(),
609                   "Gets a constant affine expression with the given value.")
610       .def_static(
611           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
612           py::arg("context") = py::none(),
613           "Gets an affine expression of a dimension at the given position.")
614       .def_static(
615           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
616           py::arg("context") = py::none(),
617           "Gets an affine expression of a symbol at the given position.")
618       .def(
619           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
620           kDumpDocstring);
621   PyAffineConstantExpr::bind(m);
622   PyAffineDimExpr::bind(m);
623   PyAffineSymbolExpr::bind(m);
624   PyAffineBinaryExpr::bind(m);
625   PyAffineAddExpr::bind(m);
626   PyAffineMulExpr::bind(m);
627   PyAffineModExpr::bind(m);
628   PyAffineFloorDivExpr::bind(m);
629   PyAffineCeilDivExpr::bind(m);
630 
631   //----------------------------------------------------------------------------
632   // Mapping of PyAffineMap.
633   //----------------------------------------------------------------------------
634   py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
635       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
636                              &PyAffineMap::getCapsule)
637       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
638       .def("__eq__",
639            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
640       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
641       .def("__str__",
642            [](PyAffineMap &self) {
643              PyPrintAccumulator printAccum;
644              mlirAffineMapPrint(self, printAccum.getCallback(),
645                                 printAccum.getUserData());
646              return printAccum.join();
647            })
648       .def("__repr__",
649            [](PyAffineMap &self) {
650              PyPrintAccumulator printAccum;
651              printAccum.parts.append("AffineMap(");
652              mlirAffineMapPrint(self, printAccum.getCallback(),
653                                 printAccum.getUserData());
654              printAccum.parts.append(")");
655              return printAccum.join();
656            })
657       .def("__hash__",
658            [](PyAffineMap &self) {
659              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
660            })
661       .def_static("compress_unused_symbols",
662                   [](py::list affineMaps, DefaultingPyMlirContext context) {
663                     SmallVector<MlirAffineMap> maps;
664                     pyListToVector<PyAffineMap, MlirAffineMap>(
665                         affineMaps, maps, "attempting to create an AffineMap");
666                     std::vector<MlirAffineMap> compressed(affineMaps.size());
667                     auto populate = [](void *result, intptr_t idx,
668                                        MlirAffineMap m) {
669                       static_cast<MlirAffineMap *>(result)[idx] = (m);
670                     };
671                     mlirAffineMapCompressUnusedSymbols(
672                         maps.data(), maps.size(), compressed.data(), populate);
673                     std::vector<PyAffineMap> res;
674                     res.reserve(compressed.size());
675                     for (auto m : compressed)
676                       res.push_back(PyAffineMap(context->getRef(), m));
677                     return res;
678                   })
679       .def_property_readonly(
680           "context",
681           [](PyAffineMap &self) { return self.getContext().getObject(); },
682           "Context that owns the Affine Map")
683       .def(
684           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
685           kDumpDocstring)
686       .def_static(
687           "get",
688           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
689              DefaultingPyMlirContext context) {
690             SmallVector<MlirAffineExpr> affineExprs;
691             pyListToVector<PyAffineExpr, MlirAffineExpr>(
692                 exprs, affineExprs, "attempting to create an AffineMap");
693             MlirAffineMap map =
694                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
695                                  affineExprs.size(), affineExprs.data());
696             return PyAffineMap(context->getRef(), map);
697           },
698           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
699           py::arg("context") = py::none(),
700           "Gets a map with the given expressions as results.")
701       .def_static(
702           "get_constant",
703           [](intptr_t value, DefaultingPyMlirContext context) {
704             MlirAffineMap affineMap =
705                 mlirAffineMapConstantGet(context->get(), value);
706             return PyAffineMap(context->getRef(), affineMap);
707           },
708           py::arg("value"), py::arg("context") = py::none(),
709           "Gets an affine map with a single constant result")
710       .def_static(
711           "get_empty",
712           [](DefaultingPyMlirContext context) {
713             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
714             return PyAffineMap(context->getRef(), affineMap);
715           },
716           py::arg("context") = py::none(), "Gets an empty affine map.")
717       .def_static(
718           "get_identity",
719           [](intptr_t nDims, DefaultingPyMlirContext context) {
720             MlirAffineMap affineMap =
721                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
722             return PyAffineMap(context->getRef(), affineMap);
723           },
724           py::arg("n_dims"), py::arg("context") = py::none(),
725           "Gets an identity map with the given number of dimensions.")
726       .def_static(
727           "get_minor_identity",
728           [](intptr_t nDims, intptr_t nResults,
729              DefaultingPyMlirContext context) {
730             MlirAffineMap affineMap =
731                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
732             return PyAffineMap(context->getRef(), affineMap);
733           },
734           py::arg("n_dims"), py::arg("n_results"),
735           py::arg("context") = py::none(),
736           "Gets a minor identity map with the given number of dimensions and "
737           "results.")
738       .def_static(
739           "get_permutation",
740           [](std::vector<unsigned> permutation,
741              DefaultingPyMlirContext context) {
742             if (!isPermutation(permutation))
743               throw py::cast_error("Invalid permutation when attempting to "
744                                    "create an AffineMap");
745             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
746                 context->get(), permutation.size(), permutation.data());
747             return PyAffineMap(context->getRef(), affineMap);
748           },
749           py::arg("permutation"), py::arg("context") = py::none(),
750           "Gets an affine map that permutes its inputs.")
751       .def("get_submap",
752            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
753              intptr_t numResults = mlirAffineMapGetNumResults(self);
754              for (intptr_t pos : resultPos) {
755                if (pos < 0 || pos >= numResults)
756                  throw py::value_error("result position out of bounds");
757              }
758              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
759                  self, resultPos.size(), resultPos.data());
760              return PyAffineMap(self.getContext(), affineMap);
761            })
762       .def("get_major_submap",
763            [](PyAffineMap &self, intptr_t nResults) {
764              if (nResults >= mlirAffineMapGetNumResults(self))
765                throw py::value_error("number of results out of bounds");
766              MlirAffineMap affineMap =
767                  mlirAffineMapGetMajorSubMap(self, nResults);
768              return PyAffineMap(self.getContext(), affineMap);
769            })
770       .def("get_minor_submap",
771            [](PyAffineMap &self, intptr_t nResults) {
772              if (nResults >= mlirAffineMapGetNumResults(self))
773                throw py::value_error("number of results out of bounds");
774              MlirAffineMap affineMap =
775                  mlirAffineMapGetMinorSubMap(self, nResults);
776              return PyAffineMap(self.getContext(), affineMap);
777            })
778       .def("replace",
779            [](PyAffineMap &self, PyAffineExpr &expression,
780               PyAffineExpr &replacement, intptr_t numResultDims,
781               intptr_t numResultSyms) {
782              MlirAffineMap affineMap = mlirAffineMapReplace(
783                  self, expression, replacement, numResultDims, numResultSyms);
784              return PyAffineMap(self.getContext(), affineMap);
785            })
786       .def_property_readonly(
787           "is_permutation",
788           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
789       .def_property_readonly("is_projected_permutation",
790                              [](PyAffineMap &self) {
791                                return mlirAffineMapIsProjectedPermutation(self);
792                              })
793       .def_property_readonly(
794           "n_dims",
795           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
796       .def_property_readonly(
797           "n_inputs",
798           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
799       .def_property_readonly(
800           "n_symbols",
801           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
802       .def_property_readonly("results", [](PyAffineMap &self) {
803         return PyAffineMapExprList(self);
804       });
805   PyAffineMapExprList::bind(m);
806 
807   //----------------------------------------------------------------------------
808   // Mapping of PyIntegerSet.
809   //----------------------------------------------------------------------------
810   py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
811       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
812                              &PyIntegerSet::getCapsule)
813       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
814       .def("__eq__", [](PyIntegerSet &self,
815                         PyIntegerSet &other) { return self == other; })
816       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
817       .def("__str__",
818            [](PyIntegerSet &self) {
819              PyPrintAccumulator printAccum;
820              mlirIntegerSetPrint(self, printAccum.getCallback(),
821                                  printAccum.getUserData());
822              return printAccum.join();
823            })
824       .def("__repr__",
825            [](PyIntegerSet &self) {
826              PyPrintAccumulator printAccum;
827              printAccum.parts.append("IntegerSet(");
828              mlirIntegerSetPrint(self, printAccum.getCallback(),
829                                  printAccum.getUserData());
830              printAccum.parts.append(")");
831              return printAccum.join();
832            })
833       .def("__hash__",
834            [](PyIntegerSet &self) {
835              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
836            })
837       .def_property_readonly(
838           "context",
839           [](PyIntegerSet &self) { return self.getContext().getObject(); })
840       .def(
841           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
842           kDumpDocstring)
843       .def_static(
844           "get",
845           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
846              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
847             if (exprs.size() != eqFlags.size())
848               throw py::value_error(
849                   "Expected the number of constraints to match "
850                   "that of equality flags");
851             if (exprs.empty())
852               throw py::value_error("Expected non-empty list of constraints");
853 
854             // Copy over to a SmallVector because std::vector has a
855             // specialization for booleans that packs data and does not
856             // expose a `bool *`.
857             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
858 
859             SmallVector<MlirAffineExpr> affineExprs;
860             pyListToVector<PyAffineExpr>(exprs, affineExprs,
861                                          "attempting to create an IntegerSet");
862             MlirIntegerSet set = mlirIntegerSetGet(
863                 context->get(), numDims, numSymbols, exprs.size(),
864                 affineExprs.data(), flags.data());
865             return PyIntegerSet(context->getRef(), set);
866           },
867           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
868           py::arg("eq_flags"), py::arg("context") = py::none())
869       .def_static(
870           "get_empty",
871           [](intptr_t numDims, intptr_t numSymbols,
872              DefaultingPyMlirContext context) {
873             MlirIntegerSet set =
874                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
875             return PyIntegerSet(context->getRef(), set);
876           },
877           py::arg("num_dims"), py::arg("num_symbols"),
878           py::arg("context") = py::none())
879       .def("get_replaced",
880            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
881               intptr_t numResultDims, intptr_t numResultSymbols) {
882              if (static_cast<intptr_t>(dimExprs.size()) !=
883                  mlirIntegerSetGetNumDims(self))
884                throw py::value_error(
885                    "Expected the number of dimension replacement expressions "
886                    "to match that of dimensions");
887              if (static_cast<intptr_t>(symbolExprs.size()) !=
888                  mlirIntegerSetGetNumSymbols(self))
889                throw py::value_error(
890                    "Expected the number of symbol replacement expressions "
891                    "to match that of symbols");
892 
893              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
894              pyListToVector<PyAffineExpr>(
895                  dimExprs, dimAffineExprs,
896                  "attempting to create an IntegerSet by replacing dimensions");
897              pyListToVector<PyAffineExpr>(
898                  symbolExprs, symbolAffineExprs,
899                  "attempting to create an IntegerSet by replacing symbols");
900              MlirIntegerSet set = mlirIntegerSetReplaceGet(
901                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
902                  numResultDims, numResultSymbols);
903              return PyIntegerSet(self.getContext(), set);
904            })
905       .def_property_readonly("is_canonical_empty",
906                              [](PyIntegerSet &self) {
907                                return mlirIntegerSetIsCanonicalEmpty(self);
908                              })
909       .def_property_readonly(
910           "n_dims",
911           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
912       .def_property_readonly(
913           "n_symbols",
914           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
915       .def_property_readonly(
916           "n_inputs",
917           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
918       .def_property_readonly("n_equalities",
919                              [](PyIntegerSet &self) {
920                                return mlirIntegerSetGetNumEqualities(self);
921                              })
922       .def_property_readonly("n_inequalities",
923                              [](PyIntegerSet &self) {
924                                return mlirIntegerSetGetNumInequalities(self);
925                              })
926       .def_property_readonly("constraints", [](PyIntegerSet &self) {
927         return PyIntegerSetConstraintList(self);
928       });
929   PyIntegerSetConstraint::bind(m);
930   PyIntegerSetConstraintList::bind(m);
931 }
932