1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 #include "llvm/ADT/TypeSwitch.h"
21
22 using namespace mlir;
23
addOperands(Operation * op,SetVector<Value> & operandSet)24 static void addOperands(Operation *op, SetVector<Value> &operandSet) {
25 if (!op)
26 return;
27 TypeSwitch<Operation *, void>(op)
28 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
29 SmallVector<Value> inputOperands = linalgOp.getInputOperands();
30 operandSet.insert(inputOperands.begin(), inputOperands.end());
31 })
32 .Default([&](Operation *operation) {
33 operandSet.insert(operation->operand_begin(), operation->operand_end());
34 });
35 }
36
37 template <int limit = 3>
setFusedOpOperandLimit(const OpResult & producer,const OpOperand & consumer)38 static bool setFusedOpOperandLimit(const OpResult &producer,
39 const OpOperand &consumer) {
40 SetVector<Value> fusedOpOperands;
41 if (producer.getOwner()->getNumResults() != 1)
42 return false;
43 addOperands(consumer.getOwner(), fusedOpOperands);
44 fusedOpOperands.remove(producer);
45 addOperands(producer.getOwner(), fusedOpOperands);
46 return fusedOpOperands.size() <= limit;
47 }
48
49 namespace {
50 struct TestLinalgElementwiseFusion
51 : public PassWrapper<TestLinalgElementwiseFusion,
52 OperationPass<func::FuncOp>> {
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
54
55 TestLinalgElementwiseFusion() = default;
TestLinalgElementwiseFusion__anona9de7fa00311::TestLinalgElementwiseFusion56 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
57 : PassWrapper(pass) {}
getDependentDialects__anona9de7fa00311::TestLinalgElementwiseFusion58 void getDependentDialects(DialectRegistry ®istry) const override {
59 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
60 tensor::TensorDialect>();
61 }
getArgument__anona9de7fa00311::TestLinalgElementwiseFusion62 StringRef getArgument() const final {
63 return "test-linalg-elementwise-fusion-patterns";
64 }
getDescription__anona9de7fa00311::TestLinalgElementwiseFusion65 StringRef getDescription() const final {
66 return "Test Linalg element wise operation fusion patterns";
67 }
68
69 Option<bool> fuseGenericOps{
70 *this, "fuse-generic-ops",
71 llvm::cl::desc("Test fusion of generic operations."),
72 llvm::cl::init(false)};
73
74 Option<bool> fuseWithReshapeByExpansion{
75 *this, "fuse-with-reshape-by-expansion",
76 llvm::cl::desc(
77 "Test fusion of generic operations with reshape by expansion"),
78 llvm::cl::init(false)};
79
80 Option<bool> controlFuseByExpansion{
81 *this, "control-fusion-by-expansion",
82 llvm::cl::desc(
83 "Test controlling fusion of reshape with generic op by expansion"),
84 llvm::cl::init(false)};
85
86 Option<bool> fuseWithReshapeByCollapsing{
87 *this, "fuse-with-reshape-by-collapsing",
88 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
89 "collapse the iteration space of the consumer"),
90 llvm::cl::init(false)};
91
92 Option<bool> fuseWithReshapeByCollapsingWithControlFn{
93 *this, "fuse-with-reshape-by-collapsing-control",
94 llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
95 "fusion patterns that "
96 "collapse the iteration space of the consumer"),
97 llvm::cl::init(false)};
98
runOnOperation__anona9de7fa00311::TestLinalgElementwiseFusion99 void runOnOperation() override {
100 MLIRContext *context = &this->getContext();
101 func::FuncOp funcOp = this->getOperation();
102
103 if (fuseGenericOps) {
104 RewritePatternSet fusionPatterns(context);
105 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
106 setFusedOpOperandLimit<4>);
107
108 (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
109 std::move(fusionPatterns));
110 return;
111 }
112
113 if (fuseWithReshapeByExpansion) {
114 RewritePatternSet fusionPatterns(context);
115 linalg::populateFoldReshapeOpsByExpansionPatterns(
116 fusionPatterns, [](const OpResult & /*producer*/,
117 OpOperand & /*consumer*/) { return true; });
118 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
119 std::move(fusionPatterns))))
120 return signalPassFailure();
121 return;
122 }
123
124 if (controlFuseByExpansion) {
125 RewritePatternSet fusionPatterns(context);
126
127 linalg::ControlFusionFn controlReshapeFusionFn =
128 [](const OpResult &producer, OpOperand &consumer) {
129 if (auto collapseOp =
130 producer.getDefiningOp<tensor::CollapseShapeOp>()) {
131 if (!collapseOp.getSrc().getDefiningOp<linalg::LinalgOp>()) {
132 return false;
133 }
134 }
135 if (auto expandOp =
136 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
137 if (expandOp->hasOneUse()) {
138 OpOperand &use = *expandOp->getUses().begin();
139 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
140 if (linalgOp && linalgOp.isOutputTensor(&use))
141 return true;
142 }
143 return false;
144 }
145 return true;
146 };
147
148 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
149 controlReshapeFusionFn);
150 (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
151 std::move(fusionPatterns));
152 return;
153 }
154
155 if (fuseWithReshapeByCollapsing) {
156 RewritePatternSet patterns(context);
157 linalg::populateFoldReshapeOpsByCollapsingPatterns(
158 patterns, [](const OpResult & /*producer*/,
159 OpOperand & /*consumer*/) { return true; });
160 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
161 }
162
163 if (fuseWithReshapeByCollapsingWithControlFn) {
164 RewritePatternSet patterns(context);
165 linalg::ControlFusionFn controlFn = [](const OpResult &producer,
166 OpOperand &consumer) -> bool {
167 if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) {
168 // Skip fusing the first operand.
169 return consumer.getOperandNumber();
170 }
171 return true;
172 };
173 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
174 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
175 }
176 }
177 };
178
179 } // namespace
180
181 namespace mlir {
182 namespace test {
registerTestLinalgElementwiseFusion()183 void registerTestLinalgElementwiseFusion() {
184 PassRegistration<TestLinalgElementwiseFusion>();
185 }
186 } // namespace test
187 } // namespace mlir
188