14bcd08ebSStephan Herhut //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===//
24bcd08ebSStephan Herhut //
34bcd08ebSStephan Herhut // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44bcd08ebSStephan Herhut // See https://llvm.org/LICENSE.txt for license information.
54bcd08ebSStephan Herhut // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64bcd08ebSStephan Herhut //
74bcd08ebSStephan Herhut //===----------------------------------------------------------------------===//
84bcd08ebSStephan Herhut //
94bcd08ebSStephan Herhut // Specializes parallel loops and for loops for easier unrolling and
104bcd08ebSStephan Herhut // vectorization.
114bcd08ebSStephan Herhut //
124bcd08ebSStephan Herhut //===----------------------------------------------------------------------===//
134bcd08ebSStephan Herhut
144bcd08ebSStephan Herhut #include "PassDetail.h"
15755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
164bcd08ebSStephan Herhut #include "mlir/Dialect/Affine/IR/AffineOps.h"
17a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Passes.h"
208b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
21f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
223a41ff48SMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
234bcd08ebSStephan Herhut #include "mlir/IR/AffineExpr.h"
244bcd08ebSStephan Herhut #include "mlir/IR/BlockAndValueMapping.h"
253a41ff48SMatthias Springer #include "mlir/IR/PatternMatch.h"
263a41ff48SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
273a41ff48SMatthias Springer #include "llvm/ADT/DenseMap.h"
284bcd08ebSStephan Herhut
294bcd08ebSStephan Herhut using namespace mlir;
304bcd08ebSStephan Herhut using scf::ForOp;
314bcd08ebSStephan Herhut using scf::ParallelOp;
324bcd08ebSStephan Herhut
334bcd08ebSStephan Herhut /// Rewrite a parallel loop with bounds defined by an affine.min with a constant
344bcd08ebSStephan Herhut /// into 2 loops after checking if the bounds are equal to that constant. This
354bcd08ebSStephan Herhut /// is beneficial if the loop will almost always have the constant bound and
364bcd08ebSStephan Herhut /// that version can be fully unrolled and vectorized.
specializeParallelLoopForUnrolling(ParallelOp op)374bcd08ebSStephan Herhut static void specializeParallelLoopForUnrolling(ParallelOp op) {
384bcd08ebSStephan Herhut SmallVector<int64_t, 2> constantIndices;
39c0342a2dSJacques Pienaar constantIndices.reserve(op.getUpperBound().size());
40c0342a2dSJacques Pienaar for (auto bound : op.getUpperBound()) {
414bcd08ebSStephan Herhut auto minOp = bound.getDefiningOp<AffineMinOp>();
424bcd08ebSStephan Herhut if (!minOp)
434bcd08ebSStephan Herhut return;
444bcd08ebSStephan Herhut int64_t minConstant = std::numeric_limits<int64_t>::max();
45*04235d07SJacques Pienaar for (AffineExpr expr : minOp.getMap().getResults()) {
464bcd08ebSStephan Herhut if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
474bcd08ebSStephan Herhut minConstant = std::min(minConstant, constantIndex.getValue());
484bcd08ebSStephan Herhut }
494bcd08ebSStephan Herhut if (minConstant == std::numeric_limits<int64_t>::max())
504bcd08ebSStephan Herhut return;
514bcd08ebSStephan Herhut constantIndices.push_back(minConstant);
524bcd08ebSStephan Herhut }
534bcd08ebSStephan Herhut
544bcd08ebSStephan Herhut OpBuilder b(op);
554bcd08ebSStephan Herhut BlockAndValueMapping map;
564bcd08ebSStephan Herhut Value cond;
57c0342a2dSJacques Pienaar for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) {
58a54f4eaeSMogball Value constant =
59a54f4eaeSMogball b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound));
60a54f4eaeSMogball Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
614bcd08ebSStephan Herhut std::get<0>(bound), constant);
62a54f4eaeSMogball cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp;
634bcd08ebSStephan Herhut map.map(std::get<0>(bound), constant);
644bcd08ebSStephan Herhut }
654bcd08ebSStephan Herhut auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
664bcd08ebSStephan Herhut ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
674bcd08ebSStephan Herhut ifOp.getElseBodyBuilder().clone(*op.getOperation());
684bcd08ebSStephan Herhut op.erase();
694bcd08ebSStephan Herhut }
704bcd08ebSStephan Herhut
714bcd08ebSStephan Herhut /// Rewrite a for loop with bounds defined by an affine.min with a constant into
724bcd08ebSStephan Herhut /// 2 loops after checking if the bounds are equal to that constant. This is
734bcd08ebSStephan Herhut /// beneficial if the loop will almost always have the constant bound and that
744bcd08ebSStephan Herhut /// version can be fully unrolled and vectorized.
specializeForLoopForUnrolling(ForOp op)754bcd08ebSStephan Herhut static void specializeForLoopForUnrolling(ForOp op) {
76c0342a2dSJacques Pienaar auto bound = op.getUpperBound();
774bcd08ebSStephan Herhut auto minOp = bound.getDefiningOp<AffineMinOp>();
784bcd08ebSStephan Herhut if (!minOp)
794bcd08ebSStephan Herhut return;
804bcd08ebSStephan Herhut int64_t minConstant = std::numeric_limits<int64_t>::max();
81*04235d07SJacques Pienaar for (AffineExpr expr : minOp.getMap().getResults()) {
824bcd08ebSStephan Herhut if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>())
834bcd08ebSStephan Herhut minConstant = std::min(minConstant, constantIndex.getValue());
844bcd08ebSStephan Herhut }
854bcd08ebSStephan Herhut if (minConstant == std::numeric_limits<int64_t>::max())
864bcd08ebSStephan Herhut return;
874bcd08ebSStephan Herhut
884bcd08ebSStephan Herhut OpBuilder b(op);
894bcd08ebSStephan Herhut BlockAndValueMapping map;
90a54f4eaeSMogball Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant);
91a54f4eaeSMogball Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq,
92a54f4eaeSMogball bound, constant);
934bcd08ebSStephan Herhut map.map(bound, constant);
944bcd08ebSStephan Herhut auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true);
954bcd08ebSStephan Herhut ifOp.getThenBodyBuilder().clone(*op.getOperation(), map);
964bcd08ebSStephan Herhut ifOp.getElseBodyBuilder().clone(*op.getOperation());
974bcd08ebSStephan Herhut op.erase();
984bcd08ebSStephan Herhut }
994bcd08ebSStephan Herhut
1003a41ff48SMatthias Springer /// Rewrite a for loop with bounds/step that potentially do not divide evenly
1013a41ff48SMatthias Springer /// into a for loop where the step divides the iteration space evenly, followed
1023a41ff48SMatthias Springer /// by an scf.if for the last (partial) iteration (if any).
1038e8b70aaSMatthias Springer ///
1048e8b70aaSMatthias Springer /// This function rewrites the given scf.for loop in-place and creates a new
1058e8b70aaSMatthias Springer /// scf.if operation for the last iteration. It replaces all uses of the
1068e8b70aaSMatthias Springer /// unpeeled loop with the results of the newly generated scf.if.
1078e8b70aaSMatthias Springer ///
1088e8b70aaSMatthias Springer /// The newly generated scf.if operation is returned via `ifOp`. The boundary
1098e8b70aaSMatthias Springer /// at which the loop is split (new upper bound) is returned via `splitBound`.
1108e8b70aaSMatthias Springer /// The return value indicates whether the loop was rewritten or not.
peelForLoop(RewriterBase & b,ForOp forOp,ForOp & partialIteration,Value & splitBound)1110f3544d1SMatthias Springer static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
1120f3544d1SMatthias Springer ForOp &partialIteration, Value &splitBound) {
1133a41ff48SMatthias Springer RewriterBase::InsertionGuard guard(b);
114c0342a2dSJacques Pienaar auto lbInt = getConstantIntValue(forOp.getLowerBound());
115c0342a2dSJacques Pienaar auto ubInt = getConstantIntValue(forOp.getUpperBound());
116c0342a2dSJacques Pienaar auto stepInt = getConstantIntValue(forOp.getStep());
1173a41ff48SMatthias Springer
1183a41ff48SMatthias Springer // No specialization necessary if step already divides upper bound evenly.
1193a41ff48SMatthias Springer if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
1203a41ff48SMatthias Springer return failure();
1213a41ff48SMatthias Springer // No specialization necessary if step size is 1.
1223a41ff48SMatthias Springer if (stepInt == static_cast<int64_t>(1))
1233a41ff48SMatthias Springer return failure();
1243a41ff48SMatthias Springer
1253a41ff48SMatthias Springer auto loc = forOp.getLoc();
1260c360829SMatthias Springer AffineExpr sym0, sym1, sym2;
1270c360829SMatthias Springer bindSymbols(b.getContext(), sym0, sym1, sym2);
1283a41ff48SMatthias Springer // New upper bound: %ub - (%ub - %lb) mod %step
1290c360829SMatthias Springer auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
130767974f3SMatthias Springer b.setInsertionPoint(forOp);
131c0342a2dSJacques Pienaar splitBound = b.createOrFold<AffineApplyOp>(loc, modMap,
132c0342a2dSJacques Pienaar ValueRange{forOp.getLowerBound(),
133c0342a2dSJacques Pienaar forOp.getUpperBound(),
134c0342a2dSJacques Pienaar forOp.getStep()});
1353a41ff48SMatthias Springer
1360f3544d1SMatthias Springer // Create ForOp for partial iteration.
1370f3544d1SMatthias Springer b.setInsertionPointAfter(forOp);
1380f3544d1SMatthias Springer partialIteration = cast<ForOp>(b.clone(*forOp.getOperation()));
139c0342a2dSJacques Pienaar partialIteration.getLowerBoundMutable().assign(splitBound);
1400f3544d1SMatthias Springer forOp.replaceAllUsesWith(partialIteration->getResults());
141c0342a2dSJacques Pienaar partialIteration.getInitArgsMutable().assign(forOp->getResults());
1420f3544d1SMatthias Springer
1433a41ff48SMatthias Springer // Set new upper loop bound.
144c0342a2dSJacques Pienaar b.updateRootInPlace(
145c0342a2dSJacques Pienaar forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
1463a41ff48SMatthias Springer
1473a41ff48SMatthias Springer return success();
1483a41ff48SMatthias Springer }
1493a41ff48SMatthias Springer
150a9cff97fSMatthias Springer template <typename OpTy, bool IsMin>
rewriteAffineOpAfterPeeling(RewriterBase & rewriter,ForOp forOp,ForOp partialIteration,Value previousUb)1510f3544d1SMatthias Springer static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
1520f3544d1SMatthias Springer ForOp partialIteration,
1530f3544d1SMatthias Springer Value previousUb) {
1540f3544d1SMatthias Springer Value mainIv = forOp.getInductionVar();
1550f3544d1SMatthias Springer Value partialIv = partialIteration.getInductionVar();
156c0342a2dSJacques Pienaar assert(forOp.getStep() == partialIteration.getStep() &&
1570f3544d1SMatthias Springer "expected same step in main and partial loop");
158c0342a2dSJacques Pienaar Value step = forOp.getStep();
1590f3544d1SMatthias Springer
160a9cff97fSMatthias Springer forOp.walk([&](OpTy affineOp) {
161c57c4f88SMatthias Springer AffineMap map = affineOp.getAffineMap();
162c57c4f88SMatthias Springer (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
1630f3544d1SMatthias Springer affineOp.operands(), IsMin, mainIv,
1640f3544d1SMatthias Springer previousUb, step,
165a9cff97fSMatthias Springer /*insideLoop=*/true);
166a9cff97fSMatthias Springer });
1670f3544d1SMatthias Springer partialIteration.walk([&](OpTy affineOp) {
168c57c4f88SMatthias Springer AffineMap map = affineOp.getAffineMap();
169c57c4f88SMatthias Springer (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
1700f3544d1SMatthias Springer affineOp.operands(), IsMin, partialIv,
1710f3544d1SMatthias Springer previousUb, step, /*insideLoop=*/false);
172a9cff97fSMatthias Springer });
1738e8b70aaSMatthias Springer }
1748e8b70aaSMatthias Springer
peelAndCanonicalizeForLoop(RewriterBase & rewriter,ForOp forOp,ForOp & partialIteration)1758e8b70aaSMatthias Springer LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
176bc194a5bSMatthias Springer ForOp forOp,
1770f3544d1SMatthias Springer ForOp &partialIteration) {
178c0342a2dSJacques Pienaar Value previousUb = forOp.getUpperBound();
1798e8b70aaSMatthias Springer Value splitBound;
1800f3544d1SMatthias Springer if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound)))
1818e8b70aaSMatthias Springer return failure();
1828e8b70aaSMatthias Springer
183a9cff97fSMatthias Springer // Rewrite affine.min and affine.max ops.
184a9cff97fSMatthias Springer rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
1850f3544d1SMatthias Springer rewriter, forOp, partialIteration, previousUb);
186a9cff97fSMatthias Springer rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
1870f3544d1SMatthias Springer rewriter, forOp, partialIteration, previousUb);
1888e8b70aaSMatthias Springer
1898e8b70aaSMatthias Springer return success();
1908e8b70aaSMatthias Springer }
1918e8b70aaSMatthias Springer
1923a41ff48SMatthias Springer static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
193bc194a5bSMatthias Springer static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
1943a41ff48SMatthias Springer
1953a41ff48SMatthias Springer namespace {
1963a41ff48SMatthias Springer struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
ForLoopPeelingPattern__anon9cb0d7000411::ForLoopPeelingPattern197bc194a5bSMatthias Springer ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial)
198bc194a5bSMatthias Springer : OpRewritePattern<ForOp>(ctx), skipPartial(skipPartial) {}
1993a41ff48SMatthias Springer
matchAndRewrite__anon9cb0d7000411::ForLoopPeelingPattern2003a41ff48SMatthias Springer LogicalResult matchAndRewrite(ForOp forOp,
2013a41ff48SMatthias Springer PatternRewriter &rewriter) const override {
202bc194a5bSMatthias Springer // Do not peel already peeled loops.
2033a41ff48SMatthias Springer if (forOp->hasAttr(kPeeledLoopLabel))
2043a41ff48SMatthias Springer return failure();
205bc194a5bSMatthias Springer if (skipPartial) {
2060f3544d1SMatthias Springer // No peeling of loops inside the partial iteration of another peeled
2070f3544d1SMatthias Springer // loop.
208bc194a5bSMatthias Springer Operation *op = forOp.getOperation();
2090f3544d1SMatthias Springer while ((op = op->getParentOfType<scf::ForOp>())) {
210bc194a5bSMatthias Springer if (op->hasAttr(kPartialIterationLabel))
211bc194a5bSMatthias Springer return failure();
212bc194a5bSMatthias Springer }
213bc194a5bSMatthias Springer }
214bc194a5bSMatthias Springer // Apply loop peeling.
2150f3544d1SMatthias Springer scf::ForOp partialIteration;
2160f3544d1SMatthias Springer if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration)))
2173a41ff48SMatthias Springer return failure();
2183a41ff48SMatthias Springer // Apply label, so that the same loop is not rewritten a second time.
2190f3544d1SMatthias Springer partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
2203a41ff48SMatthias Springer rewriter.updateRootInPlace(forOp, [&]() {
2213a41ff48SMatthias Springer forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
2223a41ff48SMatthias Springer });
2230f3544d1SMatthias Springer partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
2243a41ff48SMatthias Springer return success();
2253a41ff48SMatthias Springer }
226bc194a5bSMatthias Springer
227bc194a5bSMatthias Springer /// If set to true, loops inside partial iterations of another peeled loop
228bc194a5bSMatthias Springer /// are not peeled. This reduces the size of the generated code. Partial
229bc194a5bSMatthias Springer /// iterations are not usually performance critical.
230bc194a5bSMatthias Springer /// Note: Takes into account the entire chain of parent operations, not just
231bc194a5bSMatthias Springer /// the direct parent.
232bc194a5bSMatthias Springer bool skipPartial;
2333a41ff48SMatthias Springer };
2343a41ff48SMatthias Springer } // namespace
2353a41ff48SMatthias Springer
2364bcd08ebSStephan Herhut namespace {
2374bcd08ebSStephan Herhut struct ParallelLoopSpecialization
2384bcd08ebSStephan Herhut : public SCFParallelLoopSpecializationBase<ParallelLoopSpecialization> {
runOnOperation__anon9cb0d7000611::ParallelLoopSpecialization23941574554SRiver Riddle void runOnOperation() override {
24054998986SStella Laurenzo getOperation()->walk(
2414bcd08ebSStephan Herhut [](ParallelOp op) { specializeParallelLoopForUnrolling(op); });
2424bcd08ebSStephan Herhut }
2434bcd08ebSStephan Herhut };
2444bcd08ebSStephan Herhut
2454bcd08ebSStephan Herhut struct ForLoopSpecialization
2464bcd08ebSStephan Herhut : public SCFForLoopSpecializationBase<ForLoopSpecialization> {
runOnOperation__anon9cb0d7000611::ForLoopSpecialization24741574554SRiver Riddle void runOnOperation() override {
24854998986SStella Laurenzo getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); });
2494bcd08ebSStephan Herhut }
2504bcd08ebSStephan Herhut };
2513a41ff48SMatthias Springer
2523a41ff48SMatthias Springer struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
runOnOperation__anon9cb0d7000611::ForLoopPeeling25341574554SRiver Riddle void runOnOperation() override {
25454998986SStella Laurenzo auto *parentOp = getOperation();
25554998986SStella Laurenzo MLIRContext *ctx = parentOp->getContext();
2563a41ff48SMatthias Springer RewritePatternSet patterns(ctx);
257bc194a5bSMatthias Springer patterns.add<ForLoopPeelingPattern>(ctx, skipPartial);
25854998986SStella Laurenzo (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
2593a41ff48SMatthias Springer
260bc194a5bSMatthias Springer // Drop the markers.
26154998986SStella Laurenzo parentOp->walk([](Operation *op) {
262bc194a5bSMatthias Springer op->removeAttr(kPeeledLoopLabel);
263bc194a5bSMatthias Springer op->removeAttr(kPartialIterationLabel);
264bc194a5bSMatthias Springer });
2653a41ff48SMatthias Springer }
2663a41ff48SMatthias Springer };
2674bcd08ebSStephan Herhut } // namespace
2684bcd08ebSStephan Herhut
createParallelLoopSpecializationPass()2694bcd08ebSStephan Herhut std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
2704bcd08ebSStephan Herhut return std::make_unique<ParallelLoopSpecialization>();
2714bcd08ebSStephan Herhut }
2724bcd08ebSStephan Herhut
createForLoopSpecializationPass()2734bcd08ebSStephan Herhut std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
2744bcd08ebSStephan Herhut return std::make_unique<ForLoopSpecialization>();
2754bcd08ebSStephan Herhut }
2763a41ff48SMatthias Springer
createForLoopPeelingPass()2773a41ff48SMatthias Springer std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
2783a41ff48SMatthias Springer return std::make_unique<ForLoopPeeling>();
2793a41ff48SMatthias Springer }
280