1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
14 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
15 #include "mlir/Dialect/SCF/Utils/Utils.h"
16 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
17 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 /// A simple pattern rewriter that implements no special logic.
24 class SimpleRewriter : public PatternRewriter {
25 public:
26   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
27 };
28 } // namespace
29 
30 //===----------------------------------------------------------------------===//
31 // GetParentForOp
32 //===----------------------------------------------------------------------===//
33 
34 DiagnosedSilenceableFailure
35 transform::GetParentForOp::apply(transform::TransformResults &results,
36                                  transform::TransformState &state) {
37   SetVector<Operation *> parents;
38   for (Operation *target : state.getPayloadOps(getTarget())) {
39     scf::ForOp loop;
40     Operation *current = target;
41     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
42       loop = current->getParentOfType<scf::ForOp>();
43       if (!loop) {
44         DiagnosedSilenceableFailure diag = emitSilenceableError()
45                                            << "could not find an '"
46                                            << scf::ForOp::getOperationName()
47                                            << "' parent";
48         diag.attachNote(target->getLoc()) << "target op";
49         return diag;
50       }
51       current = loop;
52     }
53     parents.insert(loop);
54   }
55   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
56   return DiagnosedSilenceableFailure::success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // LoopOutlineOp
61 //===----------------------------------------------------------------------===//
62 
63 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
64 /// the provided rewriter for all operations to remain compatible with the
65 /// rewriting infra, as opposed to just splicing the op in place.
66 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
67                                                 Operation *op) {
68   if (op->getNumRegions() != 1)
69     return nullptr;
70   OpBuilder::InsertionGuard g(b);
71   b.setInsertionPoint(op);
72   scf::ExecuteRegionOp executeRegionOp =
73       b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
74   {
75     OpBuilder::InsertionGuard g(b);
76     b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
77     Operation *clonedOp = b.cloneWithoutRegions(*op);
78     Region &clonedRegion = clonedOp->getRegions().front();
79     assert(clonedRegion.empty() && "expected empty region");
80     b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
81                          clonedRegion.end());
82     b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
83   }
84   b.replaceOp(op, executeRegionOp.getResults());
85   return executeRegionOp;
86 }
87 
88 DiagnosedSilenceableFailure
89 transform::LoopOutlineOp::apply(transform::TransformResults &results,
90                                 transform::TransformState &state) {
91   SmallVector<Operation *> transformed;
92   DenseMap<Operation *, SymbolTable> symbolTables;
93   for (Operation *target : state.getPayloadOps(getTarget())) {
94     Location location = target->getLoc();
95     Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
96     SimpleRewriter rewriter(getContext());
97     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
98     if (!exec) {
99       DiagnosedSilenceableFailure diag = emitSilenceableError()
100                                          << "failed to outline";
101       diag.attachNote(target->getLoc()) << "target op";
102       return diag;
103     }
104     func::CallOp call;
105     FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
106         rewriter, location, exec.getRegion(), getFuncName(), &call);
107 
108     if (failed(outlined)) {
109       (void)reportUnknownTransformError(target);
110       return DiagnosedSilenceableFailure::definiteFailure();
111     }
112 
113     if (symbolTableOp) {
114       SymbolTable &symbolTable =
115           symbolTables.try_emplace(symbolTableOp, symbolTableOp)
116               .first->getSecond();
117       symbolTable.insert(*outlined);
118       call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
119     }
120     transformed.push_back(*outlined);
121   }
122   results.set(getTransformed().cast<OpResult>(), transformed);
123   return DiagnosedSilenceableFailure::success();
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // LoopPeelOp
128 //===----------------------------------------------------------------------===//
129 
130 DiagnosedSilenceableFailure
131 transform::LoopPeelOp::applyToOne(scf::ForOp target,
132                                   SmallVector<Operation *> &results,
133                                   transform::TransformState &state) {
134   scf::ForOp result;
135   IRRewriter rewriter(target->getContext());
136   // This helper returns failure when peeling does not occur (i.e. when the IR
137   // is not modified). This is not a failure for the op as the postcondition:
138   //    "the loop trip count is divisible by the step"
139   // is valid.
140   LogicalResult status =
141       scf::peelAndCanonicalizeForLoop(rewriter, target, result);
142   // TODO: Return both the peeled loop and the remainder loop.
143   results.push_back(failed(status) ? target : result);
144   return DiagnosedSilenceableFailure(success());
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // LoopPipelineOp
149 //===----------------------------------------------------------------------===//
150 
151 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
152 /// operation to its logical time position given the iteration interval and the
153 /// read latency. The latter is only relevant for vector transfers.
154 static void
155 loopScheduling(scf::ForOp forOp,
156                std::vector<std::pair<Operation *, unsigned>> &schedule,
157                unsigned iterationInterval, unsigned readLatency) {
158   auto getLatency = [&](Operation *op) -> unsigned {
159     if (isa<vector::TransferReadOp>(op))
160       return readLatency;
161     return 1;
162   };
163 
164   DenseMap<Operation *, unsigned> opCycles;
165   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
166   for (Operation &op : forOp.getBody()->getOperations()) {
167     if (isa<scf::YieldOp>(op))
168       continue;
169     unsigned earlyCycle = 0;
170     for (Value operand : op.getOperands()) {
171       Operation *def = operand.getDefiningOp();
172       if (!def)
173         continue;
174       earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
175     }
176     opCycles[&op] = earlyCycle;
177     wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
178   }
179   for (const auto &it : wrappedSchedule) {
180     for (Operation *op : it.second) {
181       unsigned cycle = opCycles[op];
182       schedule.emplace_back(op, cycle / iterationInterval);
183     }
184   }
185 }
186 
187 DiagnosedSilenceableFailure
188 transform::LoopPipelineOp::applyToOne(scf::ForOp target,
189                                       SmallVector<Operation *> &results,
190                                       transform::TransformState &state) {
191   scf::PipeliningOption options;
192   options.getScheduleFn =
193       [this](scf::ForOp forOp,
194              std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
195         loopScheduling(forOp, schedule, getIterationInterval(),
196                        getReadLatency());
197       };
198   scf::ForLoopPipeliningPattern pattern(options, target->getContext());
199   SimpleRewriter rewriter(getContext());
200   rewriter.setInsertionPoint(target);
201   FailureOr<scf::ForOp> patternResult =
202       pattern.returningMatchAndRewrite(target, rewriter);
203   if (succeeded(patternResult)) {
204     results.push_back(*patternResult);
205     return DiagnosedSilenceableFailure(success());
206   }
207   results.assign(1, nullptr);
208   return emitDefaultSilenceableFailure(target);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // LoopUnrollOp
213 //===----------------------------------------------------------------------===//
214 
215 DiagnosedSilenceableFailure
216 transform::LoopUnrollOp::applyToOne(scf::ForOp target,
217                                     SmallVector<Operation *> &results,
218                                     transform::TransformState &state) {
219   if (failed(loopUnrollByFactor(target, getFactor()))) {
220     Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
221     diag << "op failed to unroll";
222     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
223   }
224   return DiagnosedSilenceableFailure(success());
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // Transform op registration
229 //===----------------------------------------------------------------------===//
230 
231 namespace {
232 class SCFTransformDialectExtension
233     : public transform::TransformDialectExtension<
234           SCFTransformDialectExtension> {
235 public:
236   SCFTransformDialectExtension() {
237     declareDependentDialect<AffineDialect>();
238     declareDependentDialect<func::FuncDialect>();
239     registerTransformOps<
240 #define GET_OP_LIST
241 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
242         >();
243   }
244 };
245 } // namespace
246 
247 #define GET_OP_CLASSES
248 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
249 
250 void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
251   registry.addExtensions<SCFTransformDialectExtension>();
252 }
253