1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/SCF/Transforms.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassManager.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "mlir/Transforms/Passes.h"
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 template <LinalgTilingLoopType LoopType>
25 static void fillFusionPatterns(MLIRContext *context,
26                                const LinalgDependenceGraph &dependenceGraph,
27                                RewritePatternSet &patterns) {
28   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
29                LinalgTileAndFusePattern<Conv2DOp>>(
30       context, dependenceGraph,
31       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
32       LinalgFusionOptions().setIndicesToFuse({2}),
33       LinalgTransformationFilter(
34           StringAttr::get(context, "basic_fusion"),
35           StringAttr::get(context, "after_basic_fusion")),
36       LinalgTransformationFilter(
37           ArrayRef<StringAttr>(),
38           StringAttr::get(context, "after_basic_fusion_producer")),
39       LinalgTransformationFilter(
40           ArrayRef<StringAttr>(),
41           StringAttr::get(context, "after_basic_fusion_original")));
42 
43   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
44       context, dependenceGraph,
45       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
46       LinalgFusionOptions().setIndicesToFuse({0}),
47       LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
48                                  StringAttr::get(context, "after_lhs_fusion")),
49       LinalgTransformationFilter(
50           ArrayRef<StringAttr>(),
51           StringAttr::get(context, "after_lhs_fusion_producer")),
52       LinalgTransformationFilter(
53           ArrayRef<StringAttr>(),
54           StringAttr::get(context, "after_lhs_fusion_original")));
55 
56   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
57       context, dependenceGraph,
58       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
59       LinalgFusionOptions().setIndicesToFuse({2}),
60       LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
61                                  StringAttr::get(context, "after_out_fusion")),
62       LinalgTransformationFilter(
63           ArrayRef<StringAttr>(),
64           StringAttr::get(context, "after_out_fusion_producer")),
65       LinalgTransformationFilter(
66           ArrayRef<StringAttr>(),
67           StringAttr::get(context, "after_out_fusion_original")));
68 
69   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
70       context, dependenceGraph,
71       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
72       LinalgFusionOptions().setIndicesToFuse({1}),
73       LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
74                                  StringAttr::get(context, "after_rhs_fusion")),
75       LinalgTransformationFilter(
76           ArrayRef<StringAttr>(),
77           StringAttr::get(context, "after_rhs_fusion_producer")),
78       LinalgTransformationFilter(
79           ArrayRef<StringAttr>(),
80           StringAttr::get(context, "after_rhs_fusion_original")));
81 
82   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
83       context, dependenceGraph,
84       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
85       LinalgFusionOptions().setIndicesToFuse({0, 2}),
86       LinalgTransformationFilter(
87           StringAttr::get(context, "two_operand_fusion"),
88           StringAttr::get(context, "after_two_operand_fusion")),
89       LinalgTransformationFilter(
90           ArrayRef<StringAttr>(),
91           StringAttr::get(context, "after_two_operand_fusion_producer")),
92       LinalgTransformationFilter(
93           ArrayRef<StringAttr>(),
94           StringAttr::get(context, "after_two_operand_fusion_original")));
95 
96   patterns.add<LinalgTileAndFusePattern<GenericOp>>(
97       context, dependenceGraph,
98       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
99       LinalgFusionOptions().setIndicesToFuse({0, 1}),
100       LinalgTransformationFilter(
101           StringAttr::get(context, "transpose_fusion"),
102           StringAttr::get(context, "after_transpose_fusion")),
103       LinalgTransformationFilter(
104           ArrayRef<StringAttr>(),
105           StringAttr::get(context, "after_transpose_fusion_producer")),
106       LinalgTransformationFilter(
107           ArrayRef<StringAttr>(),
108           StringAttr::get(context, "after_transpose_fusion_original")));
109 }
110 
111 namespace {
112 template <LinalgTilingLoopType LoopType>
113 struct TestLinalgFusionTransforms
114     : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> {
115   void getDependentDialects(DialectRegistry &registry) const override {
116     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
117                     scf::SCFDialect, StandardOpsDialect>();
118   }
119   TestLinalgFusionTransforms() = default;
120   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
121 
122   void runOnFunction() override {
123     MLIRContext *context = &this->getContext();
124     FuncOp funcOp = this->getFunction();
125     RewritePatternSet fusionPatterns(context);
126     Aliases alias;
127     LinalgDependenceGraph dependenceGraph =
128         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
129     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
130     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
131   }
132 };
133 
134 struct TestLinalgFusionTransformsParallelLoops
135     : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
136   StringRef getArgument() const final {
137     return "test-linalg-fusion-transform-patterns";
138   }
139   StringRef getDescription() const final {
140     return "Test Linalg fusion transformation patterns by applying them "
141            "greedily.";
142   }
143 };
144 
145 struct TestLinalgFusionTransformsLoops
146     : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
147   StringRef getArgument() const final {
148     return "test-linalg-tensor-fusion-transform-patterns";
149   }
150   StringRef getDescription() const final {
151     return "Test Linalg on tensor fusion transformation "
152            "patterns by applying them greedily.";
153   }
154 };
155 
156 struct TestLinalgFusionTransformsTiledLoops
157     : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
158   StringRef getArgument() const final {
159     return "test-linalg-tiled-loop-fusion-transform-patterns";
160   }
161   StringRef getDescription() const final {
162     return "Test Linalg on tensor fusion transformation "
163            "patterns by applying them greedily.";
164   }
165 };
166 } // namespace
167 
168 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
169   OpBuilder b(f);
170   DenseSet<Operation *> eraseSet;
171 
172   // Save original Linalg ops, we only want to make a pass over those.
173   SmallVector<LinalgOp, 8> linalgOps;
174   f.walk([&](LinalgOp op) {
175     // TODO: support multi-results.
176     if (op->getNumResults() <= 1)
177       linalgOps.push_back(op);
178   });
179 
180   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
181   bool changed = false;
182   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
183     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
184       if (opOperand->get().getType().isa<MemRefType>()) {
185         // TODO: LinalgDependenceGraph should be able to update itself.
186         // The current naive and expensive reconstruction of the graph should be
187         // removed.
188         linalg::Aliases aliases;
189         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
190         auto info = fuseProducerOfBuffer(b, *opOperand, graph);
191         if (failed(info))
192           continue;
193         auto *originalOp = info->originalProducer.getOperation();
194         eraseSet.insert(originalOp);
195         auto *originalOpInLinalgOpsVector =
196             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
197         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
198         changed = true;
199       } else if (opOperand->get().getType().isa<RankedTensorType>()) {
200         // Tile and Fuse tensor input.
201         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
202           continue;
203         auto info = fuseProducerOfTensor(b, *opOperand);
204         if (failed(info))
205           continue;
206         auto *originalOp = info->originalProducer.getOperation();
207         auto *originalOpInLinalgOpsVector =
208             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
209         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
210         // Don't mark for erasure in the tensor case, let DCE handle this.
211         changed = true;
212       }
213     }
214   }
215   // The `fuseProducerOfBuffer` function performs structural checks and in
216   // particular that no covering read or write exist between the consumer and
217   // the producer. As a consequence, the only fusions that may occur preserve
218   // subsequent dependences and are guaranteed by construction to produce the
219   // whole view. We may thus erase the producer once it is fused.
220   for (auto *e : eraseSet)
221     e->erase();
222 
223   return changed ? success() : failure();
224 }
225 
226 namespace {
227 struct TestLinalgGreedyFusion
228     : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
229   void getDependentDialects(DialectRegistry &registry) const override {
230     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
231                     scf::SCFDialect>();
232   }
233   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
234   StringRef getDescription() const final {
235     return "Test Linalg fusion by applying a greedy test transformation.";
236   }
237   void runOnFunction() override {
238     MLIRContext *context = &getContext();
239     RewritePatternSet patterns =
240         linalg::getLinalgTilingCanonicalizationPatterns(context);
241     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
242     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
243     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
244     do {
245       (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
246       PassManager pm(context);
247       pm.addPass(createLoopInvariantCodeMotionPass());
248       pm.addPass(createCanonicalizerPass());
249       pm.addPass(createCSEPass());
250       LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>());
251       if (failed(res))
252         this->signalPassFailure();
253     } while (succeeded(fuseLinalgOpsGreedily(getFunction())));
254   }
255 };
256 
257 /// Pass to test tile and fuse of sequence of operations. Intended only for
258 /// testing.
259 struct TestLinalgTileAndFuseSequencePass
260     : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
261   StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
262   StringRef getDescription() const final {
263     return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
264   }
265   TestLinalgTileAndFuseSequencePass() = default;
266   TestLinalgTileAndFuseSequencePass(
267       const TestLinalgTileAndFuseSequencePass &pass){};
268 
269   ListOption<int64_t> tileSizes{
270       *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
271       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
272 
273   void getDependentDialects(DialectRegistry &registry) const override {
274     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
275                     scf::SCFDialect>();
276   }
277 
278   void runOnFunction() override {
279     FuncOp funcOp = getOperation();
280     auto &blocks = funcOp.getBody().getBlocks();
281     if (!llvm::hasSingleElement(blocks)) {
282       return;
283     }
284     SmallVector<LinalgOp, 2> linalgOps =
285         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
286     Aliases aliases;
287     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
288     OpBuilder builder(funcOp.getContext());
289     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
290     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
291           return linalgOp.hasTensorSemantics();
292         }))
293       loopType = LinalgTilingLoopType::Loops;
294     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
295         builder, linalgOps, dependenceGraph,
296         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
297     if (!tileAndFuseOps)
298       return signalPassFailure();
299     if (linalgOps.back().hasTensorSemantics()) {
300       linalgOps.back().getOperation()->replaceAllUsesWith(
301           tileAndFuseOps->fusedLoops.front());
302     }
303     for (auto op : linalgOps)
304       if (op.hasBufferSemantics())
305         op.erase();
306   }
307 };
308 
309 } // namespace
310 
311 namespace mlir {
312 namespace test {
313 void registerTestLinalgFusionTransforms() {
314   PassRegistration<TestLinalgFusionTransformsParallelLoops>();
315 }
316 void registerTestLinalgTensorFusionTransforms() {
317   PassRegistration<TestLinalgFusionTransformsLoops>();
318 }
319 void registerTestLinalgTiledLoopFusionTransforms() {
320   PassRegistration<TestLinalgFusionTransformsTiledLoops>();
321 }
322 void registerTestLinalgGreedyFusion() {
323   PassRegistration<TestLinalgGreedyFusion>();
324 }
325 void registerTestLinalgTileAndFuseSequencePass() {
326   PassRegistration<TestLinalgTileAndFuseSequencePass>();
327 }
328 
329 } // namespace test
330 } // namespace mlir
331