1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg fusion patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/SCF/Transforms.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassManager.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "mlir/Transforms/Passes.h"
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 template <LinalgTilingLoopType LoopType>
25 static void fillFusionPatterns(MLIRContext *context,
26                                const LinalgDependenceGraph &dependenceGraph,
27                                RewritePatternSet &patterns) {
28   patterns.add<LinalgTileAndFusePattern<MatmulOp>,
29                LinalgTileAndFusePattern<Conv2DOp>>(
30       context, dependenceGraph,
31       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
32       LinalgFusionOptions().setIndicesToFuse({2}),
33       LinalgTransformationFilter(
34           StringAttr::get(context, "basic_fusion"),
35           StringAttr::get(context, "after_basic_fusion")),
36       LinalgTransformationFilter(
37           ArrayRef<StringAttr>(),
38           StringAttr::get(context, "after_basic_fusion_producer")),
39       LinalgTransformationFilter(
40           ArrayRef<StringAttr>(),
41           StringAttr::get(context, "after_basic_fusion_original")));
42 
43   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
44       context, dependenceGraph,
45       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
46       LinalgFusionOptions().setIndicesToFuse({0}),
47       LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
48                                  StringAttr::get(context, "after_lhs_fusion")),
49       LinalgTransformationFilter(
50           ArrayRef<StringAttr>(),
51           StringAttr::get(context, "after_lhs_fusion_producer")),
52       LinalgTransformationFilter(
53           ArrayRef<StringAttr>(),
54           StringAttr::get(context, "after_lhs_fusion_original")));
55 
56   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
57       context, dependenceGraph,
58       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
59       LinalgFusionOptions().setIndicesToFuse({2}),
60       LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
61                                  StringAttr::get(context, "after_out_fusion")),
62       LinalgTransformationFilter(
63           ArrayRef<StringAttr>(),
64           StringAttr::get(context, "after_out_fusion_producer")),
65       LinalgTransformationFilter(
66           ArrayRef<StringAttr>(),
67           StringAttr::get(context, "after_out_fusion_original")));
68 
69   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
70       context, dependenceGraph,
71       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
72       LinalgFusionOptions().setIndicesToFuse({1}),
73       LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
74                                  StringAttr::get(context, "after_rhs_fusion")),
75       LinalgTransformationFilter(
76           ArrayRef<StringAttr>(),
77           StringAttr::get(context, "after_rhs_fusion_producer")),
78       LinalgTransformationFilter(
79           ArrayRef<StringAttr>(),
80           StringAttr::get(context, "after_rhs_fusion_original")));
81 
82   patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
83       context, dependenceGraph,
84       LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
85       LinalgFusionOptions().setIndicesToFuse({0, 2}),
86       LinalgTransformationFilter(
87           StringAttr::get(context, "two_operand_fusion"),
88           StringAttr::get(context, "after_two_operand_fusion")),
89       LinalgTransformationFilter(
90           ArrayRef<StringAttr>(),
91           StringAttr::get(context, "after_two_operand_fusion_producer")),
92       LinalgTransformationFilter(
93           ArrayRef<StringAttr>(),
94           StringAttr::get(context, "after_two_operand_fusion_original")));
95 
96   patterns.add<LinalgTileAndFusePattern<GenericOp>>(
97       context, dependenceGraph,
98       LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
99       LinalgFusionOptions().setIndicesToFuse({0, 1}),
100       LinalgTransformationFilter(
101           StringAttr::get(context, "transpose_fusion"),
102           StringAttr::get(context, "after_transpose_fusion")),
103       LinalgTransformationFilter(
104           ArrayRef<StringAttr>(),
105           StringAttr::get(context, "after_transpose_fusion_producer")),
106       LinalgTransformationFilter(
107           ArrayRef<StringAttr>(),
108           StringAttr::get(context, "after_transpose_fusion_original")));
109 }
110 
111 namespace {
112 template <LinalgTilingLoopType LoopType>
113 struct TestLinalgFusionTransforms
114     : public PassWrapper<TestLinalgFusionTransforms<LoopType>,
115                          OperationPass<FuncOp>> {
116   void getDependentDialects(DialectRegistry &registry) const override {
117     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
118                     scf::SCFDialect>();
119   }
120   TestLinalgFusionTransforms() = default;
121   TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
122 
123   void runOnOperation() override {
124     MLIRContext *context = &this->getContext();
125     FuncOp funcOp = this->getOperation();
126     RewritePatternSet fusionPatterns(context);
127     Aliases alias;
128     LinalgDependenceGraph dependenceGraph =
129         LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
130     fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns);
131     (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns));
132   }
133 };
134 
135 struct TestLinalgFusionTransformsParallelLoops
136     : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> {
137   StringRef getArgument() const final {
138     return "test-linalg-fusion-transform-patterns";
139   }
140   StringRef getDescription() const final {
141     return "Test Linalg fusion transformation patterns by applying them "
142            "greedily.";
143   }
144 };
145 
146 struct TestLinalgFusionTransformsLoops
147     : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> {
148   StringRef getArgument() const final {
149     return "test-linalg-tensor-fusion-transform-patterns";
150   }
151   StringRef getDescription() const final {
152     return "Test Linalg on tensor fusion transformation "
153            "patterns by applying them greedily.";
154   }
155 };
156 
157 struct TestLinalgFusionTransformsTiledLoops
158     : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> {
159   StringRef getArgument() const final {
160     return "test-linalg-tiled-loop-fusion-transform-patterns";
161   }
162   StringRef getDescription() const final {
163     return "Test Linalg on tensor fusion transformation "
164            "patterns by applying them greedily.";
165   }
166 };
167 } // namespace
168 
169 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
170   OpBuilder b(f);
171   DenseSet<Operation *> eraseSet;
172 
173   // Save original Linalg ops, we only want to make a pass over those.
174   SmallVector<LinalgOp, 8> linalgOps;
175   f.walk([&](LinalgOp op) {
176     // TODO: support multi-results.
177     if (op->getNumResults() <= 1)
178       linalgOps.push_back(op);
179   });
180 
181   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
182   bool changed = false;
183   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
184     for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
185       if (opOperand->get().getType().isa<MemRefType>()) {
186         // TODO: LinalgDependenceGraph should be able to update itself.
187         // The current naive and expensive reconstruction of the graph should be
188         // removed.
189         linalg::Aliases aliases;
190         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
191         auto info = fuseProducerOfBuffer(b, *opOperand, graph);
192         if (failed(info))
193           continue;
194         auto *originalOp = info->originalProducer.getOperation();
195         eraseSet.insert(originalOp);
196         auto *originalOpInLinalgOpsVector =
197             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
198         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
199         changed = true;
200       } else if (opOperand->get().getType().isa<RankedTensorType>()) {
201         // Tile and Fuse tensor input.
202         if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
203           continue;
204         auto info = fuseProducerOfTensor(b, *opOperand);
205         if (failed(info))
206           continue;
207         auto *originalOp = info->originalProducer.getOperation();
208         auto *originalOpInLinalgOpsVector =
209             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
210         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
211         // Don't mark for erasure in the tensor case, let DCE handle this.
212         changed = true;
213       }
214     }
215   }
216   // The `fuseProducerOfBuffer` function performs structural checks and in
217   // particular that no covering read or write exist between the consumer and
218   // the producer. As a consequence, the only fusions that may occur preserve
219   // subsequent dependences and are guaranteed by construction to produce the
220   // whole view. We may thus erase the producer once it is fused.
221   for (auto *e : eraseSet)
222     e->erase();
223 
224   return changed ? success() : failure();
225 }
226 
227 namespace {
228 struct TestLinalgGreedyFusion
229     : public PassWrapper<TestLinalgGreedyFusion, OperationPass<FuncOp>> {
230   void getDependentDialects(DialectRegistry &registry) const override {
231     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
232                     scf::SCFDialect>();
233   }
234   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
235   StringRef getDescription() const final {
236     return "Test Linalg fusion by applying a greedy test transformation.";
237   }
238   void runOnOperation() override {
239     MLIRContext *context = &getContext();
240     RewritePatternSet patterns =
241         linalg::getLinalgTilingCanonicalizationPatterns(context);
242     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
243     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
244     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
245     do {
246       (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
247       PassManager pm(context);
248       pm.addPass(createLoopInvariantCodeMotionPass());
249       pm.addPass(createCanonicalizerPass());
250       pm.addPass(createCSEPass());
251       LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>());
252       if (failed(res))
253         this->signalPassFailure();
254     } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
255   }
256 };
257 
258 /// Pass to test tile and fuse of sequence of operations. Intended only for
259 /// testing.
260 struct TestLinalgTileAndFuseSequencePass
261     : public PassWrapper<TestLinalgTileAndFuseSequencePass,
262                          OperationPass<FuncOp>> {
263   StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; }
264   StringRef getDescription() const final {
265     return "Test Linalg tiling and fusion of a sequence of Linalg operations.";
266   }
267   TestLinalgTileAndFuseSequencePass() = default;
268   TestLinalgTileAndFuseSequencePass(
269       const TestLinalgTileAndFuseSequencePass &pass)
270       : PassWrapper(pass){};
271 
272   ListOption<int64_t> tileSizes{
273       *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"),
274       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
275 
276   void getDependentDialects(DialectRegistry &registry) const override {
277     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
278                     scf::SCFDialect>();
279   }
280 
281   void runOnOperation() override {
282     FuncOp funcOp = getOperation();
283     auto &blocks = funcOp.getBody().getBlocks();
284     if (!llvm::hasSingleElement(blocks)) {
285       return;
286     }
287     SmallVector<LinalgOp, 2> linalgOps =
288         llvm::to_vector<2>(blocks.front().getOps<LinalgOp>());
289     Aliases aliases;
290     LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
291     OpBuilder builder(funcOp.getContext());
292     linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops;
293     if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) {
294           return linalgOp.hasTensorSemantics();
295         }))
296       loopType = LinalgTilingLoopType::Loops;
297     Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps(
298         builder, linalgOps, dependenceGraph,
299         LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType));
300     if (!tileAndFuseOps)
301       return signalPassFailure();
302     if (linalgOps.back().hasTensorSemantics()) {
303       linalgOps.back().getOperation()->replaceAllUsesWith(
304           tileAndFuseOps->fusedLoops.front());
305     }
306     for (auto op : linalgOps)
307       if (op.hasBufferSemantics())
308         op.erase();
309   }
310 };
311 
312 } // namespace
313 
314 namespace mlir {
315 namespace test {
316 void registerTestLinalgFusionTransforms() {
317   PassRegistration<TestLinalgFusionTransformsParallelLoops>();
318 }
319 void registerTestLinalgTensorFusionTransforms() {
320   PassRegistration<TestLinalgFusionTransformsLoops>();
321 }
322 void registerTestLinalgTiledLoopFusionTransforms() {
323   PassRegistration<TestLinalgFusionTransformsTiledLoops>();
324 }
325 void registerTestLinalgGreedyFusion() {
326   PassRegistration<TestLinalgGreedyFusion>();
327 }
328 void registerTestLinalgTileAndFuseSequencePass() {
329   PassRegistration<TestLinalgTileAndFuseSequencePass>();
330 }
331 
332 } // namespace test
333 } // namespace mlir
334