1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/AffineMap.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/IntegerSet.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 using llvm::SmallVector;
22 using llvm::StringRef;
23 using llvm::Twine;
24 
25 static const char kDumpDocstring[] =
26     R"(Dumps a debug representation of the object to stderr.)";
27 
28 /// Attempts to populate `result` with the content of `list` casted to the
29 /// appropriate type (Python and C types are provided as template arguments).
30 /// Throws errors in case of failure, using "action" to describe what the caller
31 /// was attempting to do.
32 template <typename PyType, typename CType>
33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
34                            StringRef action) {
35   result.reserve(py::len(list));
36   for (py::handle item : list) {
37     try {
38       result.push_back(item.cast<PyType>());
39     } catch (py::cast_error &err) {
40       std::string msg = (llvm::Twine("Invalid expression when ") + action +
41                          " (" + err.what() + ")")
42                             .str();
43       throw py::cast_error(msg);
44     } catch (py::reference_cast_error &err) {
45       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
46                          action + " (" + err.what() + ")")
47                             .str();
48       throw py::cast_error(msg);
49     }
50   }
51 }
52 
53 template <typename PermutationTy>
54 static bool isPermutation(std::vector<PermutationTy> permutation) {
55   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
56   for (auto val : permutation) {
57     if (val < permutation.size()) {
58       if (seen[val])
59         return false;
60       seen[val] = true;
61       continue;
62     }
63     return false;
64   }
65   return true;
66 }
67 
68 namespace {
69 
70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
71 /// and should be castable from it. Intermediate hierarchy classes can be
72 /// modeled by specifying BaseTy.
73 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
74 class PyConcreteAffineExpr : public BaseTy {
75 public:
76   // Derived classes must define statics for:
77   //   IsAFunctionTy isaFunction
78   //   const char *pyClassName
79   // and redefine bindDerived.
80   using ClassTy = py::class_<DerivedTy, BaseTy>;
81   using IsAFunctionTy = bool (*)(MlirAffineExpr);
82 
83   PyConcreteAffineExpr() = default;
84   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
85       : BaseTy(std::move(contextRef), affineExpr) {}
86   PyConcreteAffineExpr(PyAffineExpr &orig)
87       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
88 
89   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
90     if (!DerivedTy::isaFunction(orig)) {
91       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
92       throw SetPyError(PyExc_ValueError,
93                        Twine("Cannot cast affine expression to ") +
94                            DerivedTy::pyClassName + " (from " + origRepr + ")");
95     }
96     return orig;
97   }
98 
99   static void bind(py::module &m) {
100     auto cls = ClassTy(m, DerivedTy::pyClassName);
101     cls.def(py::init<PyAffineExpr &>());
102     DerivedTy::bindDerived(cls);
103   }
104 
105   /// Implemented by derived classes to add methods to the Python subclass.
106   static void bindDerived(ClassTy &m) {}
107 };
108 
109 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
110 public:
111   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
112   static constexpr const char *pyClassName = "AffineConstantExpr";
113   using PyConcreteAffineExpr::PyConcreteAffineExpr;
114 
115   static PyAffineConstantExpr get(intptr_t value,
116                                   DefaultingPyMlirContext context) {
117     MlirAffineExpr affineExpr =
118         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
119     return PyAffineConstantExpr(context->getRef(), affineExpr);
120   }
121 
122   static void bindDerived(ClassTy &c) {
123     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
124                  py::arg("context") = py::none());
125     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
126       return mlirAffineConstantExprGetValue(self);
127     });
128   }
129 };
130 
131 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
132 public:
133   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
134   static constexpr const char *pyClassName = "AffineDimExpr";
135   using PyConcreteAffineExpr::PyConcreteAffineExpr;
136 
137   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
138     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
139     return PyAffineDimExpr(context->getRef(), affineExpr);
140   }
141 
142   static void bindDerived(ClassTy &c) {
143     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
144                  py::arg("context") = py::none());
145     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
146       return mlirAffineDimExprGetPosition(self);
147     });
148   }
149 };
150 
151 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
152 public:
153   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
154   static constexpr const char *pyClassName = "AffineSymbolExpr";
155   using PyConcreteAffineExpr::PyConcreteAffineExpr;
156 
157   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
158     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
159     return PyAffineSymbolExpr(context->getRef(), affineExpr);
160   }
161 
162   static void bindDerived(ClassTy &c) {
163     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
164                  py::arg("context") = py::none());
165     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
166       return mlirAffineSymbolExprGetPosition(self);
167     });
168   }
169 };
170 
171 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
172 public:
173   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
174   static constexpr const char *pyClassName = "AffineBinaryExpr";
175   using PyConcreteAffineExpr::PyConcreteAffineExpr;
176 
177   PyAffineExpr lhs() {
178     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
179     return PyAffineExpr(getContext(), lhsExpr);
180   }
181 
182   PyAffineExpr rhs() {
183     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
184     return PyAffineExpr(getContext(), rhsExpr);
185   }
186 
187   static void bindDerived(ClassTy &c) {
188     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
189     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
190   }
191 };
192 
193 class PyAffineAddExpr
194     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
195 public:
196   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
197   static constexpr const char *pyClassName = "AffineAddExpr";
198   using PyConcreteAffineExpr::PyConcreteAffineExpr;
199 
200   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
201     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
202     return PyAffineAddExpr(lhs.getContext(), expr);
203   }
204 
205   static void bindDerived(ClassTy &c) {
206     c.def_static("get", &PyAffineAddExpr::get);
207   }
208 };
209 
210 class PyAffineMulExpr
211     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
212 public:
213   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
214   static constexpr const char *pyClassName = "AffineMulExpr";
215   using PyConcreteAffineExpr::PyConcreteAffineExpr;
216 
217   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
218     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
219     return PyAffineMulExpr(lhs.getContext(), expr);
220   }
221 
222   static void bindDerived(ClassTy &c) {
223     c.def_static("get", &PyAffineMulExpr::get);
224   }
225 };
226 
227 class PyAffineModExpr
228     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
229 public:
230   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
231   static constexpr const char *pyClassName = "AffineModExpr";
232   using PyConcreteAffineExpr::PyConcreteAffineExpr;
233 
234   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
235     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
236     return PyAffineModExpr(lhs.getContext(), expr);
237   }
238 
239   static void bindDerived(ClassTy &c) {
240     c.def_static("get", &PyAffineModExpr::get);
241   }
242 };
243 
244 class PyAffineFloorDivExpr
245     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
246 public:
247   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
248   static constexpr const char *pyClassName = "AffineFloorDivExpr";
249   using PyConcreteAffineExpr::PyConcreteAffineExpr;
250 
251   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
252     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
253     return PyAffineFloorDivExpr(lhs.getContext(), expr);
254   }
255 
256   static void bindDerived(ClassTy &c) {
257     c.def_static("get", &PyAffineFloorDivExpr::get);
258   }
259 };
260 
261 class PyAffineCeilDivExpr
262     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
263 public:
264   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
265   static constexpr const char *pyClassName = "AffineCeilDivExpr";
266   using PyConcreteAffineExpr::PyConcreteAffineExpr;
267 
268   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
269     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
270     return PyAffineCeilDivExpr(lhs.getContext(), expr);
271   }
272 
273   static void bindDerived(ClassTy &c) {
274     c.def_static("get", &PyAffineCeilDivExpr::get);
275   }
276 };
277 
278 } // namespace
279 
280 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
281   return mlirAffineExprEqual(affineExpr, other.affineExpr);
282 }
283 
284 py::object PyAffineExpr::getCapsule() {
285   return py::reinterpret_steal<py::object>(
286       mlirPythonAffineExprToCapsule(*this));
287 }
288 
289 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
290   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
291   if (mlirAffineExprIsNull(rawAffineExpr))
292     throw py::error_already_set();
293   return PyAffineExpr(
294       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
295       rawAffineExpr);
296 }
297 
298 //------------------------------------------------------------------------------
299 // PyAffineMap and utilities.
300 //------------------------------------------------------------------------------
301 namespace {
302 
303 /// A list of expressions contained in an affine map. Internally these are
304 /// stored as a consecutive array leading to inexpensive random access. Both
305 /// the map and the expression are owned by the context so we need not bother
306 /// with lifetime extension.
307 class PyAffineMapExprList
308     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
309 public:
310   static constexpr const char *pyClassName = "AffineExprList";
311 
312   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
313                       intptr_t length = -1, intptr_t step = 1)
314       : Sliceable(startIndex,
315                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
316                   step),
317         affineMap(map) {}
318 
319   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
320 
321   PyAffineExpr getElement(intptr_t pos) {
322     return PyAffineExpr(affineMap.getContext(),
323                         mlirAffineMapGetResult(affineMap, pos));
324   }
325 
326   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
327                             intptr_t step) {
328     return PyAffineMapExprList(affineMap, startIndex, length, step);
329   }
330 
331 private:
332   PyAffineMap affineMap;
333 };
334 } // end namespace
335 
336 bool PyAffineMap::operator==(const PyAffineMap &other) {
337   return mlirAffineMapEqual(affineMap, other.affineMap);
338 }
339 
340 py::object PyAffineMap::getCapsule() {
341   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
342 }
343 
344 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
345   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
346   if (mlirAffineMapIsNull(rawAffineMap))
347     throw py::error_already_set();
348   return PyAffineMap(
349       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
350       rawAffineMap);
351 }
352 
353 //------------------------------------------------------------------------------
354 // PyIntegerSet and utilities.
355 //------------------------------------------------------------------------------
356 namespace {
357 
358 class PyIntegerSetConstraint {
359 public:
360   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
361 
362   PyAffineExpr getExpr() {
363     return PyAffineExpr(set.getContext(),
364                         mlirIntegerSetGetConstraint(set, pos));
365   }
366 
367   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
368 
369   static void bind(py::module &m) {
370     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
371         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
372         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
373   }
374 
375 private:
376   PyIntegerSet set;
377   intptr_t pos;
378 };
379 
380 class PyIntegerSetConstraintList
381     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
382 public:
383   static constexpr const char *pyClassName = "IntegerSetConstraintList";
384 
385   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
386                              intptr_t length = -1, intptr_t step = 1)
387       : Sliceable(startIndex,
388                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
389                   step),
390         set(set) {}
391 
392   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
393 
394   PyIntegerSetConstraint getElement(intptr_t pos) {
395     return PyIntegerSetConstraint(set, pos);
396   }
397 
398   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
399                                    intptr_t step) {
400     return PyIntegerSetConstraintList(set, startIndex, length, step);
401   }
402 
403 private:
404   PyIntegerSet set;
405 };
406 } // namespace
407 
408 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
409   return mlirIntegerSetEqual(integerSet, other.integerSet);
410 }
411 
412 py::object PyIntegerSet::getCapsule() {
413   return py::reinterpret_steal<py::object>(
414       mlirPythonIntegerSetToCapsule(*this));
415 }
416 
417 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
418   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
419   if (mlirIntegerSetIsNull(rawIntegerSet))
420     throw py::error_already_set();
421   return PyIntegerSet(
422       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
423       rawIntegerSet);
424 }
425 
426 void mlir::python::populateIRAffine(py::module &m) {
427   //----------------------------------------------------------------------------
428   // Mapping of PyAffineExpr and derived classes.
429   //----------------------------------------------------------------------------
430   py::class_<PyAffineExpr>(m, "AffineExpr")
431       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
432                              &PyAffineExpr::getCapsule)
433       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
434       .def("__add__",
435            [](PyAffineExpr &self, PyAffineExpr &other) {
436              return PyAffineAddExpr::get(self, other);
437            })
438       .def("__mul__",
439            [](PyAffineExpr &self, PyAffineExpr &other) {
440              return PyAffineMulExpr::get(self, other);
441            })
442       .def("__mod__",
443            [](PyAffineExpr &self, PyAffineExpr &other) {
444              return PyAffineModExpr::get(self, other);
445            })
446       .def("__sub__",
447            [](PyAffineExpr &self, PyAffineExpr &other) {
448              auto negOne =
449                  PyAffineConstantExpr::get(-1, *self.getContext().get());
450              return PyAffineAddExpr::get(self,
451                                          PyAffineMulExpr::get(negOne, other));
452            })
453       .def("__eq__", [](PyAffineExpr &self,
454                         PyAffineExpr &other) { return self == other; })
455       .def("__eq__",
456            [](PyAffineExpr &self, py::object &other) { return false; })
457       .def("__str__",
458            [](PyAffineExpr &self) {
459              PyPrintAccumulator printAccum;
460              mlirAffineExprPrint(self, printAccum.getCallback(),
461                                  printAccum.getUserData());
462              return printAccum.join();
463            })
464       .def("__repr__",
465            [](PyAffineExpr &self) {
466              PyPrintAccumulator printAccum;
467              printAccum.parts.append("AffineExpr(");
468              mlirAffineExprPrint(self, printAccum.getCallback(),
469                                  printAccum.getUserData());
470              printAccum.parts.append(")");
471              return printAccum.join();
472            })
473       .def_property_readonly(
474           "context",
475           [](PyAffineExpr &self) { return self.getContext().getObject(); })
476       .def_static(
477           "get_add", &PyAffineAddExpr::get,
478           "Gets an affine expression containing a sum of two expressions.")
479       .def_static(
480           "get_mul", &PyAffineMulExpr::get,
481           "Gets an affine expression containing a product of two expressions.")
482       .def_static("get_mod", &PyAffineModExpr::get,
483                   "Gets an affine expression containing the modulo of dividing "
484                   "one expression by another.")
485       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
486                   "Gets an affine expression containing the rounded-down "
487                   "result of dividing one expression by another.")
488       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
489                   "Gets an affine expression containing the rounded-up result "
490                   "of dividing one expression by another.")
491       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
492                   py::arg("context") = py::none(),
493                   "Gets a constant affine expression with the given value.")
494       .def_static(
495           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
496           py::arg("context") = py::none(),
497           "Gets an affine expression of a dimension at the given position.")
498       .def_static(
499           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
500           py::arg("context") = py::none(),
501           "Gets an affine expression of a symbol at the given position.")
502       .def(
503           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
504           kDumpDocstring);
505   PyAffineConstantExpr::bind(m);
506   PyAffineDimExpr::bind(m);
507   PyAffineSymbolExpr::bind(m);
508   PyAffineBinaryExpr::bind(m);
509   PyAffineAddExpr::bind(m);
510   PyAffineMulExpr::bind(m);
511   PyAffineModExpr::bind(m);
512   PyAffineFloorDivExpr::bind(m);
513   PyAffineCeilDivExpr::bind(m);
514 
515   //----------------------------------------------------------------------------
516   // Mapping of PyAffineMap.
517   //----------------------------------------------------------------------------
518   py::class_<PyAffineMap>(m, "AffineMap")
519       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
520                              &PyAffineMap::getCapsule)
521       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
522       .def("__eq__",
523            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
524       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
525       .def("__str__",
526            [](PyAffineMap &self) {
527              PyPrintAccumulator printAccum;
528              mlirAffineMapPrint(self, printAccum.getCallback(),
529                                 printAccum.getUserData());
530              return printAccum.join();
531            })
532       .def("__repr__",
533            [](PyAffineMap &self) {
534              PyPrintAccumulator printAccum;
535              printAccum.parts.append("AffineMap(");
536              mlirAffineMapPrint(self, printAccum.getCallback(),
537                                 printAccum.getUserData());
538              printAccum.parts.append(")");
539              return printAccum.join();
540            })
541       .def_property_readonly(
542           "context",
543           [](PyAffineMap &self) { return self.getContext().getObject(); },
544           "Context that owns the Affine Map")
545       .def(
546           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
547           kDumpDocstring)
548       .def_static(
549           "get",
550           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
551              DefaultingPyMlirContext context) {
552             SmallVector<MlirAffineExpr> affineExprs;
553             pyListToVector<PyAffineExpr, MlirAffineExpr>(
554                 exprs, affineExprs, "attempting to create an AffineMap");
555             MlirAffineMap map =
556                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
557                                  affineExprs.size(), affineExprs.data());
558             return PyAffineMap(context->getRef(), map);
559           },
560           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
561           py::arg("context") = py::none(),
562           "Gets a map with the given expressions as results.")
563       .def_static(
564           "get_constant",
565           [](intptr_t value, DefaultingPyMlirContext context) {
566             MlirAffineMap affineMap =
567                 mlirAffineMapConstantGet(context->get(), value);
568             return PyAffineMap(context->getRef(), affineMap);
569           },
570           py::arg("value"), py::arg("context") = py::none(),
571           "Gets an affine map with a single constant result")
572       .def_static(
573           "get_empty",
574           [](DefaultingPyMlirContext context) {
575             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
576             return PyAffineMap(context->getRef(), affineMap);
577           },
578           py::arg("context") = py::none(), "Gets an empty affine map.")
579       .def_static(
580           "get_identity",
581           [](intptr_t nDims, DefaultingPyMlirContext context) {
582             MlirAffineMap affineMap =
583                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
584             return PyAffineMap(context->getRef(), affineMap);
585           },
586           py::arg("n_dims"), py::arg("context") = py::none(),
587           "Gets an identity map with the given number of dimensions.")
588       .def_static(
589           "get_minor_identity",
590           [](intptr_t nDims, intptr_t nResults,
591              DefaultingPyMlirContext context) {
592             MlirAffineMap affineMap =
593                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
594             return PyAffineMap(context->getRef(), affineMap);
595           },
596           py::arg("n_dims"), py::arg("n_results"),
597           py::arg("context") = py::none(),
598           "Gets a minor identity map with the given number of dimensions and "
599           "results.")
600       .def_static(
601           "get_permutation",
602           [](std::vector<unsigned> permutation,
603              DefaultingPyMlirContext context) {
604             if (!isPermutation(permutation))
605               throw py::cast_error("Invalid permutation when attempting to "
606                                    "create an AffineMap");
607             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
608                 context->get(), permutation.size(), permutation.data());
609             return PyAffineMap(context->getRef(), affineMap);
610           },
611           py::arg("permutation"), py::arg("context") = py::none(),
612           "Gets an affine map that permutes its inputs.")
613       .def("get_submap",
614            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
615              intptr_t numResults = mlirAffineMapGetNumResults(self);
616              for (intptr_t pos : resultPos) {
617                if (pos < 0 || pos >= numResults)
618                  throw py::value_error("result position out of bounds");
619              }
620              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
621                  self, resultPos.size(), resultPos.data());
622              return PyAffineMap(self.getContext(), affineMap);
623            })
624       .def("get_major_submap",
625            [](PyAffineMap &self, intptr_t nResults) {
626              if (nResults >= mlirAffineMapGetNumResults(self))
627                throw py::value_error("number of results out of bounds");
628              MlirAffineMap affineMap =
629                  mlirAffineMapGetMajorSubMap(self, nResults);
630              return PyAffineMap(self.getContext(), affineMap);
631            })
632       .def("get_minor_submap",
633            [](PyAffineMap &self, intptr_t nResults) {
634              if (nResults >= mlirAffineMapGetNumResults(self))
635                throw py::value_error("number of results out of bounds");
636              MlirAffineMap affineMap =
637                  mlirAffineMapGetMinorSubMap(self, nResults);
638              return PyAffineMap(self.getContext(), affineMap);
639            })
640       .def_property_readonly(
641           "is_permutation",
642           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
643       .def_property_readonly("is_projected_permutation",
644                              [](PyAffineMap &self) {
645                                return mlirAffineMapIsProjectedPermutation(self);
646                              })
647       .def_property_readonly(
648           "n_dims",
649           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
650       .def_property_readonly(
651           "n_inputs",
652           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
653       .def_property_readonly(
654           "n_symbols",
655           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
656       .def_property_readonly("results", [](PyAffineMap &self) {
657         return PyAffineMapExprList(self);
658       });
659   PyAffineMapExprList::bind(m);
660 
661   //----------------------------------------------------------------------------
662   // Mapping of PyIntegerSet.
663   //----------------------------------------------------------------------------
664   py::class_<PyIntegerSet>(m, "IntegerSet")
665       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
666                              &PyIntegerSet::getCapsule)
667       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
668       .def("__eq__", [](PyIntegerSet &self,
669                         PyIntegerSet &other) { return self == other; })
670       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
671       .def("__str__",
672            [](PyIntegerSet &self) {
673              PyPrintAccumulator printAccum;
674              mlirIntegerSetPrint(self, printAccum.getCallback(),
675                                  printAccum.getUserData());
676              return printAccum.join();
677            })
678       .def("__repr__",
679            [](PyIntegerSet &self) {
680              PyPrintAccumulator printAccum;
681              printAccum.parts.append("IntegerSet(");
682              mlirIntegerSetPrint(self, printAccum.getCallback(),
683                                  printAccum.getUserData());
684              printAccum.parts.append(")");
685              return printAccum.join();
686            })
687       .def_property_readonly(
688           "context",
689           [](PyIntegerSet &self) { return self.getContext().getObject(); })
690       .def(
691           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
692           kDumpDocstring)
693       .def_static(
694           "get",
695           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
696              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
697             if (exprs.size() != eqFlags.size())
698               throw py::value_error(
699                   "Expected the number of constraints to match "
700                   "that of equality flags");
701             if (exprs.empty())
702               throw py::value_error("Expected non-empty list of constraints");
703 
704             // Copy over to a SmallVector because std::vector has a
705             // specialization for booleans that packs data and does not
706             // expose a `bool *`.
707             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
708 
709             SmallVector<MlirAffineExpr> affineExprs;
710             pyListToVector<PyAffineExpr>(exprs, affineExprs,
711                                          "attempting to create an IntegerSet");
712             MlirIntegerSet set = mlirIntegerSetGet(
713                 context->get(), numDims, numSymbols, exprs.size(),
714                 affineExprs.data(), flags.data());
715             return PyIntegerSet(context->getRef(), set);
716           },
717           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
718           py::arg("eq_flags"), py::arg("context") = py::none())
719       .def_static(
720           "get_empty",
721           [](intptr_t numDims, intptr_t numSymbols,
722              DefaultingPyMlirContext context) {
723             MlirIntegerSet set =
724                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
725             return PyIntegerSet(context->getRef(), set);
726           },
727           py::arg("num_dims"), py::arg("num_symbols"),
728           py::arg("context") = py::none())
729       .def("get_replaced",
730            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
731               intptr_t numResultDims, intptr_t numResultSymbols) {
732              if (static_cast<intptr_t>(dimExprs.size()) !=
733                  mlirIntegerSetGetNumDims(self))
734                throw py::value_error(
735                    "Expected the number of dimension replacement expressions "
736                    "to match that of dimensions");
737              if (static_cast<intptr_t>(symbolExprs.size()) !=
738                  mlirIntegerSetGetNumSymbols(self))
739                throw py::value_error(
740                    "Expected the number of symbol replacement expressions "
741                    "to match that of symbols");
742 
743              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
744              pyListToVector<PyAffineExpr>(
745                  dimExprs, dimAffineExprs,
746                  "attempting to create an IntegerSet by replacing dimensions");
747              pyListToVector<PyAffineExpr>(
748                  symbolExprs, symbolAffineExprs,
749                  "attempting to create an IntegerSet by replacing symbols");
750              MlirIntegerSet set = mlirIntegerSetReplaceGet(
751                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
752                  numResultDims, numResultSymbols);
753              return PyIntegerSet(self.getContext(), set);
754            })
755       .def_property_readonly("is_canonical_empty",
756                              [](PyIntegerSet &self) {
757                                return mlirIntegerSetIsCanonicalEmpty(self);
758                              })
759       .def_property_readonly(
760           "n_dims",
761           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
762       .def_property_readonly(
763           "n_symbols",
764           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
765       .def_property_readonly(
766           "n_inputs",
767           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
768       .def_property_readonly("n_equalities",
769                              [](PyIntegerSet &self) {
770                                return mlirIntegerSetGetNumEqualities(self);
771                              })
772       .def_property_readonly("n_inequalities",
773                              [](PyIntegerSet &self) {
774                                return mlirIntegerSetGetNumInequalities(self);
775                              })
776       .def_property_readonly("constraints", [](PyIntegerSet &self) {
777         return PyIntegerSetConstraintList(self);
778       });
779   PyIntegerSetConstraint::bind(m);
780   PyIntegerSetConstraintList::bind(m);
781 }
782