1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 namespace mlir {
21 
22 static void addOperands(Operation *op, SetVector<Value> &operandSet) {
23   if (!op)
24     return;
25   TypeSwitch<Operation *, void>(op)
26       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
27         SmallVector<Value> inputOperands = linalgOp.getInputOperands();
28         operandSet.insert(inputOperands.begin(), inputOperands.end());
29       })
30       .Default([&](Operation *operation) {
31         operandSet.insert(operation->operand_begin(), operation->operand_end());
32       });
33 }
34 
35 template <int limit = 3>
36 static bool setFusedOpOperandLimit(const OpResult &producer,
37                                    const OpOperand &consumer) {
38   SetVector<Value> fusedOpOperands;
39   if (producer.getOwner()->getNumResults() != 1)
40     return false;
41   addOperands(consumer.getOwner(), fusedOpOperands);
42   fusedOpOperands.remove(producer);
43   addOperands(producer.getOwner(), fusedOpOperands);
44   return fusedOpOperands.size() <= limit;
45 }
46 
47 namespace {
48 struct TestLinalgElementwiseFusion
49     : public PassWrapper<TestLinalgElementwiseFusion, OperationPass<FuncOp>> {
50   void getDependentDialects(DialectRegistry &registry) const override {
51     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
52                     tensor::TensorDialect>();
53   }
54   StringRef getArgument() const final {
55     return "test-linalg-elementwise-fusion-patterns";
56   }
57   StringRef getDescription() const final {
58     return "Test Linalg element wise operation fusion patterns";
59   }
60 
61   void runOnOperation() override {
62     MLIRContext *context = &this->getContext();
63     FuncOp funcOp = this->getOperation();
64     RewritePatternSet fusionPatterns(context);
65 
66     linalg::populateElementwiseOpsFusionPatterns(
67         fusionPatterns,
68         linalg::LinalgElementwiseFusionOptions()
69             .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
70 
71     (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
72                                        std::move(fusionPatterns));
73   }
74 };
75 
76 struct TestLinalgControlFuseByExpansion
77     : public PassWrapper<TestLinalgControlFuseByExpansion,
78                          OperationPass<FuncOp>> {
79   void getDependentDialects(DialectRegistry &registry) const override {
80     registry
81         .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
82   }
83   StringRef getArgument() const final {
84     return "test-linalg-control-fusion-by-expansion";
85   }
86   StringRef getDescription() const final {
87     return "Test controlling of fusion of elementwise ops with reshape by "
88            "expansion";
89   }
90 
91   void runOnOperation() override {
92     MLIRContext *context = &this->getContext();
93     FuncOp funcOp = this->getOperation();
94     RewritePatternSet fusionPatterns(context);
95 
96     linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
97         [](const OpResult &producer, OpOperand &consumer) {
98           if (auto collapseOp =
99                   producer.getDefiningOp<tensor::CollapseShapeOp>()) {
100             if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
101               return false;
102             }
103           }
104           if (auto expandOp =
105                   dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
106             if (expandOp->hasOneUse()) {
107               OpOperand &use = *expandOp->getUses().begin();
108               auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
109               if (linalgOp && linalgOp.isOutputTensor(&use))
110                 return true;
111             }
112           }
113           return linalg::skipUnitDimReshape(producer, consumer);
114         };
115 
116     linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
117                                                       controlReshapeFusionFn);
118     (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
119                                        std::move(fusionPatterns));
120   }
121 };
122 
123 struct TestPushExpandingReshape
124     : public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
125   void getDependentDialects(DialectRegistry &registry) const override {
126     registry
127         .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
128   }
129   StringRef getArgument() const final { return "test-linalg-push-reshape"; }
130   StringRef getDescription() const final {
131     return "Test Linalg reshape push patterns";
132   }
133 
134   void runOnOperation() override {
135     MLIRContext *context = &this->getContext();
136     FuncOp funcOp = this->getOperation();
137     RewritePatternSet patterns(context);
138     linalg::populatePushReshapeOpsPatterns(patterns);
139     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
140   }
141 };
142 } // namespace
143 
144 namespace test {
145 void registerTestLinalgElementwiseFusion() {
146   PassRegistration<TestLinalgElementwiseFusion>();
147 }
148 
149 void registerTestLinalgControlFuseByExpansion() {
150   PassRegistration<TestLinalgControlFuseByExpansion>();
151 }
152 
153 void registerTestPushExpandingReshape() {
154   PassRegistration<TestPushExpandingReshape>();
155 }
156 } // namespace test
157 
158 } // namespace mlir
159