1f6f88e66Sthomasraoux //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// 2f6f88e66Sthomasraoux // 3f6f88e66Sthomasraoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f6f88e66Sthomasraoux // See https://llvm.org/LICENSE.txt for license information. 5f6f88e66Sthomasraoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f6f88e66Sthomasraoux // 7f6f88e66Sthomasraoux //===----------------------------------------------------------------------===// 8f6f88e66Sthomasraoux // 9f6f88e66Sthomasraoux // This file implements loop software pipelining 10f6f88e66Sthomasraoux // 11f6f88e66Sthomasraoux //===----------------------------------------------------------------------===// 12f6f88e66Sthomasraoux 13f6f88e66Sthomasraoux #include "PassDetail.h" 14*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15f6f88e66Sthomasraoux #include "mlir/Dialect/SCF/SCF.h" 16f6f88e66Sthomasraoux #include "mlir/Dialect/SCF/Transforms.h" 17f6f88e66Sthomasraoux #include "mlir/Dialect/SCF/Utils.h" 18f6f88e66Sthomasraoux #include "mlir/Dialect/StandardOps/IR/Ops.h" 19f6f88e66Sthomasraoux #include "mlir/IR/BlockAndValueMapping.h" 20f6f88e66Sthomasraoux #include "mlir/IR/PatternMatch.h" 21f6f88e66Sthomasraoux #include "mlir/Support/MathExtras.h" 22f6f88e66Sthomasraoux 23f6f88e66Sthomasraoux using namespace mlir; 24f6f88e66Sthomasraoux using namespace mlir::scf; 25f6f88e66Sthomasraoux 26f6f88e66Sthomasraoux namespace { 27f6f88e66Sthomasraoux 28f6f88e66Sthomasraoux /// Helper to keep internal information during pipelining transformation. 29f6f88e66Sthomasraoux struct LoopPipelinerInternal { 30f6f88e66Sthomasraoux /// Coarse liverange information for ops used across stages. 31f6f88e66Sthomasraoux struct LiverangeInfo { 32f6f88e66Sthomasraoux unsigned lastUseStage = 0; 33f6f88e66Sthomasraoux unsigned defStage = 0; 34f6f88e66Sthomasraoux }; 35f6f88e66Sthomasraoux 36f6f88e66Sthomasraoux protected: 37f6f88e66Sthomasraoux ForOp forOp; 38f6f88e66Sthomasraoux unsigned maxStage = 0; 39f6f88e66Sthomasraoux DenseMap<Operation *, unsigned> stages; 40f6f88e66Sthomasraoux std::vector<Operation *> opOrder; 41f6f88e66Sthomasraoux int64_t ub; 42f6f88e66Sthomasraoux int64_t lb; 43f6f88e66Sthomasraoux int64_t step; 44f6f88e66Sthomasraoux 45f6f88e66Sthomasraoux // When peeling the kernel we generate several version of each value for 46f6f88e66Sthomasraoux // different stage of the prologue. This map tracks the mapping between 47f6f88e66Sthomasraoux // original Values in the loop and the different versions 48f6f88e66Sthomasraoux // peeled from the loop. 49f6f88e66Sthomasraoux DenseMap<Value, llvm::SmallVector<Value>> valueMapping; 50f6f88e66Sthomasraoux 51f6f88e66Sthomasraoux /// Assign a value to `valueMapping`, this means `val` represents the version 52f6f88e66Sthomasraoux /// `idx` of `key` in the epilogue. 53f6f88e66Sthomasraoux void setValueMapping(Value key, Value el, int64_t idx); 54f6f88e66Sthomasraoux 55f6f88e66Sthomasraoux public: 56f6f88e66Sthomasraoux /// Initalize the information for the given `op`, return true if it 57f6f88e66Sthomasraoux /// satisfies the pre-condition to apply pipelining. 58f6f88e66Sthomasraoux bool initializeLoopInfo(ForOp op, const PipeliningOption &options); 59f6f88e66Sthomasraoux /// Emits the prologue, this creates `maxStage - 1` part which will contain 60f6f88e66Sthomasraoux /// operations from stages [0; i], where i is the part index. 61f6f88e66Sthomasraoux void emitPrologue(PatternRewriter &rewriter); 62f6f88e66Sthomasraoux /// Gather liverange information for Values that are used in a different stage 63f6f88e66Sthomasraoux /// than its definition. 64f6f88e66Sthomasraoux llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues(); 65f6f88e66Sthomasraoux scf::ForOp createKernelLoop( 66f6f88e66Sthomasraoux const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, 67f6f88e66Sthomasraoux PatternRewriter &rewriter, 68f6f88e66Sthomasraoux llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap); 69f6f88e66Sthomasraoux /// Emits the pipelined kernel. This clones loop operations following user 70f6f88e66Sthomasraoux /// order and remaps operands defined in a different stage as their use. 71f6f88e66Sthomasraoux void createKernel( 72f6f88e66Sthomasraoux scf::ForOp newForOp, 73f6f88e66Sthomasraoux const llvm::MapVector<Value, LiverangeInfo> &crossStageValues, 74f6f88e66Sthomasraoux const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, 75f6f88e66Sthomasraoux PatternRewriter &rewriter); 76f6f88e66Sthomasraoux /// Emits the epilogue, this creates `maxStage - 1` part which will contain 77f6f88e66Sthomasraoux /// operations from stages [i; maxStage], where i is the part index. 7845cb4140Sthomasraoux llvm::SmallVector<Value> emitEpilogue(PatternRewriter &rewriter); 79f6f88e66Sthomasraoux }; 80f6f88e66Sthomasraoux 81f6f88e66Sthomasraoux bool LoopPipelinerInternal::initializeLoopInfo( 82f6f88e66Sthomasraoux ForOp op, const PipeliningOption &options) { 83f6f88e66Sthomasraoux forOp = op; 84*a54f4eaeSMogball auto upperBoundCst = 85*a54f4eaeSMogball forOp.upperBound().getDefiningOp<arith::ConstantIndexOp>(); 86*a54f4eaeSMogball auto lowerBoundCst = 87*a54f4eaeSMogball forOp.lowerBound().getDefiningOp<arith::ConstantIndexOp>(); 88*a54f4eaeSMogball auto stepCst = forOp.step().getDefiningOp<arith::ConstantIndexOp>(); 89f6f88e66Sthomasraoux if (!upperBoundCst || !lowerBoundCst || !stepCst) 90f6f88e66Sthomasraoux return false; 91*a54f4eaeSMogball ub = upperBoundCst.value(); 92*a54f4eaeSMogball lb = lowerBoundCst.value(); 93*a54f4eaeSMogball step = stepCst.value(); 94f6f88e66Sthomasraoux int64_t numIteration = ceilDiv(ub - lb, step); 95f6f88e66Sthomasraoux std::vector<std::pair<Operation *, unsigned>> schedule; 96f6f88e66Sthomasraoux options.getScheduleFn(forOp, schedule); 97f6f88e66Sthomasraoux if (schedule.empty()) 98f6f88e66Sthomasraoux return false; 99f6f88e66Sthomasraoux 100f6f88e66Sthomasraoux opOrder.reserve(schedule.size()); 101f6f88e66Sthomasraoux for (auto &opSchedule : schedule) { 102f6f88e66Sthomasraoux maxStage = std::max(maxStage, opSchedule.second); 103f6f88e66Sthomasraoux stages[opSchedule.first] = opSchedule.second; 104f6f88e66Sthomasraoux opOrder.push_back(opSchedule.first); 105f6f88e66Sthomasraoux } 106f6f88e66Sthomasraoux if (numIteration <= maxStage) 107f6f88e66Sthomasraoux return false; 108f6f88e66Sthomasraoux 109f6f88e66Sthomasraoux // All operations need to have a stage. 110f6f88e66Sthomasraoux if (forOp 111f6f88e66Sthomasraoux .walk([this](Operation *op) { 112f6f88e66Sthomasraoux if (op != forOp.getOperation() && !isa<scf::YieldOp>(op) && 113f6f88e66Sthomasraoux stages.find(op) == stages.end()) 114f6f88e66Sthomasraoux return WalkResult::interrupt(); 115f6f88e66Sthomasraoux return WalkResult::advance(); 116f6f88e66Sthomasraoux }) 117f6f88e66Sthomasraoux .wasInterrupted()) 118f6f88e66Sthomasraoux return false; 119f6f88e66Sthomasraoux 12045cb4140Sthomasraoux // Only support loop carried dependency with a distance of 1. This means the 12145cb4140Sthomasraoux // source of all the scf.yield operands needs to be defined by operations in 12245cb4140Sthomasraoux // the loop. 12345cb4140Sthomasraoux if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), 12445cb4140Sthomasraoux [this](Value operand) { 12545cb4140Sthomasraoux Operation *def = operand.getDefiningOp(); 12645cb4140Sthomasraoux return !def || stages.find(def) == stages.end(); 12745cb4140Sthomasraoux })) 128f6f88e66Sthomasraoux return false; 129f6f88e66Sthomasraoux return true; 130f6f88e66Sthomasraoux } 131f6f88e66Sthomasraoux 132f6f88e66Sthomasraoux void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) { 13345cb4140Sthomasraoux // Initialize the iteration argument to the loop initiale values. 13445cb4140Sthomasraoux for (BlockArgument &arg : forOp.getRegionIterArgs()) { 13545cb4140Sthomasraoux OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); 13645cb4140Sthomasraoux setValueMapping(arg, operand.get(), 0); 13745cb4140Sthomasraoux } 13845cb4140Sthomasraoux auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); 139f6f88e66Sthomasraoux for (int64_t i = 0; i < maxStage; i++) { 140f6f88e66Sthomasraoux // special handling for induction variable as the increment is implicit. 141*a54f4eaeSMogball Value iv = rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(), lb + i); 142f6f88e66Sthomasraoux setValueMapping(forOp.getInductionVar(), iv, i); 143f6f88e66Sthomasraoux for (Operation *op : opOrder) { 144f6f88e66Sthomasraoux if (stages[op] > i) 145f6f88e66Sthomasraoux continue; 146f6f88e66Sthomasraoux Operation *newOp = rewriter.clone(*op); 147f6f88e66Sthomasraoux for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) { 148f6f88e66Sthomasraoux auto it = valueMapping.find(op->getOperand(opIdx)); 149f6f88e66Sthomasraoux if (it != valueMapping.end()) 150f6f88e66Sthomasraoux newOp->setOperand(opIdx, it->second[i - stages[op]]); 151f6f88e66Sthomasraoux } 152f6f88e66Sthomasraoux for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { 153f6f88e66Sthomasraoux setValueMapping(op->getResult(destId), newOp->getResult(destId), 154f6f88e66Sthomasraoux i - stages[op]); 15545cb4140Sthomasraoux // If the value is a loop carried dependency update the loop argument 15645cb4140Sthomasraoux // mapping. 15745cb4140Sthomasraoux for (OpOperand &operand : yield->getOpOperands()) { 15845cb4140Sthomasraoux if (operand.get() != op->getResult(destId)) 15945cb4140Sthomasraoux continue; 16045cb4140Sthomasraoux setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], 16145cb4140Sthomasraoux newOp->getResult(destId), i - stages[op] + 1); 16245cb4140Sthomasraoux } 163f6f88e66Sthomasraoux } 164f6f88e66Sthomasraoux } 165f6f88e66Sthomasraoux } 166f6f88e66Sthomasraoux } 167f6f88e66Sthomasraoux 168f6f88e66Sthomasraoux llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 169f6f88e66Sthomasraoux LoopPipelinerInternal::analyzeCrossStageValues() { 170f6f88e66Sthomasraoux llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues; 171f6f88e66Sthomasraoux for (Operation *op : opOrder) { 172f6f88e66Sthomasraoux unsigned stage = stages[op]; 173f6f88e66Sthomasraoux for (OpOperand &operand : op->getOpOperands()) { 174f6f88e66Sthomasraoux Operation *def = operand.get().getDefiningOp(); 175f6f88e66Sthomasraoux if (!def) 176f6f88e66Sthomasraoux continue; 177f6f88e66Sthomasraoux auto defStage = stages.find(def); 178f6f88e66Sthomasraoux if (defStage == stages.end() || defStage->second == stage) 179f6f88e66Sthomasraoux continue; 180f6f88e66Sthomasraoux assert(stage > defStage->second); 181f6f88e66Sthomasraoux LiverangeInfo &info = crossStageValues[operand.get()]; 182f6f88e66Sthomasraoux info.defStage = defStage->second; 183f6f88e66Sthomasraoux info.lastUseStage = std::max(info.lastUseStage, stage); 184f6f88e66Sthomasraoux } 185f6f88e66Sthomasraoux } 186f6f88e66Sthomasraoux return crossStageValues; 187f6f88e66Sthomasraoux } 188f6f88e66Sthomasraoux 189f6f88e66Sthomasraoux scf::ForOp LoopPipelinerInternal::createKernelLoop( 190f6f88e66Sthomasraoux const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 191f6f88e66Sthomasraoux &crossStageValues, 192f6f88e66Sthomasraoux PatternRewriter &rewriter, 193f6f88e66Sthomasraoux llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) { 194f6f88e66Sthomasraoux // Creates the list of initial values associated to values used across 195f6f88e66Sthomasraoux // stages. The initial values come from the prologue created above. 196f6f88e66Sthomasraoux // Keep track of the kernel argument associated to each version of the 197f6f88e66Sthomasraoux // values passed to the kernel. 19845cb4140Sthomasraoux llvm::SmallVector<Value> newLoopArg; 19945cb4140Sthomasraoux // For existing loop argument initialize them with the right version from the 20045cb4140Sthomasraoux // prologue. 20145cb4140Sthomasraoux for (auto retVal : 20245cb4140Sthomasraoux llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { 20345cb4140Sthomasraoux Operation *def = retVal.value().getDefiningOp(); 20445cb4140Sthomasraoux assert(def && "Only support loop carried dependencies of distance 1"); 20545cb4140Sthomasraoux unsigned defStage = stages[def]; 20645cb4140Sthomasraoux Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]] 20745cb4140Sthomasraoux [maxStage - defStage]; 20845cb4140Sthomasraoux assert(valueVersion); 20945cb4140Sthomasraoux newLoopArg.push_back(valueVersion); 21045cb4140Sthomasraoux } 211f6f88e66Sthomasraoux for (auto escape : crossStageValues) { 212f6f88e66Sthomasraoux LiverangeInfo &info = escape.second; 213f6f88e66Sthomasraoux Value value = escape.first; 214f6f88e66Sthomasraoux for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; 215f6f88e66Sthomasraoux stageIdx++) { 216f6f88e66Sthomasraoux Value valueVersion = 217f6f88e66Sthomasraoux valueMapping[value][maxStage - info.lastUseStage + stageIdx]; 218f6f88e66Sthomasraoux assert(valueVersion); 219f6f88e66Sthomasraoux newLoopArg.push_back(valueVersion); 220f6f88e66Sthomasraoux loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - 221f6f88e66Sthomasraoux stageIdx)] = newLoopArg.size() - 1; 222f6f88e66Sthomasraoux } 223f6f88e66Sthomasraoux } 224f6f88e66Sthomasraoux 225f6f88e66Sthomasraoux // Create the new kernel loop. Since we need to peel `numStages - 1` 226f6f88e66Sthomasraoux // iteration we change the upper bound to remove those iterations. 227*a54f4eaeSMogball Value newUb = rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(), 228*a54f4eaeSMogball ub - maxStage * step); 229f6f88e66Sthomasraoux auto newForOp = rewriter.create<scf::ForOp>( 230f6f88e66Sthomasraoux forOp.getLoc(), forOp.lowerBound(), newUb, forOp.step(), newLoopArg); 231f6f88e66Sthomasraoux return newForOp; 232f6f88e66Sthomasraoux } 233f6f88e66Sthomasraoux 234f6f88e66Sthomasraoux void LoopPipelinerInternal::createKernel( 235f6f88e66Sthomasraoux scf::ForOp newForOp, 236f6f88e66Sthomasraoux const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 237f6f88e66Sthomasraoux &crossStageValues, 238f6f88e66Sthomasraoux const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap, 239f6f88e66Sthomasraoux PatternRewriter &rewriter) { 240f6f88e66Sthomasraoux valueMapping.clear(); 241f6f88e66Sthomasraoux 242f6f88e66Sthomasraoux // Create the kernel, we clone instruction based on the order given by 243f6f88e66Sthomasraoux // user and remap operands coming from a previous stages. 244f6f88e66Sthomasraoux rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 245f6f88e66Sthomasraoux BlockAndValueMapping mapping; 246f6f88e66Sthomasraoux mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); 24745cb4140Sthomasraoux for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) { 24845cb4140Sthomasraoux mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); 24945cb4140Sthomasraoux } 250f6f88e66Sthomasraoux for (Operation *op : opOrder) { 251f6f88e66Sthomasraoux int64_t useStage = stages[op]; 252f6f88e66Sthomasraoux auto *newOp = rewriter.clone(*op, mapping); 253f6f88e66Sthomasraoux for (OpOperand &operand : op->getOpOperands()) { 254f6f88e66Sthomasraoux // Special case for the induction variable uses. We replace it with a 255f6f88e66Sthomasraoux // version incremented based on the stage where it is used. 256f6f88e66Sthomasraoux if (operand.get() == forOp.getInductionVar()) { 257f6f88e66Sthomasraoux rewriter.setInsertionPoint(newOp); 258*a54f4eaeSMogball Value offset = rewriter.create<arith::ConstantIndexOp>( 259f6f88e66Sthomasraoux forOp.getLoc(), (maxStage - stages[op]) * step); 260*a54f4eaeSMogball Value iv = rewriter.create<arith::AddIOp>( 261*a54f4eaeSMogball forOp.getLoc(), newForOp.getInductionVar(), offset); 262f6f88e66Sthomasraoux newOp->setOperand(operand.getOperandNumber(), iv); 263f6f88e66Sthomasraoux rewriter.setInsertionPointAfter(newOp); 264f6f88e66Sthomasraoux continue; 265f6f88e66Sthomasraoux } 26645cb4140Sthomasraoux auto arg = operand.get().dyn_cast<BlockArgument>(); 26745cb4140Sthomasraoux if (arg && arg.getOwner() == forOp.getBody()) { 26845cb4140Sthomasraoux // If the value is a loop carried value coming from stage N + 1 remap, 26945cb4140Sthomasraoux // it will become a direct use. 27045cb4140Sthomasraoux Value ret = forOp.getBody()->getTerminator()->getOperand( 27145cb4140Sthomasraoux arg.getArgNumber() - 1); 27245cb4140Sthomasraoux Operation *dep = ret.getDefiningOp(); 27345cb4140Sthomasraoux if (!dep) 27445cb4140Sthomasraoux continue; 27545cb4140Sthomasraoux auto stageDep = stages.find(dep); 27645cb4140Sthomasraoux if (stageDep == stages.end() || stageDep->second == useStage) 27745cb4140Sthomasraoux continue; 27845cb4140Sthomasraoux assert(stageDep->second == useStage + 1); 27945cb4140Sthomasraoux newOp->setOperand(operand.getOperandNumber(), 28045cb4140Sthomasraoux mapping.lookupOrDefault(ret)); 28145cb4140Sthomasraoux continue; 28245cb4140Sthomasraoux } 283f6f88e66Sthomasraoux // For operands defined in a previous stage we need to remap it to use 284f6f88e66Sthomasraoux // the correct region argument. We look for the right version of the 285f6f88e66Sthomasraoux // Value based on the stage where it is used. 286f6f88e66Sthomasraoux Operation *def = operand.get().getDefiningOp(); 287f6f88e66Sthomasraoux if (!def) 288f6f88e66Sthomasraoux continue; 289f6f88e66Sthomasraoux auto stageDef = stages.find(def); 290f6f88e66Sthomasraoux if (stageDef == stages.end() || stageDef->second == useStage) 291f6f88e66Sthomasraoux continue; 292f6f88e66Sthomasraoux auto remap = loopArgMap.find( 293f6f88e66Sthomasraoux std::make_pair(operand.get(), useStage - stageDef->second)); 294f6f88e66Sthomasraoux assert(remap != loopArgMap.end()); 295f6f88e66Sthomasraoux newOp->setOperand(operand.getOperandNumber(), 296f6f88e66Sthomasraoux newForOp.getRegionIterArgs()[remap->second]); 297f6f88e66Sthomasraoux } 298f6f88e66Sthomasraoux } 299f6f88e66Sthomasraoux 300f6f88e66Sthomasraoux // Collect the Values that need to be returned by the forOp. For each 301f6f88e66Sthomasraoux // value we need to have `LastUseStage - DefStage` number of versions 302f6f88e66Sthomasraoux // returned. 303f6f88e66Sthomasraoux // We create a mapping between original values and the associated loop 304f6f88e66Sthomasraoux // returned values that will be needed by the epilogue. 305f6f88e66Sthomasraoux llvm::SmallVector<Value> yieldOperands; 30645cb4140Sthomasraoux for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) { 30745cb4140Sthomasraoux yieldOperands.push_back(mapping.lookupOrDefault(retVal)); 30845cb4140Sthomasraoux } 309f6f88e66Sthomasraoux for (auto &it : crossStageValues) { 310f6f88e66Sthomasraoux int64_t version = maxStage - it.second.lastUseStage + 1; 311f6f88e66Sthomasraoux unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; 312f6f88e66Sthomasraoux // add the original verstion to yield ops. 313f6f88e66Sthomasraoux // If there is a liverange spanning across more than 2 stages we need to add 314f6f88e66Sthomasraoux // extra arg. 315f6f88e66Sthomasraoux for (unsigned i = 1; i < numVersionReturned; i++) { 316f6f88e66Sthomasraoux setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), 317f6f88e66Sthomasraoux version++); 318f6f88e66Sthomasraoux yieldOperands.push_back( 319f6f88e66Sthomasraoux newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + 320f6f88e66Sthomasraoux newForOp.getNumInductionVars()]); 321f6f88e66Sthomasraoux } 322f6f88e66Sthomasraoux setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), 323f6f88e66Sthomasraoux version++); 324f6f88e66Sthomasraoux yieldOperands.push_back(mapping.lookupOrDefault(it.first)); 325f6f88e66Sthomasraoux } 32645cb4140Sthomasraoux // Map the yield operand to the forOp returned value. 32745cb4140Sthomasraoux for (auto retVal : 32845cb4140Sthomasraoux llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { 32945cb4140Sthomasraoux Operation *def = retVal.value().getDefiningOp(); 33045cb4140Sthomasraoux assert(def && "Only support loop carried dependencies of distance 1"); 33145cb4140Sthomasraoux unsigned defStage = stages[def]; 33245cb4140Sthomasraoux setValueMapping(forOp.getRegionIterArgs()[retVal.index()], 33345cb4140Sthomasraoux newForOp->getResult(retVal.index()), 33445cb4140Sthomasraoux maxStage - defStage + 1); 33545cb4140Sthomasraoux } 336f6f88e66Sthomasraoux rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands); 337f6f88e66Sthomasraoux } 338f6f88e66Sthomasraoux 33945cb4140Sthomasraoux llvm::SmallVector<Value> 34045cb4140Sthomasraoux LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) { 34145cb4140Sthomasraoux llvm::SmallVector<Value> returnValues(forOp->getNumResults()); 342f6f88e66Sthomasraoux // Emit different versions of the induction variable. They will be 343f6f88e66Sthomasraoux // removed by dead code if not used. 344f6f88e66Sthomasraoux for (int64_t i = 0; i < maxStage; i++) { 345*a54f4eaeSMogball Value newlastIter = rewriter.create<arith::ConstantIndexOp>( 346f6f88e66Sthomasraoux forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i)); 347f6f88e66Sthomasraoux setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); 348f6f88e66Sthomasraoux } 349f6f88e66Sthomasraoux // Emit `maxStage - 1` epilogue part that includes operations fro stages 350f6f88e66Sthomasraoux // [i; maxStage]. 351f6f88e66Sthomasraoux for (int64_t i = 1; i <= maxStage; i++) { 352f6f88e66Sthomasraoux for (Operation *op : opOrder) { 353f6f88e66Sthomasraoux if (stages[op] < i) 354f6f88e66Sthomasraoux continue; 355f6f88e66Sthomasraoux Operation *newOp = rewriter.clone(*op); 356f6f88e66Sthomasraoux for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) { 357f6f88e66Sthomasraoux auto it = valueMapping.find(op->getOperand(opIdx)); 358f6f88e66Sthomasraoux if (it != valueMapping.end()) { 359f6f88e66Sthomasraoux Value v = it->second[maxStage - stages[op] + i]; 360f6f88e66Sthomasraoux assert(v); 361f6f88e66Sthomasraoux newOp->setOperand(opIdx, v); 362f6f88e66Sthomasraoux } 363f6f88e66Sthomasraoux } 364f6f88e66Sthomasraoux for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { 365f6f88e66Sthomasraoux setValueMapping(op->getResult(destId), newOp->getResult(destId), 366f6f88e66Sthomasraoux maxStage - stages[op] + i); 36745cb4140Sthomasraoux // If the value is a loop carried dependency update the loop argument 36845cb4140Sthomasraoux // mapping and keep track of the last version to replace the original 36945cb4140Sthomasraoux // forOp uses. 37045cb4140Sthomasraoux for (OpOperand &operand : 37145cb4140Sthomasraoux forOp.getBody()->getTerminator()->getOpOperands()) { 37245cb4140Sthomasraoux if (operand.get() != op->getResult(destId)) 37345cb4140Sthomasraoux continue; 37445cb4140Sthomasraoux unsigned version = maxStage - stages[op] + i + 1; 37545cb4140Sthomasraoux // If the version is greater than maxStage it means it maps to the 37645cb4140Sthomasraoux // original forOp returned value. 37745cb4140Sthomasraoux if (version > maxStage) { 37845cb4140Sthomasraoux returnValues[operand.getOperandNumber()] = newOp->getResult(destId); 37945cb4140Sthomasraoux continue; 38045cb4140Sthomasraoux } 38145cb4140Sthomasraoux setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], 38245cb4140Sthomasraoux newOp->getResult(destId), version); 383f6f88e66Sthomasraoux } 384f6f88e66Sthomasraoux } 385f6f88e66Sthomasraoux } 386f6f88e66Sthomasraoux } 38745cb4140Sthomasraoux return returnValues; 38845cb4140Sthomasraoux } 389f6f88e66Sthomasraoux 390f6f88e66Sthomasraoux void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { 391f6f88e66Sthomasraoux auto it = valueMapping.find(key); 392f6f88e66Sthomasraoux // If the value is not in the map yet add a vector big enough to store all 393f6f88e66Sthomasraoux // versions. 394f6f88e66Sthomasraoux if (it == valueMapping.end()) 395f6f88e66Sthomasraoux it = 396f6f88e66Sthomasraoux valueMapping 397f6f88e66Sthomasraoux .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1))) 398f6f88e66Sthomasraoux .first; 399f6f88e66Sthomasraoux it->second[idx] = el; 400f6f88e66Sthomasraoux } 401f6f88e66Sthomasraoux 402f6f88e66Sthomasraoux /// Generate a pipelined version of the scf.for loop based on the schedule given 403f6f88e66Sthomasraoux /// as option. This applies the mechanical transformation of changing the loop 404f6f88e66Sthomasraoux /// and generating the prologue/epilogue for the pipelining and doesn't make any 405f6f88e66Sthomasraoux /// decision regarding the schedule. 406f6f88e66Sthomasraoux /// Based on the option the loop is split into several stages. 407f6f88e66Sthomasraoux /// The transformation assumes that the scheduling given by user is valid. 408f6f88e66Sthomasraoux /// For example if we break a loop into 3 stages named S0, S1, S2 we would 409f6f88e66Sthomasraoux /// generate the following code with the number in parenthesis the iteration 410f6f88e66Sthomasraoux /// index: 411f6f88e66Sthomasraoux /// S0(0) // Prologue 412f6f88e66Sthomasraoux /// S0(1) S1(0) // Prologue 413f6f88e66Sthomasraoux /// scf.for %I = %C0 to %N - 2 { 414f6f88e66Sthomasraoux /// S0(I+2) S1(I+1) S2(I) // Pipelined kernel 415f6f88e66Sthomasraoux /// } 416f6f88e66Sthomasraoux /// S1(N) S2(N-1) // Epilogue 417f6f88e66Sthomasraoux /// S2(N) // Epilogue 418f6f88e66Sthomasraoux struct ForLoopPipelining : public OpRewritePattern<ForOp> { 419f6f88e66Sthomasraoux ForLoopPipelining(const PipeliningOption &options, MLIRContext *context) 420f6f88e66Sthomasraoux : OpRewritePattern<ForOp>(context), options(options) {} 421f6f88e66Sthomasraoux LogicalResult matchAndRewrite(ForOp forOp, 422f6f88e66Sthomasraoux PatternRewriter &rewriter) const override { 423f6f88e66Sthomasraoux 424f6f88e66Sthomasraoux LoopPipelinerInternal pipeliner; 425f6f88e66Sthomasraoux if (!pipeliner.initializeLoopInfo(forOp, options)) 426f6f88e66Sthomasraoux return failure(); 427f6f88e66Sthomasraoux 428f6f88e66Sthomasraoux // 1. Emit prologue. 429f6f88e66Sthomasraoux pipeliner.emitPrologue(rewriter); 430f6f88e66Sthomasraoux 431f6f88e66Sthomasraoux // 2. Track values used across stages. When a value cross stages it will 432f6f88e66Sthomasraoux // need to be passed as loop iteration arguments. 433f6f88e66Sthomasraoux // We first collect the values that are used in a different stage than where 434f6f88e66Sthomasraoux // they are defined. 435f6f88e66Sthomasraoux llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> 436f6f88e66Sthomasraoux crossStageValues = pipeliner.analyzeCrossStageValues(); 437f6f88e66Sthomasraoux 438f6f88e66Sthomasraoux // Mapping between original loop values used cross stage and the block 439f6f88e66Sthomasraoux // arguments associated after pipelining. A Value may map to several 440f6f88e66Sthomasraoux // arguments if its liverange spans across more than 2 stages. 441f6f88e66Sthomasraoux llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap; 442f6f88e66Sthomasraoux // 3. Create the new kernel loop and return the block arguments mapping. 443f6f88e66Sthomasraoux ForOp newForOp = 444f6f88e66Sthomasraoux pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); 445f6f88e66Sthomasraoux // Create the kernel block, order ops based on user choice and remap 446f6f88e66Sthomasraoux // operands. 447f6f88e66Sthomasraoux pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter); 448f6f88e66Sthomasraoux 449f6f88e66Sthomasraoux // 4. Emit the epilogue after the new forOp. 450f6f88e66Sthomasraoux rewriter.setInsertionPointAfter(newForOp); 45145cb4140Sthomasraoux llvm::SmallVector<Value> returnValues = pipeliner.emitEpilogue(rewriter); 452f6f88e66Sthomasraoux 453f6f88e66Sthomasraoux // 5. Erase the original loop and replace the uses with the epilogue output. 454f6f88e66Sthomasraoux if (forOp->getNumResults() > 0) 45545cb4140Sthomasraoux rewriter.replaceOp(forOp, returnValues); 456f6f88e66Sthomasraoux else 457f6f88e66Sthomasraoux rewriter.eraseOp(forOp); 458f6f88e66Sthomasraoux 459f6f88e66Sthomasraoux return success(); 460f6f88e66Sthomasraoux } 461f6f88e66Sthomasraoux 462f6f88e66Sthomasraoux protected: 463f6f88e66Sthomasraoux PipeliningOption options; 464f6f88e66Sthomasraoux }; 465f6f88e66Sthomasraoux 466f6f88e66Sthomasraoux } // namespace 467f6f88e66Sthomasraoux 468f6f88e66Sthomasraoux void mlir::scf::populateSCFLoopPipeliningPatterns( 469f6f88e66Sthomasraoux RewritePatternSet &patterns, const PipeliningOption &options) { 470f6f88e66Sthomasraoux patterns.add<ForLoopPipelining>(options, patterns.getContext()); 471f6f88e66Sthomasraoux } 472