1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/AffineMap.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/IntegerSet.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 using llvm::SmallVector;
22 using llvm::StringRef;
23 using llvm::Twine;
24 
25 static const char kDumpDocstring[] =
26     R"(Dumps a debug representation of the object to stderr.)";
27 
28 /// Attempts to populate `result` with the content of `list` casted to the
29 /// appropriate type (Python and C types are provided as template arguments).
30 /// Throws errors in case of failure, using "action" to describe what the caller
31 /// was attempting to do.
32 template <typename PyType, typename CType>
33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34                            StringRef action) {
35   result.reserve(py::len(list));
36   for (py::handle item : list) {
37     try {
38       result.push_back(item.cast<PyType>());
39     } catch (py::cast_error &err) {
40       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41                          " (" + err.what() + ")")
42                             .str();
43       throw py::cast_error(msg);
44     } catch (py::reference_cast_error &err) {
45       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46                          action + " (" + err.what() + ")")
47                             .str();
48       throw py::cast_error(msg);
49     }
50   }
51 }
52 
53 template <typename PermutationTy>
54 static bool isPermutation(std::vector<PermutationTy> permutation) {
55   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56   for (auto val : permutation) {
57     if (val < permutation.size()) {
58       if (seen[val])
59         return false;
60       seen[val] = true;
61       continue;
62     }
63     return false;
64   }
65   return true;
66 }
67 
68 namespace {
69 
70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71 /// and should be castable from it. Intermediate hierarchy classes can be
72 /// modeled by specifying BaseTy.
73 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74 class PyConcreteAffineExpr : public BaseTy {
75 public:
76   // Derived classes must define statics for:
77   //   IsAFunctionTy isaFunction
78   //   const char *pyClassName
79   // and redefine bindDerived.
80   using ClassTy = py::class_<DerivedTy, BaseTy>;
81   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82 
83   PyConcreteAffineExpr() = default;
84   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85       : BaseTy(std::move(contextRef), affineExpr) {}
86   PyConcreteAffineExpr(PyAffineExpr &orig)
87       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88 
89   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90     if (!DerivedTy::isaFunction(orig)) {
91       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92       throw SetPyError(PyExc_ValueError,
93                        Twine("Cannot cast affine expression to ") +
94                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95     }
96     return orig;
97   }
98 
99   static void bind(py::module &m) {
100     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
101     cls.def(py::init<PyAffineExpr &>());
102     cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool {
103       return DerivedTy::isaFunction(otherAffineExpr);
104     });
105     DerivedTy::bindDerived(cls);
106   }
107 
108   /// Implemented by derived classes to add methods to the Python subclass.
109   static void bindDerived(ClassTy &m) {}
110 };
111 
112 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
113 public:
114   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
115   static constexpr const char *pyClassName = "AffineConstantExpr";
116   using PyConcreteAffineExpr::PyConcreteAffineExpr;
117 
118   static PyAffineConstantExpr get(intptr_t value,
119                                   DefaultingPyMlirContext context) {
120     MlirAffineExpr affineExpr =
121         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
122     return PyAffineConstantExpr(context->getRef(), affineExpr);
123   }
124 
125   static void bindDerived(ClassTy &c) {
126     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
127                  py::arg("context") = py::none());
128     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
129       return mlirAffineConstantExprGetValue(self);
130     });
131   }
132 };
133 
134 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
135 public:
136   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
137   static constexpr const char *pyClassName = "AffineDimExpr";
138   using PyConcreteAffineExpr::PyConcreteAffineExpr;
139 
140   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
141     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
142     return PyAffineDimExpr(context->getRef(), affineExpr);
143   }
144 
145   static void bindDerived(ClassTy &c) {
146     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
147                  py::arg("context") = py::none());
148     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
149       return mlirAffineDimExprGetPosition(self);
150     });
151   }
152 };
153 
154 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
155 public:
156   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
157   static constexpr const char *pyClassName = "AffineSymbolExpr";
158   using PyConcreteAffineExpr::PyConcreteAffineExpr;
159 
160   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
161     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
162     return PyAffineSymbolExpr(context->getRef(), affineExpr);
163   }
164 
165   static void bindDerived(ClassTy &c) {
166     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
167                  py::arg("context") = py::none());
168     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
169       return mlirAffineSymbolExprGetPosition(self);
170     });
171   }
172 };
173 
174 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
175 public:
176   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
177   static constexpr const char *pyClassName = "AffineBinaryExpr";
178   using PyConcreteAffineExpr::PyConcreteAffineExpr;
179 
180   PyAffineExpr lhs() {
181     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
182     return PyAffineExpr(getContext(), lhsExpr);
183   }
184 
185   PyAffineExpr rhs() {
186     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
187     return PyAffineExpr(getContext(), rhsExpr);
188   }
189 
190   static void bindDerived(ClassTy &c) {
191     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
192     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
193   }
194 };
195 
196 class PyAffineAddExpr
197     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
198 public:
199   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
200   static constexpr const char *pyClassName = "AffineAddExpr";
201   using PyConcreteAffineExpr::PyConcreteAffineExpr;
202 
203   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
204     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
205     return PyAffineAddExpr(lhs.getContext(), expr);
206   }
207 
208   static void bindDerived(ClassTy &c) {
209     c.def_static("get", &PyAffineAddExpr::get);
210   }
211 };
212 
213 class PyAffineMulExpr
214     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
215 public:
216   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
217   static constexpr const char *pyClassName = "AffineMulExpr";
218   using PyConcreteAffineExpr::PyConcreteAffineExpr;
219 
220   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
221     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
222     return PyAffineMulExpr(lhs.getContext(), expr);
223   }
224 
225   static void bindDerived(ClassTy &c) {
226     c.def_static("get", &PyAffineMulExpr::get);
227   }
228 };
229 
230 class PyAffineModExpr
231     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
232 public:
233   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
234   static constexpr const char *pyClassName = "AffineModExpr";
235   using PyConcreteAffineExpr::PyConcreteAffineExpr;
236 
237   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
238     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
239     return PyAffineModExpr(lhs.getContext(), expr);
240   }
241 
242   static void bindDerived(ClassTy &c) {
243     c.def_static("get", &PyAffineModExpr::get);
244   }
245 };
246 
247 class PyAffineFloorDivExpr
248     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
249 public:
250   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
251   static constexpr const char *pyClassName = "AffineFloorDivExpr";
252   using PyConcreteAffineExpr::PyConcreteAffineExpr;
253 
254   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
255     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
256     return PyAffineFloorDivExpr(lhs.getContext(), expr);
257   }
258 
259   static void bindDerived(ClassTy &c) {
260     c.def_static("get", &PyAffineFloorDivExpr::get);
261   }
262 };
263 
264 class PyAffineCeilDivExpr
265     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
266 public:
267   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
268   static constexpr const char *pyClassName = "AffineCeilDivExpr";
269   using PyConcreteAffineExpr::PyConcreteAffineExpr;
270 
271   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
272     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
273     return PyAffineCeilDivExpr(lhs.getContext(), expr);
274   }
275 
276   static void bindDerived(ClassTy &c) {
277     c.def_static("get", &PyAffineCeilDivExpr::get);
278   }
279 };
280 
281 } // namespace
282 
283 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
284   return mlirAffineExprEqual(affineExpr, other.affineExpr);
285 }
286 
287 py::object PyAffineExpr::getCapsule() {
288   return py::reinterpret_steal<py::object>(
289       mlirPythonAffineExprToCapsule(*this));
290 }
291 
292 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
293   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
294   if (mlirAffineExprIsNull(rawAffineExpr))
295     throw py::error_already_set();
296   return PyAffineExpr(
297       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
298       rawAffineExpr);
299 }
300 
301 //------------------------------------------------------------------------------
302 // PyAffineMap and utilities.
303 //------------------------------------------------------------------------------
304 namespace {
305 
306 /// A list of expressions contained in an affine map. Internally these are
307 /// stored as a consecutive array leading to inexpensive random access. Both
308 /// the map and the expression are owned by the context so we need not bother
309 /// with lifetime extension.
310 class PyAffineMapExprList
311     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
312 public:
313   static constexpr const char *pyClassName = "AffineExprList";
314 
315   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
316                       intptr_t length = -1, intptr_t step = 1)
317       : Sliceable(startIndex,
318                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
319                   step),
320         affineMap(map) {}
321 
322   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
323 
324   PyAffineExpr getElement(intptr_t pos) {
325     return PyAffineExpr(affineMap.getContext(),
326                         mlirAffineMapGetResult(affineMap, pos));
327   }
328 
329   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
330                             intptr_t step) {
331     return PyAffineMapExprList(affineMap, startIndex, length, step);
332   }
333 
334 private:
335   PyAffineMap affineMap;
336 };
337 } // end namespace
338 
339 bool PyAffineMap::operator==(const PyAffineMap &other) {
340   return mlirAffineMapEqual(affineMap, other.affineMap);
341 }
342 
343 py::object PyAffineMap::getCapsule() {
344   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
345 }
346 
347 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
348   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
349   if (mlirAffineMapIsNull(rawAffineMap))
350     throw py::error_already_set();
351   return PyAffineMap(
352       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
353       rawAffineMap);
354 }
355 
356 //------------------------------------------------------------------------------
357 // PyIntegerSet and utilities.
358 //------------------------------------------------------------------------------
359 namespace {
360 
361 class PyIntegerSetConstraint {
362 public:
363   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
364 
365   PyAffineExpr getExpr() {
366     return PyAffineExpr(set.getContext(),
367                         mlirIntegerSetGetConstraint(set, pos));
368   }
369 
370   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
371 
372   static void bind(py::module &m) {
373     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
374                                        py::module_local())
375         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
376         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
377   }
378 
379 private:
380   PyIntegerSet set;
381   intptr_t pos;
382 };
383 
384 class PyIntegerSetConstraintList
385     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
386 public:
387   static constexpr const char *pyClassName = "IntegerSetConstraintList";
388 
389   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
390                              intptr_t length = -1, intptr_t step = 1)
391       : Sliceable(startIndex,
392                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
393                   step),
394         set(set) {}
395 
396   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
397 
398   PyIntegerSetConstraint getElement(intptr_t pos) {
399     return PyIntegerSetConstraint(set, pos);
400   }
401 
402   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
403                                    intptr_t step) {
404     return PyIntegerSetConstraintList(set, startIndex, length, step);
405   }
406 
407 private:
408   PyIntegerSet set;
409 };
410 } // namespace
411 
412 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
413   return mlirIntegerSetEqual(integerSet, other.integerSet);
414 }
415 
416 py::object PyIntegerSet::getCapsule() {
417   return py::reinterpret_steal<py::object>(
418       mlirPythonIntegerSetToCapsule(*this));
419 }
420 
421 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
422   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
423   if (mlirIntegerSetIsNull(rawIntegerSet))
424     throw py::error_already_set();
425   return PyIntegerSet(
426       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
427       rawIntegerSet);
428 }
429 
430 void mlir::python::populateIRAffine(py::module &m) {
431   //----------------------------------------------------------------------------
432   // Mapping of PyAffineExpr and derived classes.
433   //----------------------------------------------------------------------------
434   py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
435       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
436                              &PyAffineExpr::getCapsule)
437       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
438       .def("__add__",
439            [](PyAffineExpr &self, PyAffineExpr &other) {
440              return PyAffineAddExpr::get(self, other);
441            })
442       .def("__mul__",
443            [](PyAffineExpr &self, PyAffineExpr &other) {
444              return PyAffineMulExpr::get(self, other);
445            })
446       .def("__mod__",
447            [](PyAffineExpr &self, PyAffineExpr &other) {
448              return PyAffineModExpr::get(self, other);
449            })
450       .def("__sub__",
451            [](PyAffineExpr &self, PyAffineExpr &other) {
452              auto negOne =
453                  PyAffineConstantExpr::get(-1, *self.getContext().get());
454              return PyAffineAddExpr::get(self,
455                                          PyAffineMulExpr::get(negOne, other));
456            })
457       .def("__eq__", [](PyAffineExpr &self,
458                         PyAffineExpr &other) { return self == other; })
459       .def("__eq__",
460            [](PyAffineExpr &self, py::object &other) { return false; })
461       .def("__str__",
462            [](PyAffineExpr &self) {
463              PyPrintAccumulator printAccum;
464              mlirAffineExprPrint(self, printAccum.getCallback(),
465                                  printAccum.getUserData());
466              return printAccum.join();
467            })
468       .def("__repr__",
469            [](PyAffineExpr &self) {
470              PyPrintAccumulator printAccum;
471              printAccum.parts.append("AffineExpr(");
472              mlirAffineExprPrint(self, printAccum.getCallback(),
473                                  printAccum.getUserData());
474              printAccum.parts.append(")");
475              return printAccum.join();
476            })
477       .def_property_readonly(
478           "context",
479           [](PyAffineExpr &self) { return self.getContext().getObject(); })
480       .def_static(
481           "get_add", &PyAffineAddExpr::get,
482           "Gets an affine expression containing a sum of two expressions.")
483       .def_static(
484           "get_mul", &PyAffineMulExpr::get,
485           "Gets an affine expression containing a product of two expressions.")
486       .def_static("get_mod", &PyAffineModExpr::get,
487                   "Gets an affine expression containing the modulo of dividing "
488                   "one expression by another.")
489       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
490                   "Gets an affine expression containing the rounded-down "
491                   "result of dividing one expression by another.")
492       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
493                   "Gets an affine expression containing the rounded-up result "
494                   "of dividing one expression by another.")
495       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
496                   py::arg("context") = py::none(),
497                   "Gets a constant affine expression with the given value.")
498       .def_static(
499           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
500           py::arg("context") = py::none(),
501           "Gets an affine expression of a dimension at the given position.")
502       .def_static(
503           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
504           py::arg("context") = py::none(),
505           "Gets an affine expression of a symbol at the given position.")
506       .def(
507           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
508           kDumpDocstring);
509   PyAffineConstantExpr::bind(m);
510   PyAffineDimExpr::bind(m);
511   PyAffineSymbolExpr::bind(m);
512   PyAffineBinaryExpr::bind(m);
513   PyAffineAddExpr::bind(m);
514   PyAffineMulExpr::bind(m);
515   PyAffineModExpr::bind(m);
516   PyAffineFloorDivExpr::bind(m);
517   PyAffineCeilDivExpr::bind(m);
518 
519   //----------------------------------------------------------------------------
520   // Mapping of PyAffineMap.
521   //----------------------------------------------------------------------------
522   py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
523       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
524                              &PyAffineMap::getCapsule)
525       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
526       .def("__eq__",
527            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
528       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
529       .def("__str__",
530            [](PyAffineMap &self) {
531              PyPrintAccumulator printAccum;
532              mlirAffineMapPrint(self, printAccum.getCallback(),
533                                 printAccum.getUserData());
534              return printAccum.join();
535            })
536       .def("__repr__",
537            [](PyAffineMap &self) {
538              PyPrintAccumulator printAccum;
539              printAccum.parts.append("AffineMap(");
540              mlirAffineMapPrint(self, printAccum.getCallback(),
541                                 printAccum.getUserData());
542              printAccum.parts.append(")");
543              return printAccum.join();
544            })
545       .def_static("compress_unused_symbols",
546                   [](py::list affineMaps, DefaultingPyMlirContext context) {
547                     SmallVector<MlirAffineMap> maps;
548                     pyListToVector<PyAffineMap, MlirAffineMap>(
549                         affineMaps, maps, "attempting to create an AffineMap");
550                     std::vector<MlirAffineMap> compressed(affineMaps.size());
551                     auto populate = [](void *result, intptr_t idx,
552                                        MlirAffineMap m) {
553                       static_cast<MlirAffineMap *>(result)[idx] = (m);
554                     };
555                     mlirAffineMapCompressUnusedSymbols(
556                         maps.data(), maps.size(), compressed.data(), populate);
557                     std::vector<PyAffineMap> res;
558                     res.reserve(compressed.size());
559                     for (auto m : compressed)
560                       res.push_back(PyAffineMap(context->getRef(), m));
561                     return res;
562                   })
563       .def_property_readonly(
564           "context",
565           [](PyAffineMap &self) { return self.getContext().getObject(); },
566           "Context that owns the Affine Map")
567       .def(
568           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
569           kDumpDocstring)
570       .def_static(
571           "get",
572           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
573              DefaultingPyMlirContext context) {
574             SmallVector<MlirAffineExpr> affineExprs;
575             pyListToVector<PyAffineExpr, MlirAffineExpr>(
576                 exprs, affineExprs, "attempting to create an AffineMap");
577             MlirAffineMap map =
578                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
579                                  affineExprs.size(), affineExprs.data());
580             return PyAffineMap(context->getRef(), map);
581           },
582           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
583           py::arg("context") = py::none(),
584           "Gets a map with the given expressions as results.")
585       .def_static(
586           "get_constant",
587           [](intptr_t value, DefaultingPyMlirContext context) {
588             MlirAffineMap affineMap =
589                 mlirAffineMapConstantGet(context->get(), value);
590             return PyAffineMap(context->getRef(), affineMap);
591           },
592           py::arg("value"), py::arg("context") = py::none(),
593           "Gets an affine map with a single constant result")
594       .def_static(
595           "get_empty",
596           [](DefaultingPyMlirContext context) {
597             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
598             return PyAffineMap(context->getRef(), affineMap);
599           },
600           py::arg("context") = py::none(), "Gets an empty affine map.")
601       .def_static(
602           "get_identity",
603           [](intptr_t nDims, DefaultingPyMlirContext context) {
604             MlirAffineMap affineMap =
605                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
606             return PyAffineMap(context->getRef(), affineMap);
607           },
608           py::arg("n_dims"), py::arg("context") = py::none(),
609           "Gets an identity map with the given number of dimensions.")
610       .def_static(
611           "get_minor_identity",
612           [](intptr_t nDims, intptr_t nResults,
613              DefaultingPyMlirContext context) {
614             MlirAffineMap affineMap =
615                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
616             return PyAffineMap(context->getRef(), affineMap);
617           },
618           py::arg("n_dims"), py::arg("n_results"),
619           py::arg("context") = py::none(),
620           "Gets a minor identity map with the given number of dimensions and "
621           "results.")
622       .def_static(
623           "get_permutation",
624           [](std::vector<unsigned> permutation,
625              DefaultingPyMlirContext context) {
626             if (!isPermutation(permutation))
627               throw py::cast_error("Invalid permutation when attempting to "
628                                    "create an AffineMap");
629             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
630                 context->get(), permutation.size(), permutation.data());
631             return PyAffineMap(context->getRef(), affineMap);
632           },
633           py::arg("permutation"), py::arg("context") = py::none(),
634           "Gets an affine map that permutes its inputs.")
635       .def("get_submap",
636            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
637              intptr_t numResults = mlirAffineMapGetNumResults(self);
638              for (intptr_t pos : resultPos) {
639                if (pos < 0 || pos >= numResults)
640                  throw py::value_error("result position out of bounds");
641              }
642              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
643                  self, resultPos.size(), resultPos.data());
644              return PyAffineMap(self.getContext(), affineMap);
645            })
646       .def("get_major_submap",
647            [](PyAffineMap &self, intptr_t nResults) {
648              if (nResults >= mlirAffineMapGetNumResults(self))
649                throw py::value_error("number of results out of bounds");
650              MlirAffineMap affineMap =
651                  mlirAffineMapGetMajorSubMap(self, nResults);
652              return PyAffineMap(self.getContext(), affineMap);
653            })
654       .def("get_minor_submap",
655            [](PyAffineMap &self, intptr_t nResults) {
656              if (nResults >= mlirAffineMapGetNumResults(self))
657                throw py::value_error("number of results out of bounds");
658              MlirAffineMap affineMap =
659                  mlirAffineMapGetMinorSubMap(self, nResults);
660              return PyAffineMap(self.getContext(), affineMap);
661            })
662       .def("replace",
663            [](PyAffineMap &self, PyAffineExpr &expression,
664               PyAffineExpr &replacement, intptr_t numResultDims,
665               intptr_t numResultSyms) {
666              MlirAffineMap affineMap = mlirAffineMapReplace(
667                  self, expression, replacement, numResultDims, numResultSyms);
668              return PyAffineMap(self.getContext(), affineMap);
669            })
670       .def_property_readonly(
671           "is_permutation",
672           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
673       .def_property_readonly("is_projected_permutation",
674                              [](PyAffineMap &self) {
675                                return mlirAffineMapIsProjectedPermutation(self);
676                              })
677       .def_property_readonly(
678           "n_dims",
679           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
680       .def_property_readonly(
681           "n_inputs",
682           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
683       .def_property_readonly(
684           "n_symbols",
685           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
686       .def_property_readonly("results", [](PyAffineMap &self) {
687         return PyAffineMapExprList(self);
688       });
689   PyAffineMapExprList::bind(m);
690 
691   //----------------------------------------------------------------------------
692   // Mapping of PyIntegerSet.
693   //----------------------------------------------------------------------------
694   py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
695       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
696                              &PyIntegerSet::getCapsule)
697       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
698       .def("__eq__", [](PyIntegerSet &self,
699                         PyIntegerSet &other) { return self == other; })
700       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
701       .def("__str__",
702            [](PyIntegerSet &self) {
703              PyPrintAccumulator printAccum;
704              mlirIntegerSetPrint(self, printAccum.getCallback(),
705                                  printAccum.getUserData());
706              return printAccum.join();
707            })
708       .def("__repr__",
709            [](PyIntegerSet &self) {
710              PyPrintAccumulator printAccum;
711              printAccum.parts.append("IntegerSet(");
712              mlirIntegerSetPrint(self, printAccum.getCallback(),
713                                  printAccum.getUserData());
714              printAccum.parts.append(")");
715              return printAccum.join();
716            })
717       .def_property_readonly(
718           "context",
719           [](PyIntegerSet &self) { return self.getContext().getObject(); })
720       .def(
721           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
722           kDumpDocstring)
723       .def_static(
724           "get",
725           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
726              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
727             if (exprs.size() != eqFlags.size())
728               throw py::value_error(
729                   "Expected the number of constraints to match "
730                   "that of equality flags");
731             if (exprs.empty())
732               throw py::value_error("Expected non-empty list of constraints");
733 
734             // Copy over to a SmallVector because std::vector has a
735             // specialization for booleans that packs data and does not
736             // expose a `bool *`.
737             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
738 
739             SmallVector<MlirAffineExpr> affineExprs;
740             pyListToVector<PyAffineExpr>(exprs, affineExprs,
741                                          "attempting to create an IntegerSet");
742             MlirIntegerSet set = mlirIntegerSetGet(
743                 context->get(), numDims, numSymbols, exprs.size(),
744                 affineExprs.data(), flags.data());
745             return PyIntegerSet(context->getRef(), set);
746           },
747           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
748           py::arg("eq_flags"), py::arg("context") = py::none())
749       .def_static(
750           "get_empty",
751           [](intptr_t numDims, intptr_t numSymbols,
752              DefaultingPyMlirContext context) {
753             MlirIntegerSet set =
754                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
755             return PyIntegerSet(context->getRef(), set);
756           },
757           py::arg("num_dims"), py::arg("num_symbols"),
758           py::arg("context") = py::none())
759       .def("get_replaced",
760            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
761               intptr_t numResultDims, intptr_t numResultSymbols) {
762              if (static_cast<intptr_t>(dimExprs.size()) !=
763                  mlirIntegerSetGetNumDims(self))
764                throw py::value_error(
765                    "Expected the number of dimension replacement expressions "
766                    "to match that of dimensions");
767              if (static_cast<intptr_t>(symbolExprs.size()) !=
768                  mlirIntegerSetGetNumSymbols(self))
769                throw py::value_error(
770                    "Expected the number of symbol replacement expressions "
771                    "to match that of symbols");
772 
773              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
774              pyListToVector<PyAffineExpr>(
775                  dimExprs, dimAffineExprs,
776                  "attempting to create an IntegerSet by replacing dimensions");
777              pyListToVector<PyAffineExpr>(
778                  symbolExprs, symbolAffineExprs,
779                  "attempting to create an IntegerSet by replacing symbols");
780              MlirIntegerSet set = mlirIntegerSetReplaceGet(
781                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
782                  numResultDims, numResultSymbols);
783              return PyIntegerSet(self.getContext(), set);
784            })
785       .def_property_readonly("is_canonical_empty",
786                              [](PyIntegerSet &self) {
787                                return mlirIntegerSetIsCanonicalEmpty(self);
788                              })
789       .def_property_readonly(
790           "n_dims",
791           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
792       .def_property_readonly(
793           "n_symbols",
794           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
795       .def_property_readonly(
796           "n_inputs",
797           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
798       .def_property_readonly("n_equalities",
799                              [](PyIntegerSet &self) {
800                                return mlirIntegerSetGetNumEqualities(self);
801                              })
802       .def_property_readonly("n_inequalities",
803                              [](PyIntegerSet &self) {
804                                return mlirIntegerSetGetNumInequalities(self);
805                              })
806       .def_property_readonly("constraints", [](PyIntegerSet &self) {
807         return PyIntegerSetConstraintList(self);
808       });
809   PyIntegerSetConstraint::bind(m);
810   PyIntegerSetConstraintList::bind(m);
811 }
812