1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/AffineMap.h" 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/IntegerSet.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 using llvm::SmallVector; 22 using llvm::StringRef; 23 using llvm::Twine; 24 25 static const char kDumpDocstring[] = 26 R"(Dumps a debug representation of the object to stderr.)"; 27 28 /// Attempts to populate `result` with the content of `list` casted to the 29 /// appropriate type (Python and C types are provided as template arguments). 30 /// Throws errors in case of failure, using "action" to describe what the caller 31 /// was attempting to do. 32 template <typename PyType, typename CType> 33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, 34 StringRef action) { 35 result.reserve(py::len(list)); 36 for (py::handle item : list) { 37 try { 38 result.push_back(item.cast<PyType>()); 39 } catch (py::cast_error &err) { 40 std::string msg = (llvm::Twine("Invalid expression when ") + action + 41 " (" + err.what() + ")") 42 .str(); 43 throw py::cast_error(msg); 44 } catch (py::reference_cast_error &err) { 45 std::string msg = (llvm::Twine("Invalid expression (None?) when ") + 46 action + " (" + err.what() + ")") 47 .str(); 48 throw py::cast_error(msg); 49 } 50 } 51 } 52 53 template <typename PermutationTy> 54 static bool isPermutation(std::vector<PermutationTy> permutation) { 55 llvm::SmallVector<bool, 8> seen(permutation.size(), false); 56 for (auto val : permutation) { 57 if (val < permutation.size()) { 58 if (seen[val]) 59 return false; 60 seen[val] = true; 61 continue; 62 } 63 return false; 64 } 65 return true; 66 } 67 68 namespace { 69 70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr 71 /// and should be castable from it. Intermediate hierarchy classes can be 72 /// modeled by specifying BaseTy. 73 template <typename DerivedTy, typename BaseTy = PyAffineExpr> 74 class PyConcreteAffineExpr : public BaseTy { 75 public: 76 // Derived classes must define statics for: 77 // IsAFunctionTy isaFunction 78 // const char *pyClassName 79 // and redefine bindDerived. 80 using ClassTy = py::class_<DerivedTy, BaseTy>; 81 using IsAFunctionTy = bool (*)(MlirAffineExpr); 82 83 PyConcreteAffineExpr() = default; 84 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 85 : BaseTy(std::move(contextRef), affineExpr) {} 86 PyConcreteAffineExpr(PyAffineExpr &orig) 87 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} 88 89 static MlirAffineExpr castFrom(PyAffineExpr &orig) { 90 if (!DerivedTy::isaFunction(orig)) { 91 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 92 throw SetPyError(PyExc_ValueError, 93 Twine("Cannot cast affine expression to ") + 94 DerivedTy::pyClassName + " (from " + origRepr + ")"); 95 } 96 return orig; 97 } 98 99 static void bind(py::module &m) { 100 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 101 cls.def(py::init<PyAffineExpr &>()); 102 cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { 103 return DerivedTy::isaFunction(otherAffineExpr); 104 }); 105 DerivedTy::bindDerived(cls); 106 } 107 108 /// Implemented by derived classes to add methods to the Python subclass. 109 static void bindDerived(ClassTy &m) {} 110 }; 111 112 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { 113 public: 114 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; 115 static constexpr const char *pyClassName = "AffineConstantExpr"; 116 using PyConcreteAffineExpr::PyConcreteAffineExpr; 117 118 static PyAffineConstantExpr get(intptr_t value, 119 DefaultingPyMlirContext context) { 120 MlirAffineExpr affineExpr = 121 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); 122 return PyAffineConstantExpr(context->getRef(), affineExpr); 123 } 124 125 static void bindDerived(ClassTy &c) { 126 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), 127 py::arg("context") = py::none()); 128 c.def_property_readonly("value", [](PyAffineConstantExpr &self) { 129 return mlirAffineConstantExprGetValue(self); 130 }); 131 } 132 }; 133 134 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { 135 public: 136 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; 137 static constexpr const char *pyClassName = "AffineDimExpr"; 138 using PyConcreteAffineExpr::PyConcreteAffineExpr; 139 140 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { 141 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); 142 return PyAffineDimExpr(context->getRef(), affineExpr); 143 } 144 145 static void bindDerived(ClassTy &c) { 146 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), 147 py::arg("context") = py::none()); 148 c.def_property_readonly("position", [](PyAffineDimExpr &self) { 149 return mlirAffineDimExprGetPosition(self); 150 }); 151 } 152 }; 153 154 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { 155 public: 156 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; 157 static constexpr const char *pyClassName = "AffineSymbolExpr"; 158 using PyConcreteAffineExpr::PyConcreteAffineExpr; 159 160 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { 161 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); 162 return PyAffineSymbolExpr(context->getRef(), affineExpr); 163 } 164 165 static void bindDerived(ClassTy &c) { 166 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), 167 py::arg("context") = py::none()); 168 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { 169 return mlirAffineSymbolExprGetPosition(self); 170 }); 171 } 172 }; 173 174 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { 175 public: 176 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; 177 static constexpr const char *pyClassName = "AffineBinaryExpr"; 178 using PyConcreteAffineExpr::PyConcreteAffineExpr; 179 180 PyAffineExpr lhs() { 181 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); 182 return PyAffineExpr(getContext(), lhsExpr); 183 } 184 185 PyAffineExpr rhs() { 186 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); 187 return PyAffineExpr(getContext(), rhsExpr); 188 } 189 190 static void bindDerived(ClassTy &c) { 191 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); 192 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); 193 } 194 }; 195 196 class PyAffineAddExpr 197 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { 198 public: 199 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; 200 static constexpr const char *pyClassName = "AffineAddExpr"; 201 using PyConcreteAffineExpr::PyConcreteAffineExpr; 202 203 static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 204 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); 205 return PyAffineAddExpr(lhs.getContext(), expr); 206 } 207 208 static void bindDerived(ClassTy &c) { 209 c.def_static("get", &PyAffineAddExpr::get); 210 } 211 }; 212 213 class PyAffineMulExpr 214 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { 215 public: 216 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; 217 static constexpr const char *pyClassName = "AffineMulExpr"; 218 using PyConcreteAffineExpr::PyConcreteAffineExpr; 219 220 static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 221 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); 222 return PyAffineMulExpr(lhs.getContext(), expr); 223 } 224 225 static void bindDerived(ClassTy &c) { 226 c.def_static("get", &PyAffineMulExpr::get); 227 } 228 }; 229 230 class PyAffineModExpr 231 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { 232 public: 233 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; 234 static constexpr const char *pyClassName = "AffineModExpr"; 235 using PyConcreteAffineExpr::PyConcreteAffineExpr; 236 237 static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 238 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); 239 return PyAffineModExpr(lhs.getContext(), expr); 240 } 241 242 static void bindDerived(ClassTy &c) { 243 c.def_static("get", &PyAffineModExpr::get); 244 } 245 }; 246 247 class PyAffineFloorDivExpr 248 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { 249 public: 250 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; 251 static constexpr const char *pyClassName = "AffineFloorDivExpr"; 252 using PyConcreteAffineExpr::PyConcreteAffineExpr; 253 254 static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 255 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); 256 return PyAffineFloorDivExpr(lhs.getContext(), expr); 257 } 258 259 static void bindDerived(ClassTy &c) { 260 c.def_static("get", &PyAffineFloorDivExpr::get); 261 } 262 }; 263 264 class PyAffineCeilDivExpr 265 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { 266 public: 267 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; 268 static constexpr const char *pyClassName = "AffineCeilDivExpr"; 269 using PyConcreteAffineExpr::PyConcreteAffineExpr; 270 271 static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 272 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); 273 return PyAffineCeilDivExpr(lhs.getContext(), expr); 274 } 275 276 static void bindDerived(ClassTy &c) { 277 c.def_static("get", &PyAffineCeilDivExpr::get); 278 } 279 }; 280 281 } // namespace 282 283 bool PyAffineExpr::operator==(const PyAffineExpr &other) { 284 return mlirAffineExprEqual(affineExpr, other.affineExpr); 285 } 286 287 py::object PyAffineExpr::getCapsule() { 288 return py::reinterpret_steal<py::object>( 289 mlirPythonAffineExprToCapsule(*this)); 290 } 291 292 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { 293 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); 294 if (mlirAffineExprIsNull(rawAffineExpr)) 295 throw py::error_already_set(); 296 return PyAffineExpr( 297 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), 298 rawAffineExpr); 299 } 300 301 //------------------------------------------------------------------------------ 302 // PyAffineMap and utilities. 303 //------------------------------------------------------------------------------ 304 namespace { 305 306 /// A list of expressions contained in an affine map. Internally these are 307 /// stored as a consecutive array leading to inexpensive random access. Both 308 /// the map and the expression are owned by the context so we need not bother 309 /// with lifetime extension. 310 class PyAffineMapExprList 311 : public Sliceable<PyAffineMapExprList, PyAffineExpr> { 312 public: 313 static constexpr const char *pyClassName = "AffineExprList"; 314 315 PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, 316 intptr_t length = -1, intptr_t step = 1) 317 : Sliceable(startIndex, 318 length == -1 ? mlirAffineMapGetNumResults(map) : length, 319 step), 320 affineMap(map) {} 321 322 intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } 323 324 PyAffineExpr getElement(intptr_t pos) { 325 return PyAffineExpr(affineMap.getContext(), 326 mlirAffineMapGetResult(affineMap, pos)); 327 } 328 329 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, 330 intptr_t step) { 331 return PyAffineMapExprList(affineMap, startIndex, length, step); 332 } 333 334 private: 335 PyAffineMap affineMap; 336 }; 337 } // end namespace 338 339 bool PyAffineMap::operator==(const PyAffineMap &other) { 340 return mlirAffineMapEqual(affineMap, other.affineMap); 341 } 342 343 py::object PyAffineMap::getCapsule() { 344 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); 345 } 346 347 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { 348 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); 349 if (mlirAffineMapIsNull(rawAffineMap)) 350 throw py::error_already_set(); 351 return PyAffineMap( 352 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), 353 rawAffineMap); 354 } 355 356 //------------------------------------------------------------------------------ 357 // PyIntegerSet and utilities. 358 //------------------------------------------------------------------------------ 359 namespace { 360 361 class PyIntegerSetConstraint { 362 public: 363 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} 364 365 PyAffineExpr getExpr() { 366 return PyAffineExpr(set.getContext(), 367 mlirIntegerSetGetConstraint(set, pos)); 368 } 369 370 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } 371 372 static void bind(py::module &m) { 373 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint", 374 py::module_local()) 375 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) 376 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); 377 } 378 379 private: 380 PyIntegerSet set; 381 intptr_t pos; 382 }; 383 384 class PyIntegerSetConstraintList 385 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { 386 public: 387 static constexpr const char *pyClassName = "IntegerSetConstraintList"; 388 389 PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, 390 intptr_t length = -1, intptr_t step = 1) 391 : Sliceable(startIndex, 392 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, 393 step), 394 set(set) {} 395 396 intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } 397 398 PyIntegerSetConstraint getElement(intptr_t pos) { 399 return PyIntegerSetConstraint(set, pos); 400 } 401 402 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, 403 intptr_t step) { 404 return PyIntegerSetConstraintList(set, startIndex, length, step); 405 } 406 407 private: 408 PyIntegerSet set; 409 }; 410 } // namespace 411 412 bool PyIntegerSet::operator==(const PyIntegerSet &other) { 413 return mlirIntegerSetEqual(integerSet, other.integerSet); 414 } 415 416 py::object PyIntegerSet::getCapsule() { 417 return py::reinterpret_steal<py::object>( 418 mlirPythonIntegerSetToCapsule(*this)); 419 } 420 421 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { 422 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); 423 if (mlirIntegerSetIsNull(rawIntegerSet)) 424 throw py::error_already_set(); 425 return PyIntegerSet( 426 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), 427 rawIntegerSet); 428 } 429 430 void mlir::python::populateIRAffine(py::module &m) { 431 //---------------------------------------------------------------------------- 432 // Mapping of PyAffineExpr and derived classes. 433 //---------------------------------------------------------------------------- 434 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local()) 435 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 436 &PyAffineExpr::getCapsule) 437 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) 438 .def("__add__", 439 [](PyAffineExpr &self, PyAffineExpr &other) { 440 return PyAffineAddExpr::get(self, other); 441 }) 442 .def("__mul__", 443 [](PyAffineExpr &self, PyAffineExpr &other) { 444 return PyAffineMulExpr::get(self, other); 445 }) 446 .def("__mod__", 447 [](PyAffineExpr &self, PyAffineExpr &other) { 448 return PyAffineModExpr::get(self, other); 449 }) 450 .def("__sub__", 451 [](PyAffineExpr &self, PyAffineExpr &other) { 452 auto negOne = 453 PyAffineConstantExpr::get(-1, *self.getContext().get()); 454 return PyAffineAddExpr::get(self, 455 PyAffineMulExpr::get(negOne, other)); 456 }) 457 .def("__eq__", [](PyAffineExpr &self, 458 PyAffineExpr &other) { return self == other; }) 459 .def("__eq__", 460 [](PyAffineExpr &self, py::object &other) { return false; }) 461 .def("__str__", 462 [](PyAffineExpr &self) { 463 PyPrintAccumulator printAccum; 464 mlirAffineExprPrint(self, printAccum.getCallback(), 465 printAccum.getUserData()); 466 return printAccum.join(); 467 }) 468 .def("__repr__", 469 [](PyAffineExpr &self) { 470 PyPrintAccumulator printAccum; 471 printAccum.parts.append("AffineExpr("); 472 mlirAffineExprPrint(self, printAccum.getCallback(), 473 printAccum.getUserData()); 474 printAccum.parts.append(")"); 475 return printAccum.join(); 476 }) 477 .def_property_readonly( 478 "context", 479 [](PyAffineExpr &self) { return self.getContext().getObject(); }) 480 .def_static( 481 "get_add", &PyAffineAddExpr::get, 482 "Gets an affine expression containing a sum of two expressions.") 483 .def_static( 484 "get_mul", &PyAffineMulExpr::get, 485 "Gets an affine expression containing a product of two expressions.") 486 .def_static("get_mod", &PyAffineModExpr::get, 487 "Gets an affine expression containing the modulo of dividing " 488 "one expression by another.") 489 .def_static("get_floor_div", &PyAffineFloorDivExpr::get, 490 "Gets an affine expression containing the rounded-down " 491 "result of dividing one expression by another.") 492 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, 493 "Gets an affine expression containing the rounded-up result " 494 "of dividing one expression by another.") 495 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), 496 py::arg("context") = py::none(), 497 "Gets a constant affine expression with the given value.") 498 .def_static( 499 "get_dim", &PyAffineDimExpr::get, py::arg("position"), 500 py::arg("context") = py::none(), 501 "Gets an affine expression of a dimension at the given position.") 502 .def_static( 503 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), 504 py::arg("context") = py::none(), 505 "Gets an affine expression of a symbol at the given position.") 506 .def( 507 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, 508 kDumpDocstring); 509 PyAffineConstantExpr::bind(m); 510 PyAffineDimExpr::bind(m); 511 PyAffineSymbolExpr::bind(m); 512 PyAffineBinaryExpr::bind(m); 513 PyAffineAddExpr::bind(m); 514 PyAffineMulExpr::bind(m); 515 PyAffineModExpr::bind(m); 516 PyAffineFloorDivExpr::bind(m); 517 PyAffineCeilDivExpr::bind(m); 518 519 //---------------------------------------------------------------------------- 520 // Mapping of PyAffineMap. 521 //---------------------------------------------------------------------------- 522 py::class_<PyAffineMap>(m, "AffineMap", py::module_local()) 523 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 524 &PyAffineMap::getCapsule) 525 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) 526 .def("__eq__", 527 [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) 528 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) 529 .def("__str__", 530 [](PyAffineMap &self) { 531 PyPrintAccumulator printAccum; 532 mlirAffineMapPrint(self, printAccum.getCallback(), 533 printAccum.getUserData()); 534 return printAccum.join(); 535 }) 536 .def("__repr__", 537 [](PyAffineMap &self) { 538 PyPrintAccumulator printAccum; 539 printAccum.parts.append("AffineMap("); 540 mlirAffineMapPrint(self, printAccum.getCallback(), 541 printAccum.getUserData()); 542 printAccum.parts.append(")"); 543 return printAccum.join(); 544 }) 545 .def_static("compress_unused_symbols", 546 [](py::list affineMaps, DefaultingPyMlirContext context) { 547 SmallVector<MlirAffineMap> maps; 548 pyListToVector<PyAffineMap, MlirAffineMap>( 549 affineMaps, maps, "attempting to create an AffineMap"); 550 std::vector<MlirAffineMap> compressed(affineMaps.size()); 551 auto populate = [](void *result, intptr_t idx, 552 MlirAffineMap m) { 553 static_cast<MlirAffineMap *>(result)[idx] = (m); 554 }; 555 mlirAffineMapCompressUnusedSymbols( 556 maps.data(), maps.size(), compressed.data(), populate); 557 std::vector<PyAffineMap> res; 558 res.reserve(compressed.size()); 559 for (auto m : compressed) 560 res.push_back(PyAffineMap(context->getRef(), m)); 561 return res; 562 }) 563 .def_property_readonly( 564 "context", 565 [](PyAffineMap &self) { return self.getContext().getObject(); }, 566 "Context that owns the Affine Map") 567 .def( 568 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, 569 kDumpDocstring) 570 .def_static( 571 "get", 572 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, 573 DefaultingPyMlirContext context) { 574 SmallVector<MlirAffineExpr> affineExprs; 575 pyListToVector<PyAffineExpr, MlirAffineExpr>( 576 exprs, affineExprs, "attempting to create an AffineMap"); 577 MlirAffineMap map = 578 mlirAffineMapGet(context->get(), dimCount, symbolCount, 579 affineExprs.size(), affineExprs.data()); 580 return PyAffineMap(context->getRef(), map); 581 }, 582 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), 583 py::arg("context") = py::none(), 584 "Gets a map with the given expressions as results.") 585 .def_static( 586 "get_constant", 587 [](intptr_t value, DefaultingPyMlirContext context) { 588 MlirAffineMap affineMap = 589 mlirAffineMapConstantGet(context->get(), value); 590 return PyAffineMap(context->getRef(), affineMap); 591 }, 592 py::arg("value"), py::arg("context") = py::none(), 593 "Gets an affine map with a single constant result") 594 .def_static( 595 "get_empty", 596 [](DefaultingPyMlirContext context) { 597 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); 598 return PyAffineMap(context->getRef(), affineMap); 599 }, 600 py::arg("context") = py::none(), "Gets an empty affine map.") 601 .def_static( 602 "get_identity", 603 [](intptr_t nDims, DefaultingPyMlirContext context) { 604 MlirAffineMap affineMap = 605 mlirAffineMapMultiDimIdentityGet(context->get(), nDims); 606 return PyAffineMap(context->getRef(), affineMap); 607 }, 608 py::arg("n_dims"), py::arg("context") = py::none(), 609 "Gets an identity map with the given number of dimensions.") 610 .def_static( 611 "get_minor_identity", 612 [](intptr_t nDims, intptr_t nResults, 613 DefaultingPyMlirContext context) { 614 MlirAffineMap affineMap = 615 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); 616 return PyAffineMap(context->getRef(), affineMap); 617 }, 618 py::arg("n_dims"), py::arg("n_results"), 619 py::arg("context") = py::none(), 620 "Gets a minor identity map with the given number of dimensions and " 621 "results.") 622 .def_static( 623 "get_permutation", 624 [](std::vector<unsigned> permutation, 625 DefaultingPyMlirContext context) { 626 if (!isPermutation(permutation)) 627 throw py::cast_error("Invalid permutation when attempting to " 628 "create an AffineMap"); 629 MlirAffineMap affineMap = mlirAffineMapPermutationGet( 630 context->get(), permutation.size(), permutation.data()); 631 return PyAffineMap(context->getRef(), affineMap); 632 }, 633 py::arg("permutation"), py::arg("context") = py::none(), 634 "Gets an affine map that permutes its inputs.") 635 .def("get_submap", 636 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { 637 intptr_t numResults = mlirAffineMapGetNumResults(self); 638 for (intptr_t pos : resultPos) { 639 if (pos < 0 || pos >= numResults) 640 throw py::value_error("result position out of bounds"); 641 } 642 MlirAffineMap affineMap = mlirAffineMapGetSubMap( 643 self, resultPos.size(), resultPos.data()); 644 return PyAffineMap(self.getContext(), affineMap); 645 }) 646 .def("get_major_submap", 647 [](PyAffineMap &self, intptr_t nResults) { 648 if (nResults >= mlirAffineMapGetNumResults(self)) 649 throw py::value_error("number of results out of bounds"); 650 MlirAffineMap affineMap = 651 mlirAffineMapGetMajorSubMap(self, nResults); 652 return PyAffineMap(self.getContext(), affineMap); 653 }) 654 .def("get_minor_submap", 655 [](PyAffineMap &self, intptr_t nResults) { 656 if (nResults >= mlirAffineMapGetNumResults(self)) 657 throw py::value_error("number of results out of bounds"); 658 MlirAffineMap affineMap = 659 mlirAffineMapGetMinorSubMap(self, nResults); 660 return PyAffineMap(self.getContext(), affineMap); 661 }) 662 .def("replace", 663 [](PyAffineMap &self, PyAffineExpr &expression, 664 PyAffineExpr &replacement, intptr_t numResultDims, 665 intptr_t numResultSyms) { 666 MlirAffineMap affineMap = mlirAffineMapReplace( 667 self, expression, replacement, numResultDims, numResultSyms); 668 return PyAffineMap(self.getContext(), affineMap); 669 }) 670 .def_property_readonly( 671 "is_permutation", 672 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) 673 .def_property_readonly("is_projected_permutation", 674 [](PyAffineMap &self) { 675 return mlirAffineMapIsProjectedPermutation(self); 676 }) 677 .def_property_readonly( 678 "n_dims", 679 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) 680 .def_property_readonly( 681 "n_inputs", 682 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) 683 .def_property_readonly( 684 "n_symbols", 685 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) 686 .def_property_readonly("results", [](PyAffineMap &self) { 687 return PyAffineMapExprList(self); 688 }); 689 PyAffineMapExprList::bind(m); 690 691 //---------------------------------------------------------------------------- 692 // Mapping of PyIntegerSet. 693 //---------------------------------------------------------------------------- 694 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local()) 695 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 696 &PyIntegerSet::getCapsule) 697 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) 698 .def("__eq__", [](PyIntegerSet &self, 699 PyIntegerSet &other) { return self == other; }) 700 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) 701 .def("__str__", 702 [](PyIntegerSet &self) { 703 PyPrintAccumulator printAccum; 704 mlirIntegerSetPrint(self, printAccum.getCallback(), 705 printAccum.getUserData()); 706 return printAccum.join(); 707 }) 708 .def("__repr__", 709 [](PyIntegerSet &self) { 710 PyPrintAccumulator printAccum; 711 printAccum.parts.append("IntegerSet("); 712 mlirIntegerSetPrint(self, printAccum.getCallback(), 713 printAccum.getUserData()); 714 printAccum.parts.append(")"); 715 return printAccum.join(); 716 }) 717 .def_property_readonly( 718 "context", 719 [](PyIntegerSet &self) { return self.getContext().getObject(); }) 720 .def( 721 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, 722 kDumpDocstring) 723 .def_static( 724 "get", 725 [](intptr_t numDims, intptr_t numSymbols, py::list exprs, 726 std::vector<bool> eqFlags, DefaultingPyMlirContext context) { 727 if (exprs.size() != eqFlags.size()) 728 throw py::value_error( 729 "Expected the number of constraints to match " 730 "that of equality flags"); 731 if (exprs.empty()) 732 throw py::value_error("Expected non-empty list of constraints"); 733 734 // Copy over to a SmallVector because std::vector has a 735 // specialization for booleans that packs data and does not 736 // expose a `bool *`. 737 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); 738 739 SmallVector<MlirAffineExpr> affineExprs; 740 pyListToVector<PyAffineExpr>(exprs, affineExprs, 741 "attempting to create an IntegerSet"); 742 MlirIntegerSet set = mlirIntegerSetGet( 743 context->get(), numDims, numSymbols, exprs.size(), 744 affineExprs.data(), flags.data()); 745 return PyIntegerSet(context->getRef(), set); 746 }, 747 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), 748 py::arg("eq_flags"), py::arg("context") = py::none()) 749 .def_static( 750 "get_empty", 751 [](intptr_t numDims, intptr_t numSymbols, 752 DefaultingPyMlirContext context) { 753 MlirIntegerSet set = 754 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); 755 return PyIntegerSet(context->getRef(), set); 756 }, 757 py::arg("num_dims"), py::arg("num_symbols"), 758 py::arg("context") = py::none()) 759 .def("get_replaced", 760 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, 761 intptr_t numResultDims, intptr_t numResultSymbols) { 762 if (static_cast<intptr_t>(dimExprs.size()) != 763 mlirIntegerSetGetNumDims(self)) 764 throw py::value_error( 765 "Expected the number of dimension replacement expressions " 766 "to match that of dimensions"); 767 if (static_cast<intptr_t>(symbolExprs.size()) != 768 mlirIntegerSetGetNumSymbols(self)) 769 throw py::value_error( 770 "Expected the number of symbol replacement expressions " 771 "to match that of symbols"); 772 773 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; 774 pyListToVector<PyAffineExpr>( 775 dimExprs, dimAffineExprs, 776 "attempting to create an IntegerSet by replacing dimensions"); 777 pyListToVector<PyAffineExpr>( 778 symbolExprs, symbolAffineExprs, 779 "attempting to create an IntegerSet by replacing symbols"); 780 MlirIntegerSet set = mlirIntegerSetReplaceGet( 781 self, dimAffineExprs.data(), symbolAffineExprs.data(), 782 numResultDims, numResultSymbols); 783 return PyIntegerSet(self.getContext(), set); 784 }) 785 .def_property_readonly("is_canonical_empty", 786 [](PyIntegerSet &self) { 787 return mlirIntegerSetIsCanonicalEmpty(self); 788 }) 789 .def_property_readonly( 790 "n_dims", 791 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) 792 .def_property_readonly( 793 "n_symbols", 794 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) 795 .def_property_readonly( 796 "n_inputs", 797 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) 798 .def_property_readonly("n_equalities", 799 [](PyIntegerSet &self) { 800 return mlirIntegerSetGetNumEqualities(self); 801 }) 802 .def_property_readonly("n_inequalities", 803 [](PyIntegerSet &self) { 804 return mlirIntegerSetGetNumInequalities(self); 805 }) 806 .def_property_readonly("constraints", [](PyIntegerSet &self) { 807 return PyIntegerSetConstraintList(self); 808 }); 809 PyIntegerSetConstraint::bind(m); 810 PyIntegerSetConstraintList::bind(m); 811 } 812