1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassManager.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 
23 static void addOperands(Operation *op, SetVector<Value> &operandSet) {
24   if (!op)
25     return;
26   TypeSwitch<Operation *, void>(op)
27       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
28         SmallVector<Value> inputOperands = linalgOp.getInputOperands();
29         operandSet.insert(inputOperands.begin(), inputOperands.end());
30       })
31       .Default([&](Operation *operation) {
32         operandSet.insert(operation->operand_begin(), operation->operand_end());
33       });
34 }
35 
36 template <int limit = 3>
37 static bool setFusedOpOperandLimit(const OpResult &producer,
38                                    const OpOperand &consumer) {
39   SetVector<Value> fusedOpOperands;
40   if (producer.getOwner()->getNumResults() != 1)
41     return false;
42   addOperands(consumer.getOwner(), fusedOpOperands);
43   fusedOpOperands.remove(producer);
44   addOperands(producer.getOwner(), fusedOpOperands);
45   return fusedOpOperands.size() <= limit;
46 }
47 
48 namespace {
49 struct TestLinalgElementwiseFusion
50     : public PassWrapper<TestLinalgElementwiseFusion, OperationPass<FuncOp>> {
51   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion)
52 
53   TestLinalgElementwiseFusion() = default;
54   TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
55       : PassWrapper(pass) {}
56   void getDependentDialects(DialectRegistry &registry) const override {
57     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
58                     tensor::TensorDialect>();
59   }
60   StringRef getArgument() const final {
61     return "test-linalg-elementwise-fusion-patterns";
62   }
63   StringRef getDescription() const final {
64     return "Test Linalg element wise operation fusion patterns";
65   }
66 
67   Option<bool> fuseGenericOps{
68       *this, "fuse-generic-ops",
69       llvm::cl::desc("Test fusion of generic operations."),
70       llvm::cl::init(false)};
71 
72   Option<bool> controlFuseByExpansion{
73       *this, "control-fusion-by-expansion",
74       llvm::cl::desc(
75           "Test controlling fusion of reshape with generic op by expansion"),
76       llvm::cl::init(false)};
77 
78   Option<bool> pushExpandingReshape{
79       *this, "push-expanding-reshape",
80       llvm::cl::desc("Test linalg expand_shape -> generic "
81                      "to generic -> expand_shape pattern"),
82       llvm::cl::init(false)};
83 
84   Option<bool> fuseWithReshapeByCollapsing{
85       *this, "fuse-with-reshape-by-collapsing",
86       llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
87                      "collapse the iteration space of the consumer"),
88       llvm::cl::init(false)};
89 
90   Option<bool> fuseWithReshapeByCollapsingWithControlFn{
91       *this, "fuse-with-reshape-by-collapsing-control",
92       llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
93                      "fusion patterns that "
94                      "collapse the iteration space of the consumer"),
95       llvm::cl::init(false)};
96 
97   void runOnOperation() override {
98     MLIRContext *context = &this->getContext();
99     FuncOp funcOp = this->getOperation();
100 
101     if (fuseGenericOps) {
102       RewritePatternSet fusionPatterns(context);
103       linalg::populateElementwiseOpsFusionPatterns(
104           fusionPatterns,
105           linalg::LinalgElementwiseFusionOptions()
106               .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
107 
108       (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
109                                          std::move(fusionPatterns));
110       return;
111     }
112 
113     if (controlFuseByExpansion) {
114       RewritePatternSet fusionPatterns(context);
115 
116       linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
117           [](const OpResult &producer, OpOperand &consumer) {
118             if (auto collapseOp =
119                     producer.getDefiningOp<tensor::CollapseShapeOp>()) {
120               if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
121                 return false;
122               }
123             }
124             if (auto expandOp =
125                     dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
126               if (expandOp->hasOneUse()) {
127                 OpOperand &use = *expandOp->getUses().begin();
128                 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
129                 if (linalgOp && linalgOp.isOutputTensor(&use))
130                   return true;
131               }
132             }
133             return linalg::skipUnitDimReshape(producer, consumer);
134           };
135 
136       linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
137                                                         controlReshapeFusionFn);
138       (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
139                                          std::move(fusionPatterns));
140       return;
141     }
142 
143     if (pushExpandingReshape) {
144       RewritePatternSet patterns(context);
145       linalg::populatePushReshapeOpsPatterns(patterns);
146       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
147     }
148 
149     if (fuseWithReshapeByCollapsing) {
150       RewritePatternSet patterns(context);
151       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns);
152       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
153     }
154 
155     if (fuseWithReshapeByCollapsingWithControlFn) {
156       RewritePatternSet patterns(context);
157       linalg::ControlElementwiseOpsFusionFn controlFn =
158           [](const OpResult &producer, OpOperand &consumer) -> bool {
159         if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) {
160           // Skip fusing the first operand.
161           return consumer.getOperandNumber();
162         }
163         return true;
164       };
165       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
166       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
167     }
168   }
169 };
170 
171 } // namespace
172 
173 namespace mlir {
174 namespace test {
175 void registerTestLinalgElementwiseFusion() {
176   PassRegistration<TestLinalgElementwiseFusion>();
177 }
178 } // namespace test
179 } // namespace mlir
180