1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "IRModule.h" 12 13 #include "PybindUtils.h" 14 15 #include "mlir-c/AffineMap.h" 16 #include "mlir-c/Bindings/Python/Interop.h" 17 #include "mlir-c/IntegerSet.h" 18 19 namespace py = pybind11; 20 using namespace mlir; 21 using namespace mlir::python; 22 23 using llvm::SmallVector; 24 using llvm::StringRef; 25 using llvm::Twine; 26 27 static const char kDumpDocstring[] = 28 R"(Dumps a debug representation of the object to stderr.)"; 29 30 /// Attempts to populate `result` with the content of `list` casted to the 31 /// appropriate type (Python and C types are provided as template arguments). 32 /// Throws errors in case of failure, using "action" to describe what the caller 33 /// was attempting to do. 34 template <typename PyType, typename CType> 35 static void pyListToVector(const py::list &list, 36 llvm::SmallVectorImpl<CType> &result, 37 StringRef action) { 38 result.reserve(py::len(list)); 39 for (py::handle item : list) { 40 try { 41 result.push_back(item.cast<PyType>()); 42 } catch (py::cast_error &err) { 43 std::string msg = (llvm::Twine("Invalid expression when ") + action + 44 " (" + err.what() + ")") 45 .str(); 46 throw py::cast_error(msg); 47 } catch (py::reference_cast_error &err) { 48 std::string msg = (llvm::Twine("Invalid expression (None?) when ") + 49 action + " (" + err.what() + ")") 50 .str(); 51 throw py::cast_error(msg); 52 } 53 } 54 } 55 56 template <typename PermutationTy> 57 static bool isPermutation(std::vector<PermutationTy> permutation) { 58 llvm::SmallVector<bool, 8> seen(permutation.size(), false); 59 for (auto val : permutation) { 60 if (val < permutation.size()) { 61 if (seen[val]) 62 return false; 63 seen[val] = true; 64 continue; 65 } 66 return false; 67 } 68 return true; 69 } 70 71 namespace { 72 73 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr 74 /// and should be castable from it. Intermediate hierarchy classes can be 75 /// modeled by specifying BaseTy. 76 template <typename DerivedTy, typename BaseTy = PyAffineExpr> 77 class PyConcreteAffineExpr : public BaseTy { 78 public: 79 // Derived classes must define statics for: 80 // IsAFunctionTy isaFunction 81 // const char *pyClassName 82 // and redefine bindDerived. 83 using ClassTy = py::class_<DerivedTy, BaseTy>; 84 using IsAFunctionTy = bool (*)(MlirAffineExpr); 85 86 PyConcreteAffineExpr() = default; 87 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 88 : BaseTy(std::move(contextRef), affineExpr) {} 89 PyConcreteAffineExpr(PyAffineExpr &orig) 90 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} 91 92 static MlirAffineExpr castFrom(PyAffineExpr &orig) { 93 if (!DerivedTy::isaFunction(orig)) { 94 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 95 throw SetPyError(PyExc_ValueError, 96 Twine("Cannot cast affine expression to ") + 97 DerivedTy::pyClassName + " (from " + origRepr + ")"); 98 } 99 return orig; 100 } 101 102 static void bind(py::module &m) { 103 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 104 cls.def(py::init<PyAffineExpr &>(), py::arg("expr")); 105 cls.def_static( 106 "isinstance", 107 [](PyAffineExpr &otherAffineExpr) -> bool { 108 return DerivedTy::isaFunction(otherAffineExpr); 109 }, 110 py::arg("other")); 111 DerivedTy::bindDerived(cls); 112 } 113 114 /// Implemented by derived classes to add methods to the Python subclass. 115 static void bindDerived(ClassTy &m) {} 116 }; 117 118 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { 119 public: 120 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; 121 static constexpr const char *pyClassName = "AffineConstantExpr"; 122 using PyConcreteAffineExpr::PyConcreteAffineExpr; 123 124 static PyAffineConstantExpr get(intptr_t value, 125 DefaultingPyMlirContext context) { 126 MlirAffineExpr affineExpr = 127 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); 128 return PyAffineConstantExpr(context->getRef(), affineExpr); 129 } 130 131 static void bindDerived(ClassTy &c) { 132 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), 133 py::arg("context") = py::none()); 134 c.def_property_readonly("value", [](PyAffineConstantExpr &self) { 135 return mlirAffineConstantExprGetValue(self); 136 }); 137 } 138 }; 139 140 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { 141 public: 142 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; 143 static constexpr const char *pyClassName = "AffineDimExpr"; 144 using PyConcreteAffineExpr::PyConcreteAffineExpr; 145 146 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { 147 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); 148 return PyAffineDimExpr(context->getRef(), affineExpr); 149 } 150 151 static void bindDerived(ClassTy &c) { 152 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), 153 py::arg("context") = py::none()); 154 c.def_property_readonly("position", [](PyAffineDimExpr &self) { 155 return mlirAffineDimExprGetPosition(self); 156 }); 157 } 158 }; 159 160 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { 161 public: 162 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; 163 static constexpr const char *pyClassName = "AffineSymbolExpr"; 164 using PyConcreteAffineExpr::PyConcreteAffineExpr; 165 166 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { 167 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); 168 return PyAffineSymbolExpr(context->getRef(), affineExpr); 169 } 170 171 static void bindDerived(ClassTy &c) { 172 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), 173 py::arg("context") = py::none()); 174 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { 175 return mlirAffineSymbolExprGetPosition(self); 176 }); 177 } 178 }; 179 180 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { 181 public: 182 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; 183 static constexpr const char *pyClassName = "AffineBinaryExpr"; 184 using PyConcreteAffineExpr::PyConcreteAffineExpr; 185 186 PyAffineExpr lhs() { 187 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); 188 return PyAffineExpr(getContext(), lhsExpr); 189 } 190 191 PyAffineExpr rhs() { 192 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); 193 return PyAffineExpr(getContext(), rhsExpr); 194 } 195 196 static void bindDerived(ClassTy &c) { 197 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); 198 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); 199 } 200 }; 201 202 class PyAffineAddExpr 203 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { 204 public: 205 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; 206 static constexpr const char *pyClassName = "AffineAddExpr"; 207 using PyConcreteAffineExpr::PyConcreteAffineExpr; 208 209 static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 210 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); 211 return PyAffineAddExpr(lhs.getContext(), expr); 212 } 213 214 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 215 MlirAffineExpr expr = mlirAffineAddExprGet( 216 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 217 return PyAffineAddExpr(lhs.getContext(), expr); 218 } 219 220 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 221 MlirAffineExpr expr = mlirAffineAddExprGet( 222 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 223 return PyAffineAddExpr(rhs.getContext(), expr); 224 } 225 226 static void bindDerived(ClassTy &c) { 227 c.def_static("get", &PyAffineAddExpr::get); 228 } 229 }; 230 231 class PyAffineMulExpr 232 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { 233 public: 234 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; 235 static constexpr const char *pyClassName = "AffineMulExpr"; 236 using PyConcreteAffineExpr::PyConcreteAffineExpr; 237 238 static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 239 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); 240 return PyAffineMulExpr(lhs.getContext(), expr); 241 } 242 243 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 244 MlirAffineExpr expr = mlirAffineMulExprGet( 245 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 246 return PyAffineMulExpr(lhs.getContext(), expr); 247 } 248 249 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 250 MlirAffineExpr expr = mlirAffineMulExprGet( 251 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 252 return PyAffineMulExpr(rhs.getContext(), expr); 253 } 254 255 static void bindDerived(ClassTy &c) { 256 c.def_static("get", &PyAffineMulExpr::get); 257 } 258 }; 259 260 class PyAffineModExpr 261 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { 262 public: 263 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; 264 static constexpr const char *pyClassName = "AffineModExpr"; 265 using PyConcreteAffineExpr::PyConcreteAffineExpr; 266 267 static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 268 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); 269 return PyAffineModExpr(lhs.getContext(), expr); 270 } 271 272 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 273 MlirAffineExpr expr = mlirAffineModExprGet( 274 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 275 return PyAffineModExpr(lhs.getContext(), expr); 276 } 277 278 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 279 MlirAffineExpr expr = mlirAffineModExprGet( 280 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 281 return PyAffineModExpr(rhs.getContext(), expr); 282 } 283 284 static void bindDerived(ClassTy &c) { 285 c.def_static("get", &PyAffineModExpr::get); 286 } 287 }; 288 289 class PyAffineFloorDivExpr 290 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { 291 public: 292 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; 293 static constexpr const char *pyClassName = "AffineFloorDivExpr"; 294 using PyConcreteAffineExpr::PyConcreteAffineExpr; 295 296 static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 297 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); 298 return PyAffineFloorDivExpr(lhs.getContext(), expr); 299 } 300 301 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 302 MlirAffineExpr expr = mlirAffineFloorDivExprGet( 303 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 304 return PyAffineFloorDivExpr(lhs.getContext(), expr); 305 } 306 307 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 308 MlirAffineExpr expr = mlirAffineFloorDivExprGet( 309 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 310 return PyAffineFloorDivExpr(rhs.getContext(), expr); 311 } 312 313 static void bindDerived(ClassTy &c) { 314 c.def_static("get", &PyAffineFloorDivExpr::get); 315 } 316 }; 317 318 class PyAffineCeilDivExpr 319 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { 320 public: 321 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; 322 static constexpr const char *pyClassName = "AffineCeilDivExpr"; 323 using PyConcreteAffineExpr::PyConcreteAffineExpr; 324 325 static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { 326 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); 327 return PyAffineCeilDivExpr(lhs.getContext(), expr); 328 } 329 330 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 331 MlirAffineExpr expr = mlirAffineCeilDivExprGet( 332 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 333 return PyAffineCeilDivExpr(lhs.getContext(), expr); 334 } 335 336 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 337 MlirAffineExpr expr = mlirAffineCeilDivExprGet( 338 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 339 return PyAffineCeilDivExpr(rhs.getContext(), expr); 340 } 341 342 static void bindDerived(ClassTy &c) { 343 c.def_static("get", &PyAffineCeilDivExpr::get); 344 } 345 }; 346 347 } // namespace 348 349 bool PyAffineExpr::operator==(const PyAffineExpr &other) { 350 return mlirAffineExprEqual(affineExpr, other.affineExpr); 351 } 352 353 py::object PyAffineExpr::getCapsule() { 354 return py::reinterpret_steal<py::object>( 355 mlirPythonAffineExprToCapsule(*this)); 356 } 357 358 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { 359 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); 360 if (mlirAffineExprIsNull(rawAffineExpr)) 361 throw py::error_already_set(); 362 return PyAffineExpr( 363 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), 364 rawAffineExpr); 365 } 366 367 //------------------------------------------------------------------------------ 368 // PyAffineMap and utilities. 369 //------------------------------------------------------------------------------ 370 namespace { 371 372 /// A list of expressions contained in an affine map. Internally these are 373 /// stored as a consecutive array leading to inexpensive random access. Both 374 /// the map and the expression are owned by the context so we need not bother 375 /// with lifetime extension. 376 class PyAffineMapExprList 377 : public Sliceable<PyAffineMapExprList, PyAffineExpr> { 378 public: 379 static constexpr const char *pyClassName = "AffineExprList"; 380 381 PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0, 382 intptr_t length = -1, intptr_t step = 1) 383 : Sliceable(startIndex, 384 length == -1 ? mlirAffineMapGetNumResults(map) : length, 385 step), 386 affineMap(map) {} 387 388 private: 389 /// Give the parent CRTP class access to hook implementations below. 390 friend class Sliceable<PyAffineMapExprList, PyAffineExpr>; 391 392 intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); } 393 394 PyAffineExpr getRawElement(intptr_t pos) { 395 return PyAffineExpr(affineMap.getContext(), 396 mlirAffineMapGetResult(affineMap, pos)); 397 } 398 399 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, 400 intptr_t step) { 401 return PyAffineMapExprList(affineMap, startIndex, length, step); 402 } 403 404 PyAffineMap affineMap; 405 }; 406 } // namespace 407 408 bool PyAffineMap::operator==(const PyAffineMap &other) { 409 return mlirAffineMapEqual(affineMap, other.affineMap); 410 } 411 412 py::object PyAffineMap::getCapsule() { 413 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); 414 } 415 416 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { 417 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); 418 if (mlirAffineMapIsNull(rawAffineMap)) 419 throw py::error_already_set(); 420 return PyAffineMap( 421 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), 422 rawAffineMap); 423 } 424 425 //------------------------------------------------------------------------------ 426 // PyIntegerSet and utilities. 427 //------------------------------------------------------------------------------ 428 namespace { 429 430 class PyIntegerSetConstraint { 431 public: 432 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) 433 : set(std::move(set)), pos(pos) {} 434 435 PyAffineExpr getExpr() { 436 return PyAffineExpr(set.getContext(), 437 mlirIntegerSetGetConstraint(set, pos)); 438 } 439 440 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } 441 442 static void bind(py::module &m) { 443 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint", 444 py::module_local()) 445 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) 446 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); 447 } 448 449 private: 450 PyIntegerSet set; 451 intptr_t pos; 452 }; 453 454 class PyIntegerSetConstraintList 455 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { 456 public: 457 static constexpr const char *pyClassName = "IntegerSetConstraintList"; 458 459 PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0, 460 intptr_t length = -1, intptr_t step = 1) 461 : Sliceable(startIndex, 462 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, 463 step), 464 set(set) {} 465 466 private: 467 /// Give the parent CRTP class access to hook implementations below. 468 friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>; 469 470 intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); } 471 472 PyIntegerSetConstraint getRawElement(intptr_t pos) { 473 return PyIntegerSetConstraint(set, pos); 474 } 475 476 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, 477 intptr_t step) { 478 return PyIntegerSetConstraintList(set, startIndex, length, step); 479 } 480 481 PyIntegerSet set; 482 }; 483 } // namespace 484 485 bool PyIntegerSet::operator==(const PyIntegerSet &other) { 486 return mlirIntegerSetEqual(integerSet, other.integerSet); 487 } 488 489 py::object PyIntegerSet::getCapsule() { 490 return py::reinterpret_steal<py::object>( 491 mlirPythonIntegerSetToCapsule(*this)); 492 } 493 494 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { 495 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); 496 if (mlirIntegerSetIsNull(rawIntegerSet)) 497 throw py::error_already_set(); 498 return PyIntegerSet( 499 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), 500 rawIntegerSet); 501 } 502 503 void mlir::python::populateIRAffine(py::module &m) { 504 //---------------------------------------------------------------------------- 505 // Mapping of PyAffineExpr and derived classes. 506 //---------------------------------------------------------------------------- 507 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local()) 508 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 509 &PyAffineExpr::getCapsule) 510 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) 511 .def("__add__", &PyAffineAddExpr::get) 512 .def("__add__", &PyAffineAddExpr::getRHSConstant) 513 .def("__radd__", &PyAffineAddExpr::getRHSConstant) 514 .def("__mul__", &PyAffineMulExpr::get) 515 .def("__mul__", &PyAffineMulExpr::getRHSConstant) 516 .def("__rmul__", &PyAffineMulExpr::getRHSConstant) 517 .def("__mod__", &PyAffineModExpr::get) 518 .def("__mod__", &PyAffineModExpr::getRHSConstant) 519 .def("__rmod__", 520 [](PyAffineExpr &self, intptr_t other) { 521 return PyAffineModExpr::get( 522 PyAffineConstantExpr::get(other, *self.getContext().get()), 523 self); 524 }) 525 .def("__sub__", 526 [](PyAffineExpr &self, PyAffineExpr &other) { 527 auto negOne = 528 PyAffineConstantExpr::get(-1, *self.getContext().get()); 529 return PyAffineAddExpr::get(self, 530 PyAffineMulExpr::get(negOne, other)); 531 }) 532 .def("__sub__", 533 [](PyAffineExpr &self, intptr_t other) { 534 return PyAffineAddExpr::get( 535 self, 536 PyAffineConstantExpr::get(-other, *self.getContext().get())); 537 }) 538 .def("__rsub__", 539 [](PyAffineExpr &self, intptr_t other) { 540 return PyAffineAddExpr::getLHSConstant( 541 other, PyAffineMulExpr::getLHSConstant(-1, self)); 542 }) 543 .def("__eq__", [](PyAffineExpr &self, 544 PyAffineExpr &other) { return self == other; }) 545 .def("__eq__", 546 [](PyAffineExpr &self, py::object &other) { return false; }) 547 .def("__str__", 548 [](PyAffineExpr &self) { 549 PyPrintAccumulator printAccum; 550 mlirAffineExprPrint(self, printAccum.getCallback(), 551 printAccum.getUserData()); 552 return printAccum.join(); 553 }) 554 .def("__repr__", 555 [](PyAffineExpr &self) { 556 PyPrintAccumulator printAccum; 557 printAccum.parts.append("AffineExpr("); 558 mlirAffineExprPrint(self, printAccum.getCallback(), 559 printAccum.getUserData()); 560 printAccum.parts.append(")"); 561 return printAccum.join(); 562 }) 563 .def("__hash__", 564 [](PyAffineExpr &self) { 565 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 566 }) 567 .def_property_readonly( 568 "context", 569 [](PyAffineExpr &self) { return self.getContext().getObject(); }) 570 .def("compose", 571 [](PyAffineExpr &self, PyAffineMap &other) { 572 return PyAffineExpr(self.getContext(), 573 mlirAffineExprCompose(self, other)); 574 }) 575 .def_static( 576 "get_add", &PyAffineAddExpr::get, 577 "Gets an affine expression containing a sum of two expressions.") 578 .def_static("get_add", &PyAffineAddExpr::getLHSConstant, 579 "Gets an affine expression containing a sum of a constant " 580 "and another expression.") 581 .def_static("get_add", &PyAffineAddExpr::getRHSConstant, 582 "Gets an affine expression containing a sum of an expression " 583 "and a constant.") 584 .def_static( 585 "get_mul", &PyAffineMulExpr::get, 586 "Gets an affine expression containing a product of two expressions.") 587 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, 588 "Gets an affine expression containing a product of a " 589 "constant and another expression.") 590 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, 591 "Gets an affine expression containing a product of an " 592 "expression and a constant.") 593 .def_static("get_mod", &PyAffineModExpr::get, 594 "Gets an affine expression containing the modulo of dividing " 595 "one expression by another.") 596 .def_static("get_mod", &PyAffineModExpr::getLHSConstant, 597 "Gets a semi-affine expression containing the modulo of " 598 "dividing a constant by an expression.") 599 .def_static("get_mod", &PyAffineModExpr::getRHSConstant, 600 "Gets an affine expression containing the module of dividing" 601 "an expression by a constant.") 602 .def_static("get_floor_div", &PyAffineFloorDivExpr::get, 603 "Gets an affine expression containing the rounded-down " 604 "result of dividing one expression by another.") 605 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, 606 "Gets a semi-affine expression containing the rounded-down " 607 "result of dividing a constant by an expression.") 608 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, 609 "Gets an affine expression containing the rounded-down " 610 "result of dividing an expression by a constant.") 611 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, 612 "Gets an affine expression containing the rounded-up result " 613 "of dividing one expression by another.") 614 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, 615 "Gets a semi-affine expression containing the rounded-up " 616 "result of dividing a constant by an expression.") 617 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, 618 "Gets an affine expression containing the rounded-up result " 619 "of dividing an expression by a constant.") 620 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), 621 py::arg("context") = py::none(), 622 "Gets a constant affine expression with the given value.") 623 .def_static( 624 "get_dim", &PyAffineDimExpr::get, py::arg("position"), 625 py::arg("context") = py::none(), 626 "Gets an affine expression of a dimension at the given position.") 627 .def_static( 628 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), 629 py::arg("context") = py::none(), 630 "Gets an affine expression of a symbol at the given position.") 631 .def( 632 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, 633 kDumpDocstring); 634 PyAffineConstantExpr::bind(m); 635 PyAffineDimExpr::bind(m); 636 PyAffineSymbolExpr::bind(m); 637 PyAffineBinaryExpr::bind(m); 638 PyAffineAddExpr::bind(m); 639 PyAffineMulExpr::bind(m); 640 PyAffineModExpr::bind(m); 641 PyAffineFloorDivExpr::bind(m); 642 PyAffineCeilDivExpr::bind(m); 643 644 //---------------------------------------------------------------------------- 645 // Mapping of PyAffineMap. 646 //---------------------------------------------------------------------------- 647 py::class_<PyAffineMap>(m, "AffineMap", py::module_local()) 648 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 649 &PyAffineMap::getCapsule) 650 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) 651 .def("__eq__", 652 [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) 653 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) 654 .def("__str__", 655 [](PyAffineMap &self) { 656 PyPrintAccumulator printAccum; 657 mlirAffineMapPrint(self, printAccum.getCallback(), 658 printAccum.getUserData()); 659 return printAccum.join(); 660 }) 661 .def("__repr__", 662 [](PyAffineMap &self) { 663 PyPrintAccumulator printAccum; 664 printAccum.parts.append("AffineMap("); 665 mlirAffineMapPrint(self, printAccum.getCallback(), 666 printAccum.getUserData()); 667 printAccum.parts.append(")"); 668 return printAccum.join(); 669 }) 670 .def("__hash__", 671 [](PyAffineMap &self) { 672 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 673 }) 674 .def_static("compress_unused_symbols", 675 [](py::list affineMaps, DefaultingPyMlirContext context) { 676 SmallVector<MlirAffineMap> maps; 677 pyListToVector<PyAffineMap, MlirAffineMap>( 678 affineMaps, maps, "attempting to create an AffineMap"); 679 std::vector<MlirAffineMap> compressed(affineMaps.size()); 680 auto populate = [](void *result, intptr_t idx, 681 MlirAffineMap m) { 682 static_cast<MlirAffineMap *>(result)[idx] = (m); 683 }; 684 mlirAffineMapCompressUnusedSymbols( 685 maps.data(), maps.size(), compressed.data(), populate); 686 std::vector<PyAffineMap> res; 687 res.reserve(compressed.size()); 688 for (auto m : compressed) 689 res.emplace_back(context->getRef(), m); 690 return res; 691 }) 692 .def_property_readonly( 693 "context", 694 [](PyAffineMap &self) { return self.getContext().getObject(); }, 695 "Context that owns the Affine Map") 696 .def( 697 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, 698 kDumpDocstring) 699 .def_static( 700 "get", 701 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, 702 DefaultingPyMlirContext context) { 703 SmallVector<MlirAffineExpr> affineExprs; 704 pyListToVector<PyAffineExpr, MlirAffineExpr>( 705 exprs, affineExprs, "attempting to create an AffineMap"); 706 MlirAffineMap map = 707 mlirAffineMapGet(context->get(), dimCount, symbolCount, 708 affineExprs.size(), affineExprs.data()); 709 return PyAffineMap(context->getRef(), map); 710 }, 711 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), 712 py::arg("context") = py::none(), 713 "Gets a map with the given expressions as results.") 714 .def_static( 715 "get_constant", 716 [](intptr_t value, DefaultingPyMlirContext context) { 717 MlirAffineMap affineMap = 718 mlirAffineMapConstantGet(context->get(), value); 719 return PyAffineMap(context->getRef(), affineMap); 720 }, 721 py::arg("value"), py::arg("context") = py::none(), 722 "Gets an affine map with a single constant result") 723 .def_static( 724 "get_empty", 725 [](DefaultingPyMlirContext context) { 726 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); 727 return PyAffineMap(context->getRef(), affineMap); 728 }, 729 py::arg("context") = py::none(), "Gets an empty affine map.") 730 .def_static( 731 "get_identity", 732 [](intptr_t nDims, DefaultingPyMlirContext context) { 733 MlirAffineMap affineMap = 734 mlirAffineMapMultiDimIdentityGet(context->get(), nDims); 735 return PyAffineMap(context->getRef(), affineMap); 736 }, 737 py::arg("n_dims"), py::arg("context") = py::none(), 738 "Gets an identity map with the given number of dimensions.") 739 .def_static( 740 "get_minor_identity", 741 [](intptr_t nDims, intptr_t nResults, 742 DefaultingPyMlirContext context) { 743 MlirAffineMap affineMap = 744 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); 745 return PyAffineMap(context->getRef(), affineMap); 746 }, 747 py::arg("n_dims"), py::arg("n_results"), 748 py::arg("context") = py::none(), 749 "Gets a minor identity map with the given number of dimensions and " 750 "results.") 751 .def_static( 752 "get_permutation", 753 [](std::vector<unsigned> permutation, 754 DefaultingPyMlirContext context) { 755 if (!isPermutation(permutation)) 756 throw py::cast_error("Invalid permutation when attempting to " 757 "create an AffineMap"); 758 MlirAffineMap affineMap = mlirAffineMapPermutationGet( 759 context->get(), permutation.size(), permutation.data()); 760 return PyAffineMap(context->getRef(), affineMap); 761 }, 762 py::arg("permutation"), py::arg("context") = py::none(), 763 "Gets an affine map that permutes its inputs.") 764 .def( 765 "get_submap", 766 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { 767 intptr_t numResults = mlirAffineMapGetNumResults(self); 768 for (intptr_t pos : resultPos) { 769 if (pos < 0 || pos >= numResults) 770 throw py::value_error("result position out of bounds"); 771 } 772 MlirAffineMap affineMap = mlirAffineMapGetSubMap( 773 self, resultPos.size(), resultPos.data()); 774 return PyAffineMap(self.getContext(), affineMap); 775 }, 776 py::arg("result_positions")) 777 .def( 778 "get_major_submap", 779 [](PyAffineMap &self, intptr_t nResults) { 780 if (nResults >= mlirAffineMapGetNumResults(self)) 781 throw py::value_error("number of results out of bounds"); 782 MlirAffineMap affineMap = 783 mlirAffineMapGetMajorSubMap(self, nResults); 784 return PyAffineMap(self.getContext(), affineMap); 785 }, 786 py::arg("n_results")) 787 .def( 788 "get_minor_submap", 789 [](PyAffineMap &self, intptr_t nResults) { 790 if (nResults >= mlirAffineMapGetNumResults(self)) 791 throw py::value_error("number of results out of bounds"); 792 MlirAffineMap affineMap = 793 mlirAffineMapGetMinorSubMap(self, nResults); 794 return PyAffineMap(self.getContext(), affineMap); 795 }, 796 py::arg("n_results")) 797 .def( 798 "replace", 799 [](PyAffineMap &self, PyAffineExpr &expression, 800 PyAffineExpr &replacement, intptr_t numResultDims, 801 intptr_t numResultSyms) { 802 MlirAffineMap affineMap = mlirAffineMapReplace( 803 self, expression, replacement, numResultDims, numResultSyms); 804 return PyAffineMap(self.getContext(), affineMap); 805 }, 806 py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"), 807 py::arg("n_result_syms")) 808 .def_property_readonly( 809 "is_permutation", 810 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) 811 .def_property_readonly("is_projected_permutation", 812 [](PyAffineMap &self) { 813 return mlirAffineMapIsProjectedPermutation(self); 814 }) 815 .def_property_readonly( 816 "n_dims", 817 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) 818 .def_property_readonly( 819 "n_inputs", 820 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) 821 .def_property_readonly( 822 "n_symbols", 823 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) 824 .def_property_readonly("results", [](PyAffineMap &self) { 825 return PyAffineMapExprList(self); 826 }); 827 PyAffineMapExprList::bind(m); 828 829 //---------------------------------------------------------------------------- 830 // Mapping of PyIntegerSet. 831 //---------------------------------------------------------------------------- 832 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local()) 833 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 834 &PyIntegerSet::getCapsule) 835 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) 836 .def("__eq__", [](PyIntegerSet &self, 837 PyIntegerSet &other) { return self == other; }) 838 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) 839 .def("__str__", 840 [](PyIntegerSet &self) { 841 PyPrintAccumulator printAccum; 842 mlirIntegerSetPrint(self, printAccum.getCallback(), 843 printAccum.getUserData()); 844 return printAccum.join(); 845 }) 846 .def("__repr__", 847 [](PyIntegerSet &self) { 848 PyPrintAccumulator printAccum; 849 printAccum.parts.append("IntegerSet("); 850 mlirIntegerSetPrint(self, printAccum.getCallback(), 851 printAccum.getUserData()); 852 printAccum.parts.append(")"); 853 return printAccum.join(); 854 }) 855 .def("__hash__", 856 [](PyIntegerSet &self) { 857 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 858 }) 859 .def_property_readonly( 860 "context", 861 [](PyIntegerSet &self) { return self.getContext().getObject(); }) 862 .def( 863 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, 864 kDumpDocstring) 865 .def_static( 866 "get", 867 [](intptr_t numDims, intptr_t numSymbols, py::list exprs, 868 std::vector<bool> eqFlags, DefaultingPyMlirContext context) { 869 if (exprs.size() != eqFlags.size()) 870 throw py::value_error( 871 "Expected the number of constraints to match " 872 "that of equality flags"); 873 if (exprs.empty()) 874 throw py::value_error("Expected non-empty list of constraints"); 875 876 // Copy over to a SmallVector because std::vector has a 877 // specialization for booleans that packs data and does not 878 // expose a `bool *`. 879 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); 880 881 SmallVector<MlirAffineExpr> affineExprs; 882 pyListToVector<PyAffineExpr>(exprs, affineExprs, 883 "attempting to create an IntegerSet"); 884 MlirIntegerSet set = mlirIntegerSetGet( 885 context->get(), numDims, numSymbols, exprs.size(), 886 affineExprs.data(), flags.data()); 887 return PyIntegerSet(context->getRef(), set); 888 }, 889 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), 890 py::arg("eq_flags"), py::arg("context") = py::none()) 891 .def_static( 892 "get_empty", 893 [](intptr_t numDims, intptr_t numSymbols, 894 DefaultingPyMlirContext context) { 895 MlirIntegerSet set = 896 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); 897 return PyIntegerSet(context->getRef(), set); 898 }, 899 py::arg("num_dims"), py::arg("num_symbols"), 900 py::arg("context") = py::none()) 901 .def( 902 "get_replaced", 903 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, 904 intptr_t numResultDims, intptr_t numResultSymbols) { 905 if (static_cast<intptr_t>(dimExprs.size()) != 906 mlirIntegerSetGetNumDims(self)) 907 throw py::value_error( 908 "Expected the number of dimension replacement expressions " 909 "to match that of dimensions"); 910 if (static_cast<intptr_t>(symbolExprs.size()) != 911 mlirIntegerSetGetNumSymbols(self)) 912 throw py::value_error( 913 "Expected the number of symbol replacement expressions " 914 "to match that of symbols"); 915 916 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; 917 pyListToVector<PyAffineExpr>( 918 dimExprs, dimAffineExprs, 919 "attempting to create an IntegerSet by replacing dimensions"); 920 pyListToVector<PyAffineExpr>( 921 symbolExprs, symbolAffineExprs, 922 "attempting to create an IntegerSet by replacing symbols"); 923 MlirIntegerSet set = mlirIntegerSetReplaceGet( 924 self, dimAffineExprs.data(), symbolAffineExprs.data(), 925 numResultDims, numResultSymbols); 926 return PyIntegerSet(self.getContext(), set); 927 }, 928 py::arg("dim_exprs"), py::arg("symbol_exprs"), 929 py::arg("num_result_dims"), py::arg("num_result_symbols")) 930 .def_property_readonly("is_canonical_empty", 931 [](PyIntegerSet &self) { 932 return mlirIntegerSetIsCanonicalEmpty(self); 933 }) 934 .def_property_readonly( 935 "n_dims", 936 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) 937 .def_property_readonly( 938 "n_symbols", 939 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) 940 .def_property_readonly( 941 "n_inputs", 942 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) 943 .def_property_readonly("n_equalities", 944 [](PyIntegerSet &self) { 945 return mlirIntegerSetGetNumEqualities(self); 946 }) 947 .def_property_readonly("n_inequalities", 948 [](PyIntegerSet &self) { 949 return mlirIntegerSetGetNumInequalities(self); 950 }) 951 .def_property_readonly("constraints", [](PyIntegerSet &self) { 952 return PyIntegerSetConstraintList(self); 953 }); 954 PyIntegerSetConstraint::bind(m); 955 PyIntegerSetConstraintList::bind(m); 956 } 957