1f6f88e66Sthomasraoux //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
2f6f88e66Sthomasraoux //
3f6f88e66Sthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f6f88e66Sthomasraoux // See https://llvm.org/LICENSE.txt for license information.
5f6f88e66Sthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f6f88e66Sthomasraoux //
7f6f88e66Sthomasraoux //===----------------------------------------------------------------------===//
8f6f88e66Sthomasraoux //
9f6f88e66Sthomasraoux // This file implements loop software pipelining
10f6f88e66Sthomasraoux //
11f6f88e66Sthomasraoux //===----------------------------------------------------------------------===//
12f6f88e66Sthomasraoux 
13f6f88e66Sthomasraoux #include "PassDetail.h"
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
16*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h"
17*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
19f6f88e66Sthomasraoux #include "mlir/IR/BlockAndValueMapping.h"
20f6f88e66Sthomasraoux #include "mlir/IR/PatternMatch.h"
21f6f88e66Sthomasraoux #include "mlir/Support/MathExtras.h"
22f6f88e66Sthomasraoux 
23f6f88e66Sthomasraoux using namespace mlir;
24f6f88e66Sthomasraoux using namespace mlir::scf;
25f6f88e66Sthomasraoux 
26f6f88e66Sthomasraoux namespace {
27f6f88e66Sthomasraoux 
28f6f88e66Sthomasraoux /// Helper to keep internal information during pipelining transformation.
29f6f88e66Sthomasraoux struct LoopPipelinerInternal {
30f6f88e66Sthomasraoux   /// Coarse liverange information for ops used across stages.
31f6f88e66Sthomasraoux   struct LiverangeInfo {
32f6f88e66Sthomasraoux     unsigned lastUseStage = 0;
33f6f88e66Sthomasraoux     unsigned defStage = 0;
34f6f88e66Sthomasraoux   };
35f6f88e66Sthomasraoux 
36f6f88e66Sthomasraoux protected:
37f6f88e66Sthomasraoux   ForOp forOp;
38f6f88e66Sthomasraoux   unsigned maxStage = 0;
39f6f88e66Sthomasraoux   DenseMap<Operation *, unsigned> stages;
40f6f88e66Sthomasraoux   std::vector<Operation *> opOrder;
41f6f88e66Sthomasraoux   int64_t ub;
42f6f88e66Sthomasraoux   int64_t lb;
43f6f88e66Sthomasraoux   int64_t step;
440736bbd7SThomas Raoux   PipeliningOption::AnnotationlFnType annotateFn = nullptr;
45205c08b5SThomas Raoux   bool peelEpilogue;
46205c08b5SThomas Raoux   PipeliningOption::PredicateOpFn predicateFn = nullptr;
47f6f88e66Sthomasraoux 
48f6f88e66Sthomasraoux   // When peeling the kernel we generate several version of each value for
49f6f88e66Sthomasraoux   // different stage of the prologue. This map tracks the mapping between
50f6f88e66Sthomasraoux   // original Values in the loop and the different versions
51f6f88e66Sthomasraoux   // peeled from the loop.
52f6f88e66Sthomasraoux   DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
53f6f88e66Sthomasraoux 
54f6f88e66Sthomasraoux   /// Assign a value to `valueMapping`, this means `val` represents the version
55f6f88e66Sthomasraoux   /// `idx` of `key` in the epilogue.
56f6f88e66Sthomasraoux   void setValueMapping(Value key, Value el, int64_t idx);
57f6f88e66Sthomasraoux 
58f6f88e66Sthomasraoux public:
59f6f88e66Sthomasraoux   /// Initalize the information for the given `op`, return true if it
60f6f88e66Sthomasraoux   /// satisfies the pre-condition to apply pipelining.
61f6f88e66Sthomasraoux   bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
62f6f88e66Sthomasraoux   /// Emits the prologue, this creates `maxStage - 1` part which will contain
63f6f88e66Sthomasraoux   /// operations from stages [0; i], where i is the part index.
64f6f88e66Sthomasraoux   void emitPrologue(PatternRewriter &rewriter);
65f6f88e66Sthomasraoux   /// Gather liverange information for Values that are used in a different stage
66f6f88e66Sthomasraoux   /// than its definition.
67f6f88e66Sthomasraoux   llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
68f6f88e66Sthomasraoux   scf::ForOp createKernelLoop(
69f6f88e66Sthomasraoux       const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
70f6f88e66Sthomasraoux       PatternRewriter &rewriter,
71f6f88e66Sthomasraoux       llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
72f6f88e66Sthomasraoux   /// Emits the pipelined kernel. This clones loop operations following user
73f6f88e66Sthomasraoux   /// order and remaps operands defined in a different stage as their use.
74f6f88e66Sthomasraoux   void createKernel(
75f6f88e66Sthomasraoux       scf::ForOp newForOp,
76f6f88e66Sthomasraoux       const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
77f6f88e66Sthomasraoux       const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
78f6f88e66Sthomasraoux       PatternRewriter &rewriter);
79f6f88e66Sthomasraoux   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
80f6f88e66Sthomasraoux   /// operations from stages [i; maxStage], where i is the part index.
8145cb4140Sthomasraoux   llvm::SmallVector<Value> emitEpilogue(PatternRewriter &rewriter);
82f6f88e66Sthomasraoux };
83f6f88e66Sthomasraoux 
initializeLoopInfo(ForOp op,const PipeliningOption & options)84f6f88e66Sthomasraoux bool LoopPipelinerInternal::initializeLoopInfo(
85f6f88e66Sthomasraoux     ForOp op, const PipeliningOption &options) {
86f6f88e66Sthomasraoux   forOp = op;
87a54f4eaeSMogball   auto upperBoundCst =
88c0342a2dSJacques Pienaar       forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
89a54f4eaeSMogball   auto lowerBoundCst =
90c0342a2dSJacques Pienaar       forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
91c0342a2dSJacques Pienaar   auto stepCst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
92f6f88e66Sthomasraoux   if (!upperBoundCst || !lowerBoundCst || !stepCst)
93f6f88e66Sthomasraoux     return false;
94a54f4eaeSMogball   ub = upperBoundCst.value();
95a54f4eaeSMogball   lb = lowerBoundCst.value();
96a54f4eaeSMogball   step = stepCst.value();
97205c08b5SThomas Raoux   peelEpilogue = options.peelEpilogue;
98205c08b5SThomas Raoux   predicateFn = options.predicateFn;
99205c08b5SThomas Raoux   if (!peelEpilogue && predicateFn == nullptr)
100205c08b5SThomas Raoux     return false;
101f6f88e66Sthomasraoux   int64_t numIteration = ceilDiv(ub - lb, step);
102f6f88e66Sthomasraoux   std::vector<std::pair<Operation *, unsigned>> schedule;
103f6f88e66Sthomasraoux   options.getScheduleFn(forOp, schedule);
104f6f88e66Sthomasraoux   if (schedule.empty())
105f6f88e66Sthomasraoux     return false;
106f6f88e66Sthomasraoux 
107f6f88e66Sthomasraoux   opOrder.reserve(schedule.size());
108f6f88e66Sthomasraoux   for (auto &opSchedule : schedule) {
109f6f88e66Sthomasraoux     maxStage = std::max(maxStage, opSchedule.second);
110f6f88e66Sthomasraoux     stages[opSchedule.first] = opSchedule.second;
111f6f88e66Sthomasraoux     opOrder.push_back(opSchedule.first);
112f6f88e66Sthomasraoux   }
113f6f88e66Sthomasraoux   if (numIteration <= maxStage)
114f6f88e66Sthomasraoux     return false;
115f6f88e66Sthomasraoux 
116f6f88e66Sthomasraoux   // All operations need to have a stage.
117f6f88e66Sthomasraoux   if (forOp
118f6f88e66Sthomasraoux           .walk([this](Operation *op) {
119f6f88e66Sthomasraoux             if (op != forOp.getOperation() && !isa<scf::YieldOp>(op) &&
120f6f88e66Sthomasraoux                 stages.find(op) == stages.end())
121f6f88e66Sthomasraoux               return WalkResult::interrupt();
122f6f88e66Sthomasraoux             return WalkResult::advance();
123f6f88e66Sthomasraoux           })
124f6f88e66Sthomasraoux           .wasInterrupted())
125f6f88e66Sthomasraoux     return false;
126f6f88e66Sthomasraoux 
12745cb4140Sthomasraoux   // Only support loop carried dependency with a distance of 1. This means the
12845cb4140Sthomasraoux   // source of all the scf.yield operands needs to be defined by operations in
12945cb4140Sthomasraoux   // the loop.
13045cb4140Sthomasraoux   if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
13145cb4140Sthomasraoux                    [this](Value operand) {
13245cb4140Sthomasraoux                      Operation *def = operand.getDefiningOp();
13345cb4140Sthomasraoux                      return !def || stages.find(def) == stages.end();
13445cb4140Sthomasraoux                    }))
135f6f88e66Sthomasraoux     return false;
1360736bbd7SThomas Raoux   annotateFn = options.annotateFn;
137f6f88e66Sthomasraoux   return true;
138f6f88e66Sthomasraoux }
139f6f88e66Sthomasraoux 
emitPrologue(PatternRewriter & rewriter)140f6f88e66Sthomasraoux void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
14145cb4140Sthomasraoux   // Initialize the iteration argument to the loop initiale values.
14245cb4140Sthomasraoux   for (BlockArgument &arg : forOp.getRegionIterArgs()) {
14345cb4140Sthomasraoux     OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
14445cb4140Sthomasraoux     setValueMapping(arg, operand.get(), 0);
14545cb4140Sthomasraoux   }
14645cb4140Sthomasraoux   auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
147f6f88e66Sthomasraoux   for (int64_t i = 0; i < maxStage; i++) {
148f6f88e66Sthomasraoux     // special handling for induction variable as the increment is implicit.
149c3c1c5c6SThomas Raoux     Value iv =
150c3c1c5c6SThomas Raoux         rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(), lb + i * step);
151f6f88e66Sthomasraoux     setValueMapping(forOp.getInductionVar(), iv, i);
152f6f88e66Sthomasraoux     for (Operation *op : opOrder) {
153f6f88e66Sthomasraoux       if (stages[op] > i)
154f6f88e66Sthomasraoux         continue;
155f6f88e66Sthomasraoux       Operation *newOp = rewriter.clone(*op);
156f6f88e66Sthomasraoux       for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) {
157f6f88e66Sthomasraoux         auto it = valueMapping.find(op->getOperand(opIdx));
158f6f88e66Sthomasraoux         if (it != valueMapping.end())
159f6f88e66Sthomasraoux           newOp->setOperand(opIdx, it->second[i - stages[op]]);
160f6f88e66Sthomasraoux       }
1610736bbd7SThomas Raoux       if (annotateFn)
1620736bbd7SThomas Raoux         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
163f6f88e66Sthomasraoux       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
164f6f88e66Sthomasraoux         setValueMapping(op->getResult(destId), newOp->getResult(destId),
165f6f88e66Sthomasraoux                         i - stages[op]);
16645cb4140Sthomasraoux         // If the value is a loop carried dependency update the loop argument
16745cb4140Sthomasraoux         // mapping.
16845cb4140Sthomasraoux         for (OpOperand &operand : yield->getOpOperands()) {
16945cb4140Sthomasraoux           if (operand.get() != op->getResult(destId))
17045cb4140Sthomasraoux             continue;
17145cb4140Sthomasraoux           setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
17245cb4140Sthomasraoux                           newOp->getResult(destId), i - stages[op] + 1);
17345cb4140Sthomasraoux         }
174f6f88e66Sthomasraoux       }
175f6f88e66Sthomasraoux     }
176f6f88e66Sthomasraoux   }
177f6f88e66Sthomasraoux }
178f6f88e66Sthomasraoux 
179f6f88e66Sthomasraoux llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
analyzeCrossStageValues()180f6f88e66Sthomasraoux LoopPipelinerInternal::analyzeCrossStageValues() {
181f6f88e66Sthomasraoux   llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
182f6f88e66Sthomasraoux   for (Operation *op : opOrder) {
183f6f88e66Sthomasraoux     unsigned stage = stages[op];
184f6f88e66Sthomasraoux     for (OpOperand &operand : op->getOpOperands()) {
185f6f88e66Sthomasraoux       Operation *def = operand.get().getDefiningOp();
186f6f88e66Sthomasraoux       if (!def)
187f6f88e66Sthomasraoux         continue;
188f6f88e66Sthomasraoux       auto defStage = stages.find(def);
189f6f88e66Sthomasraoux       if (defStage == stages.end() || defStage->second == stage)
190f6f88e66Sthomasraoux         continue;
191f6f88e66Sthomasraoux       assert(stage > defStage->second);
192f6f88e66Sthomasraoux       LiverangeInfo &info = crossStageValues[operand.get()];
193f6f88e66Sthomasraoux       info.defStage = defStage->second;
194f6f88e66Sthomasraoux       info.lastUseStage = std::max(info.lastUseStage, stage);
195f6f88e66Sthomasraoux     }
196f6f88e66Sthomasraoux   }
197f6f88e66Sthomasraoux   return crossStageValues;
198f6f88e66Sthomasraoux }
199f6f88e66Sthomasraoux 
createKernelLoop(const llvm::MapVector<Value,LoopPipelinerInternal::LiverangeInfo> & crossStageValues,PatternRewriter & rewriter,llvm::DenseMap<std::pair<Value,unsigned>,unsigned> & loopArgMap)200f6f88e66Sthomasraoux scf::ForOp LoopPipelinerInternal::createKernelLoop(
201f6f88e66Sthomasraoux     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
202f6f88e66Sthomasraoux         &crossStageValues,
203f6f88e66Sthomasraoux     PatternRewriter &rewriter,
204f6f88e66Sthomasraoux     llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
205f6f88e66Sthomasraoux   // Creates the list of initial values associated to values used across
206f6f88e66Sthomasraoux   // stages. The initial values come from the prologue created above.
207f6f88e66Sthomasraoux   // Keep track of the kernel argument associated to each version of the
208f6f88e66Sthomasraoux   // values passed to the kernel.
20945cb4140Sthomasraoux   llvm::SmallVector<Value> newLoopArg;
21045cb4140Sthomasraoux   // For existing loop argument initialize them with the right version from the
21145cb4140Sthomasraoux   // prologue.
212e4853be2SMehdi Amini   for (const auto &retVal :
21345cb4140Sthomasraoux        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
21445cb4140Sthomasraoux     Operation *def = retVal.value().getDefiningOp();
21545cb4140Sthomasraoux     assert(def && "Only support loop carried dependencies of distance 1");
21645cb4140Sthomasraoux     unsigned defStage = stages[def];
21745cb4140Sthomasraoux     Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
21845cb4140Sthomasraoux                                      [maxStage - defStage];
21945cb4140Sthomasraoux     assert(valueVersion);
22045cb4140Sthomasraoux     newLoopArg.push_back(valueVersion);
22145cb4140Sthomasraoux   }
222f6f88e66Sthomasraoux   for (auto escape : crossStageValues) {
223f6f88e66Sthomasraoux     LiverangeInfo &info = escape.second;
224f6f88e66Sthomasraoux     Value value = escape.first;
225f6f88e66Sthomasraoux     for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
226f6f88e66Sthomasraoux          stageIdx++) {
227f6f88e66Sthomasraoux       Value valueVersion =
228f6f88e66Sthomasraoux           valueMapping[value][maxStage - info.lastUseStage + stageIdx];
229f6f88e66Sthomasraoux       assert(valueVersion);
230f6f88e66Sthomasraoux       newLoopArg.push_back(valueVersion);
231f6f88e66Sthomasraoux       loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
232f6f88e66Sthomasraoux                                            stageIdx)] = newLoopArg.size() - 1;
233f6f88e66Sthomasraoux     }
234f6f88e66Sthomasraoux   }
235f6f88e66Sthomasraoux 
236205c08b5SThomas Raoux   // Create the new kernel loop. When we peel the epilgue we need to peel
237205c08b5SThomas Raoux   // `numStages - 1` iterations. Then we adjust the upper bound to remove those
238205c08b5SThomas Raoux   // iterations.
239205c08b5SThomas Raoux   Value newUb = forOp.getUpperBound();
240205c08b5SThomas Raoux   if (peelEpilogue)
241205c08b5SThomas Raoux     newUb = rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(),
242a54f4eaeSMogball                                                     ub - maxStage * step);
243c0342a2dSJacques Pienaar   auto newForOp =
244c0342a2dSJacques Pienaar       rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
245c0342a2dSJacques Pienaar                                   forOp.getStep(), newLoopArg);
246f6f88e66Sthomasraoux   return newForOp;
247f6f88e66Sthomasraoux }
248f6f88e66Sthomasraoux 
createKernel(scf::ForOp newForOp,const llvm::MapVector<Value,LoopPipelinerInternal::LiverangeInfo> & crossStageValues,const llvm::DenseMap<std::pair<Value,unsigned>,unsigned> & loopArgMap,PatternRewriter & rewriter)249f6f88e66Sthomasraoux void LoopPipelinerInternal::createKernel(
250f6f88e66Sthomasraoux     scf::ForOp newForOp,
251f6f88e66Sthomasraoux     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
252f6f88e66Sthomasraoux         &crossStageValues,
253f6f88e66Sthomasraoux     const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
254f6f88e66Sthomasraoux     PatternRewriter &rewriter) {
255f6f88e66Sthomasraoux   valueMapping.clear();
256f6f88e66Sthomasraoux 
257f6f88e66Sthomasraoux   // Create the kernel, we clone instruction based on the order given by
258f6f88e66Sthomasraoux   // user and remap operands coming from a previous stages.
259f6f88e66Sthomasraoux   rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
260f6f88e66Sthomasraoux   BlockAndValueMapping mapping;
261f6f88e66Sthomasraoux   mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
262e4853be2SMehdi Amini   for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
26345cb4140Sthomasraoux     mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
26445cb4140Sthomasraoux   }
265205c08b5SThomas Raoux   SmallVector<Value> predicates(maxStage + 1, nullptr);
266205c08b5SThomas Raoux   if (!peelEpilogue) {
267205c08b5SThomas Raoux     // Create a predicate for each stage except the last stage.
268205c08b5SThomas Raoux     for (unsigned i = 0; i < maxStage; i++) {
269205c08b5SThomas Raoux       Value c = rewriter.create<arith::ConstantIndexOp>(
270205c08b5SThomas Raoux           newForOp.getLoc(), ub - (maxStage - i) * step);
271205c08b5SThomas Raoux       Value pred = rewriter.create<arith::CmpIOp>(
272205c08b5SThomas Raoux           newForOp.getLoc(), arith::CmpIPredicate::slt,
273205c08b5SThomas Raoux           newForOp.getInductionVar(), c);
274205c08b5SThomas Raoux       predicates[i] = pred;
275205c08b5SThomas Raoux     }
276205c08b5SThomas Raoux   }
277f6f88e66Sthomasraoux   for (Operation *op : opOrder) {
278f6f88e66Sthomasraoux     int64_t useStage = stages[op];
279f6f88e66Sthomasraoux     auto *newOp = rewriter.clone(*op, mapping);
280f6f88e66Sthomasraoux     for (OpOperand &operand : op->getOpOperands()) {
281f6f88e66Sthomasraoux       // Special case for the induction variable uses. We replace it with a
282f6f88e66Sthomasraoux       // version incremented based on the stage where it is used.
283f6f88e66Sthomasraoux       if (operand.get() == forOp.getInductionVar()) {
284f6f88e66Sthomasraoux         rewriter.setInsertionPoint(newOp);
285a54f4eaeSMogball         Value offset = rewriter.create<arith::ConstantIndexOp>(
286f6f88e66Sthomasraoux             forOp.getLoc(), (maxStage - stages[op]) * step);
287a54f4eaeSMogball         Value iv = rewriter.create<arith::AddIOp>(
288a54f4eaeSMogball             forOp.getLoc(), newForOp.getInductionVar(), offset);
289f6f88e66Sthomasraoux         newOp->setOperand(operand.getOperandNumber(), iv);
290f6f88e66Sthomasraoux         rewriter.setInsertionPointAfter(newOp);
291f6f88e66Sthomasraoux         continue;
292f6f88e66Sthomasraoux       }
29345cb4140Sthomasraoux       auto arg = operand.get().dyn_cast<BlockArgument>();
29445cb4140Sthomasraoux       if (arg && arg.getOwner() == forOp.getBody()) {
29545cb4140Sthomasraoux         // If the value is a loop carried value coming from stage N + 1 remap,
29645cb4140Sthomasraoux         // it will become a direct use.
29745cb4140Sthomasraoux         Value ret = forOp.getBody()->getTerminator()->getOperand(
29845cb4140Sthomasraoux             arg.getArgNumber() - 1);
29945cb4140Sthomasraoux         Operation *dep = ret.getDefiningOp();
30045cb4140Sthomasraoux         if (!dep)
30145cb4140Sthomasraoux           continue;
30245cb4140Sthomasraoux         auto stageDep = stages.find(dep);
30345cb4140Sthomasraoux         if (stageDep == stages.end() || stageDep->second == useStage)
30445cb4140Sthomasraoux           continue;
30545cb4140Sthomasraoux         assert(stageDep->second == useStage + 1);
30645cb4140Sthomasraoux         newOp->setOperand(operand.getOperandNumber(),
30745cb4140Sthomasraoux                           mapping.lookupOrDefault(ret));
30845cb4140Sthomasraoux         continue;
30945cb4140Sthomasraoux       }
310f6f88e66Sthomasraoux       // For operands defined in a previous stage we need to remap it to use
311f6f88e66Sthomasraoux       // the correct region argument. We look for the right version of the
312f6f88e66Sthomasraoux       // Value based on the stage where it is used.
313f6f88e66Sthomasraoux       Operation *def = operand.get().getDefiningOp();
314f6f88e66Sthomasraoux       if (!def)
315f6f88e66Sthomasraoux         continue;
316f6f88e66Sthomasraoux       auto stageDef = stages.find(def);
317f6f88e66Sthomasraoux       if (stageDef == stages.end() || stageDef->second == useStage)
318f6f88e66Sthomasraoux         continue;
319f6f88e66Sthomasraoux       auto remap = loopArgMap.find(
320f6f88e66Sthomasraoux           std::make_pair(operand.get(), useStage - stageDef->second));
321f6f88e66Sthomasraoux       assert(remap != loopArgMap.end());
322f6f88e66Sthomasraoux       newOp->setOperand(operand.getOperandNumber(),
323f6f88e66Sthomasraoux                         newForOp.getRegionIterArgs()[remap->second]);
324f6f88e66Sthomasraoux     }
325205c08b5SThomas Raoux     if (predicates[useStage]) {
326205c08b5SThomas Raoux       newOp = predicateFn(newOp, predicates[useStage], rewriter);
327205c08b5SThomas Raoux       // Remap the results to the new predicated one.
328205c08b5SThomas Raoux       for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
329205c08b5SThomas Raoux         mapping.map(std::get<0>(values), std::get<1>(values));
330205c08b5SThomas Raoux     }
331205c08b5SThomas Raoux     rewriter.setInsertionPointAfter(newOp);
3320736bbd7SThomas Raoux     if (annotateFn)
3330736bbd7SThomas Raoux       annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
334f6f88e66Sthomasraoux   }
335f6f88e66Sthomasraoux 
336f6f88e66Sthomasraoux   // Collect the Values that need to be returned by the forOp. For each
337f6f88e66Sthomasraoux   // value we need to have `LastUseStage - DefStage` number of versions
338f6f88e66Sthomasraoux   // returned.
339f6f88e66Sthomasraoux   // We create a mapping between original values and the associated loop
340f6f88e66Sthomasraoux   // returned values that will be needed by the epilogue.
341f6f88e66Sthomasraoux   llvm::SmallVector<Value> yieldOperands;
34245cb4140Sthomasraoux   for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
34345cb4140Sthomasraoux     yieldOperands.push_back(mapping.lookupOrDefault(retVal));
34445cb4140Sthomasraoux   }
345f6f88e66Sthomasraoux   for (auto &it : crossStageValues) {
346f6f88e66Sthomasraoux     int64_t version = maxStage - it.second.lastUseStage + 1;
347f6f88e66Sthomasraoux     unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
348f6f88e66Sthomasraoux     // add the original verstion to yield ops.
349f6f88e66Sthomasraoux     // If there is a liverange spanning across more than 2 stages we need to add
350f6f88e66Sthomasraoux     // extra arg.
351f6f88e66Sthomasraoux     for (unsigned i = 1; i < numVersionReturned; i++) {
352f6f88e66Sthomasraoux       setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
353f6f88e66Sthomasraoux                       version++);
354f6f88e66Sthomasraoux       yieldOperands.push_back(
355f6f88e66Sthomasraoux           newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
356f6f88e66Sthomasraoux                                              newForOp.getNumInductionVars()]);
357f6f88e66Sthomasraoux     }
358f6f88e66Sthomasraoux     setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
359f6f88e66Sthomasraoux                     version++);
360f6f88e66Sthomasraoux     yieldOperands.push_back(mapping.lookupOrDefault(it.first));
361f6f88e66Sthomasraoux   }
36245cb4140Sthomasraoux   // Map the yield operand to the forOp returned value.
363e4853be2SMehdi Amini   for (const auto &retVal :
36445cb4140Sthomasraoux        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
36545cb4140Sthomasraoux     Operation *def = retVal.value().getDefiningOp();
36645cb4140Sthomasraoux     assert(def && "Only support loop carried dependencies of distance 1");
36745cb4140Sthomasraoux     unsigned defStage = stages[def];
36845cb4140Sthomasraoux     setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
36945cb4140Sthomasraoux                     newForOp->getResult(retVal.index()),
37045cb4140Sthomasraoux                     maxStage - defStage + 1);
37145cb4140Sthomasraoux   }
372f6f88e66Sthomasraoux   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
373f6f88e66Sthomasraoux }
374f6f88e66Sthomasraoux 
37545cb4140Sthomasraoux llvm::SmallVector<Value>
emitEpilogue(PatternRewriter & rewriter)37645cb4140Sthomasraoux LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
37745cb4140Sthomasraoux   llvm::SmallVector<Value> returnValues(forOp->getNumResults());
378f6f88e66Sthomasraoux   // Emit different versions of the induction variable. They will be
379f6f88e66Sthomasraoux   // removed by dead code if not used.
380f6f88e66Sthomasraoux   for (int64_t i = 0; i < maxStage; i++) {
381a54f4eaeSMogball     Value newlastIter = rewriter.create<arith::ConstantIndexOp>(
382f6f88e66Sthomasraoux         forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i));
383f6f88e66Sthomasraoux     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
384f6f88e66Sthomasraoux   }
385f6f88e66Sthomasraoux   // Emit `maxStage - 1` epilogue part that includes operations fro stages
386f6f88e66Sthomasraoux   // [i; maxStage].
387f6f88e66Sthomasraoux   for (int64_t i = 1; i <= maxStage; i++) {
388f6f88e66Sthomasraoux     for (Operation *op : opOrder) {
389f6f88e66Sthomasraoux       if (stages[op] < i)
390f6f88e66Sthomasraoux         continue;
391f6f88e66Sthomasraoux       Operation *newOp = rewriter.clone(*op);
392f6f88e66Sthomasraoux       for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) {
393f6f88e66Sthomasraoux         auto it = valueMapping.find(op->getOperand(opIdx));
394f6f88e66Sthomasraoux         if (it != valueMapping.end()) {
395f6f88e66Sthomasraoux           Value v = it->second[maxStage - stages[op] + i];
396f6f88e66Sthomasraoux           assert(v);
397f6f88e66Sthomasraoux           newOp->setOperand(opIdx, v);
398f6f88e66Sthomasraoux         }
399f6f88e66Sthomasraoux       }
4000736bbd7SThomas Raoux       if (annotateFn)
4010736bbd7SThomas Raoux         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
402f6f88e66Sthomasraoux       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
403f6f88e66Sthomasraoux         setValueMapping(op->getResult(destId), newOp->getResult(destId),
404f6f88e66Sthomasraoux                         maxStage - stages[op] + i);
40545cb4140Sthomasraoux         // If the value is a loop carried dependency update the loop argument
40645cb4140Sthomasraoux         // mapping and keep track of the last version to replace the original
40745cb4140Sthomasraoux         // forOp uses.
40845cb4140Sthomasraoux         for (OpOperand &operand :
40945cb4140Sthomasraoux              forOp.getBody()->getTerminator()->getOpOperands()) {
41045cb4140Sthomasraoux           if (operand.get() != op->getResult(destId))
41145cb4140Sthomasraoux             continue;
41245cb4140Sthomasraoux           unsigned version = maxStage - stages[op] + i + 1;
41345cb4140Sthomasraoux           // If the version is greater than maxStage it means it maps to the
41445cb4140Sthomasraoux           // original forOp returned value.
41545cb4140Sthomasraoux           if (version > maxStage) {
41645cb4140Sthomasraoux             returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
41745cb4140Sthomasraoux             continue;
41845cb4140Sthomasraoux           }
41945cb4140Sthomasraoux           setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
42045cb4140Sthomasraoux                           newOp->getResult(destId), version);
421f6f88e66Sthomasraoux         }
422f6f88e66Sthomasraoux       }
423f6f88e66Sthomasraoux     }
424f6f88e66Sthomasraoux   }
42545cb4140Sthomasraoux   return returnValues;
42645cb4140Sthomasraoux }
427f6f88e66Sthomasraoux 
setValueMapping(Value key,Value el,int64_t idx)428f6f88e66Sthomasraoux void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
429f6f88e66Sthomasraoux   auto it = valueMapping.find(key);
430f6f88e66Sthomasraoux   // If the value is not in the map yet add a vector big enough to store all
431f6f88e66Sthomasraoux   // versions.
432f6f88e66Sthomasraoux   if (it == valueMapping.end())
433f6f88e66Sthomasraoux     it =
434f6f88e66Sthomasraoux         valueMapping
435f6f88e66Sthomasraoux             .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
436f6f88e66Sthomasraoux             .first;
437f6f88e66Sthomasraoux   it->second[idx] = el;
438f6f88e66Sthomasraoux }
439f6f88e66Sthomasraoux 
4405f0d4f20SAlex Zinenko } // namespace
4415f0d4f20SAlex Zinenko 
returningMatchAndRewrite(ForOp forOp,PatternRewriter & rewriter) const4425f0d4f20SAlex Zinenko FailureOr<ForOp> ForLoopPipeliningPattern::returningMatchAndRewrite(
4435f0d4f20SAlex Zinenko     ForOp forOp, PatternRewriter &rewriter) const {
444f6f88e66Sthomasraoux 
445f6f88e66Sthomasraoux   LoopPipelinerInternal pipeliner;
446f6f88e66Sthomasraoux   if (!pipeliner.initializeLoopInfo(forOp, options))
447f6f88e66Sthomasraoux     return failure();
448f6f88e66Sthomasraoux 
449f6f88e66Sthomasraoux   // 1. Emit prologue.
450f6f88e66Sthomasraoux   pipeliner.emitPrologue(rewriter);
451f6f88e66Sthomasraoux 
452f6f88e66Sthomasraoux   // 2. Track values used across stages. When a value cross stages it will
453f6f88e66Sthomasraoux   // need to be passed as loop iteration arguments.
454f6f88e66Sthomasraoux   // We first collect the values that are used in a different stage than where
455f6f88e66Sthomasraoux   // they are defined.
456f6f88e66Sthomasraoux   llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
457f6f88e66Sthomasraoux       crossStageValues = pipeliner.analyzeCrossStageValues();
458f6f88e66Sthomasraoux 
459f6f88e66Sthomasraoux   // Mapping between original loop values used cross stage and the block
460f6f88e66Sthomasraoux   // arguments associated after pipelining. A Value may map to several
461f6f88e66Sthomasraoux   // arguments if its liverange spans across more than 2 stages.
462f6f88e66Sthomasraoux   llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
463f6f88e66Sthomasraoux   // 3. Create the new kernel loop and return the block arguments mapping.
464f6f88e66Sthomasraoux   ForOp newForOp =
465f6f88e66Sthomasraoux       pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
466f6f88e66Sthomasraoux   // Create the kernel block, order ops based on user choice and remap
467f6f88e66Sthomasraoux   // operands.
468f6f88e66Sthomasraoux   pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter);
469f6f88e66Sthomasraoux 
470205c08b5SThomas Raoux   llvm::SmallVector<Value> returnValues =
471205c08b5SThomas Raoux       newForOp.getResults().take_front(forOp->getNumResults());
472205c08b5SThomas Raoux   if (options.peelEpilogue) {
473f6f88e66Sthomasraoux     // 4. Emit the epilogue after the new forOp.
474f6f88e66Sthomasraoux     rewriter.setInsertionPointAfter(newForOp);
475205c08b5SThomas Raoux     returnValues = pipeliner.emitEpilogue(rewriter);
476205c08b5SThomas Raoux   }
477f6f88e66Sthomasraoux   // 5. Erase the original loop and replace the uses with the epilogue output.
478f6f88e66Sthomasraoux   if (forOp->getNumResults() > 0)
47945cb4140Sthomasraoux     rewriter.replaceOp(forOp, returnValues);
480f6f88e66Sthomasraoux   else
481f6f88e66Sthomasraoux     rewriter.eraseOp(forOp);
482f6f88e66Sthomasraoux 
4835f0d4f20SAlex Zinenko   return newForOp;
484f6f88e66Sthomasraoux }
485f6f88e66Sthomasraoux 
populateSCFLoopPipeliningPatterns(RewritePatternSet & patterns,const PipeliningOption & options)486f6f88e66Sthomasraoux void mlir::scf::populateSCFLoopPipeliningPatterns(
487f6f88e66Sthomasraoux     RewritePatternSet &patterns, const PipeliningOption &options) {
4885f0d4f20SAlex Zinenko   patterns.add<ForLoopPipeliningPattern>(options, patterns.getContext());
489f6f88e66Sthomasraoux }
490