1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassManager.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 
23 static void addOperands(Operation *op, SetVector<Value> &operandSet) {
24   if (!op)
25     return;
26   TypeSwitch<Operation *, void>(op)
27       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
28         SmallVector<Value> inputOperands = linalgOp.getInputOperands();
29         operandSet.insert(inputOperands.begin(), inputOperands.end());
30       })
31       .Default([&](Operation *operation) {
32         operandSet.insert(operation->operand_begin(), operation->operand_end());
33       });
34 }
35 
36 template <int limit = 3>
37 static bool setFusedOpOperandLimit(const OpResult &producer,
38                                    const OpOperand &consumer) {
39   SetVector<Value> fusedOpOperands;
40   if (producer.getOwner()->getNumResults() != 1)
41     return false;
42   addOperands(consumer.getOwner(), fusedOpOperands);
43   fusedOpOperands.remove(producer);
44   addOperands(producer.getOwner(), fusedOpOperands);
45   return fusedOpOperands.size() <= limit;
46 }
47 
48 namespace {
49 struct TestLinalgElementwiseFusion
50     : public PassWrapper<TestLinalgElementwiseFusion,
51                          OperationPass<func::FuncOp>> {
52   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
53 
54   TestLinalgElementwiseFusion() = default;
55   TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
56       : PassWrapper(pass) {}
57   void getDependentDialects(DialectRegistry &registry) const override {
58     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
59                     tensor::TensorDialect>();
60   }
61   StringRef getArgument() const final {
62     return "test-linalg-elementwise-fusion-patterns";
63   }
64   StringRef getDescription() const final {
65     return "Test Linalg element wise operation fusion patterns";
66   }
67 
68   Option<bool> fuseGenericOps{
69       *this, "fuse-generic-ops",
70       llvm::cl::desc("Test fusion of generic operations."),
71       llvm::cl::init(false)};
72 
73   Option<bool> fuseWithReshapeByExpansion{
74       *this, "fuse-with-reshape-by-expansion",
75       llvm::cl::desc(
76           "Test fusion of generic operations with reshape by expansion"),
77       llvm::cl::init(false)};
78 
79   Option<bool> controlFuseByExpansion{
80       *this, "control-fusion-by-expansion",
81       llvm::cl::desc(
82           "Test controlling fusion of reshape with generic op by expansion"),
83       llvm::cl::init(false)};
84 
85   Option<bool> fuseWithReshapeByCollapsing{
86       *this, "fuse-with-reshape-by-collapsing",
87       llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
88                      "collapse the iteration space of the consumer"),
89       llvm::cl::init(false)};
90 
91   Option<bool> fuseWithReshapeByCollapsingWithControlFn{
92       *this, "fuse-with-reshape-by-collapsing-control",
93       llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
94                      "fusion patterns that "
95                      "collapse the iteration space of the consumer"),
96       llvm::cl::init(false)};
97 
98   void runOnOperation() override {
99     MLIRContext *context = &this->getContext();
100     func::FuncOp funcOp = this->getOperation();
101 
102     if (fuseGenericOps) {
103       RewritePatternSet fusionPatterns(context);
104       linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
105                                                    setFusedOpOperandLimit<4>);
106 
107       (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
108                                          std::move(fusionPatterns));
109       return;
110     }
111 
112     if (fuseWithReshapeByExpansion) {
113       RewritePatternSet fusionPatterns(context);
114       linalg::populateFoldReshapeOpsByExpansionPatterns(
115           fusionPatterns, [](const OpResult & /*producer*/,
116                              OpOperand & /*consumer*/) { return true; });
117       if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
118                                               std::move(fusionPatterns))))
119         return signalPassFailure();
120       return;
121     }
122 
123     if (controlFuseByExpansion) {
124       RewritePatternSet fusionPatterns(context);
125 
126       linalg::ControlFusionFn controlReshapeFusionFn =
127           [](const OpResult &producer, OpOperand &consumer) {
128             if (auto collapseOp =
129                     producer.getDefiningOp<tensor::CollapseShapeOp>()) {
130               if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
131                 return false;
132               }
133             }
134             if (auto expandOp =
135                     dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
136               if (expandOp->hasOneUse()) {
137                 OpOperand &use = *expandOp->getUses().begin();
138                 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
139                 if (linalgOp && linalgOp.isOutputTensor(&use))
140                   return true;
141               }
142               return false;
143             }
144             return true;
145           };
146 
147       linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
148                                                         controlReshapeFusionFn);
149       (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
150                                          std::move(fusionPatterns));
151       return;
152     }
153 
154     if (fuseWithReshapeByCollapsing) {
155       RewritePatternSet patterns(context);
156       linalg::populateFoldReshapeOpsByCollapsingPatterns(
157           patterns, [](const OpResult & /*producer*/,
158                        OpOperand & /*consumer*/) { return true; });
159       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
160     }
161 
162     if (fuseWithReshapeByCollapsingWithControlFn) {
163       RewritePatternSet patterns(context);
164       linalg::ControlFusionFn controlFn = [](const OpResult &producer,
165                                              OpOperand &consumer) -> bool {
166         if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) {
167           // Skip fusing the first operand.
168           return consumer.getOperandNumber();
169         }
170         return true;
171       };
172       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
173       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
174     }
175   }
176 };
177 
178 } // namespace
179 
180 namespace mlir {
181 namespace test {
182 void registerTestLinalgElementwiseFusion() {
183   PassRegistration<TestLinalgElementwiseFusion>();
184 }
185 } // namespace test
186 } // namespace mlir
187