1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "mlir/Transforms/Passes.h"
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 template <LinalgTilingLoopType LoopType>
24 static void fillFusionPatterns(MLIRContext *context,
25                                const LinalgDependenceGraph &dependenceGraph,
26                                RewritePatternSet &patterns) {
27   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
28                LinalgTileAndFusePattern<ConvOp>>(
29       context, dependenceGraph,
30       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
31       LinalgFusionOptions().setIndicesToFuse({2}),
32       LinalgTransformationFilter(
33           Identifier::get("basic_fusion", context),
34           Identifier::get("after_basic_fusion", context)),
35       LinalgTransformationFilter(
36           ArrayRef<Identifier>(),
37           Identifier::get("after_basic_fusion_producer", context)),
38       LinalgTransformationFilter(
39           ArrayRef<Identifier>(),
40           Identifier::get("after_basic_fusion_original", context)));
41 
42   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
43       context, dependenceGraph,
44       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
45       LinalgFusionOptions().setIndicesToFuse({0}),
46       LinalgTransformationFilter(Identifier::get("lhs_fusion", context),
47                                  Identifier::get("after_lhs_fusion", context)),
48       LinalgTransformationFilter(
49           ArrayRef<Identifier>(),
50           Identifier::get("after_lhs_fusion_producer", context)),
51       LinalgTransformationFilter(
52           ArrayRef<Identifier>(),
53           Identifier::get("after_lhs_fusion_original", context)));
54 
55   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
56       context, dependenceGraph,
57       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
58       LinalgFusionOptions().setIndicesToFuse({2}),
59       LinalgTransformationFilter(Identifier::get("out_fusion", context),
60                                  Identifier::get("after_out_fusion", context)),
61       LinalgTransformationFilter(
62           ArrayRef<Identifier>(),
63           Identifier::get("after_out_fusion_producer", context)),
64       LinalgTransformationFilter(
65           ArrayRef<Identifier>(),
66           Identifier::get("after_out_fusion_original", context)));
67 
68   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
69       context, dependenceGraph,
70       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
71       LinalgFusionOptions().setIndicesToFuse({1}),
72       LinalgTransformationFilter(Identifier::get("rhs_fusion", context),
73                                  Identifier::get("after_rhs_fusion", context)),
74       LinalgTransformationFilter(
75           ArrayRef<Identifier>(),
76           Identifier::get("after_rhs_fusion_producer", context)),
77       LinalgTransformationFilter(
78           ArrayRef<Identifier>(),
79           Identifier::get("after_rhs_fusion_original", context)));
80 
81   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
82       context, dependenceGraph,
83       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
84       LinalgFusionOptions().setIndicesToFuse({0, 2}),
85       LinalgTransformationFilter(
86           Identifier::get("two_operand_fusion", context),
87           Identifier::get("after_two_operand_fusion", context)),
88       LinalgTransformationFilter(
89           ArrayRef<Identifier>(),
90           Identifier::get("after_two_operand_fusion_producer", context)),
91       LinalgTransformationFilter(
92           ArrayRef<Identifier>(),
93           Identifier::get("after_two_operand_fusion_original", context)));
94 
95   patterns.add<LinalgTileAndFusePattern<GenericOp>>(
96       context, dependenceGraph,
97       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
98       LinalgFusionOptions().setIndicesToFuse({0, 1}),
99       LinalgTransformationFilter(
100           Identifier::get("transpose_fusion", context),
101           Identifier::get("after_transpose_fusion", context)),
102       LinalgTransformationFilter(
103           ArrayRef<Identifier>(),
104           Identifier::get("after_transpose_fusion_producer", context)),
105       LinalgTransformationFilter(
106           ArrayRef<Identifier>(),
107           Identifier::get("after_transpose_fusion_original", context)));
108 }
109 
110 namespace {
111 template <LinalgTilingLoopType LoopType>
112 struct TestLinalgFusionTransforms
113     : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> {
114   void getDependentDialects(DialectRegistry &registry) const override {
115     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
116                     scf::SCFDialect, StandardOpsDialect>();
117   }
118   TestLinalgFusionTransforms() = default;
119   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
120 
121   void runOnFunction() override {
122     MLIRContext *context = &this->getContext();
123     FuncOp funcOp = this->getFunction();
124     RewritePatternSet fusionPatterns(context);
125     Aliases alias;
126     LinalgDependenceGraph dependenceGraph =
127         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
128     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
129     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
130   }
131 };
132 
133 struct TestLinalgFusionTransformsParallelLoops
134     : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
135   StringRef getArgument() const final {
136     return "test-linalg-fusion-transform-patterns";
137   }
138   StringRef getDescription() const final {
139     return "Test Linalg fusion transformation patterns by applying them "
140            "greedily.";
141   }
142 };
143 
144 struct TestLinalgFusionTransformsLoops
145     : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
146   StringRef getArgument() const final {
147     return "test-linalg-tensor-fusion-transform-patterns";
148   }
149   StringRef getDescription() const final {
150     return "Test Linalg on tensor fusion transformation "
151            "patterns by applying them greedily.";
152   }
153 };
154 
155 struct TestLinalgFusionTransformsTiledLoops
156     : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
157   StringRef getArgument() const final {
158     return "test-linalg-tiled-loop-fusion-transform-patterns";
159   }
160   StringRef getDescription() const final {
161     return "Test Linalg on tensor fusion transformation "
162            "patterns by applying them greedily.";
163   }
164 };
165 } // namespace
166 
167 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
168   OpBuilder b(f);
169   DenseSet<Operation *> eraseSet;
170 
171   // Save original Linalg ops, we only want to make a pass over those.
172   SmallVector<LinalgOp, 8> linalgOps;
173   f.walk([&](LinalgOp op) {
174     // TODO: support multi-results.
175     if (op->getNumResults() <= 1)
176       linalgOps.push_back(op);
177   });
178 
179   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
180   bool changed = false;
181   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
182     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
183       if (opOperand->get().getType().isa<MemRefType>()) {
184         // TODO: LinalgDependenceGraph should be able to update itself.
185         // The current naive and expensive reconstruction of the graph should be
186         // removed.
187         linalg::Aliases aliases;
188         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
189         if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) {
190           auto *originalOp = info->originalProducer.getOperation();
191           eraseSet.insert(originalOp);
192           auto *originalOpInLinalgOpsVector =
193               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
194           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
195           changed = true;
196         }
197       } else if (opOperand->get().getType().isa<RankedTensorType>()) {
198         // Tile and Fuse tensor input.
199         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
200           continue;
201         if (auto info = fuseProducerOfTensor(b, *opOperand)) {
202           auto *originalOp = info->originalProducer.getOperation();
203           auto *originalOpInLinalgOpsVector =
204               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
205           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
206           // Don't mark for erasure in the tensor case, let DCE handle this.
207           changed = true;
208         }
209       }
210     }
211   }
212   // The `fuseProducerOfBuffer` function performs structural checks and in
213   // particular that no covering read or write exist between the consumer and
214   // the producer. As a consequence, the only fusions that may occur preserve
215   // subsequent dependences and are guaranteed by construction to produce the
216   // whole view. We may thus erase the producer once it is fused.
217   for (auto *e : eraseSet)
218     e->erase();
219 
220   return changed ? success() : failure();
221 }
222 
223 namespace {
224 struct TestLinalgGreedyFusion
225     : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
226   void getDependentDialects(DialectRegistry &registry) const override {
227     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
228                     scf::SCFDialect>();
229   }
230   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
231   StringRef getDescription() const final {
232     return "Test Linalg fusion by applying a greedy test transformation.";
233   }
234   void runOnFunction() override {
235     MLIRContext *context = &getContext();
236     RewritePatternSet patterns =
237         linalg::getLinalgTilingCanonicalizationPatterns(context);
238     patterns.add<AffineMinSCFCanonicalizationPattern,
239                  ExtractSliceOfPadTensorSwapPattern>(context);
240     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
241     do {
242       (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
243       PassManager pm(context);
244       pm.addPass(createLoopInvariantCodeMotionPass());
245       pm.addPass(createCanonicalizerPass());
246       pm.addPass(createCSEPass());
247       LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>());
248       if (failed(res))
249         this->signalPassFailure();
250     } while (succeeded(fuseLinalgOpsGreedily(getFunction())));
251   }
252 };
253 
254 /// Pass to test tile and fuse of sequence of operations. Intended only for
255 /// testing.
256 struct TestLinalgTileAndFuseSequencePass
257     : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
258   StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
259   StringRef getDescription() const final {
260     return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
261   }
262   TestLinalgTileAndFuseSequencePass() = default;
263   TestLinalgTileAndFuseSequencePass(
264       const TestLinalgTileAndFuseSequencePass &pass){};
265 
266   ListOption<int64_t> tileSizes{
267       *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
268       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
269 
270   void getDependentDialects(DialectRegistry &registry) const override {
271     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
272                     scf::SCFDialect>();
273   }
274 
275   void runOnFunction() override {
276     FuncOp funcOp = getOperation();
277     auto &blocks = funcOp.getBody().getBlocks();
278     if (!llvm::hasSingleElement(blocks)) {
279       return;
280     }
281     SmallVector<LinalgOp, 2> linalgOps =
282         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
283     Aliases aliases;
284     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
285     OpBuilder builder(funcOp.getContext());
286     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
287     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
288           return linalgOp.hasTensorSemantics();
289         }))
290       loopType = LinalgTilingLoopType::Loops;
291     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
292         builder, linalgOps, dependenceGraph,
293         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
294     if (!tileAndFuseOps)
295       return signalPassFailure();
296     if (linalgOps.back().hasTensorSemantics()) {
297       linalgOps.back().getOperation()->replaceAllUsesWith(
298           tileAndFuseOps->fusedLoops.front());
299     }
300     for (auto op : linalgOps)
301       if (op.hasBufferSemantics())
302         op.erase();
303   }
304 };
305 
306 } // namespace
307 
308 namespace mlir {
309 namespace test {
310 void registerTestLinalgFusionTransforms() {
311   PassRegistration<TestLinalgFusionTransformsParallelLoops>();
312 }
313 void registerTestLinalgTensorFusionTransforms() {
314   PassRegistration<TestLinalgFusionTransformsLoops>();
315 }
316 void registerTestLinalgTiledLoopFusionTransforms() {
317   PassRegistration<TestLinalgFusionTransformsTiledLoops>();
318 }
319 void registerTestLinalgGreedyFusion() {
320   PassRegistration<TestLinalgGreedyFusion>();
321 }
322 void registerTestLinalgTileAndFuseSequencePass() {
323   PassRegistration<TestLinalgTileAndFuseSequencePass>();
324 }
325 
326 } // namespace test
327 } // namespace mlir
328