1 //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file lowers affine constructs (If and For statements, AffineApply 10 // operations) within a function into their standard If and For equivalent ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 15 16 #include "../PassDetail.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/AffineExprVisitor.h" 24 #include "mlir/IR/BlockAndValueMapping.h" 25 #include "mlir/IR/Builders.h" 26 #include "mlir/IR/IntegerSet.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "mlir/Transforms/Passes.h" 31 32 using namespace mlir; 33 using namespace mlir::vector; 34 35 namespace { 36 /// Visit affine expressions recursively and build the sequence of operations 37 /// that correspond to it. Visitation functions return an Value of the 38 /// expression subtree they visited or `nullptr` on error. 39 class AffineApplyExpander 40 : public AffineExprVisitor<AffineApplyExpander, Value> { 41 public: 42 /// This internal class expects arguments to be non-null, checks must be 43 /// performed at the call site. 44 AffineApplyExpander(OpBuilder &builder, ValueRange dimValues, 45 ValueRange symbolValues, Location loc) 46 : builder(builder), dimValues(dimValues), symbolValues(symbolValues), 47 loc(loc) {} 48 49 template <typename OpTy> 50 Value buildBinaryExpr(AffineBinaryOpExpr expr) { 51 auto lhs = visit(expr.getLHS()); 52 auto rhs = visit(expr.getRHS()); 53 if (!lhs || !rhs) 54 return nullptr; 55 auto op = builder.create<OpTy>(loc, lhs, rhs); 56 return op.getResult(); 57 } 58 59 Value visitAddExpr(AffineBinaryOpExpr expr) { 60 return buildBinaryExpr<arith::AddIOp>(expr); 61 } 62 63 Value visitMulExpr(AffineBinaryOpExpr expr) { 64 return buildBinaryExpr<arith::MulIOp>(expr); 65 } 66 67 /// Euclidean modulo operation: negative RHS is not allowed. 68 /// Remainder of the euclidean integer division is always non-negative. 69 /// 70 /// Implemented as 71 /// 72 /// a mod b = 73 /// let remainder = srem a, b; 74 /// negative = a < 0 in 75 /// select negative, remainder + b, remainder. 76 Value visitModExpr(AffineBinaryOpExpr expr) { 77 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 78 if (!rhsConst) { 79 emitError( 80 loc, 81 "semi-affine expressions (modulo by non-const) are not supported"); 82 return nullptr; 83 } 84 if (rhsConst.getValue() <= 0) { 85 emitError(loc, "modulo by non-positive value is not supported"); 86 return nullptr; 87 } 88 89 auto lhs = visit(expr.getLHS()); 90 auto rhs = visit(expr.getRHS()); 91 assert(lhs && rhs && "unexpected affine expr lowering failure"); 92 93 Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs); 94 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 95 Value isRemainderNegative = builder.create<arith::CmpIOp>( 96 loc, arith::CmpIPredicate::slt, remainder, zeroCst); 97 Value correctedRemainder = 98 builder.create<arith::AddIOp>(loc, remainder, rhs); 99 Value result = builder.create<SelectOp>(loc, isRemainderNegative, 100 correctedRemainder, remainder); 101 return result; 102 } 103 104 /// Floor division operation (rounds towards negative infinity). 105 /// 106 /// For positive divisors, it can be implemented without branching and with a 107 /// single division operation as 108 /// 109 /// a floordiv b = 110 /// let negative = a < 0 in 111 /// let absolute = negative ? -a - 1 : a in 112 /// let quotient = absolute / b in 113 /// negative ? -quotient - 1 : quotient 114 Value visitFloorDivExpr(AffineBinaryOpExpr expr) { 115 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 116 if (!rhsConst) { 117 emitError( 118 loc, 119 "semi-affine expressions (division by non-const) are not supported"); 120 return nullptr; 121 } 122 if (rhsConst.getValue() <= 0) { 123 emitError(loc, "division by non-positive value is not supported"); 124 return nullptr; 125 } 126 127 auto lhs = visit(expr.getLHS()); 128 auto rhs = visit(expr.getRHS()); 129 assert(lhs && rhs && "unexpected affine expr lowering failure"); 130 131 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 132 Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1); 133 Value negative = builder.create<arith::CmpIOp>( 134 loc, arith::CmpIPredicate::slt, lhs, zeroCst); 135 Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs); 136 Value dividend = 137 builder.create<SelectOp>(loc, negative, negatedDecremented, lhs); 138 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); 139 Value correctedQuotient = 140 builder.create<arith::SubIOp>(loc, noneCst, quotient); 141 Value result = 142 builder.create<SelectOp>(loc, negative, correctedQuotient, quotient); 143 return result; 144 } 145 146 /// Ceiling division operation (rounds towards positive infinity). 147 /// 148 /// For positive divisors, it can be implemented without branching and with a 149 /// single division operation as 150 /// 151 /// a ceildiv b = 152 /// let negative = a <= 0 in 153 /// let absolute = negative ? -a : a - 1 in 154 /// let quotient = absolute / b in 155 /// negative ? -quotient : quotient + 1 156 Value visitCeilDivExpr(AffineBinaryOpExpr expr) { 157 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 158 if (!rhsConst) { 159 emitError(loc) << "semi-affine expressions (division by non-const) are " 160 "not supported"; 161 return nullptr; 162 } 163 if (rhsConst.getValue() <= 0) { 164 emitError(loc, "division by non-positive value is not supported"); 165 return nullptr; 166 } 167 auto lhs = visit(expr.getLHS()); 168 auto rhs = visit(expr.getRHS()); 169 assert(lhs && rhs && "unexpected affine expr lowering failure"); 170 171 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 172 Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1); 173 Value nonPositive = builder.create<arith::CmpIOp>( 174 loc, arith::CmpIPredicate::sle, lhs, zeroCst); 175 Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs); 176 Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst); 177 Value dividend = 178 builder.create<SelectOp>(loc, nonPositive, negated, decremented); 179 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); 180 Value negatedQuotient = 181 builder.create<arith::SubIOp>(loc, zeroCst, quotient); 182 Value incrementedQuotient = 183 builder.create<arith::AddIOp>(loc, quotient, oneCst); 184 Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient, 185 incrementedQuotient); 186 return result; 187 } 188 189 Value visitConstantExpr(AffineConstantExpr expr) { 190 auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue()); 191 return op.getResult(); 192 } 193 194 Value visitDimExpr(AffineDimExpr expr) { 195 assert(expr.getPosition() < dimValues.size() && 196 "affine dim position out of range"); 197 return dimValues[expr.getPosition()]; 198 } 199 200 Value visitSymbolExpr(AffineSymbolExpr expr) { 201 assert(expr.getPosition() < symbolValues.size() && 202 "symbol dim position out of range"); 203 return symbolValues[expr.getPosition()]; 204 } 205 206 private: 207 OpBuilder &builder; 208 ValueRange dimValues; 209 ValueRange symbolValues; 210 211 Location loc; 212 }; 213 } // namespace 214 215 /// Create a sequence of operations that implement the `expr` applied to the 216 /// given dimension and symbol values. 217 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc, 218 AffineExpr expr, ValueRange dimValues, 219 ValueRange symbolValues) { 220 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); 221 } 222 223 /// Create a sequence of operations that implement the `affineMap` applied to 224 /// the given `operands` (as it it were an AffineApplyOp). 225 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder, 226 Location loc, 227 AffineMap affineMap, 228 ValueRange operands) { 229 auto numDims = affineMap.getNumDims(); 230 auto expanded = llvm::to_vector<8>( 231 llvm::map_range(affineMap.getResults(), 232 [numDims, &builder, loc, operands](AffineExpr expr) { 233 return expandAffineExpr(builder, loc, expr, 234 operands.take_front(numDims), 235 operands.drop_front(numDims)); 236 })); 237 if (llvm::all_of(expanded, [](Value v) { return v; })) 238 return expanded; 239 return None; 240 } 241 242 /// Given a range of values, emit the code that reduces them with "min" or "max" 243 /// depending on the provided comparison predicate. The predicate defines which 244 /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the 245 /// `cmpi` operation followed by the `select` operation: 246 /// 247 /// %cond = arith.cmpi "predicate" %v0, %v1 248 /// %result = select %cond, %v0, %v1 249 /// 250 /// Multiple values are scanned in a linear sequence. This creates a data 251 /// dependences that wouldn't exist in a tree reduction, but is easier to 252 /// recognize as a reduction by the subsequent passes. 253 static Value buildMinMaxReductionSeq(Location loc, 254 arith::CmpIPredicate predicate, 255 ValueRange values, OpBuilder &builder) { 256 assert(!llvm::empty(values) && "empty min/max chain"); 257 258 auto valueIt = values.begin(); 259 Value value = *valueIt++; 260 for (; valueIt != values.end(); ++valueIt) { 261 auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt); 262 value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt); 263 } 264 265 return value; 266 } 267 268 /// Emit instructions that correspond to computing the maximum value among the 269 /// values of a (potentially) multi-output affine map applied to `operands`. 270 static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, 271 ValueRange operands) { 272 if (auto values = expandAffineMap(builder, loc, map, operands)) 273 return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values, 274 builder); 275 return nullptr; 276 } 277 278 /// Emit instructions that correspond to computing the minimum value among the 279 /// values of a (potentially) multi-output affine map applied to `operands`. 280 static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, 281 ValueRange operands) { 282 if (auto values = expandAffineMap(builder, loc, map, operands)) 283 return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values, 284 builder); 285 return nullptr; 286 } 287 288 /// Emit instructions that correspond to the affine map in the upper bound 289 /// applied to the respective operands, and compute the minimum value across 290 /// the results. 291 Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { 292 return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(), 293 op.getUpperBoundOperands()); 294 } 295 296 /// Emit instructions that correspond to the affine map in the lower bound 297 /// applied to the respective operands, and compute the maximum value across 298 /// the results. 299 Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { 300 return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(), 301 op.getLowerBoundOperands()); 302 } 303 304 namespace { 305 class AffineMinLowering : public OpRewritePattern<AffineMinOp> { 306 public: 307 using OpRewritePattern<AffineMinOp>::OpRewritePattern; 308 309 LogicalResult matchAndRewrite(AffineMinOp op, 310 PatternRewriter &rewriter) const override { 311 Value reduced = 312 lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands()); 313 if (!reduced) 314 return failure(); 315 316 rewriter.replaceOp(op, reduced); 317 return success(); 318 } 319 }; 320 321 class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> { 322 public: 323 using OpRewritePattern<AffineMaxOp>::OpRewritePattern; 324 325 LogicalResult matchAndRewrite(AffineMaxOp op, 326 PatternRewriter &rewriter) const override { 327 Value reduced = 328 lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands()); 329 if (!reduced) 330 return failure(); 331 332 rewriter.replaceOp(op, reduced); 333 return success(); 334 } 335 }; 336 337 /// Affine yields ops are removed. 338 class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> { 339 public: 340 using OpRewritePattern<AffineYieldOp>::OpRewritePattern; 341 342 LogicalResult matchAndRewrite(AffineYieldOp op, 343 PatternRewriter &rewriter) const override { 344 if (isa<scf::ParallelOp>(op->getParentOp())) { 345 // scf.parallel does not yield any values via its terminator scf.yield but 346 // models reductions differently using additional ops in its region. 347 rewriter.replaceOpWithNewOp<scf::YieldOp>(op); 348 return success(); 349 } 350 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.operands()); 351 return success(); 352 } 353 }; 354 355 class AffineForLowering : public OpRewritePattern<AffineForOp> { 356 public: 357 using OpRewritePattern<AffineForOp>::OpRewritePattern; 358 359 LogicalResult matchAndRewrite(AffineForOp op, 360 PatternRewriter &rewriter) const override { 361 Location loc = op.getLoc(); 362 Value lowerBound = lowerAffineLowerBound(op, rewriter); 363 Value upperBound = lowerAffineUpperBound(op, rewriter); 364 Value step = rewriter.create<arith::ConstantIndexOp>(loc, op.getStep()); 365 auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, 366 step, op.getIterOperands()); 367 rewriter.eraseBlock(scfForOp.getBody()); 368 rewriter.inlineRegionBefore(op.region(), scfForOp.getRegion(), 369 scfForOp.getRegion().end()); 370 rewriter.replaceOp(op, scfForOp.getResults()); 371 return success(); 372 } 373 }; 374 375 /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel` 376 /// operation. 377 class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> { 378 public: 379 using OpRewritePattern<AffineParallelOp>::OpRewritePattern; 380 381 LogicalResult matchAndRewrite(AffineParallelOp op, 382 PatternRewriter &rewriter) const override { 383 Location loc = op.getLoc(); 384 SmallVector<Value, 8> steps; 385 SmallVector<Value, 8> upperBoundTuple; 386 SmallVector<Value, 8> lowerBoundTuple; 387 SmallVector<Value, 8> identityVals; 388 // Emit IR computing the lower and upper bound by expanding the map 389 // expression. 390 lowerBoundTuple.reserve(op.getNumDims()); 391 upperBoundTuple.reserve(op.getNumDims()); 392 for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) { 393 Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i), 394 op.getLowerBoundsOperands()); 395 if (!lower) 396 return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds"); 397 lowerBoundTuple.push_back(lower); 398 399 Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i), 400 op.getUpperBoundsOperands()); 401 if (!upper) 402 return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds"); 403 upperBoundTuple.push_back(upper); 404 } 405 steps.reserve(op.steps().size()); 406 for (Attribute step : op.steps()) 407 steps.push_back(rewriter.create<arith::ConstantIndexOp>( 408 loc, step.cast<IntegerAttr>().getInt())); 409 410 // Get the terminator op. 411 Operation *affineParOpTerminator = op.getBody()->getTerminator(); 412 scf::ParallelOp parOp; 413 if (op.results().empty()) { 414 // Case with no reduction operations/return values. 415 parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple, 416 upperBoundTuple, steps, 417 /*bodyBuilderFn=*/nullptr); 418 rewriter.eraseBlock(parOp.getBody()); 419 rewriter.inlineRegionBefore(op.region(), parOp.getRegion(), 420 parOp.getRegion().end()); 421 rewriter.replaceOp(op, parOp.getResults()); 422 return success(); 423 } 424 // Case with affine.parallel with reduction operations/return values. 425 // scf.parallel handles the reduction operation differently unlike 426 // affine.parallel. 427 ArrayRef<Attribute> reductions = op.reductions().getValue(); 428 for (auto pair : llvm::zip(reductions, op.getResultTypes())) { 429 // For each of the reduction operations get the identity values for 430 // initialization of the result values. 431 Attribute reduction = std::get<0>(pair); 432 Type resultType = std::get<1>(pair); 433 Optional<arith::AtomicRMWKind> reductionOp = 434 arith::symbolizeAtomicRMWKind( 435 static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt())); 436 assert(reductionOp.hasValue() && 437 "Reduction operation cannot be of None Type"); 438 arith::AtomicRMWKind reductionOpValue = reductionOp.getValue(); 439 identityVals.push_back( 440 arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); 441 } 442 parOp = rewriter.create<scf::ParallelOp>( 443 loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, 444 /*bodyBuilderFn=*/nullptr); 445 446 // Copy the body of the affine.parallel op. 447 rewriter.eraseBlock(parOp.getBody()); 448 rewriter.inlineRegionBefore(op.region(), parOp.getRegion(), 449 parOp.getRegion().end()); 450 assert(reductions.size() == affineParOpTerminator->getNumOperands() && 451 "Unequal number of reductions and operands."); 452 for (unsigned i = 0, end = reductions.size(); i < end; i++) { 453 // For each of the reduction operations get the respective mlir::Value. 454 Optional<arith::AtomicRMWKind> reductionOp = 455 arith::symbolizeAtomicRMWKind( 456 reductions[i].cast<IntegerAttr>().getInt()); 457 assert(reductionOp.hasValue() && 458 "Reduction Operation cannot be of None Type"); 459 arith::AtomicRMWKind reductionOpValue = reductionOp.getValue(); 460 rewriter.setInsertionPoint(&parOp.getBody()->back()); 461 auto reduceOp = rewriter.create<scf::ReduceOp>( 462 loc, affineParOpTerminator->getOperand(i)); 463 rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front()); 464 Value reductionResult = arith::getReductionOp( 465 reductionOpValue, rewriter, loc, 466 reduceOp.getReductionOperator().front().getArgument(0), 467 reduceOp.getReductionOperator().front().getArgument(1)); 468 rewriter.create<scf::ReduceReturnOp>(loc, reductionResult); 469 } 470 rewriter.replaceOp(op, parOp.getResults()); 471 return success(); 472 } 473 }; 474 475 class AffineIfLowering : public OpRewritePattern<AffineIfOp> { 476 public: 477 using OpRewritePattern<AffineIfOp>::OpRewritePattern; 478 479 LogicalResult matchAndRewrite(AffineIfOp op, 480 PatternRewriter &rewriter) const override { 481 auto loc = op.getLoc(); 482 483 // Now we just have to handle the condition logic. 484 auto integerSet = op.getIntegerSet(); 485 Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0); 486 SmallVector<Value, 8> operands(op.getOperands()); 487 auto operandsRef = llvm::makeArrayRef(operands); 488 489 // Calculate cond as a conjunction without short-circuiting. 490 Value cond = nullptr; 491 for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) { 492 AffineExpr constraintExpr = integerSet.getConstraint(i); 493 bool isEquality = integerSet.isEq(i); 494 495 // Build and apply an affine expression 496 auto numDims = integerSet.getNumDims(); 497 Value affResult = expandAffineExpr(rewriter, loc, constraintExpr, 498 operandsRef.take_front(numDims), 499 operandsRef.drop_front(numDims)); 500 if (!affResult) 501 return failure(); 502 auto pred = 503 isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge; 504 Value cmpVal = 505 rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant); 506 cond = cond 507 ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult() 508 : cmpVal; 509 } 510 cond = cond ? cond 511 : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1, 512 /*width=*/1); 513 514 bool hasElseRegion = !op.elseRegion().empty(); 515 auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond, 516 hasElseRegion); 517 rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.getThenRegion().back()); 518 rewriter.eraseBlock(&ifOp.getThenRegion().back()); 519 if (hasElseRegion) { 520 rewriter.inlineRegionBefore(op.elseRegion(), 521 &ifOp.getElseRegion().back()); 522 rewriter.eraseBlock(&ifOp.getElseRegion().back()); 523 } 524 525 // Replace the Affine IfOp finally. 526 rewriter.replaceOp(op, ifOp.getResults()); 527 return success(); 528 } 529 }; 530 531 /// Convert an "affine.apply" operation into a sequence of arithmetic 532 /// operations using the StandardOps dialect. 533 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> { 534 public: 535 using OpRewritePattern<AffineApplyOp>::OpRewritePattern; 536 537 LogicalResult matchAndRewrite(AffineApplyOp op, 538 PatternRewriter &rewriter) const override { 539 auto maybeExpandedMap = 540 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), 541 llvm::to_vector<8>(op.getOperands())); 542 if (!maybeExpandedMap) 543 return failure(); 544 rewriter.replaceOp(op, *maybeExpandedMap); 545 return success(); 546 } 547 }; 548 549 /// Apply the affine map from an 'affine.load' operation to its operands, and 550 /// feed the results to a newly created 'memref.load' operation (which replaces 551 /// the original 'affine.load'). 552 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> { 553 public: 554 using OpRewritePattern<AffineLoadOp>::OpRewritePattern; 555 556 LogicalResult matchAndRewrite(AffineLoadOp op, 557 PatternRewriter &rewriter) const override { 558 // Expand affine map from 'affineLoadOp'. 559 SmallVector<Value, 8> indices(op.getMapOperands()); 560 auto resultOperands = 561 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 562 if (!resultOperands) 563 return failure(); 564 565 // Build vector.load memref[expandedMap.results]. 566 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(), 567 *resultOperands); 568 return success(); 569 } 570 }; 571 572 /// Apply the affine map from an 'affine.prefetch' operation to its operands, 573 /// and feed the results to a newly created 'memref.prefetch' operation (which 574 /// replaces the original 'affine.prefetch'). 575 class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> { 576 public: 577 using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern; 578 579 LogicalResult matchAndRewrite(AffinePrefetchOp op, 580 PatternRewriter &rewriter) const override { 581 // Expand affine map from 'affinePrefetchOp'. 582 SmallVector<Value, 8> indices(op.getMapOperands()); 583 auto resultOperands = 584 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 585 if (!resultOperands) 586 return failure(); 587 588 // Build memref.prefetch memref[expandedMap.results]. 589 rewriter.replaceOpWithNewOp<memref::PrefetchOp>( 590 op, op.memref(), *resultOperands, op.isWrite(), op.localityHint(), 591 op.isDataCache()); 592 return success(); 593 } 594 }; 595 596 /// Apply the affine map from an 'affine.store' operation to its operands, and 597 /// feed the results to a newly created 'memref.store' operation (which replaces 598 /// the original 'affine.store'). 599 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> { 600 public: 601 using OpRewritePattern<AffineStoreOp>::OpRewritePattern; 602 603 LogicalResult matchAndRewrite(AffineStoreOp op, 604 PatternRewriter &rewriter) const override { 605 // Expand affine map from 'affineStoreOp'. 606 SmallVector<Value, 8> indices(op.getMapOperands()); 607 auto maybeExpandedMap = 608 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 609 if (!maybeExpandedMap) 610 return failure(); 611 612 // Build memref.store valueToStore, memref[expandedMap.results]. 613 rewriter.replaceOpWithNewOp<memref::StoreOp>( 614 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap); 615 return success(); 616 } 617 }; 618 619 /// Apply the affine maps from an 'affine.dma_start' operation to each of their 620 /// respective map operands, and feed the results to a newly created 621 /// 'memref.dma_start' operation (which replaces the original 622 /// 'affine.dma_start'). 623 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> { 624 public: 625 using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern; 626 627 LogicalResult matchAndRewrite(AffineDmaStartOp op, 628 PatternRewriter &rewriter) const override { 629 SmallVector<Value, 8> operands(op.getOperands()); 630 auto operandsRef = llvm::makeArrayRef(operands); 631 632 // Expand affine map for DMA source memref. 633 auto maybeExpandedSrcMap = expandAffineMap( 634 rewriter, op.getLoc(), op.getSrcMap(), 635 operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1)); 636 if (!maybeExpandedSrcMap) 637 return failure(); 638 // Expand affine map for DMA destination memref. 639 auto maybeExpandedDstMap = expandAffineMap( 640 rewriter, op.getLoc(), op.getDstMap(), 641 operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1)); 642 if (!maybeExpandedDstMap) 643 return failure(); 644 // Expand affine map for DMA tag memref. 645 auto maybeExpandedTagMap = expandAffineMap( 646 rewriter, op.getLoc(), op.getTagMap(), 647 operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1)); 648 if (!maybeExpandedTagMap) 649 return failure(); 650 651 // Build memref.dma_start operation with affine map results. 652 rewriter.replaceOpWithNewOp<memref::DmaStartOp>( 653 op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(), 654 *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(), 655 *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride()); 656 return success(); 657 } 658 }; 659 660 /// Apply the affine map from an 'affine.dma_wait' operation tag memref, 661 /// and feed the results to a newly created 'memref.dma_wait' operation (which 662 /// replaces the original 'affine.dma_wait'). 663 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> { 664 public: 665 using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern; 666 667 LogicalResult matchAndRewrite(AffineDmaWaitOp op, 668 PatternRewriter &rewriter) const override { 669 // Expand affine map for DMA tag memref. 670 SmallVector<Value, 8> indices(op.getTagIndices()); 671 auto maybeExpandedTagMap = 672 expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices); 673 if (!maybeExpandedTagMap) 674 return failure(); 675 676 // Build memref.dma_wait operation with affine map results. 677 rewriter.replaceOpWithNewOp<memref::DmaWaitOp>( 678 op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements()); 679 return success(); 680 } 681 }; 682 683 /// Apply the affine map from an 'affine.vector_load' operation to its operands, 684 /// and feed the results to a newly created 'vector.load' operation (which 685 /// replaces the original 'affine.vector_load'). 686 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> { 687 public: 688 using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern; 689 690 LogicalResult matchAndRewrite(AffineVectorLoadOp op, 691 PatternRewriter &rewriter) const override { 692 // Expand affine map from 'affineVectorLoadOp'. 693 SmallVector<Value, 8> indices(op.getMapOperands()); 694 auto resultOperands = 695 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 696 if (!resultOperands) 697 return failure(); 698 699 // Build vector.load memref[expandedMap.results]. 700 rewriter.replaceOpWithNewOp<vector::LoadOp>( 701 op, op.getVectorType(), op.getMemRef(), *resultOperands); 702 return success(); 703 } 704 }; 705 706 /// Apply the affine map from an 'affine.vector_store' operation to its 707 /// operands, and feed the results to a newly created 'vector.store' operation 708 /// (which replaces the original 'affine.vector_store'). 709 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> { 710 public: 711 using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern; 712 713 LogicalResult matchAndRewrite(AffineVectorStoreOp op, 714 PatternRewriter &rewriter) const override { 715 // Expand affine map from 'affineVectorStoreOp'. 716 SmallVector<Value, 8> indices(op.getMapOperands()); 717 auto maybeExpandedMap = 718 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 719 if (!maybeExpandedMap) 720 return failure(); 721 722 rewriter.replaceOpWithNewOp<vector::StoreOp>( 723 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap); 724 return success(); 725 } 726 }; 727 728 } // namespace 729 730 void mlir::populateAffineToStdConversionPatterns(RewritePatternSet &patterns) { 731 // clang-format off 732 patterns.add< 733 AffineApplyLowering, 734 AffineDmaStartLowering, 735 AffineDmaWaitLowering, 736 AffineLoadLowering, 737 AffineMinLowering, 738 AffineMaxLowering, 739 AffineParallelLowering, 740 AffinePrefetchLowering, 741 AffineStoreLowering, 742 AffineForLowering, 743 AffineIfLowering, 744 AffineYieldOpLowering>(patterns.getContext()); 745 // clang-format on 746 } 747 748 void mlir::populateAffineToVectorConversionPatterns( 749 RewritePatternSet &patterns) { 750 // clang-format off 751 patterns.add< 752 AffineVectorLoadLowering, 753 AffineVectorStoreLowering>(patterns.getContext()); 754 // clang-format on 755 } 756 757 namespace { 758 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> { 759 void runOnOperation() override { 760 RewritePatternSet patterns(&getContext()); 761 populateAffineToStdConversionPatterns(patterns); 762 populateAffineToVectorConversionPatterns(patterns); 763 ConversionTarget target(getContext()); 764 target 765 .addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect, 766 scf::SCFDialect, StandardOpsDialect, VectorDialect>(); 767 if (failed(applyPartialConversion(getOperation(), target, 768 std::move(patterns)))) 769 signalPassFailure(); 770 } 771 }; 772 } // namespace 773 774 /// Lowers If and For operations within a function into their lower level CFG 775 /// equivalent blocks. 776 std::unique_ptr<Pass> mlir::createLowerAffinePass() { 777 return std::make_unique<LowerAffinePass>(); 778 } 779