1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/IR/AffineExpr.h" 12 #include "AffineExprDetail.h" 13 #include "mlir/IR/AffineExprVisitor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/IntegerSet.h" 16 #include "mlir/Support/MathExtras.h" 17 #include "mlir/Support/TypeID.h" 18 #include "llvm/ADT/STLExtras.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 MLIRContext *AffineExpr::getContext() const { return expr->context; } 24 25 AffineExprKind AffineExpr::getKind() const { return expr->kind; } 26 27 /// Walk all of the AffineExprs in this subgraph in postorder. 28 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const { 29 struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> { 30 std::function<void(AffineExpr)> callback; 31 32 AffineExprWalker(std::function<void(AffineExpr)> callback) 33 : callback(std::move(callback)) {} 34 35 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); } 36 void visitConstantExpr(AffineConstantExpr expr) { callback(expr); } 37 void visitDimExpr(AffineDimExpr expr) { callback(expr); } 38 void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); } 39 }; 40 41 AffineExprWalker(std::move(callback)).walkPostOrder(*this); 42 } 43 44 // Dispatch affine expression construction based on kind. 45 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, 46 AffineExpr rhs) { 47 if (kind == AffineExprKind::Add) 48 return lhs + rhs; 49 if (kind == AffineExprKind::Mul) 50 return lhs * rhs; 51 if (kind == AffineExprKind::FloorDiv) 52 return lhs.floorDiv(rhs); 53 if (kind == AffineExprKind::CeilDiv) 54 return lhs.ceilDiv(rhs); 55 if (kind == AffineExprKind::Mod) 56 return lhs % rhs; 57 58 llvm_unreachable("unknown binary operation on affine expressions"); 59 } 60 61 /// This method substitutes any uses of dimensions and symbols (e.g. 62 /// dim#0 with dimReplacements[0]) and returns the modified expression tree. 63 AffineExpr 64 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 65 ArrayRef<AffineExpr> symReplacements) const { 66 switch (getKind()) { 67 case AffineExprKind::Constant: 68 return *this; 69 case AffineExprKind::DimId: { 70 unsigned dimId = cast<AffineDimExpr>().getPosition(); 71 if (dimId >= dimReplacements.size()) 72 return *this; 73 return dimReplacements[dimId]; 74 } 75 case AffineExprKind::SymbolId: { 76 unsigned symId = cast<AffineSymbolExpr>().getPosition(); 77 if (symId >= symReplacements.size()) 78 return *this; 79 return symReplacements[symId]; 80 } 81 case AffineExprKind::Add: 82 case AffineExprKind::Mul: 83 case AffineExprKind::FloorDiv: 84 case AffineExprKind::CeilDiv: 85 case AffineExprKind::Mod: 86 auto binOp = cast<AffineBinaryOpExpr>(); 87 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 88 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 89 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 90 if (newLHS == lhs && newRHS == rhs) 91 return *this; 92 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 93 } 94 llvm_unreachable("Unknown AffineExpr"); 95 } 96 97 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const { 98 return replaceDimsAndSymbols(dimReplacements, {}); 99 } 100 101 AffineExpr 102 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const { 103 return replaceDimsAndSymbols({}, symReplacements); 104 } 105 106 /// Replace dims[offset ... numDims) 107 /// by dims[offset + shift ... shift + numDims). 108 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift, 109 unsigned offset) const { 110 SmallVector<AffineExpr, 4> dims; 111 for (unsigned idx = 0; idx < offset; ++idx) 112 dims.push_back(getAffineDimExpr(idx, getContext())); 113 for (unsigned idx = offset; idx < numDims; ++idx) 114 dims.push_back(getAffineDimExpr(idx + shift, getContext())); 115 return replaceDimsAndSymbols(dims, {}); 116 } 117 118 /// Replace symbols[offset ... numSymbols) 119 /// by symbols[offset + shift ... shift + numSymbols). 120 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift, 121 unsigned offset) const { 122 SmallVector<AffineExpr, 4> symbols; 123 for (unsigned idx = 0; idx < offset; ++idx) 124 symbols.push_back(getAffineSymbolExpr(idx, getContext())); 125 for (unsigned idx = offset; idx < numSymbols; ++idx) 126 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext())); 127 return replaceDimsAndSymbols({}, symbols); 128 } 129 130 /// Sparse replace method. Return the modified expression tree. 131 AffineExpr 132 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 133 auto it = map.find(*this); 134 if (it != map.end()) 135 return it->second; 136 switch (getKind()) { 137 default: 138 return *this; 139 case AffineExprKind::Add: 140 case AffineExprKind::Mul: 141 case AffineExprKind::FloorDiv: 142 case AffineExprKind::CeilDiv: 143 case AffineExprKind::Mod: 144 auto binOp = cast<AffineBinaryOpExpr>(); 145 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 146 auto newLHS = lhs.replace(map); 147 auto newRHS = rhs.replace(map); 148 if (newLHS == lhs && newRHS == rhs) 149 return *this; 150 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 151 } 152 llvm_unreachable("Unknown AffineExpr"); 153 } 154 155 /// Sparse replace method. Return the modified expression tree. 156 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const { 157 DenseMap<AffineExpr, AffineExpr> map; 158 map.insert(std::make_pair(expr, replacement)); 159 return replace(map); 160 } 161 /// Returns true if this expression is made out of only symbols and 162 /// constants (no dimensional identifiers). 163 bool AffineExpr::isSymbolicOrConstant() const { 164 switch (getKind()) { 165 case AffineExprKind::Constant: 166 return true; 167 case AffineExprKind::DimId: 168 return false; 169 case AffineExprKind::SymbolId: 170 return true; 171 172 case AffineExprKind::Add: 173 case AffineExprKind::Mul: 174 case AffineExprKind::FloorDiv: 175 case AffineExprKind::CeilDiv: 176 case AffineExprKind::Mod: { 177 auto expr = this->cast<AffineBinaryOpExpr>(); 178 return expr.getLHS().isSymbolicOrConstant() && 179 expr.getRHS().isSymbolicOrConstant(); 180 } 181 } 182 llvm_unreachable("Unknown AffineExpr"); 183 } 184 185 /// Returns true if this is a pure affine expression, i.e., multiplication, 186 /// floordiv, ceildiv, and mod is only allowed w.r.t constants. 187 bool AffineExpr::isPureAffine() const { 188 switch (getKind()) { 189 case AffineExprKind::SymbolId: 190 case AffineExprKind::DimId: 191 case AffineExprKind::Constant: 192 return true; 193 case AffineExprKind::Add: { 194 auto op = cast<AffineBinaryOpExpr>(); 195 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine(); 196 } 197 198 case AffineExprKind::Mul: { 199 // TODO: Canonicalize the constants in binary operators to the RHS when 200 // possible, allowing this to merge into the next case. 201 auto op = cast<AffineBinaryOpExpr>(); 202 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() && 203 (op.getLHS().template isa<AffineConstantExpr>() || 204 op.getRHS().template isa<AffineConstantExpr>()); 205 } 206 case AffineExprKind::FloorDiv: 207 case AffineExprKind::CeilDiv: 208 case AffineExprKind::Mod: { 209 auto op = cast<AffineBinaryOpExpr>(); 210 return op.getLHS().isPureAffine() && 211 op.getRHS().template isa<AffineConstantExpr>(); 212 } 213 } 214 llvm_unreachable("Unknown AffineExpr"); 215 } 216 217 // Returns the greatest known integral divisor of this affine expression. 218 int64_t AffineExpr::getLargestKnownDivisor() const { 219 AffineBinaryOpExpr binExpr(nullptr); 220 switch (getKind()) { 221 case AffineExprKind::CeilDiv: 222 LLVM_FALLTHROUGH; 223 case AffineExprKind::DimId: 224 case AffineExprKind::FloorDiv: 225 case AffineExprKind::SymbolId: 226 return 1; 227 case AffineExprKind::Constant: 228 return std::abs(this->cast<AffineConstantExpr>().getValue()); 229 case AffineExprKind::Mul: { 230 binExpr = this->cast<AffineBinaryOpExpr>(); 231 return binExpr.getLHS().getLargestKnownDivisor() * 232 binExpr.getRHS().getLargestKnownDivisor(); 233 } 234 case AffineExprKind::Add: 235 LLVM_FALLTHROUGH; 236 case AffineExprKind::Mod: { 237 binExpr = cast<AffineBinaryOpExpr>(); 238 return llvm::GreatestCommonDivisor64( 239 binExpr.getLHS().getLargestKnownDivisor(), 240 binExpr.getRHS().getLargestKnownDivisor()); 241 } 242 } 243 llvm_unreachable("Unknown AffineExpr"); 244 } 245 246 bool AffineExpr::isMultipleOf(int64_t factor) const { 247 AffineBinaryOpExpr binExpr(nullptr); 248 uint64_t l, u; 249 switch (getKind()) { 250 case AffineExprKind::SymbolId: 251 LLVM_FALLTHROUGH; 252 case AffineExprKind::DimId: 253 return factor * factor == 1; 254 case AffineExprKind::Constant: 255 return cast<AffineConstantExpr>().getValue() % factor == 0; 256 case AffineExprKind::Mul: { 257 binExpr = cast<AffineBinaryOpExpr>(); 258 // It's probably not worth optimizing this further (to not traverse the 259 // whole sub-tree under - it that would require a version of isMultipleOf 260 // that on a 'false' return also returns the largest known divisor). 261 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 || 262 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 || 263 (l * u) % factor == 0; 264 } 265 case AffineExprKind::Add: 266 case AffineExprKind::FloorDiv: 267 case AffineExprKind::CeilDiv: 268 case AffineExprKind::Mod: { 269 binExpr = cast<AffineBinaryOpExpr>(); 270 return llvm::GreatestCommonDivisor64( 271 binExpr.getLHS().getLargestKnownDivisor(), 272 binExpr.getRHS().getLargestKnownDivisor()) % 273 factor == 274 0; 275 } 276 } 277 llvm_unreachable("Unknown AffineExpr"); 278 } 279 280 bool AffineExpr::isFunctionOfDim(unsigned position) const { 281 if (getKind() == AffineExprKind::DimId) { 282 return *this == mlir::getAffineDimExpr(position, getContext()); 283 } 284 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 285 return expr.getLHS().isFunctionOfDim(position) || 286 expr.getRHS().isFunctionOfDim(position); 287 } 288 return false; 289 } 290 291 bool AffineExpr::isFunctionOfSymbol(unsigned position) const { 292 if (getKind() == AffineExprKind::SymbolId) { 293 return *this == mlir::getAffineSymbolExpr(position, getContext()); 294 } 295 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 296 return expr.getLHS().isFunctionOfSymbol(position) || 297 expr.getRHS().isFunctionOfSymbol(position); 298 } 299 return false; 300 } 301 302 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr) 303 : AffineExpr(ptr) {} 304 AffineExpr AffineBinaryOpExpr::getLHS() const { 305 return static_cast<ImplType *>(expr)->lhs; 306 } 307 AffineExpr AffineBinaryOpExpr::getRHS() const { 308 return static_cast<ImplType *>(expr)->rhs; 309 } 310 311 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {} 312 unsigned AffineDimExpr::getPosition() const { 313 return static_cast<ImplType *>(expr)->position; 314 } 315 316 /// Returns true if the expression is divisible by the given symbol with 317 /// position `symbolPos`. The argument `opKind` specifies here what kind of 318 /// division or mod operation called this division. It helps in implementing the 319 /// commutative property of the floordiv and ceildiv operations. If the argument 320 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv 321 /// operation, then the commutative property can be used otherwise, the floordiv 322 /// operation is not divisible. The same argument holds for ceildiv operation. 323 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, 324 AffineExprKind opKind) { 325 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 326 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 327 opKind == AffineExprKind::CeilDiv) && 328 "unexpected opKind"); 329 switch (expr.getKind()) { 330 case AffineExprKind::Constant: 331 if (expr.cast<AffineConstantExpr>().getValue()) 332 return false; 333 return true; 334 case AffineExprKind::DimId: 335 return false; 336 case AffineExprKind::SymbolId: 337 return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos); 338 // Checks divisibility by the given symbol for both operands. 339 case AffineExprKind::Add: { 340 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 341 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) && 342 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 343 } 344 // Checks divisibility by the given symbol for both operands. Consider the 345 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, 346 // this is a division by s1 and both the operands of modulo are divisible by 347 // s1 but it is not divisible by s1 always. The third argument is 348 // `AffineExprKind::Mod` for this reason. 349 case AffineExprKind::Mod: { 350 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 351 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, 352 AffineExprKind::Mod) && 353 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, 354 AffineExprKind::Mod); 355 } 356 // Checks if any of the operand divisible by the given symbol. 357 case AffineExprKind::Mul: { 358 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 359 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) || 360 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); 361 } 362 // Floordiv and ceildiv are divisible by the given symbol when the first 363 // operand is divisible, and the affine expression kind of the argument expr 364 // is same as the argument `opKind`. This can be inferred from commutative 365 // property of floordiv and ceildiv operations and are as follow: 366 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 367 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 368 // It will fail if operations are not same. For example: 369 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 370 case AffineExprKind::FloorDiv: 371 case AffineExprKind::CeilDiv: { 372 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 373 if (opKind != expr.getKind()) 374 return false; 375 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()); 376 } 377 } 378 llvm_unreachable("Unknown AffineExpr"); 379 } 380 381 /// Divides the given expression by the given symbol at position `symbolPos`. It 382 /// considers the divisibility condition is checked before calling itself. A 383 /// null expression is returned whenever the divisibility condition fails. 384 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, 385 AffineExprKind opKind) { 386 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only. 387 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || 388 opKind == AffineExprKind::CeilDiv) && 389 "unexpected opKind"); 390 switch (expr.getKind()) { 391 case AffineExprKind::Constant: 392 if (expr.cast<AffineConstantExpr>().getValue() != 0) 393 return nullptr; 394 return getAffineConstantExpr(0, expr.getContext()); 395 case AffineExprKind::DimId: 396 return nullptr; 397 case AffineExprKind::SymbolId: 398 return getAffineConstantExpr(1, expr.getContext()); 399 // Dividing both operands by the given symbol. 400 case AffineExprKind::Add: { 401 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 402 return getAffineBinaryOpExpr( 403 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind), 404 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind)); 405 } 406 // Dividing both operands by the given symbol. 407 case AffineExprKind::Mod: { 408 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 409 return getAffineBinaryOpExpr( 410 expr.getKind(), 411 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 412 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind())); 413 } 414 // Dividing any of the operand by the given symbol. 415 case AffineExprKind::Mul: { 416 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 417 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind)) 418 return binaryExpr.getLHS() * 419 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind); 420 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) * 421 binaryExpr.getRHS(); 422 } 423 // Dividing first operand only by the given symbol. 424 case AffineExprKind::FloorDiv: 425 case AffineExprKind::CeilDiv: { 426 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 427 return getAffineBinaryOpExpr( 428 expr.getKind(), 429 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), 430 binaryExpr.getRHS()); 431 } 432 } 433 llvm_unreachable("Unknown AffineExpr"); 434 } 435 436 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv 437 /// operations when the second operand simplifies to a symbol and the first 438 /// operand is divisible by that symbol. It can be applied to any semi-affine 439 /// expression. Returned expression can either be a semi-affine or pure affine 440 /// expression. 441 static AffineExpr simplifySemiAffine(AffineExpr expr) { 442 switch (expr.getKind()) { 443 case AffineExprKind::Constant: 444 case AffineExprKind::DimId: 445 case AffineExprKind::SymbolId: 446 return expr; 447 case AffineExprKind::Add: 448 case AffineExprKind::Mul: { 449 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 450 return getAffineBinaryOpExpr(expr.getKind(), 451 simplifySemiAffine(binaryExpr.getLHS()), 452 simplifySemiAffine(binaryExpr.getRHS())); 453 } 454 // Check if the simplification of the second operand is a symbol, and the 455 // first operand is divisible by it. If the operation is a modulo, a constant 456 // zero expression is returned. In the case of floordiv and ceildiv, the 457 // symbol from the simplification of the second operand divides the first 458 // operand. Otherwise, simplification is not possible. 459 case AffineExprKind::FloorDiv: 460 case AffineExprKind::CeilDiv: 461 case AffineExprKind::Mod: { 462 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 463 AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS()); 464 AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS()); 465 AffineSymbolExpr symbolExpr = 466 simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>(); 467 if (!symbolExpr) 468 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 469 unsigned symbolPos = symbolExpr.getPosition(); 470 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind())) 471 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); 472 if (expr.getKind() == AffineExprKind::Mod) 473 return getAffineConstantExpr(0, expr.getContext()); 474 return symbolicDivide(sLHS, symbolPos, expr.getKind()); 475 } 476 } 477 llvm_unreachable("Unknown AffineExpr"); 478 } 479 480 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, 481 MLIRContext *context) { 482 auto assignCtx = [context](AffineDimExprStorage *storage) { 483 storage->context = context; 484 }; 485 486 StorageUniquer &uniquer = context->getAffineUniquer(); 487 return uniquer.get<AffineDimExprStorage>( 488 assignCtx, static_cast<unsigned>(kind), position); 489 } 490 491 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { 492 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context); 493 } 494 495 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr) 496 : AffineExpr(ptr) {} 497 unsigned AffineSymbolExpr::getPosition() const { 498 return static_cast<ImplType *>(expr)->position; 499 } 500 501 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { 502 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context); 503 ; 504 } 505 506 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr) 507 : AffineExpr(ptr) {} 508 int64_t AffineConstantExpr::getValue() const { 509 return static_cast<ImplType *>(expr)->constant; 510 } 511 512 bool AffineExpr::operator==(int64_t v) const { 513 return *this == getAffineConstantExpr(v, getContext()); 514 } 515 516 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { 517 auto assignCtx = [context](AffineConstantExprStorage *storage) { 518 storage->context = context; 519 }; 520 521 StorageUniquer &uniquer = context->getAffineUniquer(); 522 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant); 523 } 524 525 /// Simplify add expression. Return nullptr if it can't be simplified. 526 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { 527 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 528 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 529 // Fold if both LHS, RHS are a constant. 530 if (lhsConst && rhsConst) 531 return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), 532 lhs.getContext()); 533 534 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). 535 // If only one of them is a symbolic expressions, make it the RHS. 536 if (lhs.isa<AffineConstantExpr>() || 537 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) { 538 return rhs + lhs; 539 } 540 541 // At this point, if there was a constant, it would be on the right. 542 543 // Addition with a zero is a noop, return the other input. 544 if (rhsConst) { 545 if (rhsConst.getValue() == 0) 546 return lhs; 547 } 548 // Fold successive additions like (d0 + 2) + 3 into d0 + 5. 549 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 550 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) { 551 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 552 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); 553 } 554 555 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr". 556 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their 557 // respective multiplicands. 558 Optional<int64_t> rLhsConst, rRhsConst; 559 AffineExpr firstExpr, secondExpr; 560 AffineConstantExpr rLhsConstExpr; 561 auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>(); 562 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul && 563 (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) { 564 rLhsConst = rLhsConstExpr.getValue(); 565 firstExpr = lBinOpExpr.getLHS(); 566 } else { 567 rLhsConst = 1; 568 firstExpr = lhs; 569 } 570 571 auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>(); 572 AffineConstantExpr rRhsConstExpr; 573 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul && 574 (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) { 575 rRhsConst = rRhsConstExpr.getValue(); 576 secondExpr = rBinOpExpr.getLHS(); 577 } else { 578 rRhsConst = 1; 579 secondExpr = rhs; 580 } 581 582 if (rLhsConst && rRhsConst && firstExpr == secondExpr) 583 return getAffineBinaryOpExpr( 584 AffineExprKind::Mul, firstExpr, 585 getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(), 586 lhs.getContext())); 587 588 // When doing successive additions, bring constant to the right: turn (d0 + 2) 589 // + d1 into (d0 + d1) + 2. 590 if (lBin && lBin.getKind() == AffineExprKind::Add) { 591 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 592 return lBin.getLHS() + rhs + lrhs; 593 } 594 } 595 596 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where 597 // q may be a constant or symbolic expression. This leads to a much more 598 // efficient form when 'c' is a power of two, and in general a more compact 599 // and readable form. 600 601 // Process '(expr floordiv c) * (-c)'. 602 if (!rBinOpExpr) 603 return nullptr; 604 605 auto lrhs = rBinOpExpr.getLHS(); 606 auto rrhs = rBinOpExpr.getRHS(); 607 608 AffineExpr llrhs, rlrhs; 609 610 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a 611 // symbolic expression. 612 auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>(); 613 // Check rrhsConstOpExpr = -1. 614 auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>(); 615 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr && 616 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) { 617 // Check llrhs = expr floordiv q. 618 llrhs = lrhsBinOpExpr.getLHS(); 619 // Check rlrhs = q. 620 rlrhs = lrhsBinOpExpr.getRHS(); 621 auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>(); 622 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv) 623 return nullptr; 624 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS()) 625 return lhs % rlrhs; 626 } 627 628 // Process lrhs, which is 'expr floordiv c'. 629 AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>(); 630 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) 631 return nullptr; 632 633 llrhs = lrBinOpExpr.getLHS(); 634 rlrhs = lrBinOpExpr.getRHS(); 635 636 if (lhs == llrhs && rlrhs == -rrhs) { 637 return lhs % rlrhs; 638 } 639 return nullptr; 640 } 641 642 AffineExpr AffineExpr::operator+(int64_t v) const { 643 return *this + getAffineConstantExpr(v, getContext()); 644 } 645 AffineExpr AffineExpr::operator+(AffineExpr other) const { 646 if (auto simplified = simplifyAdd(*this, other)) 647 return simplified; 648 649 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 650 return uniquer.get<AffineBinaryOpExprStorage>( 651 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other); 652 } 653 654 /// Simplify a multiply expression. Return nullptr if it can't be simplified. 655 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { 656 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 657 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 658 659 if (lhsConst && rhsConst) 660 return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), 661 lhs.getContext()); 662 663 assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant()); 664 665 // Canonicalize the mul expression so that the constant/symbolic term is the 666 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a 667 // constant. (Note that a constant is trivially symbolic). 668 if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) { 669 // At least one of them has to be symbolic. 670 return rhs * lhs; 671 } 672 673 // At this point, if there was a constant, it would be on the right. 674 675 // Multiplication with a one is a noop, return the other input. 676 if (rhsConst) { 677 if (rhsConst.getValue() == 1) 678 return lhs; 679 // Multiplication with zero. 680 if (rhsConst.getValue() == 0) 681 return rhsConst; 682 } 683 684 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. 685 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 686 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) { 687 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 688 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue()); 689 } 690 691 // When doing successive multiplication, bring constant to the right: turn (d0 692 // * 2) * d1 into (d0 * d1) * 2. 693 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 694 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 695 return (lBin.getLHS() * rhs) * lrhs; 696 } 697 } 698 699 return nullptr; 700 } 701 702 AffineExpr AffineExpr::operator*(int64_t v) const { 703 return *this * getAffineConstantExpr(v, getContext()); 704 } 705 AffineExpr AffineExpr::operator*(AffineExpr other) const { 706 if (auto simplified = simplifyMul(*this, other)) 707 return simplified; 708 709 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 710 return uniquer.get<AffineBinaryOpExprStorage>( 711 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other); 712 } 713 714 // Unary minus, delegate to operator*. 715 AffineExpr AffineExpr::operator-() const { 716 return *this * getAffineConstantExpr(-1, getContext()); 717 } 718 719 // Delegate to operator+. 720 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } 721 AffineExpr AffineExpr::operator-(AffineExpr other) const { 722 return *this + (-other); 723 } 724 725 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { 726 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 727 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 728 729 // mlir floordiv by zero or negative numbers is undefined and preserved as is. 730 if (!rhsConst || rhsConst.getValue() < 1) 731 return nullptr; 732 733 if (lhsConst) 734 return getAffineConstantExpr( 735 floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 736 737 // Fold floordiv of a multiply with a constant that is a multiple of the 738 // divisor. Eg: (i * 128) floordiv 64 = i * 2. 739 if (rhsConst == 1) 740 return lhs; 741 742 // Simplify (expr * const) floordiv divConst when expr is known to be a 743 // multiple of divConst. 744 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 745 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 746 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 747 // rhsConst is known to be a positive constant. 748 if (lrhs.getValue() % rhsConst.getValue() == 0) 749 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 750 } 751 } 752 753 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is 754 // known to be a multiple of divConst. 755 if (lBin && lBin.getKind() == AffineExprKind::Add) { 756 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 757 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 758 // rhsConst is known to be a positive constant. 759 if (llhsDiv % rhsConst.getValue() == 0 || 760 lrhsDiv % rhsConst.getValue() == 0) 761 return lBin.getLHS().floorDiv(rhsConst.getValue()) + 762 lBin.getRHS().floorDiv(rhsConst.getValue()); 763 } 764 765 return nullptr; 766 } 767 768 AffineExpr AffineExpr::floorDiv(uint64_t v) const { 769 return floorDiv(getAffineConstantExpr(v, getContext())); 770 } 771 AffineExpr AffineExpr::floorDiv(AffineExpr other) const { 772 if (auto simplified = simplifyFloorDiv(*this, other)) 773 return simplified; 774 775 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 776 return uniquer.get<AffineBinaryOpExprStorage>( 777 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this, 778 other); 779 } 780 781 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { 782 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 783 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 784 785 if (!rhsConst || rhsConst.getValue() < 1) 786 return nullptr; 787 788 if (lhsConst) 789 return getAffineConstantExpr( 790 ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 791 792 // Fold ceildiv of a multiply with a constant that is a multiple of the 793 // divisor. Eg: (i * 128) ceildiv 64 = i * 2. 794 if (rhsConst.getValue() == 1) 795 return lhs; 796 797 // Simplify (expr * const) ceildiv divConst when const is known to be a 798 // multiple of divConst. 799 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 800 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 801 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 802 // rhsConst is known to be a positive constant. 803 if (lrhs.getValue() % rhsConst.getValue() == 0) 804 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 805 } 806 } 807 808 return nullptr; 809 } 810 811 AffineExpr AffineExpr::ceilDiv(uint64_t v) const { 812 return ceilDiv(getAffineConstantExpr(v, getContext())); 813 } 814 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const { 815 if (auto simplified = simplifyCeilDiv(*this, other)) 816 return simplified; 817 818 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 819 return uniquer.get<AffineBinaryOpExprStorage>( 820 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this, 821 other); 822 } 823 824 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { 825 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 826 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 827 828 // mod w.r.t zero or negative numbers is undefined and preserved as is. 829 if (!rhsConst || rhsConst.getValue() < 1) 830 return nullptr; 831 832 if (lhsConst) 833 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), 834 lhs.getContext()); 835 836 // Fold modulo of an expression that is known to be a multiple of a constant 837 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) 838 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0. 839 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) 840 return getAffineConstantExpr(0, lhs.getContext()); 841 842 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is 843 // known to be a multiple of divConst. 844 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 845 if (lBin && lBin.getKind() == AffineExprKind::Add) { 846 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); 847 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); 848 // rhsConst is known to be a positive constant. 849 if (llhsDiv % rhsConst.getValue() == 0) 850 return lBin.getRHS() % rhsConst.getValue(); 851 if (lrhsDiv % rhsConst.getValue() == 0) 852 return lBin.getLHS() % rhsConst.getValue(); 853 } 854 855 // Simplify (e % a) % b to e % b when b evenly divides a 856 if (lBin && lBin.getKind() == AffineExprKind::Mod) { 857 auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>(); 858 if (intermediate && intermediate.getValue() >= 1 && 859 mod(intermediate.getValue(), rhsConst.getValue()) == 0) { 860 return lBin.getLHS() % rhsConst.getValue(); 861 } 862 } 863 864 return nullptr; 865 } 866 867 AffineExpr AffineExpr::operator%(uint64_t v) const { 868 return *this % getAffineConstantExpr(v, getContext()); 869 } 870 AffineExpr AffineExpr::operator%(AffineExpr other) const { 871 if (auto simplified = simplifyMod(*this, other)) 872 return simplified; 873 874 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 875 return uniquer.get<AffineBinaryOpExprStorage>( 876 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other); 877 } 878 879 AffineExpr AffineExpr::compose(AffineMap map) const { 880 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(), 881 map.getResults().end()); 882 return replaceDimsAndSymbols(dimReplacements, {}); 883 } 884 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) { 885 expr.print(os); 886 return os; 887 } 888 889 /// Constructs an affine expression from a flat ArrayRef. If there are local 890 /// identifiers (neither dimensional nor symbolic) that appear in the sum of 891 /// products expression, `localExprs` is expected to have the AffineExpr 892 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be 893 /// in the format [dims, symbols, locals, constant term]. 894 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 895 unsigned numDims, 896 unsigned numSymbols, 897 ArrayRef<AffineExpr> localExprs, 898 MLIRContext *context) { 899 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 900 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 901 "unexpected number of local expressions"); 902 903 auto expr = getAffineConstantExpr(0, context); 904 // Dimensions and symbols. 905 for (unsigned j = 0; j < numDims + numSymbols; j++) { 906 if (flatExprs[j] == 0) 907 continue; 908 auto id = j < numDims ? getAffineDimExpr(j, context) 909 : getAffineSymbolExpr(j - numDims, context); 910 expr = expr + id * flatExprs[j]; 911 } 912 913 // Local identifiers. 914 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 915 j++) { 916 if (flatExprs[j] == 0) 917 continue; 918 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 919 expr = expr + term; 920 } 921 922 // Constant term. 923 int64_t constTerm = flatExprs[flatExprs.size() - 1]; 924 if (constTerm != 0) 925 expr = expr + constTerm; 926 return expr; 927 } 928 929 /// Constructs a semi-affine expression from a flat ArrayRef. If there are 930 /// local identifiers (neither dimensional nor symbolic) that appear in the sum 931 /// of products expression, `localExprs` is expected to have the AffineExprs for 932 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in 933 /// the format [dims, symbols, locals, constant term]. The semi-affine 934 /// expression is constructed in the sorted order of dimension and symbol 935 /// position numbers. Note: local expressions/ids are used for mod, div as well 936 /// as symbolic RHS terms for terms that are not pure affine. 937 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs, 938 unsigned numDims, 939 unsigned numSymbols, 940 ArrayRef<AffineExpr> localExprs, 941 MLIRContext *context) { 942 assert(!flatExprs.empty() && "flatExprs cannot be empty"); 943 944 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1. 945 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() && 946 "unexpected number of local expressions"); 947 948 AffineExpr expr = getAffineConstantExpr(0, context); 949 950 // We design indices as a pair which help us present the semi-affine map as 951 // sum of product where terms are sorted based on dimension or symbol 952 // position: <keyA, keyB> for expressions of the form dimension * symbol, 953 // where keyA is the position number of the dimension and keyB is the 954 // position number of the symbol. For dimensional expressions we set the index 955 // as (position number of the dimension, -1), as we want dimensional 956 // expressions to appear before symbolic and product of dimensional and 957 // symbolic expressions having the dimension with the same position number. 958 // For symbolic expression set the index as (position number of the symbol, 959 // maximum of last dimension and symbol position) number. For example, we want 960 // the expression we are constructing to look something like: d0 + d0 * s0 + 961 // s0 + d1*s1 + s1. 962 963 // Stores the affine expression corresponding to a given index. 964 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap; 965 // Stores the constant coefficient value corresponding to a given 966 // dimension, symbol or a non-pure affine expression stored in `localExprs`. 967 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients; 968 // Stores the indices as defined above, and later sorted to produce 969 // the semi-affine expression in the desired form. 970 SmallVector<std::pair<unsigned, signed>, 8> indices; 971 972 // Example: expression = d0 + d0 * s0 + 2 * s0. 973 // indices = [{0,-1}, {0, 0}, {0, 1}] 974 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}] 975 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}] 976 977 // Adds entries to `indexToExprMap`, `coefficients` and `indices`. 978 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient, 979 AffineExpr expr) { 980 assert(std::find(indices.begin(), indices.end(), index) == indices.end() && 981 "Key is already present in indices vector and overwriting will " 982 "happen in `indexToExprMap` and `coefficients`!"); 983 984 indices.push_back(index); 985 coefficients.insert({index, coefficient}); 986 indexToExprMap.insert({index, expr}); 987 }; 988 989 // Design indices for dimensional or symbolic terms, and store the indices, 990 // constant coefficient corresponding to the indices in `coefficients` map, 991 // and affine expression corresponding to indices in `indexToExprMap` map. 992 993 for (unsigned j = 0; j < numDims; ++j) { 994 if (flatExprs[j] == 0) 995 continue; 996 // For dimensional expressions we set the index as <position number of the 997 // dimension, 0>, as we want dimensional expressions to appear before 998 // symbolic ones and products of dimensional and symbolic expressions 999 // having the dimension with the same position number. 1000 std::pair<unsigned, signed> indexEntry(j, -1); 1001 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context)); 1002 } 1003 for (unsigned j = numDims; j < numDims + numSymbols; ++j) { 1004 if (flatExprs[j] == 0) 1005 continue; 1006 // For symbolic expression set the index as <position number 1007 // of the symbol, max(dimCount, symCount)> number, 1008 // as we want symbolic expressions with the same positional number to 1009 // appear after dimensional expressions having the same positional number. 1010 std::pair<unsigned, signed> indexEntry(j - numDims, 1011 std::max(numDims, numSymbols)); 1012 addEntry(indexEntry, flatExprs[j], 1013 getAffineSymbolExpr(j - numDims, context)); 1014 } 1015 1016 // Denotes semi-affine product, modulo or division terms, which has been added 1017 // to the `indexToExpr` map. 1018 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1, 1019 false); 1020 unsigned lhsPos, rhsPos; 1021 // Construct indices for product terms involving dimension, symbol or constant 1022 // as lhs/rhs, and store the indices, constant coefficient corresponding to 1023 // the indices in `coefficients` map, and affine expression corresponding to 1024 // in indices in `indexToExprMap` map. 1025 for (const auto &it : llvm::enumerate(localExprs)) { 1026 AffineExpr expr = it.value(); 1027 if (flatExprs[numDims + numSymbols + it.index()] == 0) 1028 continue; 1029 AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS(); 1030 AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS(); 1031 if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) && 1032 (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() || 1033 rhs.isa<AffineConstantExpr>()))) { 1034 continue; 1035 } 1036 if (rhs.isa<AffineConstantExpr>()) { 1037 // For product/modulo/division expressions, when rhs of modulo/division 1038 // expression is constant, we put 0 in place of keyB, because we want 1039 // them to appear earlier in the semi-affine expression we are 1040 // constructing. When rhs is constant, we place 0 in place of keyB. 1041 if (lhs.isa<AffineDimExpr>()) { 1042 lhsPos = lhs.cast<AffineDimExpr>().getPosition(); 1043 std::pair<unsigned, signed> indexEntry(lhsPos, -1); 1044 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], 1045 expr); 1046 } else { 1047 lhsPos = lhs.cast<AffineSymbolExpr>().getPosition(); 1048 std::pair<unsigned, signed> indexEntry(lhsPos, 1049 std::max(numDims, numSymbols)); 1050 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], 1051 expr); 1052 } 1053 } else if (lhs.isa<AffineDimExpr>()) { 1054 // For product/modulo/division expressions having lhs as dimension and rhs 1055 // as symbol, we order the terms in the semi-affine expression based on 1056 // the pair: <keyA, keyB> for expressions of the form dimension * symbol, 1057 // where keyA is the position number of the dimension and keyB is the 1058 // position number of the symbol. 1059 lhsPos = lhs.cast<AffineDimExpr>().getPosition(); 1060 rhsPos = rhs.cast<AffineSymbolExpr>().getPosition(); 1061 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos); 1062 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); 1063 } else { 1064 // For product/modulo/division expressions having both lhs and rhs as 1065 // symbol, we design indices as a pair: <keyA, keyB> for expressions 1066 // of the form dimension * symbol, where keyA is the position number of 1067 // the dimension and keyB is the position number of the symbol. 1068 lhsPos = lhs.cast<AffineSymbolExpr>().getPosition(); 1069 rhsPos = rhs.cast<AffineSymbolExpr>().getPosition(); 1070 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos); 1071 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr); 1072 } 1073 addedToMap[it.index()] = true; 1074 } 1075 1076 // Constructing the simplified semi-affine sum of product/division/mod 1077 // expression from the flattened form in the desired sorted order of indices 1078 // of the various individual product/division/mod expressions. 1079 std::sort(indices.begin(), indices.end()); 1080 for (const std::pair<unsigned, unsigned> index : indices) { 1081 assert(indexToExprMap.lookup(index) && 1082 "cannot find key in `indexToExprMap` map"); 1083 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index); 1084 } 1085 1086 // Local identifiers. 1087 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e; 1088 j++) { 1089 // If the coefficient of the local expression is 0, continue as we need not 1090 // add it in out final expression. 1091 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols]) 1092 continue; 1093 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j]; 1094 expr = expr + term; 1095 } 1096 1097 // Constant term. 1098 int64_t constTerm = flatExprs.back(); 1099 if (constTerm != 0) 1100 expr = expr + constTerm; 1101 return expr; 1102 } 1103 1104 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, 1105 unsigned numSymbols) 1106 : numDims(numDims), numSymbols(numSymbols), numLocals(0) { 1107 operandExprStack.reserve(8); 1108 } 1109 1110 // In pure affine t = expr * c, we multiply each coefficient of lhs with c. 1111 // 1112 // In case of semi affine multiplication expressions, t = expr * symbolic_expr, 1113 // introduce a local variable p (= expr * symbolic_expr), and the affine 1114 // expression expr * symbolic_expr is added to `localExprs`. 1115 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { 1116 assert(operandExprStack.size() >= 2); 1117 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1118 operandExprStack.pop_back(); 1119 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1120 1121 // Flatten semi-affine multiplication expressions by introducing a local 1122 // variable in place of the product; the affine expression 1123 // corresponding to the quantifier is added to `localExprs`. 1124 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1125 MLIRContext *context = expr.getContext(); 1126 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, 1127 localExprs, context); 1128 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1129 localExprs, context); 1130 addLocalVariableSemiAffine(a * b, lhs, lhs.size()); 1131 return; 1132 } 1133 1134 // Get the RHS constant. 1135 auto rhsConst = rhs[getConstantIndex()]; 1136 for (unsigned i = 0, e = lhs.size(); i < e; i++) { 1137 lhs[i] *= rhsConst; 1138 } 1139 } 1140 1141 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { 1142 assert(operandExprStack.size() >= 2); 1143 const auto &rhs = operandExprStack.back(); 1144 auto &lhs = operandExprStack[operandExprStack.size() - 2]; 1145 assert(lhs.size() == rhs.size()); 1146 // Update the LHS in place. 1147 for (unsigned i = 0, e = rhs.size(); i < e; i++) { 1148 lhs[i] += rhs[i]; 1149 } 1150 // Pop off the RHS. 1151 operandExprStack.pop_back(); 1152 } 1153 1154 // 1155 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 1156 // 1157 // A mod expression "expr mod c" is thus flattened by introducing a new local 1158 // variable q (= expr floordiv c), such that expr mod c is replaced with 1159 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. 1160 // 1161 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr, 1162 // introduce a local variable m (= expr mod symbolic_expr), and the affine 1163 // expression expr mod symbolic_expr is added to `localExprs`. 1164 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { 1165 assert(operandExprStack.size() >= 2); 1166 1167 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1168 operandExprStack.pop_back(); 1169 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1170 MLIRContext *context = expr.getContext(); 1171 1172 // Flatten semi affine modulo expressions by introducing a local 1173 // variable in place of the modulo value, and the affine expression 1174 // corresponding to the quantifier is added to `localExprs`. 1175 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1176 AffineExpr dividendExpr = getAffineExprFromFlatForm( 1177 lhs, numDims, numSymbols, localExprs, context); 1178 AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1179 localExprs, context); 1180 AffineExpr modExpr = dividendExpr % divisorExpr; 1181 addLocalVariableSemiAffine(modExpr, lhs, lhs.size()); 1182 return; 1183 } 1184 1185 int64_t rhsConst = rhs[getConstantIndex()]; 1186 // TODO: handle modulo by zero case when this issue is fixed 1187 // at the other places in the IR. 1188 assert(rhsConst > 0 && "RHS constant has to be positive"); 1189 1190 // Check if the LHS expression is a multiple of modulo factor. 1191 unsigned i, e; 1192 for (i = 0, e = lhs.size(); i < e; i++) 1193 if (lhs[i] % rhsConst != 0) 1194 break; 1195 // If yes, modulo expression here simplifies to zero. 1196 if (i == lhs.size()) { 1197 std::fill(lhs.begin(), lhs.end(), 0); 1198 return; 1199 } 1200 1201 // Add a local variable for the quotient, i.e., expr % c is replaced by 1202 // (expr - q * c) where q = expr floordiv c. Do this while canceling out 1203 // the GCD of expr and c. 1204 SmallVector<int64_t, 8> floorDividend(lhs); 1205 uint64_t gcd = rhsConst; 1206 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1207 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 1208 // Simplify the numerator and the denominator. 1209 if (gcd != 1) { 1210 for (unsigned i = 0, e = floorDividend.size(); i < e; i++) 1211 floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd); 1212 } 1213 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd); 1214 1215 // Construct the AffineExpr form of the floordiv to store in localExprs. 1216 1217 AffineExpr dividendExpr = getAffineExprFromFlatForm( 1218 floorDividend, numDims, numSymbols, localExprs, context); 1219 AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context); 1220 AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr); 1221 int loc; 1222 if ((loc = findLocalId(floorDivExpr)) == -1) { 1223 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); 1224 // Set result at top of stack to "lhs - rhsConst * q". 1225 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; 1226 } else { 1227 // Reuse the existing local id. 1228 lhs[getLocalVarStartIndex() + loc] = -rhsConst; 1229 } 1230 } 1231 1232 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { 1233 visitDivExpr(expr, /*isCeil=*/true); 1234 } 1235 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { 1236 visitDivExpr(expr, /*isCeil=*/false); 1237 } 1238 1239 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { 1240 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1241 auto &eq = operandExprStack.back(); 1242 assert(expr.getPosition() < numDims && "Inconsistent number of dims"); 1243 eq[getDimStartIndex() + expr.getPosition()] = 1; 1244 } 1245 1246 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { 1247 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1248 auto &eq = operandExprStack.back(); 1249 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); 1250 eq[getSymbolStartIndex() + expr.getPosition()] = 1; 1251 } 1252 1253 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { 1254 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 1255 auto &eq = operandExprStack.back(); 1256 eq[getConstantIndex()] = expr.getValue(); 1257 } 1258 1259 void SimpleAffineExprFlattener::addLocalVariableSemiAffine( 1260 AffineExpr expr, SmallVectorImpl<int64_t> &result, 1261 unsigned long resultSize) { 1262 assert(result.size() == resultSize && 1263 "`result` vector passed is not of correct size"); 1264 int loc; 1265 if ((loc = findLocalId(expr)) == -1) 1266 addLocalIdSemiAffine(expr); 1267 std::fill(result.begin(), result.end(), 0); 1268 if (loc == -1) 1269 result[getLocalVarStartIndex() + numLocals - 1] = 1; 1270 else 1271 result[getLocalVarStartIndex() + loc] = 1; 1272 } 1273 1274 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 1275 // A floordiv is thus flattened by introducing a new local variable q, and 1276 // replacing that expression with 'q' while adding the constraints 1277 // c * q <= expr <= c * q + c - 1 to localVarCst (done by 1278 // FlatAffineConstraints::addLocalFloorDiv). 1279 // 1280 // A ceildiv is similarly flattened: 1281 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c 1282 // 1283 // In case of semi affine division expressions, t = expr floordiv symbolic_expr 1284 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr 1285 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to 1286 // `localExprs`. 1287 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, 1288 bool isCeil) { 1289 assert(operandExprStack.size() >= 2); 1290 1291 MLIRContext *context = expr.getContext(); 1292 SmallVector<int64_t, 8> rhs = operandExprStack.back(); 1293 operandExprStack.pop_back(); 1294 SmallVector<int64_t, 8> &lhs = operandExprStack.back(); 1295 1296 // Flatten semi affine division expressions by introducing a local 1297 // variable in place of the quotient, and the affine expression corresponding 1298 // to the quantifier is added to `localExprs`. 1299 if (!expr.getRHS().isa<AffineConstantExpr>()) { 1300 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols, 1301 localExprs, context); 1302 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, 1303 localExprs, context); 1304 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1305 addLocalVariableSemiAffine(divExpr, lhs, lhs.size()); 1306 return; 1307 } 1308 1309 // This is a pure affine expr; the RHS is a positive constant. 1310 int64_t rhsConst = rhs[getConstantIndex()]; 1311 // TODO: handle division by zero at the same time the issue is 1312 // fixed at other places. 1313 assert(rhsConst > 0 && "RHS constant has to be positive"); 1314 1315 // Simplify the floordiv, ceildiv if possible by canceling out the greatest 1316 // common divisors of the numerator and denominator. 1317 uint64_t gcd = std::abs(rhsConst); 1318 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1319 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 1320 // Simplify the numerator and the denominator. 1321 if (gcd != 1) { 1322 for (unsigned i = 0, e = lhs.size(); i < e; i++) 1323 lhs[i] = lhs[i] / static_cast<int64_t>(gcd); 1324 } 1325 int64_t divisor = rhsConst / static_cast<int64_t>(gcd); 1326 // If the divisor becomes 1, the updated LHS is the result. (The 1327 // divisor can't be negative since rhsConst is positive). 1328 if (divisor == 1) 1329 return; 1330 1331 // If the divisor cannot be simplified to one, we will have to retain 1332 // the ceil/floor expr (simplified up until here). Add an existential 1333 // quantifier to express its result, i.e., expr1 div expr2 is replaced 1334 // by a new identifier, q. 1335 AffineExpr a = 1336 getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context); 1337 AffineExpr b = getAffineConstantExpr(divisor, context); 1338 1339 int loc; 1340 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 1341 if ((loc = findLocalId(divExpr)) == -1) { 1342 if (!isCeil) { 1343 SmallVector<int64_t, 8> dividend(lhs); 1344 addLocalFloorDivId(dividend, divisor, divExpr); 1345 } else { 1346 // lhs ceildiv c <=> (lhs + c - 1) floordiv c 1347 SmallVector<int64_t, 8> dividend(lhs); 1348 dividend.back() += divisor - 1; 1349 addLocalFloorDivId(dividend, divisor, divExpr); 1350 } 1351 } 1352 // Set the expression on stack to the local var introduced to capture the 1353 // result of the division (floor or ceil). 1354 std::fill(lhs.begin(), lhs.end(), 0); 1355 if (loc == -1) 1356 lhs[getLocalVarStartIndex() + numLocals - 1] = 1; 1357 else 1358 lhs[getLocalVarStartIndex() + loc] = 1; 1359 } 1360 1361 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 1362 // The local identifier added is always a floordiv of a pure add/mul affine 1363 // function of other identifiers, coefficients of which are specified in 1364 // dividend and with respect to a positive constant divisor. localExpr is the 1365 // simplified tree expression (AffineExpr) corresponding to the quantifier. 1366 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend, 1367 int64_t divisor, 1368 AffineExpr localExpr) { 1369 assert(divisor > 0 && "positive constant divisor expected"); 1370 for (SmallVector<int64_t, 8> &subExpr : operandExprStack) 1371 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1372 localExprs.push_back(localExpr); 1373 numLocals++; 1374 // dividend and divisor are not used here; an override of this method uses it. 1375 } 1376 1377 void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) { 1378 for (SmallVector<int64_t, 8> &subExpr : operandExprStack) 1379 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 1380 localExprs.push_back(localExpr); 1381 ++numLocals; 1382 } 1383 1384 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) { 1385 SmallVectorImpl<AffineExpr>::iterator it; 1386 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end()) 1387 return -1; 1388 return it - localExprs.begin(); 1389 } 1390 1391 /// Simplify the affine expression by flattening it and reconstructing it. 1392 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, 1393 unsigned numSymbols) { 1394 // Simplify semi-affine expressions separately. 1395 if (!expr.isPureAffine()) 1396 expr = simplifySemiAffine(expr); 1397 1398 SimpleAffineExprFlattener flattener(numDims, numSymbols); 1399 flattener.walkPostOrder(expr); 1400 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); 1401 if (!expr.isPureAffine() && 1402 expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1403 flattener.localExprs, 1404 expr.getContext())) 1405 return expr; 1406 AffineExpr simplifiedExpr = 1407 expr.isPureAffine() 1408 ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1409 flattener.localExprs, expr.getContext()) 1410 : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, 1411 flattener.localExprs, 1412 expr.getContext()); 1413 1414 flattener.operandExprStack.pop_back(); 1415 assert(flattener.operandExprStack.empty()); 1416 return simplifiedExpr; 1417 } 1418