1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/AffineMap.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/IntegerSet.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 using llvm::SmallVector;
22 using llvm::StringRef;
23 using llvm::Twine;
24 
25 static const char kDumpDocstring[] =
26     R"(Dumps a debug representation of the object to stderr.)";
27 
28 /// Attempts to populate `result` with the content of `list` casted to the
29 /// appropriate type (Python and C types are provided as template arguments).
30 /// Throws errors in case of failure, using "action" to describe what the caller
31 /// was attempting to do.
32 template <typename PyType, typename CType>
33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34                            StringRef action) {
35   result.reserve(py::len(list));
36   for (py::handle item : list) {
37     try {
38       result.push_back(item.cast<PyType>());
39     } catch (py::cast_error &err) {
40       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41                          " (" + err.what() + ")")
42                             .str();
43       throw py::cast_error(msg);
44     } catch (py::reference_cast_error &err) {
45       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46                          action + " (" + err.what() + ")")
47                             .str();
48       throw py::cast_error(msg);
49     }
50   }
51 }
52 
53 template <typename PermutationTy>
54 static bool isPermutation(std::vector<PermutationTy> permutation) {
55   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56   for (auto val : permutation) {
57     if (val < permutation.size()) {
58       if (seen[val])
59         return false;
60       seen[val] = true;
61       continue;
62     }
63     return false;
64   }
65   return true;
66 }
67 
68 namespace {
69 
70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71 /// and should be castable from it. Intermediate hierarchy classes can be
72 /// modeled by specifying BaseTy.
73 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74 class PyConcreteAffineExpr : public BaseTy {
75 public:
76   // Derived classes must define statics for:
77   //   IsAFunctionTy isaFunction
78   //   const char *pyClassName
79   // and redefine bindDerived.
80   using ClassTy = py::class_<DerivedTy, BaseTy>;
81   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82 
83   PyConcreteAffineExpr() = default;
84   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85       : BaseTy(std::move(contextRef), affineExpr) {}
86   PyConcreteAffineExpr(PyAffineExpr &orig)
87       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88 
89   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90     if (!DerivedTy::isaFunction(orig)) {
91       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92       throw SetPyError(PyExc_ValueError,
93                        Twine("Cannot cast affine expression to ") +
94                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95     }
96     return orig;
97   }
98 
99   static void bind(py::module &m) {
100     auto cls = ClassTy(m, DerivedTy::pyClassName);
101     cls.def(py::init<PyAffineExpr &>());
102     DerivedTy::bindDerived(cls);
103   }
104 
105   /// Implemented by derived classes to add methods to the Python subclass.
106   static void bindDerived(ClassTy &m) {}
107 };
108 
109 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
110 public:
111   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
112   static constexpr const char *pyClassName = "AffineConstantExpr";
113   using PyConcreteAffineExpr::PyConcreteAffineExpr;
114 
115   static PyAffineConstantExpr get(intptr_t value,
116                                   DefaultingPyMlirContext context) {
117     MlirAffineExpr affineExpr =
118         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
119     return PyAffineConstantExpr(context->getRef(), affineExpr);
120   }
121 
122   static void bindDerived(ClassTy &c) {
123     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
124                  py::arg("context") = py::none());
125     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
126       return mlirAffineConstantExprGetValue(self);
127     });
128   }
129 };
130 
131 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
132 public:
133   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
134   static constexpr const char *pyClassName = "AffineDimExpr";
135   using PyConcreteAffineExpr::PyConcreteAffineExpr;
136 
137   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
138     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
139     return PyAffineDimExpr(context->getRef(), affineExpr);
140   }
141 
142   static void bindDerived(ClassTy &c) {
143     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
144                  py::arg("context") = py::none());
145     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
146       return mlirAffineDimExprGetPosition(self);
147     });
148   }
149 };
150 
151 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
152 public:
153   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
154   static constexpr const char *pyClassName = "AffineSymbolExpr";
155   using PyConcreteAffineExpr::PyConcreteAffineExpr;
156 
157   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
158     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
159     return PyAffineSymbolExpr(context->getRef(), affineExpr);
160   }
161 
162   static void bindDerived(ClassTy &c) {
163     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
164                  py::arg("context") = py::none());
165     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
166       return mlirAffineSymbolExprGetPosition(self);
167     });
168   }
169 };
170 
171 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
172 public:
173   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
174   static constexpr const char *pyClassName = "AffineBinaryExpr";
175   using PyConcreteAffineExpr::PyConcreteAffineExpr;
176 
177   PyAffineExpr lhs() {
178     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
179     return PyAffineExpr(getContext(), lhsExpr);
180   }
181 
182   PyAffineExpr rhs() {
183     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
184     return PyAffineExpr(getContext(), rhsExpr);
185   }
186 
187   static void bindDerived(ClassTy &c) {
188     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
189     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
190   }
191 };
192 
193 class PyAffineAddExpr
194     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
195 public:
196   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
197   static constexpr const char *pyClassName = "AffineAddExpr";
198   using PyConcreteAffineExpr::PyConcreteAffineExpr;
199 
200   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
201     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
202     return PyAffineAddExpr(lhs.getContext(), expr);
203   }
204 
205   static void bindDerived(ClassTy &c) {
206     c.def_static("get", &PyAffineAddExpr::get);
207   }
208 };
209 
210 class PyAffineMulExpr
211     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
212 public:
213   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
214   static constexpr const char *pyClassName = "AffineMulExpr";
215   using PyConcreteAffineExpr::PyConcreteAffineExpr;
216 
217   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
218     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
219     return PyAffineMulExpr(lhs.getContext(), expr);
220   }
221 
222   static void bindDerived(ClassTy &c) {
223     c.def_static("get", &PyAffineMulExpr::get);
224   }
225 };
226 
227 class PyAffineModExpr
228     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
229 public:
230   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
231   static constexpr const char *pyClassName = "AffineModExpr";
232   using PyConcreteAffineExpr::PyConcreteAffineExpr;
233 
234   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
235     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
236     return PyAffineModExpr(lhs.getContext(), expr);
237   }
238 
239   static void bindDerived(ClassTy &c) {
240     c.def_static("get", &PyAffineModExpr::get);
241   }
242 };
243 
244 class PyAffineFloorDivExpr
245     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
246 public:
247   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
248   static constexpr const char *pyClassName = "AffineFloorDivExpr";
249   using PyConcreteAffineExpr::PyConcreteAffineExpr;
250 
251   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
252     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
253     return PyAffineFloorDivExpr(lhs.getContext(), expr);
254   }
255 
256   static void bindDerived(ClassTy &c) {
257     c.def_static("get", &PyAffineFloorDivExpr::get);
258   }
259 };
260 
261 class PyAffineCeilDivExpr
262     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
263 public:
264   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
265   static constexpr const char *pyClassName = "AffineCeilDivExpr";
266   using PyConcreteAffineExpr::PyConcreteAffineExpr;
267 
268   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
269     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
270     return PyAffineCeilDivExpr(lhs.getContext(), expr);
271   }
272 
273   static void bindDerived(ClassTy &c) {
274     c.def_static("get", &PyAffineCeilDivExpr::get);
275   }
276 };
277 
278 } // namespace
279 
280 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
281   return mlirAffineExprEqual(affineExpr, other.affineExpr);
282 }
283 
284 py::object PyAffineExpr::getCapsule() {
285   return py::reinterpret_steal<py::object>(
286       mlirPythonAffineExprToCapsule(*this));
287 }
288 
289 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
290   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
291   if (mlirAffineExprIsNull(rawAffineExpr))
292     throw py::error_already_set();
293   return PyAffineExpr(
294       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
295       rawAffineExpr);
296 }
297 
298 //------------------------------------------------------------------------------
299 // PyAffineMap and utilities.
300 //------------------------------------------------------------------------------
301 namespace {
302 
303 /// A list of expressions contained in an affine map. Internally these are
304 /// stored as a consecutive array leading to inexpensive random access. Both
305 /// the map and the expression are owned by the context so we need not bother
306 /// with lifetime extension.
307 class PyAffineMapExprList
308     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
309 public:
310   static constexpr const char *pyClassName = "AffineExprList";
311 
312   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
313                       intptr_t length = -1, intptr_t step = 1)
314       : Sliceable(startIndex,
315                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
316                   step),
317         affineMap(map) {}
318 
319   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
320 
321   PyAffineExpr getElement(intptr_t pos) {
322     return PyAffineExpr(affineMap.getContext(),
323                         mlirAffineMapGetResult(affineMap, pos));
324   }
325 
326   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
327                             intptr_t step) {
328     return PyAffineMapExprList(affineMap, startIndex, length, step);
329   }
330 
331 private:
332   PyAffineMap affineMap;
333 };
334 } // end namespace
335 
336 bool PyAffineMap::operator==(const PyAffineMap &other) {
337   return mlirAffineMapEqual(affineMap, other.affineMap);
338 }
339 
340 py::object PyAffineMap::getCapsule() {
341   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
342 }
343 
344 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
345   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
346   if (mlirAffineMapIsNull(rawAffineMap))
347     throw py::error_already_set();
348   return PyAffineMap(
349       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
350       rawAffineMap);
351 }
352 
353 //------------------------------------------------------------------------------
354 // PyIntegerSet and utilities.
355 //------------------------------------------------------------------------------
356 namespace {
357 
358 class PyIntegerSetConstraint {
359 public:
360   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
361 
362   PyAffineExpr getExpr() {
363     return PyAffineExpr(set.getContext(),
364                         mlirIntegerSetGetConstraint(set, pos));
365   }
366 
367   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
368 
369   static void bind(py::module &m) {
370     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
371         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
372         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
373   }
374 
375 private:
376   PyIntegerSet set;
377   intptr_t pos;
378 };
379 
380 class PyIntegerSetConstraintList
381     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
382 public:
383   static constexpr const char *pyClassName = "IntegerSetConstraintList";
384 
385   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
386                              intptr_t length = -1, intptr_t step = 1)
387       : Sliceable(startIndex,
388                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
389                   step),
390         set(set) {}
391 
392   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
393 
394   PyIntegerSetConstraint getElement(intptr_t pos) {
395     return PyIntegerSetConstraint(set, pos);
396   }
397 
398   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
399                                    intptr_t step) {
400     return PyIntegerSetConstraintList(set, startIndex, length, step);
401   }
402 
403 private:
404   PyIntegerSet set;
405 };
406 } // namespace
407 
408 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
409   return mlirIntegerSetEqual(integerSet, other.integerSet);
410 }
411 
412 py::object PyIntegerSet::getCapsule() {
413   return py::reinterpret_steal<py::object>(
414       mlirPythonIntegerSetToCapsule(*this));
415 }
416 
417 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
418   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
419   if (mlirIntegerSetIsNull(rawIntegerSet))
420     throw py::error_already_set();
421   return PyIntegerSet(
422       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
423       rawIntegerSet);
424 }
425 
426 void mlir::python::populateIRAffine(py::module &m) {
427   //----------------------------------------------------------------------------
428   // Mapping of PyAffineExpr and derived classes.
429   //----------------------------------------------------------------------------
430   py::class_<PyAffineExpr>(m, "AffineExpr")
431       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
432                              &PyAffineExpr::getCapsule)
433       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
434       .def("__add__",
435            [](PyAffineExpr &self, PyAffineExpr &other) {
436              return PyAffineAddExpr::get(self, other);
437            })
438       .def("__mul__",
439            [](PyAffineExpr &self, PyAffineExpr &other) {
440              return PyAffineMulExpr::get(self, other);
441            })
442       .def("__mod__",
443            [](PyAffineExpr &self, PyAffineExpr &other) {
444              return PyAffineModExpr::get(self, other);
445            })
446       .def("__sub__",
447            [](PyAffineExpr &self, PyAffineExpr &other) {
448              auto negOne =
449                  PyAffineConstantExpr::get(-1, *self.getContext().get());
450              return PyAffineAddExpr::get(self,
451                                          PyAffineMulExpr::get(negOne, other));
452            })
453       .def("__eq__", [](PyAffineExpr &self,
454                         PyAffineExpr &other) { return self == other; })
455       .def("__eq__",
456            [](PyAffineExpr &self, py::object &other) { return false; })
457       .def("__str__",
458            [](PyAffineExpr &self) {
459              PyPrintAccumulator printAccum;
460              mlirAffineExprPrint(self, printAccum.getCallback(),
461                                  printAccum.getUserData());
462              return printAccum.join();
463            })
464       .def("__repr__",
465            [](PyAffineExpr &self) {
466              PyPrintAccumulator printAccum;
467              printAccum.parts.append("AffineExpr(");
468              mlirAffineExprPrint(self, printAccum.getCallback(),
469                                  printAccum.getUserData());
470              printAccum.parts.append(")");
471              return printAccum.join();
472            })
473       .def_property_readonly(
474           "context",
475           [](PyAffineExpr &self) { return self.getContext().getObject(); })
476       .def_static(
477           "get_add", &PyAffineAddExpr::get,
478           "Gets an affine expression containing a sum of two expressions.")
479       .def_static(
480           "get_mul", &PyAffineMulExpr::get,
481           "Gets an affine expression containing a product of two expressions.")
482       .def_static("get_mod", &PyAffineModExpr::get,
483                   "Gets an affine expression containing the modulo of dividing "
484                   "one expression by another.")
485       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
486                   "Gets an affine expression containing the rounded-down "
487                   "result of dividing one expression by another.")
488       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
489                   "Gets an affine expression containing the rounded-up result "
490                   "of dividing one expression by another.")
491       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
492                   py::arg("context") = py::none(),
493                   "Gets a constant affine expression with the given value.")
494       .def_static(
495           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
496           py::arg("context") = py::none(),
497           "Gets an affine expression of a dimension at the given position.")
498       .def_static(
499           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
500           py::arg("context") = py::none(),
501           "Gets an affine expression of a symbol at the given position.")
502       .def(
503           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
504           kDumpDocstring);
505   PyAffineConstantExpr::bind(m);
506   PyAffineDimExpr::bind(m);
507   PyAffineSymbolExpr::bind(m);
508   PyAffineBinaryExpr::bind(m);
509   PyAffineAddExpr::bind(m);
510   PyAffineMulExpr::bind(m);
511   PyAffineModExpr::bind(m);
512   PyAffineFloorDivExpr::bind(m);
513   PyAffineCeilDivExpr::bind(m);
514 
515   //----------------------------------------------------------------------------
516   // Mapping of PyAffineMap.
517   //----------------------------------------------------------------------------
518   py::class_<PyAffineMap>(m, "AffineMap")
519       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
520                              &PyAffineMap::getCapsule)
521       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
522       .def("__eq__",
523            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
524       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
525       .def("__str__",
526            [](PyAffineMap &self) {
527              PyPrintAccumulator printAccum;
528              mlirAffineMapPrint(self, printAccum.getCallback(),
529                                 printAccum.getUserData());
530              return printAccum.join();
531            })
532       .def("__repr__",
533            [](PyAffineMap &self) {
534              PyPrintAccumulator printAccum;
535              printAccum.parts.append("AffineMap(");
536              mlirAffineMapPrint(self, printAccum.getCallback(),
537                                 printAccum.getUserData());
538              printAccum.parts.append(")");
539              return printAccum.join();
540            })
541       .def_static("compress_unused_symbols",
542                   [](py::list affineMaps, DefaultingPyMlirContext context) {
543                     SmallVector<MlirAffineMap> maps;
544                     pyListToVector<PyAffineMap, MlirAffineMap>(
545                         affineMaps, maps, "attempting to create an AffineMap");
546                     std::vector<MlirAffineMap> compressed(affineMaps.size());
547                     auto populate = [](void *result, intptr_t idx,
548                                        MlirAffineMap m) {
549                       static_cast<MlirAffineMap *>(result)[idx] = (m);
550                     };
551                     mlirAffineMapCompressUnusedSymbols(
552                         maps.data(), maps.size(), compressed.data(), populate);
553                     std::vector<PyAffineMap> res;
554                     for (auto m : compressed)
555                       res.push_back(PyAffineMap(context->getRef(), m));
556                     return res;
557                   })
558       .def_property_readonly(
559           "context",
560           [](PyAffineMap &self) { return self.getContext().getObject(); },
561           "Context that owns the Affine Map")
562       .def(
563           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
564           kDumpDocstring)
565       .def_static(
566           "get",
567           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
568              DefaultingPyMlirContext context) {
569             SmallVector<MlirAffineExpr> affineExprs;
570             pyListToVector<PyAffineExpr, MlirAffineExpr>(
571                 exprs, affineExprs, "attempting to create an AffineMap");
572             MlirAffineMap map =
573                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
574                                  affineExprs.size(), affineExprs.data());
575             return PyAffineMap(context->getRef(), map);
576           },
577           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
578           py::arg("context") = py::none(),
579           "Gets a map with the given expressions as results.")
580       .def_static(
581           "get_constant",
582           [](intptr_t value, DefaultingPyMlirContext context) {
583             MlirAffineMap affineMap =
584                 mlirAffineMapConstantGet(context->get(), value);
585             return PyAffineMap(context->getRef(), affineMap);
586           },
587           py::arg("value"), py::arg("context") = py::none(),
588           "Gets an affine map with a single constant result")
589       .def_static(
590           "get_empty",
591           [](DefaultingPyMlirContext context) {
592             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
593             return PyAffineMap(context->getRef(), affineMap);
594           },
595           py::arg("context") = py::none(), "Gets an empty affine map.")
596       .def_static(
597           "get_identity",
598           [](intptr_t nDims, DefaultingPyMlirContext context) {
599             MlirAffineMap affineMap =
600                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
601             return PyAffineMap(context->getRef(), affineMap);
602           },
603           py::arg("n_dims"), py::arg("context") = py::none(),
604           "Gets an identity map with the given number of dimensions.")
605       .def_static(
606           "get_minor_identity",
607           [](intptr_t nDims, intptr_t nResults,
608              DefaultingPyMlirContext context) {
609             MlirAffineMap affineMap =
610                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
611             return PyAffineMap(context->getRef(), affineMap);
612           },
613           py::arg("n_dims"), py::arg("n_results"),
614           py::arg("context") = py::none(),
615           "Gets a minor identity map with the given number of dimensions and "
616           "results.")
617       .def_static(
618           "get_permutation",
619           [](std::vector<unsigned> permutation,
620              DefaultingPyMlirContext context) {
621             if (!isPermutation(permutation))
622               throw py::cast_error("Invalid permutation when attempting to "
623                                    "create an AffineMap");
624             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
625                 context->get(), permutation.size(), permutation.data());
626             return PyAffineMap(context->getRef(), affineMap);
627           },
628           py::arg("permutation"), py::arg("context") = py::none(),
629           "Gets an affine map that permutes its inputs.")
630       .def("get_submap",
631            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
632              intptr_t numResults = mlirAffineMapGetNumResults(self);
633              for (intptr_t pos : resultPos) {
634                if (pos < 0 || pos >= numResults)
635                  throw py::value_error("result position out of bounds");
636              }
637              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
638                  self, resultPos.size(), resultPos.data());
639              return PyAffineMap(self.getContext(), affineMap);
640            })
641       .def("get_major_submap",
642            [](PyAffineMap &self, intptr_t nResults) {
643              if (nResults >= mlirAffineMapGetNumResults(self))
644                throw py::value_error("number of results out of bounds");
645              MlirAffineMap affineMap =
646                  mlirAffineMapGetMajorSubMap(self, nResults);
647              return PyAffineMap(self.getContext(), affineMap);
648            })
649       .def("get_minor_submap",
650            [](PyAffineMap &self, intptr_t nResults) {
651              if (nResults >= mlirAffineMapGetNumResults(self))
652                throw py::value_error("number of results out of bounds");
653              MlirAffineMap affineMap =
654                  mlirAffineMapGetMinorSubMap(self, nResults);
655              return PyAffineMap(self.getContext(), affineMap);
656            })
657       .def_property_readonly(
658           "is_permutation",
659           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
660       .def_property_readonly("is_projected_permutation",
661                              [](PyAffineMap &self) {
662                                return mlirAffineMapIsProjectedPermutation(self);
663                              })
664       .def_property_readonly(
665           "n_dims",
666           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
667       .def_property_readonly(
668           "n_inputs",
669           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
670       .def_property_readonly(
671           "n_symbols",
672           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
673       .def_property_readonly("results", [](PyAffineMap &self) {
674         return PyAffineMapExprList(self);
675       });
676   PyAffineMapExprList::bind(m);
677 
678   //----------------------------------------------------------------------------
679   // Mapping of PyIntegerSet.
680   //----------------------------------------------------------------------------
681   py::class_<PyIntegerSet>(m, "IntegerSet")
682       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
683                              &PyIntegerSet::getCapsule)
684       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
685       .def("__eq__", [](PyIntegerSet &self,
686                         PyIntegerSet &other) { return self == other; })
687       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
688       .def("__str__",
689            [](PyIntegerSet &self) {
690              PyPrintAccumulator printAccum;
691              mlirIntegerSetPrint(self, printAccum.getCallback(),
692                                  printAccum.getUserData());
693              return printAccum.join();
694            })
695       .def("__repr__",
696            [](PyIntegerSet &self) {
697              PyPrintAccumulator printAccum;
698              printAccum.parts.append("IntegerSet(");
699              mlirIntegerSetPrint(self, printAccum.getCallback(),
700                                  printAccum.getUserData());
701              printAccum.parts.append(")");
702              return printAccum.join();
703            })
704       .def_property_readonly(
705           "context",
706           [](PyIntegerSet &self) { return self.getContext().getObject(); })
707       .def(
708           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
709           kDumpDocstring)
710       .def_static(
711           "get",
712           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
713              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
714             if (exprs.size() != eqFlags.size())
715               throw py::value_error(
716                   "Expected the number of constraints to match "
717                   "that of equality flags");
718             if (exprs.empty())
719               throw py::value_error("Expected non-empty list of constraints");
720 
721             // Copy over to a SmallVector because std::vector has a
722             // specialization for booleans that packs data and does not
723             // expose a `bool *`.
724             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
725 
726             SmallVector<MlirAffineExpr> affineExprs;
727             pyListToVector<PyAffineExpr>(exprs, affineExprs,
728                                          "attempting to create an IntegerSet");
729             MlirIntegerSet set = mlirIntegerSetGet(
730                 context->get(), numDims, numSymbols, exprs.size(),
731                 affineExprs.data(), flags.data());
732             return PyIntegerSet(context->getRef(), set);
733           },
734           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
735           py::arg("eq_flags"), py::arg("context") = py::none())
736       .def_static(
737           "get_empty",
738           [](intptr_t numDims, intptr_t numSymbols,
739              DefaultingPyMlirContext context) {
740             MlirIntegerSet set =
741                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
742             return PyIntegerSet(context->getRef(), set);
743           },
744           py::arg("num_dims"), py::arg("num_symbols"),
745           py::arg("context") = py::none())
746       .def("get_replaced",
747            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
748               intptr_t numResultDims, intptr_t numResultSymbols) {
749              if (static_cast<intptr_t>(dimExprs.size()) !=
750                  mlirIntegerSetGetNumDims(self))
751                throw py::value_error(
752                    "Expected the number of dimension replacement expressions "
753                    "to match that of dimensions");
754              if (static_cast<intptr_t>(symbolExprs.size()) !=
755                  mlirIntegerSetGetNumSymbols(self))
756                throw py::value_error(
757                    "Expected the number of symbol replacement expressions "
758                    "to match that of symbols");
759 
760              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
761              pyListToVector<PyAffineExpr>(
762                  dimExprs, dimAffineExprs,
763                  "attempting to create an IntegerSet by replacing dimensions");
764              pyListToVector<PyAffineExpr>(
765                  symbolExprs, symbolAffineExprs,
766                  "attempting to create an IntegerSet by replacing symbols");
767              MlirIntegerSet set = mlirIntegerSetReplaceGet(
768                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
769                  numResultDims, numResultSymbols);
770              return PyIntegerSet(self.getContext(), set);
771            })
772       .def_property_readonly("is_canonical_empty",
773                              [](PyIntegerSet &self) {
774                                return mlirIntegerSetIsCanonicalEmpty(self);
775                              })
776       .def_property_readonly(
777           "n_dims",
778           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
779       .def_property_readonly(
780           "n_symbols",
781           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
782       .def_property_readonly(
783           "n_inputs",
784           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
785       .def_property_readonly("n_equalities",
786                              [](PyIntegerSet &self) {
787                                return mlirIntegerSetGetNumEqualities(self);
788                              })
789       .def_property_readonly("n_inequalities",
790                              [](PyIntegerSet &self) {
791                                return mlirIntegerSetGetNumInequalities(self);
792                              })
793       .def_property_readonly("constraints", [](PyIntegerSet &self) {
794         return PyIntegerSetConstraintList(self);
795       });
796   PyIntegerSetConstraint::bind(m);
797   PyIntegerSetConstraintList::bind(m);
798 }
799