1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
15 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
16 #include "mlir/Dialect/SCF/Utils/Utils.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 /// A simple pattern rewriter that implements no special logic.
25 class SimpleRewriter : public PatternRewriter {
26 public:
SimpleRewriter(MLIRContext * context)27   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
28 };
29 } // namespace
30 
31 //===----------------------------------------------------------------------===//
32 // GetParentForOp
33 //===----------------------------------------------------------------------===//
34 
35 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)36 transform::GetParentForOp::apply(transform::TransformResults &results,
37                                  transform::TransformState &state) {
38   SetVector<Operation *> parents;
39   for (Operation *target : state.getPayloadOps(getTarget())) {
40     scf::ForOp loop;
41     Operation *current = target;
42     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
43       loop = current->getParentOfType<scf::ForOp>();
44       if (!loop) {
45         DiagnosedSilenceableFailure diag = emitSilenceableError()
46                                            << "could not find an '"
47                                            << scf::ForOp::getOperationName()
48                                            << "' parent";
49         diag.attachNote(target->getLoc()) << "target op";
50         return diag;
51       }
52       current = loop;
53     }
54     parents.insert(loop);
55   }
56   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
57   return DiagnosedSilenceableFailure::success();
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // LoopOutlineOp
62 //===----------------------------------------------------------------------===//
63 
64 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
65 /// the provided rewriter for all operations to remain compatible with the
66 /// rewriting infra, as opposed to just splicing the op in place.
wrapInExecuteRegion(RewriterBase & b,Operation * op)67 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
68                                                 Operation *op) {
69   if (op->getNumRegions() != 1)
70     return nullptr;
71   OpBuilder::InsertionGuard g(b);
72   b.setInsertionPoint(op);
73   scf::ExecuteRegionOp executeRegionOp =
74       b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
75   {
76     OpBuilder::InsertionGuard g(b);
77     b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
78     Operation *clonedOp = b.cloneWithoutRegions(*op);
79     Region &clonedRegion = clonedOp->getRegions().front();
80     assert(clonedRegion.empty() && "expected empty region");
81     b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
82                          clonedRegion.end());
83     b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
84   }
85   b.replaceOp(op, executeRegionOp.getResults());
86   return executeRegionOp;
87 }
88 
89 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)90 transform::LoopOutlineOp::apply(transform::TransformResults &results,
91                                 transform::TransformState &state) {
92   SmallVector<Operation *> transformed;
93   DenseMap<Operation *, SymbolTable> symbolTables;
94   for (Operation *target : state.getPayloadOps(getTarget())) {
95     Location location = target->getLoc();
96     Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
97     SimpleRewriter rewriter(getContext());
98     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
99     if (!exec) {
100       DiagnosedSilenceableFailure diag = emitSilenceableError()
101                                          << "failed to outline";
102       diag.attachNote(target->getLoc()) << "target op";
103       return diag;
104     }
105     func::CallOp call;
106     FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
107         rewriter, location, exec.getRegion(), getFuncName(), &call);
108 
109     if (failed(outlined)) {
110       (void)reportUnknownTransformError(target);
111       return DiagnosedSilenceableFailure::definiteFailure();
112     }
113 
114     if (symbolTableOp) {
115       SymbolTable &symbolTable =
116           symbolTables.try_emplace(symbolTableOp, symbolTableOp)
117               .first->getSecond();
118       symbolTable.insert(*outlined);
119       call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
120     }
121     transformed.push_back(*outlined);
122   }
123   results.set(getTransformed().cast<OpResult>(), transformed);
124   return DiagnosedSilenceableFailure::success();
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // LoopPeelOp
129 //===----------------------------------------------------------------------===//
130 
131 DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)132 transform::LoopPeelOp::applyToOne(scf::ForOp target,
133                                   SmallVector<Operation *> &results,
134                                   transform::TransformState &state) {
135   scf::ForOp result;
136   IRRewriter rewriter(target->getContext());
137   // This helper returns failure when peeling does not occur (i.e. when the IR
138   // is not modified). This is not a failure for the op as the postcondition:
139   //    "the loop trip count is divisible by the step"
140   // is valid.
141   LogicalResult status =
142       scf::peelAndCanonicalizeForLoop(rewriter, target, result);
143   // TODO: Return both the peeled loop and the remainder loop.
144   results.push_back(failed(status) ? target : result);
145   return DiagnosedSilenceableFailure(success());
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // LoopPipelineOp
150 //===----------------------------------------------------------------------===//
151 
152 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
153 /// operation to its logical time position given the iteration interval and the
154 /// read latency. The latter is only relevant for vector transfers.
155 static void
loopScheduling(scf::ForOp forOp,std::vector<std::pair<Operation *,unsigned>> & schedule,unsigned iterationInterval,unsigned readLatency)156 loopScheduling(scf::ForOp forOp,
157                std::vector<std::pair<Operation *, unsigned>> &schedule,
158                unsigned iterationInterval, unsigned readLatency) {
159   auto getLatency = [&](Operation *op) -> unsigned {
160     if (isa<vector::TransferReadOp>(op))
161       return readLatency;
162     return 1;
163   };
164 
165   DenseMap<Operation *, unsigned> opCycles;
166   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
167   for (Operation &op : forOp.getBody()->getOperations()) {
168     if (isa<scf::YieldOp>(op))
169       continue;
170     unsigned earlyCycle = 0;
171     for (Value operand : op.getOperands()) {
172       Operation *def = operand.getDefiningOp();
173       if (!def)
174         continue;
175       earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
176     }
177     opCycles[&op] = earlyCycle;
178     wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
179   }
180   for (const auto &it : wrappedSchedule) {
181     for (Operation *op : it.second) {
182       unsigned cycle = opCycles[op];
183       schedule.emplace_back(op, cycle / iterationInterval);
184     }
185   }
186 }
187 
188 DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)189 transform::LoopPipelineOp::applyToOne(scf::ForOp target,
190                                       SmallVector<Operation *> &results,
191                                       transform::TransformState &state) {
192   scf::PipeliningOption options;
193   options.getScheduleFn =
194       [this](scf::ForOp forOp,
195              std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
196         loopScheduling(forOp, schedule, getIterationInterval(),
197                        getReadLatency());
198       };
199   scf::ForLoopPipeliningPattern pattern(options, target->getContext());
200   SimpleRewriter rewriter(getContext());
201   rewriter.setInsertionPoint(target);
202   FailureOr<scf::ForOp> patternResult =
203       pattern.returningMatchAndRewrite(target, rewriter);
204   if (succeeded(patternResult)) {
205     results.push_back(*patternResult);
206     return DiagnosedSilenceableFailure(success());
207   }
208   results.assign(1, nullptr);
209   return emitDefaultSilenceableFailure(target);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // LoopUnrollOp
214 //===----------------------------------------------------------------------===//
215 
216 DiagnosedSilenceableFailure
applyToOne(scf::ForOp target,SmallVector<Operation * > & results,transform::TransformState & state)217 transform::LoopUnrollOp::applyToOne(scf::ForOp target,
218                                     SmallVector<Operation *> &results,
219                                     transform::TransformState &state) {
220   if (failed(loopUnrollByFactor(target, getFactor()))) {
221     Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
222     diag << "op failed to unroll";
223     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
224   }
225   return DiagnosedSilenceableFailure(success());
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // Transform op registration
230 //===----------------------------------------------------------------------===//
231 
232 namespace {
233 class SCFTransformDialectExtension
234     : public transform::TransformDialectExtension<
235           SCFTransformDialectExtension> {
236 public:
237   using Base::Base;
238 
init()239   void init() {
240     declareDependentDialect<pdl::PDLDialect>();
241 
242     declareGeneratedDialect<AffineDialect>();
243     declareGeneratedDialect<func::FuncDialect>();
244 
245     registerTransformOps<
246 #define GET_OP_LIST
247 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
248         >();
249   }
250 };
251 } // namespace
252 
253 #define GET_OP_CLASSES
254 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
255 
registerTransformDialectExtension(DialectRegistry & registry)256 void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
257   registry.addExtensions<SCFTransformDialectExtension>();
258 }
259