1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "AffineExprDetail.h" 12 #include "mlir/IR/AffineExpr.h" 13 #include "mlir/IR/AffineExprVisitor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/Support/MathExtras.h" 17 #include "mlir/Support/TypeID.h" 18 #include "llvm/ADT/STLExtras.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 MLIRContext *AffineExpr::getContext() const { return expr->context; } 24 25 AffineExprKind AffineExpr::getKind() const { return expr->kind; } 26 27 /// Walk all of the AffineExprs in this subgraph in postorder. 28 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const { 29 struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> { 30 std::function<void(AffineExpr)> callback; 31 32 AffineExprWalker(std::function<void(AffineExpr)> callback) 33 : callback(std::move(callback)) {} 34 35 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); } 36 void visitConstantExpr(AffineConstantExpr expr) { callback(expr); } 37 void visitDimExpr(AffineDimExpr expr) { callback(expr); } 38 void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); } 39 }; 40 41 AffineExprWalker(std::move(callback)).walkPostOrder(*this); 42 } 43 44 // Dispatch affine expression construction based on kind. 45 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, 46 AffineExpr rhs) { 47 if (kind == AffineExprKind::Add) 48 return lhs + rhs; 49 if (kind == AffineExprKind::Mul) 50 return lhs * rhs; 51 if (kind == AffineExprKind::FloorDiv) 52 return lhs.floorDiv(rhs); 53 if (kind == AffineExprKind::CeilDiv) 54 return lhs.ceilDiv(rhs); 55 if (kind == AffineExprKind::Mod) 56 return lhs % rhs; 57 58 llvm_unreachable("unknown binary operation on affine expressions"); 59 } 60 61 /// This method substitutes any uses of dimensions and symbols (e.g. 62 /// dim#0 with dimReplacements[0]) and returns the modified expression tree. 63 AffineExpr 64 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 65 ArrayRef<AffineExpr> symReplacements) const { 66 switch (getKind()) { 67 case AffineExprKind::Constant: 68 return *this; 69 case AffineExprKind::DimId: { 70 unsigned dimId = cast<AffineDimExpr>().getPosition(); 71 if (dimId >= dimReplacements.size()) 72 return *this; 73 return dimReplacements[dimId]; 74 } 75 case AffineExprKind::SymbolId: { 76 unsigned symId = cast<AffineSymbolExpr>().getPosition(); 77 if (symId >= symReplacements.size()) 78 return *this; 79 return symReplacements[symId]; 80 } 81 case AffineExprKind::Add: 82 case AffineExprKind::Mul: 83 case AffineExprKind::FloorDiv: 84 case AffineExprKind::CeilDiv: 85 case AffineExprKind::Mod: 86 auto binOp = cast<AffineBinaryOpExpr>(); 87 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 88 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 89 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 90 if (newLHS == lhs && newRHS == rhs) 91 return *this; 92 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 93 } 94 llvm_unreachable("Unknown AffineExpr"); 95 } 96 97 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const { 98 return replaceDimsAndSymbols(dimReplacements, {}); 99 } 100 101 AffineExpr 102 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const { 103 return replaceDimsAndSymbols({}, symReplacements); 104 } 105 106 /// Replace dims[offset ... numDims) 107 /// by dims[offset + shift ... shift + numDims). 108 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift, 109 unsigned offset) const { 110 SmallVector<AffineExpr, 4> dims; 111 for (unsigned idx = 0; idx < offset; ++idx) 112 dims.push_back(getAffineDimExpr(idx, getContext())); 113 for (unsigned idx = offset; idx < numDims; ++idx) 114 dims.push_back(getAffineDimExpr(idx + shift, getContext())); 115 return replaceDimsAndSymbols(dims, {}); 116 } 117 118 /// Replace symbols[offset ... numSymbols) 119 /// by symbols[offset + shift ... shift + numSymbols). 120 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift, 121 unsigned offset) const { 122 SmallVector<AffineExpr, 4> symbols; 123 for (unsigned idx = 0; idx < offset; ++idx) 124 symbols.push_back(getAffineSymbolExpr(idx, getContext())); 125 for (unsigned idx = offset; idx < numSymbols; ++idx) 126 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext())); 127 return replaceDimsAndSymbols({}, symbols); 128 } 129 130 /// Sparse replace method. Return the modified expression tree. 131 AffineExpr 132 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 133 auto it = map.find(*this); 134 if (it != map.end()) 135 return it->second; 136 switch (getKind()) { 137 default: 138 return *this; 139 case AffineExprKind::Add: 140 case AffineExprKind::Mul: 141 case AffineExprKind::FloorDiv: 142 case AffineExprKind::CeilDiv: 143 case AffineExprKind::Mod: 144 auto binOp = cast<AffineBinaryOpExpr>(); 145 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 146 auto newLHS = lhs.replace(map); 147 auto newRHS = rhs.replace(map); 148 if (newLHS == lhs && newRHS == rhs) 149 return *this; 150 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 151 } 152 llvm_unreachable("Unknown AffineExpr"); 153 } 154 155 /// Sparse replace method. Return the modified expression tree. 156 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const { 157 DenseMap<AffineExpr, AffineExpr> map; 158 map.insert(std::make_pair(expr, replacement)); 159 return replace(map); 160 } 161 /// Returns true if this expression is made out of only symbols and 162 /// constants (no dimensional identifiers). 163 bool AffineExpr::isSymbolicOrConstant() const { 164 switch (getKind()) { 165 case AffineExprKind::Constant: 166 return true; 167 case AffineExprKind::DimId: 168 return false; 169 case AffineExprKind::SymbolId: 170 return true; 171 172 case AffineExprKind::Add: 173 case AffineExprKind::Mul: 174 case AffineExprKind::FloorDiv: 175 case AffineExprKind::CeilDiv: 176 case AffineExprKind::Mod: { 177 auto expr = this->cast<AffineBinaryOpExpr>(); 178 return expr.getLHS().isSymbolicOrConstant() && 179 expr.getRHS().isSymbolicOrConstant(); 180 } 181 } 182 llvm_unreachable("Unknown AffineExpr"); 183 } 184 185 /// Returns true if this is a pure affine expression, i.e., multiplication, 186 /// floordiv, ceildiv, and mod is only allowed w.r.t constants. 187 bool AffineExpr::isPureAffine() const { 188 switch (getKind()) { 189 case AffineExprKind::SymbolId: 190 case AffineExprKind::DimId: 191 case AffineExprKind::Constant: 192 return true; 193 case AffineExprKind::Add: { 194 auto op = cast<AffineBinaryOpExpr>(); 195 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine(); 196 } 197 198 case AffineExprKind::Mul: { 199 // TODO: Canonicalize the constants in binary operators to the RHS when 200 // possible, allowing this to merge into the next case. 201 auto op = cast<AffineBinaryOpExpr>(); 202 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() && 203 (op.getLHS().template isa<AffineConstantExpr>() || 204 op.getRHS().template isa<AffineConstantExpr>()); 205 } 206 case AffineExprKind::FloorDiv: 207 case AffineExprKind::CeilDiv: 208 case AffineExprKind::Mod: { 209 auto op = cast<AffineBinaryOpExpr>(); 210 return op.getLHS().isPureAffine() && 211 op.getRHS().template isa<AffineConstantExpr>(); 212 } 213 } 214 llvm_unreachable("Unknown AffineExpr"); 215 } 216 217 // Returns the greatest known integral divisor of this affine expression. 218 int64_t AffineExpr::getLargestKnownDivisor() const { 219 AffineBinaryOpExpr binExpr(nullptr); 220 switch (getKind()) { 221 case AffineExprKind::CeilDiv: 222 LLVM_FALLTHROUGH; 223 case AffineExprKind::DimId: 224 case AffineExprKind::FloorDiv: 225 case AffineExprKind::SymbolId: 226 return 1; 227 case AffineExprKind::Constant: 228 return std::abs(this->cast<AffineConstantExpr>().getValue()); 229 case AffineExprKind::Mul: { 230 binExpr = this->cast<AffineBinaryOpExpr>(); 231 return binExpr.getLHS().getLargestKnownDivisor() * 232 binExpr.getRHS().getLargestKnownDivisor(); 233 } 234 case AffineExprKind::Add: 235 LLVM_FALLTHROUGH; 236 case AffineExprKind::Mod: { 237 binExpr = cast<AffineBinaryOpExpr>(); 238 return llvm::GreatestCommonDivisor64( 239 binExpr.getLHS().getLargestKnownDivisor(), 240 binExpr.getRHS().getLargestKnownDivisor()); 241 } 242 } 243 llvm_unreachable("Unknown AffineExpr"); 244 } 245 246 bool AffineExpr::isMultipleOf(int64_t factor) const { 247 AffineBinaryOpExpr binExpr(nullptr); 248 uint64_t l, u; 249 switch (getKind()) { 250 case AffineExprKind::SymbolId: 251 LLVM_FALLTHROUGH; 252 case AffineExprKind::DimId: 253 return factor * factor == 1; 254 case AffineExprKind::Constant: 255 return cast<AffineConstantExpr>().getValue() % factor == 0; 256 case AffineExprKind::Mul: { 257 binExpr = cast<AffineBinaryOpExpr>(); 258 // It's probably not worth optimizing this further (to not traverse the 259 // whole sub-tree under - it that would require a version of isMultipleOf 260 // that on a 'false' return also returns the largest known divisor). 261 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 || 262 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 || 263 (l * u) % factor == 0; 264 } 265 case AffineExprKind::Add: 266 case AffineExprKind::FloorDiv: 267 case AffineExprKind::CeilDiv: 268 case AffineExprKind::Mod: { 269 binExpr = cast<AffineBinaryOpExpr>(); 270 return llvm::GreatestCommonDivisor64( 271 binExpr.getLHS().getLargestKnownDivisor(), 272 binExpr.getRHS().getLargestKnownDivisor()) % 273 factor == 274 0; 275 } 276 } 277 llvm_unreachable("Unknown AffineExpr"); 278 } 279 280 bool AffineExpr::isFunctionOfDim(unsigned position) const { 281 if (getKind() == AffineExprKind::DimId) { 282 return *this == mlir::getAffineDimExpr(position, getContext()); 283 } 284 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 285 return expr.getLHS().isFunctionOfDim(position) || 286 expr.getRHS().isFunctionOfDim(position); 287 } 288 return false; 289 } 290 291 bool AffineExpr::isFunctionOfSymbol(unsigned position) const { 292 if (getKind() == AffineExprKind::SymbolId) { 293 return *this == mlir::getAffineSymbolExpr(position, getContext()); 294 } 295 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 296 return expr.getLHS().isFunctionOfSymbol(position) || 297 expr.getRHS().isFunctionOfSymbol(position); 298 } 299 return false; 300 } 301 302 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr) 303 : AffineExpr(ptr) {} 304 AffineExpr AffineBinaryOpExpr::getLHS() const { 305 return static_cast<ImplType *>(expr)->lhs; 306 } 307 AffineExpr AffineBinaryOpExpr::getRHS() const { 308 return static_cast<ImplType *>(expr)->rhs; 309 } 310 311 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {} 312 unsigned AffineDimExpr::getPosition() const { 313 return static_cast<ImplType *>(expr)->position; 314 } 315 316 /// Returns true if the expression is divisible by the given symbol with 317 /// position `symbolPos`. The argument `opKind` specifies here what kind of 318 /// division or mod operation called this division. It helps in implementing the 319 /// commutative property of the floordiv and ceildiv operations. If the argument 320 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv 321 /// operation, then the commutative property can be used otherwise, the floordiv 322 /// operation is not divisible. The same argument holds for ceildiv operation. 323 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, 324 AffineExprKind opKind) { 325 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 326 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 327 opKind == AffineExprKind::CeilDiv) && 328 "unexpected opKind"); 329 switch (expr.getKind()) { 330 case AffineExprKind::Constant: 331 return expr.cast<AffineConstantExpr>().getValue() == 0; 332 case AffineExprKind::DimId: 333 return false; 334 case AffineExprKind::SymbolId: 335 return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos); 336 // Checks divisibility by the given symbol for both operands. 337 case AffineExprKind::Add: { 338 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 339 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) && 340 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 341 } 342 // Checks divisibility by the given symbol for both operands. Consider the 343 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, 344 // this is a division by s1 and both the operands of modulo are divisible by 345 // s1 but it is not divisible by s1 always. The third argument is 346 // `AffineExprKind::Mod` for this reason. 347 case AffineExprKind::Mod: { 348 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 349 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, 350 AffineExprKind::Mod) && 351 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, 352 AffineExprKind::Mod); 353 } 354 // Checks if any of the operand divisible by the given symbol. 355 case AffineExprKind::Mul: { 356 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 357 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) || 358 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 359 } 360 // Floordiv and ceildiv are divisible by the given symbol when the first 361 // operand is divisible, and the affine expression kind of the argument expr 362 // is same as the argument `opKind`. This can be inferred from commutative 363 // property of floordiv and ceildiv operations and are as follow: 364 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 365 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 366 // It will fail if operations are not same. For example: 367 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 368 case AffineExprKind::FloorDiv: 369 case AffineExprKind::CeilDiv: { 370 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 371 if (opKind != expr.getKind()) 372 return false; 373 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()); 374 } 375 } 376 llvm_unreachable("Unknown AffineExpr"); 377 } 378 379 /// Divides the given expression by the given symbol at position `symbolPos`. It 380 /// considers the divisibility condition is checked before calling itself. A 381 /// null expression is returned whenever the divisibility condition fails. 382 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, 383 AffineExprKind opKind) { 384 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 385 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 386 opKind == AffineExprKind::CeilDiv) && 387 "unexpected opKind"); 388 switch (expr.getKind()) { 389 case AffineExprKind::Constant: 390 if (expr.cast<AffineConstantExpr>().getValue() != 0) 391 return nullptr; 392 return getAffineConstantExpr(0, expr.getContext()); 393 case AffineExprKind::DimId: 394 return nullptr; 395 case AffineExprKind::SymbolId: 396 return getAffineConstantExpr(1, expr.getContext()); 397 // Dividing both operands by the given symbol. 398 case AffineExprKind::Add: { 399 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 400 return getAffineBinaryOpExpr( 401 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind), 402 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind)); 403 } 404 // Dividing both operands by the given symbol. 405 case AffineExprKind::Mod: { 406 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 407 return getAffineBinaryOpExpr( 408 expr.getKind(), 409 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 410 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind())); 411 } 412 // Dividing any of the operand by the given symbol. 413 case AffineExprKind::Mul: { 414 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 415 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind)) 416 return binaryExpr.getLHS() * 417 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind); 418 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) * 419 binaryExpr.getRHS(); 420 } 421 // Dividing first operand only by the given symbol. 422 case AffineExprKind::FloorDiv: 423 case AffineExprKind::CeilDiv: { 424 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 425 return getAffineBinaryOpExpr( 426 expr.getKind(), 427 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 428 binaryExpr.getRHS()); 429 } 430 } 431 llvm_unreachable("Unknown AffineExpr"); 432 } 433 434 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv 435 /// operations when the second operand simplifies to a symbol and the first 436 /// operand is divisible by that symbol. It can be applied to any semi-affine 437 /// expression. Returned expression can either be a semi-affine or pure affine 438 /// expression. 439 static AffineExpr simplifySemiAffine(AffineExpr expr) { 440 switch (expr.getKind()) { 441 case AffineExprKind::Constant: 442 case AffineExprKind::DimId: 443 case AffineExprKind::SymbolId: 444 return expr; 445 case AffineExprKind::Add: 446 case AffineExprKind::Mul: { 447 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 448 return getAffineBinaryOpExpr(expr.getKind(), 449 simplifySemiAffine(binaryExpr.getLHS()), 450 simplifySemiAffine(binaryExpr.getRHS())); 451 } 452 // Check if the simplification of the second operand is a symbol, and the 453 // first operand is divisible by it. If the operation is a modulo, a constant 454 // zero expression is returned. In the case of floordiv and ceildiv, the 455 // symbol from the simplification of the second operand divides the first 456 // operand. Otherwise, simplification is not possible. 457 case AffineExprKind::FloorDiv: 458 case AffineExprKind::CeilDiv: 459 case AffineExprKind::Mod: { 460 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 461 AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS()); 462 AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS()); 463 AffineSymbolExpr symbolExpr = 464 simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>(); 465 if (!symbolExpr) 466 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 467 unsigned symbolPos = symbolExpr.getPosition(); 468 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind())) 469 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 470 if (expr.getKind() == AffineExprKind::Mod) 471 return getAffineConstantExpr(0, expr.getContext()); 472 return symbolicDivide(sLHS, symbolPos, expr.getKind()); 473 } 474 } 475 llvm_unreachable("Unknown AffineExpr"); 476 } 477 478 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, 479 MLIRContext *context) { 480 auto assignCtx = [context](AffineDimExprStorage *storage) { 481 storage->context = context; 482 }; 483 484 StorageUniquer &uniquer = context->getAffineUniquer(); 485 return uniquer.get<AffineDimExprStorage>( 486 assignCtx, static_cast<unsigned>(kind), position); 487 } 488 489 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { 490 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context); 491 } 492 493 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr) 494 : AffineExpr(ptr) {} 495 unsigned AffineSymbolExpr::getPosition() const { 496 return static_cast<ImplType *>(expr)->position; 497 } 498 499 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { 500 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context); 501 ; 502 } 503 504 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr) 505 : AffineExpr(ptr) {} 506 int64_t AffineConstantExpr::getValue() const { 507 return static_cast<ImplType *>(expr)->constant; 508 } 509 510 bool AffineExpr::operator==(int64_t v) const { 511 return *this == getAffineConstantExpr(v, getContext()); 512 } 513 514 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { 515 auto assignCtx = [context](AffineConstantExprStorage *storage) { 516 storage->context = context; 517 }; 518 519 StorageUniquer &uniquer = context->getAffineUniquer(); 520 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant); 521 } 522 523 /// Simplify add expression. Return nullptr if it can't be simplified. 524 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { 525 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 526 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 527 // Fold if both LHS, RHS are a constant. 528 if (lhsConst && rhsConst) 529 return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), 530 lhs.getContext()); 531 532 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). 533 // If only one of them is a symbolic expressions, make it the RHS. 534 if (lhs.isa<AffineConstantExpr>() || 535 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) { 536 return rhs + lhs; 537 } 538 539 // At this point, if there was a constant, it would be on the right. 540 541 // Addition with a zero is a noop, return the other input. 542 if (rhsConst) { 543 if (rhsConst.getValue() == 0) 544 return lhs; 545 } 546 // Fold successive additions like (d0 + 2) + 3 into d0 + 5. 547 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 548 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) { 549 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 550 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); 551 } 552 553 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr". 554 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their 555 // respective multiplicands. 556 Optional<int64_t> rLhsConst, rRhsConst; 557 AffineExpr firstExpr, secondExpr; 558 AffineConstantExpr rLhsConstExpr; 559 auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>(); 560 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul && 561 (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) { 562 rLhsConst = rLhsConstExpr.getValue(); 563 firstExpr = lBinOpExpr.getLHS(); 564 } else { 565 rLhsConst = 1; 566 firstExpr = lhs; 567 } 568 569 auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>(); 570 AffineConstantExpr rRhsConstExpr; 571 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul && 572 (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) { 573 rRhsConst = rRhsConstExpr.getValue(); 574 secondExpr = rBinOpExpr.getLHS(); 575 } else { 576 rRhsConst = 1; 577 secondExpr = rhs; 578 } 579 580 if (rLhsConst && rRhsConst && firstExpr == secondExpr) 581 return getAffineBinaryOpExpr( 582 AffineExprKind::Mul, firstExpr, 583 getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext())); 584 585 // When doing successive additions, bring constant to the right: turn (d0 + 2) 586 // + d1 into (d0 + d1) + 2. 587 if (lBin && lBin.getKind() == AffineExprKind::Add) { 588 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 589 return lBin.getLHS() + rhs + lrhs; 590 } 591 } 592 593 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where 594 // q may be a constant or symbolic expression. This leads to a much more 595 // efficient form when 'c' is a power of two, and in general a more compact 596 // and readable form. 597 598 // Process '(expr floordiv c) * (-c)'. 599 if (!rBinOpExpr) 600 return nullptr; 601 602 auto lrhs = rBinOpExpr.getLHS(); 603 auto rrhs = rBinOpExpr.getRHS(); 604 605 AffineExpr llrhs, rlrhs; 606 607 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a 608 // symbolic expression. 609 auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>(); 610 // Check rrhsConstOpExpr = -1. 611 auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>(); 612 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr && 613 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) { 614 // Check llrhs = expr floordiv q. 615 llrhs = lrhsBinOpExpr.getLHS(); 616 // Check rlrhs = q. 617 rlrhs = lrhsBinOpExpr.getRHS(); 618 auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>(); 619 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv) 620 return nullptr; 621 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS()) 622 return lhs % rlrhs; 623 } 624 625 // Process lrhs, which is 'expr floordiv c'. 626 AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>(); 627 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) 628 return nullptr; 629 630 llrhs = lrBinOpExpr.getLHS(); 631 rlrhs = lrBinOpExpr.getRHS(); 632 633 if (lhs == llrhs && rlrhs == -rrhs) { 634 return lhs % rlrhs; 635 } 636 return nullptr; 637 } 638 639 AffineExpr AffineExpr::operator+(int64_t v) const { 640 return *this + getAffineConstantExpr(v, getContext()); 641 } 642 AffineExpr AffineExpr::operator+(AffineExpr other) const { 643 if (auto simplified = simplifyAdd(*this, other)) 644 return simplified; 645 646 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 647 return uniquer.get<AffineBinaryOpExprStorage>( 648 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other); 649 } 650 651 /// Simplify a multiply expression. Return nullptr if it can't be simplified. 652 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { 653 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 654 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 655 656 if (lhsConst && rhsConst) 657 return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), 658 lhs.getContext()); 659 660 assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant()); 661 662 // Canonicalize the mul expression so that the constant/symbolic term is the 663 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a 664 // constant. (Note that a constant is trivially symbolic). 665 if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) { 666 // At least one of them has to be symbolic. 667 return rhs * lhs; 668 } 669 670 // At this point, if there was a constant, it would be on the right. 671 672 // Multiplication with a one is a noop, return the other input. 673 if (rhsConst) { 674 if (rhsConst.getValue() == 1) 675 return lhs; 676 // Multiplication with zero. 677 if (rhsConst.getValue() == 0) 678 return rhsConst; 679 } 680 681 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. 682 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 683 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) { 684 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 685 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue()); 686 } 687 688 // When doing successive multiplication, bring constant to the right: turn (d0 689 // * 2) * d1 into (d0 * d1) * 2. 690 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 691 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 692 return (lBin.getLHS() * rhs) * lrhs; 693 } 694 } 695 696 return nullptr; 697 } 698 699 AffineExpr AffineExpr::operator*(int64_t v) const { 700 return *this * getAffineConstantExpr(v, getContext()); 701 } 702 AffineExpr AffineExpr::operator*(AffineExpr other) const { 703 if (auto simplified = simplifyMul(*this, other)) 704 return simplified; 705 706 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 707 return uniquer.get<AffineBinaryOpExprStorage>( 708 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other); 709 } 710 711 // Unary minus, delegate to operator*. 712 AffineExpr AffineExpr::operator-() const { 713 return *this * getAffineConstantExpr(-1, getContext()); 714 } 715 716 // Delegate to operator+. 717 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } 718 AffineExpr AffineExpr::operator-(AffineExpr other) const { 719 return *this + (-other); 720 } 721 722 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { 723 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 724 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 725 726 // mlir floordiv by zero or negative numbers is undefined and preserved as is. 727 if (!rhsConst || rhsConst.getValue() < 1) 728 return nullptr; 729 730 if (lhsConst) 731 return getAffineConstantExpr( 732 floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 733 734 // Fold floordiv of a multiply with a constant that is a multiple of the 735 // divisor. Eg: (i * 128) floordiv 64 = i * 2. 736 if (rhsConst == 1) 737 return lhs; 738 739 // Simplify (expr * const) floordiv divConst when expr is known to be a 740 // multiple of divConst. 741 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 742 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 743 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 744 // rhsConst is known to be a positive constant. 745 if (lrhs.getValue() % rhsConst.getValue() == 0) 746 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 747 } 748 } 749 750 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is 751 // known to be a multiple of divConst. 752 if (lBin && lBin.getKind() == AffineExprKind::Add) { 753 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 754 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 755 // rhsConst is known to be a positive constant. 756 if (llhsDiv % rhsConst.getValue() == 0 || 757 lrhsDiv % rhsConst.getValue() == 0) 758 return lBin.getLHS().floorDiv(rhsConst.getValue()) + 759 lBin.getRHS().floorDiv(rhsConst.getValue()); 760 } 761 762 return nullptr; 763 } 764 765 AffineExpr AffineExpr::floorDiv(uint64_t v) const { 766 return floorDiv(getAffineConstantExpr(v, getContext())); 767 } 768 AffineExpr AffineExpr::floorDiv(AffineExpr other) const { 769 if (auto simplified = simplifyFloorDiv(*this, other)) 770 return simplified; 771 772 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 773 return uniquer.get<AffineBinaryOpExprStorage>( 774 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this, 775 other); 776 } 777 778 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { 779 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 780 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 781 782 if (!rhsConst || rhsConst.getValue() < 1) 783 return nullptr; 784 785 if (lhsConst) 786 return getAffineConstantExpr( 787 ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 788 789 // Fold ceildiv of a multiply with a constant that is a multiple of the 790 // divisor. Eg: (i * 128) ceildiv 64 = i * 2. 791 if (rhsConst.getValue() == 1) 792 return lhs; 793 794 // Simplify (expr * const) ceildiv divConst when const is known to be a 795 // multiple of divConst. 796 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 797 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 798 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 799 // rhsConst is known to be a positive constant. 800 if (lrhs.getValue() % rhsConst.getValue() == 0) 801 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 802 } 803 } 804 805 return nullptr; 806 } 807 808 AffineExpr AffineExpr::ceilDiv(uint64_t v) const { 809 return ceilDiv(getAffineConstantExpr(v, getContext())); 810 } 811 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const { 812 if (auto simplified = simplifyCeilDiv(*this, other)) 813 return simplified; 814 815 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 816 return uniquer.get<AffineBinaryOpExprStorage>( 817 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this, 818 other); 819 } 820 821 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { 822 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 823 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 824 825 // mod w.r.t zero or negative numbers is undefined and preserved as is. 826 if (!rhsConst || rhsConst.getValue() < 1) 827 return nullptr; 828 829 if (lhsConst) 830 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), 831 lhs.getContext()); 832 833 // Fold modulo of an expression that is known to be a multiple of a constant 834 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) 835 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0. 836 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) 837 return getAffineConstantExpr(0, lhs.getContext()); 838 839 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is 840 // known to be a multiple of divConst. 841 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 842 if (lBin && lBin.getKind() == AffineExprKind::Add) { 843 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 844 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 845 // rhsConst is known to be a positive constant. 846 if (llhsDiv % rhsConst.getValue() == 0) 847 return lBin.getRHS() % rhsConst.getValue(); 848 if (lrhsDiv % rhsConst.getValue() == 0) 849 return lBin.getLHS() % rhsConst.getValue(); 850 } 851 852 // Simplify (e % a) % b to e % b when b evenly divides a 853 if (lBin && lBin.getKind() == AffineExprKind::Mod) { 854 auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>(); 855 if (intermediate && intermediate.getValue() >= 1 && 856 mod(intermediate.getValue(), rhsConst.getValue()) == 0) { 857 return lBin.getLHS() % rhsConst.getValue(); 858 } 859 } 860 861 return nullptr; 862 } 863 864 AffineExpr AffineExpr::operator%(uint64_t v) const { 865 return *this % getAffineConstantExpr(v, getContext()); 866 } 867 AffineExpr AffineExpr::operator%(AffineExpr other) const { 868 if (auto simplified = simplifyMod(*this, other)) 869 return simplified; 870 871 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 872 return uniquer.get<AffineBinaryOpExprStorage>( 873 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other); 874 } 875 876 AffineExpr AffineExpr::compose(AffineMap map) const { 877 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(), 878 map.getResults().end()); 879 return replaceDimsAndSymbols(dimReplacements, {}); 880 } 881 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) { 882 expr.print(os); 883 return os; 884 } 885 886 /// Constructs an affine expression from a flat ArrayRef. If there are local 887 /// identifiers (neither dimensional nor symbolic) that appear in the sum of 888 /// products expression, `localExprs` is expected to have the AffineExpr 889 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be 890 /// in the format [dims, symbols, locals, constant term]. 891 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 892 unsigned numDims, 893 unsigned numSymbols, 894 ArrayRef<AffineExpr> localExprs, 895 MLIRContext *context) { 896 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 897 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 898 "unexpected number of local expressions"); 899 900 auto expr = getAffineConstantExpr(0, context); 901 // Dimensions and symbols. 902 for (unsigned j = 0; j < numDims + numSymbols; j++) { 903 if (flatExprs[j] == 0) 904 continue; 905 auto id = j < numDims ? getAffineDimExpr(j, context) 906 : getAffineSymbolExpr(j - numDims, context); 907 expr = expr + id * flatExprs[j]; 908 } 909 910 // Local identifiers. 911 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 912 j++) { 913 if (flatExprs[j] == 0) 914 continue; 915 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 916 expr = expr + term; 917 } 918 919 // Constant term. 920 int64_t constTerm = flatExprs[flatExprs.size() - 1]; 921 if (constTerm != 0) 922 expr = expr + constTerm; 923 return expr; 924 } 925 926 /// Constructs a semi-affine expression from a flat ArrayRef. If there are 927 /// local identifiers (neither dimensional nor symbolic) that appear in the sum 928 /// of products expression, `localExprs` is expected to have the AffineExprs for 929 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in 930 /// the format [dims, symbols, locals, constant term]. The semi-affine 931 /// expression is constructed in the sorted order of dimension and symbol 932 /// position numbers. Note: local expressions/ids are used for mod, div as well 933 /// as symbolic RHS terms for terms that are not pure affine. 934 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 935 unsigned numDims, 936 unsigned numSymbols, 937 ArrayRef<AffineExpr> localExprs, 938 MLIRContext *context) { 939 assert(!flatExprs.empty() && "flatExprs cannot be empty"); 940 941 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 942 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 943 "unexpected number of local expressions"); 944 945 AffineExpr expr = getAffineConstantExpr(0, context); 946 947 // We design indices as a pair which help us present the semi-affine map as 948 // sum of product where terms are sorted based on dimension or symbol 949 // position: <keyA, keyB> for expressions of the form dimension * symbol, 950 // where keyA is the position number of the dimension and keyB is the 951 // position number of the symbol. For dimensional expressions we set the index 952 // as (position number of the dimension, -1), as we want dimensional 953 // expressions to appear before symbolic and product of dimensional and 954 // symbolic expressions having the dimension with the same position number. 955 // For symbolic expression set the index as (position number of the symbol, 956 // maximum of last dimension and symbol position) number. For example, we want 957 // the expression we are constructing to look something like: d0 + d0 * s0 + 958 // s0 + d1*s1 + s1. 959 960 // Stores the affine expression corresponding to a given index. 961 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap; 962 // Stores the constant coefficient value corresponding to a given 963 // dimension, symbol or a non-pure affine expression stored in `localExprs`. 964 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients; 965 // Stores the indices as defined above, and later sorted to produce 966 // the semi-affine expression in the desired form. 967 SmallVector<std::pair<unsigned, signed>, 8> indices; 968 969 // Example: expression = d0 + d0 * s0 + 2 * s0. 970 // indices = [{0,-1}, {0, 0}, {0, 1}] 971 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}] 972 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}] 973 974 // Adds entries to `indexToExprMap`, `coefficients` and `indices`. 975 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient, 976 AffineExpr expr) { 977 assert(std::find(indices.begin(), indices.end(), index) == indices.end() && 978 "Key is already present in indices vector and overwriting will " 979 "happen in `indexToExprMap` and `coefficients`!"); 980 981 indices.push_back(index); 982 coefficients.insert({index, coefficient}); 983 indexToExprMap.insert({index, expr}); 984 }; 985 986 // Design indices for dimensional or symbolic terms, and store the indices, 987 // constant coefficient corresponding to the indices in `coefficients` map, 988 // and affine expression corresponding to indices in `indexToExprMap` map. 989 990 for (unsigned j = 0; j < numDims; ++j) { 991 if (flatExprs[j] == 0) 992 continue; 993 // For dimensional expressions we set the index as <position number of the 994 // dimension, 0>, as we want dimensional expressions to appear before 995 // symbolic ones and products of dimensional and symbolic expressions 996 // having the dimension with the same position number. 997 std::pair<unsigned, signed> indexEntry(j, -1); 998 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context)); 999 } 1000 for (unsigned j = numDims; j < numDims + numSymbols; ++j) { 1001 if (flatExprs[j] == 0) 1002 continue; 1003 // For symbolic expression set the index as <position number 1004 // of the symbol, max(dimCount, symCount)> number, 1005 // as we want symbolic expressions with the same positional number to 1006 // appear after dimensional expressions having the same positional number. 1007 std::pair<unsigned, signed> indexEntry(j - numDims, 1008 std::max(numDims, numSymbols)); 1009 addEntry(indexEntry, flatExprs[j], 1010 getAffineSymbolExpr(j - numDims, context)); 1011 } 1012 1013 // Denotes semi-affine product, modulo or division terms, which has been added 1014 // to the `indexToExpr` map. 1015 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1, 1016 false); 1017 unsigned lhsPos, rhsPos; 1018 // Construct indices for product terms involving dimension, symbol or constant 1019 // as lhs/rhs, and store the indices, constant coefficient corresponding to 1020 // the indices in `coefficients` map, and affine expression corresponding to 1021 // in indices in `indexToExprMap` map. 1022 for (const auto &it : llvm::enumerate(localExprs)) { 1023 AffineExpr expr = it.value(); 1024 if (flatExprs[numDims + numSymbols + it.index()] == 0) 1025 continue; 1026 AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS(); 1027 AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS(); 1028 if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) && 1029 (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() || 1030 rhs.isa<AffineConstantExpr>()))) { 1031 continue; 1032 } 1033 if (rhs.isa<AffineConstantExpr>()) { 1034 // For product/modulo/division expressions, when rhs of modulo/division 1035 // expression is constant, we put 0 in place of keyB, because we want 1036 // them to appear earlier in the semi-affine expression we are 1037 // constructing. When rhs is constant, we place 0 in place of keyB. 1038 if (lhs.isa<AffineDimExpr>()) { 1039 lhsPos = lhs.cast<AffineDimExpr>().getPosition(); 1040 std::pair<unsigned, signed> indexEntry(lhsPos, -1); 1041 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], 1042 expr); 1043 } else { 1044 lhsPos = lhs.cast<AffineSymbolExpr>().getPosition(); 1045 std::pair<unsigned, signed> indexEntry(lhsPos, 1046 std::max(numDims, numSymbols)); 1047 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], 1048 expr); 1049 } 1050 } else if (lhs.isa<AffineDimExpr>()) { 1051 // For product/modulo/division expressions having lhs as dimension and rhs 1052 // as symbol, we order the terms in the semi-affine expression based on 1053 // the pair: <keyA, keyB> for expressions of the form dimension * symbol, 1054 // where keyA is the position number of the dimension and keyB is the 1055 // position number of the symbol. 1056 lhsPos = lhs.cast<AffineDimExpr>().getPosition(); 1057 rhsPos = rhs.cast<AffineSymbolExpr>().getPosition(); 1058 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos); 1059 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); 1060 } else { 1061 // For product/modulo/division expressions having both lhs and rhs as 1062 // symbol, we design indices as a pair: <keyA, keyB> for expressions 1063 // of the form dimension * symbol, where keyA is the position number of 1064 // the dimension and keyB is the position number of the symbol. 1065 lhsPos = lhs.cast<AffineSymbolExpr>().getPosition(); 1066 rhsPos = rhs.cast<AffineSymbolExpr>().getPosition(); 1067 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos); 1068 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); 1069 } 1070 addedToMap[it.index()] = true; 1071 } 1072 1073 // Constructing the simplified semi-affine sum of product/division/mod 1074 // expression from the flattened form in the desired sorted order of indices 1075 // of the various individual product/division/mod expressions. 1076 std::sort(indices.begin(), indices.end()); 1077 for (const std::pair<unsigned, unsigned> index : indices) { 1078 assert(indexToExprMap.lookup(index) && 1079 "cannot find key in `indexToExprMap` map"); 1080 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index); 1081 } 1082 1083 // Local identifiers. 1084 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 1085 j++) { 1086 // If the coefficient of the local expression is 0, continue as we need not 1087 // add it in out final expression. 1088 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols]) 1089 continue; 1090 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 1091 expr = expr + term; 1092 } 1093 1094 // Constant term. 1095 int64_t constTerm = flatExprs.back(); 1096 if (constTerm != 0) 1097 expr = expr + constTerm; 1098 return expr; 1099 } 1100 1101 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, 1102 unsigned numSymbols) 1103 : numDims(numDims), numSymbols(numSymbols), numLocals(0) { 1104 operandExprStack.reserve(8); 1105 } 1106 1107 // In pure affine t = expr * c, we multiply each coefficient of lhs with c. 1108 // 1109 // In case of semi affine multiplication expressions, t = expr * symbolic_expr, 1110 // introduce a local variable p (= expr * symbolic_expr), and the affine 1111 // expression expr * symbolic_expr is added to `localExprs`. 1112 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { 1113 assert(operandExprStack.size() >= 2); 1114 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1115 operandExprStack.pop_back(); 1116 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1117 1118 // Flatten semi-affine multiplication expressions by introducing a local 1119 // variable in place of the product; the affine expression 1120 // corresponding to the quantifier is added to `localExprs`. 1121 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1122 MLIRContext *context = expr.getContext(); 1123 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, 1124 localExprs, context); 1125 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1126 localExprs, context); 1127 addLocalVariableSemiAffine(a * b, lhs, lhs.size()); 1128 return; 1129 } 1130 1131 // Get the RHS constant. 1132 auto rhsConst = rhs[getConstantIndex()]; 1133 for (unsigned i = 0, e = lhs.size(); i < e; i++) { 1134 lhs[i] *= rhsConst; 1135 } 1136 } 1137 1138 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { 1139 assert(operandExprStack.size() >= 2); 1140 const auto &rhs = operandExprStack.back(); 1141 auto &lhs = operandExprStack[operandExprStack.size() - 2]; 1142 assert(lhs.size() == rhs.size()); 1143 // Update the LHS in place. 1144 for (unsigned i = 0, e = rhs.size(); i < e; i++) { 1145 lhs[i] += rhs[i]; 1146 } 1147 // Pop off the RHS. 1148 operandExprStack.pop_back(); 1149 } 1150 1151 // 1152 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 1153 // 1154 // A mod expression "expr mod c" is thus flattened by introducing a new local 1155 // variable q (= expr floordiv c), such that expr mod c is replaced with 1156 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. 1157 // 1158 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr, 1159 // introduce a local variable m (= expr mod symbolic_expr), and the affine 1160 // expression expr mod symbolic_expr is added to `localExprs`. 1161 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { 1162 assert(operandExprStack.size() >= 2); 1163 1164 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1165 operandExprStack.pop_back(); 1166 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1167 MLIRContext *context = expr.getContext(); 1168 1169 // Flatten semi affine modulo expressions by introducing a local 1170 // variable in place of the modulo value, and the affine expression 1171 // corresponding to the quantifier is added to `localExprs`. 1172 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1173 AffineExpr dividendExpr = getAffineExprFromFlatForm( 1174 lhs, numDims, numSymbols, localExprs, context); 1175 AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1176 localExprs, context); 1177 AffineExpr modExpr = dividendExpr % divisorExpr; 1178 addLocalVariableSemiAffine(modExpr, lhs, lhs.size()); 1179 return; 1180 } 1181 1182 int64_t rhsConst = rhs[getConstantIndex()]; 1183 // TODO: handle modulo by zero case when this issue is fixed 1184 // at the other places in the IR. 1185 assert(rhsConst > 0 && "RHS constant has to be positive"); 1186 1187 // Check if the LHS expression is a multiple of modulo factor. 1188 unsigned i, e; 1189 for (i = 0, e = lhs.size(); i < e; i++) 1190 if (lhs[i] % rhsConst != 0) 1191 break; 1192 // If yes, modulo expression here simplifies to zero. 1193 if (i == lhs.size()) { 1194 std::fill(lhs.begin(), lhs.end(), 0); 1195 return; 1196 } 1197 1198 // Add a local variable for the quotient, i.e., expr % c is replaced by 1199 // (expr - q * c) where q = expr floordiv c. Do this while canceling out 1200 // the GCD of expr and c. 1201 SmallVector<int64_t, 8> floorDividend(lhs); 1202 uint64_t gcd = rhsConst; 1203 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1204 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 1205 // Simplify the numerator and the denominator. 1206 if (gcd != 1) { 1207 for (unsigned i = 0, e = floorDividend.size(); i < e; i++) 1208 floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd); 1209 } 1210 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd); 1211 1212 // Construct the AffineExpr form of the floordiv to store in localExprs. 1213 1214 AffineExpr dividendExpr = getAffineExprFromFlatForm( 1215 floorDividend, numDims, numSymbols, localExprs, context); 1216 AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context); 1217 AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr); 1218 int loc; 1219 if ((loc = findLocalId(floorDivExpr)) == -1) { 1220 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); 1221 // Set result at top of stack to "lhs - rhsConst * q". 1222 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; 1223 } else { 1224 // Reuse the existing local id. 1225 lhs[getLocalVarStartIndex() + loc] = -rhsConst; 1226 } 1227 } 1228 1229 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { 1230 visitDivExpr(expr, /*isCeil=*/true); 1231 } 1232 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { 1233 visitDivExpr(expr, /*isCeil=*/false); 1234 } 1235 1236 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { 1237 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1238 auto &eq = operandExprStack.back(); 1239 assert(expr.getPosition() < numDims && "Inconsistent number of dims"); 1240 eq[getDimStartIndex() + expr.getPosition()] = 1; 1241 } 1242 1243 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { 1244 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1245 auto &eq = operandExprStack.back(); 1246 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); 1247 eq[getSymbolStartIndex() + expr.getPosition()] = 1; 1248 } 1249 1250 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { 1251 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1252 auto &eq = operandExprStack.back(); 1253 eq[getConstantIndex()] = expr.getValue(); 1254 } 1255 1256 void SimpleAffineExprFlattener::addLocalVariableSemiAffine( 1257 AffineExpr expr, SmallVectorImpl<int64_t> &result, 1258 unsigned long resultSize) { 1259 assert(result.size() == resultSize && 1260 "`result` vector passed is not of correct size"); 1261 int loc; 1262 if ((loc = findLocalId(expr)) == -1) 1263 addLocalIdSemiAffine(expr); 1264 std::fill(result.begin(), result.end(), 0); 1265 if (loc == -1) 1266 result[getLocalVarStartIndex() + numLocals - 1] = 1; 1267 else 1268 result[getLocalVarStartIndex() + loc] = 1; 1269 } 1270 1271 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 1272 // A floordiv is thus flattened by introducing a new local variable q, and 1273 // replacing that expression with 'q' while adding the constraints 1274 // c * q <= expr <= c * q + c - 1 to localVarCst (done by 1275 // FlatAffineConstraints::addLocalFloorDiv). 1276 // 1277 // A ceildiv is similarly flattened: 1278 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c 1279 // 1280 // In case of semi affine division expressions, t = expr floordiv symbolic_expr 1281 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr 1282 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to 1283 // `localExprs`. 1284 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, 1285 bool isCeil) { 1286 assert(operandExprStack.size() >= 2); 1287 1288 MLIRContext *context = expr.getContext(); 1289 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1290 operandExprStack.pop_back(); 1291 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1292 1293 // Flatten semi affine division expressions by introducing a local 1294 // variable in place of the quotient, and the affine expression corresponding 1295 // to the quantifier is added to `localExprs`. 1296 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1297 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, 1298 localExprs, context); 1299 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1300 localExprs, context); 1301 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1302 addLocalVariableSemiAffine(divExpr, lhs, lhs.size()); 1303 return; 1304 } 1305 1306 // This is a pure affine expr; the RHS is a positive constant. 1307 int64_t rhsConst = rhs[getConstantIndex()]; 1308 // TODO: handle division by zero at the same time the issue is 1309 // fixed at other places. 1310 assert(rhsConst > 0 && "RHS constant has to be positive"); 1311 1312 // Simplify the floordiv, ceildiv if possible by canceling out the greatest 1313 // common divisors of the numerator and denominator. 1314 uint64_t gcd = std::abs(rhsConst); 1315 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1316 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 1317 // Simplify the numerator and the denominator. 1318 if (gcd != 1) { 1319 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1320 lhs[i] = lhs[i] / static_cast<int64_t>(gcd); 1321 } 1322 int64_t divisor = rhsConst / static_cast<int64_t>(gcd); 1323 // If the divisor becomes 1, the updated LHS is the result. (The 1324 // divisor can't be negative since rhsConst is positive). 1325 if (divisor == 1) 1326 return; 1327 1328 // If the divisor cannot be simplified to one, we will have to retain 1329 // the ceil/floor expr (simplified up until here). Add an existential 1330 // quantifier to express its result, i.e., expr1 div expr2 is replaced 1331 // by a new identifier, q. 1332 AffineExpr a = 1333 getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context); 1334 AffineExpr b = getAffineConstantExpr(divisor, context); 1335 1336 int loc; 1337 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1338 if ((loc = findLocalId(divExpr)) == -1) { 1339 if (!isCeil) { 1340 SmallVector<int64_t, 8> dividend(lhs); 1341 addLocalFloorDivId(dividend, divisor, divExpr); 1342 } else { 1343 // lhs ceildiv c <=> (lhs + c - 1) floordiv c 1344 SmallVector<int64_t, 8> dividend(lhs); 1345 dividend.back() += divisor - 1; 1346 addLocalFloorDivId(dividend, divisor, divExpr); 1347 } 1348 } 1349 // Set the expression on stack to the local var introduced to capture the 1350 // result of the division (floor or ceil). 1351 std::fill(lhs.begin(), lhs.end(), 0); 1352 if (loc == -1) 1353 lhs[getLocalVarStartIndex() + numLocals - 1] = 1; 1354 else 1355 lhs[getLocalVarStartIndex() + loc] = 1; 1356 } 1357 1358 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 1359 // The local identifier added is always a floordiv of a pure add/mul affine 1360 // function of other identifiers, coefficients of which are specified in 1361 // dividend and with respect to a positive constant divisor. localExpr is the 1362 // simplified tree expression (AffineExpr) corresponding to the quantifier. 1363 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend, 1364 int64_t divisor, 1365 AffineExpr localExpr) { 1366 assert(divisor > 0 && "positive constant divisor expected"); 1367 for (SmallVector<int64_t, 8> &subExpr : operandExprStack) 1368 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1369 localExprs.push_back(localExpr); 1370 numLocals++; 1371 // dividend and divisor are not used here; an override of this method uses it. 1372 } 1373 1374 void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) { 1375 for (SmallVector<int64_t, 8> &subExpr : operandExprStack) 1376 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1377 localExprs.push_back(localExpr); 1378 ++numLocals; 1379 } 1380 1381 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) { 1382 SmallVectorImpl<AffineExpr>::iterator it; 1383 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end()) 1384 return -1; 1385 return it - localExprs.begin(); 1386 } 1387 1388 /// Simplify the affine expression by flattening it and reconstructing it. 1389 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, 1390 unsigned numSymbols) { 1391 // Simplify semi-affine expressions separately. 1392 if (!expr.isPureAffine()) 1393 expr = simplifySemiAffine(expr); 1394 1395 SimpleAffineExprFlattener flattener(numDims, numSymbols); 1396 flattener.walkPostOrder(expr); 1397 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); 1398 if (!expr.isPureAffine() && 1399 expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1400 flattener.localExprs, 1401 expr.getContext())) 1402 return expr; 1403 AffineExpr simplifiedExpr = 1404 expr.isPureAffine() 1405 ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1406 flattener.localExprs, expr.getContext()) 1407 : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1408 flattener.localExprs, 1409 expr.getContext()); 1410 1411 flattener.operandExprStack.pop_back(); 1412 assert(flattener.operandExprStack.empty()); 1413 return simplifiedExpr; 1414 } 1415