1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/IR/AffineExpr.h" 12 #include "AffineExprDetail.h" 13 #include "mlir/IR/AffineExprVisitor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/Support/MathExtras.h" 17 #include "mlir/Support/TypeID.h" 18 #include "llvm/ADT/STLExtras.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 MLIRContext *AffineExpr::getContext() const { return expr->context; } 24 25 AffineExprKind AffineExpr::getKind() const { return expr->kind; } 26 27 /// Walk all of the AffineExprs in this subgraph in postorder. 28 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const { 29 struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> { 30 std::function<void(AffineExpr)> callback; 31 32 AffineExprWalker(std::function<void(AffineExpr)> callback) 33 : callback(std::move(callback)) {} 34 35 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); } 36 void visitConstantExpr(AffineConstantExpr expr) { callback(expr); } 37 void visitDimExpr(AffineDimExpr expr) { callback(expr); } 38 void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); } 39 }; 40 41 AffineExprWalker(std::move(callback)).walkPostOrder(*this); 42 } 43 44 // Dispatch affine expression construction based on kind. 45 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, 46 AffineExpr rhs) { 47 if (kind == AffineExprKind::Add) 48 return lhs + rhs; 49 if (kind == AffineExprKind::Mul) 50 return lhs * rhs; 51 if (kind == AffineExprKind::FloorDiv) 52 return lhs.floorDiv(rhs); 53 if (kind == AffineExprKind::CeilDiv) 54 return lhs.ceilDiv(rhs); 55 if (kind == AffineExprKind::Mod) 56 return lhs % rhs; 57 58 llvm_unreachable("unknown binary operation on affine expressions"); 59 } 60 61 /// This method substitutes any uses of dimensions and symbols (e.g. 62 /// dim#0 with dimReplacements[0]) and returns the modified expression tree. 63 AffineExpr 64 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 65 ArrayRef<AffineExpr> symReplacements) const { 66 switch (getKind()) { 67 case AffineExprKind::Constant: 68 return *this; 69 case AffineExprKind::DimId: { 70 unsigned dimId = cast<AffineDimExpr>().getPosition(); 71 if (dimId >= dimReplacements.size()) 72 return *this; 73 return dimReplacements[dimId]; 74 } 75 case AffineExprKind::SymbolId: { 76 unsigned symId = cast<AffineSymbolExpr>().getPosition(); 77 if (symId >= symReplacements.size()) 78 return *this; 79 return symReplacements[symId]; 80 } 81 case AffineExprKind::Add: 82 case AffineExprKind::Mul: 83 case AffineExprKind::FloorDiv: 84 case AffineExprKind::CeilDiv: 85 case AffineExprKind::Mod: 86 auto binOp = cast<AffineBinaryOpExpr>(); 87 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 88 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 89 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 90 if (newLHS == lhs && newRHS == rhs) 91 return *this; 92 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 93 } 94 llvm_unreachable("Unknown AffineExpr"); 95 } 96 97 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const { 98 return replaceDimsAndSymbols(dimReplacements, {}); 99 } 100 101 AffineExpr 102 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const { 103 return replaceDimsAndSymbols({}, symReplacements); 104 } 105 106 /// Replace dims[offset ... numDims) 107 /// by dims[offset + shift ... shift + numDims). 108 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift, 109 unsigned offset) const { 110 SmallVector<AffineExpr, 4> dims; 111 for (unsigned idx = 0; idx < offset; ++idx) 112 dims.push_back(getAffineDimExpr(idx, getContext())); 113 for (unsigned idx = offset; idx < numDims; ++idx) 114 dims.push_back(getAffineDimExpr(idx + shift, getContext())); 115 return replaceDimsAndSymbols(dims, {}); 116 } 117 118 /// Replace symbols[offset ... numSymbols) 119 /// by symbols[offset + shift ... shift + numSymbols). 120 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift, 121 unsigned offset) const { 122 SmallVector<AffineExpr, 4> symbols; 123 for (unsigned idx = 0; idx < offset; ++idx) 124 symbols.push_back(getAffineSymbolExpr(idx, getContext())); 125 for (unsigned idx = offset; idx < numSymbols; ++idx) 126 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext())); 127 return replaceDimsAndSymbols({}, symbols); 128 } 129 130 /// Sparse replace method. Return the modified expression tree. 131 AffineExpr 132 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 133 auto it = map.find(*this); 134 if (it != map.end()) 135 return it->second; 136 switch (getKind()) { 137 default: 138 return *this; 139 case AffineExprKind::Add: 140 case AffineExprKind::Mul: 141 case AffineExprKind::FloorDiv: 142 case AffineExprKind::CeilDiv: 143 case AffineExprKind::Mod: 144 auto binOp = cast<AffineBinaryOpExpr>(); 145 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 146 auto newLHS = lhs.replace(map); 147 auto newRHS = rhs.replace(map); 148 if (newLHS == lhs && newRHS == rhs) 149 return *this; 150 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 151 } 152 llvm_unreachable("Unknown AffineExpr"); 153 } 154 155 /// Sparse replace method. Return the modified expression tree. 156 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const { 157 DenseMap<AffineExpr, AffineExpr> map; 158 map.insert(std::make_pair(expr, replacement)); 159 return replace(map); 160 } 161 /// Returns true if this expression is made out of only symbols and 162 /// constants (no dimensional identifiers). 163 bool AffineExpr::isSymbolicOrConstant() const { 164 switch (getKind()) { 165 case AffineExprKind::Constant: 166 return true; 167 case AffineExprKind::DimId: 168 return false; 169 case AffineExprKind::SymbolId: 170 return true; 171 172 case AffineExprKind::Add: 173 case AffineExprKind::Mul: 174 case AffineExprKind::FloorDiv: 175 case AffineExprKind::CeilDiv: 176 case AffineExprKind::Mod: { 177 auto expr = this->cast<AffineBinaryOpExpr>(); 178 return expr.getLHS().isSymbolicOrConstant() && 179 expr.getRHS().isSymbolicOrConstant(); 180 } 181 } 182 llvm_unreachable("Unknown AffineExpr"); 183 } 184 185 /// Returns true if this is a pure affine expression, i.e., multiplication, 186 /// floordiv, ceildiv, and mod is only allowed w.r.t constants. 187 bool AffineExpr::isPureAffine() const { 188 switch (getKind()) { 189 case AffineExprKind::SymbolId: 190 case AffineExprKind::DimId: 191 case AffineExprKind::Constant: 192 return true; 193 case AffineExprKind::Add: { 194 auto op = cast<AffineBinaryOpExpr>(); 195 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine(); 196 } 197 198 case AffineExprKind::Mul: { 199 // TODO: Canonicalize the constants in binary operators to the RHS when 200 // possible, allowing this to merge into the next case. 201 auto op = cast<AffineBinaryOpExpr>(); 202 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() && 203 (op.getLHS().template isa<AffineConstantExpr>() || 204 op.getRHS().template isa<AffineConstantExpr>()); 205 } 206 case AffineExprKind::FloorDiv: 207 case AffineExprKind::CeilDiv: 208 case AffineExprKind::Mod: { 209 auto op = cast<AffineBinaryOpExpr>(); 210 return op.getLHS().isPureAffine() && 211 op.getRHS().template isa<AffineConstantExpr>(); 212 } 213 } 214 llvm_unreachable("Unknown AffineExpr"); 215 } 216 217 // Returns the greatest known integral divisor of this affine expression. 218 int64_t AffineExpr::getLargestKnownDivisor() const { 219 AffineBinaryOpExpr binExpr(nullptr); 220 switch (getKind()) { 221 case AffineExprKind::CeilDiv: 222 LLVM_FALLTHROUGH; 223 case AffineExprKind::DimId: 224 case AffineExprKind::FloorDiv: 225 case AffineExprKind::SymbolId: 226 return 1; 227 case AffineExprKind::Constant: 228 return std::abs(this->cast<AffineConstantExpr>().getValue()); 229 case AffineExprKind::Mul: { 230 binExpr = this->cast<AffineBinaryOpExpr>(); 231 return binExpr.getLHS().getLargestKnownDivisor() * 232 binExpr.getRHS().getLargestKnownDivisor(); 233 } 234 case AffineExprKind::Add: 235 LLVM_FALLTHROUGH; 236 case AffineExprKind::Mod: { 237 binExpr = cast<AffineBinaryOpExpr>(); 238 return llvm::GreatestCommonDivisor64( 239 binExpr.getLHS().getLargestKnownDivisor(), 240 binExpr.getRHS().getLargestKnownDivisor()); 241 } 242 } 243 llvm_unreachable("Unknown AffineExpr"); 244 } 245 246 bool AffineExpr::isMultipleOf(int64_t factor) const { 247 AffineBinaryOpExpr binExpr(nullptr); 248 uint64_t l, u; 249 switch (getKind()) { 250 case AffineExprKind::SymbolId: 251 LLVM_FALLTHROUGH; 252 case AffineExprKind::DimId: 253 return factor * factor == 1; 254 case AffineExprKind::Constant: 255 return cast<AffineConstantExpr>().getValue() % factor == 0; 256 case AffineExprKind::Mul: { 257 binExpr = cast<AffineBinaryOpExpr>(); 258 // It's probably not worth optimizing this further (to not traverse the 259 // whole sub-tree under - it that would require a version of isMultipleOf 260 // that on a 'false' return also returns the largest known divisor). 261 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 || 262 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 || 263 (l * u) % factor == 0; 264 } 265 case AffineExprKind::Add: 266 case AffineExprKind::FloorDiv: 267 case AffineExprKind::CeilDiv: 268 case AffineExprKind::Mod: { 269 binExpr = cast<AffineBinaryOpExpr>(); 270 return llvm::GreatestCommonDivisor64( 271 binExpr.getLHS().getLargestKnownDivisor(), 272 binExpr.getRHS().getLargestKnownDivisor()) % 273 factor == 274 0; 275 } 276 } 277 llvm_unreachable("Unknown AffineExpr"); 278 } 279 280 bool AffineExpr::isFunctionOfDim(unsigned position) const { 281 if (getKind() == AffineExprKind::DimId) { 282 return *this == mlir::getAffineDimExpr(position, getContext()); 283 } 284 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 285 return expr.getLHS().isFunctionOfDim(position) || 286 expr.getRHS().isFunctionOfDim(position); 287 } 288 return false; 289 } 290 291 bool AffineExpr::isFunctionOfSymbol(unsigned position) const { 292 if (getKind() == AffineExprKind::SymbolId) { 293 return *this == mlir::getAffineSymbolExpr(position, getContext()); 294 } 295 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 296 return expr.getLHS().isFunctionOfSymbol(position) || 297 expr.getRHS().isFunctionOfSymbol(position); 298 } 299 return false; 300 } 301 302 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr) 303 : AffineExpr(ptr) {} 304 AffineExpr AffineBinaryOpExpr::getLHS() const { 305 return static_cast<ImplType *>(expr)->lhs; 306 } 307 AffineExpr AffineBinaryOpExpr::getRHS() const { 308 return static_cast<ImplType *>(expr)->rhs; 309 } 310 311 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {} 312 unsigned AffineDimExpr::getPosition() const { 313 return static_cast<ImplType *>(expr)->position; 314 } 315 316 /// Returns true if the expression is divisible by the given symbol with 317 /// position `symbolPos`. The argument `opKind` specifies here what kind of 318 /// division or mod operation called this division. It helps in implementing the 319 /// commutative property of the floordiv and ceildiv operations. If the argument 320 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv 321 /// operation, then the commutative property can be used otherwise, the floordiv 322 /// operation is not divisible. The same argument holds for ceildiv operation. 323 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, 324 AffineExprKind opKind) { 325 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 326 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 327 opKind == AffineExprKind::CeilDiv) && 328 "unexpected opKind"); 329 switch (expr.getKind()) { 330 case AffineExprKind::Constant: 331 return expr.cast<AffineConstantExpr>().getValue() == 0; 332 case AffineExprKind::DimId: 333 return false; 334 case AffineExprKind::SymbolId: 335 return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos); 336 // Checks divisibility by the given symbol for both operands. 337 case AffineExprKind::Add: { 338 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 339 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) && 340 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 341 } 342 // Checks divisibility by the given symbol for both operands. Consider the 343 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, 344 // this is a division by s1 and both the operands of modulo are divisible by 345 // s1 but it is not divisible by s1 always. The third argument is 346 // `AffineExprKind::Mod` for this reason. 347 case AffineExprKind::Mod: { 348 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 349 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, 350 AffineExprKind::Mod) && 351 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, 352 AffineExprKind::Mod); 353 } 354 // Checks if any of the operand divisible by the given symbol. 355 case AffineExprKind::Mul: { 356 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 357 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) || 358 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 359 } 360 // Floordiv and ceildiv are divisible by the given symbol when the first 361 // operand is divisible, and the affine expression kind of the argument expr 362 // is same as the argument `opKind`. This can be inferred from commutative 363 // property of floordiv and ceildiv operations and are as follow: 364 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 365 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 366 // It will fail if operations are not same. For example: 367 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 368 case AffineExprKind::FloorDiv: 369 case AffineExprKind::CeilDiv: { 370 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 371 if (opKind != expr.getKind()) 372 return false; 373 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()); 374 } 375 } 376 llvm_unreachable("Unknown AffineExpr"); 377 } 378 379 /// Divides the given expression by the given symbol at position `symbolPos`. It 380 /// considers the divisibility condition is checked before calling itself. A 381 /// null expression is returned whenever the divisibility condition fails. 382 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, 383 AffineExprKind opKind) { 384 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 385 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 386 opKind == AffineExprKind::CeilDiv) && 387 "unexpected opKind"); 388 switch (expr.getKind()) { 389 case AffineExprKind::Constant: 390 if (expr.cast<AffineConstantExpr>().getValue() != 0) 391 return nullptr; 392 return getAffineConstantExpr(0, expr.getContext()); 393 case AffineExprKind::DimId: 394 return nullptr; 395 case AffineExprKind::SymbolId: 396 return getAffineConstantExpr(1, expr.getContext()); 397 // Dividing both operands by the given symbol. 398 case AffineExprKind::Add: { 399 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 400 return getAffineBinaryOpExpr( 401 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind), 402 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind)); 403 } 404 // Dividing both operands by the given symbol. 405 case AffineExprKind::Mod: { 406 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 407 return getAffineBinaryOpExpr( 408 expr.getKind(), 409 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 410 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind())); 411 } 412 // Dividing any of the operand by the given symbol. 413 case AffineExprKind::Mul: { 414 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 415 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind)) 416 return binaryExpr.getLHS() * 417 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind); 418 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) * 419 binaryExpr.getRHS(); 420 } 421 // Dividing first operand only by the given symbol. 422 case AffineExprKind::FloorDiv: 423 case AffineExprKind::CeilDiv: { 424 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 425 return getAffineBinaryOpExpr( 426 expr.getKind(), 427 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 428 binaryExpr.getRHS()); 429 } 430 } 431 llvm_unreachable("Unknown AffineExpr"); 432 } 433 434 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv 435 /// operations when the second operand simplifies to a symbol and the first 436 /// operand is divisible by that symbol. It can be applied to any semi-affine 437 /// expression. Returned expression can either be a semi-affine or pure affine 438 /// expression. 439 static AffineExpr simplifySemiAffine(AffineExpr expr) { 440 switch (expr.getKind()) { 441 case AffineExprKind::Constant: 442 case AffineExprKind::DimId: 443 case AffineExprKind::SymbolId: 444 return expr; 445 case AffineExprKind::Add: 446 case AffineExprKind::Mul: { 447 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 448 return getAffineBinaryOpExpr(expr.getKind(), 449 simplifySemiAffine(binaryExpr.getLHS()), 450 simplifySemiAffine(binaryExpr.getRHS())); 451 } 452 // Check if the simplification of the second operand is a symbol, and the 453 // first operand is divisible by it. If the operation is a modulo, a constant 454 // zero expression is returned. In the case of floordiv and ceildiv, the 455 // symbol from the simplification of the second operand divides the first 456 // operand. Otherwise, simplification is not possible. 457 case AffineExprKind::FloorDiv: 458 case AffineExprKind::CeilDiv: 459 case AffineExprKind::Mod: { 460 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 461 AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS()); 462 AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS()); 463 AffineSymbolExpr symbolExpr = 464 simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>(); 465 if (!symbolExpr) 466 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 467 unsigned symbolPos = symbolExpr.getPosition(); 468 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind())) 469 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 470 if (expr.getKind() == AffineExprKind::Mod) 471 return getAffineConstantExpr(0, expr.getContext()); 472 return symbolicDivide(sLHS, symbolPos, expr.getKind()); 473 } 474 } 475 llvm_unreachable("Unknown AffineExpr"); 476 } 477 478 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, 479 MLIRContext *context) { 480 auto assignCtx = [context](AffineDimExprStorage *storage) { 481 storage->context = context; 482 }; 483 484 StorageUniquer &uniquer = context->getAffineUniquer(); 485 return uniquer.get<AffineDimExprStorage>( 486 assignCtx, static_cast<unsigned>(kind), position); 487 } 488 489 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { 490 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context); 491 } 492 493 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr) 494 : AffineExpr(ptr) {} 495 unsigned AffineSymbolExpr::getPosition() const { 496 return static_cast<ImplType *>(expr)->position; 497 } 498 499 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { 500 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context); 501 ; 502 } 503 504 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr) 505 : AffineExpr(ptr) {} 506 int64_t AffineConstantExpr::getValue() const { 507 return static_cast<ImplType *>(expr)->constant; 508 } 509 510 bool AffineExpr::operator==(int64_t v) const { 511 return *this == getAffineConstantExpr(v, getContext()); 512 } 513 514 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { 515 auto assignCtx = [context](AffineConstantExprStorage *storage) { 516 storage->context = context; 517 }; 518 519 StorageUniquer &uniquer = context->getAffineUniquer(); 520 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant); 521 } 522 523 /// Simplify add expression. Return nullptr if it can't be simplified. 524 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { 525 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 526 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 527 // Fold if both LHS, RHS are a constant. 528 if (lhsConst && rhsConst) 529 return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), 530 lhs.getContext()); 531 532 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). 533 // If only one of them is a symbolic expressions, make it the RHS. 534 if (lhs.isa<AffineConstantExpr>() || 535 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) { 536 return rhs + lhs; 537 } 538 539 // At this point, if there was a constant, it would be on the right. 540 541 // Addition with a zero is a noop, return the other input. 542 if (rhsConst) { 543 if (rhsConst.getValue() == 0) 544 return lhs; 545 } 546 // Fold successive additions like (d0 + 2) + 3 into d0 + 5. 547 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 548 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) { 549 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 550 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); 551 } 552 553 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr". 554 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their 555 // respective multiplicands. 556 Optional<int64_t> rLhsConst, rRhsConst; 557 AffineExpr firstExpr, secondExpr; 558 AffineConstantExpr rLhsConstExpr; 559 auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>(); 560 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul && 561 (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) { 562 rLhsConst = rLhsConstExpr.getValue(); 563 firstExpr = lBinOpExpr.getLHS(); 564 } else { 565 rLhsConst = 1; 566 firstExpr = lhs; 567 } 568 569 auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>(); 570 AffineConstantExpr rRhsConstExpr; 571 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul && 572 (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) { 573 rRhsConst = rRhsConstExpr.getValue(); 574 secondExpr = rBinOpExpr.getLHS(); 575 } else { 576 rRhsConst = 1; 577 secondExpr = rhs; 578 } 579 580 if (rLhsConst && rRhsConst && firstExpr == secondExpr) 581 return getAffineBinaryOpExpr( 582 AffineExprKind::Mul, firstExpr, 583 getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(), 584 lhs.getContext())); 585 586 // When doing successive additions, bring constant to the right: turn (d0 + 2) 587 // + d1 into (d0 + d1) + 2. 588 if (lBin && lBin.getKind() == AffineExprKind::Add) { 589 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 590 return lBin.getLHS() + rhs + lrhs; 591 } 592 } 593 594 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where 595 // q may be a constant or symbolic expression. This leads to a much more 596 // efficient form when 'c' is a power of two, and in general a more compact 597 // and readable form. 598 599 // Process '(expr floordiv c) * (-c)'. 600 if (!rBinOpExpr) 601 return nullptr; 602 603 auto lrhs = rBinOpExpr.getLHS(); 604 auto rrhs = rBinOpExpr.getRHS(); 605 606 AffineExpr llrhs, rlrhs; 607 608 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a 609 // symbolic expression. 610 auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>(); 611 // Check rrhsConstOpExpr = -1. 612 auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>(); 613 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr && 614 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) { 615 // Check llrhs = expr floordiv q. 616 llrhs = lrhsBinOpExpr.getLHS(); 617 // Check rlrhs = q. 618 rlrhs = lrhsBinOpExpr.getRHS(); 619 auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>(); 620 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv) 621 return nullptr; 622 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS()) 623 return lhs % rlrhs; 624 } 625 626 // Process lrhs, which is 'expr floordiv c'. 627 AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>(); 628 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) 629 return nullptr; 630 631 llrhs = lrBinOpExpr.getLHS(); 632 rlrhs = lrBinOpExpr.getRHS(); 633 634 if (lhs == llrhs && rlrhs == -rrhs) { 635 return lhs % rlrhs; 636 } 637 return nullptr; 638 } 639 640 AffineExpr AffineExpr::operator+(int64_t v) const { 641 return *this + getAffineConstantExpr(v, getContext()); 642 } 643 AffineExpr AffineExpr::operator+(AffineExpr other) const { 644 if (auto simplified = simplifyAdd(*this, other)) 645 return simplified; 646 647 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 648 return uniquer.get<AffineBinaryOpExprStorage>( 649 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other); 650 } 651 652 /// Simplify a multiply expression. Return nullptr if it can't be simplified. 653 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { 654 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 655 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 656 657 if (lhsConst && rhsConst) 658 return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), 659 lhs.getContext()); 660 661 assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant()); 662 663 // Canonicalize the mul expression so that the constant/symbolic term is the 664 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a 665 // constant. (Note that a constant is trivially symbolic). 666 if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) { 667 // At least one of them has to be symbolic. 668 return rhs * lhs; 669 } 670 671 // At this point, if there was a constant, it would be on the right. 672 673 // Multiplication with a one is a noop, return the other input. 674 if (rhsConst) { 675 if (rhsConst.getValue() == 1) 676 return lhs; 677 // Multiplication with zero. 678 if (rhsConst.getValue() == 0) 679 return rhsConst; 680 } 681 682 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. 683 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 684 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) { 685 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 686 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue()); 687 } 688 689 // When doing successive multiplication, bring constant to the right: turn (d0 690 // * 2) * d1 into (d0 * d1) * 2. 691 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 692 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 693 return (lBin.getLHS() * rhs) * lrhs; 694 } 695 } 696 697 return nullptr; 698 } 699 700 AffineExpr AffineExpr::operator*(int64_t v) const { 701 return *this * getAffineConstantExpr(v, getContext()); 702 } 703 AffineExpr AffineExpr::operator*(AffineExpr other) const { 704 if (auto simplified = simplifyMul(*this, other)) 705 return simplified; 706 707 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 708 return uniquer.get<AffineBinaryOpExprStorage>( 709 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other); 710 } 711 712 // Unary minus, delegate to operator*. 713 AffineExpr AffineExpr::operator-() const { 714 return *this * getAffineConstantExpr(-1, getContext()); 715 } 716 717 // Delegate to operator+. 718 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } 719 AffineExpr AffineExpr::operator-(AffineExpr other) const { 720 return *this + (-other); 721 } 722 723 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { 724 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 725 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 726 727 // mlir floordiv by zero or negative numbers is undefined and preserved as is. 728 if (!rhsConst || rhsConst.getValue() < 1) 729 return nullptr; 730 731 if (lhsConst) 732 return getAffineConstantExpr( 733 floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 734 735 // Fold floordiv of a multiply with a constant that is a multiple of the 736 // divisor. Eg: (i * 128) floordiv 64 = i * 2. 737 if (rhsConst == 1) 738 return lhs; 739 740 // Simplify (expr * const) floordiv divConst when expr is known to be a 741 // multiple of divConst. 742 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 743 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 744 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 745 // rhsConst is known to be a positive constant. 746 if (lrhs.getValue() % rhsConst.getValue() == 0) 747 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 748 } 749 } 750 751 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is 752 // known to be a multiple of divConst. 753 if (lBin && lBin.getKind() == AffineExprKind::Add) { 754 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 755 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 756 // rhsConst is known to be a positive constant. 757 if (llhsDiv % rhsConst.getValue() == 0 || 758 lrhsDiv % rhsConst.getValue() == 0) 759 return lBin.getLHS().floorDiv(rhsConst.getValue()) + 760 lBin.getRHS().floorDiv(rhsConst.getValue()); 761 } 762 763 return nullptr; 764 } 765 766 AffineExpr AffineExpr::floorDiv(uint64_t v) const { 767 return floorDiv(getAffineConstantExpr(v, getContext())); 768 } 769 AffineExpr AffineExpr::floorDiv(AffineExpr other) const { 770 if (auto simplified = simplifyFloorDiv(*this, other)) 771 return simplified; 772 773 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 774 return uniquer.get<AffineBinaryOpExprStorage>( 775 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this, 776 other); 777 } 778 779 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { 780 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 781 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 782 783 if (!rhsConst || rhsConst.getValue() < 1) 784 return nullptr; 785 786 if (lhsConst) 787 return getAffineConstantExpr( 788 ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 789 790 // Fold ceildiv of a multiply with a constant that is a multiple of the 791 // divisor. Eg: (i * 128) ceildiv 64 = i * 2. 792 if (rhsConst.getValue() == 1) 793 return lhs; 794 795 // Simplify (expr * const) ceildiv divConst when const is known to be a 796 // multiple of divConst. 797 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 798 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 799 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 800 // rhsConst is known to be a positive constant. 801 if (lrhs.getValue() % rhsConst.getValue() == 0) 802 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 803 } 804 } 805 806 return nullptr; 807 } 808 809 AffineExpr AffineExpr::ceilDiv(uint64_t v) const { 810 return ceilDiv(getAffineConstantExpr(v, getContext())); 811 } 812 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const { 813 if (auto simplified = simplifyCeilDiv(*this, other)) 814 return simplified; 815 816 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 817 return uniquer.get<AffineBinaryOpExprStorage>( 818 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this, 819 other); 820 } 821 822 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { 823 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 824 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 825 826 // mod w.r.t zero or negative numbers is undefined and preserved as is. 827 if (!rhsConst || rhsConst.getValue() < 1) 828 return nullptr; 829 830 if (lhsConst) 831 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), 832 lhs.getContext()); 833 834 // Fold modulo of an expression that is known to be a multiple of a constant 835 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) 836 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0. 837 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) 838 return getAffineConstantExpr(0, lhs.getContext()); 839 840 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is 841 // known to be a multiple of divConst. 842 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 843 if (lBin && lBin.getKind() == AffineExprKind::Add) { 844 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 845 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 846 // rhsConst is known to be a positive constant. 847 if (llhsDiv % rhsConst.getValue() == 0) 848 return lBin.getRHS() % rhsConst.getValue(); 849 if (lrhsDiv % rhsConst.getValue() == 0) 850 return lBin.getLHS() % rhsConst.getValue(); 851 } 852 853 // Simplify (e % a) % b to e % b when b evenly divides a 854 if (lBin && lBin.getKind() == AffineExprKind::Mod) { 855 auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>(); 856 if (intermediate && intermediate.getValue() >= 1 && 857 mod(intermediate.getValue(), rhsConst.getValue()) == 0) { 858 return lBin.getLHS() % rhsConst.getValue(); 859 } 860 } 861 862 return nullptr; 863 } 864 865 AffineExpr AffineExpr::operator%(uint64_t v) const { 866 return *this % getAffineConstantExpr(v, getContext()); 867 } 868 AffineExpr AffineExpr::operator%(AffineExpr other) const { 869 if (auto simplified = simplifyMod(*this, other)) 870 return simplified; 871 872 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 873 return uniquer.get<AffineBinaryOpExprStorage>( 874 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other); 875 } 876 877 AffineExpr AffineExpr::compose(AffineMap map) const { 878 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(), 879 map.getResults().end()); 880 return replaceDimsAndSymbols(dimReplacements, {}); 881 } 882 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) { 883 expr.print(os); 884 return os; 885 } 886 887 /// Constructs an affine expression from a flat ArrayRef. If there are local 888 /// identifiers (neither dimensional nor symbolic) that appear in the sum of 889 /// products expression, `localExprs` is expected to have the AffineExpr 890 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be 891 /// in the format [dims, symbols, locals, constant term]. 892 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 893 unsigned numDims, 894 unsigned numSymbols, 895 ArrayRef<AffineExpr> localExprs, 896 MLIRContext *context) { 897 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 898 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 899 "unexpected number of local expressions"); 900 901 auto expr = getAffineConstantExpr(0, context); 902 // Dimensions and symbols. 903 for (unsigned j = 0; j < numDims + numSymbols; j++) { 904 if (flatExprs[j] == 0) 905 continue; 906 auto id = j < numDims ? getAffineDimExpr(j, context) 907 : getAffineSymbolExpr(j - numDims, context); 908 expr = expr + id * flatExprs[j]; 909 } 910 911 // Local identifiers. 912 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 913 j++) { 914 if (flatExprs[j] == 0) 915 continue; 916 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 917 expr = expr + term; 918 } 919 920 // Constant term. 921 int64_t constTerm = flatExprs[flatExprs.size() - 1]; 922 if (constTerm != 0) 923 expr = expr + constTerm; 924 return expr; 925 } 926 927 /// Constructs a semi-affine expression from a flat ArrayRef. If there are 928 /// local identifiers (neither dimensional nor symbolic) that appear in the sum 929 /// of products expression, `localExprs` is expected to have the AffineExprs for 930 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in 931 /// the format [dims, symbols, locals, constant term]. The semi-affine 932 /// expression is constructed in the sorted order of dimension and symbol 933 /// position numbers. Note: local expressions/ids are used for mod, div as well 934 /// as symbolic RHS terms for terms that are not pure affine. 935 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 936 unsigned numDims, 937 unsigned numSymbols, 938 ArrayRef<AffineExpr> localExprs, 939 MLIRContext *context) { 940 assert(!flatExprs.empty() && "flatExprs cannot be empty"); 941 942 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 943 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 944 "unexpected number of local expressions"); 945 946 AffineExpr expr = getAffineConstantExpr(0, context); 947 948 // We design indices as a pair which help us present the semi-affine map as 949 // sum of product where terms are sorted based on dimension or symbol 950 // position: <keyA, keyB> for expressions of the form dimension * symbol, 951 // where keyA is the position number of the dimension and keyB is the 952 // position number of the symbol. For dimensional expressions we set the index 953 // as (position number of the dimension, -1), as we want dimensional 954 // expressions to appear before symbolic and product of dimensional and 955 // symbolic expressions having the dimension with the same position number. 956 // For symbolic expression set the index as (position number of the symbol, 957 // maximum of last dimension and symbol position) number. For example, we want 958 // the expression we are constructing to look something like: d0 + d0 * s0 + 959 // s0 + d1*s1 + s1. 960 961 // Stores the affine expression corresponding to a given index. 962 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap; 963 // Stores the constant coefficient value corresponding to a given 964 // dimension, symbol or a non-pure affine expression stored in `localExprs`. 965 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients; 966 // Stores the indices as defined above, and later sorted to produce 967 // the semi-affine expression in the desired form. 968 SmallVector<std::pair<unsigned, signed>, 8> indices; 969 970 // Example: expression = d0 + d0 * s0 + 2 * s0. 971 // indices = [{0,-1}, {0, 0}, {0, 1}] 972 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}] 973 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}] 974 975 // Adds entries to `indexToExprMap`, `coefficients` and `indices`. 976 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient, 977 AffineExpr expr) { 978 assert(std::find(indices.begin(), indices.end(), index) == indices.end() && 979 "Key is already present in indices vector and overwriting will " 980 "happen in `indexToExprMap` and `coefficients`!"); 981 982 indices.push_back(index); 983 coefficients.insert({index, coefficient}); 984 indexToExprMap.insert({index, expr}); 985 }; 986 987 // Design indices for dimensional or symbolic terms, and store the indices, 988 // constant coefficient corresponding to the indices in `coefficients` map, 989 // and affine expression corresponding to indices in `indexToExprMap` map. 990 991 for (unsigned j = 0; j < numDims; ++j) { 992 if (flatExprs[j] == 0) 993 continue; 994 // For dimensional expressions we set the index as <position number of the 995 // dimension, 0>, as we want dimensional expressions to appear before 996 // symbolic ones and products of dimensional and symbolic expressions 997 // having the dimension with the same position number. 998 std::pair<unsigned, signed> indexEntry(j, -1); 999 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context)); 1000 } 1001 for (unsigned j = numDims; j < numDims + numSymbols; ++j) { 1002 if (flatExprs[j] == 0) 1003 continue; 1004 // For symbolic expression set the index as <position number 1005 // of the symbol, max(dimCount, symCount)> number, 1006 // as we want symbolic expressions with the same positional number to 1007 // appear after dimensional expressions having the same positional number. 1008 std::pair<unsigned, signed> indexEntry(j - numDims, 1009 std::max(numDims, numSymbols)); 1010 addEntry(indexEntry, flatExprs[j], 1011 getAffineSymbolExpr(j - numDims, context)); 1012 } 1013 1014 // Denotes semi-affine product, modulo or division terms, which has been added 1015 // to the `indexToExpr` map. 1016 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1, 1017 false); 1018 unsigned lhsPos, rhsPos; 1019 // Construct indices for product terms involving dimension, symbol or constant 1020 // as lhs/rhs, and store the indices, constant coefficient corresponding to 1021 // the indices in `coefficients` map, and affine expression corresponding to 1022 // in indices in `indexToExprMap` map. 1023 for (const auto &it : llvm::enumerate(localExprs)) { 1024 AffineExpr expr = it.value(); 1025 if (flatExprs[numDims + numSymbols + it.index()] == 0) 1026 continue; 1027 AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS(); 1028 AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS(); 1029 if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) && 1030 (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() || 1031 rhs.isa<AffineConstantExpr>()))) { 1032 continue; 1033 } 1034 if (rhs.isa<AffineConstantExpr>()) { 1035 // For product/modulo/division expressions, when rhs of modulo/division 1036 // expression is constant, we put 0 in place of keyB, because we want 1037 // them to appear earlier in the semi-affine expression we are 1038 // constructing. When rhs is constant, we place 0 in place of keyB. 1039 if (lhs.isa<AffineDimExpr>()) { 1040 lhsPos = lhs.cast<AffineDimExpr>().getPosition(); 1041 std::pair<unsigned, signed> indexEntry(lhsPos, -1); 1042 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], 1043 expr); 1044 } else { 1045 lhsPos = lhs.cast<AffineSymbolExpr>().getPosition(); 1046 std::pair<unsigned, signed> indexEntry(lhsPos, 1047 std::max(numDims, numSymbols)); 1048 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], 1049 expr); 1050 } 1051 } else if (lhs.isa<AffineDimExpr>()) { 1052 // For product/modulo/division expressions having lhs as dimension and rhs 1053 // as symbol, we order the terms in the semi-affine expression based on 1054 // the pair: <keyA, keyB> for expressions of the form dimension * symbol, 1055 // where keyA is the position number of the dimension and keyB is the 1056 // position number of the symbol. 1057 lhsPos = lhs.cast<AffineDimExpr>().getPosition(); 1058 rhsPos = rhs.cast<AffineSymbolExpr>().getPosition(); 1059 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos); 1060 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); 1061 } else { 1062 // For product/modulo/division expressions having both lhs and rhs as 1063 // symbol, we design indices as a pair: <keyA, keyB> for expressions 1064 // of the form dimension * symbol, where keyA is the position number of 1065 // the dimension and keyB is the position number of the symbol. 1066 lhsPos = lhs.cast<AffineSymbolExpr>().getPosition(); 1067 rhsPos = rhs.cast<AffineSymbolExpr>().getPosition(); 1068 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos); 1069 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); 1070 } 1071 addedToMap[it.index()] = true; 1072 } 1073 1074 // Constructing the simplified semi-affine sum of product/division/mod 1075 // expression from the flattened form in the desired sorted order of indices 1076 // of the various individual product/division/mod expressions. 1077 std::sort(indices.begin(), indices.end()); 1078 for (const std::pair<unsigned, unsigned> index : indices) { 1079 assert(indexToExprMap.lookup(index) && 1080 "cannot find key in `indexToExprMap` map"); 1081 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index); 1082 } 1083 1084 // Local identifiers. 1085 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 1086 j++) { 1087 // If the coefficient of the local expression is 0, continue as we need not 1088 // add it in out final expression. 1089 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols]) 1090 continue; 1091 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 1092 expr = expr + term; 1093 } 1094 1095 // Constant term. 1096 int64_t constTerm = flatExprs.back(); 1097 if (constTerm != 0) 1098 expr = expr + constTerm; 1099 return expr; 1100 } 1101 1102 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, 1103 unsigned numSymbols) 1104 : numDims(numDims), numSymbols(numSymbols), numLocals(0) { 1105 operandExprStack.reserve(8); 1106 } 1107 1108 // In pure affine t = expr * c, we multiply each coefficient of lhs with c. 1109 // 1110 // In case of semi affine multiplication expressions, t = expr * symbolic_expr, 1111 // introduce a local variable p (= expr * symbolic_expr), and the affine 1112 // expression expr * symbolic_expr is added to `localExprs`. 1113 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { 1114 assert(operandExprStack.size() >= 2); 1115 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1116 operandExprStack.pop_back(); 1117 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1118 1119 // Flatten semi-affine multiplication expressions by introducing a local 1120 // variable in place of the product; the affine expression 1121 // corresponding to the quantifier is added to `localExprs`. 1122 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1123 MLIRContext *context = expr.getContext(); 1124 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, 1125 localExprs, context); 1126 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1127 localExprs, context); 1128 addLocalVariableSemiAffine(a * b, lhs, lhs.size()); 1129 return; 1130 } 1131 1132 // Get the RHS constant. 1133 auto rhsConst = rhs[getConstantIndex()]; 1134 for (unsigned i = 0, e = lhs.size(); i < e; i++) { 1135 lhs[i] *= rhsConst; 1136 } 1137 } 1138 1139 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { 1140 assert(operandExprStack.size() >= 2); 1141 const auto &rhs = operandExprStack.back(); 1142 auto &lhs = operandExprStack[operandExprStack.size() - 2]; 1143 assert(lhs.size() == rhs.size()); 1144 // Update the LHS in place. 1145 for (unsigned i = 0, e = rhs.size(); i < e; i++) { 1146 lhs[i] += rhs[i]; 1147 } 1148 // Pop off the RHS. 1149 operandExprStack.pop_back(); 1150 } 1151 1152 // 1153 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 1154 // 1155 // A mod expression "expr mod c" is thus flattened by introducing a new local 1156 // variable q (= expr floordiv c), such that expr mod c is replaced with 1157 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. 1158 // 1159 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr, 1160 // introduce a local variable m (= expr mod symbolic_expr), and the affine 1161 // expression expr mod symbolic_expr is added to `localExprs`. 1162 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { 1163 assert(operandExprStack.size() >= 2); 1164 1165 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1166 operandExprStack.pop_back(); 1167 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1168 MLIRContext *context = expr.getContext(); 1169 1170 // Flatten semi affine modulo expressions by introducing a local 1171 // variable in place of the modulo value, and the affine expression 1172 // corresponding to the quantifier is added to `localExprs`. 1173 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1174 AffineExpr dividendExpr = getAffineExprFromFlatForm( 1175 lhs, numDims, numSymbols, localExprs, context); 1176 AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1177 localExprs, context); 1178 AffineExpr modExpr = dividendExpr % divisorExpr; 1179 addLocalVariableSemiAffine(modExpr, lhs, lhs.size()); 1180 return; 1181 } 1182 1183 int64_t rhsConst = rhs[getConstantIndex()]; 1184 // TODO: handle modulo by zero case when this issue is fixed 1185 // at the other places in the IR. 1186 assert(rhsConst > 0 && "RHS constant has to be positive"); 1187 1188 // Check if the LHS expression is a multiple of modulo factor. 1189 unsigned i, e; 1190 for (i = 0, e = lhs.size(); i < e; i++) 1191 if (lhs[i] % rhsConst != 0) 1192 break; 1193 // If yes, modulo expression here simplifies to zero. 1194 if (i == lhs.size()) { 1195 std::fill(lhs.begin(), lhs.end(), 0); 1196 return; 1197 } 1198 1199 // Add a local variable for the quotient, i.e., expr % c is replaced by 1200 // (expr - q * c) where q = expr floordiv c. Do this while canceling out 1201 // the GCD of expr and c. 1202 SmallVector<int64_t, 8> floorDividend(lhs); 1203 uint64_t gcd = rhsConst; 1204 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1205 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 1206 // Simplify the numerator and the denominator. 1207 if (gcd != 1) { 1208 for (unsigned i = 0, e = floorDividend.size(); i < e; i++) 1209 floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd); 1210 } 1211 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd); 1212 1213 // Construct the AffineExpr form of the floordiv to store in localExprs. 1214 1215 AffineExpr dividendExpr = getAffineExprFromFlatForm( 1216 floorDividend, numDims, numSymbols, localExprs, context); 1217 AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context); 1218 AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr); 1219 int loc; 1220 if ((loc = findLocalId(floorDivExpr)) == -1) { 1221 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); 1222 // Set result at top of stack to "lhs - rhsConst * q". 1223 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; 1224 } else { 1225 // Reuse the existing local id. 1226 lhs[getLocalVarStartIndex() + loc] = -rhsConst; 1227 } 1228 } 1229 1230 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { 1231 visitDivExpr(expr, /*isCeil=*/true); 1232 } 1233 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { 1234 visitDivExpr(expr, /*isCeil=*/false); 1235 } 1236 1237 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { 1238 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1239 auto &eq = operandExprStack.back(); 1240 assert(expr.getPosition() < numDims && "Inconsistent number of dims"); 1241 eq[getDimStartIndex() + expr.getPosition()] = 1; 1242 } 1243 1244 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { 1245 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1246 auto &eq = operandExprStack.back(); 1247 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); 1248 eq[getSymbolStartIndex() + expr.getPosition()] = 1; 1249 } 1250 1251 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { 1252 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1253 auto &eq = operandExprStack.back(); 1254 eq[getConstantIndex()] = expr.getValue(); 1255 } 1256 1257 void SimpleAffineExprFlattener::addLocalVariableSemiAffine( 1258 AffineExpr expr, SmallVectorImpl<int64_t> &result, 1259 unsigned long resultSize) { 1260 assert(result.size() == resultSize && 1261 "`result` vector passed is not of correct size"); 1262 int loc; 1263 if ((loc = findLocalId(expr)) == -1) 1264 addLocalIdSemiAffine(expr); 1265 std::fill(result.begin(), result.end(), 0); 1266 if (loc == -1) 1267 result[getLocalVarStartIndex() + numLocals - 1] = 1; 1268 else 1269 result[getLocalVarStartIndex() + loc] = 1; 1270 } 1271 1272 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 1273 // A floordiv is thus flattened by introducing a new local variable q, and 1274 // replacing that expression with 'q' while adding the constraints 1275 // c * q <= expr <= c * q + c - 1 to localVarCst (done by 1276 // FlatAffineConstraints::addLocalFloorDiv). 1277 // 1278 // A ceildiv is similarly flattened: 1279 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c 1280 // 1281 // In case of semi affine division expressions, t = expr floordiv symbolic_expr 1282 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr 1283 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to 1284 // `localExprs`. 1285 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, 1286 bool isCeil) { 1287 assert(operandExprStack.size() >= 2); 1288 1289 MLIRContext *context = expr.getContext(); 1290 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1291 operandExprStack.pop_back(); 1292 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1293 1294 // Flatten semi affine division expressions by introducing a local 1295 // variable in place of the quotient, and the affine expression corresponding 1296 // to the quantifier is added to `localExprs`. 1297 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1298 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, 1299 localExprs, context); 1300 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1301 localExprs, context); 1302 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1303 addLocalVariableSemiAffine(divExpr, lhs, lhs.size()); 1304 return; 1305 } 1306 1307 // This is a pure affine expr; the RHS is a positive constant. 1308 int64_t rhsConst = rhs[getConstantIndex()]; 1309 // TODO: handle division by zero at the same time the issue is 1310 // fixed at other places. 1311 assert(rhsConst > 0 && "RHS constant has to be positive"); 1312 1313 // Simplify the floordiv, ceildiv if possible by canceling out the greatest 1314 // common divisors of the numerator and denominator. 1315 uint64_t gcd = std::abs(rhsConst); 1316 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1317 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 1318 // Simplify the numerator and the denominator. 1319 if (gcd != 1) { 1320 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1321 lhs[i] = lhs[i] / static_cast<int64_t>(gcd); 1322 } 1323 int64_t divisor = rhsConst / static_cast<int64_t>(gcd); 1324 // If the divisor becomes 1, the updated LHS is the result. (The 1325 // divisor can't be negative since rhsConst is positive). 1326 if (divisor == 1) 1327 return; 1328 1329 // If the divisor cannot be simplified to one, we will have to retain 1330 // the ceil/floor expr (simplified up until here). Add an existential 1331 // quantifier to express its result, i.e., expr1 div expr2 is replaced 1332 // by a new identifier, q. 1333 AffineExpr a = 1334 getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context); 1335 AffineExpr b = getAffineConstantExpr(divisor, context); 1336 1337 int loc; 1338 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1339 if ((loc = findLocalId(divExpr)) == -1) { 1340 if (!isCeil) { 1341 SmallVector<int64_t, 8> dividend(lhs); 1342 addLocalFloorDivId(dividend, divisor, divExpr); 1343 } else { 1344 // lhs ceildiv c <=> (lhs + c - 1) floordiv c 1345 SmallVector<int64_t, 8> dividend(lhs); 1346 dividend.back() += divisor - 1; 1347 addLocalFloorDivId(dividend, divisor, divExpr); 1348 } 1349 } 1350 // Set the expression on stack to the local var introduced to capture the 1351 // result of the division (floor or ceil). 1352 std::fill(lhs.begin(), lhs.end(), 0); 1353 if (loc == -1) 1354 lhs[getLocalVarStartIndex() + numLocals - 1] = 1; 1355 else 1356 lhs[getLocalVarStartIndex() + loc] = 1; 1357 } 1358 1359 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 1360 // The local identifier added is always a floordiv of a pure add/mul affine 1361 // function of other identifiers, coefficients of which are specified in 1362 // dividend and with respect to a positive constant divisor. localExpr is the 1363 // simplified tree expression (AffineExpr) corresponding to the quantifier. 1364 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend, 1365 int64_t divisor, 1366 AffineExpr localExpr) { 1367 assert(divisor > 0 && "positive constant divisor expected"); 1368 for (SmallVector<int64_t, 8> &subExpr : operandExprStack) 1369 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1370 localExprs.push_back(localExpr); 1371 numLocals++; 1372 // dividend and divisor are not used here; an override of this method uses it. 1373 } 1374 1375 void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) { 1376 for (SmallVector<int64_t, 8> &subExpr : operandExprStack) 1377 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1378 localExprs.push_back(localExpr); 1379 ++numLocals; 1380 } 1381 1382 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) { 1383 SmallVectorImpl<AffineExpr>::iterator it; 1384 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end()) 1385 return -1; 1386 return it - localExprs.begin(); 1387 } 1388 1389 /// Simplify the affine expression by flattening it and reconstructing it. 1390 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, 1391 unsigned numSymbols) { 1392 // Simplify semi-affine expressions separately. 1393 if (!expr.isPureAffine()) 1394 expr = simplifySemiAffine(expr); 1395 1396 SimpleAffineExprFlattener flattener(numDims, numSymbols); 1397 flattener.walkPostOrder(expr); 1398 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); 1399 if (!expr.isPureAffine() && 1400 expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1401 flattener.localExprs, 1402 expr.getContext())) 1403 return expr; 1404 AffineExpr simplifiedExpr = 1405 expr.isPureAffine() 1406 ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1407 flattener.localExprs, expr.getContext()) 1408 : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1409 flattener.localExprs, 1410 expr.getContext()); 1411 1412 flattener.operandExprStack.pop_back(); 1413 assert(flattener.operandExprStack.empty()); 1414 return simplifiedExpr; 1415 } 1416