1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/SCF/Transforms.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassManager.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "mlir/Transforms/Passes.h"
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 template <LinalgTilingLoopType LoopType>
25 static void fillFusionPatterns(MLIRContext *context,
26                                const LinalgDependenceGraph &dependenceGraph,
27                                RewritePatternSet &patterns) {
28   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
29                LinalgTileAndFusePattern<Conv2DOp>>(
30       context, dependenceGraph,
31       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
32       LinalgFusionOptions().setIndicesToFuse({2}),
33       LinalgTransformationFilter(
34           Identifier::get("basic_fusion", context),
35           Identifier::get("after_basic_fusion", context)),
36       LinalgTransformationFilter(
37           ArrayRef<Identifier>(),
38           Identifier::get("after_basic_fusion_producer", context)),
39       LinalgTransformationFilter(
40           ArrayRef<Identifier>(),
41           Identifier::get("after_basic_fusion_original", context)));
42 
43   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
44       context, dependenceGraph,
45       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
46       LinalgFusionOptions().setIndicesToFuse({0}),
47       LinalgTransformationFilter(Identifier::get("lhs_fusion", context),
48                                  Identifier::get("after_lhs_fusion", context)),
49       LinalgTransformationFilter(
50           ArrayRef<Identifier>(),
51           Identifier::get("after_lhs_fusion_producer", context)),
52       LinalgTransformationFilter(
53           ArrayRef<Identifier>(),
54           Identifier::get("after_lhs_fusion_original", context)));
55 
56   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
57       context, dependenceGraph,
58       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
59       LinalgFusionOptions().setIndicesToFuse({2}),
60       LinalgTransformationFilter(Identifier::get("out_fusion", context),
61                                  Identifier::get("after_out_fusion", context)),
62       LinalgTransformationFilter(
63           ArrayRef<Identifier>(),
64           Identifier::get("after_out_fusion_producer", context)),
65       LinalgTransformationFilter(
66           ArrayRef<Identifier>(),
67           Identifier::get("after_out_fusion_original", context)));
68 
69   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
70       context, dependenceGraph,
71       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
72       LinalgFusionOptions().setIndicesToFuse({1}),
73       LinalgTransformationFilter(Identifier::get("rhs_fusion", context),
74                                  Identifier::get("after_rhs_fusion", context)),
75       LinalgTransformationFilter(
76           ArrayRef<Identifier>(),
77           Identifier::get("after_rhs_fusion_producer", context)),
78       LinalgTransformationFilter(
79           ArrayRef<Identifier>(),
80           Identifier::get("after_rhs_fusion_original", context)));
81 
82   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
83       context, dependenceGraph,
84       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
85       LinalgFusionOptions().setIndicesToFuse({0, 2}),
86       LinalgTransformationFilter(
87           Identifier::get("two_operand_fusion", context),
88           Identifier::get("after_two_operand_fusion", context)),
89       LinalgTransformationFilter(
90           ArrayRef<Identifier>(),
91           Identifier::get("after_two_operand_fusion_producer", context)),
92       LinalgTransformationFilter(
93           ArrayRef<Identifier>(),
94           Identifier::get("after_two_operand_fusion_original", context)));
95 
96   patterns.add<LinalgTileAndFusePattern<GenericOp>>(
97       context, dependenceGraph,
98       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
99       LinalgFusionOptions().setIndicesToFuse({0, 1}),
100       LinalgTransformationFilter(
101           Identifier::get("transpose_fusion", context),
102           Identifier::get("after_transpose_fusion", context)),
103       LinalgTransformationFilter(
104           ArrayRef<Identifier>(),
105           Identifier::get("after_transpose_fusion_producer", context)),
106       LinalgTransformationFilter(
107           ArrayRef<Identifier>(),
108           Identifier::get("after_transpose_fusion_original", context)));
109 }
110 
111 namespace {
112 template <LinalgTilingLoopType LoopType>
113 struct TestLinalgFusionTransforms
114     : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> {
115   void getDependentDialects(DialectRegistry &registry) const override {
116     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
117                     scf::SCFDialect, StandardOpsDialect>();
118   }
119   TestLinalgFusionTransforms() = default;
120   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
121 
122   void runOnFunction() override {
123     MLIRContext *context = &this->getContext();
124     FuncOp funcOp = this->getFunction();
125     RewritePatternSet fusionPatterns(context);
126     Aliases alias;
127     LinalgDependenceGraph dependenceGraph =
128         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
129     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
130     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
131   }
132 };
133 
134 struct TestLinalgFusionTransformsParallelLoops
135     : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
136   StringRef getArgument() const final {
137     return "test-linalg-fusion-transform-patterns";
138   }
139   StringRef getDescription() const final {
140     return "Test Linalg fusion transformation patterns by applying them "
141            "greedily.";
142   }
143 };
144 
145 struct TestLinalgFusionTransformsLoops
146     : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
147   StringRef getArgument() const final {
148     return "test-linalg-tensor-fusion-transform-patterns";
149   }
150   StringRef getDescription() const final {
151     return "Test Linalg on tensor fusion transformation "
152            "patterns by applying them greedily.";
153   }
154 };
155 
156 struct TestLinalgFusionTransformsTiledLoops
157     : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
158   StringRef getArgument() const final {
159     return "test-linalg-tiled-loop-fusion-transform-patterns";
160   }
161   StringRef getDescription() const final {
162     return "Test Linalg on tensor fusion transformation "
163            "patterns by applying them greedily.";
164   }
165 };
166 } // namespace
167 
168 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
169   OpBuilder b(f);
170   DenseSet<Operation *> eraseSet;
171 
172   // Save original Linalg ops, we only want to make a pass over those.
173   SmallVector<LinalgOp, 8> linalgOps;
174   f.walk([&](LinalgOp op) {
175     // TODO: support multi-results.
176     if (op->getNumResults() <= 1)
177       linalgOps.push_back(op);
178   });
179 
180   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
181   bool changed = false;
182   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
183     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
184       if (opOperand->get().getType().isa<MemRefType>()) {
185         // TODO: LinalgDependenceGraph should be able to update itself.
186         // The current naive and expensive reconstruction of the graph should be
187         // removed.
188         linalg::Aliases aliases;
189         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
190         if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) {
191           auto *originalOp = info->originalProducer.getOperation();
192           eraseSet.insert(originalOp);
193           auto *originalOpInLinalgOpsVector =
194               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
195           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
196           changed = true;
197         }
198       } else if (opOperand->get().getType().isa<RankedTensorType>()) {
199         // Tile and Fuse tensor input.
200         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
201           continue;
202         if (auto info = fuseProducerOfTensor(b, *opOperand)) {
203           auto *originalOp = info->originalProducer.getOperation();
204           auto *originalOpInLinalgOpsVector =
205               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
206           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
207           // Don't mark for erasure in the tensor case, let DCE handle this.
208           changed = true;
209         }
210       }
211     }
212   }
213   // The `fuseProducerOfBuffer` function performs structural checks and in
214   // particular that no covering read or write exist between the consumer and
215   // the producer. As a consequence, the only fusions that may occur preserve
216   // subsequent dependences and are guaranteed by construction to produce the
217   // whole view. We may thus erase the producer once it is fused.
218   for (auto *e : eraseSet)
219     e->erase();
220 
221   return changed ? success() : failure();
222 }
223 
224 namespace {
225 struct TestLinalgGreedyFusion
226     : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
227   void getDependentDialects(DialectRegistry &registry) const override {
228     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
229                     scf::SCFDialect>();
230   }
231   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
232   StringRef getDescription() const final {
233     return "Test Linalg fusion by applying a greedy test transformation.";
234   }
235   void runOnFunction() override {
236     MLIRContext *context = &getContext();
237     RewritePatternSet patterns =
238         linalg::getLinalgTilingCanonicalizationPatterns(context);
239     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
240     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
241     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
242     do {
243       (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
244       PassManager pm(context);
245       pm.addPass(createLoopInvariantCodeMotionPass());
246       pm.addPass(createCanonicalizerPass());
247       pm.addPass(createCSEPass());
248       LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>());
249       if (failed(res))
250         this->signalPassFailure();
251     } while (succeeded(fuseLinalgOpsGreedily(getFunction())));
252   }
253 };
254 
255 /// Pass to test tile and fuse of sequence of operations. Intended only for
256 /// testing.
257 struct TestLinalgTileAndFuseSequencePass
258     : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
259   StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
260   StringRef getDescription() const final {
261     return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
262   }
263   TestLinalgTileAndFuseSequencePass() = default;
264   TestLinalgTileAndFuseSequencePass(
265       const TestLinalgTileAndFuseSequencePass &pass){};
266 
267   ListOption<int64_t> tileSizes{
268       *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
269       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
270 
271   void getDependentDialects(DialectRegistry &registry) const override {
272     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
273                     scf::SCFDialect>();
274   }
275 
276   void runOnFunction() override {
277     FuncOp funcOp = getOperation();
278     auto &blocks = funcOp.getBody().getBlocks();
279     if (!llvm::hasSingleElement(blocks)) {
280       return;
281     }
282     SmallVector<LinalgOp, 2> linalgOps =
283         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
284     Aliases aliases;
285     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
286     OpBuilder builder(funcOp.getContext());
287     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
288     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
289           return linalgOp.hasTensorSemantics();
290         }))
291       loopType = LinalgTilingLoopType::Loops;
292     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
293         builder, linalgOps, dependenceGraph,
294         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
295     if (!tileAndFuseOps)
296       return signalPassFailure();
297     if (linalgOps.back().hasTensorSemantics()) {
298       linalgOps.back().getOperation()->replaceAllUsesWith(
299           tileAndFuseOps->fusedLoops.front());
300     }
301     for (auto op : linalgOps)
302       if (op.hasBufferSemantics())
303         op.erase();
304   }
305 };
306 
307 } // namespace
308 
309 namespace mlir {
310 namespace test {
311 void registerTestLinalgFusionTransforms() {
312   PassRegistration<TestLinalgFusionTransformsParallelLoops>();
313 }
314 void registerTestLinalgTensorFusionTransforms() {
315   PassRegistration<TestLinalgFusionTransformsLoops>();
316 }
317 void registerTestLinalgTiledLoopFusionTransforms() {
318   PassRegistration<TestLinalgFusionTransformsTiledLoops>();
319 }
320 void registerTestLinalgGreedyFusion() {
321   PassRegistration<TestLinalgGreedyFusion>();
322 }
323 void registerTestLinalgTileAndFuseSequencePass() {
324   PassRegistration<TestLinalgTileAndFuseSequencePass>();
325 }
326 
327 } // namespace test
328 } // namespace mlir
329