1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/SCF/Patterns.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/SCF/Transforms.h"
15 #include "mlir/Dialect/SCF/Utils/Utils.h"
16 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 /// A simple pattern rewriter that implements no special logic.
23 class SimpleRewriter : public PatternRewriter {
24 public:
25   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
26 };
27 } // namespace
28 
29 //===----------------------------------------------------------------------===//
30 // GetParentForOp
31 //===----------------------------------------------------------------------===//
32 
33 LogicalResult
34 transform::GetParentForOp::apply(transform::TransformResults &results,
35                                  transform::TransformState &state) {
36   SetVector<Operation *> parents;
37   for (Operation *target : state.getPayloadOps(getTarget())) {
38     scf::ForOp loop;
39     Operation *current = target;
40     for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
41       loop = current->getParentOfType<scf::ForOp>();
42       if (!loop) {
43         InFlightDiagnostic diag = emitError() << "could not find an '"
44                                               << scf::ForOp::getOperationName()
45                                               << "' parent";
46         diag.attachNote(target->getLoc()) << "target op";
47         return diag;
48       }
49       current = loop;
50     }
51     parents.insert(loop);
52   }
53   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
54   return success();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // LoopOutlineOp
59 //===----------------------------------------------------------------------===//
60 
61 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
62 /// the provided rewriter for all operations to remain compatible with the
63 /// rewriting infra, as opposed to just splicing the op in place.
64 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
65                                                 Operation *op) {
66   if (op->getNumRegions() != 1)
67     return nullptr;
68   OpBuilder::InsertionGuard g(b);
69   b.setInsertionPoint(op);
70   scf::ExecuteRegionOp executeRegionOp =
71       b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
72   {
73     OpBuilder::InsertionGuard g(b);
74     b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
75     Operation *clonedOp = b.cloneWithoutRegions(*op);
76     Region &clonedRegion = clonedOp->getRegions().front();
77     assert(clonedRegion.empty() && "expected empty region");
78     b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
79                          clonedRegion.end());
80     b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
81   }
82   b.replaceOp(op, executeRegionOp.getResults());
83   return executeRegionOp;
84 }
85 
86 LogicalResult
87 transform::LoopOutlineOp::apply(transform::TransformResults &results,
88                                 transform::TransformState &state) {
89   SmallVector<Operation *> transformed;
90   DenseMap<Operation *, SymbolTable> symbolTables;
91   for (Operation *target : state.getPayloadOps(getTarget())) {
92     Location location = target->getLoc();
93     Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
94     SimpleRewriter rewriter(getContext());
95     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
96     if (!exec) {
97       InFlightDiagnostic diag = emitError() << "failed to outline";
98       diag.attachNote(target->getLoc()) << "target op";
99       return diag;
100     }
101     func::CallOp call;
102     FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
103         rewriter, location, exec.getRegion(), getFuncName(), &call);
104 
105     if (failed(outlined))
106       return reportUnknownTransformError(target);
107 
108     if (symbolTableOp) {
109       SymbolTable &symbolTable =
110           symbolTables.try_emplace(symbolTableOp, symbolTableOp)
111               .first->getSecond();
112       symbolTable.insert(*outlined);
113       call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
114     }
115     transformed.push_back(*outlined);
116   }
117   results.set(getTransformed().cast<OpResult>(), transformed);
118   return success();
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // LoopPeelOp
123 //===----------------------------------------------------------------------===//
124 
125 FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop) {
126   scf::ForOp result;
127   IRRewriter rewriter(loop->getContext());
128   LogicalResult status =
129       scf::peelAndCanonicalizeForLoop(rewriter, loop, result);
130   if (failed(status)) {
131     if (getFailIfAlreadyDivisible())
132       return reportUnknownTransformError(loop);
133     return loop;
134   }
135   return result;
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // LoopPipelineOp
140 //===----------------------------------------------------------------------===//
141 
142 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
143 /// operation to its logical time position given the iteration interval and the
144 /// read latency. The latter is only relevant for vector transfers.
145 static void
146 loopScheduling(scf::ForOp forOp,
147                std::vector<std::pair<Operation *, unsigned>> &schedule,
148                unsigned iterationInterval, unsigned readLatency) {
149   auto getLatency = [&](Operation *op) -> unsigned {
150     if (isa<vector::TransferReadOp>(op))
151       return readLatency;
152     return 1;
153   };
154 
155   DenseMap<Operation *, unsigned> opCycles;
156   std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
157   for (Operation &op : forOp.getBody()->getOperations()) {
158     if (isa<scf::YieldOp>(op))
159       continue;
160     unsigned earlyCycle = 0;
161     for (Value operand : op.getOperands()) {
162       Operation *def = operand.getDefiningOp();
163       if (!def)
164         continue;
165       earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
166     }
167     opCycles[&op] = earlyCycle;
168     wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
169   }
170   for (const auto &it : wrappedSchedule) {
171     for (Operation *op : it.second) {
172       unsigned cycle = opCycles[op];
173       schedule.push_back(std::make_pair(op, cycle / iterationInterval));
174     }
175   }
176 }
177 
178 FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
179   scf::PipeliningOption options;
180   options.getScheduleFn =
181       [this](scf::ForOp forOp,
182              std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
183         loopScheduling(forOp, schedule, getIterationInterval(),
184                        getReadLatency());
185       };
186 
187   scf::ForLoopPipeliningPattern pattern(options, loop->getContext());
188   SimpleRewriter rewriter(getContext());
189   rewriter.setInsertionPoint(loop);
190   FailureOr<scf::ForOp> patternResult =
191       pattern.returningMatchAndRewrite(loop, rewriter);
192   if (failed(patternResult))
193     return reportUnknownTransformError(loop);
194   return patternResult;
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // LoopUnrollOp
199 //===----------------------------------------------------------------------===//
200 
201 LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) {
202   if (failed(loopUnrollByFactor(loop, getFactor())))
203     return reportUnknownTransformError(loop);
204   return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // Transform op registration
209 //===----------------------------------------------------------------------===//
210 
211 namespace {
212 class SCFTransformDialectExtension
213     : public transform::TransformDialectExtension<
214           SCFTransformDialectExtension> {
215 public:
216   SCFTransformDialectExtension() {
217     declareDependentDialect<AffineDialect>();
218     declareDependentDialect<func::FuncDialect>();
219     registerTransformOps<
220 #define GET_OP_LIST
221 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
222         >();
223   }
224 };
225 } // namespace
226 
227 #define GET_OP_CLASSES
228 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
229 
230 void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
231   registry.addExtensions<SCFTransformDialectExtension>();
232 }
233