1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "mlir/Transforms/Passes.h"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 /// Use this to safely fill patterns for this test, since RewritePatternSet::add
27 /// forwards Rvalues only to the first pattern.
28 template <typename OpTy, LinalgTilingLoopType LoopType>
29 static void fillFusionPattern(MLIRContext *context,
30                               const LinalgDependenceGraph &dependenceGraph,
31                               RewritePatternSet &patterns,
32                               const Twine &testCase,
33                               ArrayRef<int64_t> tileSizes,
34                               ArrayRef<int64_t> indicesToFuse) {
35   patterns.add<LinalgTileAndFusePattern<OpTy>>(
36       context, dependenceGraph,
37       LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType),
38       LinalgFusionOptions().setIndicesToFuse(indicesToFuse),
39       LinalgTransformationFilter(
40           StringAttr::get(context, testCase + "_fusion"),
41           StringAttr::get(context, "after_" + testCase + "_fusion")),
42       LinalgTransformationFilter(
43           ArrayRef<StringAttr>(),
44           StringAttr::get(context, "after_" + testCase + "_fusion_producer")),
45       LinalgTransformationFilter(
46           ArrayRef<StringAttr>(),
47           StringAttr::get(context, "after_" + testCase + "_fusion_original")));
48 }
49 
50 template <LinalgTilingLoopType LoopType>
51 static void fillFusionPatterns(MLIRContext *context,
52                                const LinalgDependenceGraph &dependenceGraph,
53                                RewritePatternSet &patterns) {
54   fillFusionPattern<Conv2DOp, LoopType>(context, dependenceGraph, patterns,
55                                         /*testCase=*/"basic",
56                                         /*tileSizes=*/{32, 64, 16},
57                                         /*indicesToFuse=*/{2});
58 
59   auto fillMatmulPattern = [&](const Twine &testCase,
60                                ArrayRef<int64_t> indicesToFuse) {
61     fillFusionPattern<MatmulOp, LoopType>(context, dependenceGraph, patterns,
62                                           testCase, /*tileSizes=*/{32, 64, 16},
63                                           indicesToFuse);
64   };
65   fillMatmulPattern(/*testCase=*/"basic",
66                     /*indicesToFuse=*/{2});
67   fillMatmulPattern(/*testCase=*/"lhs",
68                     /*indicesToFuse=*/{0});
69   fillMatmulPattern(/*testCase=*/"out",
70                     /*indicesToFuse=*/{2});
71   fillMatmulPattern(/*testCase=*/"rhs",
72                     /*indicesToFuse=*/{1});
73   fillMatmulPattern(/*testCase=*/"two_operand",
74                     /*indicesToFuse=*/{0, 2});
75 
76   fillFusionPattern<GenericOp, LoopType>(context, dependenceGraph, patterns,
77                                          /*testCase=*/"transpose",
78                                          /*tileSizes=*/{32, 64},
79                                          /*indicesToFuse=*/{0, 1});
80 }
81 
82 namespace {
83 template <LinalgTilingLoopType LoopType>
84 struct TestLinalgFusionTransforms
85     : public PassWrapper<TestLinalgFusionTransforms<LoopType>,
86                          OperationPass<func::FuncOp>> {
87   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms)
88 
89   void getDependentDialects(DialectRegistry &registry) const override {
90     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
91                     scf::SCFDialect>();
92   }
93   TestLinalgFusionTransforms() = default;
94   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
95 
96   void runOnOperation() override {
97     MLIRContext *context = &this->getContext();
98     func::FuncOp funcOp = this->getOperation();
99     RewritePatternSet fusionPatterns(context);
100     Aliases alias;
101     LinalgDependenceGraph dependenceGraph =
102         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
103     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
104     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
105   }
106 };
107 
108 struct TestLinalgFusionTransformsParallelLoops
109     : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
110   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
111       TestLinalgFusionTransformsParallelLoops)
112 
113   StringRef getArgument() const final {
114     return "test-linalg-fusion-transform-patterns";
115   }
116   StringRef getDescription() const final {
117     return "Test Linalg fusion transformation patterns by applying them "
118            "greedily.";
119   }
120 };
121 
122 struct TestLinalgFusionTransformsLoops
123     : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
124   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops)
125 
126   StringRef getArgument() const final {
127     return "test-linalg-tensor-fusion-transform-patterns";
128   }
129   StringRef getDescription() const final {
130     return "Test Linalg on tensor fusion transformation "
131            "patterns by applying them greedily.";
132   }
133 };
134 
135 struct TestLinalgFusionTransformsTiledLoops
136     : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
137   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
138       TestLinalgFusionTransformsTiledLoops)
139 
140   StringRef getArgument() const final {
141     return "test-linalg-tiled-loop-fusion-transform-patterns";
142   }
143   StringRef getDescription() const final {
144     return "Test Linalg on tensor fusion transformation "
145            "patterns by applying them greedily.";
146   }
147 };
148 } // namespace
149 
150 static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
151   OpBuilder b(f);
152   DenseSet<Operation *> eraseSet;
153 
154   // Save original Linalg ops, we only want to make a pass over those.
155   SmallVector<LinalgOp, 8> linalgOps;
156   f.walk([&](LinalgOp op) {
157     // TODO: support multi-results.
158     if (op->getNumResults() <= 1)
159       linalgOps.push_back(op);
160   });
161 
162   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
163   bool changed = false;
164   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
165     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
166       if (opOperand->get().getType().isa<MemRefType>()) {
167         // TODO: LinalgDependenceGraph should be able to update itself.
168         // The current naive and expensive reconstruction of the graph should be
169         // removed.
170         linalg::Aliases aliases;
171         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
172         auto info = fuseProducerOfBuffer(b, *opOperand, graph);
173         if (failed(info))
174           continue;
175         auto *originalOp = info->originalProducer.getOperation();
176         eraseSet.insert(originalOp);
177         auto *originalOpInLinalgOpsVector =
178             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
179         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
180         changed = true;
181       } else if (opOperand->get().getType().isa<RankedTensorType>()) {
182         // Tile and Fuse tensor input.
183         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
184           continue;
185         auto info = fuseProducerOfTensor(b, *opOperand);
186         if (failed(info))
187           continue;
188         auto *originalOp = info->originalProducer.getOperation();
189         auto *originalOpInLinalgOpsVector =
190             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
191         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
192         // Don't mark for erasure in the tensor case, let DCE handle this.
193         changed = true;
194       }
195     }
196   }
197   // The `fuseProducerOfBuffer` function performs structural checks and in
198   // particular that no covering read or write exist between the consumer and
199   // the producer. As a consequence, the only fusions that may occur preserve
200   // subsequent dependences and are guaranteed by construction to produce the
201   // whole view. We may thus erase the producer once it is fused.
202   for (auto *e : eraseSet)
203     e->erase();
204 
205   return changed ? success() : failure();
206 }
207 
208 namespace {
209 struct TestLinalgGreedyFusion
210     : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> {
211   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion)
212 
213   void getDependentDialects(DialectRegistry &registry) const override {
214     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
215                     scf::SCFDialect>();
216   }
217   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
218   StringRef getDescription() const final {
219     return "Test Linalg fusion by applying a greedy test transformation.";
220   }
221   void runOnOperation() override {
222     MLIRContext *context = &getContext();
223     RewritePatternSet patterns =
224         linalg::getLinalgTilingCanonicalizationPatterns(context);
225     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
226     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
227     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
228     OpPassManager pm(func::FuncOp::getOperationName());
229     pm.addPass(createLoopInvariantCodeMotionPass());
230     pm.addPass(createCanonicalizerPass());
231     pm.addPass(createCSEPass());
232     do {
233       (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
234       if (failed(runPipeline(pm, getOperation())))
235         this->signalPassFailure();
236     } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
237   }
238 };
239 
240 /// Pass to test tile and fuse of sequence of operations. Intended only for
241 /// testing.
242 struct TestLinalgTileAndFuseSequencePass
243     : public PassWrapper<TestLinalgTileAndFuseSequencePass,
244                          OperationPass<func::FuncOp>> {
245   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
246       TestLinalgTileAndFuseSequencePass)
247 
248   StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
249   StringRef getDescription() const final {
250     return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
251   }
252   TestLinalgTileAndFuseSequencePass() = default;
253   TestLinalgTileAndFuseSequencePass(
254       const TestLinalgTileAndFuseSequencePass &pass)
255       : PassWrapper(pass){};
256 
257   ListOption<int64_t> tileSizes{*this, "tile-sizes",
258                                 llvm::cl::desc("Tile sizes to use for ops")};
259 
260   void getDependentDialects(DialectRegistry &registry) const override {
261     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
262                     scf::SCFDialect>();
263   }
264 
265   void runOnOperation() override {
266     func::FuncOp funcOp = getOperation();
267     auto &blocks = funcOp.getBody().getBlocks();
268     if (!llvm::hasSingleElement(blocks)) {
269       return;
270     }
271     SmallVector<LinalgOp, 2> linalgOps =
272         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
273     Aliases aliases;
274     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
275     OpBuilder builder(funcOp.getContext());
276     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
277     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
278           return linalgOp.hasTensorSemantics();
279         }))
280       loopType = LinalgTilingLoopType::Loops;
281     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
282         builder, linalgOps, dependenceGraph,
283         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
284     if (!tileAndFuseOps)
285       return signalPassFailure();
286     if (linalgOps.back().hasTensorSemantics()) {
287       linalgOps.back().getOperation()->replaceAllUsesWith(
288           tileAndFuseOps->fusedLoops.front());
289     }
290     for (auto op : linalgOps)
291       if (op.hasBufferSemantics())
292         op.erase();
293   }
294 };
295 
296 } // namespace
297 
298 namespace mlir {
299 namespace test {
300 void registerTestLinalgFusionTransforms() {
301   PassRegistration<TestLinalgFusionTransformsParallelLoops>();
302 }
303 void registerTestLinalgTensorFusionTransforms() {
304   PassRegistration<TestLinalgFusionTransformsLoops>();
305 }
306 void registerTestLinalgTiledLoopFusionTransforms() {
307   PassRegistration<TestLinalgFusionTransformsTiledLoops>();
308 }
309 void registerTestLinalgGreedyFusion() {
310   PassRegistration<TestLinalgGreedyFusion>();
311 }
312 void registerTestLinalgTileAndFuseSequencePass() {
313   PassRegistration<TestLinalgTileAndFuseSequencePass>();
314 }
315 
316 } // namespace test
317 } // namespace mlir
318