1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing fusion of elementwise operations in 10 // Linalg, mainly linalg options. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Pass/PassManager.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 23 static void addOperands(Operation *op, SetVector<Value> &operandSet) { 24 if (!op) 25 return; 26 TypeSwitch<Operation *, void>(op) 27 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 28 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 29 operandSet.insert(inputOperands.begin(), inputOperands.end()); 30 }) 31 .Default([&](Operation *operation) { 32 operandSet.insert(operation->operand_begin(), operation->operand_end()); 33 }); 34 } 35 36 template <int limit = 3> 37 static bool setFusedOpOperandLimit(const OpResult &producer, 38 const OpOperand &consumer) { 39 SetVector<Value> fusedOpOperands; 40 if (producer.getOwner()->getNumResults() != 1) 41 return false; 42 addOperands(consumer.getOwner(), fusedOpOperands); 43 fusedOpOperands.remove(producer); 44 addOperands(producer.getOwner(), fusedOpOperands); 45 return fusedOpOperands.size() <= limit; 46 } 47 48 namespace { 49 struct TestLinalgElementwiseFusion 50 : public PassWrapper<TestLinalgElementwiseFusion, OperationPass<FuncOp>> { 51 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion) 52 53 TestLinalgElementwiseFusion() = default; 54 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass) 55 : PassWrapper(pass) {} 56 void getDependentDialects(DialectRegistry ®istry) const override { 57 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 58 tensor::TensorDialect>(); 59 } 60 StringRef getArgument() const final { 61 return "test-linalg-elementwise-fusion-patterns"; 62 } 63 StringRef getDescription() const final { 64 return "Test Linalg element wise operation fusion patterns"; 65 } 66 67 Option<bool> fuseGenericOps{ 68 *this, "fuse-generic-ops", 69 llvm::cl::desc("Test fusion of generic operations."), 70 llvm::cl::init(false)}; 71 72 Option<bool> controlFuseByExpansion{ 73 *this, "control-fusion-by-expansion", 74 llvm::cl::desc( 75 "Test controlling fusion of reshape with generic op by expansion"), 76 llvm::cl::init(false)}; 77 78 Option<bool> pushExpandingReshape{ 79 *this, "push-expanding-reshape", 80 llvm::cl::desc("Test linalg expand_shape -> generic " 81 "to generic -> expand_shape pattern"), 82 llvm::cl::init(false)}; 83 84 Option<bool> fuseWithReshapeByCollapsing{ 85 *this, "fuse-with-reshape-by-collapsing", 86 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " 87 "collapse the iteration space of the consumer"), 88 llvm::cl::init(false)}; 89 90 Option<bool> fuseWithReshapeByCollapsingWithControlFn{ 91 *this, "fuse-with-reshape-by-collapsing-control", 92 llvm::cl::desc("Test controlling the linalg expand_shape -> generic " 93 "fusion patterns that " 94 "collapse the iteration space of the consumer"), 95 llvm::cl::init(false)}; 96 97 void runOnOperation() override { 98 MLIRContext *context = &this->getContext(); 99 FuncOp funcOp = this->getOperation(); 100 101 if (fuseGenericOps) { 102 RewritePatternSet fusionPatterns(context); 103 linalg::populateElementwiseOpsFusionPatterns( 104 fusionPatterns, 105 linalg::LinalgElementwiseFusionOptions() 106 .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); 107 108 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 109 std::move(fusionPatterns)); 110 return; 111 } 112 113 if (controlFuseByExpansion) { 114 RewritePatternSet fusionPatterns(context); 115 116 linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = 117 [](const OpResult &producer, OpOperand &consumer) { 118 if (auto collapseOp = 119 producer.getDefiningOp<tensor::CollapseShapeOp>()) { 120 if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) { 121 return false; 122 } 123 } 124 if (auto expandOp = 125 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 126 if (expandOp->hasOneUse()) { 127 OpOperand &use = *expandOp->getUses().begin(); 128 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner()); 129 if (linalgOp && linalgOp.isOutputTensor(&use)) 130 return true; 131 } 132 } 133 return linalg::skipUnitDimReshape(producer, consumer); 134 }; 135 136 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, 137 controlReshapeFusionFn); 138 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 139 std::move(fusionPatterns)); 140 return; 141 } 142 143 if (pushExpandingReshape) { 144 RewritePatternSet patterns(context); 145 linalg::populatePushReshapeOpsPatterns(patterns); 146 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 147 } 148 149 if (fuseWithReshapeByCollapsing) { 150 RewritePatternSet patterns(context); 151 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns); 152 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 153 } 154 155 if (fuseWithReshapeByCollapsingWithControlFn) { 156 RewritePatternSet patterns(context); 157 linalg::ControlElementwiseOpsFusionFn controlFn = 158 [](const OpResult &producer, OpOperand &consumer) -> bool { 159 if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) { 160 // Skip fusing the first operand. 161 return consumer.getOperandNumber(); 162 } 163 return true; 164 }; 165 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); 166 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 167 } 168 } 169 }; 170 171 } // namespace 172 173 namespace mlir { 174 namespace test { 175 void registerTestLinalgElementwiseFusion() { 176 PassRegistration<TestLinalgElementwiseFusion>(); 177 } 178 } // namespace test 179 } // namespace mlir 180