1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/AffineMap.h" 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/IntegerSet.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 using llvm::SmallVector; 22 using llvm::StringRef; 23 using llvm::Twine; 24 25 static const char kDumpDocstring[] = 26 R"(Dumps a debug representation of the object to stderr.)"; 27 28 /// Attempts to populate `result` with the content of `list` casted to the 29 /// appropriate type (Python and C types are provided as template arguments). 30 /// Throws errors in case of failure, using "action" to describe what the caller 31 /// was attempting to do. 32 template <typename PyType, typename CType> 33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, 34 StringRef action) { 35 result.reserve(py::len(list)); 36 for (py::handle item : list) { 37 try { 38 result.push_back(item.cast<PyType>()); 39 } catch (py::cast_error &err) { 40 std::string msg = (llvm::Twine("Invalid expression when ") + action + 41 " (" + err.what() + ")") 42 .str(); 43 throw py::cast_error(msg); 44 } catch (py::reference_cast_error &err) { 45 std::string msg = (llvm::Twine("Invalid expression (None?) when ") + 46 action + " (" + err.what() + ")") 47 .str(); 48 throw py::cast_error(msg); 49 } 50 } 51 } 52 53 template <typename PermutationTy> 54 static bool isPermutation(std::vector<PermutationTy> permutation) { 55 llvm::SmallVector<bool, 8> seen(permutation.size(), false); 56 for (auto val : permutation) { 57 if (val < permutation.size()) { 58 if (seen[val]) 59 return false; 60 seen[val] = true; 61 continue; 62 } 63 return false; 64 } 65 return true; 66 } 67 68 namespace { 69 70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr 71 /// and should be castable from it. Intermediate hierarchy classes can be 72 /// modeled by specifying BaseTy. 73 template <typename DerivedTy, typename BaseTy = PyAffineExpr> 74 class PyConcreteAffineExpr : public BaseTy { 75 public: 76 // Derived classes must define statics for: 77 // IsAFunctionTy isaFunction 78 // const char *pyClassName 79 // and redefine bindDerived. 80 using ClassTy = py::class_<DerivedTy, BaseTy>; 81 using IsAFunctionTy = bool (*)(MlirAffineExpr); 82 83 PyConcreteAffineExpr() = default; 84 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 85 : BaseTy(std::move(contextRef), affineExpr) {} 86 PyConcreteAffineExpr(PyAffineExpr &orig) 87 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} 88 89 static MlirAffineExpr castFrom(PyAffineExpr &orig) { 90 if (!DerivedTy::isaFunction(orig)) { 91 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 92 throw SetPyError(PyExc_ValueError, 93 Twine("Cannot cast affine expression to ") + 94 DerivedTy::pyClassName + " (from " + origRepr + ")"); 95 } 96 return orig; 97 } 98 99 static void bind(py::module &m) { 100 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 101 cls.def(py::init<PyAffineExpr &>()); 102 cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { 103 return DerivedTy::isaFunction(otherAffineExpr); 104 }); 105 DerivedTy::bindDerived(cls); 106 } 107 108 /// Implemented by derived classes to add methods to the Python subclass. 109 static void bindDerived(ClassTy &m) {} 110 }; 111 112 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { 113 public: 114 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; 115 static constexpr const char *pyClassName = "AffineConstantExpr"; 116 using PyConcreteAffineExpr::PyConcreteAffineExpr; 117 118 static PyAffineConstantExpr get(intptr_t value, 119 DefaultingPyMlirContext context) { 120 MlirAffineExpr affineExpr = 121 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); 122 return PyAffineConstantExpr(context->getRef(), affineExpr); 123 } 124 125 static void bindDerived(ClassTy &c) { 126 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), 127 py::arg("context") = py::none()); 128 c.def_property_readonly("value", [](PyAffineConstantExpr &self) { 129 return mlirAffineConstantExprGetValue(self); 130 }); 131 } 132 }; 133 134 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { 135 public: 136 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; 137 static constexpr const char *pyClassName = "AffineDimExpr"; 138 using PyConcreteAffineExpr::PyConcreteAffineExpr; 139 140 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { 141 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); 142 return PyAffineDimExpr(context->getRef(), affineExpr); 143 } 144 145 static void bindDerived(ClassTy &c) { 146 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), 147 py::arg("context") = py::none()); 148 c.def_property_readonly("position", [](PyAffineDimExpr &self) { 149 return mlirAffineDimExprGetPosition(self); 150 }); 151 } 152 }; 153 154 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { 155 public: 156 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; 157 static constexpr const char *pyClassName = "AffineSymbolExpr"; 158 using PyConcreteAffineExpr::PyConcreteAffineExpr; 159 160 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { 161 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); 162 return PyAffineSymbolExpr(context->getRef(), affineExpr); 163 } 164 165 static void bindDerived(ClassTy &c) { 166 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), 167 py::arg("context") = py::none()); 168 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { 169 return mlirAffineSymbolExprGetPosition(self); 170 }); 171 } 172 }; 173 174 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { 175 public: 176 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; 177 static constexpr const char *pyClassName = "AffineBinaryExpr"; 178 using PyConcreteAffineExpr::PyConcreteAffineExpr; 179 180 PyAffineExpr lhs() { 181 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); 182 return PyAffineExpr(getContext(), lhsExpr); 183 } 184 185 PyAffineExpr rhs() { 186 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); 187 return PyAffineExpr(getContext(), rhsExpr); 188 } 189 190 static void bindDerived(ClassTy &c) { 191 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); 192 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); 193 } 194 }; 195 196 class PyAffineAddExpr 197 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { 198 public: 199 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; 200 static constexpr const char *pyClassName = "AffineAddExpr"; 201 using PyConcreteAffineExpr::PyConcreteAffineExpr; 202 203 static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 204 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); 205 return PyAffineAddExpr(lhs.getContext(), expr); 206 } 207 208 static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 209 MlirAffineExpr expr = mlirAffineAddExprGet( 210 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 211 return PyAffineAddExpr(lhs.getContext(), expr); 212 } 213 214 static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 215 MlirAffineExpr expr = mlirAffineAddExprGet( 216 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 217 return PyAffineAddExpr(rhs.getContext(), expr); 218 } 219 220 static void bindDerived(ClassTy &c) { 221 c.def_static("get", &PyAffineAddExpr::get); 222 } 223 }; 224 225 class PyAffineMulExpr 226 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { 227 public: 228 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; 229 static constexpr const char *pyClassName = "AffineMulExpr"; 230 using PyConcreteAffineExpr::PyConcreteAffineExpr; 231 232 static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 233 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); 234 return PyAffineMulExpr(lhs.getContext(), expr); 235 } 236 237 static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 238 MlirAffineExpr expr = mlirAffineMulExprGet( 239 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 240 return PyAffineMulExpr(lhs.getContext(), expr); 241 } 242 243 static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 244 MlirAffineExpr expr = mlirAffineMulExprGet( 245 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 246 return PyAffineMulExpr(rhs.getContext(), expr); 247 } 248 249 static void bindDerived(ClassTy &c) { 250 c.def_static("get", &PyAffineMulExpr::get); 251 } 252 }; 253 254 class PyAffineModExpr 255 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { 256 public: 257 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; 258 static constexpr const char *pyClassName = "AffineModExpr"; 259 using PyConcreteAffineExpr::PyConcreteAffineExpr; 260 261 static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 262 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); 263 return PyAffineModExpr(lhs.getContext(), expr); 264 } 265 266 static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 267 MlirAffineExpr expr = mlirAffineModExprGet( 268 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 269 return PyAffineModExpr(lhs.getContext(), expr); 270 } 271 272 static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 273 MlirAffineExpr expr = mlirAffineModExprGet( 274 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 275 return PyAffineModExpr(rhs.getContext(), expr); 276 } 277 278 static void bindDerived(ClassTy &c) { 279 c.def_static("get", &PyAffineModExpr::get); 280 } 281 }; 282 283 class PyAffineFloorDivExpr 284 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { 285 public: 286 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; 287 static constexpr const char *pyClassName = "AffineFloorDivExpr"; 288 using PyConcreteAffineExpr::PyConcreteAffineExpr; 289 290 static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 291 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); 292 return PyAffineFloorDivExpr(lhs.getContext(), expr); 293 } 294 295 static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 296 MlirAffineExpr expr = mlirAffineFloorDivExprGet( 297 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 298 return PyAffineFloorDivExpr(lhs.getContext(), expr); 299 } 300 301 static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 302 MlirAffineExpr expr = mlirAffineFloorDivExprGet( 303 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 304 return PyAffineFloorDivExpr(rhs.getContext(), expr); 305 } 306 307 static void bindDerived(ClassTy &c) { 308 c.def_static("get", &PyAffineFloorDivExpr::get); 309 } 310 }; 311 312 class PyAffineCeilDivExpr 313 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { 314 public: 315 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; 316 static constexpr const char *pyClassName = "AffineCeilDivExpr"; 317 using PyConcreteAffineExpr::PyConcreteAffineExpr; 318 319 static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 320 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); 321 return PyAffineCeilDivExpr(lhs.getContext(), expr); 322 } 323 324 static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { 325 MlirAffineExpr expr = mlirAffineCeilDivExprGet( 326 lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); 327 return PyAffineCeilDivExpr(lhs.getContext(), expr); 328 } 329 330 static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { 331 MlirAffineExpr expr = mlirAffineCeilDivExprGet( 332 mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); 333 return PyAffineCeilDivExpr(rhs.getContext(), expr); 334 } 335 336 static void bindDerived(ClassTy &c) { 337 c.def_static("get", &PyAffineCeilDivExpr::get); 338 } 339 }; 340 341 } // namespace 342 343 bool PyAffineExpr::operator==(const PyAffineExpr &other) { 344 return mlirAffineExprEqual(affineExpr, other.affineExpr); 345 } 346 347 py::object PyAffineExpr::getCapsule() { 348 return py::reinterpret_steal<py::object>( 349 mlirPythonAffineExprToCapsule(*this)); 350 } 351 352 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { 353 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); 354 if (mlirAffineExprIsNull(rawAffineExpr)) 355 throw py::error_already_set(); 356 return PyAffineExpr( 357 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), 358 rawAffineExpr); 359 } 360 361 //------------------------------------------------------------------------------ 362 // PyAffineMap and utilities. 363 //------------------------------------------------------------------------------ 364 namespace { 365 366 /// A list of expressions contained in an affine map. Internally these are 367 /// stored as a consecutive array leading to inexpensive random access. Both 368 /// the map and the expression are owned by the context so we need not bother 369 /// with lifetime extension. 370 class PyAffineMapExprList 371 : public Sliceable<PyAffineMapExprList, PyAffineExpr> { 372 public: 373 static constexpr const char *pyClassName = "AffineExprList"; 374 375 PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, 376 intptr_t length = -1, intptr_t step = 1) 377 : Sliceable(startIndex, 378 length == -1 ? mlirAffineMapGetNumResults(map) : length, 379 step), 380 affineMap(map) {} 381 382 intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } 383 384 PyAffineExpr getElement(intptr_t pos) { 385 return PyAffineExpr(affineMap.getContext(), 386 mlirAffineMapGetResult(affineMap, pos)); 387 } 388 389 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, 390 intptr_t step) { 391 return PyAffineMapExprList(affineMap, startIndex, length, step); 392 } 393 394 private: 395 PyAffineMap affineMap; 396 }; 397 } // end namespace 398 399 bool PyAffineMap::operator==(const PyAffineMap &other) { 400 return mlirAffineMapEqual(affineMap, other.affineMap); 401 } 402 403 py::object PyAffineMap::getCapsule() { 404 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); 405 } 406 407 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { 408 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); 409 if (mlirAffineMapIsNull(rawAffineMap)) 410 throw py::error_already_set(); 411 return PyAffineMap( 412 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), 413 rawAffineMap); 414 } 415 416 //------------------------------------------------------------------------------ 417 // PyIntegerSet and utilities. 418 //------------------------------------------------------------------------------ 419 namespace { 420 421 class PyIntegerSetConstraint { 422 public: 423 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} 424 425 PyAffineExpr getExpr() { 426 return PyAffineExpr(set.getContext(), 427 mlirIntegerSetGetConstraint(set, pos)); 428 } 429 430 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } 431 432 static void bind(py::module &m) { 433 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint", 434 py::module_local()) 435 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) 436 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); 437 } 438 439 private: 440 PyIntegerSet set; 441 intptr_t pos; 442 }; 443 444 class PyIntegerSetConstraintList 445 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { 446 public: 447 static constexpr const char *pyClassName = "IntegerSetConstraintList"; 448 449 PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, 450 intptr_t length = -1, intptr_t step = 1) 451 : Sliceable(startIndex, 452 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, 453 step), 454 set(set) {} 455 456 intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } 457 458 PyIntegerSetConstraint getElement(intptr_t pos) { 459 return PyIntegerSetConstraint(set, pos); 460 } 461 462 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, 463 intptr_t step) { 464 return PyIntegerSetConstraintList(set, startIndex, length, step); 465 } 466 467 private: 468 PyIntegerSet set; 469 }; 470 } // namespace 471 472 bool PyIntegerSet::operator==(const PyIntegerSet &other) { 473 return mlirIntegerSetEqual(integerSet, other.integerSet); 474 } 475 476 py::object PyIntegerSet::getCapsule() { 477 return py::reinterpret_steal<py::object>( 478 mlirPythonIntegerSetToCapsule(*this)); 479 } 480 481 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { 482 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); 483 if (mlirIntegerSetIsNull(rawIntegerSet)) 484 throw py::error_already_set(); 485 return PyIntegerSet( 486 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), 487 rawIntegerSet); 488 } 489 490 void mlir::python::populateIRAffine(py::module &m) { 491 //---------------------------------------------------------------------------- 492 // Mapping of PyAffineExpr and derived classes. 493 //---------------------------------------------------------------------------- 494 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local()) 495 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 496 &PyAffineExpr::getCapsule) 497 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) 498 .def("__add__", &PyAffineAddExpr::get) 499 .def("__add__", &PyAffineAddExpr::getRHSConstant) 500 .def("__radd__", &PyAffineAddExpr::getRHSConstant) 501 .def("__mul__", &PyAffineMulExpr::get) 502 .def("__mul__", &PyAffineMulExpr::getRHSConstant) 503 .def("__rmul__", &PyAffineMulExpr::getRHSConstant) 504 .def("__mod__", &PyAffineModExpr::get) 505 .def("__mod__", &PyAffineModExpr::getRHSConstant) 506 .def("__rmod__", 507 [](PyAffineExpr &self, intptr_t other) { 508 return PyAffineModExpr::get( 509 PyAffineConstantExpr::get(other, *self.getContext().get()), 510 self); 511 }) 512 .def("__sub__", 513 [](PyAffineExpr &self, PyAffineExpr &other) { 514 auto negOne = 515 PyAffineConstantExpr::get(-1, *self.getContext().get()); 516 return PyAffineAddExpr::get(self, 517 PyAffineMulExpr::get(negOne, other)); 518 }) 519 .def("__sub__", 520 [](PyAffineExpr &self, intptr_t other) { 521 return PyAffineAddExpr::get( 522 self, 523 PyAffineConstantExpr::get(-other, *self.getContext().get())); 524 }) 525 .def("__rsub__", 526 [](PyAffineExpr &self, intptr_t other) { 527 return PyAffineAddExpr::getLHSConstant( 528 other, PyAffineMulExpr::getLHSConstant(-1, self)); 529 }) 530 .def("__eq__", [](PyAffineExpr &self, 531 PyAffineExpr &other) { return self == other; }) 532 .def("__eq__", 533 [](PyAffineExpr &self, py::object &other) { return false; }) 534 .def("__str__", 535 [](PyAffineExpr &self) { 536 PyPrintAccumulator printAccum; 537 mlirAffineExprPrint(self, printAccum.getCallback(), 538 printAccum.getUserData()); 539 return printAccum.join(); 540 }) 541 .def("__repr__", 542 [](PyAffineExpr &self) { 543 PyPrintAccumulator printAccum; 544 printAccum.parts.append("AffineExpr("); 545 mlirAffineExprPrint(self, printAccum.getCallback(), 546 printAccum.getUserData()); 547 printAccum.parts.append(")"); 548 return printAccum.join(); 549 }) 550 .def("__hash__", 551 [](PyAffineExpr &self) { 552 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 553 }) 554 .def_property_readonly( 555 "context", 556 [](PyAffineExpr &self) { return self.getContext().getObject(); }) 557 .def("compose", 558 [](PyAffineExpr &self, PyAffineMap &other) { 559 return PyAffineExpr(self.getContext(), 560 mlirAffineExprCompose(self, other)); 561 }) 562 .def_static( 563 "get_add", &PyAffineAddExpr::get, 564 "Gets an affine expression containing a sum of two expressions.") 565 .def_static("get_add", &PyAffineAddExpr::getLHSConstant, 566 "Gets an affine expression containing a sum of a constant " 567 "and another expression.") 568 .def_static("get_add", &PyAffineAddExpr::getRHSConstant, 569 "Gets an affine expression containing a sum of an expression " 570 "and a constant.") 571 .def_static( 572 "get_mul", &PyAffineMulExpr::get, 573 "Gets an affine expression containing a product of two expressions.") 574 .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, 575 "Gets an affine expression containing a product of a " 576 "constant and another expression.") 577 .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, 578 "Gets an affine expression containing a product of an " 579 "expression and a constant.") 580 .def_static("get_mod", &PyAffineModExpr::get, 581 "Gets an affine expression containing the modulo of dividing " 582 "one expression by another.") 583 .def_static("get_mod", &PyAffineModExpr::getLHSConstant, 584 "Gets a semi-affine expression containing the modulo of " 585 "dividing a constant by an expression.") 586 .def_static("get_mod", &PyAffineModExpr::getRHSConstant, 587 "Gets an affine expression containing the module of dividing" 588 "an expression by a constant.") 589 .def_static("get_floor_div", &PyAffineFloorDivExpr::get, 590 "Gets an affine expression containing the rounded-down " 591 "result of dividing one expression by another.") 592 .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, 593 "Gets a semi-affine expression containing the rounded-down " 594 "result of dividing a constant by an expression.") 595 .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, 596 "Gets an affine expression containing the rounded-down " 597 "result of dividing an expression by a constant.") 598 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, 599 "Gets an affine expression containing the rounded-up result " 600 "of dividing one expression by another.") 601 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, 602 "Gets a semi-affine expression containing the rounded-up " 603 "result of dividing a constant by an expression.") 604 .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, 605 "Gets an affine expression containing the rounded-up result " 606 "of dividing an expression by a constant.") 607 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), 608 py::arg("context") = py::none(), 609 "Gets a constant affine expression with the given value.") 610 .def_static( 611 "get_dim", &PyAffineDimExpr::get, py::arg("position"), 612 py::arg("context") = py::none(), 613 "Gets an affine expression of a dimension at the given position.") 614 .def_static( 615 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), 616 py::arg("context") = py::none(), 617 "Gets an affine expression of a symbol at the given position.") 618 .def( 619 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, 620 kDumpDocstring); 621 PyAffineConstantExpr::bind(m); 622 PyAffineDimExpr::bind(m); 623 PyAffineSymbolExpr::bind(m); 624 PyAffineBinaryExpr::bind(m); 625 PyAffineAddExpr::bind(m); 626 PyAffineMulExpr::bind(m); 627 PyAffineModExpr::bind(m); 628 PyAffineFloorDivExpr::bind(m); 629 PyAffineCeilDivExpr::bind(m); 630 631 //---------------------------------------------------------------------------- 632 // Mapping of PyAffineMap. 633 //---------------------------------------------------------------------------- 634 py::class_<PyAffineMap>(m, "AffineMap", py::module_local()) 635 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 636 &PyAffineMap::getCapsule) 637 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) 638 .def("__eq__", 639 [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) 640 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) 641 .def("__str__", 642 [](PyAffineMap &self) { 643 PyPrintAccumulator printAccum; 644 mlirAffineMapPrint(self, printAccum.getCallback(), 645 printAccum.getUserData()); 646 return printAccum.join(); 647 }) 648 .def("__repr__", 649 [](PyAffineMap &self) { 650 PyPrintAccumulator printAccum; 651 printAccum.parts.append("AffineMap("); 652 mlirAffineMapPrint(self, printAccum.getCallback(), 653 printAccum.getUserData()); 654 printAccum.parts.append(")"); 655 return printAccum.join(); 656 }) 657 .def("__hash__", 658 [](PyAffineMap &self) { 659 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 660 }) 661 .def_static("compress_unused_symbols", 662 [](py::list affineMaps, DefaultingPyMlirContext context) { 663 SmallVector<MlirAffineMap> maps; 664 pyListToVector<PyAffineMap, MlirAffineMap>( 665 affineMaps, maps, "attempting to create an AffineMap"); 666 std::vector<MlirAffineMap> compressed(affineMaps.size()); 667 auto populate = [](void *result, intptr_t idx, 668 MlirAffineMap m) { 669 static_cast<MlirAffineMap *>(result)[idx] = (m); 670 }; 671 mlirAffineMapCompressUnusedSymbols( 672 maps.data(), maps.size(), compressed.data(), populate); 673 std::vector<PyAffineMap> res; 674 res.reserve(compressed.size()); 675 for (auto m : compressed) 676 res.push_back(PyAffineMap(context->getRef(), m)); 677 return res; 678 }) 679 .def_property_readonly( 680 "context", 681 [](PyAffineMap &self) { return self.getContext().getObject(); }, 682 "Context that owns the Affine Map") 683 .def( 684 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, 685 kDumpDocstring) 686 .def_static( 687 "get", 688 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, 689 DefaultingPyMlirContext context) { 690 SmallVector<MlirAffineExpr> affineExprs; 691 pyListToVector<PyAffineExpr, MlirAffineExpr>( 692 exprs, affineExprs, "attempting to create an AffineMap"); 693 MlirAffineMap map = 694 mlirAffineMapGet(context->get(), dimCount, symbolCount, 695 affineExprs.size(), affineExprs.data()); 696 return PyAffineMap(context->getRef(), map); 697 }, 698 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), 699 py::arg("context") = py::none(), 700 "Gets a map with the given expressions as results.") 701 .def_static( 702 "get_constant", 703 [](intptr_t value, DefaultingPyMlirContext context) { 704 MlirAffineMap affineMap = 705 mlirAffineMapConstantGet(context->get(), value); 706 return PyAffineMap(context->getRef(), affineMap); 707 }, 708 py::arg("value"), py::arg("context") = py::none(), 709 "Gets an affine map with a single constant result") 710 .def_static( 711 "get_empty", 712 [](DefaultingPyMlirContext context) { 713 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); 714 return PyAffineMap(context->getRef(), affineMap); 715 }, 716 py::arg("context") = py::none(), "Gets an empty affine map.") 717 .def_static( 718 "get_identity", 719 [](intptr_t nDims, DefaultingPyMlirContext context) { 720 MlirAffineMap affineMap = 721 mlirAffineMapMultiDimIdentityGet(context->get(), nDims); 722 return PyAffineMap(context->getRef(), affineMap); 723 }, 724 py::arg("n_dims"), py::arg("context") = py::none(), 725 "Gets an identity map with the given number of dimensions.") 726 .def_static( 727 "get_minor_identity", 728 [](intptr_t nDims, intptr_t nResults, 729 DefaultingPyMlirContext context) { 730 MlirAffineMap affineMap = 731 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); 732 return PyAffineMap(context->getRef(), affineMap); 733 }, 734 py::arg("n_dims"), py::arg("n_results"), 735 py::arg("context") = py::none(), 736 "Gets a minor identity map with the given number of dimensions and " 737 "results.") 738 .def_static( 739 "get_permutation", 740 [](std::vector<unsigned> permutation, 741 DefaultingPyMlirContext context) { 742 if (!isPermutation(permutation)) 743 throw py::cast_error("Invalid permutation when attempting to " 744 "create an AffineMap"); 745 MlirAffineMap affineMap = mlirAffineMapPermutationGet( 746 context->get(), permutation.size(), permutation.data()); 747 return PyAffineMap(context->getRef(), affineMap); 748 }, 749 py::arg("permutation"), py::arg("context") = py::none(), 750 "Gets an affine map that permutes its inputs.") 751 .def("get_submap", 752 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { 753 intptr_t numResults = mlirAffineMapGetNumResults(self); 754 for (intptr_t pos : resultPos) { 755 if (pos < 0 || pos >= numResults) 756 throw py::value_error("result position out of bounds"); 757 } 758 MlirAffineMap affineMap = mlirAffineMapGetSubMap( 759 self, resultPos.size(), resultPos.data()); 760 return PyAffineMap(self.getContext(), affineMap); 761 }) 762 .def("get_major_submap", 763 [](PyAffineMap &self, intptr_t nResults) { 764 if (nResults >= mlirAffineMapGetNumResults(self)) 765 throw py::value_error("number of results out of bounds"); 766 MlirAffineMap affineMap = 767 mlirAffineMapGetMajorSubMap(self, nResults); 768 return PyAffineMap(self.getContext(), affineMap); 769 }) 770 .def("get_minor_submap", 771 [](PyAffineMap &self, intptr_t nResults) { 772 if (nResults >= mlirAffineMapGetNumResults(self)) 773 throw py::value_error("number of results out of bounds"); 774 MlirAffineMap affineMap = 775 mlirAffineMapGetMinorSubMap(self, nResults); 776 return PyAffineMap(self.getContext(), affineMap); 777 }) 778 .def("replace", 779 [](PyAffineMap &self, PyAffineExpr &expression, 780 PyAffineExpr &replacement, intptr_t numResultDims, 781 intptr_t numResultSyms) { 782 MlirAffineMap affineMap = mlirAffineMapReplace( 783 self, expression, replacement, numResultDims, numResultSyms); 784 return PyAffineMap(self.getContext(), affineMap); 785 }) 786 .def_property_readonly( 787 "is_permutation", 788 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) 789 .def_property_readonly("is_projected_permutation", 790 [](PyAffineMap &self) { 791 return mlirAffineMapIsProjectedPermutation(self); 792 }) 793 .def_property_readonly( 794 "n_dims", 795 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) 796 .def_property_readonly( 797 "n_inputs", 798 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) 799 .def_property_readonly( 800 "n_symbols", 801 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) 802 .def_property_readonly("results", [](PyAffineMap &self) { 803 return PyAffineMapExprList(self); 804 }); 805 PyAffineMapExprList::bind(m); 806 807 //---------------------------------------------------------------------------- 808 // Mapping of PyIntegerSet. 809 //---------------------------------------------------------------------------- 810 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local()) 811 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 812 &PyIntegerSet::getCapsule) 813 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) 814 .def("__eq__", [](PyIntegerSet &self, 815 PyIntegerSet &other) { return self == other; }) 816 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) 817 .def("__str__", 818 [](PyIntegerSet &self) { 819 PyPrintAccumulator printAccum; 820 mlirIntegerSetPrint(self, printAccum.getCallback(), 821 printAccum.getUserData()); 822 return printAccum.join(); 823 }) 824 .def("__repr__", 825 [](PyIntegerSet &self) { 826 PyPrintAccumulator printAccum; 827 printAccum.parts.append("IntegerSet("); 828 mlirIntegerSetPrint(self, printAccum.getCallback(), 829 printAccum.getUserData()); 830 printAccum.parts.append(")"); 831 return printAccum.join(); 832 }) 833 .def("__hash__", 834 [](PyIntegerSet &self) { 835 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 836 }) 837 .def_property_readonly( 838 "context", 839 [](PyIntegerSet &self) { return self.getContext().getObject(); }) 840 .def( 841 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, 842 kDumpDocstring) 843 .def_static( 844 "get", 845 [](intptr_t numDims, intptr_t numSymbols, py::list exprs, 846 std::vector<bool> eqFlags, DefaultingPyMlirContext context) { 847 if (exprs.size() != eqFlags.size()) 848 throw py::value_error( 849 "Expected the number of constraints to match " 850 "that of equality flags"); 851 if (exprs.empty()) 852 throw py::value_error("Expected non-empty list of constraints"); 853 854 // Copy over to a SmallVector because std::vector has a 855 // specialization for booleans that packs data and does not 856 // expose a `bool *`. 857 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); 858 859 SmallVector<MlirAffineExpr> affineExprs; 860 pyListToVector<PyAffineExpr>(exprs, affineExprs, 861 "attempting to create an IntegerSet"); 862 MlirIntegerSet set = mlirIntegerSetGet( 863 context->get(), numDims, numSymbols, exprs.size(), 864 affineExprs.data(), flags.data()); 865 return PyIntegerSet(context->getRef(), set); 866 }, 867 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), 868 py::arg("eq_flags"), py::arg("context") = py::none()) 869 .def_static( 870 "get_empty", 871 [](intptr_t numDims, intptr_t numSymbols, 872 DefaultingPyMlirContext context) { 873 MlirIntegerSet set = 874 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); 875 return PyIntegerSet(context->getRef(), set); 876 }, 877 py::arg("num_dims"), py::arg("num_symbols"), 878 py::arg("context") = py::none()) 879 .def("get_replaced", 880 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, 881 intptr_t numResultDims, intptr_t numResultSymbols) { 882 if (static_cast<intptr_t>(dimExprs.size()) != 883 mlirIntegerSetGetNumDims(self)) 884 throw py::value_error( 885 "Expected the number of dimension replacement expressions " 886 "to match that of dimensions"); 887 if (static_cast<intptr_t>(symbolExprs.size()) != 888 mlirIntegerSetGetNumSymbols(self)) 889 throw py::value_error( 890 "Expected the number of symbol replacement expressions " 891 "to match that of symbols"); 892 893 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; 894 pyListToVector<PyAffineExpr>( 895 dimExprs, dimAffineExprs, 896 "attempting to create an IntegerSet by replacing dimensions"); 897 pyListToVector<PyAffineExpr>( 898 symbolExprs, symbolAffineExprs, 899 "attempting to create an IntegerSet by replacing symbols"); 900 MlirIntegerSet set = mlirIntegerSetReplaceGet( 901 self, dimAffineExprs.data(), symbolAffineExprs.data(), 902 numResultDims, numResultSymbols); 903 return PyIntegerSet(self.getContext(), set); 904 }) 905 .def_property_readonly("is_canonical_empty", 906 [](PyIntegerSet &self) { 907 return mlirIntegerSetIsCanonicalEmpty(self); 908 }) 909 .def_property_readonly( 910 "n_dims", 911 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) 912 .def_property_readonly( 913 "n_symbols", 914 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) 915 .def_property_readonly( 916 "n_inputs", 917 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) 918 .def_property_readonly("n_equalities", 919 [](PyIntegerSet &self) { 920 return mlirIntegerSetGetNumEqualities(self); 921 }) 922 .def_property_readonly("n_inequalities", 923 [](PyIntegerSet &self) { 924 return mlirIntegerSetGetNumInequalities(self); 925 }) 926 .def_property_readonly("constraints", [](PyIntegerSet &self) { 927 return PyIntegerSetConstraintList(self); 928 }); 929 PyIntegerSetConstraint::bind(m); 930 PyIntegerSetConstraintList::bind(m); 931 } 932