1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/AffineMap.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/IntegerSet.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 using llvm::SmallVector;
22 using llvm::StringRef;
23 using llvm::Twine;
24 
25 static const char kDumpDocstring[] =
26     R"(Dumps a debug representation of the object to stderr.)";
27 
28 /// Attempts to populate `result` with the content of `list` casted to the
29 /// appropriate type (Python and C types are provided as template arguments).
30 /// Throws errors in case of failure, using "action" to describe what the caller
31 /// was attempting to do.
32 template <typename PyType, typename CType>
33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34                            StringRef action) {
35   result.reserve(py::len(list));
36   for (py::handle item : list) {
37     try {
38       result.push_back(item.cast<PyType>());
39     } catch (py::cast_error &err) {
40       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41                          " (" + err.what() + ")")
42                             .str();
43       throw py::cast_error(msg);
44     } catch (py::reference_cast_error &err) {
45       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46                          action + " (" + err.what() + ")")
47                             .str();
48       throw py::cast_error(msg);
49     }
50   }
51 }
52 
53 template <typename PermutationTy>
54 static bool isPermutation(std::vector<PermutationTy> permutation) {
55   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56   for (auto val : permutation) {
57     if (val < permutation.size()) {
58       if (seen[val])
59         return false;
60       seen[val] = true;
61       continue;
62     }
63     return false;
64   }
65   return true;
66 }
67 
68 namespace {
69 
70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71 /// and should be castable from it. Intermediate hierarchy classes can be
72 /// modeled by specifying BaseTy.
73 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74 class PyConcreteAffineExpr : public BaseTy {
75 public:
76   // Derived classes must define statics for:
77   //   IsAFunctionTy isaFunction
78   //   const char *pyClassName
79   // and redefine bindDerived.
80   using ClassTy = py::class_<DerivedTy, BaseTy>;
81   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82 
83   PyConcreteAffineExpr() = default;
84   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85       : BaseTy(std::move(contextRef), affineExpr) {}
86   PyConcreteAffineExpr(PyAffineExpr &orig)
87       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88 
89   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90     if (!DerivedTy::isaFunction(orig)) {
91       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92       throw SetPyError(PyExc_ValueError,
93                        Twine("Cannot cast affine expression to ") +
94                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95     }
96     return orig;
97   }
98 
99   static void bind(py::module &m) {
100     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
101     cls.def(py::init<PyAffineExpr &>());
102     cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool {
103       return DerivedTy::isaFunction(otherAffineExpr);
104     });
105     DerivedTy::bindDerived(cls);
106   }
107 
108   /// Implemented by derived classes to add methods to the Python subclass.
109   static void bindDerived(ClassTy &m) {}
110 };
111 
112 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
113 public:
114   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
115   static constexpr const char *pyClassName = "AffineConstantExpr";
116   using PyConcreteAffineExpr::PyConcreteAffineExpr;
117 
118   static PyAffineConstantExpr get(intptr_t value,
119                                   DefaultingPyMlirContext context) {
120     MlirAffineExpr affineExpr =
121         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
122     return PyAffineConstantExpr(context->getRef(), affineExpr);
123   }
124 
125   static void bindDerived(ClassTy &c) {
126     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
127                  py::arg("context") = py::none());
128     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
129       return mlirAffineConstantExprGetValue(self);
130     });
131   }
132 };
133 
134 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
135 public:
136   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
137   static constexpr const char *pyClassName = "AffineDimExpr";
138   using PyConcreteAffineExpr::PyConcreteAffineExpr;
139 
140   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
141     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
142     return PyAffineDimExpr(context->getRef(), affineExpr);
143   }
144 
145   static void bindDerived(ClassTy &c) {
146     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
147                  py::arg("context") = py::none());
148     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
149       return mlirAffineDimExprGetPosition(self);
150     });
151   }
152 };
153 
154 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
155 public:
156   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
157   static constexpr const char *pyClassName = "AffineSymbolExpr";
158   using PyConcreteAffineExpr::PyConcreteAffineExpr;
159 
160   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
161     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
162     return PyAffineSymbolExpr(context->getRef(), affineExpr);
163   }
164 
165   static void bindDerived(ClassTy &c) {
166     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
167                  py::arg("context") = py::none());
168     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
169       return mlirAffineSymbolExprGetPosition(self);
170     });
171   }
172 };
173 
174 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
175 public:
176   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
177   static constexpr const char *pyClassName = "AffineBinaryExpr";
178   using PyConcreteAffineExpr::PyConcreteAffineExpr;
179 
180   PyAffineExpr lhs() {
181     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
182     return PyAffineExpr(getContext(), lhsExpr);
183   }
184 
185   PyAffineExpr rhs() {
186     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
187     return PyAffineExpr(getContext(), rhsExpr);
188   }
189 
190   static void bindDerived(ClassTy &c) {
191     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
192     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
193   }
194 };
195 
196 class PyAffineAddExpr
197     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
198 public:
199   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
200   static constexpr const char *pyClassName = "AffineAddExpr";
201   using PyConcreteAffineExpr::PyConcreteAffineExpr;
202 
203   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
204     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
205     return PyAffineAddExpr(lhs.getContext(), expr);
206   }
207 
208   static void bindDerived(ClassTy &c) {
209     c.def_static("get", &PyAffineAddExpr::get);
210   }
211 };
212 
213 class PyAffineMulExpr
214     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
215 public:
216   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
217   static constexpr const char *pyClassName = "AffineMulExpr";
218   using PyConcreteAffineExpr::PyConcreteAffineExpr;
219 
220   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
221     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
222     return PyAffineMulExpr(lhs.getContext(), expr);
223   }
224 
225   static void bindDerived(ClassTy &c) {
226     c.def_static("get", &PyAffineMulExpr::get);
227   }
228 };
229 
230 class PyAffineModExpr
231     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
232 public:
233   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
234   static constexpr const char *pyClassName = "AffineModExpr";
235   using PyConcreteAffineExpr::PyConcreteAffineExpr;
236 
237   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
238     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
239     return PyAffineModExpr(lhs.getContext(), expr);
240   }
241 
242   static void bindDerived(ClassTy &c) {
243     c.def_static("get", &PyAffineModExpr::get);
244   }
245 };
246 
247 class PyAffineFloorDivExpr
248     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
249 public:
250   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
251   static constexpr const char *pyClassName = "AffineFloorDivExpr";
252   using PyConcreteAffineExpr::PyConcreteAffineExpr;
253 
254   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
255     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
256     return PyAffineFloorDivExpr(lhs.getContext(), expr);
257   }
258 
259   static void bindDerived(ClassTy &c) {
260     c.def_static("get", &PyAffineFloorDivExpr::get);
261   }
262 };
263 
264 class PyAffineCeilDivExpr
265     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
266 public:
267   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
268   static constexpr const char *pyClassName = "AffineCeilDivExpr";
269   using PyConcreteAffineExpr::PyConcreteAffineExpr;
270 
271   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
272     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
273     return PyAffineCeilDivExpr(lhs.getContext(), expr);
274   }
275 
276   static void bindDerived(ClassTy &c) {
277     c.def_static("get", &PyAffineCeilDivExpr::get);
278   }
279 };
280 
281 } // namespace
282 
283 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
284   return mlirAffineExprEqual(affineExpr, other.affineExpr);
285 }
286 
287 py::object PyAffineExpr::getCapsule() {
288   return py::reinterpret_steal<py::object>(
289       mlirPythonAffineExprToCapsule(*this));
290 }
291 
292 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
293   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
294   if (mlirAffineExprIsNull(rawAffineExpr))
295     throw py::error_already_set();
296   return PyAffineExpr(
297       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
298       rawAffineExpr);
299 }
300 
301 //------------------------------------------------------------------------------
302 // PyAffineMap and utilities.
303 //------------------------------------------------------------------------------
304 namespace {
305 
306 /// A list of expressions contained in an affine map. Internally these are
307 /// stored as a consecutive array leading to inexpensive random access. Both
308 /// the map and the expression are owned by the context so we need not bother
309 /// with lifetime extension.
310 class PyAffineMapExprList
311     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
312 public:
313   static constexpr const char *pyClassName = "AffineExprList";
314 
315   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
316                       intptr_t length = -1, intptr_t step = 1)
317       : Sliceable(startIndex,
318                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
319                   step),
320         affineMap(map) {}
321 
322   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
323 
324   PyAffineExpr getElement(intptr_t pos) {
325     return PyAffineExpr(affineMap.getContext(),
326                         mlirAffineMapGetResult(affineMap, pos));
327   }
328 
329   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
330                             intptr_t step) {
331     return PyAffineMapExprList(affineMap, startIndex, length, step);
332   }
333 
334 private:
335   PyAffineMap affineMap;
336 };
337 } // end namespace
338 
339 bool PyAffineMap::operator==(const PyAffineMap &other) {
340   return mlirAffineMapEqual(affineMap, other.affineMap);
341 }
342 
343 py::object PyAffineMap::getCapsule() {
344   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
345 }
346 
347 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
348   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
349   if (mlirAffineMapIsNull(rawAffineMap))
350     throw py::error_already_set();
351   return PyAffineMap(
352       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
353       rawAffineMap);
354 }
355 
356 //------------------------------------------------------------------------------
357 // PyIntegerSet and utilities.
358 //------------------------------------------------------------------------------
359 namespace {
360 
361 class PyIntegerSetConstraint {
362 public:
363   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
364 
365   PyAffineExpr getExpr() {
366     return PyAffineExpr(set.getContext(),
367                         mlirIntegerSetGetConstraint(set, pos));
368   }
369 
370   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
371 
372   static void bind(py::module &m) {
373     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
374                                        py::module_local())
375         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
376         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
377   }
378 
379 private:
380   PyIntegerSet set;
381   intptr_t pos;
382 };
383 
384 class PyIntegerSetConstraintList
385     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
386 public:
387   static constexpr const char *pyClassName = "IntegerSetConstraintList";
388 
389   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
390                              intptr_t length = -1, intptr_t step = 1)
391       : Sliceable(startIndex,
392                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
393                   step),
394         set(set) {}
395 
396   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
397 
398   PyIntegerSetConstraint getElement(intptr_t pos) {
399     return PyIntegerSetConstraint(set, pos);
400   }
401 
402   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
403                                    intptr_t step) {
404     return PyIntegerSetConstraintList(set, startIndex, length, step);
405   }
406 
407 private:
408   PyIntegerSet set;
409 };
410 } // namespace
411 
412 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
413   return mlirIntegerSetEqual(integerSet, other.integerSet);
414 }
415 
416 py::object PyIntegerSet::getCapsule() {
417   return py::reinterpret_steal<py::object>(
418       mlirPythonIntegerSetToCapsule(*this));
419 }
420 
421 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
422   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
423   if (mlirIntegerSetIsNull(rawIntegerSet))
424     throw py::error_already_set();
425   return PyIntegerSet(
426       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
427       rawIntegerSet);
428 }
429 
430 void mlir::python::populateIRAffine(py::module &m) {
431   //----------------------------------------------------------------------------
432   // Mapping of PyAffineExpr and derived classes.
433   //----------------------------------------------------------------------------
434   py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
435       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
436                              &PyAffineExpr::getCapsule)
437       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
438       .def("__add__",
439            [](PyAffineExpr &self, PyAffineExpr &other) {
440              return PyAffineAddExpr::get(self, other);
441            })
442       .def("__mul__",
443            [](PyAffineExpr &self, PyAffineExpr &other) {
444              return PyAffineMulExpr::get(self, other);
445            })
446       .def("__mod__",
447            [](PyAffineExpr &self, PyAffineExpr &other) {
448              return PyAffineModExpr::get(self, other);
449            })
450       .def("__sub__",
451            [](PyAffineExpr &self, PyAffineExpr &other) {
452              auto negOne =
453                  PyAffineConstantExpr::get(-1, *self.getContext().get());
454              return PyAffineAddExpr::get(self,
455                                          PyAffineMulExpr::get(negOne, other));
456            })
457       .def("__eq__", [](PyAffineExpr &self,
458                         PyAffineExpr &other) { return self == other; })
459       .def("__eq__",
460            [](PyAffineExpr &self, py::object &other) { return false; })
461       .def("__str__",
462            [](PyAffineExpr &self) {
463              PyPrintAccumulator printAccum;
464              mlirAffineExprPrint(self, printAccum.getCallback(),
465                                  printAccum.getUserData());
466              return printAccum.join();
467            })
468       .def("__repr__",
469            [](PyAffineExpr &self) {
470              PyPrintAccumulator printAccum;
471              printAccum.parts.append("AffineExpr(");
472              mlirAffineExprPrint(self, printAccum.getCallback(),
473                                  printAccum.getUserData());
474              printAccum.parts.append(")");
475              return printAccum.join();
476            })
477       .def_property_readonly(
478           "context",
479           [](PyAffineExpr &self) { return self.getContext().getObject(); })
480       .def_static(
481           "get_add", &PyAffineAddExpr::get,
482           "Gets an affine expression containing a sum of two expressions.")
483       .def_static(
484           "get_mul", &PyAffineMulExpr::get,
485           "Gets an affine expression containing a product of two expressions.")
486       .def_static("get_mod", &PyAffineModExpr::get,
487                   "Gets an affine expression containing the modulo of dividing "
488                   "one expression by another.")
489       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
490                   "Gets an affine expression containing the rounded-down "
491                   "result of dividing one expression by another.")
492       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
493                   "Gets an affine expression containing the rounded-up result "
494                   "of dividing one expression by another.")
495       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
496                   py::arg("context") = py::none(),
497                   "Gets a constant affine expression with the given value.")
498       .def_static(
499           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
500           py::arg("context") = py::none(),
501           "Gets an affine expression of a dimension at the given position.")
502       .def_static(
503           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
504           py::arg("context") = py::none(),
505           "Gets an affine expression of a symbol at the given position.")
506       .def(
507           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
508           kDumpDocstring);
509   PyAffineConstantExpr::bind(m);
510   PyAffineDimExpr::bind(m);
511   PyAffineSymbolExpr::bind(m);
512   PyAffineBinaryExpr::bind(m);
513   PyAffineAddExpr::bind(m);
514   PyAffineMulExpr::bind(m);
515   PyAffineModExpr::bind(m);
516   PyAffineFloorDivExpr::bind(m);
517   PyAffineCeilDivExpr::bind(m);
518 
519   //----------------------------------------------------------------------------
520   // Mapping of PyAffineMap.
521   //----------------------------------------------------------------------------
522   py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
523       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
524                              &PyAffineMap::getCapsule)
525       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
526       .def("__eq__",
527            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
528       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
529       .def("__str__",
530            [](PyAffineMap &self) {
531              PyPrintAccumulator printAccum;
532              mlirAffineMapPrint(self, printAccum.getCallback(),
533                                 printAccum.getUserData());
534              return printAccum.join();
535            })
536       .def("__repr__",
537            [](PyAffineMap &self) {
538              PyPrintAccumulator printAccum;
539              printAccum.parts.append("AffineMap(");
540              mlirAffineMapPrint(self, printAccum.getCallback(),
541                                 printAccum.getUserData());
542              printAccum.parts.append(")");
543              return printAccum.join();
544            })
545       .def_static("compress_unused_symbols",
546                   [](py::list affineMaps, DefaultingPyMlirContext context) {
547                     SmallVector<MlirAffineMap> maps;
548                     pyListToVector<PyAffineMap, MlirAffineMap>(
549                         affineMaps, maps, "attempting to create an AffineMap");
550                     std::vector<MlirAffineMap> compressed(affineMaps.size());
551                     auto populate = [](void *result, intptr_t idx,
552                                        MlirAffineMap m) {
553                       static_cast<MlirAffineMap *>(result)[idx] = (m);
554                     };
555                     mlirAffineMapCompressUnusedSymbols(
556                         maps.data(), maps.size(), compressed.data(), populate);
557                     std::vector<PyAffineMap> res;
558                     for (auto m : compressed)
559                       res.push_back(PyAffineMap(context->getRef(), m));
560                     return res;
561                   })
562       .def_property_readonly(
563           "context",
564           [](PyAffineMap &self) { return self.getContext().getObject(); },
565           "Context that owns the Affine Map")
566       .def(
567           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
568           kDumpDocstring)
569       .def_static(
570           "get",
571           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
572              DefaultingPyMlirContext context) {
573             SmallVector<MlirAffineExpr> affineExprs;
574             pyListToVector<PyAffineExpr, MlirAffineExpr>(
575                 exprs, affineExprs, "attempting to create an AffineMap");
576             MlirAffineMap map =
577                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
578                                  affineExprs.size(), affineExprs.data());
579             return PyAffineMap(context->getRef(), map);
580           },
581           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
582           py::arg("context") = py::none(),
583           "Gets a map with the given expressions as results.")
584       .def_static(
585           "get_constant",
586           [](intptr_t value, DefaultingPyMlirContext context) {
587             MlirAffineMap affineMap =
588                 mlirAffineMapConstantGet(context->get(), value);
589             return PyAffineMap(context->getRef(), affineMap);
590           },
591           py::arg("value"), py::arg("context") = py::none(),
592           "Gets an affine map with a single constant result")
593       .def_static(
594           "get_empty",
595           [](DefaultingPyMlirContext context) {
596             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
597             return PyAffineMap(context->getRef(), affineMap);
598           },
599           py::arg("context") = py::none(), "Gets an empty affine map.")
600       .def_static(
601           "get_identity",
602           [](intptr_t nDims, DefaultingPyMlirContext context) {
603             MlirAffineMap affineMap =
604                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
605             return PyAffineMap(context->getRef(), affineMap);
606           },
607           py::arg("n_dims"), py::arg("context") = py::none(),
608           "Gets an identity map with the given number of dimensions.")
609       .def_static(
610           "get_minor_identity",
611           [](intptr_t nDims, intptr_t nResults,
612              DefaultingPyMlirContext context) {
613             MlirAffineMap affineMap =
614                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
615             return PyAffineMap(context->getRef(), affineMap);
616           },
617           py::arg("n_dims"), py::arg("n_results"),
618           py::arg("context") = py::none(),
619           "Gets a minor identity map with the given number of dimensions and "
620           "results.")
621       .def_static(
622           "get_permutation",
623           [](std::vector<unsigned> permutation,
624              DefaultingPyMlirContext context) {
625             if (!isPermutation(permutation))
626               throw py::cast_error("Invalid permutation when attempting to "
627                                    "create an AffineMap");
628             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
629                 context->get(), permutation.size(), permutation.data());
630             return PyAffineMap(context->getRef(), affineMap);
631           },
632           py::arg("permutation"), py::arg("context") = py::none(),
633           "Gets an affine map that permutes its inputs.")
634       .def("get_submap",
635            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
636              intptr_t numResults = mlirAffineMapGetNumResults(self);
637              for (intptr_t pos : resultPos) {
638                if (pos < 0 || pos >= numResults)
639                  throw py::value_error("result position out of bounds");
640              }
641              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
642                  self, resultPos.size(), resultPos.data());
643              return PyAffineMap(self.getContext(), affineMap);
644            })
645       .def("get_major_submap",
646            [](PyAffineMap &self, intptr_t nResults) {
647              if (nResults >= mlirAffineMapGetNumResults(self))
648                throw py::value_error("number of results out of bounds");
649              MlirAffineMap affineMap =
650                  mlirAffineMapGetMajorSubMap(self, nResults);
651              return PyAffineMap(self.getContext(), affineMap);
652            })
653       .def("get_minor_submap",
654            [](PyAffineMap &self, intptr_t nResults) {
655              if (nResults >= mlirAffineMapGetNumResults(self))
656                throw py::value_error("number of results out of bounds");
657              MlirAffineMap affineMap =
658                  mlirAffineMapGetMinorSubMap(self, nResults);
659              return PyAffineMap(self.getContext(), affineMap);
660            })
661       .def("replace",
662            [](PyAffineMap &self, PyAffineExpr &expression,
663               PyAffineExpr &replacement, intptr_t numResultDims,
664               intptr_t numResultSyms) {
665              MlirAffineMap affineMap = mlirAffineMapReplace(
666                  self, expression, replacement, numResultDims, numResultSyms);
667              return PyAffineMap(self.getContext(), affineMap);
668            })
669       .def_property_readonly(
670           "is_permutation",
671           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
672       .def_property_readonly("is_projected_permutation",
673                              [](PyAffineMap &self) {
674                                return mlirAffineMapIsProjectedPermutation(self);
675                              })
676       .def_property_readonly(
677           "n_dims",
678           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
679       .def_property_readonly(
680           "n_inputs",
681           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
682       .def_property_readonly(
683           "n_symbols",
684           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
685       .def_property_readonly("results", [](PyAffineMap &self) {
686         return PyAffineMapExprList(self);
687       });
688   PyAffineMapExprList::bind(m);
689 
690   //----------------------------------------------------------------------------
691   // Mapping of PyIntegerSet.
692   //----------------------------------------------------------------------------
693   py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
694       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
695                              &PyIntegerSet::getCapsule)
696       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
697       .def("__eq__", [](PyIntegerSet &self,
698                         PyIntegerSet &other) { return self == other; })
699       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
700       .def("__str__",
701            [](PyIntegerSet &self) {
702              PyPrintAccumulator printAccum;
703              mlirIntegerSetPrint(self, printAccum.getCallback(),
704                                  printAccum.getUserData());
705              return printAccum.join();
706            })
707       .def("__repr__",
708            [](PyIntegerSet &self) {
709              PyPrintAccumulator printAccum;
710              printAccum.parts.append("IntegerSet(");
711              mlirIntegerSetPrint(self, printAccum.getCallback(),
712                                  printAccum.getUserData());
713              printAccum.parts.append(")");
714              return printAccum.join();
715            })
716       .def_property_readonly(
717           "context",
718           [](PyIntegerSet &self) { return self.getContext().getObject(); })
719       .def(
720           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
721           kDumpDocstring)
722       .def_static(
723           "get",
724           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
725              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
726             if (exprs.size() != eqFlags.size())
727               throw py::value_error(
728                   "Expected the number of constraints to match "
729                   "that of equality flags");
730             if (exprs.empty())
731               throw py::value_error("Expected non-empty list of constraints");
732 
733             // Copy over to a SmallVector because std::vector has a
734             // specialization for booleans that packs data and does not
735             // expose a `bool *`.
736             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
737 
738             SmallVector<MlirAffineExpr> affineExprs;
739             pyListToVector<PyAffineExpr>(exprs, affineExprs,
740                                          "attempting to create an IntegerSet");
741             MlirIntegerSet set = mlirIntegerSetGet(
742                 context->get(), numDims, numSymbols, exprs.size(),
743                 affineExprs.data(), flags.data());
744             return PyIntegerSet(context->getRef(), set);
745           },
746           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
747           py::arg("eq_flags"), py::arg("context") = py::none())
748       .def_static(
749           "get_empty",
750           [](intptr_t numDims, intptr_t numSymbols,
751              DefaultingPyMlirContext context) {
752             MlirIntegerSet set =
753                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
754             return PyIntegerSet(context->getRef(), set);
755           },
756           py::arg("num_dims"), py::arg("num_symbols"),
757           py::arg("context") = py::none())
758       .def("get_replaced",
759            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
760               intptr_t numResultDims, intptr_t numResultSymbols) {
761              if (static_cast<intptr_t>(dimExprs.size()) !=
762                  mlirIntegerSetGetNumDims(self))
763                throw py::value_error(
764                    "Expected the number of dimension replacement expressions "
765                    "to match that of dimensions");
766              if (static_cast<intptr_t>(symbolExprs.size()) !=
767                  mlirIntegerSetGetNumSymbols(self))
768                throw py::value_error(
769                    "Expected the number of symbol replacement expressions "
770                    "to match that of symbols");
771 
772              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
773              pyListToVector<PyAffineExpr>(
774                  dimExprs, dimAffineExprs,
775                  "attempting to create an IntegerSet by replacing dimensions");
776              pyListToVector<PyAffineExpr>(
777                  symbolExprs, symbolAffineExprs,
778                  "attempting to create an IntegerSet by replacing symbols");
779              MlirIntegerSet set = mlirIntegerSetReplaceGet(
780                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
781                  numResultDims, numResultSymbols);
782              return PyIntegerSet(self.getContext(), set);
783            })
784       .def_property_readonly("is_canonical_empty",
785                              [](PyIntegerSet &self) {
786                                return mlirIntegerSetIsCanonicalEmpty(self);
787                              })
788       .def_property_readonly(
789           "n_dims",
790           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
791       .def_property_readonly(
792           "n_symbols",
793           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
794       .def_property_readonly(
795           "n_inputs",
796           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
797       .def_property_readonly("n_equalities",
798                              [](PyIntegerSet &self) {
799                                return mlirIntegerSetGetNumEqualities(self);
800                              })
801       .def_property_readonly("n_inequalities",
802                              [](PyIntegerSet &self) {
803                                return mlirIntegerSetGetNumInequalities(self);
804                              })
805       .def_property_readonly("constraints", [](PyIntegerSet &self) {
806         return PyIntegerSetConstraintList(self);
807       });
808   PyIntegerSetConstraint::bind(m);
809   PyIntegerSetConstraintList::bind(m);
810 }
811