1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineExpr.h" 10 #include "AffineExprDetail.h" 11 #include "mlir/IR/AffineExprVisitor.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/IntegerSet.h" 14 #include "mlir/Support/MathExtras.h" 15 #include "mlir/Support/TypeID.h" 16 #include "llvm/ADT/STLExtras.h" 17 18 using namespace mlir; 19 using namespace mlir::detail; 20 21 MLIRContext *AffineExpr::getContext() const { return expr->context; } 22 23 AffineExprKind AffineExpr::getKind() const { return expr->kind; } 24 25 /// Walk all of the AffineExprs in this subgraph in postorder. 26 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const { 27 struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> { 28 std::function<void(AffineExpr)> callback; 29 30 AffineExprWalker(std::function<void(AffineExpr)> callback) 31 : callback(callback) {} 32 33 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); } 34 void visitConstantExpr(AffineConstantExpr expr) { callback(expr); } 35 void visitDimExpr(AffineDimExpr expr) { callback(expr); } 36 void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); } 37 }; 38 39 AffineExprWalker(callback).walkPostOrder(*this); 40 } 41 42 // Dispatch affine expression construction based on kind. 43 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, 44 AffineExpr rhs) { 45 if (kind == AffineExprKind::Add) 46 return lhs + rhs; 47 if (kind == AffineExprKind::Mul) 48 return lhs * rhs; 49 if (kind == AffineExprKind::FloorDiv) 50 return lhs.floorDiv(rhs); 51 if (kind == AffineExprKind::CeilDiv) 52 return lhs.ceilDiv(rhs); 53 if (kind == AffineExprKind::Mod) 54 return lhs % rhs; 55 56 llvm_unreachable("unknown binary operation on affine expressions"); 57 } 58 59 /// This method substitutes any uses of dimensions and symbols (e.g. 60 /// dim#0 with dimReplacements[0]) and returns the modified expression tree. 61 AffineExpr 62 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 63 ArrayRef<AffineExpr> symReplacements) const { 64 switch (getKind()) { 65 case AffineExprKind::Constant: 66 return *this; 67 case AffineExprKind::DimId: { 68 unsigned dimId = cast<AffineDimExpr>().getPosition(); 69 if (dimId >= dimReplacements.size()) 70 return *this; 71 return dimReplacements[dimId]; 72 } 73 case AffineExprKind::SymbolId: { 74 unsigned symId = cast<AffineSymbolExpr>().getPosition(); 75 if (symId >= symReplacements.size()) 76 return *this; 77 return symReplacements[symId]; 78 } 79 case AffineExprKind::Add: 80 case AffineExprKind::Mul: 81 case AffineExprKind::FloorDiv: 82 case AffineExprKind::CeilDiv: 83 case AffineExprKind::Mod: 84 auto binOp = cast<AffineBinaryOpExpr>(); 85 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 86 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 87 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 88 if (newLHS == lhs && newRHS == rhs) 89 return *this; 90 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 91 } 92 llvm_unreachable("Unknown AffineExpr"); 93 } 94 95 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const { 96 return replaceDimsAndSymbols(dimReplacements, {}); 97 } 98 99 AffineExpr 100 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const { 101 return replaceDimsAndSymbols({}, symReplacements); 102 } 103 104 /// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1]. 105 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift) const { 106 SmallVector<AffineExpr, 4> dims; 107 for (unsigned idx = 0; idx < numDims; ++idx) 108 dims.push_back(getAffineDimExpr(idx + shift, getContext())); 109 return replaceDimsAndSymbols(dims, {}); 110 } 111 112 /// Replace symbols[0 .. numSymbols - 1] by 113 /// symbols[shift .. shift + numSymbols - 1]. 114 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift) const { 115 SmallVector<AffineExpr, 4> symbols; 116 for (unsigned idx = 0; idx < numSymbols; ++idx) 117 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext())); 118 return replaceDimsAndSymbols({}, symbols); 119 } 120 121 /// Sparse replace method. Return the modified expression tree. 122 AffineExpr 123 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 124 auto it = map.find(*this); 125 if (it != map.end()) 126 return it->second; 127 switch (getKind()) { 128 default: 129 return *this; 130 case AffineExprKind::Add: 131 case AffineExprKind::Mul: 132 case AffineExprKind::FloorDiv: 133 case AffineExprKind::CeilDiv: 134 case AffineExprKind::Mod: 135 auto binOp = cast<AffineBinaryOpExpr>(); 136 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 137 auto newLHS = lhs.replace(map); 138 auto newRHS = rhs.replace(map); 139 if (newLHS == lhs && newRHS == rhs) 140 return *this; 141 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 142 } 143 llvm_unreachable("Unknown AffineExpr"); 144 } 145 146 /// Sparse replace method. Return the modified expression tree. 147 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const { 148 DenseMap<AffineExpr, AffineExpr> map; 149 map.insert(std::make_pair(expr, replacement)); 150 return replace(map); 151 } 152 /// Returns true if this expression is made out of only symbols and 153 /// constants (no dimensional identifiers). 154 bool AffineExpr::isSymbolicOrConstant() const { 155 switch (getKind()) { 156 case AffineExprKind::Constant: 157 return true; 158 case AffineExprKind::DimId: 159 return false; 160 case AffineExprKind::SymbolId: 161 return true; 162 163 case AffineExprKind::Add: 164 case AffineExprKind::Mul: 165 case AffineExprKind::FloorDiv: 166 case AffineExprKind::CeilDiv: 167 case AffineExprKind::Mod: { 168 auto expr = this->cast<AffineBinaryOpExpr>(); 169 return expr.getLHS().isSymbolicOrConstant() && 170 expr.getRHS().isSymbolicOrConstant(); 171 } 172 } 173 llvm_unreachable("Unknown AffineExpr"); 174 } 175 176 /// Returns true if this is a pure affine expression, i.e., multiplication, 177 /// floordiv, ceildiv, and mod is only allowed w.r.t constants. 178 bool AffineExpr::isPureAffine() const { 179 switch (getKind()) { 180 case AffineExprKind::SymbolId: 181 case AffineExprKind::DimId: 182 case AffineExprKind::Constant: 183 return true; 184 case AffineExprKind::Add: { 185 auto op = cast<AffineBinaryOpExpr>(); 186 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine(); 187 } 188 189 case AffineExprKind::Mul: { 190 // TODO: Canonicalize the constants in binary operators to the RHS when 191 // possible, allowing this to merge into the next case. 192 auto op = cast<AffineBinaryOpExpr>(); 193 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() && 194 (op.getLHS().template isa<AffineConstantExpr>() || 195 op.getRHS().template isa<AffineConstantExpr>()); 196 } 197 case AffineExprKind::FloorDiv: 198 case AffineExprKind::CeilDiv: 199 case AffineExprKind::Mod: { 200 auto op = cast<AffineBinaryOpExpr>(); 201 return op.getLHS().isPureAffine() && 202 op.getRHS().template isa<AffineConstantExpr>(); 203 } 204 } 205 llvm_unreachable("Unknown AffineExpr"); 206 } 207 208 // Returns the greatest known integral divisor of this affine expression. 209 int64_t AffineExpr::getLargestKnownDivisor() const { 210 AffineBinaryOpExpr binExpr(nullptr); 211 switch (getKind()) { 212 case AffineExprKind::SymbolId: 213 LLVM_FALLTHROUGH; 214 case AffineExprKind::DimId: 215 return 1; 216 case AffineExprKind::Constant: 217 return std::abs(this->cast<AffineConstantExpr>().getValue()); 218 case AffineExprKind::Mul: { 219 binExpr = this->cast<AffineBinaryOpExpr>(); 220 return binExpr.getLHS().getLargestKnownDivisor() * 221 binExpr.getRHS().getLargestKnownDivisor(); 222 } 223 case AffineExprKind::Add: 224 LLVM_FALLTHROUGH; 225 case AffineExprKind::FloorDiv: 226 case AffineExprKind::CeilDiv: 227 case AffineExprKind::Mod: { 228 binExpr = cast<AffineBinaryOpExpr>(); 229 return llvm::GreatestCommonDivisor64( 230 binExpr.getLHS().getLargestKnownDivisor(), 231 binExpr.getRHS().getLargestKnownDivisor()); 232 } 233 } 234 llvm_unreachable("Unknown AffineExpr"); 235 } 236 237 bool AffineExpr::isMultipleOf(int64_t factor) const { 238 AffineBinaryOpExpr binExpr(nullptr); 239 uint64_t l, u; 240 switch (getKind()) { 241 case AffineExprKind::SymbolId: 242 LLVM_FALLTHROUGH; 243 case AffineExprKind::DimId: 244 return factor * factor == 1; 245 case AffineExprKind::Constant: 246 return cast<AffineConstantExpr>().getValue() % factor == 0; 247 case AffineExprKind::Mul: { 248 binExpr = cast<AffineBinaryOpExpr>(); 249 // It's probably not worth optimizing this further (to not traverse the 250 // whole sub-tree under - it that would require a version of isMultipleOf 251 // that on a 'false' return also returns the largest known divisor). 252 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 || 253 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 || 254 (l * u) % factor == 0; 255 } 256 case AffineExprKind::Add: 257 case AffineExprKind::FloorDiv: 258 case AffineExprKind::CeilDiv: 259 case AffineExprKind::Mod: { 260 binExpr = cast<AffineBinaryOpExpr>(); 261 return llvm::GreatestCommonDivisor64( 262 binExpr.getLHS().getLargestKnownDivisor(), 263 binExpr.getRHS().getLargestKnownDivisor()) % 264 factor == 265 0; 266 } 267 } 268 llvm_unreachable("Unknown AffineExpr"); 269 } 270 271 bool AffineExpr::isFunctionOfDim(unsigned position) const { 272 if (getKind() == AffineExprKind::DimId) { 273 return *this == mlir::getAffineDimExpr(position, getContext()); 274 } 275 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 276 return expr.getLHS().isFunctionOfDim(position) || 277 expr.getRHS().isFunctionOfDim(position); 278 } 279 return false; 280 } 281 282 bool AffineExpr::isFunctionOfSymbol(unsigned position) const { 283 if (getKind() == AffineExprKind::SymbolId) { 284 return *this == mlir::getAffineSymbolExpr(position, getContext()); 285 } 286 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 287 return expr.getLHS().isFunctionOfSymbol(position) || 288 expr.getRHS().isFunctionOfSymbol(position); 289 } 290 return false; 291 } 292 293 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr) 294 : AffineExpr(ptr) {} 295 AffineExpr AffineBinaryOpExpr::getLHS() const { 296 return static_cast<ImplType *>(expr)->lhs; 297 } 298 AffineExpr AffineBinaryOpExpr::getRHS() const { 299 return static_cast<ImplType *>(expr)->rhs; 300 } 301 302 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {} 303 unsigned AffineDimExpr::getPosition() const { 304 return static_cast<ImplType *>(expr)->position; 305 } 306 307 /// Returns true if the expression is divisible by the given symbol with 308 /// position `symbolPos`. The argument `opKind` specifies here what kind of 309 /// division or mod operation called this division. It helps in implementing the 310 /// commutative property of the floordiv and ceildiv operations. If the argument 311 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv 312 /// operation, then the commutative property can be used otherwise, the floordiv 313 /// operation is not divisible. The same argument holds for ceildiv operation. 314 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, 315 AffineExprKind opKind) { 316 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 317 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 318 opKind == AffineExprKind::CeilDiv) && 319 "unexpected opKind"); 320 switch (expr.getKind()) { 321 case AffineExprKind::Constant: 322 if (expr.cast<AffineConstantExpr>().getValue()) 323 return false; 324 return true; 325 case AffineExprKind::DimId: 326 return false; 327 case AffineExprKind::SymbolId: 328 return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos); 329 // Checks divisibility by the given symbol for both operands. 330 case AffineExprKind::Add: { 331 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 332 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) && 333 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 334 } 335 // Checks divisibility by the given symbol for both operands. Consider the 336 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, 337 // this is a division by s1 and both the operands of modulo are divisible by 338 // s1 but it is not divisible by s1 always. The third argument is 339 // `AffineExprKind::Mod` for this reason. 340 case AffineExprKind::Mod: { 341 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 342 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, 343 AffineExprKind::Mod) && 344 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, 345 AffineExprKind::Mod); 346 } 347 // Checks if any of the operand divisible by the given symbol. 348 case AffineExprKind::Mul: { 349 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 350 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) || 351 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 352 } 353 // Floordiv and ceildiv are divisible by the given symbol when the first 354 // operand is divisible, and the affine expression kind of the argument expr 355 // is same as the argument `opKind`. This can be inferred from commutative 356 // property of floordiv and ceildiv operations and are as follow: 357 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 358 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 359 // It will fail if operations are not same. For example: 360 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 361 case AffineExprKind::FloorDiv: 362 case AffineExprKind::CeilDiv: { 363 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 364 if (opKind != expr.getKind()) 365 return false; 366 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()); 367 } 368 } 369 llvm_unreachable("Unknown AffineExpr"); 370 } 371 372 /// Divides the given expression by the given symbol at position `symbolPos`. It 373 /// considers the divisibility condition is checked before calling itself. A 374 /// null expression is returned whenever the divisibility condition fails. 375 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, 376 AffineExprKind opKind) { 377 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 378 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 379 opKind == AffineExprKind::CeilDiv) && 380 "unexpected opKind"); 381 switch (expr.getKind()) { 382 case AffineExprKind::Constant: 383 if (expr.cast<AffineConstantExpr>().getValue() != 0) 384 return nullptr; 385 return getAffineConstantExpr(0, expr.getContext()); 386 case AffineExprKind::DimId: 387 return nullptr; 388 case AffineExprKind::SymbolId: 389 return getAffineConstantExpr(1, expr.getContext()); 390 // Dividing both operands by the given symbol. 391 case AffineExprKind::Add: { 392 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 393 return getAffineBinaryOpExpr( 394 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind), 395 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind)); 396 } 397 // Dividing both operands by the given symbol. 398 case AffineExprKind::Mod: { 399 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 400 return getAffineBinaryOpExpr( 401 expr.getKind(), 402 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 403 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind())); 404 } 405 // Dividing any of the operand by the given symbol. 406 case AffineExprKind::Mul: { 407 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 408 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind)) 409 return binaryExpr.getLHS() * 410 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind); 411 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) * 412 binaryExpr.getRHS(); 413 } 414 // Dividing first operand only by the given symbol. 415 case AffineExprKind::FloorDiv: 416 case AffineExprKind::CeilDiv: { 417 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 418 return getAffineBinaryOpExpr( 419 expr.getKind(), 420 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 421 binaryExpr.getRHS()); 422 } 423 } 424 llvm_unreachable("Unknown AffineExpr"); 425 } 426 427 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv 428 /// operations when the second operand simplifies to a symbol and the first 429 /// operand is divisible by that symbol. It can be applied to any semi-affine 430 /// expression. Returned expression can either be a semi-affine or pure affine 431 /// expression. 432 static AffineExpr simplifySemiAffine(AffineExpr expr) { 433 switch (expr.getKind()) { 434 case AffineExprKind::Constant: 435 case AffineExprKind::DimId: 436 case AffineExprKind::SymbolId: 437 return expr; 438 case AffineExprKind::Add: 439 case AffineExprKind::Mul: { 440 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 441 return getAffineBinaryOpExpr(expr.getKind(), 442 simplifySemiAffine(binaryExpr.getLHS()), 443 simplifySemiAffine(binaryExpr.getRHS())); 444 } 445 // Check if the simplification of the second operand is a symbol, and the 446 // first operand is divisible by it. If the operation is a modulo, a constant 447 // zero expression is returned. In the case of floordiv and ceildiv, the 448 // symbol from the simplification of the second operand divides the first 449 // operand. Otherwise, simplification is not possible. 450 case AffineExprKind::FloorDiv: 451 case AffineExprKind::CeilDiv: 452 case AffineExprKind::Mod: { 453 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 454 AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS()); 455 AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS()); 456 AffineSymbolExpr symbolExpr = 457 simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>(); 458 if (!symbolExpr) 459 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 460 unsigned symbolPos = symbolExpr.getPosition(); 461 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind())) 462 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 463 if (expr.getKind() == AffineExprKind::Mod) 464 return getAffineConstantExpr(0, expr.getContext()); 465 return symbolicDivide(sLHS, symbolPos, expr.getKind()); 466 } 467 } 468 llvm_unreachable("Unknown AffineExpr"); 469 } 470 471 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, 472 MLIRContext *context) { 473 auto assignCtx = [context](AffineDimExprStorage *storage) { 474 storage->context = context; 475 }; 476 477 StorageUniquer &uniquer = context->getAffineUniquer(); 478 return uniquer.get<AffineDimExprStorage>( 479 assignCtx, static_cast<unsigned>(kind), position); 480 } 481 482 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { 483 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context); 484 } 485 486 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr) 487 : AffineExpr(ptr) {} 488 unsigned AffineSymbolExpr::getPosition() const { 489 return static_cast<ImplType *>(expr)->position; 490 } 491 492 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { 493 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context); 494 ; 495 } 496 497 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr) 498 : AffineExpr(ptr) {} 499 int64_t AffineConstantExpr::getValue() const { 500 return static_cast<ImplType *>(expr)->constant; 501 } 502 503 bool AffineExpr::operator==(int64_t v) const { 504 return *this == getAffineConstantExpr(v, getContext()); 505 } 506 507 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { 508 auto assignCtx = [context](AffineConstantExprStorage *storage) { 509 storage->context = context; 510 }; 511 512 StorageUniquer &uniquer = context->getAffineUniquer(); 513 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant); 514 } 515 516 /// Simplify add expression. Return nullptr if it can't be simplified. 517 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { 518 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 519 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 520 // Fold if both LHS, RHS are a constant. 521 if (lhsConst && rhsConst) 522 return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), 523 lhs.getContext()); 524 525 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). 526 // If only one of them is a symbolic expressions, make it the RHS. 527 if (lhs.isa<AffineConstantExpr>() || 528 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) { 529 return rhs + lhs; 530 } 531 532 // At this point, if there was a constant, it would be on the right. 533 534 // Addition with a zero is a noop, return the other input. 535 if (rhsConst) { 536 if (rhsConst.getValue() == 0) 537 return lhs; 538 } 539 // Fold successive additions like (d0 + 2) + 3 into d0 + 5. 540 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 541 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) { 542 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 543 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); 544 } 545 546 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr". 547 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their 548 // respective multiplicands. 549 Optional<int64_t> rLhsConst, rRhsConst; 550 AffineExpr firstExpr, secondExpr; 551 AffineConstantExpr rLhsConstExpr; 552 auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>(); 553 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul && 554 (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) { 555 rLhsConst = rLhsConstExpr.getValue(); 556 firstExpr = lBinOpExpr.getLHS(); 557 } else { 558 rLhsConst = 1; 559 firstExpr = lhs; 560 } 561 562 auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>(); 563 AffineConstantExpr rRhsConstExpr; 564 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul && 565 (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) { 566 rRhsConst = rRhsConstExpr.getValue(); 567 secondExpr = rBinOpExpr.getLHS(); 568 } else { 569 rRhsConst = 1; 570 secondExpr = rhs; 571 } 572 573 if (rLhsConst && rRhsConst && firstExpr == secondExpr) 574 return getAffineBinaryOpExpr( 575 AffineExprKind::Mul, firstExpr, 576 getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(), 577 lhs.getContext())); 578 579 // When doing successive additions, bring constant to the right: turn (d0 + 2) 580 // + d1 into (d0 + d1) + 2. 581 if (lBin && lBin.getKind() == AffineExprKind::Add) { 582 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 583 return lBin.getLHS() + rhs + lrhs; 584 } 585 } 586 587 // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This 588 // leads to a much more efficient form when 'c' is a power of two, and in 589 // general a more compact and readable form. 590 591 // Process '(expr floordiv c) * (-c)'. 592 if (!rBinOpExpr) 593 return nullptr; 594 595 auto lrhs = rBinOpExpr.getLHS(); 596 auto rrhs = rBinOpExpr.getRHS(); 597 598 // Process lrhs, which is 'expr floordiv c'. 599 AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>(); 600 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) 601 return nullptr; 602 603 auto llrhs = lrBinOpExpr.getLHS(); 604 auto rlrhs = lrBinOpExpr.getRHS(); 605 606 if (lhs == llrhs && rlrhs == -rrhs) { 607 return lhs % rlrhs; 608 } 609 return nullptr; 610 } 611 612 AffineExpr AffineExpr::operator+(int64_t v) const { 613 return *this + getAffineConstantExpr(v, getContext()); 614 } 615 AffineExpr AffineExpr::operator+(AffineExpr other) const { 616 if (auto simplified = simplifyAdd(*this, other)) 617 return simplified; 618 619 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 620 return uniquer.get<AffineBinaryOpExprStorage>( 621 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other); 622 } 623 624 /// Simplify a multiply expression. Return nullptr if it can't be simplified. 625 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { 626 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 627 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 628 629 if (lhsConst && rhsConst) 630 return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), 631 lhs.getContext()); 632 633 assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant()); 634 635 // Canonicalize the mul expression so that the constant/symbolic term is the 636 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a 637 // constant. (Note that a constant is trivially symbolic). 638 if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) { 639 // At least one of them has to be symbolic. 640 return rhs * lhs; 641 } 642 643 // At this point, if there was a constant, it would be on the right. 644 645 // Multiplication with a one is a noop, return the other input. 646 if (rhsConst) { 647 if (rhsConst.getValue() == 1) 648 return lhs; 649 // Multiplication with zero. 650 if (rhsConst.getValue() == 0) 651 return rhsConst; 652 } 653 654 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. 655 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 656 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) { 657 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 658 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue()); 659 } 660 661 // When doing successive multiplication, bring constant to the right: turn (d0 662 // * 2) * d1 into (d0 * d1) * 2. 663 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 664 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 665 return (lBin.getLHS() * rhs) * lrhs; 666 } 667 } 668 669 return nullptr; 670 } 671 672 AffineExpr AffineExpr::operator*(int64_t v) const { 673 return *this * getAffineConstantExpr(v, getContext()); 674 } 675 AffineExpr AffineExpr::operator*(AffineExpr other) const { 676 if (auto simplified = simplifyMul(*this, other)) 677 return simplified; 678 679 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 680 return uniquer.get<AffineBinaryOpExprStorage>( 681 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other); 682 } 683 684 // Unary minus, delegate to operator*. 685 AffineExpr AffineExpr::operator-() const { 686 return *this * getAffineConstantExpr(-1, getContext()); 687 } 688 689 // Delegate to operator+. 690 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } 691 AffineExpr AffineExpr::operator-(AffineExpr other) const { 692 return *this + (-other); 693 } 694 695 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { 696 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 697 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 698 699 // mlir floordiv by zero or negative numbers is undefined and preserved as is. 700 if (!rhsConst || rhsConst.getValue() < 1) 701 return nullptr; 702 703 if (lhsConst) 704 return getAffineConstantExpr( 705 floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 706 707 // Fold floordiv of a multiply with a constant that is a multiple of the 708 // divisor. Eg: (i * 128) floordiv 64 = i * 2. 709 if (rhsConst == 1) 710 return lhs; 711 712 // Simplify (expr * const) floordiv divConst when expr is known to be a 713 // multiple of divConst. 714 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 715 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 716 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 717 // rhsConst is known to be a positive constant. 718 if (lrhs.getValue() % rhsConst.getValue() == 0) 719 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 720 } 721 } 722 723 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is 724 // known to be a multiple of divConst. 725 if (lBin && lBin.getKind() == AffineExprKind::Add) { 726 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 727 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 728 // rhsConst is known to be a positive constant. 729 if (llhsDiv % rhsConst.getValue() == 0 || 730 lrhsDiv % rhsConst.getValue() == 0) 731 return lBin.getLHS().floorDiv(rhsConst.getValue()) + 732 lBin.getRHS().floorDiv(rhsConst.getValue()); 733 } 734 735 return nullptr; 736 } 737 738 AffineExpr AffineExpr::floorDiv(uint64_t v) const { 739 return floorDiv(getAffineConstantExpr(v, getContext())); 740 } 741 AffineExpr AffineExpr::floorDiv(AffineExpr other) const { 742 if (auto simplified = simplifyFloorDiv(*this, other)) 743 return simplified; 744 745 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 746 return uniquer.get<AffineBinaryOpExprStorage>( 747 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this, 748 other); 749 } 750 751 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { 752 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 753 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 754 755 if (!rhsConst || rhsConst.getValue() < 1) 756 return nullptr; 757 758 if (lhsConst) 759 return getAffineConstantExpr( 760 ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 761 762 // Fold ceildiv of a multiply with a constant that is a multiple of the 763 // divisor. Eg: (i * 128) ceildiv 64 = i * 2. 764 if (rhsConst.getValue() == 1) 765 return lhs; 766 767 // Simplify (expr * const) ceildiv divConst when const is known to be a 768 // multiple of divConst. 769 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 770 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 771 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 772 // rhsConst is known to be a positive constant. 773 if (lrhs.getValue() % rhsConst.getValue() == 0) 774 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 775 } 776 } 777 778 return nullptr; 779 } 780 781 AffineExpr AffineExpr::ceilDiv(uint64_t v) const { 782 return ceilDiv(getAffineConstantExpr(v, getContext())); 783 } 784 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const { 785 if (auto simplified = simplifyCeilDiv(*this, other)) 786 return simplified; 787 788 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 789 return uniquer.get<AffineBinaryOpExprStorage>( 790 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this, 791 other); 792 } 793 794 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { 795 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 796 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 797 798 // mod w.r.t zero or negative numbers is undefined and preserved as is. 799 if (!rhsConst || rhsConst.getValue() < 1) 800 return nullptr; 801 802 if (lhsConst) 803 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), 804 lhs.getContext()); 805 806 // Fold modulo of an expression that is known to be a multiple of a constant 807 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) 808 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0. 809 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) 810 return getAffineConstantExpr(0, lhs.getContext()); 811 812 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is 813 // known to be a multiple of divConst. 814 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 815 if (lBin && lBin.getKind() == AffineExprKind::Add) { 816 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 817 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 818 // rhsConst is known to be a positive constant. 819 if (llhsDiv % rhsConst.getValue() == 0) 820 return lBin.getRHS() % rhsConst.getValue(); 821 if (lrhsDiv % rhsConst.getValue() == 0) 822 return lBin.getLHS() % rhsConst.getValue(); 823 } 824 825 return nullptr; 826 } 827 828 AffineExpr AffineExpr::operator%(uint64_t v) const { 829 return *this % getAffineConstantExpr(v, getContext()); 830 } 831 AffineExpr AffineExpr::operator%(AffineExpr other) const { 832 if (auto simplified = simplifyMod(*this, other)) 833 return simplified; 834 835 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 836 return uniquer.get<AffineBinaryOpExprStorage>( 837 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other); 838 } 839 840 AffineExpr AffineExpr::compose(AffineMap map) const { 841 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(), 842 map.getResults().end()); 843 return replaceDimsAndSymbols(dimReplacements, {}); 844 } 845 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) { 846 expr.print(os); 847 return os; 848 } 849 850 /// Constructs an affine expression from a flat ArrayRef. If there are local 851 /// identifiers (neither dimensional nor symbolic) that appear in the sum of 852 /// products expression, `localExprs` is expected to have the AffineExpr 853 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be 854 /// in the format [dims, symbols, locals, constant term]. 855 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 856 unsigned numDims, 857 unsigned numSymbols, 858 ArrayRef<AffineExpr> localExprs, 859 MLIRContext *context) { 860 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 861 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 862 "unexpected number of local expressions"); 863 864 auto expr = getAffineConstantExpr(0, context); 865 // Dimensions and symbols. 866 for (unsigned j = 0; j < numDims + numSymbols; j++) { 867 if (flatExprs[j] == 0) 868 continue; 869 auto id = j < numDims ? getAffineDimExpr(j, context) 870 : getAffineSymbolExpr(j - numDims, context); 871 expr = expr + id * flatExprs[j]; 872 } 873 874 // Local identifiers. 875 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 876 j++) { 877 if (flatExprs[j] == 0) 878 continue; 879 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 880 expr = expr + term; 881 } 882 883 // Constant term. 884 int64_t constTerm = flatExprs[flatExprs.size() - 1]; 885 if (constTerm != 0) 886 expr = expr + constTerm; 887 return expr; 888 } 889 890 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, 891 unsigned numSymbols) 892 : numDims(numDims), numSymbols(numSymbols), numLocals(0) { 893 operandExprStack.reserve(8); 894 } 895 896 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { 897 assert(operandExprStack.size() >= 2); 898 // This is a pure affine expr; the RHS will be a constant. 899 assert(expr.getRHS().isa<AffineConstantExpr>()); 900 // Get the RHS constant. 901 auto rhsConst = operandExprStack.back()[getConstantIndex()]; 902 operandExprStack.pop_back(); 903 // Update the LHS in place instead of pop and push. 904 auto &lhs = operandExprStack.back(); 905 for (unsigned i = 0, e = lhs.size(); i < e; i++) { 906 lhs[i] *= rhsConst; 907 } 908 } 909 910 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { 911 assert(operandExprStack.size() >= 2); 912 const auto &rhs = operandExprStack.back(); 913 auto &lhs = operandExprStack[operandExprStack.size() - 2]; 914 assert(lhs.size() == rhs.size()); 915 // Update the LHS in place. 916 for (unsigned i = 0, e = rhs.size(); i < e; i++) { 917 lhs[i] += rhs[i]; 918 } 919 // Pop off the RHS. 920 operandExprStack.pop_back(); 921 } 922 923 // 924 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 925 // 926 // A mod expression "expr mod c" is thus flattened by introducing a new local 927 // variable q (= expr floordiv c), such that expr mod c is replaced with 928 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. 929 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { 930 assert(operandExprStack.size() >= 2); 931 // This is a pure affine expr; the RHS will be a constant. 932 assert(expr.getRHS().isa<AffineConstantExpr>()); 933 auto rhsConst = operandExprStack.back()[getConstantIndex()]; 934 operandExprStack.pop_back(); 935 auto &lhs = operandExprStack.back(); 936 // TODO: handle modulo by zero case when this issue is fixed 937 // at the other places in the IR. 938 assert(rhsConst > 0 && "RHS constant has to be positive"); 939 940 // Check if the LHS expression is a multiple of modulo factor. 941 unsigned i, e; 942 for (i = 0, e = lhs.size(); i < e; i++) 943 if (lhs[i] % rhsConst != 0) 944 break; 945 // If yes, modulo expression here simplifies to zero. 946 if (i == lhs.size()) { 947 std::fill(lhs.begin(), lhs.end(), 0); 948 return; 949 } 950 951 // Add a local variable for the quotient, i.e., expr % c is replaced by 952 // (expr - q * c) where q = expr floordiv c. Do this while canceling out 953 // the GCD of expr and c. 954 SmallVector<int64_t, 8> floorDividend(lhs); 955 uint64_t gcd = rhsConst; 956 for (unsigned i = 0, e = lhs.size(); i < e; i++) 957 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 958 // Simplify the numerator and the denominator. 959 if (gcd != 1) { 960 for (unsigned i = 0, e = floorDividend.size(); i < e; i++) 961 floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd); 962 } 963 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd); 964 965 // Construct the AffineExpr form of the floordiv to store in localExprs. 966 MLIRContext *context = expr.getContext(); 967 auto dividendExpr = getAffineExprFromFlatForm( 968 floorDividend, numDims, numSymbols, localExprs, context); 969 auto divisorExpr = getAffineConstantExpr(floorDivisor, context); 970 auto floorDivExpr = dividendExpr.floorDiv(divisorExpr); 971 int loc; 972 if ((loc = findLocalId(floorDivExpr)) == -1) { 973 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); 974 // Set result at top of stack to "lhs - rhsConst * q". 975 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; 976 } else { 977 // Reuse the existing local id. 978 lhs[getLocalVarStartIndex() + loc] = -rhsConst; 979 } 980 } 981 982 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { 983 visitDivExpr(expr, /*isCeil=*/true); 984 } 985 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { 986 visitDivExpr(expr, /*isCeil=*/false); 987 } 988 989 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { 990 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 991 auto &eq = operandExprStack.back(); 992 assert(expr.getPosition() < numDims && "Inconsistent number of dims"); 993 eq[getDimStartIndex() + expr.getPosition()] = 1; 994 } 995 996 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { 997 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 998 auto &eq = operandExprStack.back(); 999 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); 1000 eq[getSymbolStartIndex() + expr.getPosition()] = 1; 1001 } 1002 1003 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { 1004 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1005 auto &eq = operandExprStack.back(); 1006 eq[getConstantIndex()] = expr.getValue(); 1007 } 1008 1009 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 1010 // A floordiv is thus flattened by introducing a new local variable q, and 1011 // replacing that expression with 'q' while adding the constraints 1012 // c * q <= expr <= c * q + c - 1 to localVarCst (done by 1013 // FlatAffineConstraints::addLocalFloorDiv). 1014 // 1015 // A ceildiv is similarly flattened: 1016 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c 1017 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, 1018 bool isCeil) { 1019 assert(operandExprStack.size() >= 2); 1020 assert(expr.getRHS().isa<AffineConstantExpr>()); 1021 1022 // This is a pure affine expr; the RHS is a positive constant. 1023 int64_t rhsConst = operandExprStack.back()[getConstantIndex()]; 1024 // TODO: handle division by zero at the same time the issue is 1025 // fixed at other places. 1026 assert(rhsConst > 0 && "RHS constant has to be positive"); 1027 operandExprStack.pop_back(); 1028 auto &lhs = operandExprStack.back(); 1029 1030 // Simplify the floordiv, ceildiv if possible by canceling out the greatest 1031 // common divisors of the numerator and denominator. 1032 uint64_t gcd = std::abs(rhsConst); 1033 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1034 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 1035 // Simplify the numerator and the denominator. 1036 if (gcd != 1) { 1037 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1038 lhs[i] = lhs[i] / static_cast<int64_t>(gcd); 1039 } 1040 int64_t divisor = rhsConst / static_cast<int64_t>(gcd); 1041 // If the divisor becomes 1, the updated LHS is the result. (The 1042 // divisor can't be negative since rhsConst is positive). 1043 if (divisor == 1) 1044 return; 1045 1046 // If the divisor cannot be simplified to one, we will have to retain 1047 // the ceil/floor expr (simplified up until here). Add an existential 1048 // quantifier to express its result, i.e., expr1 div expr2 is replaced 1049 // by a new identifier, q. 1050 MLIRContext *context = expr.getContext(); 1051 auto a = 1052 getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context); 1053 auto b = getAffineConstantExpr(divisor, context); 1054 1055 int loc; 1056 auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1057 if ((loc = findLocalId(divExpr)) == -1) { 1058 if (!isCeil) { 1059 SmallVector<int64_t, 8> dividend(lhs); 1060 addLocalFloorDivId(dividend, divisor, divExpr); 1061 } else { 1062 // lhs ceildiv c <=> (lhs + c - 1) floordiv c 1063 SmallVector<int64_t, 8> dividend(lhs); 1064 dividend.back() += divisor - 1; 1065 addLocalFloorDivId(dividend, divisor, divExpr); 1066 } 1067 } 1068 // Set the expression on stack to the local var introduced to capture the 1069 // result of the division (floor or ceil). 1070 std::fill(lhs.begin(), lhs.end(), 0); 1071 if (loc == -1) 1072 lhs[getLocalVarStartIndex() + numLocals - 1] = 1; 1073 else 1074 lhs[getLocalVarStartIndex() + loc] = 1; 1075 } 1076 1077 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 1078 // The local identifier added is always a floordiv of a pure add/mul affine 1079 // function of other identifiers, coefficients of which are specified in 1080 // dividend and with respect to a positive constant divisor. localExpr is the 1081 // simplified tree expression (AffineExpr) corresponding to the quantifier. 1082 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend, 1083 int64_t divisor, 1084 AffineExpr localExpr) { 1085 assert(divisor > 0 && "positive constant divisor expected"); 1086 for (auto &subExpr : operandExprStack) 1087 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1088 localExprs.push_back(localExpr); 1089 numLocals++; 1090 // dividend and divisor are not used here; an override of this method uses it. 1091 } 1092 1093 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) { 1094 SmallVectorImpl<AffineExpr>::iterator it; 1095 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end()) 1096 return -1; 1097 return it - localExprs.begin(); 1098 } 1099 1100 /// Simplify the affine expression by flattening it and reconstructing it. 1101 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, 1102 unsigned numSymbols) { 1103 // Simplify semi-affine expressions separately. 1104 if (!expr.isPureAffine()) 1105 expr = simplifySemiAffine(expr); 1106 if (!expr.isPureAffine()) 1107 return expr; 1108 1109 SimpleAffineExprFlattener flattener(numDims, numSymbols); 1110 flattener.walkPostOrder(expr); 1111 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); 1112 auto simplifiedExpr = 1113 getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1114 flattener.localExprs, expr.getContext()); 1115 flattener.operandExprStack.pop_back(); 1116 assert(flattener.operandExprStack.empty()); 1117 1118 return simplifiedExpr; 1119 } 1120