1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/SCF/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "mlir/Transforms/Passes.h"
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
25 template <LinalgTilingLoopType LoopType>
26 static void fillFusionPatterns(MLIRContext *context,
27                                const LinalgDependenceGraph &dependenceGraph,
28                                RewritePatternSet &patterns) {
29   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
30                LinalgTileAndFusePattern<Conv2DOp>>(
31       context, dependenceGraph,
32       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
33       LinalgFusionOptions().setIndicesToFuse({2}),
34       LinalgTransformationFilter(
35           StringAttr::get(context, "basic_fusion"),
36           StringAttr::get(context, "after_basic_fusion")),
37       LinalgTransformationFilter(
38           ArrayRef<StringAttr>(),
39           StringAttr::get(context, "after_basic_fusion_producer")),
40       LinalgTransformationFilter(
41           ArrayRef<StringAttr>(),
42           StringAttr::get(context, "after_basic_fusion_original")));
43 
44   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
45       context, dependenceGraph,
46       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
47       LinalgFusionOptions().setIndicesToFuse({0}),
48       LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
49                                  StringAttr::get(context, "after_lhs_fusion")),
50       LinalgTransformationFilter(
51           ArrayRef<StringAttr>(),
52           StringAttr::get(context, "after_lhs_fusion_producer")),
53       LinalgTransformationFilter(
54           ArrayRef<StringAttr>(),
55           StringAttr::get(context, "after_lhs_fusion_original")));
56 
57   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
58       context, dependenceGraph,
59       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
60       LinalgFusionOptions().setIndicesToFuse({2}),
61       LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
62                                  StringAttr::get(context, "after_out_fusion")),
63       LinalgTransformationFilter(
64           ArrayRef<StringAttr>(),
65           StringAttr::get(context, "after_out_fusion_producer")),
66       LinalgTransformationFilter(
67           ArrayRef<StringAttr>(),
68           StringAttr::get(context, "after_out_fusion_original")));
69 
70   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
71       context, dependenceGraph,
72       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
73       LinalgFusionOptions().setIndicesToFuse({1}),
74       LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
75                                  StringAttr::get(context, "after_rhs_fusion")),
76       LinalgTransformationFilter(
77           ArrayRef<StringAttr>(),
78           StringAttr::get(context, "after_rhs_fusion_producer")),
79       LinalgTransformationFilter(
80           ArrayRef<StringAttr>(),
81           StringAttr::get(context, "after_rhs_fusion_original")));
82 
83   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
84       context, dependenceGraph,
85       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
86       LinalgFusionOptions().setIndicesToFuse({0, 2}),
87       LinalgTransformationFilter(
88           StringAttr::get(context, "two_operand_fusion"),
89           StringAttr::get(context, "after_two_operand_fusion")),
90       LinalgTransformationFilter(
91           ArrayRef<StringAttr>(),
92           StringAttr::get(context, "after_two_operand_fusion_producer")),
93       LinalgTransformationFilter(
94           ArrayRef<StringAttr>(),
95           StringAttr::get(context, "after_two_operand_fusion_original")));
96 
97   patterns.add<LinalgTileAndFusePattern<GenericOp>>(
98       context, dependenceGraph,
99       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
100       LinalgFusionOptions().setIndicesToFuse({0, 1}),
101       LinalgTransformationFilter(
102           StringAttr::get(context, "transpose_fusion"),
103           StringAttr::get(context, "after_transpose_fusion")),
104       LinalgTransformationFilter(
105           ArrayRef<StringAttr>(),
106           StringAttr::get(context, "after_transpose_fusion_producer")),
107       LinalgTransformationFilter(
108           ArrayRef<StringAttr>(),
109           StringAttr::get(context, "after_transpose_fusion_original")));
110 }
111 
112 namespace {
113 template <LinalgTilingLoopType LoopType>
114 struct TestLinalgFusionTransforms
115     : public PassWrapper<TestLinalgFusionTransforms<LoopType>,
116                          OperationPass<FuncOp>> {
117   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms)
118 
119   void getDependentDialects(DialectRegistry &registry) const override {
120     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
121                     scf::SCFDialect>();
122   }
123   TestLinalgFusionTransforms() = default;
124   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
125 
126   void runOnOperation() override {
127     MLIRContext *context = &this->getContext();
128     FuncOp funcOp = this->getOperation();
129     RewritePatternSet fusionPatterns(context);
130     Aliases alias;
131     LinalgDependenceGraph dependenceGraph =
132         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
133     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
134     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
135   }
136 };
137 
138 struct TestLinalgFusionTransformsParallelLoops
139     : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
140   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
141       TestLinalgFusionTransformsParallelLoops)
142 
143   StringRef getArgument() const final {
144     return "test-linalg-fusion-transform-patterns";
145   }
146   StringRef getDescription() const final {
147     return "Test Linalg fusion transformation patterns by applying them "
148            "greedily.";
149   }
150 };
151 
152 struct TestLinalgFusionTransformsLoops
153     : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
154   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops)
155 
156   StringRef getArgument() const final {
157     return "test-linalg-tensor-fusion-transform-patterns";
158   }
159   StringRef getDescription() const final {
160     return "Test Linalg on tensor fusion transformation "
161            "patterns by applying them greedily.";
162   }
163 };
164 
165 struct TestLinalgFusionTransformsTiledLoops
166     : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
167   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
168       TestLinalgFusionTransformsTiledLoops)
169 
170   StringRef getArgument() const final {
171     return "test-linalg-tiled-loop-fusion-transform-patterns";
172   }
173   StringRef getDescription() const final {
174     return "Test Linalg on tensor fusion transformation "
175            "patterns by applying them greedily.";
176   }
177 };
178 } // namespace
179 
180 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
181   OpBuilder b(f);
182   DenseSet<Operation *> eraseSet;
183 
184   // Save original Linalg ops, we only want to make a pass over those.
185   SmallVector<LinalgOp, 8> linalgOps;
186   f.walk([&](LinalgOp op) {
187     // TODO: support multi-results.
188     if (op->getNumResults() <= 1)
189       linalgOps.push_back(op);
190   });
191 
192   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
193   bool changed = false;
194   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
195     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
196       if (opOperand->get().getType().isa<MemRefType>()) {
197         // TODO: LinalgDependenceGraph should be able to update itself.
198         // The current naive and expensive reconstruction of the graph should be
199         // removed.
200         linalg::Aliases aliases;
201         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
202         auto info = fuseProducerOfBuffer(b, *opOperand, graph);
203         if (failed(info))
204           continue;
205         auto *originalOp = info->originalProducer.getOperation();
206         eraseSet.insert(originalOp);
207         auto *originalOpInLinalgOpsVector =
208             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
209         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
210         changed = true;
211       } else if (opOperand->get().getType().isa<RankedTensorType>()) {
212         // Tile and Fuse tensor input.
213         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
214           continue;
215         auto info = fuseProducerOfTensor(b, *opOperand);
216         if (failed(info))
217           continue;
218         auto *originalOp = info->originalProducer.getOperation();
219         auto *originalOpInLinalgOpsVector =
220             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
221         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
222         // Don't mark for erasure in the tensor case, let DCE handle this.
223         changed = true;
224       }
225     }
226   }
227   // The `fuseProducerOfBuffer` function performs structural checks and in
228   // particular that no covering read or write exist between the consumer and
229   // the producer. As a consequence, the only fusions that may occur preserve
230   // subsequent dependences and are guaranteed by construction to produce the
231   // whole view. We may thus erase the producer once it is fused.
232   for (auto *e : eraseSet)
233     e->erase();
234 
235   return changed ? success() : failure();
236 }
237 
238 namespace {
239 struct TestLinalgGreedyFusion
240     : public PassWrapper<TestLinalgGreedyFusion, OperationPass<FuncOp>> {
241   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion)
242 
243   void getDependentDialects(DialectRegistry &registry) const override {
244     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
245                     scf::SCFDialect>();
246   }
247   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
248   StringRef getDescription() const final {
249     return "Test Linalg fusion by applying a greedy test transformation.";
250   }
251   void runOnOperation() override {
252     MLIRContext *context = &getContext();
253     RewritePatternSet patterns =
254         linalg::getLinalgTilingCanonicalizationPatterns(context);
255     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
256     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
257     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
258     do {
259       (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
260       PassManager pm(context);
261       pm.addPass(createLoopInvariantCodeMotionPass());
262       pm.addPass(createCanonicalizerPass());
263       pm.addPass(createCSEPass());
264       LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>());
265       if (failed(res))
266         this->signalPassFailure();
267     } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
268   }
269 };
270 
271 /// Pass to test tile and fuse of sequence of operations. Intended only for
272 /// testing.
273 struct TestLinalgTileAndFuseSequencePass
274     : public PassWrapper<TestLinalgTileAndFuseSequencePass,
275                          OperationPass<FuncOp>> {
276   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
277       TestLinalgTileAndFuseSequencePass)
278 
279   StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
280   StringRef getDescription() const final {
281     return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
282   }
283   TestLinalgTileAndFuseSequencePass() = default;
284   TestLinalgTileAndFuseSequencePass(
285       const TestLinalgTileAndFuseSequencePass &pass)
286       : PassWrapper(pass){};
287 
288   ListOption<int64_t> tileSizes{*this, "tile-sizes",
289                                 llvm::cl::desc("Tile sizes to use for ops"),
290                                 llvm::cl::ZeroOrMore};
291 
292   void getDependentDialects(DialectRegistry &registry) const override {
293     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
294                     scf::SCFDialect>();
295   }
296 
297   void runOnOperation() override {
298     FuncOp funcOp = getOperation();
299     auto &blocks = funcOp.getBody().getBlocks();
300     if (!llvm::hasSingleElement(blocks)) {
301       return;
302     }
303     SmallVector<LinalgOp, 2> linalgOps =
304         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
305     Aliases aliases;
306     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
307     OpBuilder builder(funcOp.getContext());
308     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
309     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
310           return linalgOp.hasTensorSemantics();
311         }))
312       loopType = LinalgTilingLoopType::Loops;
313     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
314         builder, linalgOps, dependenceGraph,
315         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
316     if (!tileAndFuseOps)
317       return signalPassFailure();
318     if (linalgOps.back().hasTensorSemantics()) {
319       linalgOps.back().getOperation()->replaceAllUsesWith(
320           tileAndFuseOps->fusedLoops.front());
321     }
322     for (auto op : linalgOps)
323       if (op.hasBufferSemantics())
324         op.erase();
325   }
326 };
327 
328 } // namespace
329 
330 namespace mlir {
331 namespace test {
332 void registerTestLinalgFusionTransforms() {
333   PassRegistration<TestLinalgFusionTransformsParallelLoops>();
334 }
335 void registerTestLinalgTensorFusionTransforms() {
336   PassRegistration<TestLinalgFusionTransformsLoops>();
337 }
338 void registerTestLinalgTiledLoopFusionTransforms() {
339   PassRegistration<TestLinalgFusionTransformsTiledLoops>();
340 }
341 void registerTestLinalgGreedyFusion() {
342   PassRegistration<TestLinalgGreedyFusion>();
343 }
344 void registerTestLinalgTileAndFuseSequencePass() {
345   PassRegistration<TestLinalgTileAndFuseSequencePass>();
346 }
347 
348 } // namespace test
349 } // namespace mlir
350