1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/AffineMap.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/IntegerSet.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 using llvm::SmallVector;
22 using llvm::StringRef;
23 using llvm::Twine;
24 
25 static const char kDumpDocstring[] =
26     R"(Dumps a debug representation of the object to stderr.)";
27 
28 /// Attempts to populate `result` with the content of `list` casted to the
29 /// appropriate type (Python and C types are provided as template arguments).
30 /// Throws errors in case of failure, using "action" to describe what the caller
31 /// was attempting to do.
32 template <typename PyType, typename CType>
33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34                            StringRef action) {
35   result.reserve(py::len(list));
36   for (py::handle item : list) {
37     try {
38       result.push_back(item.cast<PyType>());
39     } catch (py::cast_error &err) {
40       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41                          " (" + err.what() + ")")
42                             .str();
43       throw py::cast_error(msg);
44     } catch (py::reference_cast_error &err) {
45       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46                          action + " (" + err.what() + ")")
47                             .str();
48       throw py::cast_error(msg);
49     }
50   }
51 }
52 
53 template <typename PermutationTy>
54 static bool isPermutation(std::vector<PermutationTy> permutation) {
55   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56   for (auto val : permutation) {
57     if (val < permutation.size()) {
58       if (seen[val])
59         return false;
60       seen[val] = true;
61       continue;
62     }
63     return false;
64   }
65   return true;
66 }
67 
68 namespace {
69 
70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71 /// and should be castable from it. Intermediate hierarchy classes can be
72 /// modeled by specifying BaseTy.
73 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74 class PyConcreteAffineExpr : public BaseTy {
75 public:
76   // Derived classes must define statics for:
77   //   IsAFunctionTy isaFunction
78   //   const char *pyClassName
79   // and redefine bindDerived.
80   using ClassTy = py::class_<DerivedTy, BaseTy>;
81   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82 
83   PyConcreteAffineExpr() = default;
84   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85       : BaseTy(std::move(contextRef), affineExpr) {}
86   PyConcreteAffineExpr(PyAffineExpr &orig)
87       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88 
89   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90     if (!DerivedTy::isaFunction(orig)) {
91       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92       throw SetPyError(PyExc_ValueError,
93                        Twine("Cannot cast affine expression to ") +
94                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95     }
96     return orig;
97   }
98 
99   static void bind(py::module &m) {
100     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
101     cls.def(py::init<PyAffineExpr &>());
102     DerivedTy::bindDerived(cls);
103   }
104 
105   /// Implemented by derived classes to add methods to the Python subclass.
106   static void bindDerived(ClassTy &m) {}
107 };
108 
109 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
110 public:
111   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
112   static constexpr const char *pyClassName = "AffineConstantExpr";
113   using PyConcreteAffineExpr::PyConcreteAffineExpr;
114 
115   static PyAffineConstantExpr get(intptr_t value,
116                                   DefaultingPyMlirContext context) {
117     MlirAffineExpr affineExpr =
118         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
119     return PyAffineConstantExpr(context->getRef(), affineExpr);
120   }
121 
122   static void bindDerived(ClassTy &c) {
123     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
124                  py::arg("context") = py::none());
125     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
126       return mlirAffineConstantExprGetValue(self);
127     });
128   }
129 };
130 
131 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
132 public:
133   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
134   static constexpr const char *pyClassName = "AffineDimExpr";
135   using PyConcreteAffineExpr::PyConcreteAffineExpr;
136 
137   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
138     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
139     return PyAffineDimExpr(context->getRef(), affineExpr);
140   }
141 
142   static void bindDerived(ClassTy &c) {
143     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
144                  py::arg("context") = py::none());
145     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
146       return mlirAffineDimExprGetPosition(self);
147     });
148   }
149 };
150 
151 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
152 public:
153   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
154   static constexpr const char *pyClassName = "AffineSymbolExpr";
155   using PyConcreteAffineExpr::PyConcreteAffineExpr;
156 
157   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
158     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
159     return PyAffineSymbolExpr(context->getRef(), affineExpr);
160   }
161 
162   static void bindDerived(ClassTy &c) {
163     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
164                  py::arg("context") = py::none());
165     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
166       return mlirAffineSymbolExprGetPosition(self);
167     });
168   }
169 };
170 
171 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
172 public:
173   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
174   static constexpr const char *pyClassName = "AffineBinaryExpr";
175   using PyConcreteAffineExpr::PyConcreteAffineExpr;
176 
177   PyAffineExpr lhs() {
178     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
179     return PyAffineExpr(getContext(), lhsExpr);
180   }
181 
182   PyAffineExpr rhs() {
183     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
184     return PyAffineExpr(getContext(), rhsExpr);
185   }
186 
187   static void bindDerived(ClassTy &c) {
188     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
189     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
190   }
191 };
192 
193 class PyAffineAddExpr
194     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
195 public:
196   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
197   static constexpr const char *pyClassName = "AffineAddExpr";
198   using PyConcreteAffineExpr::PyConcreteAffineExpr;
199 
200   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
201     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
202     return PyAffineAddExpr(lhs.getContext(), expr);
203   }
204 
205   static void bindDerived(ClassTy &c) {
206     c.def_static("get", &PyAffineAddExpr::get);
207   }
208 };
209 
210 class PyAffineMulExpr
211     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
212 public:
213   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
214   static constexpr const char *pyClassName = "AffineMulExpr";
215   using PyConcreteAffineExpr::PyConcreteAffineExpr;
216 
217   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
218     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
219     return PyAffineMulExpr(lhs.getContext(), expr);
220   }
221 
222   static void bindDerived(ClassTy &c) {
223     c.def_static("get", &PyAffineMulExpr::get);
224   }
225 };
226 
227 class PyAffineModExpr
228     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
229 public:
230   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
231   static constexpr const char *pyClassName = "AffineModExpr";
232   using PyConcreteAffineExpr::PyConcreteAffineExpr;
233 
234   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
235     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
236     return PyAffineModExpr(lhs.getContext(), expr);
237   }
238 
239   static void bindDerived(ClassTy &c) {
240     c.def_static("get", &PyAffineModExpr::get);
241   }
242 };
243 
244 class PyAffineFloorDivExpr
245     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
246 public:
247   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
248   static constexpr const char *pyClassName = "AffineFloorDivExpr";
249   using PyConcreteAffineExpr::PyConcreteAffineExpr;
250 
251   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
252     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
253     return PyAffineFloorDivExpr(lhs.getContext(), expr);
254   }
255 
256   static void bindDerived(ClassTy &c) {
257     c.def_static("get", &PyAffineFloorDivExpr::get);
258   }
259 };
260 
261 class PyAffineCeilDivExpr
262     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
263 public:
264   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
265   static constexpr const char *pyClassName = "AffineCeilDivExpr";
266   using PyConcreteAffineExpr::PyConcreteAffineExpr;
267 
268   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
269     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
270     return PyAffineCeilDivExpr(lhs.getContext(), expr);
271   }
272 
273   static void bindDerived(ClassTy &c) {
274     c.def_static("get", &PyAffineCeilDivExpr::get);
275   }
276 };
277 
278 } // namespace
279 
280 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
281   return mlirAffineExprEqual(affineExpr, other.affineExpr);
282 }
283 
284 py::object PyAffineExpr::getCapsule() {
285   return py::reinterpret_steal<py::object>(
286       mlirPythonAffineExprToCapsule(*this));
287 }
288 
289 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
290   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
291   if (mlirAffineExprIsNull(rawAffineExpr))
292     throw py::error_already_set();
293   return PyAffineExpr(
294       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
295       rawAffineExpr);
296 }
297 
298 //------------------------------------------------------------------------------
299 // PyAffineMap and utilities.
300 //------------------------------------------------------------------------------
301 namespace {
302 
303 /// A list of expressions contained in an affine map. Internally these are
304 /// stored as a consecutive array leading to inexpensive random access. Both
305 /// the map and the expression are owned by the context so we need not bother
306 /// with lifetime extension.
307 class PyAffineMapExprList
308     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
309 public:
310   static constexpr const char *pyClassName = "AffineExprList";
311 
312   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
313                       intptr_t length = -1, intptr_t step = 1)
314       : Sliceable(startIndex,
315                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
316                   step),
317         affineMap(map) {}
318 
319   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
320 
321   PyAffineExpr getElement(intptr_t pos) {
322     return PyAffineExpr(affineMap.getContext(),
323                         mlirAffineMapGetResult(affineMap, pos));
324   }
325 
326   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
327                             intptr_t step) {
328     return PyAffineMapExprList(affineMap, startIndex, length, step);
329   }
330 
331 private:
332   PyAffineMap affineMap;
333 };
334 } // end namespace
335 
336 bool PyAffineMap::operator==(const PyAffineMap &other) {
337   return mlirAffineMapEqual(affineMap, other.affineMap);
338 }
339 
340 py::object PyAffineMap::getCapsule() {
341   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
342 }
343 
344 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
345   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
346   if (mlirAffineMapIsNull(rawAffineMap))
347     throw py::error_already_set();
348   return PyAffineMap(
349       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
350       rawAffineMap);
351 }
352 
353 //------------------------------------------------------------------------------
354 // PyIntegerSet and utilities.
355 //------------------------------------------------------------------------------
356 namespace {
357 
358 class PyIntegerSetConstraint {
359 public:
360   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
361 
362   PyAffineExpr getExpr() {
363     return PyAffineExpr(set.getContext(),
364                         mlirIntegerSetGetConstraint(set, pos));
365   }
366 
367   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
368 
369   static void bind(py::module &m) {
370     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
371                                        py::module_local())
372         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
373         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
374   }
375 
376 private:
377   PyIntegerSet set;
378   intptr_t pos;
379 };
380 
381 class PyIntegerSetConstraintList
382     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
383 public:
384   static constexpr const char *pyClassName = "IntegerSetConstraintList";
385 
386   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
387                              intptr_t length = -1, intptr_t step = 1)
388       : Sliceable(startIndex,
389                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
390                   step),
391         set(set) {}
392 
393   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
394 
395   PyIntegerSetConstraint getElement(intptr_t pos) {
396     return PyIntegerSetConstraint(set, pos);
397   }
398 
399   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
400                                    intptr_t step) {
401     return PyIntegerSetConstraintList(set, startIndex, length, step);
402   }
403 
404 private:
405   PyIntegerSet set;
406 };
407 } // namespace
408 
409 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
410   return mlirIntegerSetEqual(integerSet, other.integerSet);
411 }
412 
413 py::object PyIntegerSet::getCapsule() {
414   return py::reinterpret_steal<py::object>(
415       mlirPythonIntegerSetToCapsule(*this));
416 }
417 
418 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
419   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
420   if (mlirIntegerSetIsNull(rawIntegerSet))
421     throw py::error_already_set();
422   return PyIntegerSet(
423       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
424       rawIntegerSet);
425 }
426 
427 void mlir::python::populateIRAffine(py::module &m) {
428   //----------------------------------------------------------------------------
429   // Mapping of PyAffineExpr and derived classes.
430   //----------------------------------------------------------------------------
431   py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
432       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
433                              &PyAffineExpr::getCapsule)
434       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
435       .def("__add__",
436            [](PyAffineExpr &self, PyAffineExpr &other) {
437              return PyAffineAddExpr::get(self, other);
438            })
439       .def("__mul__",
440            [](PyAffineExpr &self, PyAffineExpr &other) {
441              return PyAffineMulExpr::get(self, other);
442            })
443       .def("__mod__",
444            [](PyAffineExpr &self, PyAffineExpr &other) {
445              return PyAffineModExpr::get(self, other);
446            })
447       .def("__sub__",
448            [](PyAffineExpr &self, PyAffineExpr &other) {
449              auto negOne =
450                  PyAffineConstantExpr::get(-1, *self.getContext().get());
451              return PyAffineAddExpr::get(self,
452                                          PyAffineMulExpr::get(negOne, other));
453            })
454       .def("__eq__", [](PyAffineExpr &self,
455                         PyAffineExpr &other) { return self == other; })
456       .def("__eq__",
457            [](PyAffineExpr &self, py::object &other) { return false; })
458       .def("__str__",
459            [](PyAffineExpr &self) {
460              PyPrintAccumulator printAccum;
461              mlirAffineExprPrint(self, printAccum.getCallback(),
462                                  printAccum.getUserData());
463              return printAccum.join();
464            })
465       .def("__repr__",
466            [](PyAffineExpr &self) {
467              PyPrintAccumulator printAccum;
468              printAccum.parts.append("AffineExpr(");
469              mlirAffineExprPrint(self, printAccum.getCallback(),
470                                  printAccum.getUserData());
471              printAccum.parts.append(")");
472              return printAccum.join();
473            })
474       .def_property_readonly(
475           "context",
476           [](PyAffineExpr &self) { return self.getContext().getObject(); })
477       .def_static(
478           "get_add", &PyAffineAddExpr::get,
479           "Gets an affine expression containing a sum of two expressions.")
480       .def_static(
481           "get_mul", &PyAffineMulExpr::get,
482           "Gets an affine expression containing a product of two expressions.")
483       .def_static("get_mod", &PyAffineModExpr::get,
484                   "Gets an affine expression containing the modulo of dividing "
485                   "one expression by another.")
486       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
487                   "Gets an affine expression containing the rounded-down "
488                   "result of dividing one expression by another.")
489       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
490                   "Gets an affine expression containing the rounded-up result "
491                   "of dividing one expression by another.")
492       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
493                   py::arg("context") = py::none(),
494                   "Gets a constant affine expression with the given value.")
495       .def_static(
496           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
497           py::arg("context") = py::none(),
498           "Gets an affine expression of a dimension at the given position.")
499       .def_static(
500           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
501           py::arg("context") = py::none(),
502           "Gets an affine expression of a symbol at the given position.")
503       .def(
504           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
505           kDumpDocstring);
506   PyAffineConstantExpr::bind(m);
507   PyAffineDimExpr::bind(m);
508   PyAffineSymbolExpr::bind(m);
509   PyAffineBinaryExpr::bind(m);
510   PyAffineAddExpr::bind(m);
511   PyAffineMulExpr::bind(m);
512   PyAffineModExpr::bind(m);
513   PyAffineFloorDivExpr::bind(m);
514   PyAffineCeilDivExpr::bind(m);
515 
516   //----------------------------------------------------------------------------
517   // Mapping of PyAffineMap.
518   //----------------------------------------------------------------------------
519   py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
520       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
521                              &PyAffineMap::getCapsule)
522       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
523       .def("__eq__",
524            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
525       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
526       .def("__str__",
527            [](PyAffineMap &self) {
528              PyPrintAccumulator printAccum;
529              mlirAffineMapPrint(self, printAccum.getCallback(),
530                                 printAccum.getUserData());
531              return printAccum.join();
532            })
533       .def("__repr__",
534            [](PyAffineMap &self) {
535              PyPrintAccumulator printAccum;
536              printAccum.parts.append("AffineMap(");
537              mlirAffineMapPrint(self, printAccum.getCallback(),
538                                 printAccum.getUserData());
539              printAccum.parts.append(")");
540              return printAccum.join();
541            })
542       .def_static("compress_unused_symbols",
543                   [](py::list affineMaps, DefaultingPyMlirContext context) {
544                     SmallVector<MlirAffineMap> maps;
545                     pyListToVector<PyAffineMap, MlirAffineMap>(
546                         affineMaps, maps, "attempting to create an AffineMap");
547                     std::vector<MlirAffineMap> compressed(affineMaps.size());
548                     auto populate = [](void *result, intptr_t idx,
549                                        MlirAffineMap m) {
550                       static_cast<MlirAffineMap *>(result)[idx] = (m);
551                     };
552                     mlirAffineMapCompressUnusedSymbols(
553                         maps.data(), maps.size(), compressed.data(), populate);
554                     std::vector<PyAffineMap> res;
555                     for (auto m : compressed)
556                       res.push_back(PyAffineMap(context->getRef(), m));
557                     return res;
558                   })
559       .def_property_readonly(
560           "context",
561           [](PyAffineMap &self) { return self.getContext().getObject(); },
562           "Context that owns the Affine Map")
563       .def(
564           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
565           kDumpDocstring)
566       .def_static(
567           "get",
568           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
569              DefaultingPyMlirContext context) {
570             SmallVector<MlirAffineExpr> affineExprs;
571             pyListToVector<PyAffineExpr, MlirAffineExpr>(
572                 exprs, affineExprs, "attempting to create an AffineMap");
573             MlirAffineMap map =
574                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
575                                  affineExprs.size(), affineExprs.data());
576             return PyAffineMap(context->getRef(), map);
577           },
578           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
579           py::arg("context") = py::none(),
580           "Gets a map with the given expressions as results.")
581       .def_static(
582           "get_constant",
583           [](intptr_t value, DefaultingPyMlirContext context) {
584             MlirAffineMap affineMap =
585                 mlirAffineMapConstantGet(context->get(), value);
586             return PyAffineMap(context->getRef(), affineMap);
587           },
588           py::arg("value"), py::arg("context") = py::none(),
589           "Gets an affine map with a single constant result")
590       .def_static(
591           "get_empty",
592           [](DefaultingPyMlirContext context) {
593             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
594             return PyAffineMap(context->getRef(), affineMap);
595           },
596           py::arg("context") = py::none(), "Gets an empty affine map.")
597       .def_static(
598           "get_identity",
599           [](intptr_t nDims, DefaultingPyMlirContext context) {
600             MlirAffineMap affineMap =
601                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
602             return PyAffineMap(context->getRef(), affineMap);
603           },
604           py::arg("n_dims"), py::arg("context") = py::none(),
605           "Gets an identity map with the given number of dimensions.")
606       .def_static(
607           "get_minor_identity",
608           [](intptr_t nDims, intptr_t nResults,
609              DefaultingPyMlirContext context) {
610             MlirAffineMap affineMap =
611                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
612             return PyAffineMap(context->getRef(), affineMap);
613           },
614           py::arg("n_dims"), py::arg("n_results"),
615           py::arg("context") = py::none(),
616           "Gets a minor identity map with the given number of dimensions and "
617           "results.")
618       .def_static(
619           "get_permutation",
620           [](std::vector<unsigned> permutation,
621              DefaultingPyMlirContext context) {
622             if (!isPermutation(permutation))
623               throw py::cast_error("Invalid permutation when attempting to "
624                                    "create an AffineMap");
625             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
626                 context->get(), permutation.size(), permutation.data());
627             return PyAffineMap(context->getRef(), affineMap);
628           },
629           py::arg("permutation"), py::arg("context") = py::none(),
630           "Gets an affine map that permutes its inputs.")
631       .def("get_submap",
632            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
633              intptr_t numResults = mlirAffineMapGetNumResults(self);
634              for (intptr_t pos : resultPos) {
635                if (pos < 0 || pos >= numResults)
636                  throw py::value_error("result position out of bounds");
637              }
638              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
639                  self, resultPos.size(), resultPos.data());
640              return PyAffineMap(self.getContext(), affineMap);
641            })
642       .def("get_major_submap",
643            [](PyAffineMap &self, intptr_t nResults) {
644              if (nResults >= mlirAffineMapGetNumResults(self))
645                throw py::value_error("number of results out of bounds");
646              MlirAffineMap affineMap =
647                  mlirAffineMapGetMajorSubMap(self, nResults);
648              return PyAffineMap(self.getContext(), affineMap);
649            })
650       .def("get_minor_submap",
651            [](PyAffineMap &self, intptr_t nResults) {
652              if (nResults >= mlirAffineMapGetNumResults(self))
653                throw py::value_error("number of results out of bounds");
654              MlirAffineMap affineMap =
655                  mlirAffineMapGetMinorSubMap(self, nResults);
656              return PyAffineMap(self.getContext(), affineMap);
657            })
658       .def("replace",
659            [](PyAffineMap &self, PyAffineExpr &expression,
660               PyAffineExpr &replacement, intptr_t numResultDims,
661               intptr_t numResultSyms) {
662              MlirAffineMap affineMap = mlirAffineMapReplace(
663                  self, expression, replacement, numResultDims, numResultSyms);
664              return PyAffineMap(self.getContext(), affineMap);
665            })
666       .def_property_readonly(
667           "is_permutation",
668           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
669       .def_property_readonly("is_projected_permutation",
670                              [](PyAffineMap &self) {
671                                return mlirAffineMapIsProjectedPermutation(self);
672                              })
673       .def_property_readonly(
674           "n_dims",
675           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
676       .def_property_readonly(
677           "n_inputs",
678           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
679       .def_property_readonly(
680           "n_symbols",
681           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
682       .def_property_readonly("results", [](PyAffineMap &self) {
683         return PyAffineMapExprList(self);
684       });
685   PyAffineMapExprList::bind(m);
686 
687   //----------------------------------------------------------------------------
688   // Mapping of PyIntegerSet.
689   //----------------------------------------------------------------------------
690   py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
691       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
692                              &PyIntegerSet::getCapsule)
693       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
694       .def("__eq__", [](PyIntegerSet &self,
695                         PyIntegerSet &other) { return self == other; })
696       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
697       .def("__str__",
698            [](PyIntegerSet &self) {
699              PyPrintAccumulator printAccum;
700              mlirIntegerSetPrint(self, printAccum.getCallback(),
701                                  printAccum.getUserData());
702              return printAccum.join();
703            })
704       .def("__repr__",
705            [](PyIntegerSet &self) {
706              PyPrintAccumulator printAccum;
707              printAccum.parts.append("IntegerSet(");
708              mlirIntegerSetPrint(self, printAccum.getCallback(),
709                                  printAccum.getUserData());
710              printAccum.parts.append(")");
711              return printAccum.join();
712            })
713       .def_property_readonly(
714           "context",
715           [](PyIntegerSet &self) { return self.getContext().getObject(); })
716       .def(
717           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
718           kDumpDocstring)
719       .def_static(
720           "get",
721           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
722              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
723             if (exprs.size() != eqFlags.size())
724               throw py::value_error(
725                   "Expected the number of constraints to match "
726                   "that of equality flags");
727             if (exprs.empty())
728               throw py::value_error("Expected non-empty list of constraints");
729 
730             // Copy over to a SmallVector because std::vector has a
731             // specialization for booleans that packs data and does not
732             // expose a `bool *`.
733             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
734 
735             SmallVector<MlirAffineExpr> affineExprs;
736             pyListToVector<PyAffineExpr>(exprs, affineExprs,
737                                          "attempting to create an IntegerSet");
738             MlirIntegerSet set = mlirIntegerSetGet(
739                 context->get(), numDims, numSymbols, exprs.size(),
740                 affineExprs.data(), flags.data());
741             return PyIntegerSet(context->getRef(), set);
742           },
743           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
744           py::arg("eq_flags"), py::arg("context") = py::none())
745       .def_static(
746           "get_empty",
747           [](intptr_t numDims, intptr_t numSymbols,
748              DefaultingPyMlirContext context) {
749             MlirIntegerSet set =
750                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
751             return PyIntegerSet(context->getRef(), set);
752           },
753           py::arg("num_dims"), py::arg("num_symbols"),
754           py::arg("context") = py::none())
755       .def("get_replaced",
756            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
757               intptr_t numResultDims, intptr_t numResultSymbols) {
758              if (static_cast<intptr_t>(dimExprs.size()) !=
759                  mlirIntegerSetGetNumDims(self))
760                throw py::value_error(
761                    "Expected the number of dimension replacement expressions "
762                    "to match that of dimensions");
763              if (static_cast<intptr_t>(symbolExprs.size()) !=
764                  mlirIntegerSetGetNumSymbols(self))
765                throw py::value_error(
766                    "Expected the number of symbol replacement expressions "
767                    "to match that of symbols");
768 
769              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
770              pyListToVector<PyAffineExpr>(
771                  dimExprs, dimAffineExprs,
772                  "attempting to create an IntegerSet by replacing dimensions");
773              pyListToVector<PyAffineExpr>(
774                  symbolExprs, symbolAffineExprs,
775                  "attempting to create an IntegerSet by replacing symbols");
776              MlirIntegerSet set = mlirIntegerSetReplaceGet(
777                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
778                  numResultDims, numResultSymbols);
779              return PyIntegerSet(self.getContext(), set);
780            })
781       .def_property_readonly("is_canonical_empty",
782                              [](PyIntegerSet &self) {
783                                return mlirIntegerSetIsCanonicalEmpty(self);
784                              })
785       .def_property_readonly(
786           "n_dims",
787           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
788       .def_property_readonly(
789           "n_symbols",
790           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
791       .def_property_readonly(
792           "n_inputs",
793           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
794       .def_property_readonly("n_equalities",
795                              [](PyIntegerSet &self) {
796                                return mlirIntegerSetGetNumEqualities(self);
797                              })
798       .def_property_readonly("n_inequalities",
799                              [](PyIntegerSet &self) {
800                                return mlirIntegerSetGetNumInequalities(self);
801                              })
802       .def_property_readonly("constraints", [](PyIntegerSet &self) {
803         return PyIntegerSetConstraintList(self);
804       });
805   PyIntegerSetConstraint::bind(m);
806   PyIntegerSetConstraintList::bind(m);
807 }
808