1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/AffineMap.h" 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/IntegerSet.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 using llvm::SmallVector; 22 using llvm::StringRef; 23 using llvm::Twine; 24 25 static const char kDumpDocstring[] = 26 R"(Dumps a debug representation of the object to stderr.)"; 27 28 /// Attempts to populate `result` with the content of `list` casted to the 29 /// appropriate type (Python and C types are provided as template arguments). 30 /// Throws errors in case of failure, using "action" to describe what the caller 31 /// was attempting to do. 32 template <typename PyType, typename CType> 33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, 34 StringRef action) { 35 result.reserve(py::len(list)); 36 for (py::handle item : list) { 37 try { 38 result.push_back(item.cast<PyType>()); 39 } catch (py::cast_error &err) { 40 std::string msg = (llvm::Twine("Invalid expression when ") + action + 41 " (" + err.what() + ")") 42 .str(); 43 throw py::cast_error(msg); 44 } catch (py::reference_cast_error &err) { 45 std::string msg = (llvm::Twine("Invalid expression (None?) when ") + 46 action + " (" + err.what() + ")") 47 .str(); 48 throw py::cast_error(msg); 49 } 50 } 51 } 52 53 template <typename PermutationTy> 54 static bool isPermutation(std::vector<PermutationTy> permutation) { 55 llvm::SmallVector<bool, 8> seen(permutation.size(), false); 56 for (auto val : permutation) { 57 if (val < permutation.size()) { 58 if (seen[val]) 59 return false; 60 seen[val] = true; 61 continue; 62 } 63 return false; 64 } 65 return true; 66 } 67 68 namespace { 69 70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr 71 /// and should be castable from it. Intermediate hierarchy classes can be 72 /// modeled by specifying BaseTy. 73 template <typename DerivedTy, typename BaseTy = PyAffineExpr> 74 class PyConcreteAffineExpr : public BaseTy { 75 public: 76 // Derived classes must define statics for: 77 // IsAFunctionTy isaFunction 78 // const char *pyClassName 79 // and redefine bindDerived. 80 using ClassTy = py::class_<DerivedTy, BaseTy>; 81 using IsAFunctionTy = bool (*)(MlirAffineExpr); 82 83 PyConcreteAffineExpr() = default; 84 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 85 : BaseTy(std::move(contextRef), affineExpr) {} 86 PyConcreteAffineExpr(PyAffineExpr &orig) 87 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} 88 89 static MlirAffineExpr castFrom(PyAffineExpr &orig) { 90 if (!DerivedTy::isaFunction(orig)) { 91 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 92 throw SetPyError(PyExc_ValueError, 93 Twine("Cannot cast affine expression to ") + 94 DerivedTy::pyClassName + " (from " + origRepr + ")"); 95 } 96 return orig; 97 } 98 99 static void bind(py::module &m) { 100 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 101 cls.def(py::init<PyAffineExpr &>(), py::arg("expr")); 102 cls.def_static( 103 "isinstance", 104 [](PyAffineExpr &otherAffineExpr) -> bool { 105 return DerivedTy::isaFunction(otherAffineExpr); 106 }, 107 py::arg("other")); 108 DerivedTy::bindDerived(cls); 109 } 110 111 /// Implemented by derived classes to add methods to the Python subclass. 112 static void bindDerived(ClassTy &m) {} 113 }; 114 115 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { 116 public: 117 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; 118 static constexpr const char *pyClassName = "AffineConstantExpr"; 119 using PyConcreteAffineExpr::PyConcreteAffineExpr; 120 121 static PyAffineConstantExpr get(intptr_t value, 122 DefaultingPyMlirContext context) { 123 MlirAffineExpr affineExpr = 124 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); 125 return PyAffineConstantExpr(context->getRef(), affineExpr); 126 } 127 128 static void bindDerived(ClassTy &c) { 129 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), 130 py::arg("context") = py::none()); 131 c.def_property_readonly("value", [](PyAffineConstantExpr &self) { 132 return mlirAffineConstantExprGetValue(self); 133 }); 134 } 135 }; 136 137 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { 138 public: 139 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; 140 static constexpr const char *pyClassName = "AffineDimExpr"; 141 using PyConcreteAffineExpr::PyConcreteAffineExpr; 142 143 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { 144 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); 145 return PyAffineDimExpr(context->getRef(), affineExpr); 146 } 147 148 static void bindDerived(ClassTy &c) { 149 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), 150 py::arg("context") = py::none()); 151 c.def_property_readonly("position", [](PyAffineDimExpr &self) { 152 return mlirAffineDimExprGetPosition(self); 153 }); 154 } 155 }; 156 157 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { 158 public: 159 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; 160 static constexpr const char *pyClassName = "AffineSymbolExpr"; 161 using PyConcreteAffineExpr::PyConcreteAffineExpr; 162 163 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { 164 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); 165 return PyAffineSymbolExpr(context->getRef(), affineExpr); 166 } 167 168 static void bindDerived(ClassTy &c) { 169 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), 170 py::arg("context") = py::none()); 171 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { 172 return mlirAffineSymbolExprGetPosition(self); 173 }); 174 } 175 }; 176 177 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { 178 public: 179 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; 180 static constexpr const char *pyClassName = "AffineBinaryExpr"; 181 using PyConcreteAffineExpr::PyConcreteAffineExpr; 182 183 PyAffineExpr lhs() { 184 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); 185 return PyAffineExpr(getContext(), lhsExpr); 186 } 187 188 PyAffineExpr rhs() { 189 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); 190 return PyAffineExpr(getContext(), rhsExpr); 191 } 192 193 static void bindDerived(ClassTy &c) { 194 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); 195 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); 196 } 197 }; 198 199 class PyAffineAddExpr 200 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { 201 public: 202 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; 203 static constexpr const char *pyClassName = "AffineAddExpr"; 204 using PyConcreteAffineExpr::PyConcreteAffineExpr; 205 206 static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 207 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); 208 return PyAffineAddExpr(lhs.getContext(), expr); 209 } 210 211 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 212 MlirAffineExpr expr = mlirAffineAddExprGet( 213 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 214 return PyAffineAddExpr(lhs.getContext(), expr); 215 } 216 217 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 218 MlirAffineExpr expr = mlirAffineAddExprGet( 219 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 220 return PyAffineAddExpr(rhs.getContext(), expr); 221 } 222 223 static void bindDerived(ClassTy &c) { 224 c.def_static("get", &PyAffineAddExpr::get); 225 } 226 }; 227 228 class PyAffineMulExpr 229 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { 230 public: 231 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; 232 static constexpr const char *pyClassName = "AffineMulExpr"; 233 using PyConcreteAffineExpr::PyConcreteAffineExpr; 234 235 static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 236 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); 237 return PyAffineMulExpr(lhs.getContext(), expr); 238 } 239 240 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 241 MlirAffineExpr expr = mlirAffineMulExprGet( 242 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 243 return PyAffineMulExpr(lhs.getContext(), expr); 244 } 245 246 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 247 MlirAffineExpr expr = mlirAffineMulExprGet( 248 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 249 return PyAffineMulExpr(rhs.getContext(), expr); 250 } 251 252 static void bindDerived(ClassTy &c) { 253 c.def_static("get", &PyAffineMulExpr::get); 254 } 255 }; 256 257 class PyAffineModExpr 258 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { 259 public: 260 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; 261 static constexpr const char *pyClassName = "AffineModExpr"; 262 using PyConcreteAffineExpr::PyConcreteAffineExpr; 263 264 static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 265 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); 266 return PyAffineModExpr(lhs.getContext(), expr); 267 } 268 269 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 270 MlirAffineExpr expr = mlirAffineModExprGet( 271 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 272 return PyAffineModExpr(lhs.getContext(), expr); 273 } 274 275 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 276 MlirAffineExpr expr = mlirAffineModExprGet( 277 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 278 return PyAffineModExpr(rhs.getContext(), expr); 279 } 280 281 static void bindDerived(ClassTy &c) { 282 c.def_static("get", &PyAffineModExpr::get); 283 } 284 }; 285 286 class PyAffineFloorDivExpr 287 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { 288 public: 289 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; 290 static constexpr const char *pyClassName = "AffineFloorDivExpr"; 291 using PyConcreteAffineExpr::PyConcreteAffineExpr; 292 293 static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 294 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); 295 return PyAffineFloorDivExpr(lhs.getContext(), expr); 296 } 297 298 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 299 MlirAffineExpr expr = mlirAffineFloorDivExprGet( 300 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 301 return PyAffineFloorDivExpr(lhs.getContext(), expr); 302 } 303 304 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 305 MlirAffineExpr expr = mlirAffineFloorDivExprGet( 306 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 307 return PyAffineFloorDivExpr(rhs.getContext(), expr); 308 } 309 310 static void bindDerived(ClassTy &c) { 311 c.def_static("get", &PyAffineFloorDivExpr::get); 312 } 313 }; 314 315 class PyAffineCeilDivExpr 316 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { 317 public: 318 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; 319 static constexpr const char *pyClassName = "AffineCeilDivExpr"; 320 using PyConcreteAffineExpr::PyConcreteAffineExpr; 321 322 static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 323 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); 324 return PyAffineCeilDivExpr(lhs.getContext(), expr); 325 } 326 327 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 328 MlirAffineExpr expr = mlirAffineCeilDivExprGet( 329 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 330 return PyAffineCeilDivExpr(lhs.getContext(), expr); 331 } 332 333 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 334 MlirAffineExpr expr = mlirAffineCeilDivExprGet( 335 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 336 return PyAffineCeilDivExpr(rhs.getContext(), expr); 337 } 338 339 static void bindDerived(ClassTy &c) { 340 c.def_static("get", &PyAffineCeilDivExpr::get); 341 } 342 }; 343 344 } // namespace 345 346 bool PyAffineExpr::operator==(const PyAffineExpr &other) { 347 return mlirAffineExprEqual(affineExpr, other.affineExpr); 348 } 349 350 py::object PyAffineExpr::getCapsule() { 351 return py::reinterpret_steal<py::object>( 352 mlirPythonAffineExprToCapsule(*this)); 353 } 354 355 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { 356 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); 357 if (mlirAffineExprIsNull(rawAffineExpr)) 358 throw py::error_already_set(); 359 return PyAffineExpr( 360 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), 361 rawAffineExpr); 362 } 363 364 //------------------------------------------------------------------------------ 365 // PyAffineMap and utilities. 366 //------------------------------------------------------------------------------ 367 namespace { 368 369 /// A list of expressions contained in an affine map. Internally these are 370 /// stored as a consecutive array leading to inexpensive random access. Both 371 /// the map and the expression are owned by the context so we need not bother 372 /// with lifetime extension. 373 class PyAffineMapExprList 374 : public Sliceable<PyAffineMapExprList, PyAffineExpr> { 375 public: 376 static constexpr const char *pyClassName = "AffineExprList"; 377 378 PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, 379 intptr_t length = -1, intptr_t step = 1) 380 : Sliceable(startIndex, 381 length == -1 ? mlirAffineMapGetNumResults(map) : length, 382 step), 383 affineMap(map) {} 384 385 intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } 386 387 PyAffineExpr getElement(intptr_t pos) { 388 return PyAffineExpr(affineMap.getContext(), 389 mlirAffineMapGetResult(affineMap, pos)); 390 } 391 392 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, 393 intptr_t step) { 394 return PyAffineMapExprList(affineMap, startIndex, length, step); 395 } 396 397 private: 398 PyAffineMap affineMap; 399 }; 400 } // namespace 401 402 bool PyAffineMap::operator==(const PyAffineMap &other) { 403 return mlirAffineMapEqual(affineMap, other.affineMap); 404 } 405 406 py::object PyAffineMap::getCapsule() { 407 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); 408 } 409 410 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { 411 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); 412 if (mlirAffineMapIsNull(rawAffineMap)) 413 throw py::error_already_set(); 414 return PyAffineMap( 415 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), 416 rawAffineMap); 417 } 418 419 //------------------------------------------------------------------------------ 420 // PyIntegerSet and utilities. 421 //------------------------------------------------------------------------------ 422 namespace { 423 424 class PyIntegerSetConstraint { 425 public: 426 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} 427 428 PyAffineExpr getExpr() { 429 return PyAffineExpr(set.getContext(), 430 mlirIntegerSetGetConstraint(set, pos)); 431 } 432 433 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } 434 435 static void bind(py::module &m) { 436 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint", 437 py::module_local()) 438 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) 439 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); 440 } 441 442 private: 443 PyIntegerSet set; 444 intptr_t pos; 445 }; 446 447 class PyIntegerSetConstraintList 448 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { 449 public: 450 static constexpr const char *pyClassName = "IntegerSetConstraintList"; 451 452 PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, 453 intptr_t length = -1, intptr_t step = 1) 454 : Sliceable(startIndex, 455 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, 456 step), 457 set(set) {} 458 459 intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } 460 461 PyIntegerSetConstraint getElement(intptr_t pos) { 462 return PyIntegerSetConstraint(set, pos); 463 } 464 465 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, 466 intptr_t step) { 467 return PyIntegerSetConstraintList(set, startIndex, length, step); 468 } 469 470 private: 471 PyIntegerSet set; 472 }; 473 } // namespace 474 475 bool PyIntegerSet::operator==(const PyIntegerSet &other) { 476 return mlirIntegerSetEqual(integerSet, other.integerSet); 477 } 478 479 py::object PyIntegerSet::getCapsule() { 480 return py::reinterpret_steal<py::object>( 481 mlirPythonIntegerSetToCapsule(*this)); 482 } 483 484 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { 485 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); 486 if (mlirIntegerSetIsNull(rawIntegerSet)) 487 throw py::error_already_set(); 488 return PyIntegerSet( 489 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), 490 rawIntegerSet); 491 } 492 493 void mlir::python::populateIRAffine(py::module &m) { 494 //---------------------------------------------------------------------------- 495 // Mapping of PyAffineExpr and derived classes. 496 //---------------------------------------------------------------------------- 497 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local()) 498 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 499 &PyAffineExpr::getCapsule) 500 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) 501 .def("__add__", &PyAffineAddExpr::get) 502 .def("__add__", &PyAffineAddExpr::getRHSConstant) 503 .def("__radd__", &PyAffineAddExpr::getRHSConstant) 504 .def("__mul__", &PyAffineMulExpr::get) 505 .def("__mul__", &PyAffineMulExpr::getRHSConstant) 506 .def("__rmul__", &PyAffineMulExpr::getRHSConstant) 507 .def("__mod__", &PyAffineModExpr::get) 508 .def("__mod__", &PyAffineModExpr::getRHSConstant) 509 .def("__rmod__", 510 [](PyAffineExpr &self, intptr_t other) { 511 return PyAffineModExpr::get( 512 PyAffineConstantExpr::get(other, *self.getContext().get()), 513 self); 514 }) 515 .def("__sub__", 516 [](PyAffineExpr &self, PyAffineExpr &other) { 517 auto negOne = 518 PyAffineConstantExpr::get(-1, *self.getContext().get()); 519 return PyAffineAddExpr::get(self, 520 PyAffineMulExpr::get(negOne, other)); 521 }) 522 .def("__sub__", 523 [](PyAffineExpr &self, intptr_t other) { 524 return PyAffineAddExpr::get( 525 self, 526 PyAffineConstantExpr::get(-other, *self.getContext().get())); 527 }) 528 .def("__rsub__", 529 [](PyAffineExpr &self, intptr_t other) { 530 return PyAffineAddExpr::getLHSConstant( 531 other, PyAffineMulExpr::getLHSConstant(-1, self)); 532 }) 533 .def("__eq__", [](PyAffineExpr &self, 534 PyAffineExpr &other) { return self == other; }) 535 .def("__eq__", 536 [](PyAffineExpr &self, py::object &other) { return false; }) 537 .def("__str__", 538 [](PyAffineExpr &self) { 539 PyPrintAccumulator printAccum; 540 mlirAffineExprPrint(self, printAccum.getCallback(), 541 printAccum.getUserData()); 542 return printAccum.join(); 543 }) 544 .def("__repr__", 545 [](PyAffineExpr &self) { 546 PyPrintAccumulator printAccum; 547 printAccum.parts.append("AffineExpr("); 548 mlirAffineExprPrint(self, printAccum.getCallback(), 549 printAccum.getUserData()); 550 printAccum.parts.append(")"); 551 return printAccum.join(); 552 }) 553 .def("__hash__", 554 [](PyAffineExpr &self) { 555 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 556 }) 557 .def_property_readonly( 558 "context", 559 [](PyAffineExpr &self) { return self.getContext().getObject(); }) 560 .def("compose", 561 [](PyAffineExpr &self, PyAffineMap &other) { 562 return PyAffineExpr(self.getContext(), 563 mlirAffineExprCompose(self, other)); 564 }) 565 .def_static( 566 "get_add", &PyAffineAddExpr::get, 567 "Gets an affine expression containing a sum of two expressions.") 568 .def_static("get_add", &PyAffineAddExpr::getLHSConstant, 569 "Gets an affine expression containing a sum of a constant " 570 "and another expression.") 571 .def_static("get_add", &PyAffineAddExpr::getRHSConstant, 572 "Gets an affine expression containing a sum of an expression " 573 "and a constant.") 574 .def_static( 575 "get_mul", &PyAffineMulExpr::get, 576 "Gets an affine expression containing a product of two expressions.") 577 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, 578 "Gets an affine expression containing a product of a " 579 "constant and another expression.") 580 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, 581 "Gets an affine expression containing a product of an " 582 "expression and a constant.") 583 .def_static("get_mod", &PyAffineModExpr::get, 584 "Gets an affine expression containing the modulo of dividing " 585 "one expression by another.") 586 .def_static("get_mod", &PyAffineModExpr::getLHSConstant, 587 "Gets a semi-affine expression containing the modulo of " 588 "dividing a constant by an expression.") 589 .def_static("get_mod", &PyAffineModExpr::getRHSConstant, 590 "Gets an affine expression containing the module of dividing" 591 "an expression by a constant.") 592 .def_static("get_floor_div", &PyAffineFloorDivExpr::get, 593 "Gets an affine expression containing the rounded-down " 594 "result of dividing one expression by another.") 595 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, 596 "Gets a semi-affine expression containing the rounded-down " 597 "result of dividing a constant by an expression.") 598 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, 599 "Gets an affine expression containing the rounded-down " 600 "result of dividing an expression by a constant.") 601 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, 602 "Gets an affine expression containing the rounded-up result " 603 "of dividing one expression by another.") 604 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, 605 "Gets a semi-affine expression containing the rounded-up " 606 "result of dividing a constant by an expression.") 607 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, 608 "Gets an affine expression containing the rounded-up result " 609 "of dividing an expression by a constant.") 610 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), 611 py::arg("context") = py::none(), 612 "Gets a constant affine expression with the given value.") 613 .def_static( 614 "get_dim", &PyAffineDimExpr::get, py::arg("position"), 615 py::arg("context") = py::none(), 616 "Gets an affine expression of a dimension at the given position.") 617 .def_static( 618 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), 619 py::arg("context") = py::none(), 620 "Gets an affine expression of a symbol at the given position.") 621 .def( 622 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, 623 kDumpDocstring); 624 PyAffineConstantExpr::bind(m); 625 PyAffineDimExpr::bind(m); 626 PyAffineSymbolExpr::bind(m); 627 PyAffineBinaryExpr::bind(m); 628 PyAffineAddExpr::bind(m); 629 PyAffineMulExpr::bind(m); 630 PyAffineModExpr::bind(m); 631 PyAffineFloorDivExpr::bind(m); 632 PyAffineCeilDivExpr::bind(m); 633 634 //---------------------------------------------------------------------------- 635 // Mapping of PyAffineMap. 636 //---------------------------------------------------------------------------- 637 py::class_<PyAffineMap>(m, "AffineMap", py::module_local()) 638 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 639 &PyAffineMap::getCapsule) 640 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) 641 .def("__eq__", 642 [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) 643 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) 644 .def("__str__", 645 [](PyAffineMap &self) { 646 PyPrintAccumulator printAccum; 647 mlirAffineMapPrint(self, printAccum.getCallback(), 648 printAccum.getUserData()); 649 return printAccum.join(); 650 }) 651 .def("__repr__", 652 [](PyAffineMap &self) { 653 PyPrintAccumulator printAccum; 654 printAccum.parts.append("AffineMap("); 655 mlirAffineMapPrint(self, printAccum.getCallback(), 656 printAccum.getUserData()); 657 printAccum.parts.append(")"); 658 return printAccum.join(); 659 }) 660 .def("__hash__", 661 [](PyAffineMap &self) { 662 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 663 }) 664 .def_static("compress_unused_symbols", 665 [](py::list affineMaps, DefaultingPyMlirContext context) { 666 SmallVector<MlirAffineMap> maps; 667 pyListToVector<PyAffineMap, MlirAffineMap>( 668 affineMaps, maps, "attempting to create an AffineMap"); 669 std::vector<MlirAffineMap> compressed(affineMaps.size()); 670 auto populate = [](void *result, intptr_t idx, 671 MlirAffineMap m) { 672 static_cast<MlirAffineMap *>(result)[idx] = (m); 673 }; 674 mlirAffineMapCompressUnusedSymbols( 675 maps.data(), maps.size(), compressed.data(), populate); 676 std::vector<PyAffineMap> res; 677 res.reserve(compressed.size()); 678 for (auto m : compressed) 679 res.push_back(PyAffineMap(context->getRef(), m)); 680 return res; 681 }) 682 .def_property_readonly( 683 "context", 684 [](PyAffineMap &self) { return self.getContext().getObject(); }, 685 "Context that owns the Affine Map") 686 .def( 687 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, 688 kDumpDocstring) 689 .def_static( 690 "get", 691 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, 692 DefaultingPyMlirContext context) { 693 SmallVector<MlirAffineExpr> affineExprs; 694 pyListToVector<PyAffineExpr, MlirAffineExpr>( 695 exprs, affineExprs, "attempting to create an AffineMap"); 696 MlirAffineMap map = 697 mlirAffineMapGet(context->get(), dimCount, symbolCount, 698 affineExprs.size(), affineExprs.data()); 699 return PyAffineMap(context->getRef(), map); 700 }, 701 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), 702 py::arg("context") = py::none(), 703 "Gets a map with the given expressions as results.") 704 .def_static( 705 "get_constant", 706 [](intptr_t value, DefaultingPyMlirContext context) { 707 MlirAffineMap affineMap = 708 mlirAffineMapConstantGet(context->get(), value); 709 return PyAffineMap(context->getRef(), affineMap); 710 }, 711 py::arg("value"), py::arg("context") = py::none(), 712 "Gets an affine map with a single constant result") 713 .def_static( 714 "get_empty", 715 [](DefaultingPyMlirContext context) { 716 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); 717 return PyAffineMap(context->getRef(), affineMap); 718 }, 719 py::arg("context") = py::none(), "Gets an empty affine map.") 720 .def_static( 721 "get_identity", 722 [](intptr_t nDims, DefaultingPyMlirContext context) { 723 MlirAffineMap affineMap = 724 mlirAffineMapMultiDimIdentityGet(context->get(), nDims); 725 return PyAffineMap(context->getRef(), affineMap); 726 }, 727 py::arg("n_dims"), py::arg("context") = py::none(), 728 "Gets an identity map with the given number of dimensions.") 729 .def_static( 730 "get_minor_identity", 731 [](intptr_t nDims, intptr_t nResults, 732 DefaultingPyMlirContext context) { 733 MlirAffineMap affineMap = 734 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); 735 return PyAffineMap(context->getRef(), affineMap); 736 }, 737 py::arg("n_dims"), py::arg("n_results"), 738 py::arg("context") = py::none(), 739 "Gets a minor identity map with the given number of dimensions and " 740 "results.") 741 .def_static( 742 "get_permutation", 743 [](std::vector<unsigned> permutation, 744 DefaultingPyMlirContext context) { 745 if (!isPermutation(permutation)) 746 throw py::cast_error("Invalid permutation when attempting to " 747 "create an AffineMap"); 748 MlirAffineMap affineMap = mlirAffineMapPermutationGet( 749 context->get(), permutation.size(), permutation.data()); 750 return PyAffineMap(context->getRef(), affineMap); 751 }, 752 py::arg("permutation"), py::arg("context") = py::none(), 753 "Gets an affine map that permutes its inputs.") 754 .def( 755 "get_submap", 756 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { 757 intptr_t numResults = mlirAffineMapGetNumResults(self); 758 for (intptr_t pos : resultPos) { 759 if (pos < 0 || pos >= numResults) 760 throw py::value_error("result position out of bounds"); 761 } 762 MlirAffineMap affineMap = mlirAffineMapGetSubMap( 763 self, resultPos.size(), resultPos.data()); 764 return PyAffineMap(self.getContext(), affineMap); 765 }, 766 py::arg("result_positions")) 767 .def( 768 "get_major_submap", 769 [](PyAffineMap &self, intptr_t nResults) { 770 if (nResults >= mlirAffineMapGetNumResults(self)) 771 throw py::value_error("number of results out of bounds"); 772 MlirAffineMap affineMap = 773 mlirAffineMapGetMajorSubMap(self, nResults); 774 return PyAffineMap(self.getContext(), affineMap); 775 }, 776 py::arg("n_results")) 777 .def( 778 "get_minor_submap", 779 [](PyAffineMap &self, intptr_t nResults) { 780 if (nResults >= mlirAffineMapGetNumResults(self)) 781 throw py::value_error("number of results out of bounds"); 782 MlirAffineMap affineMap = 783 mlirAffineMapGetMinorSubMap(self, nResults); 784 return PyAffineMap(self.getContext(), affineMap); 785 }, 786 py::arg("n_results")) 787 .def( 788 "replace", 789 [](PyAffineMap &self, PyAffineExpr &expression, 790 PyAffineExpr &replacement, intptr_t numResultDims, 791 intptr_t numResultSyms) { 792 MlirAffineMap affineMap = mlirAffineMapReplace( 793 self, expression, replacement, numResultDims, numResultSyms); 794 return PyAffineMap(self.getContext(), affineMap); 795 }, 796 py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"), 797 py::arg("n_result_syms")) 798 .def_property_readonly( 799 "is_permutation", 800 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) 801 .def_property_readonly("is_projected_permutation", 802 [](PyAffineMap &self) { 803 return mlirAffineMapIsProjectedPermutation(self); 804 }) 805 .def_property_readonly( 806 "n_dims", 807 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) 808 .def_property_readonly( 809 "n_inputs", 810 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) 811 .def_property_readonly( 812 "n_symbols", 813 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) 814 .def_property_readonly("results", [](PyAffineMap &self) { 815 return PyAffineMapExprList(self); 816 }); 817 PyAffineMapExprList::bind(m); 818 819 //---------------------------------------------------------------------------- 820 // Mapping of PyIntegerSet. 821 //---------------------------------------------------------------------------- 822 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local()) 823 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 824 &PyIntegerSet::getCapsule) 825 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) 826 .def("__eq__", [](PyIntegerSet &self, 827 PyIntegerSet &other) { return self == other; }) 828 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) 829 .def("__str__", 830 [](PyIntegerSet &self) { 831 PyPrintAccumulator printAccum; 832 mlirIntegerSetPrint(self, printAccum.getCallback(), 833 printAccum.getUserData()); 834 return printAccum.join(); 835 }) 836 .def("__repr__", 837 [](PyIntegerSet &self) { 838 PyPrintAccumulator printAccum; 839 printAccum.parts.append("IntegerSet("); 840 mlirIntegerSetPrint(self, printAccum.getCallback(), 841 printAccum.getUserData()); 842 printAccum.parts.append(")"); 843 return printAccum.join(); 844 }) 845 .def("__hash__", 846 [](PyIntegerSet &self) { 847 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 848 }) 849 .def_property_readonly( 850 "context", 851 [](PyIntegerSet &self) { return self.getContext().getObject(); }) 852 .def( 853 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, 854 kDumpDocstring) 855 .def_static( 856 "get", 857 [](intptr_t numDims, intptr_t numSymbols, py::list exprs, 858 std::vector<bool> eqFlags, DefaultingPyMlirContext context) { 859 if (exprs.size() != eqFlags.size()) 860 throw py::value_error( 861 "Expected the number of constraints to match " 862 "that of equality flags"); 863 if (exprs.empty()) 864 throw py::value_error("Expected non-empty list of constraints"); 865 866 // Copy over to a SmallVector because std::vector has a 867 // specialization for booleans that packs data and does not 868 // expose a `bool *`. 869 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); 870 871 SmallVector<MlirAffineExpr> affineExprs; 872 pyListToVector<PyAffineExpr>(exprs, affineExprs, 873 "attempting to create an IntegerSet"); 874 MlirIntegerSet set = mlirIntegerSetGet( 875 context->get(), numDims, numSymbols, exprs.size(), 876 affineExprs.data(), flags.data()); 877 return PyIntegerSet(context->getRef(), set); 878 }, 879 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), 880 py::arg("eq_flags"), py::arg("context") = py::none()) 881 .def_static( 882 "get_empty", 883 [](intptr_t numDims, intptr_t numSymbols, 884 DefaultingPyMlirContext context) { 885 MlirIntegerSet set = 886 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); 887 return PyIntegerSet(context->getRef(), set); 888 }, 889 py::arg("num_dims"), py::arg("num_symbols"), 890 py::arg("context") = py::none()) 891 .def( 892 "get_replaced", 893 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, 894 intptr_t numResultDims, intptr_t numResultSymbols) { 895 if (static_cast<intptr_t>(dimExprs.size()) != 896 mlirIntegerSetGetNumDims(self)) 897 throw py::value_error( 898 "Expected the number of dimension replacement expressions " 899 "to match that of dimensions"); 900 if (static_cast<intptr_t>(symbolExprs.size()) != 901 mlirIntegerSetGetNumSymbols(self)) 902 throw py::value_error( 903 "Expected the number of symbol replacement expressions " 904 "to match that of symbols"); 905 906 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; 907 pyListToVector<PyAffineExpr>( 908 dimExprs, dimAffineExprs, 909 "attempting to create an IntegerSet by replacing dimensions"); 910 pyListToVector<PyAffineExpr>( 911 symbolExprs, symbolAffineExprs, 912 "attempting to create an IntegerSet by replacing symbols"); 913 MlirIntegerSet set = mlirIntegerSetReplaceGet( 914 self, dimAffineExprs.data(), symbolAffineExprs.data(), 915 numResultDims, numResultSymbols); 916 return PyIntegerSet(self.getContext(), set); 917 }, 918 py::arg("dim_exprs"), py::arg("symbol_exprs"), 919 py::arg("num_result_dims"), py::arg("num_result_symbols")) 920 .def_property_readonly("is_canonical_empty", 921 [](PyIntegerSet &self) { 922 return mlirIntegerSetIsCanonicalEmpty(self); 923 }) 924 .def_property_readonly( 925 "n_dims", 926 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) 927 .def_property_readonly( 928 "n_symbols", 929 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) 930 .def_property_readonly( 931 "n_inputs", 932 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) 933 .def_property_readonly("n_equalities", 934 [](PyIntegerSet &self) { 935 return mlirIntegerSetGetNumEqualities(self); 936 }) 937 .def_property_readonly("n_inequalities", 938 [](PyIntegerSet &self) { 939 return mlirIntegerSetGetNumInequalities(self); 940 }) 941 .def_property_readonly("constraints", [](PyIntegerSet &self) { 942 return PyIntegerSetConstraintList(self); 943 }); 944 PyIntegerSetConstraint::bind(m); 945 PyIntegerSetConstraintList::bind(m); 946 } 947