1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing fusion of elementwise operations in 10 // Linalg, mainly linalg options. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassManager.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 20 namespace mlir { 21 22 static void addOperands(Operation *op, SetVector<Value> &operandSet) { 23 if (!op) 24 return; 25 TypeSwitch<Operation *, void>(op) 26 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 27 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 28 operandSet.insert(inputOperands.begin(), inputOperands.end()); 29 }) 30 .Default([&](Operation *operation) { 31 operandSet.insert(operation->operand_begin(), operation->operand_end()); 32 }); 33 } 34 35 template <int limit = 3> 36 static bool setFusedOpOperandLimit(const OpResult &producer, 37 const OpOperand &consumer) { 38 SetVector<Value> fusedOpOperands; 39 if (producer.getOwner()->getNumResults() != 1) 40 return false; 41 addOperands(consumer.getOwner(), fusedOpOperands); 42 fusedOpOperands.remove(producer); 43 addOperands(producer.getOwner(), fusedOpOperands); 44 return fusedOpOperands.size() <= limit; 45 } 46 47 namespace { 48 struct TestLinalgElementwiseFusion 49 : public PassWrapper<TestLinalgElementwiseFusion, OperationPass<FuncOp>> { 50 void getDependentDialects(DialectRegistry ®istry) const override { 51 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 52 tensor::TensorDialect>(); 53 } 54 StringRef getArgument() const final { 55 return "test-linalg-elementwise-fusion-patterns"; 56 } 57 StringRef getDescription() const final { 58 return "Test Linalg element wise operation fusion patterns"; 59 } 60 61 void runOnOperation() override { 62 MLIRContext *context = &this->getContext(); 63 FuncOp funcOp = this->getOperation(); 64 RewritePatternSet fusionPatterns(context); 65 66 linalg::populateElementwiseOpsFusionPatterns( 67 fusionPatterns, 68 linalg::LinalgElementwiseFusionOptions() 69 .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); 70 71 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 72 std::move(fusionPatterns)); 73 } 74 }; 75 76 struct TestLinalgControlFuseByExpansion 77 : public PassWrapper<TestLinalgControlFuseByExpansion, 78 OperationPass<FuncOp>> { 79 void getDependentDialects(DialectRegistry ®istry) const override { 80 registry 81 .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>(); 82 } 83 StringRef getArgument() const final { 84 return "test-linalg-control-fusion-by-expansion"; 85 } 86 StringRef getDescription() const final { 87 return "Test controlling of fusion of elementwise ops with reshape by " 88 "expansion"; 89 } 90 91 void runOnOperation() override { 92 MLIRContext *context = &this->getContext(); 93 FuncOp funcOp = this->getOperation(); 94 RewritePatternSet fusionPatterns(context); 95 96 linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = 97 [](const OpResult &producer, OpOperand &consumer) { 98 if (auto collapseOp = 99 producer.getDefiningOp<tensor::CollapseShapeOp>()) { 100 if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) { 101 return false; 102 } 103 } 104 if (auto expandOp = 105 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 106 if (expandOp->hasOneUse()) { 107 OpOperand &use = *expandOp->getUses().begin(); 108 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner()); 109 if (linalgOp && linalgOp.isOutputTensor(&use)) 110 return true; 111 } 112 } 113 return linalg::skipUnitDimReshape(producer, consumer); 114 }; 115 116 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, 117 controlReshapeFusionFn); 118 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 119 std::move(fusionPatterns)); 120 } 121 }; 122 123 struct TestPushExpandingReshape 124 : public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> { 125 void getDependentDialects(DialectRegistry ®istry) const override { 126 registry 127 .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>(); 128 } 129 StringRef getArgument() const final { return "test-linalg-push-reshape"; } 130 StringRef getDescription() const final { 131 return "Test Linalg reshape push patterns"; 132 } 133 134 void runOnOperation() override { 135 MLIRContext *context = &this->getContext(); 136 FuncOp funcOp = this->getOperation(); 137 RewritePatternSet patterns(context); 138 linalg::populatePushReshapeOpsPatterns(patterns); 139 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 140 } 141 }; 142 } // namespace 143 144 namespace test { 145 void registerTestLinalgElementwiseFusion() { 146 PassRegistration<TestLinalgElementwiseFusion>(); 147 } 148 149 void registerTestLinalgControlFuseByExpansion() { 150 PassRegistration<TestLinalgControlFuseByExpansion>(); 151 } 152 153 void registerTestPushExpandingReshape() { 154 PassRegistration<TestPushExpandingReshape>(); 155 } 156 } // namespace test 157 158 } // namespace mlir 159