1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/SCF/Transforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "mlir/Transforms/Passes.h"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 template <LinalgTilingLoopType LoopType>
27 static void fillFusionPatterns(MLIRContext *context,
28                                const LinalgDependenceGraph &dependenceGraph,
29                                RewritePatternSet &patterns) {
30   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
31                LinalgTileAndFusePattern<Conv2DOp>>(
32       context, dependenceGraph,
33       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
34       LinalgFusionOptions().setIndicesToFuse({2}),
35       LinalgTransformationFilter(
36           StringAttr::get(context, "basic_fusion"),
37           StringAttr::get(context, "after_basic_fusion")),
38       LinalgTransformationFilter(
39           ArrayRef<StringAttr>(),
40           StringAttr::get(context, "after_basic_fusion_producer")),
41       LinalgTransformationFilter(
42           ArrayRef<StringAttr>(),
43           StringAttr::get(context, "after_basic_fusion_original")));
44 
45   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
46       context, dependenceGraph,
47       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
48       LinalgFusionOptions().setIndicesToFuse({0}),
49       LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
50                                  StringAttr::get(context, "after_lhs_fusion")),
51       LinalgTransformationFilter(
52           ArrayRef<StringAttr>(),
53           StringAttr::get(context, "after_lhs_fusion_producer")),
54       LinalgTransformationFilter(
55           ArrayRef<StringAttr>(),
56           StringAttr::get(context, "after_lhs_fusion_original")));
57 
58   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
59       context, dependenceGraph,
60       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
61       LinalgFusionOptions().setIndicesToFuse({2}),
62       LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
63                                  StringAttr::get(context, "after_out_fusion")),
64       LinalgTransformationFilter(
65           ArrayRef<StringAttr>(),
66           StringAttr::get(context, "after_out_fusion_producer")),
67       LinalgTransformationFilter(
68           ArrayRef<StringAttr>(),
69           StringAttr::get(context, "after_out_fusion_original")));
70 
71   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
72       context, dependenceGraph,
73       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
74       LinalgFusionOptions().setIndicesToFuse({1}),
75       LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
76                                  StringAttr::get(context, "after_rhs_fusion")),
77       LinalgTransformationFilter(
78           ArrayRef<StringAttr>(),
79           StringAttr::get(context, "after_rhs_fusion_producer")),
80       LinalgTransformationFilter(
81           ArrayRef<StringAttr>(),
82           StringAttr::get(context, "after_rhs_fusion_original")));
83 
84   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
85       context, dependenceGraph,
86       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
87       LinalgFusionOptions().setIndicesToFuse({0, 2}),
88       LinalgTransformationFilter(
89           StringAttr::get(context, "two_operand_fusion"),
90           StringAttr::get(context, "after_two_operand_fusion")),
91       LinalgTransformationFilter(
92           ArrayRef<StringAttr>(),
93           StringAttr::get(context, "after_two_operand_fusion_producer")),
94       LinalgTransformationFilter(
95           ArrayRef<StringAttr>(),
96           StringAttr::get(context, "after_two_operand_fusion_original")));
97 
98   patterns.add<LinalgTileAndFusePattern<GenericOp>>(
99       context, dependenceGraph,
100       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
101       LinalgFusionOptions().setIndicesToFuse({0, 1}),
102       LinalgTransformationFilter(
103           StringAttr::get(context, "transpose_fusion"),
104           StringAttr::get(context, "after_transpose_fusion")),
105       LinalgTransformationFilter(
106           ArrayRef<StringAttr>(),
107           StringAttr::get(context, "after_transpose_fusion_producer")),
108       LinalgTransformationFilter(
109           ArrayRef<StringAttr>(),
110           StringAttr::get(context, "after_transpose_fusion_original")));
111 }
112 
113 namespace {
114 template <LinalgTilingLoopType LoopType>
115 struct TestLinalgFusionTransforms
116     : public PassWrapper<TestLinalgFusionTransforms<LoopType>,
117                          OperationPass<func::FuncOp>> {
118   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms)
119 
120   void getDependentDialects(DialectRegistry &registry) const override {
121     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
122                     scf::SCFDialect>();
123   }
124   TestLinalgFusionTransforms() = default;
125   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
126 
127   void runOnOperation() override {
128     MLIRContext *context = &this->getContext();
129     func::FuncOp funcOp = this->getOperation();
130     RewritePatternSet fusionPatterns(context);
131     Aliases alias;
132     LinalgDependenceGraph dependenceGraph =
133         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
134     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
135     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
136   }
137 };
138 
139 struct TestLinalgFusionTransformsParallelLoops
140     : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
141   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
142       TestLinalgFusionTransformsParallelLoops)
143 
144   StringRef getArgument() const final {
145     return "test-linalg-fusion-transform-patterns";
146   }
147   StringRef getDescription() const final {
148     return "Test Linalg fusion transformation patterns by applying them "
149            "greedily.";
150   }
151 };
152 
153 struct TestLinalgFusionTransformsLoops
154     : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
155   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops)
156 
157   StringRef getArgument() const final {
158     return "test-linalg-tensor-fusion-transform-patterns";
159   }
160   StringRef getDescription() const final {
161     return "Test Linalg on tensor fusion transformation "
162            "patterns by applying them greedily.";
163   }
164 };
165 
166 struct TestLinalgFusionTransformsTiledLoops
167     : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
168   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
169       TestLinalgFusionTransformsTiledLoops)
170 
171   StringRef getArgument() const final {
172     return "test-linalg-tiled-loop-fusion-transform-patterns";
173   }
174   StringRef getDescription() const final {
175     return "Test Linalg on tensor fusion transformation "
176            "patterns by applying them greedily.";
177   }
178 };
179 } // namespace
180 
181 static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
182   OpBuilder b(f);
183   DenseSet<Operation *> eraseSet;
184 
185   // Save original Linalg ops, we only want to make a pass over those.
186   SmallVector<LinalgOp, 8> linalgOps;
187   f.walk([&](LinalgOp op) {
188     // TODO: support multi-results.
189     if (op->getNumResults() <= 1)
190       linalgOps.push_back(op);
191   });
192 
193   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
194   bool changed = false;
195   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
196     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
197       if (opOperand->get().getType().isa<MemRefType>()) {
198         // TODO: LinalgDependenceGraph should be able to update itself.
199         // The current naive and expensive reconstruction of the graph should be
200         // removed.
201         linalg::Aliases aliases;
202         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
203         auto info = fuseProducerOfBuffer(b, *opOperand, graph);
204         if (failed(info))
205           continue;
206         auto *originalOp = info->originalProducer.getOperation();
207         eraseSet.insert(originalOp);
208         auto *originalOpInLinalgOpsVector =
209             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
210         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
211         changed = true;
212       } else if (opOperand->get().getType().isa<RankedTensorType>()) {
213         // Tile and Fuse tensor input.
214         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
215           continue;
216         auto info = fuseProducerOfTensor(b, *opOperand);
217         if (failed(info))
218           continue;
219         auto *originalOp = info->originalProducer.getOperation();
220         auto *originalOpInLinalgOpsVector =
221             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
222         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
223         // Don't mark for erasure in the tensor case, let DCE handle this.
224         changed = true;
225       }
226     }
227   }
228   // The `fuseProducerOfBuffer` function performs structural checks and in
229   // particular that no covering read or write exist between the consumer and
230   // the producer. As a consequence, the only fusions that may occur preserve
231   // subsequent dependences and are guaranteed by construction to produce the
232   // whole view. We may thus erase the producer once it is fused.
233   for (auto *e : eraseSet)
234     e->erase();
235 
236   return changed ? success() : failure();
237 }
238 
239 namespace {
240 struct TestLinalgGreedyFusion
241     : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> {
242   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion)
243 
244   void getDependentDialects(DialectRegistry &registry) const override {
245     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
246                     scf::SCFDialect>();
247   }
248   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
249   StringRef getDescription() const final {
250     return "Test Linalg fusion by applying a greedy test transformation.";
251   }
252   void runOnOperation() override {
253     MLIRContext *context = &getContext();
254     RewritePatternSet patterns =
255         linalg::getLinalgTilingCanonicalizationPatterns(context);
256     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
257     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
258     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
259     OpPassManager pm(func::FuncOp::getOperationName());
260     pm.addPass(createLoopInvariantCodeMotionPass());
261     pm.addPass(createCanonicalizerPass());
262     pm.addPass(createCSEPass());
263     do {
264       (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
265       if (failed(runPipeline(pm, getOperation())))
266         this->signalPassFailure();
267     } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
268   }
269 };
270 
271 /// Pass to test tile and fuse of sequence of operations. Intended only for
272 /// testing.
273 struct TestLinalgTileAndFuseSequencePass
274     : public PassWrapper<TestLinalgTileAndFuseSequencePass,
275                          OperationPass<func::FuncOp>> {
276   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
277       TestLinalgTileAndFuseSequencePass)
278 
279   StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
280   StringRef getDescription() const final {
281     return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
282   }
283   TestLinalgTileAndFuseSequencePass() = default;
284   TestLinalgTileAndFuseSequencePass(
285       const TestLinalgTileAndFuseSequencePass &pass)
286       : PassWrapper(pass){};
287 
288   ListOption<int64_t> tileSizes{*this, "tile-sizes",
289                                 llvm::cl::desc("Tile sizes to use for ops"),
290                                 llvm::cl::ZeroOrMore};
291 
292   void getDependentDialects(DialectRegistry &registry) const override {
293     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
294                     scf::SCFDialect>();
295   }
296 
297   void runOnOperation() override {
298     func::FuncOp funcOp = getOperation();
299     auto &blocks = funcOp.getBody().getBlocks();
300     if (!llvm::hasSingleElement(blocks)) {
301       return;
302     }
303     SmallVector<LinalgOp, 2> linalgOps =
304         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
305     Aliases aliases;
306     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
307     OpBuilder builder(funcOp.getContext());
308     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
309     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
310           return linalgOp.hasTensorSemantics();
311         }))
312       loopType = LinalgTilingLoopType::Loops;
313     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
314         builder, linalgOps, dependenceGraph,
315         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
316     if (!tileAndFuseOps)
317       return signalPassFailure();
318     if (linalgOps.back().hasTensorSemantics()) {
319       linalgOps.back().getOperation()->replaceAllUsesWith(
320           tileAndFuseOps->fusedLoops.front());
321     }
322     for (auto op : linalgOps)
323       if (op.hasBufferSemantics())
324         op.erase();
325   }
326 };
327 
328 } // namespace
329 
330 namespace mlir {
331 namespace test {
332 void registerTestLinalgFusionTransforms() {
333   PassRegistration<TestLinalgFusionTransformsParallelLoops>();
334 }
335 void registerTestLinalgTensorFusionTransforms() {
336   PassRegistration<TestLinalgFusionTransformsLoops>();
337 }
338 void registerTestLinalgTiledLoopFusionTransforms() {
339   PassRegistration<TestLinalgFusionTransformsTiledLoops>();
340 }
341 void registerTestLinalgGreedyFusion() {
342   PassRegistration<TestLinalgGreedyFusion>();
343 }
344 void registerTestLinalgTileAndFuseSequencePass() {
345   PassRegistration<TestLinalgTileAndFuseSequencePass>();
346 }
347 
348 } // namespace test
349 } // namespace mlir
350