1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <utility>
10
11 #include "IRModule.h"
12
13 #include "PybindUtils.h"
14
15 #include "mlir-c/AffineMap.h"
16 #include "mlir-c/Bindings/Python/Interop.h"
17 #include "mlir-c/IntegerSet.h"
18
19 namespace py = pybind11;
20 using namespace mlir;
21 using namespace mlir::python;
22
23 using llvm::SmallVector;
24 using llvm::StringRef;
25 using llvm::Twine;
26
27 static const char kDumpDocstring[] =
28 R"(Dumps a debug representation of the object to stderr.)";
29
30 /// Attempts to populate `result` with the content of `list` casted to the
31 /// appropriate type (Python and C types are provided as template arguments).
32 /// Throws errors in case of failure, using "action" to describe what the caller
33 /// was attempting to do.
34 template <typename PyType, typename CType>
pyListToVector(const py::list & list,llvm::SmallVectorImpl<CType> & result,StringRef action)35 static void pyListToVector(const py::list &list,
36 llvm::SmallVectorImpl<CType> &result,
37 StringRef action) {
38 result.reserve(py::len(list));
39 for (py::handle item : list) {
40 try {
41 result.push_back(item.cast<PyType>());
42 } catch (py::cast_error &err) {
43 std::string msg = (llvm::Twine("Invalid expression when ") + action +
44 " (" + err.what() + ")")
45 .str();
46 throw py::cast_error(msg);
47 } catch (py::reference_cast_error &err) {
48 std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
49 action + " (" + err.what() + ")")
50 .str();
51 throw py::cast_error(msg);
52 }
53 }
54 }
55
56 template <typename PermutationTy>
isPermutation(std::vector<PermutationTy> permutation)57 static bool isPermutation(std::vector<PermutationTy> permutation) {
58 llvm::SmallVector<bool, 8> seen(permutation.size(), false);
59 for (auto val : permutation) {
60 if (val < permutation.size()) {
61 if (seen[val])
62 return false;
63 seen[val] = true;
64 continue;
65 }
66 return false;
67 }
68 return true;
69 }
70
71 namespace {
72
73 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
74 /// and should be castable from it. Intermediate hierarchy classes can be
75 /// modeled by specifying BaseTy.
76 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
77 class PyConcreteAffineExpr : public BaseTy {
78 public:
79 // Derived classes must define statics for:
80 // IsAFunctionTy isaFunction
81 // const char *pyClassName
82 // and redefine bindDerived.
83 using ClassTy = py::class_<DerivedTy, BaseTy>;
84 using IsAFunctionTy = bool (*)(MlirAffineExpr);
85
86 PyConcreteAffineExpr() = default;
PyConcreteAffineExpr(PyMlirContextRef contextRef,MlirAffineExpr affineExpr)87 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
88 : BaseTy(std::move(contextRef), affineExpr) {}
PyConcreteAffineExpr(PyAffineExpr & orig)89 PyConcreteAffineExpr(PyAffineExpr &orig)
90 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
91
castFrom(PyAffineExpr & orig)92 static MlirAffineExpr castFrom(PyAffineExpr &orig) {
93 if (!DerivedTy::isaFunction(orig)) {
94 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
95 throw SetPyError(PyExc_ValueError,
96 Twine("Cannot cast affine expression to ") +
97 DerivedTy::pyClassName + " (from " + origRepr + ")");
98 }
99 return orig;
100 }
101
bind(py::module & m)102 static void bind(py::module &m) {
103 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
104 cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
105 cls.def_static(
106 "isinstance",
107 [](PyAffineExpr &otherAffineExpr) -> bool {
108 return DerivedTy::isaFunction(otherAffineExpr);
109 },
110 py::arg("other"));
111 DerivedTy::bindDerived(cls);
112 }
113
114 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)115 static void bindDerived(ClassTy &m) {}
116 };
117
118 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
119 public:
120 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
121 static constexpr const char *pyClassName = "AffineConstantExpr";
122 using PyConcreteAffineExpr::PyConcreteAffineExpr;
123
get(intptr_t value,DefaultingPyMlirContext context)124 static PyAffineConstantExpr get(intptr_t value,
125 DefaultingPyMlirContext context) {
126 MlirAffineExpr affineExpr =
127 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
128 return PyAffineConstantExpr(context->getRef(), affineExpr);
129 }
130
bindDerived(ClassTy & c)131 static void bindDerived(ClassTy &c) {
132 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
133 py::arg("context") = py::none());
134 c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
135 return mlirAffineConstantExprGetValue(self);
136 });
137 }
138 };
139
140 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
141 public:
142 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
143 static constexpr const char *pyClassName = "AffineDimExpr";
144 using PyConcreteAffineExpr::PyConcreteAffineExpr;
145
get(intptr_t pos,DefaultingPyMlirContext context)146 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
147 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
148 return PyAffineDimExpr(context->getRef(), affineExpr);
149 }
150
bindDerived(ClassTy & c)151 static void bindDerived(ClassTy &c) {
152 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
153 py::arg("context") = py::none());
154 c.def_property_readonly("position", [](PyAffineDimExpr &self) {
155 return mlirAffineDimExprGetPosition(self);
156 });
157 }
158 };
159
160 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
161 public:
162 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
163 static constexpr const char *pyClassName = "AffineSymbolExpr";
164 using PyConcreteAffineExpr::PyConcreteAffineExpr;
165
get(intptr_t pos,DefaultingPyMlirContext context)166 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
167 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
168 return PyAffineSymbolExpr(context->getRef(), affineExpr);
169 }
170
bindDerived(ClassTy & c)171 static void bindDerived(ClassTy &c) {
172 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
173 py::arg("context") = py::none());
174 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
175 return mlirAffineSymbolExprGetPosition(self);
176 });
177 }
178 };
179
180 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
181 public:
182 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
183 static constexpr const char *pyClassName = "AffineBinaryExpr";
184 using PyConcreteAffineExpr::PyConcreteAffineExpr;
185
lhs()186 PyAffineExpr lhs() {
187 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
188 return PyAffineExpr(getContext(), lhsExpr);
189 }
190
rhs()191 PyAffineExpr rhs() {
192 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
193 return PyAffineExpr(getContext(), rhsExpr);
194 }
195
bindDerived(ClassTy & c)196 static void bindDerived(ClassTy &c) {
197 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
198 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
199 }
200 };
201
202 class PyAffineAddExpr
203 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
204 public:
205 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
206 static constexpr const char *pyClassName = "AffineAddExpr";
207 using PyConcreteAffineExpr::PyConcreteAffineExpr;
208
get(PyAffineExpr lhs,const PyAffineExpr & rhs)209 static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
210 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
211 return PyAffineAddExpr(lhs.getContext(), expr);
212 }
213
getRHSConstant(PyAffineExpr lhs,intptr_t rhs)214 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
215 MlirAffineExpr expr = mlirAffineAddExprGet(
216 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
217 return PyAffineAddExpr(lhs.getContext(), expr);
218 }
219
getLHSConstant(intptr_t lhs,PyAffineExpr rhs)220 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
221 MlirAffineExpr expr = mlirAffineAddExprGet(
222 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
223 return PyAffineAddExpr(rhs.getContext(), expr);
224 }
225
bindDerived(ClassTy & c)226 static void bindDerived(ClassTy &c) {
227 c.def_static("get", &PyAffineAddExpr::get);
228 }
229 };
230
231 class PyAffineMulExpr
232 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
233 public:
234 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
235 static constexpr const char *pyClassName = "AffineMulExpr";
236 using PyConcreteAffineExpr::PyConcreteAffineExpr;
237
get(PyAffineExpr lhs,const PyAffineExpr & rhs)238 static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
239 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
240 return PyAffineMulExpr(lhs.getContext(), expr);
241 }
242
getRHSConstant(PyAffineExpr lhs,intptr_t rhs)243 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
244 MlirAffineExpr expr = mlirAffineMulExprGet(
245 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
246 return PyAffineMulExpr(lhs.getContext(), expr);
247 }
248
getLHSConstant(intptr_t lhs,PyAffineExpr rhs)249 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
250 MlirAffineExpr expr = mlirAffineMulExprGet(
251 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
252 return PyAffineMulExpr(rhs.getContext(), expr);
253 }
254
bindDerived(ClassTy & c)255 static void bindDerived(ClassTy &c) {
256 c.def_static("get", &PyAffineMulExpr::get);
257 }
258 };
259
260 class PyAffineModExpr
261 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
262 public:
263 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
264 static constexpr const char *pyClassName = "AffineModExpr";
265 using PyConcreteAffineExpr::PyConcreteAffineExpr;
266
get(PyAffineExpr lhs,const PyAffineExpr & rhs)267 static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
268 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
269 return PyAffineModExpr(lhs.getContext(), expr);
270 }
271
getRHSConstant(PyAffineExpr lhs,intptr_t rhs)272 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
273 MlirAffineExpr expr = mlirAffineModExprGet(
274 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
275 return PyAffineModExpr(lhs.getContext(), expr);
276 }
277
getLHSConstant(intptr_t lhs,PyAffineExpr rhs)278 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
279 MlirAffineExpr expr = mlirAffineModExprGet(
280 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
281 return PyAffineModExpr(rhs.getContext(), expr);
282 }
283
bindDerived(ClassTy & c)284 static void bindDerived(ClassTy &c) {
285 c.def_static("get", &PyAffineModExpr::get);
286 }
287 };
288
289 class PyAffineFloorDivExpr
290 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
291 public:
292 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
293 static constexpr const char *pyClassName = "AffineFloorDivExpr";
294 using PyConcreteAffineExpr::PyConcreteAffineExpr;
295
get(PyAffineExpr lhs,const PyAffineExpr & rhs)296 static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
297 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
298 return PyAffineFloorDivExpr(lhs.getContext(), expr);
299 }
300
getRHSConstant(PyAffineExpr lhs,intptr_t rhs)301 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
302 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
303 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
304 return PyAffineFloorDivExpr(lhs.getContext(), expr);
305 }
306
getLHSConstant(intptr_t lhs,PyAffineExpr rhs)307 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
308 MlirAffineExpr expr = mlirAffineFloorDivExprGet(
309 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
310 return PyAffineFloorDivExpr(rhs.getContext(), expr);
311 }
312
bindDerived(ClassTy & c)313 static void bindDerived(ClassTy &c) {
314 c.def_static("get", &PyAffineFloorDivExpr::get);
315 }
316 };
317
318 class PyAffineCeilDivExpr
319 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
320 public:
321 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
322 static constexpr const char *pyClassName = "AffineCeilDivExpr";
323 using PyConcreteAffineExpr::PyConcreteAffineExpr;
324
get(PyAffineExpr lhs,const PyAffineExpr & rhs)325 static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) {
326 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
327 return PyAffineCeilDivExpr(lhs.getContext(), expr);
328 }
329
getRHSConstant(PyAffineExpr lhs,intptr_t rhs)330 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) {
331 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
332 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs));
333 return PyAffineCeilDivExpr(lhs.getContext(), expr);
334 }
335
getLHSConstant(intptr_t lhs,PyAffineExpr rhs)336 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) {
337 MlirAffineExpr expr = mlirAffineCeilDivExprGet(
338 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs);
339 return PyAffineCeilDivExpr(rhs.getContext(), expr);
340 }
341
bindDerived(ClassTy & c)342 static void bindDerived(ClassTy &c) {
343 c.def_static("get", &PyAffineCeilDivExpr::get);
344 }
345 };
346
347 } // namespace
348
operator ==(const PyAffineExpr & other)349 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
350 return mlirAffineExprEqual(affineExpr, other.affineExpr);
351 }
352
getCapsule()353 py::object PyAffineExpr::getCapsule() {
354 return py::reinterpret_steal<py::object>(
355 mlirPythonAffineExprToCapsule(*this));
356 }
357
createFromCapsule(py::object capsule)358 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
359 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
360 if (mlirAffineExprIsNull(rawAffineExpr))
361 throw py::error_already_set();
362 return PyAffineExpr(
363 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
364 rawAffineExpr);
365 }
366
367 //------------------------------------------------------------------------------
368 // PyAffineMap and utilities.
369 //------------------------------------------------------------------------------
370 namespace {
371
372 /// A list of expressions contained in an affine map. Internally these are
373 /// stored as a consecutive array leading to inexpensive random access. Both
374 /// the map and the expression are owned by the context so we need not bother
375 /// with lifetime extension.
376 class PyAffineMapExprList
377 : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
378 public:
379 static constexpr const char *pyClassName = "AffineExprList";
380
PyAffineMapExprList(const PyAffineMap & map,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)381 PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0,
382 intptr_t length = -1, intptr_t step = 1)
383 : Sliceable(startIndex,
384 length == -1 ? mlirAffineMapGetNumResults(map) : length,
385 step),
386 affineMap(map) {}
387
388 private:
389 /// Give the parent CRTP class access to hook implementations below.
390 friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
391
getRawNumElements()392 intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
393
getRawElement(intptr_t pos)394 PyAffineExpr getRawElement(intptr_t pos) {
395 return PyAffineExpr(affineMap.getContext(),
396 mlirAffineMapGetResult(affineMap, pos));
397 }
398
slice(intptr_t startIndex,intptr_t length,intptr_t step)399 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
400 intptr_t step) {
401 return PyAffineMapExprList(affineMap, startIndex, length, step);
402 }
403
404 PyAffineMap affineMap;
405 };
406 } // namespace
407
operator ==(const PyAffineMap & other)408 bool PyAffineMap::operator==(const PyAffineMap &other) {
409 return mlirAffineMapEqual(affineMap, other.affineMap);
410 }
411
getCapsule()412 py::object PyAffineMap::getCapsule() {
413 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
414 }
415
createFromCapsule(py::object capsule)416 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
417 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
418 if (mlirAffineMapIsNull(rawAffineMap))
419 throw py::error_already_set();
420 return PyAffineMap(
421 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
422 rawAffineMap);
423 }
424
425 //------------------------------------------------------------------------------
426 // PyIntegerSet and utilities.
427 //------------------------------------------------------------------------------
428 namespace {
429
430 class PyIntegerSetConstraint {
431 public:
PyIntegerSetConstraint(PyIntegerSet set,intptr_t pos)432 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos)
433 : set(std::move(set)), pos(pos) {}
434
getExpr()435 PyAffineExpr getExpr() {
436 return PyAffineExpr(set.getContext(),
437 mlirIntegerSetGetConstraint(set, pos));
438 }
439
isEq()440 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
441
bind(py::module & m)442 static void bind(py::module &m) {
443 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
444 py::module_local())
445 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
446 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
447 }
448
449 private:
450 PyIntegerSet set;
451 intptr_t pos;
452 };
453
454 class PyIntegerSetConstraintList
455 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
456 public:
457 static constexpr const char *pyClassName = "IntegerSetConstraintList";
458
PyIntegerSetConstraintList(const PyIntegerSet & set,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)459 PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0,
460 intptr_t length = -1, intptr_t step = 1)
461 : Sliceable(startIndex,
462 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
463 step),
464 set(set) {}
465
466 private:
467 /// Give the parent CRTP class access to hook implementations below.
468 friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
469
getRawNumElements()470 intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
471
getRawElement(intptr_t pos)472 PyIntegerSetConstraint getRawElement(intptr_t pos) {
473 return PyIntegerSetConstraint(set, pos);
474 }
475
slice(intptr_t startIndex,intptr_t length,intptr_t step)476 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
477 intptr_t step) {
478 return PyIntegerSetConstraintList(set, startIndex, length, step);
479 }
480
481 PyIntegerSet set;
482 };
483 } // namespace
484
operator ==(const PyIntegerSet & other)485 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
486 return mlirIntegerSetEqual(integerSet, other.integerSet);
487 }
488
getCapsule()489 py::object PyIntegerSet::getCapsule() {
490 return py::reinterpret_steal<py::object>(
491 mlirPythonIntegerSetToCapsule(*this));
492 }
493
createFromCapsule(py::object capsule)494 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
495 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
496 if (mlirIntegerSetIsNull(rawIntegerSet))
497 throw py::error_already_set();
498 return PyIntegerSet(
499 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
500 rawIntegerSet);
501 }
502
populateIRAffine(py::module & m)503 void mlir::python::populateIRAffine(py::module &m) {
504 //----------------------------------------------------------------------------
505 // Mapping of PyAffineExpr and derived classes.
506 //----------------------------------------------------------------------------
507 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
508 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
509 &PyAffineExpr::getCapsule)
510 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
511 .def("__add__", &PyAffineAddExpr::get)
512 .def("__add__", &PyAffineAddExpr::getRHSConstant)
513 .def("__radd__", &PyAffineAddExpr::getRHSConstant)
514 .def("__mul__", &PyAffineMulExpr::get)
515 .def("__mul__", &PyAffineMulExpr::getRHSConstant)
516 .def("__rmul__", &PyAffineMulExpr::getRHSConstant)
517 .def("__mod__", &PyAffineModExpr::get)
518 .def("__mod__", &PyAffineModExpr::getRHSConstant)
519 .def("__rmod__",
520 [](PyAffineExpr &self, intptr_t other) {
521 return PyAffineModExpr::get(
522 PyAffineConstantExpr::get(other, *self.getContext().get()),
523 self);
524 })
525 .def("__sub__",
526 [](PyAffineExpr &self, PyAffineExpr &other) {
527 auto negOne =
528 PyAffineConstantExpr::get(-1, *self.getContext().get());
529 return PyAffineAddExpr::get(self,
530 PyAffineMulExpr::get(negOne, other));
531 })
532 .def("__sub__",
533 [](PyAffineExpr &self, intptr_t other) {
534 return PyAffineAddExpr::get(
535 self,
536 PyAffineConstantExpr::get(-other, *self.getContext().get()));
537 })
538 .def("__rsub__",
539 [](PyAffineExpr &self, intptr_t other) {
540 return PyAffineAddExpr::getLHSConstant(
541 other, PyAffineMulExpr::getLHSConstant(-1, self));
542 })
543 .def("__eq__", [](PyAffineExpr &self,
544 PyAffineExpr &other) { return self == other; })
545 .def("__eq__",
546 [](PyAffineExpr &self, py::object &other) { return false; })
547 .def("__str__",
548 [](PyAffineExpr &self) {
549 PyPrintAccumulator printAccum;
550 mlirAffineExprPrint(self, printAccum.getCallback(),
551 printAccum.getUserData());
552 return printAccum.join();
553 })
554 .def("__repr__",
555 [](PyAffineExpr &self) {
556 PyPrintAccumulator printAccum;
557 printAccum.parts.append("AffineExpr(");
558 mlirAffineExprPrint(self, printAccum.getCallback(),
559 printAccum.getUserData());
560 printAccum.parts.append(")");
561 return printAccum.join();
562 })
563 .def("__hash__",
564 [](PyAffineExpr &self) {
565 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
566 })
567 .def_property_readonly(
568 "context",
569 [](PyAffineExpr &self) { return self.getContext().getObject(); })
570 .def("compose",
571 [](PyAffineExpr &self, PyAffineMap &other) {
572 return PyAffineExpr(self.getContext(),
573 mlirAffineExprCompose(self, other));
574 })
575 .def_static(
576 "get_add", &PyAffineAddExpr::get,
577 "Gets an affine expression containing a sum of two expressions.")
578 .def_static("get_add", &PyAffineAddExpr::getLHSConstant,
579 "Gets an affine expression containing a sum of a constant "
580 "and another expression.")
581 .def_static("get_add", &PyAffineAddExpr::getRHSConstant,
582 "Gets an affine expression containing a sum of an expression "
583 "and a constant.")
584 .def_static(
585 "get_mul", &PyAffineMulExpr::get,
586 "Gets an affine expression containing a product of two expressions.")
587 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant,
588 "Gets an affine expression containing a product of a "
589 "constant and another expression.")
590 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant,
591 "Gets an affine expression containing a product of an "
592 "expression and a constant.")
593 .def_static("get_mod", &PyAffineModExpr::get,
594 "Gets an affine expression containing the modulo of dividing "
595 "one expression by another.")
596 .def_static("get_mod", &PyAffineModExpr::getLHSConstant,
597 "Gets a semi-affine expression containing the modulo of "
598 "dividing a constant by an expression.")
599 .def_static("get_mod", &PyAffineModExpr::getRHSConstant,
600 "Gets an affine expression containing the module of dividing"
601 "an expression by a constant.")
602 .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
603 "Gets an affine expression containing the rounded-down "
604 "result of dividing one expression by another.")
605 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant,
606 "Gets a semi-affine expression containing the rounded-down "
607 "result of dividing a constant by an expression.")
608 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant,
609 "Gets an affine expression containing the rounded-down "
610 "result of dividing an expression by a constant.")
611 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
612 "Gets an affine expression containing the rounded-up result "
613 "of dividing one expression by another.")
614 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant,
615 "Gets a semi-affine expression containing the rounded-up "
616 "result of dividing a constant by an expression.")
617 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
618 "Gets an affine expression containing the rounded-up result "
619 "of dividing an expression by a constant.")
620 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
621 py::arg("context") = py::none(),
622 "Gets a constant affine expression with the given value.")
623 .def_static(
624 "get_dim", &PyAffineDimExpr::get, py::arg("position"),
625 py::arg("context") = py::none(),
626 "Gets an affine expression of a dimension at the given position.")
627 .def_static(
628 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
629 py::arg("context") = py::none(),
630 "Gets an affine expression of a symbol at the given position.")
631 .def(
632 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
633 kDumpDocstring);
634 PyAffineConstantExpr::bind(m);
635 PyAffineDimExpr::bind(m);
636 PyAffineSymbolExpr::bind(m);
637 PyAffineBinaryExpr::bind(m);
638 PyAffineAddExpr::bind(m);
639 PyAffineMulExpr::bind(m);
640 PyAffineModExpr::bind(m);
641 PyAffineFloorDivExpr::bind(m);
642 PyAffineCeilDivExpr::bind(m);
643
644 //----------------------------------------------------------------------------
645 // Mapping of PyAffineMap.
646 //----------------------------------------------------------------------------
647 py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
648 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
649 &PyAffineMap::getCapsule)
650 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
651 .def("__eq__",
652 [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
653 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
654 .def("__str__",
655 [](PyAffineMap &self) {
656 PyPrintAccumulator printAccum;
657 mlirAffineMapPrint(self, printAccum.getCallback(),
658 printAccum.getUserData());
659 return printAccum.join();
660 })
661 .def("__repr__",
662 [](PyAffineMap &self) {
663 PyPrintAccumulator printAccum;
664 printAccum.parts.append("AffineMap(");
665 mlirAffineMapPrint(self, printAccum.getCallback(),
666 printAccum.getUserData());
667 printAccum.parts.append(")");
668 return printAccum.join();
669 })
670 .def("__hash__",
671 [](PyAffineMap &self) {
672 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
673 })
674 .def_static("compress_unused_symbols",
675 [](py::list affineMaps, DefaultingPyMlirContext context) {
676 SmallVector<MlirAffineMap> maps;
677 pyListToVector<PyAffineMap, MlirAffineMap>(
678 affineMaps, maps, "attempting to create an AffineMap");
679 std::vector<MlirAffineMap> compressed(affineMaps.size());
680 auto populate = [](void *result, intptr_t idx,
681 MlirAffineMap m) {
682 static_cast<MlirAffineMap *>(result)[idx] = (m);
683 };
684 mlirAffineMapCompressUnusedSymbols(
685 maps.data(), maps.size(), compressed.data(), populate);
686 std::vector<PyAffineMap> res;
687 res.reserve(compressed.size());
688 for (auto m : compressed)
689 res.emplace_back(context->getRef(), m);
690 return res;
691 })
692 .def_property_readonly(
693 "context",
694 [](PyAffineMap &self) { return self.getContext().getObject(); },
695 "Context that owns the Affine Map")
696 .def(
697 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
698 kDumpDocstring)
699 .def_static(
700 "get",
701 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
702 DefaultingPyMlirContext context) {
703 SmallVector<MlirAffineExpr> affineExprs;
704 pyListToVector<PyAffineExpr, MlirAffineExpr>(
705 exprs, affineExprs, "attempting to create an AffineMap");
706 MlirAffineMap map =
707 mlirAffineMapGet(context->get(), dimCount, symbolCount,
708 affineExprs.size(), affineExprs.data());
709 return PyAffineMap(context->getRef(), map);
710 },
711 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
712 py::arg("context") = py::none(),
713 "Gets a map with the given expressions as results.")
714 .def_static(
715 "get_constant",
716 [](intptr_t value, DefaultingPyMlirContext context) {
717 MlirAffineMap affineMap =
718 mlirAffineMapConstantGet(context->get(), value);
719 return PyAffineMap(context->getRef(), affineMap);
720 },
721 py::arg("value"), py::arg("context") = py::none(),
722 "Gets an affine map with a single constant result")
723 .def_static(
724 "get_empty",
725 [](DefaultingPyMlirContext context) {
726 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
727 return PyAffineMap(context->getRef(), affineMap);
728 },
729 py::arg("context") = py::none(), "Gets an empty affine map.")
730 .def_static(
731 "get_identity",
732 [](intptr_t nDims, DefaultingPyMlirContext context) {
733 MlirAffineMap affineMap =
734 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
735 return PyAffineMap(context->getRef(), affineMap);
736 },
737 py::arg("n_dims"), py::arg("context") = py::none(),
738 "Gets an identity map with the given number of dimensions.")
739 .def_static(
740 "get_minor_identity",
741 [](intptr_t nDims, intptr_t nResults,
742 DefaultingPyMlirContext context) {
743 MlirAffineMap affineMap =
744 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
745 return PyAffineMap(context->getRef(), affineMap);
746 },
747 py::arg("n_dims"), py::arg("n_results"),
748 py::arg("context") = py::none(),
749 "Gets a minor identity map with the given number of dimensions and "
750 "results.")
751 .def_static(
752 "get_permutation",
753 [](std::vector<unsigned> permutation,
754 DefaultingPyMlirContext context) {
755 if (!isPermutation(permutation))
756 throw py::cast_error("Invalid permutation when attempting to "
757 "create an AffineMap");
758 MlirAffineMap affineMap = mlirAffineMapPermutationGet(
759 context->get(), permutation.size(), permutation.data());
760 return PyAffineMap(context->getRef(), affineMap);
761 },
762 py::arg("permutation"), py::arg("context") = py::none(),
763 "Gets an affine map that permutes its inputs.")
764 .def(
765 "get_submap",
766 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
767 intptr_t numResults = mlirAffineMapGetNumResults(self);
768 for (intptr_t pos : resultPos) {
769 if (pos < 0 || pos >= numResults)
770 throw py::value_error("result position out of bounds");
771 }
772 MlirAffineMap affineMap = mlirAffineMapGetSubMap(
773 self, resultPos.size(), resultPos.data());
774 return PyAffineMap(self.getContext(), affineMap);
775 },
776 py::arg("result_positions"))
777 .def(
778 "get_major_submap",
779 [](PyAffineMap &self, intptr_t nResults) {
780 if (nResults >= mlirAffineMapGetNumResults(self))
781 throw py::value_error("number of results out of bounds");
782 MlirAffineMap affineMap =
783 mlirAffineMapGetMajorSubMap(self, nResults);
784 return PyAffineMap(self.getContext(), affineMap);
785 },
786 py::arg("n_results"))
787 .def(
788 "get_minor_submap",
789 [](PyAffineMap &self, intptr_t nResults) {
790 if (nResults >= mlirAffineMapGetNumResults(self))
791 throw py::value_error("number of results out of bounds");
792 MlirAffineMap affineMap =
793 mlirAffineMapGetMinorSubMap(self, nResults);
794 return PyAffineMap(self.getContext(), affineMap);
795 },
796 py::arg("n_results"))
797 .def(
798 "replace",
799 [](PyAffineMap &self, PyAffineExpr &expression,
800 PyAffineExpr &replacement, intptr_t numResultDims,
801 intptr_t numResultSyms) {
802 MlirAffineMap affineMap = mlirAffineMapReplace(
803 self, expression, replacement, numResultDims, numResultSyms);
804 return PyAffineMap(self.getContext(), affineMap);
805 },
806 py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
807 py::arg("n_result_syms"))
808 .def_property_readonly(
809 "is_permutation",
810 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
811 .def_property_readonly("is_projected_permutation",
812 [](PyAffineMap &self) {
813 return mlirAffineMapIsProjectedPermutation(self);
814 })
815 .def_property_readonly(
816 "n_dims",
817 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
818 .def_property_readonly(
819 "n_inputs",
820 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
821 .def_property_readonly(
822 "n_symbols",
823 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
824 .def_property_readonly("results", [](PyAffineMap &self) {
825 return PyAffineMapExprList(self);
826 });
827 PyAffineMapExprList::bind(m);
828
829 //----------------------------------------------------------------------------
830 // Mapping of PyIntegerSet.
831 //----------------------------------------------------------------------------
832 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
833 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
834 &PyIntegerSet::getCapsule)
835 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
836 .def("__eq__", [](PyIntegerSet &self,
837 PyIntegerSet &other) { return self == other; })
838 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
839 .def("__str__",
840 [](PyIntegerSet &self) {
841 PyPrintAccumulator printAccum;
842 mlirIntegerSetPrint(self, printAccum.getCallback(),
843 printAccum.getUserData());
844 return printAccum.join();
845 })
846 .def("__repr__",
847 [](PyIntegerSet &self) {
848 PyPrintAccumulator printAccum;
849 printAccum.parts.append("IntegerSet(");
850 mlirIntegerSetPrint(self, printAccum.getCallback(),
851 printAccum.getUserData());
852 printAccum.parts.append(")");
853 return printAccum.join();
854 })
855 .def("__hash__",
856 [](PyIntegerSet &self) {
857 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
858 })
859 .def_property_readonly(
860 "context",
861 [](PyIntegerSet &self) { return self.getContext().getObject(); })
862 .def(
863 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
864 kDumpDocstring)
865 .def_static(
866 "get",
867 [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
868 std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
869 if (exprs.size() != eqFlags.size())
870 throw py::value_error(
871 "Expected the number of constraints to match "
872 "that of equality flags");
873 if (exprs.empty())
874 throw py::value_error("Expected non-empty list of constraints");
875
876 // Copy over to a SmallVector because std::vector has a
877 // specialization for booleans that packs data and does not
878 // expose a `bool *`.
879 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
880
881 SmallVector<MlirAffineExpr> affineExprs;
882 pyListToVector<PyAffineExpr>(exprs, affineExprs,
883 "attempting to create an IntegerSet");
884 MlirIntegerSet set = mlirIntegerSetGet(
885 context->get(), numDims, numSymbols, exprs.size(),
886 affineExprs.data(), flags.data());
887 return PyIntegerSet(context->getRef(), set);
888 },
889 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
890 py::arg("eq_flags"), py::arg("context") = py::none())
891 .def_static(
892 "get_empty",
893 [](intptr_t numDims, intptr_t numSymbols,
894 DefaultingPyMlirContext context) {
895 MlirIntegerSet set =
896 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
897 return PyIntegerSet(context->getRef(), set);
898 },
899 py::arg("num_dims"), py::arg("num_symbols"),
900 py::arg("context") = py::none())
901 .def(
902 "get_replaced",
903 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
904 intptr_t numResultDims, intptr_t numResultSymbols) {
905 if (static_cast<intptr_t>(dimExprs.size()) !=
906 mlirIntegerSetGetNumDims(self))
907 throw py::value_error(
908 "Expected the number of dimension replacement expressions "
909 "to match that of dimensions");
910 if (static_cast<intptr_t>(symbolExprs.size()) !=
911 mlirIntegerSetGetNumSymbols(self))
912 throw py::value_error(
913 "Expected the number of symbol replacement expressions "
914 "to match that of symbols");
915
916 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
917 pyListToVector<PyAffineExpr>(
918 dimExprs, dimAffineExprs,
919 "attempting to create an IntegerSet by replacing dimensions");
920 pyListToVector<PyAffineExpr>(
921 symbolExprs, symbolAffineExprs,
922 "attempting to create an IntegerSet by replacing symbols");
923 MlirIntegerSet set = mlirIntegerSetReplaceGet(
924 self, dimAffineExprs.data(), symbolAffineExprs.data(),
925 numResultDims, numResultSymbols);
926 return PyIntegerSet(self.getContext(), set);
927 },
928 py::arg("dim_exprs"), py::arg("symbol_exprs"),
929 py::arg("num_result_dims"), py::arg("num_result_symbols"))
930 .def_property_readonly("is_canonical_empty",
931 [](PyIntegerSet &self) {
932 return mlirIntegerSetIsCanonicalEmpty(self);
933 })
934 .def_property_readonly(
935 "n_dims",
936 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
937 .def_property_readonly(
938 "n_symbols",
939 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
940 .def_property_readonly(
941 "n_inputs",
942 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
943 .def_property_readonly("n_equalities",
944 [](PyIntegerSet &self) {
945 return mlirIntegerSetGetNumEqualities(self);
946 })
947 .def_property_readonly("n_inequalities",
948 [](PyIntegerSet &self) {
949 return mlirIntegerSetGetNumInequalities(self);
950 })
951 .def_property_readonly("constraints", [](PyIntegerSet &self) {
952 return PyIntegerSetConstraintList(self);
953 });
954 PyIntegerSetConstraint::bind(m);
955 PyIntegerSetConstraintList::bind(m);
956 }
957