1 //===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "IRModule.h" 10 11 #include "PybindUtils.h" 12 13 #include "mlir-c/AffineMap.h" 14 #include "mlir-c/Bindings/Python/Interop.h" 15 #include "mlir-c/IntegerSet.h" 16 17 namespace py = pybind11; 18 using namespace mlir; 19 using namespace mlir::python; 20 21 using llvm::SmallVector; 22 using llvm::StringRef; 23 using llvm::Twine; 24 25 static const char kDumpDocstring[] = 26 R"(Dumps a debug representation of the object to stderr.)"; 27 28 /// Attempts to populate `result` with the content of `list` casted to the 29 /// appropriate type (Python and C types are provided as template arguments). 30 /// Throws errors in case of failure, using "action" to describe what the caller 31 /// was attempting to do. 32 template <typename PyType, typename CType> 33 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, 34 StringRef action) { 35 result.reserve(py::len(list)); 36 for (py::handle item : list) { 37 try { 38 result.push_back(item.cast<PyType>()); 39 } catch (py::cast_error &err) { 40 std::string msg = (llvm::Twine("Invalid expression when ") + action + 41 " (" + err.what() + ")") 42 .str(); 43 throw py::cast_error(msg); 44 } catch (py::reference_cast_error &err) { 45 std::string msg = (llvm::Twine("Invalid expression (None?) when ") + 46 action + " (" + err.what() + ")") 47 .str(); 48 throw py::cast_error(msg); 49 } 50 } 51 } 52 53 template <typename PermutationTy> 54 static bool isPermutation(std::vector<PermutationTy> permutation) { 55 llvm::SmallVector<bool, 8> seen(permutation.size(), false); 56 for (auto val : permutation) { 57 if (val < permutation.size()) { 58 if (seen[val]) 59 return false; 60 seen[val] = true; 61 continue; 62 } 63 return false; 64 } 65 return true; 66 } 67 68 namespace { 69 70 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr 71 /// and should be castable from it. Intermediate hierarchy classes can be 72 /// modeled by specifying BaseTy. 73 template <typename DerivedTy, typename BaseTy = PyAffineExpr> 74 class PyConcreteAffineExpr : public BaseTy { 75 public: 76 // Derived classes must define statics for: 77 // IsAFunctionTy isaFunction 78 // const char *pyClassName 79 // and redefine bindDerived. 80 using ClassTy = py::class_<DerivedTy, BaseTy>; 81 using IsAFunctionTy = bool (*)(MlirAffineExpr); 82 83 PyConcreteAffineExpr() = default; 84 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 85 : BaseTy(std::move(contextRef), affineExpr) {} 86 PyConcreteAffineExpr(PyAffineExpr &orig) 87 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} 88 89 static MlirAffineExpr castFrom(PyAffineExpr &orig) { 90 if (!DerivedTy::isaFunction(orig)) { 91 auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); 92 throw SetPyError(PyExc_ValueError, 93 Twine("Cannot cast affine expression to ") + 94 DerivedTy::pyClassName + " (from " + origRepr + ")"); 95 } 96 return orig; 97 } 98 99 static void bind(py::module &m) { 100 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); 101 cls.def(py::init<PyAffineExpr &>()); 102 DerivedTy::bindDerived(cls); 103 } 104 105 /// Implemented by derived classes to add methods to the Python subclass. 106 static void bindDerived(ClassTy &m) {} 107 }; 108 109 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> { 110 public: 111 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; 112 static constexpr const char *pyClassName = "AffineConstantExpr"; 113 using PyConcreteAffineExpr::PyConcreteAffineExpr; 114 115 static PyAffineConstantExpr get(intptr_t value, 116 DefaultingPyMlirContext context) { 117 MlirAffineExpr affineExpr = 118 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value)); 119 return PyAffineConstantExpr(context->getRef(), affineExpr); 120 } 121 122 static void bindDerived(ClassTy &c) { 123 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), 124 py::arg("context") = py::none()); 125 c.def_property_readonly("value", [](PyAffineConstantExpr &self) { 126 return mlirAffineConstantExprGetValue(self); 127 }); 128 } 129 }; 130 131 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> { 132 public: 133 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; 134 static constexpr const char *pyClassName = "AffineDimExpr"; 135 using PyConcreteAffineExpr::PyConcreteAffineExpr; 136 137 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { 138 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); 139 return PyAffineDimExpr(context->getRef(), affineExpr); 140 } 141 142 static void bindDerived(ClassTy &c) { 143 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), 144 py::arg("context") = py::none()); 145 c.def_property_readonly("position", [](PyAffineDimExpr &self) { 146 return mlirAffineDimExprGetPosition(self); 147 }); 148 } 149 }; 150 151 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> { 152 public: 153 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; 154 static constexpr const char *pyClassName = "AffineSymbolExpr"; 155 using PyConcreteAffineExpr::PyConcreteAffineExpr; 156 157 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { 158 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); 159 return PyAffineSymbolExpr(context->getRef(), affineExpr); 160 } 161 162 static void bindDerived(ClassTy &c) { 163 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), 164 py::arg("context") = py::none()); 165 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { 166 return mlirAffineSymbolExprGetPosition(self); 167 }); 168 } 169 }; 170 171 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> { 172 public: 173 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; 174 static constexpr const char *pyClassName = "AffineBinaryExpr"; 175 using PyConcreteAffineExpr::PyConcreteAffineExpr; 176 177 PyAffineExpr lhs() { 178 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); 179 return PyAffineExpr(getContext(), lhsExpr); 180 } 181 182 PyAffineExpr rhs() { 183 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); 184 return PyAffineExpr(getContext(), rhsExpr); 185 } 186 187 static void bindDerived(ClassTy &c) { 188 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); 189 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); 190 } 191 }; 192 193 class PyAffineAddExpr 194 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> { 195 public: 196 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; 197 static constexpr const char *pyClassName = "AffineAddExpr"; 198 using PyConcreteAffineExpr::PyConcreteAffineExpr; 199 200 static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 201 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); 202 return PyAffineAddExpr(lhs.getContext(), expr); 203 } 204 205 static void bindDerived(ClassTy &c) { 206 c.def_static("get", &PyAffineAddExpr::get); 207 } 208 }; 209 210 class PyAffineMulExpr 211 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> { 212 public: 213 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; 214 static constexpr const char *pyClassName = "AffineMulExpr"; 215 using PyConcreteAffineExpr::PyConcreteAffineExpr; 216 217 static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 218 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); 219 return PyAffineMulExpr(lhs.getContext(), expr); 220 } 221 222 static void bindDerived(ClassTy &c) { 223 c.def_static("get", &PyAffineMulExpr::get); 224 } 225 }; 226 227 class PyAffineModExpr 228 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> { 229 public: 230 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; 231 static constexpr const char *pyClassName = "AffineModExpr"; 232 using PyConcreteAffineExpr::PyConcreteAffineExpr; 233 234 static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 235 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); 236 return PyAffineModExpr(lhs.getContext(), expr); 237 } 238 239 static void bindDerived(ClassTy &c) { 240 c.def_static("get", &PyAffineModExpr::get); 241 } 242 }; 243 244 class PyAffineFloorDivExpr 245 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> { 246 public: 247 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; 248 static constexpr const char *pyClassName = "AffineFloorDivExpr"; 249 using PyConcreteAffineExpr::PyConcreteAffineExpr; 250 251 static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 252 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); 253 return PyAffineFloorDivExpr(lhs.getContext(), expr); 254 } 255 256 static void bindDerived(ClassTy &c) { 257 c.def_static("get", &PyAffineFloorDivExpr::get); 258 } 259 }; 260 261 class PyAffineCeilDivExpr 262 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> { 263 public: 264 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; 265 static constexpr const char *pyClassName = "AffineCeilDivExpr"; 266 using PyConcreteAffineExpr::PyConcreteAffineExpr; 267 268 static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { 269 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); 270 return PyAffineCeilDivExpr(lhs.getContext(), expr); 271 } 272 273 static void bindDerived(ClassTy &c) { 274 c.def_static("get", &PyAffineCeilDivExpr::get); 275 } 276 }; 277 278 } // namespace 279 280 bool PyAffineExpr::operator==(const PyAffineExpr &other) { 281 return mlirAffineExprEqual(affineExpr, other.affineExpr); 282 } 283 284 py::object PyAffineExpr::getCapsule() { 285 return py::reinterpret_steal<py::object>( 286 mlirPythonAffineExprToCapsule(*this)); 287 } 288 289 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { 290 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); 291 if (mlirAffineExprIsNull(rawAffineExpr)) 292 throw py::error_already_set(); 293 return PyAffineExpr( 294 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), 295 rawAffineExpr); 296 } 297 298 //------------------------------------------------------------------------------ 299 // PyAffineMap and utilities. 300 //------------------------------------------------------------------------------ 301 namespace { 302 303 /// A list of expressions contained in an affine map. Internally these are 304 /// stored as a consecutive array leading to inexpensive random access. Both 305 /// the map and the expression are owned by the context so we need not bother 306 /// with lifetime extension. 307 class PyAffineMapExprList 308 : public Sliceable<PyAffineMapExprList, PyAffineExpr> { 309 public: 310 static constexpr const char *pyClassName = "AffineExprList"; 311 312 PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, 313 intptr_t length = -1, intptr_t step = 1) 314 : Sliceable(startIndex, 315 length == -1 ? mlirAffineMapGetNumResults(map) : length, 316 step), 317 affineMap(map) {} 318 319 intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } 320 321 PyAffineExpr getElement(intptr_t pos) { 322 return PyAffineExpr(affineMap.getContext(), 323 mlirAffineMapGetResult(affineMap, pos)); 324 } 325 326 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, 327 intptr_t step) { 328 return PyAffineMapExprList(affineMap, startIndex, length, step); 329 } 330 331 private: 332 PyAffineMap affineMap; 333 }; 334 } // end namespace 335 336 bool PyAffineMap::operator==(const PyAffineMap &other) { 337 return mlirAffineMapEqual(affineMap, other.affineMap); 338 } 339 340 py::object PyAffineMap::getCapsule() { 341 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this)); 342 } 343 344 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { 345 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); 346 if (mlirAffineMapIsNull(rawAffineMap)) 347 throw py::error_already_set(); 348 return PyAffineMap( 349 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), 350 rawAffineMap); 351 } 352 353 //------------------------------------------------------------------------------ 354 // PyIntegerSet and utilities. 355 //------------------------------------------------------------------------------ 356 namespace { 357 358 class PyIntegerSetConstraint { 359 public: 360 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} 361 362 PyAffineExpr getExpr() { 363 return PyAffineExpr(set.getContext(), 364 mlirIntegerSetGetConstraint(set, pos)); 365 } 366 367 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } 368 369 static void bind(py::module &m) { 370 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint", 371 py::module_local()) 372 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) 373 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); 374 } 375 376 private: 377 PyIntegerSet set; 378 intptr_t pos; 379 }; 380 381 class PyIntegerSetConstraintList 382 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { 383 public: 384 static constexpr const char *pyClassName = "IntegerSetConstraintList"; 385 386 PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, 387 intptr_t length = -1, intptr_t step = 1) 388 : Sliceable(startIndex, 389 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, 390 step), 391 set(set) {} 392 393 intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } 394 395 PyIntegerSetConstraint getElement(intptr_t pos) { 396 return PyIntegerSetConstraint(set, pos); 397 } 398 399 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, 400 intptr_t step) { 401 return PyIntegerSetConstraintList(set, startIndex, length, step); 402 } 403 404 private: 405 PyIntegerSet set; 406 }; 407 } // namespace 408 409 bool PyIntegerSet::operator==(const PyIntegerSet &other) { 410 return mlirIntegerSetEqual(integerSet, other.integerSet); 411 } 412 413 py::object PyIntegerSet::getCapsule() { 414 return py::reinterpret_steal<py::object>( 415 mlirPythonIntegerSetToCapsule(*this)); 416 } 417 418 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { 419 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); 420 if (mlirIntegerSetIsNull(rawIntegerSet)) 421 throw py::error_already_set(); 422 return PyIntegerSet( 423 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), 424 rawIntegerSet); 425 } 426 427 void mlir::python::populateIRAffine(py::module &m) { 428 //---------------------------------------------------------------------------- 429 // Mapping of PyAffineExpr and derived classes. 430 //---------------------------------------------------------------------------- 431 py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local()) 432 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 433 &PyAffineExpr::getCapsule) 434 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) 435 .def("__add__", 436 [](PyAffineExpr &self, PyAffineExpr &other) { 437 return PyAffineAddExpr::get(self, other); 438 }) 439 .def("__mul__", 440 [](PyAffineExpr &self, PyAffineExpr &other) { 441 return PyAffineMulExpr::get(self, other); 442 }) 443 .def("__mod__", 444 [](PyAffineExpr &self, PyAffineExpr &other) { 445 return PyAffineModExpr::get(self, other); 446 }) 447 .def("__sub__", 448 [](PyAffineExpr &self, PyAffineExpr &other) { 449 auto negOne = 450 PyAffineConstantExpr::get(-1, *self.getContext().get()); 451 return PyAffineAddExpr::get(self, 452 PyAffineMulExpr::get(negOne, other)); 453 }) 454 .def("__eq__", [](PyAffineExpr &self, 455 PyAffineExpr &other) { return self == other; }) 456 .def("__eq__", 457 [](PyAffineExpr &self, py::object &other) { return false; }) 458 .def("__str__", 459 [](PyAffineExpr &self) { 460 PyPrintAccumulator printAccum; 461 mlirAffineExprPrint(self, printAccum.getCallback(), 462 printAccum.getUserData()); 463 return printAccum.join(); 464 }) 465 .def("__repr__", 466 [](PyAffineExpr &self) { 467 PyPrintAccumulator printAccum; 468 printAccum.parts.append("AffineExpr("); 469 mlirAffineExprPrint(self, printAccum.getCallback(), 470 printAccum.getUserData()); 471 printAccum.parts.append(")"); 472 return printAccum.join(); 473 }) 474 .def_property_readonly( 475 "context", 476 [](PyAffineExpr &self) { return self.getContext().getObject(); }) 477 .def_static( 478 "get_add", &PyAffineAddExpr::get, 479 "Gets an affine expression containing a sum of two expressions.") 480 .def_static( 481 "get_mul", &PyAffineMulExpr::get, 482 "Gets an affine expression containing a product of two expressions.") 483 .def_static("get_mod", &PyAffineModExpr::get, 484 "Gets an affine expression containing the modulo of dividing " 485 "one expression by another.") 486 .def_static("get_floor_div", &PyAffineFloorDivExpr::get, 487 "Gets an affine expression containing the rounded-down " 488 "result of dividing one expression by another.") 489 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, 490 "Gets an affine expression containing the rounded-up result " 491 "of dividing one expression by another.") 492 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), 493 py::arg("context") = py::none(), 494 "Gets a constant affine expression with the given value.") 495 .def_static( 496 "get_dim", &PyAffineDimExpr::get, py::arg("position"), 497 py::arg("context") = py::none(), 498 "Gets an affine expression of a dimension at the given position.") 499 .def_static( 500 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), 501 py::arg("context") = py::none(), 502 "Gets an affine expression of a symbol at the given position.") 503 .def( 504 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, 505 kDumpDocstring); 506 PyAffineConstantExpr::bind(m); 507 PyAffineDimExpr::bind(m); 508 PyAffineSymbolExpr::bind(m); 509 PyAffineBinaryExpr::bind(m); 510 PyAffineAddExpr::bind(m); 511 PyAffineMulExpr::bind(m); 512 PyAffineModExpr::bind(m); 513 PyAffineFloorDivExpr::bind(m); 514 PyAffineCeilDivExpr::bind(m); 515 516 //---------------------------------------------------------------------------- 517 // Mapping of PyAffineMap. 518 //---------------------------------------------------------------------------- 519 py::class_<PyAffineMap>(m, "AffineMap", py::module_local()) 520 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 521 &PyAffineMap::getCapsule) 522 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) 523 .def("__eq__", 524 [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) 525 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) 526 .def("__str__", 527 [](PyAffineMap &self) { 528 PyPrintAccumulator printAccum; 529 mlirAffineMapPrint(self, printAccum.getCallback(), 530 printAccum.getUserData()); 531 return printAccum.join(); 532 }) 533 .def("__repr__", 534 [](PyAffineMap &self) { 535 PyPrintAccumulator printAccum; 536 printAccum.parts.append("AffineMap("); 537 mlirAffineMapPrint(self, printAccum.getCallback(), 538 printAccum.getUserData()); 539 printAccum.parts.append(")"); 540 return printAccum.join(); 541 }) 542 .def_static("compress_unused_symbols", 543 [](py::list affineMaps, DefaultingPyMlirContext context) { 544 SmallVector<MlirAffineMap> maps; 545 pyListToVector<PyAffineMap, MlirAffineMap>( 546 affineMaps, maps, "attempting to create an AffineMap"); 547 std::vector<MlirAffineMap> compressed(affineMaps.size()); 548 auto populate = [](void *result, intptr_t idx, 549 MlirAffineMap m) { 550 static_cast<MlirAffineMap *>(result)[idx] = (m); 551 }; 552 mlirAffineMapCompressUnusedSymbols( 553 maps.data(), maps.size(), compressed.data(), populate); 554 std::vector<PyAffineMap> res; 555 for (auto m : compressed) 556 res.push_back(PyAffineMap(context->getRef(), m)); 557 return res; 558 }) 559 .def_property_readonly( 560 "context", 561 [](PyAffineMap &self) { return self.getContext().getObject(); }, 562 "Context that owns the Affine Map") 563 .def( 564 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, 565 kDumpDocstring) 566 .def_static( 567 "get", 568 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, 569 DefaultingPyMlirContext context) { 570 SmallVector<MlirAffineExpr> affineExprs; 571 pyListToVector<PyAffineExpr, MlirAffineExpr>( 572 exprs, affineExprs, "attempting to create an AffineMap"); 573 MlirAffineMap map = 574 mlirAffineMapGet(context->get(), dimCount, symbolCount, 575 affineExprs.size(), affineExprs.data()); 576 return PyAffineMap(context->getRef(), map); 577 }, 578 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), 579 py::arg("context") = py::none(), 580 "Gets a map with the given expressions as results.") 581 .def_static( 582 "get_constant", 583 [](intptr_t value, DefaultingPyMlirContext context) { 584 MlirAffineMap affineMap = 585 mlirAffineMapConstantGet(context->get(), value); 586 return PyAffineMap(context->getRef(), affineMap); 587 }, 588 py::arg("value"), py::arg("context") = py::none(), 589 "Gets an affine map with a single constant result") 590 .def_static( 591 "get_empty", 592 [](DefaultingPyMlirContext context) { 593 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); 594 return PyAffineMap(context->getRef(), affineMap); 595 }, 596 py::arg("context") = py::none(), "Gets an empty affine map.") 597 .def_static( 598 "get_identity", 599 [](intptr_t nDims, DefaultingPyMlirContext context) { 600 MlirAffineMap affineMap = 601 mlirAffineMapMultiDimIdentityGet(context->get(), nDims); 602 return PyAffineMap(context->getRef(), affineMap); 603 }, 604 py::arg("n_dims"), py::arg("context") = py::none(), 605 "Gets an identity map with the given number of dimensions.") 606 .def_static( 607 "get_minor_identity", 608 [](intptr_t nDims, intptr_t nResults, 609 DefaultingPyMlirContext context) { 610 MlirAffineMap affineMap = 611 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); 612 return PyAffineMap(context->getRef(), affineMap); 613 }, 614 py::arg("n_dims"), py::arg("n_results"), 615 py::arg("context") = py::none(), 616 "Gets a minor identity map with the given number of dimensions and " 617 "results.") 618 .def_static( 619 "get_permutation", 620 [](std::vector<unsigned> permutation, 621 DefaultingPyMlirContext context) { 622 if (!isPermutation(permutation)) 623 throw py::cast_error("Invalid permutation when attempting to " 624 "create an AffineMap"); 625 MlirAffineMap affineMap = mlirAffineMapPermutationGet( 626 context->get(), permutation.size(), permutation.data()); 627 return PyAffineMap(context->getRef(), affineMap); 628 }, 629 py::arg("permutation"), py::arg("context") = py::none(), 630 "Gets an affine map that permutes its inputs.") 631 .def("get_submap", 632 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) { 633 intptr_t numResults = mlirAffineMapGetNumResults(self); 634 for (intptr_t pos : resultPos) { 635 if (pos < 0 || pos >= numResults) 636 throw py::value_error("result position out of bounds"); 637 } 638 MlirAffineMap affineMap = mlirAffineMapGetSubMap( 639 self, resultPos.size(), resultPos.data()); 640 return PyAffineMap(self.getContext(), affineMap); 641 }) 642 .def("get_major_submap", 643 [](PyAffineMap &self, intptr_t nResults) { 644 if (nResults >= mlirAffineMapGetNumResults(self)) 645 throw py::value_error("number of results out of bounds"); 646 MlirAffineMap affineMap = 647 mlirAffineMapGetMajorSubMap(self, nResults); 648 return PyAffineMap(self.getContext(), affineMap); 649 }) 650 .def("get_minor_submap", 651 [](PyAffineMap &self, intptr_t nResults) { 652 if (nResults >= mlirAffineMapGetNumResults(self)) 653 throw py::value_error("number of results out of bounds"); 654 MlirAffineMap affineMap = 655 mlirAffineMapGetMinorSubMap(self, nResults); 656 return PyAffineMap(self.getContext(), affineMap); 657 }) 658 .def("replace", 659 [](PyAffineMap &self, PyAffineExpr &expression, 660 PyAffineExpr &replacement, intptr_t numResultDims, 661 intptr_t numResultSyms) { 662 MlirAffineMap affineMap = mlirAffineMapReplace( 663 self, expression, replacement, numResultDims, numResultSyms); 664 return PyAffineMap(self.getContext(), affineMap); 665 }) 666 .def_property_readonly( 667 "is_permutation", 668 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) 669 .def_property_readonly("is_projected_permutation", 670 [](PyAffineMap &self) { 671 return mlirAffineMapIsProjectedPermutation(self); 672 }) 673 .def_property_readonly( 674 "n_dims", 675 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) 676 .def_property_readonly( 677 "n_inputs", 678 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) 679 .def_property_readonly( 680 "n_symbols", 681 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) 682 .def_property_readonly("results", [](PyAffineMap &self) { 683 return PyAffineMapExprList(self); 684 }); 685 PyAffineMapExprList::bind(m); 686 687 //---------------------------------------------------------------------------- 688 // Mapping of PyIntegerSet. 689 //---------------------------------------------------------------------------- 690 py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local()) 691 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, 692 &PyIntegerSet::getCapsule) 693 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) 694 .def("__eq__", [](PyIntegerSet &self, 695 PyIntegerSet &other) { return self == other; }) 696 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) 697 .def("__str__", 698 [](PyIntegerSet &self) { 699 PyPrintAccumulator printAccum; 700 mlirIntegerSetPrint(self, printAccum.getCallback(), 701 printAccum.getUserData()); 702 return printAccum.join(); 703 }) 704 .def("__repr__", 705 [](PyIntegerSet &self) { 706 PyPrintAccumulator printAccum; 707 printAccum.parts.append("IntegerSet("); 708 mlirIntegerSetPrint(self, printAccum.getCallback(), 709 printAccum.getUserData()); 710 printAccum.parts.append(")"); 711 return printAccum.join(); 712 }) 713 .def_property_readonly( 714 "context", 715 [](PyIntegerSet &self) { return self.getContext().getObject(); }) 716 .def( 717 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, 718 kDumpDocstring) 719 .def_static( 720 "get", 721 [](intptr_t numDims, intptr_t numSymbols, py::list exprs, 722 std::vector<bool> eqFlags, DefaultingPyMlirContext context) { 723 if (exprs.size() != eqFlags.size()) 724 throw py::value_error( 725 "Expected the number of constraints to match " 726 "that of equality flags"); 727 if (exprs.empty()) 728 throw py::value_error("Expected non-empty list of constraints"); 729 730 // Copy over to a SmallVector because std::vector has a 731 // specialization for booleans that packs data and does not 732 // expose a `bool *`. 733 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); 734 735 SmallVector<MlirAffineExpr> affineExprs; 736 pyListToVector<PyAffineExpr>(exprs, affineExprs, 737 "attempting to create an IntegerSet"); 738 MlirIntegerSet set = mlirIntegerSetGet( 739 context->get(), numDims, numSymbols, exprs.size(), 740 affineExprs.data(), flags.data()); 741 return PyIntegerSet(context->getRef(), set); 742 }, 743 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), 744 py::arg("eq_flags"), py::arg("context") = py::none()) 745 .def_static( 746 "get_empty", 747 [](intptr_t numDims, intptr_t numSymbols, 748 DefaultingPyMlirContext context) { 749 MlirIntegerSet set = 750 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); 751 return PyIntegerSet(context->getRef(), set); 752 }, 753 py::arg("num_dims"), py::arg("num_symbols"), 754 py::arg("context") = py::none()) 755 .def("get_replaced", 756 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, 757 intptr_t numResultDims, intptr_t numResultSymbols) { 758 if (static_cast<intptr_t>(dimExprs.size()) != 759 mlirIntegerSetGetNumDims(self)) 760 throw py::value_error( 761 "Expected the number of dimension replacement expressions " 762 "to match that of dimensions"); 763 if (static_cast<intptr_t>(symbolExprs.size()) != 764 mlirIntegerSetGetNumSymbols(self)) 765 throw py::value_error( 766 "Expected the number of symbol replacement expressions " 767 "to match that of symbols"); 768 769 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; 770 pyListToVector<PyAffineExpr>( 771 dimExprs, dimAffineExprs, 772 "attempting to create an IntegerSet by replacing dimensions"); 773 pyListToVector<PyAffineExpr>( 774 symbolExprs, symbolAffineExprs, 775 "attempting to create an IntegerSet by replacing symbols"); 776 MlirIntegerSet set = mlirIntegerSetReplaceGet( 777 self, dimAffineExprs.data(), symbolAffineExprs.data(), 778 numResultDims, numResultSymbols); 779 return PyIntegerSet(self.getContext(), set); 780 }) 781 .def_property_readonly("is_canonical_empty", 782 [](PyIntegerSet &self) { 783 return mlirIntegerSetIsCanonicalEmpty(self); 784 }) 785 .def_property_readonly( 786 "n_dims", 787 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) 788 .def_property_readonly( 789 "n_symbols", 790 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) 791 .def_property_readonly( 792 "n_inputs", 793 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) 794 .def_property_readonly("n_equalities", 795 [](PyIntegerSet &self) { 796 return mlirIntegerSetGetNumEqualities(self); 797 }) 798 .def_property_readonly("n_inequalities", 799 [](PyIntegerSet &self) { 800 return mlirIntegerSetGetNumInequalities(self); 801 }) 802 .def_property_readonly("constraints", [](PyIntegerSet &self) { 803 return PyIntegerSetConstraintList(self); 804 }); 805 PyIntegerSetConstraint::bind(m); 806 PyIntegerSetConstraintList::bind(m); 807 } 808