1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 namespace mlir {
21 
22 static void addOperands(Operation *op, SetVector<Value> &operandSet) {
23   if (!op)
24     return;
25   TypeSwitch<Operation *, void>(op)
26       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
27         SmallVector<Value> inputOperands = linalgOp.getInputOperands();
28         operandSet.insert(inputOperands.begin(), inputOperands.end());
29       })
30       .Default([&](Operation *operation) {
31         operandSet.insert(operation->operand_begin(), operation->operand_end());
32       });
33 }
34 
35 template <int limit = 3>
36 static bool setFusedOpOperandLimit(const OpResult &producer,
37                                    const OpOperand &consumer) {
38   SetVector<Value> fusedOpOperands;
39   if (producer.getOwner()->getNumResults() != 1)
40     return false;
41   addOperands(consumer.getOwner(), fusedOpOperands);
42   fusedOpOperands.remove(producer);
43   addOperands(producer.getOwner(), fusedOpOperands);
44   return fusedOpOperands.size() <= limit;
45 }
46 
47 namespace {
48 struct TestLinalgElementwiseFusion
49     : public PassWrapper<TestLinalgElementwiseFusion, OperationPass<FuncOp>> {
50   TestLinalgElementwiseFusion() = default;
51   TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
52       : PassWrapper(pass) {}
53   void getDependentDialects(DialectRegistry &registry) const override {
54     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
55                     tensor::TensorDialect>();
56   }
57   StringRef getArgument() const final {
58     return "test-linalg-elementwise-fusion-patterns";
59   }
60   StringRef getDescription() const final {
61     return "Test Linalg element wise operation fusion patterns";
62   }
63 
64   Option<bool> fuseGenericOps{
65       *this, "fuse-generic-ops",
66       llvm::cl::desc("Test fusion of generic operations."),
67       llvm::cl::init(false)};
68 
69   Option<bool> controlFuseByExpansion{
70       *this, "control-fusion-by-expansion",
71       llvm::cl::desc(
72           "Test controlling fusion of reshape with generic op by expansion"),
73       llvm::cl::init(false)};
74 
75   Option<bool> pushExpandingReshape{
76       *this, "push-expanding-reshape",
77       llvm::cl::desc("Test linalg expand_shape -> generic "
78                      "to generic -> expand_shape pattern"),
79       llvm::cl::init(false)};
80 
81   Option<bool> fuseWithReshapeByCollapsing{
82       *this, "fuse-with-reshape-by-collapsing",
83       llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
84                      "collapse the iteration space of the consumer"),
85       llvm::cl::init(false)};
86 
87   Option<bool> fuseWithReshapeByCollapsingWithControlFn{
88       *this, "fuse-with-reshape-by-collapsing-control",
89       llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
90                      "fusion patterns that "
91                      "collapse the iteration space of the consumer"),
92       llvm::cl::init(false)};
93 
94   void runOnOperation() override {
95     MLIRContext *context = &this->getContext();
96     FuncOp funcOp = this->getOperation();
97 
98     if (fuseGenericOps) {
99       RewritePatternSet fusionPatterns(context);
100       linalg::populateElementwiseOpsFusionPatterns(
101           fusionPatterns,
102           linalg::LinalgElementwiseFusionOptions()
103               .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
104 
105       (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
106                                          std::move(fusionPatterns));
107       return;
108     }
109 
110     if (controlFuseByExpansion) {
111       RewritePatternSet fusionPatterns(context);
112 
113       linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
114           [](const OpResult &producer, OpOperand &consumer) {
115             if (auto collapseOp =
116                     producer.getDefiningOp<tensor::CollapseShapeOp>()) {
117               if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
118                 return false;
119               }
120             }
121             if (auto expandOp =
122                     dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
123               if (expandOp->hasOneUse()) {
124                 OpOperand &use = *expandOp->getUses().begin();
125                 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
126                 if (linalgOp && linalgOp.isOutputTensor(&use))
127                   return true;
128               }
129             }
130             return linalg::skipUnitDimReshape(producer, consumer);
131           };
132 
133       linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
134                                                         controlReshapeFusionFn);
135       (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
136                                          std::move(fusionPatterns));
137       return;
138     }
139 
140     if (pushExpandingReshape) {
141       RewritePatternSet patterns(context);
142       linalg::populatePushReshapeOpsPatterns(patterns);
143       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
144     }
145 
146     if (fuseWithReshapeByCollapsing) {
147       RewritePatternSet patterns(context);
148       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns);
149       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
150     }
151 
152     if (fuseWithReshapeByCollapsingWithControlFn) {
153       RewritePatternSet patterns(context);
154       linalg::ControlElementwiseOpsFusionFn controlFn =
155           [](const OpResult &producer, OpOperand &consumer) -> bool {
156         if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) {
157           // Skip fusing the first operand.
158           return consumer.getOperandNumber();
159         }
160         return true;
161       };
162       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
163       (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
164     }
165   }
166 };
167 
168 } // namespace
169 
170 namespace test {
171 void registerTestLinalgElementwiseFusion() {
172   PassRegistration<TestLinalgElementwiseFusion>();
173 }
174 } // namespace test
175 
176 } // namespace mlir
177