1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing fusion of elementwise operations in 10 // Linalg, mainly linalg options. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassManager.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 20 namespace mlir { 21 22 static void addOperands(Operation *op, SetVector<Value> &operandSet) { 23 if (!op) 24 return; 25 TypeSwitch<Operation *, void>(op) 26 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 27 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 28 operandSet.insert(inputOperands.begin(), inputOperands.end()); 29 }) 30 .Default([&](Operation *operation) { 31 operandSet.insert(operation->operand_begin(), operation->operand_end()); 32 }); 33 } 34 35 template <int limit = 3> 36 static bool setFusedOpOperandLimit(const OpResult &producer, 37 const OpOperand &consumer) { 38 SetVector<Value> fusedOpOperands; 39 if (producer.getOwner()->getNumResults() != 1) 40 return false; 41 addOperands(consumer.getOwner(), fusedOpOperands); 42 fusedOpOperands.remove(producer); 43 addOperands(producer.getOwner(), fusedOpOperands); 44 return fusedOpOperands.size() <= limit; 45 } 46 47 namespace { 48 struct TestLinalgElementwiseFusion 49 : public PassWrapper<TestLinalgElementwiseFusion, FunctionPass> { 50 void getDependentDialects(DialectRegistry ®istry) const override { 51 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 52 tensor::TensorDialect>(); 53 } 54 StringRef getArgument() const final { 55 return "test-linalg-elementwise-fusion-patterns"; 56 } 57 StringRef getDescription() const final { 58 return "Test Linalg element wise operation fusion patterns"; 59 } 60 61 void runOnFunction() override { 62 MLIRContext *context = &this->getContext(); 63 FuncOp funcOp = this->getFunction(); 64 RewritePatternSet fusionPatterns(context); 65 66 linalg::populateElementwiseOpsFusionPatterns( 67 fusionPatterns, 68 linalg::LinalgElementwiseFusionOptions() 69 .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); 70 71 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 72 std::move(fusionPatterns)); 73 } 74 }; 75 76 struct TestLinalgControlFuseByExpansion 77 : public PassWrapper<TestLinalgControlFuseByExpansion, FunctionPass> { 78 void getDependentDialects(DialectRegistry ®istry) const override { 79 registry 80 .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>(); 81 } 82 StringRef getArgument() const final { 83 return "test-linalg-control-fusion-by-expansion"; 84 } 85 StringRef getDescription() const final { 86 return "Test controlling of fusion of elementwise ops with reshape by " 87 "expansion"; 88 } 89 90 void runOnFunction() override { 91 MLIRContext *context = &this->getContext(); 92 FuncOp funcOp = this->getFunction(); 93 RewritePatternSet fusionPatterns(context); 94 95 linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = 96 [](const OpResult &producer, OpOperand &consumer) { 97 if (auto collapseOp = 98 producer.getDefiningOp<linalg::TensorCollapseShapeOp>()) { 99 if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) { 100 return false; 101 } 102 } 103 if (auto expandOp = 104 dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) { 105 if (expandOp->hasOneUse()) { 106 OpOperand &use = *expandOp->getUses().begin(); 107 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner()); 108 if (linalgOp && linalgOp.isOutputTensor(&use)) 109 return true; 110 } 111 } 112 return linalg::skipUnitDimReshape(producer, consumer); 113 }; 114 115 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, 116 controlReshapeFusionFn); 117 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 118 std::move(fusionPatterns)); 119 } 120 }; 121 122 struct TestPushExpandingReshape 123 : public PassWrapper<TestPushExpandingReshape, FunctionPass> { 124 void getDependentDialects(DialectRegistry ®istry) const override { 125 registry 126 .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>(); 127 } 128 StringRef getArgument() const final { return "test-linalg-push-reshape"; } 129 StringRef getDescription() const final { 130 return "Test Linalg reshape push patterns"; 131 } 132 133 void runOnFunction() override { 134 MLIRContext *context = &this->getContext(); 135 FuncOp funcOp = this->getFunction(); 136 RewritePatternSet patterns(context); 137 linalg::populatePushReshapeOpsPatterns(patterns); 138 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 139 } 140 }; 141 } // namespace 142 143 namespace test { 144 void registerTestLinalgElementwiseFusion() { 145 PassRegistration<TestLinalgElementwiseFusion>(); 146 } 147 148 void registerTestLinalgControlFuseByExpansion() { 149 PassRegistration<TestLinalgControlFuseByExpansion>(); 150 } 151 152 void registerTestPushExpandingReshape() { 153 PassRegistration<TestPushExpandingReshape>(); 154 } 155 } // namespace test 156 157 } // namespace mlir 158