1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing fusion of elementwise operations in
10 // Linalg, mainly linalg options.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 namespace mlir {
21 
22 static void addOperands(Operation *op, SetVector<Value> &operandSet) {
23   if (!op)
24     return;
25   TypeSwitch<Operation *, void>(op)
26       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
27         SmallVector<Value> inputOperands = linalgOp.getInputOperands();
28         operandSet.insert(inputOperands.begin(), inputOperands.end());
29       })
30       .Default([&](Operation *operation) {
31         operandSet.insert(operation->operand_begin(), operation->operand_end());
32       });
33 }
34 
35 template <int limit = 3>
36 static bool setFusedOpOperandLimit(const OpResult &producer,
37                                    const OpOperand &consumer) {
38   SetVector<Value> fusedOpOperands;
39   if (producer.getOwner()->getNumResults() != 1)
40     return false;
41   addOperands(consumer.getOwner(), fusedOpOperands);
42   fusedOpOperands.remove(producer);
43   addOperands(producer.getOwner(), fusedOpOperands);
44   return fusedOpOperands.size() <= limit;
45 }
46 
47 namespace {
48 struct TestLinalgElementwiseFusion
49     : public PassWrapper<TestLinalgElementwiseFusion, FunctionPass> {
50   void getDependentDialects(DialectRegistry &registry) const override {
51     registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
52                     tensor::TensorDialect>();
53   }
54 
55   void runOnFunction() override {
56     MLIRContext *context = &this->getContext();
57     FuncOp funcOp = this->getFunction();
58     RewritePatternSet fusionPatterns(context);
59 
60     linalg::populateElementwiseOpsFusionPatterns(
61         fusionPatterns,
62         linalg::LinalgElementwiseFusionOptions()
63             .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
64 
65     (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
66                                        std::move(fusionPatterns));
67   }
68 };
69 
70 struct TestPushExpandingReshape
71     : public PassWrapper<TestPushExpandingReshape, FunctionPass> {
72   void getDependentDialects(DialectRegistry &registry) const override {
73     registry
74         .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
75   }
76 
77   void runOnFunction() override {
78     MLIRContext *context = &this->getContext();
79     FuncOp funcOp = this->getFunction();
80     RewritePatternSet patterns(context);
81     linalg::populatePushReshapeOpsPatterns(patterns);
82     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
83   }
84 };
85 } // namespace
86 
87 namespace test {
88 void registerTestLinalgElementwiseFusion() {
89   PassRegistration<TestLinalgElementwiseFusion> testElementwiseFusionPass(
90       "test-linalg-elementwise-fusion-patterns",
91       "Test Linalg element wise operation fusion patterns");
92 }
93 
94 void registerTestPushExpandingReshape() {
95   PassRegistration<TestPushExpandingReshape> testPushExpandingReshapePass(
96       "test-linalg-push-reshape", "Test Linalg reshape push patterns");
97 }
98 } // namespace test
99 
100 } // namespace mlir
101