1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/SCF/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "mlir/Transforms/Passes.h"
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
25 template <LinalgTilingLoopType LoopType>
26 static void fillFusionPatterns(MLIRContext *context,
27                                const LinalgDependenceGraph &dependenceGraph,
28                                RewritePatternSet &patterns) {
29   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
30                LinalgTileAndFusePattern<Conv2DOp>>(
31       context, dependenceGraph,
32       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
33       LinalgFusionOptions().setIndicesToFuse({2}),
34       LinalgTransformationFilter(
35           StringAttr::get(context, "basic_fusion"),
36           StringAttr::get(context, "after_basic_fusion")),
37       LinalgTransformationFilter(
38           ArrayRef<StringAttr>(),
39           StringAttr::get(context, "after_basic_fusion_producer")),
40       LinalgTransformationFilter(
41           ArrayRef<StringAttr>(),
42           StringAttr::get(context, "after_basic_fusion_original")));
43 
44   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
45       context, dependenceGraph,
46       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
47       LinalgFusionOptions().setIndicesToFuse({0}),
48       LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
49                                  StringAttr::get(context, "after_lhs_fusion")),
50       LinalgTransformationFilter(
51           ArrayRef<StringAttr>(),
52           StringAttr::get(context, "after_lhs_fusion_producer")),
53       LinalgTransformationFilter(
54           ArrayRef<StringAttr>(),
55           StringAttr::get(context, "after_lhs_fusion_original")));
56 
57   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
58       context, dependenceGraph,
59       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
60       LinalgFusionOptions().setIndicesToFuse({2}),
61       LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
62                                  StringAttr::get(context, "after_out_fusion")),
63       LinalgTransformationFilter(
64           ArrayRef<StringAttr>(),
65           StringAttr::get(context, "after_out_fusion_producer")),
66       LinalgTransformationFilter(
67           ArrayRef<StringAttr>(),
68           StringAttr::get(context, "after_out_fusion_original")));
69 
70   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
71       context, dependenceGraph,
72       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
73       LinalgFusionOptions().setIndicesToFuse({1}),
74       LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
75                                  StringAttr::get(context, "after_rhs_fusion")),
76       LinalgTransformationFilter(
77           ArrayRef<StringAttr>(),
78           StringAttr::get(context, "after_rhs_fusion_producer")),
79       LinalgTransformationFilter(
80           ArrayRef<StringAttr>(),
81           StringAttr::get(context, "after_rhs_fusion_original")));
82 
83   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
84       context, dependenceGraph,
85       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
86       LinalgFusionOptions().setIndicesToFuse({0, 2}),
87       LinalgTransformationFilter(
88           StringAttr::get(context, "two_operand_fusion"),
89           StringAttr::get(context, "after_two_operand_fusion")),
90       LinalgTransformationFilter(
91           ArrayRef<StringAttr>(),
92           StringAttr::get(context, "after_two_operand_fusion_producer")),
93       LinalgTransformationFilter(
94           ArrayRef<StringAttr>(),
95           StringAttr::get(context, "after_two_operand_fusion_original")));
96 
97   patterns.add<LinalgTileAndFusePattern<GenericOp>>(
98       context, dependenceGraph,
99       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
100       LinalgFusionOptions().setIndicesToFuse({0, 1}),
101       LinalgTransformationFilter(
102           StringAttr::get(context, "transpose_fusion"),
103           StringAttr::get(context, "after_transpose_fusion")),
104       LinalgTransformationFilter(
105           ArrayRef<StringAttr>(),
106           StringAttr::get(context, "after_transpose_fusion_producer")),
107       LinalgTransformationFilter(
108           ArrayRef<StringAttr>(),
109           StringAttr::get(context, "after_transpose_fusion_original")));
110 }
111 
112 namespace {
113 template <LinalgTilingLoopType LoopType>
114 struct TestLinalgFusionTransforms
115     : public PassWrapper<TestLinalgFusionTransforms<LoopType>,
116                          OperationPass<FuncOp>> {
117   void getDependentDialects(DialectRegistry &registry) const override {
118     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
119                     scf::SCFDialect>();
120   }
121   TestLinalgFusionTransforms() = default;
122   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
123 
124   void runOnOperation() override {
125     MLIRContext *context = &this->getContext();
126     FuncOp funcOp = this->getOperation();
127     RewritePatternSet fusionPatterns(context);
128     Aliases alias;
129     LinalgDependenceGraph dependenceGraph =
130         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
131     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
132     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
133   }
134 };
135 
136 struct TestLinalgFusionTransformsParallelLoops
137     : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
138   StringRef getArgument() const final {
139     return "test-linalg-fusion-transform-patterns";
140   }
141   StringRef getDescription() const final {
142     return "Test Linalg fusion transformation patterns by applying them "
143            "greedily.";
144   }
145 };
146 
147 struct TestLinalgFusionTransformsLoops
148     : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
149   StringRef getArgument() const final {
150     return "test-linalg-tensor-fusion-transform-patterns";
151   }
152   StringRef getDescription() const final {
153     return "Test Linalg on tensor fusion transformation "
154            "patterns by applying them greedily.";
155   }
156 };
157 
158 struct TestLinalgFusionTransformsTiledLoops
159     : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
160   StringRef getArgument() const final {
161     return "test-linalg-tiled-loop-fusion-transform-patterns";
162   }
163   StringRef getDescription() const final {
164     return "Test Linalg on tensor fusion transformation "
165            "patterns by applying them greedily.";
166   }
167 };
168 } // namespace
169 
170 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
171   OpBuilder b(f);
172   DenseSet<Operation *> eraseSet;
173 
174   // Save original Linalg ops, we only want to make a pass over those.
175   SmallVector<LinalgOp, 8> linalgOps;
176   f.walk([&](LinalgOp op) {
177     // TODO: support multi-results.
178     if (op->getNumResults() <= 1)
179       linalgOps.push_back(op);
180   });
181 
182   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
183   bool changed = false;
184   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
185     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
186       if (opOperand->get().getType().isa<MemRefType>()) {
187         // TODO: LinalgDependenceGraph should be able to update itself.
188         // The current naive and expensive reconstruction of the graph should be
189         // removed.
190         linalg::Aliases aliases;
191         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
192         auto info = fuseProducerOfBuffer(b, *opOperand, graph);
193         if (failed(info))
194           continue;
195         auto *originalOp = info->originalProducer.getOperation();
196         eraseSet.insert(originalOp);
197         auto *originalOpInLinalgOpsVector =
198             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
199         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
200         changed = true;
201       } else if (opOperand->get().getType().isa<RankedTensorType>()) {
202         // Tile and Fuse tensor input.
203         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
204           continue;
205         auto info = fuseProducerOfTensor(b, *opOperand);
206         if (failed(info))
207           continue;
208         auto *originalOp = info->originalProducer.getOperation();
209         auto *originalOpInLinalgOpsVector =
210             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
211         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
212         // Don't mark for erasure in the tensor case, let DCE handle this.
213         changed = true;
214       }
215     }
216   }
217   // The `fuseProducerOfBuffer` function performs structural checks and in
218   // particular that no covering read or write exist between the consumer and
219   // the producer. As a consequence, the only fusions that may occur preserve
220   // subsequent dependences and are guaranteed by construction to produce the
221   // whole view. We may thus erase the producer once it is fused.
222   for (auto *e : eraseSet)
223     e->erase();
224 
225   return changed ? success() : failure();
226 }
227 
228 namespace {
229 struct TestLinalgGreedyFusion
230     : public PassWrapper<TestLinalgGreedyFusion, OperationPass<FuncOp>> {
231   void getDependentDialects(DialectRegistry &registry) const override {
232     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
233                     scf::SCFDialect>();
234   }
235   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
236   StringRef getDescription() const final {
237     return "Test Linalg fusion by applying a greedy test transformation.";
238   }
239   void runOnOperation() override {
240     MLIRContext *context = &getContext();
241     RewritePatternSet patterns =
242         linalg::getLinalgTilingCanonicalizationPatterns(context);
243     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
244     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
245     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
246     do {
247       (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
248       PassManager pm(context);
249       pm.addPass(createLoopInvariantCodeMotionPass());
250       pm.addPass(createCanonicalizerPass());
251       pm.addPass(createCSEPass());
252       LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>());
253       if (failed(res))
254         this->signalPassFailure();
255     } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
256   }
257 };
258 
259 /// Pass to test tile and fuse of sequence of operations. Intended only for
260 /// testing.
261 struct TestLinalgTileAndFuseSequencePass
262     : public PassWrapper<TestLinalgTileAndFuseSequencePass,
263                          OperationPass<FuncOp>> {
264   StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
265   StringRef getDescription() const final {
266     return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
267   }
268   TestLinalgTileAndFuseSequencePass() = default;
269   TestLinalgTileAndFuseSequencePass(
270       const TestLinalgTileAndFuseSequencePass &pass)
271       : PassWrapper(pass){};
272 
273   ListOption<int64_t> tileSizes{
274       *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
275       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
276 
277   void getDependentDialects(DialectRegistry &registry) const override {
278     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
279                     scf::SCFDialect>();
280   }
281 
282   void runOnOperation() override {
283     FuncOp funcOp = getOperation();
284     auto &blocks = funcOp.getBody().getBlocks();
285     if (!llvm::hasSingleElement(blocks)) {
286       return;
287     }
288     SmallVector<LinalgOp, 2> linalgOps =
289         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
290     Aliases aliases;
291     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
292     OpBuilder builder(funcOp.getContext());
293     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
294     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
295           return linalgOp.hasTensorSemantics();
296         }))
297       loopType = LinalgTilingLoopType::Loops;
298     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
299         builder, linalgOps, dependenceGraph,
300         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
301     if (!tileAndFuseOps)
302       return signalPassFailure();
303     if (linalgOps.back().hasTensorSemantics()) {
304       linalgOps.back().getOperation()->replaceAllUsesWith(
305           tileAndFuseOps->fusedLoops.front());
306     }
307     for (auto op : linalgOps)
308       if (op.hasBufferSemantics())
309         op.erase();
310   }
311 };
312 
313 } // namespace
314 
315 namespace mlir {
316 namespace test {
317 void registerTestLinalgFusionTransforms() {
318   PassRegistration<TestLinalgFusionTransformsParallelLoops>();
319 }
320 void registerTestLinalgTensorFusionTransforms() {
321   PassRegistration<TestLinalgFusionTransformsLoops>();
322 }
323 void registerTestLinalgTiledLoopFusionTransforms() {
324   PassRegistration<TestLinalgFusionTransformsTiledLoops>();
325 }
326 void registerTestLinalgGreedyFusion() {
327   PassRegistration<TestLinalgGreedyFusion>();
328 }
329 void registerTestLinalgTileAndFuseSequencePass() {
330   PassRegistration<TestLinalgTileAndFuseSequencePass>();
331 }
332 
333 } // namespace test
334 } // namespace mlir
335