1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing fusion of elementwise operations in 10 // Linalg, mainly linalg options. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassManager.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 20 namespace mlir { 21 22 static void addOperands(Operation *op, SetVector<Value> &operandSet) { 23 if (!op) 24 return; 25 TypeSwitch<Operation *, void>(op) 26 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 27 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 28 operandSet.insert(inputOperands.begin(), inputOperands.end()); 29 }) 30 .Default([&](Operation *operation) { 31 operandSet.insert(operation->operand_begin(), operation->operand_end()); 32 }); 33 } 34 35 template <int limit = 3> 36 static bool setFusedOpOperandLimit(const OpResult &producer, 37 const OpOperand &consumer) { 38 SetVector<Value> fusedOpOperands; 39 if (producer.getOwner()->getNumResults() != 1) 40 return false; 41 addOperands(consumer.getOwner(), fusedOpOperands); 42 fusedOpOperands.remove(producer); 43 addOperands(producer.getOwner(), fusedOpOperands); 44 return fusedOpOperands.size() <= limit; 45 } 46 47 namespace { 48 struct TestLinalgElementwiseFusion 49 : public PassWrapper<TestLinalgElementwiseFusion, FunctionPass> { 50 void getDependentDialects(DialectRegistry ®istry) const override { 51 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 52 tensor::TensorDialect>(); 53 } 54 55 void runOnFunction() override { 56 MLIRContext *context = &this->getContext(); 57 FuncOp funcOp = this->getFunction(); 58 RewritePatternSet fusionPatterns(context); 59 60 linalg::populateElementwiseOpsFusionPatterns( 61 fusionPatterns, 62 linalg::LinalgElementwiseFusionOptions() 63 .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); 64 65 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 66 std::move(fusionPatterns)); 67 } 68 }; 69 70 struct TestPushExpandingReshape 71 : public PassWrapper<TestPushExpandingReshape, FunctionPass> { 72 void getDependentDialects(DialectRegistry ®istry) const override { 73 registry 74 .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>(); 75 } 76 77 void runOnFunction() override { 78 MLIRContext *context = &this->getContext(); 79 FuncOp funcOp = this->getFunction(); 80 RewritePatternSet patterns(context); 81 linalg::populatePushReshapeOpsPatterns(patterns); 82 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 83 } 84 }; 85 } // namespace 86 87 namespace test { 88 void registerTestLinalgElementwiseFusion() { 89 PassRegistration<TestLinalgElementwiseFusion> testElementwiseFusionPass( 90 "test-linalg-elementwise-fusion-patterns", 91 "Test Linalg element wise operation fusion patterns"); 92 } 93 94 void registerTestPushExpandingReshape() { 95 PassRegistration<TestPushExpandingReshape> testPushExpandingReshapePass( 96 "test-linalg-push-reshape", "Test Linalg reshape push patterns"); 97 } 98 } // namespace test 99 100 } // namespace mlir 101