1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "mlir/Transforms/Passes.h"
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 template <LinalgTilingLoopType LoopType>
24 static void fillFusionPatterns(MLIRContext *context,
25                                const LinalgDependenceGraph &dependenceGraph,
26                                RewritePatternSet &patterns) {
27   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
28                LinalgTileAndFusePattern<ConvOp>>(
29       context, dependenceGraph,
30       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
31       LinalgFusionOptions().setIndicesToFuse({2}),
32       LinalgTransformationFilter(
33           Identifier::get("basic_fusion", context),
34           Identifier::get("after_basic_fusion", context)),
35       LinalgTransformationFilter(
36           ArrayRef<Identifier>(),
37           Identifier::get("after_basic_fusion_producer", context)),
38       LinalgTransformationFilter(
39           ArrayRef<Identifier>(),
40           Identifier::get("after_basic_fusion_original", context)));
41 
42   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
43       context, dependenceGraph,
44       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
45       LinalgFusionOptions().setIndicesToFuse({0}),
46       LinalgTransformationFilter(Identifier::get("lhs_fusion", context),
47                                  Identifier::get("after_lhs_fusion", context)),
48       LinalgTransformationFilter(
49           ArrayRef<Identifier>(),
50           Identifier::get("after_lhs_fusion_producer", context)),
51       LinalgTransformationFilter(
52           ArrayRef<Identifier>(),
53           Identifier::get("after_lhs_fusion_original", context)));
54 
55   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
56       context, dependenceGraph,
57       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
58       LinalgFusionOptions().setIndicesToFuse({2}),
59       LinalgTransformationFilter(Identifier::get("out_fusion", context),
60                                  Identifier::get("after_out_fusion", context)),
61       LinalgTransformationFilter(
62           ArrayRef<Identifier>(),
63           Identifier::get("after_out_fusion_producer", context)),
64       LinalgTransformationFilter(
65           ArrayRef<Identifier>(),
66           Identifier::get("after_out_fusion_original", context)));
67 
68   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
69       context, dependenceGraph,
70       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
71       LinalgFusionOptions().setIndicesToFuse({1}),
72       LinalgTransformationFilter(Identifier::get("rhs_fusion", context),
73                                  Identifier::get("after_rhs_fusion", context)),
74       LinalgTransformationFilter(
75           ArrayRef<Identifier>(),
76           Identifier::get("after_rhs_fusion_producer", context)),
77       LinalgTransformationFilter(
78           ArrayRef<Identifier>(),
79           Identifier::get("after_rhs_fusion_original", context)));
80 
81   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
82       context, dependenceGraph,
83       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
84       LinalgFusionOptions().setIndicesToFuse({0, 2}),
85       LinalgTransformationFilter(
86           Identifier::get("two_operand_fusion", context),
87           Identifier::get("after_two_operand_fusion", context)),
88       LinalgTransformationFilter(
89           ArrayRef<Identifier>(),
90           Identifier::get("after_two_operand_fusion_producer", context)),
91       LinalgTransformationFilter(
92           ArrayRef<Identifier>(),
93           Identifier::get("after_two_operand_fusion_original", context)));
94 
95   patterns.add<LinalgTileAndFusePattern<GenericOp>>(
96       context, dependenceGraph,
97       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
98       LinalgFusionOptions().setIndicesToFuse({0, 1}),
99       LinalgTransformationFilter(
100           Identifier::get("transpose_fusion", context),
101           Identifier::get("after_transpose_fusion", context)),
102       LinalgTransformationFilter(
103           ArrayRef<Identifier>(),
104           Identifier::get("after_transpose_fusion_producer", context)),
105       LinalgTransformationFilter(
106           ArrayRef<Identifier>(),
107           Identifier::get("after_transpose_fusion_original", context)));
108 }
109 
110 namespace {
111 template <LinalgTilingLoopType LoopType = LinalgTilingLoopType::ParallelLoops>
112 struct TestLinalgFusionTransforms
113     : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> {
114   TestLinalgFusionTransforms() = default;
115   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
116 
117   void getDependentDialects(DialectRegistry &registry) const override {
118     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
119                     scf::SCFDialect, StandardOpsDialect>();
120   }
121 
122   void runOnFunction() override {
123     MLIRContext *context = &this->getContext();
124     FuncOp funcOp = this->getFunction();
125     RewritePatternSet fusionPatterns(context);
126     Aliases alias;
127     LinalgDependenceGraph dependenceGraph =
128         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
129     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
130     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
131   }
132 };
133 } // namespace
134 
135 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
136   OpBuilder b(f);
137   DenseSet<Operation *> eraseSet;
138 
139   // Save original Linalg ops, we only want to make a pass over those.
140   SmallVector<LinalgOp, 8> linalgOps;
141   f.walk([&](LinalgOp op) {
142     // TODO: support multi-results.
143     if (op->getNumResults() <= 1)
144       linalgOps.push_back(op);
145   });
146 
147   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
148   bool changed = false;
149   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
150     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
151       if (opOperand->get().getType().isa<MemRefType>()) {
152         // TODO: LinalgDependenceGraph should be able to update itself.
153         // The current naive and expensive reconstruction of the graph should be
154         // removed.
155         linalg::Aliases aliases;
156         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
157         if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) {
158           auto *originalOp = info->originalProducer.getOperation();
159           eraseSet.insert(originalOp);
160           auto *originalOpInLinalgOpsVector =
161               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
162           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
163           changed = true;
164         }
165       } else {
166         assert(opOperand->get().getType().isa<RankedTensorType>());
167         // Tile and Fuse tensor input.
168         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
169           continue;
170         if (auto info = fuseProducerOfTensor(b, *opOperand)) {
171           auto *originalOp = info->originalProducer.getOperation();
172           auto *originalOpInLinalgOpsVector =
173               std::find(linalgOps.begin(), linalgOps.end(), originalOp);
174           *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
175           // Don't mark for erasure in the tensor case, let DCE handle this.
176           changed = true;
177         }
178       }
179     }
180   }
181   // The `fuseProducerOfBuffer` function performs structural checks and in
182   // particular that no covering read or write exist between the consumer and
183   // the producer. As a consequence, the only fusions that may occur preserve
184   // subsequent dependences and are guaranteed by construction to produce the
185   // whole view. We may thus erase the producer once it is fused.
186   for (auto *e : eraseSet)
187     e->erase();
188 
189   return changed ? success() : failure();
190 }
191 
192 namespace {
193 struct TestLinalgGreedyFusion
194     : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
195   void getDependentDialects(DialectRegistry &registry) const override {
196     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
197                     scf::SCFDialect>();
198   }
199   void runOnFunction() override {
200     MLIRContext *context = &getContext();
201     RewritePatternSet patterns =
202         linalg::getLinalgTilingCanonicalizationPatterns(context);
203     patterns.add<AffineMinSCFCanonicalizationPattern>(context);
204     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
205     while (succeeded(fuseLinalgOpsGreedily(getFunction()))) {
206       (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
207       PassManager pm(context);
208       pm.addPass(createLoopInvariantCodeMotionPass());
209       pm.addPass(createCanonicalizerPass());
210       pm.addPass(createCSEPass());
211       LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>());
212       if (failed(res))
213         this->signalPassFailure();
214     }
215   }
216 };
217 
218 /// Pass to test tile and fuse of sequence of operations. Intended only for
219 /// testing.
220 struct TestLinalgTileAndFuseSequencePass
221     : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> {
222   TestLinalgTileAndFuseSequencePass() = default;
223   TestLinalgTileAndFuseSequencePass(
224       const TestLinalgTileAndFuseSequencePass &pass){};
225 
226   ListOption<int64_t> tileSizes{
227       *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
228       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
229 
230   void getDependentDialects(DialectRegistry &registry) const override {
231     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
232                     scf::SCFDialect>();
233   }
234 
235   void runOnFunction() override {
236     FuncOp funcOp = getOperation();
237     auto &blocks = funcOp.getBody().getBlocks();
238     if (!llvm::hasSingleElement(blocks)) {
239       return;
240     }
241     SmallVector<LinalgOp, 2> linalgOps =
242         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
243     Aliases aliases;
244     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
245     OpBuilder builder(funcOp.getContext());
246     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
247     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
248           return linalgOp.hasTensorSemantics();
249         }))
250       loopType = LinalgTilingLoopType::Loops;
251     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
252         builder, linalgOps, dependenceGraph,
253         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
254     if (!tileAndFuseOps)
255       return signalPassFailure();
256     if (linalgOps.back().hasTensorSemantics()) {
257       linalgOps.back().getOperation()->replaceAllUsesWith(
258           tileAndFuseOps->fusedLoops.front());
259     }
260     for (auto op : linalgOps)
261       if (op.hasBufferSemantics())
262         op.erase();
263   }
264 };
265 } // namespace
266 
267 namespace mlir {
268 namespace test {
269 void registerTestLinalgFusionTransforms() {
270   PassRegistration<TestLinalgFusionTransforms<>> testFusionTransformsPass(
271       "test-linalg-fusion-transform-patterns",
272       "Test Linalg fusion transformation patterns by applying them greedily.");
273 }
274 void registerTestLinalgTensorFusionTransforms() {
275   PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::Loops>>
276       testTensorFusionTransformsPass(
277           "test-linalg-tensor-fusion-transform-patterns",
278           "Test Linalg on tensor fusion transformation "
279           "patterns by applying them greedily.");
280 }
281 void registerTestLinalgTiledLoopFusionTransforms() {
282   PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops>>
283       testTiledLoopFusionTransformsPass(
284           "test-linalg-tiled-loop-fusion-transform-patterns",
285           "Test Linalg on tensor fusion transformation "
286           "patterns by applying them greedily.");
287 }
288 void registerTestLinalgGreedyFusion() {
289   PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass(
290       "test-linalg-greedy-fusion",
291       "Test Linalg fusion by applying a greedy test transformation.");
292 }
293 void registerTestLinalgTileAndFuseSequencePass() {
294   PassRegistration<TestLinalgTileAndFuseSequencePass>
295       testTileAndFuseSequencePass(
296           "test-linalg-tile-and-fuse",
297           "Test Linalg tiling and fusion of a sequence of Linalg operations.");
298 }
299 
300 } // namespace test
301 } // namespace mlir
302