1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing fusion of elementwise operations in 10 // Linalg, mainly linalg options. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Pass/PassManager.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 23 static void addOperands(Operation *op, SetVector<Value> &operandSet) { 24 if (!op) 25 return; 26 TypeSwitch<Operation *, void>(op) 27 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 28 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 29 operandSet.insert(inputOperands.begin(), inputOperands.end()); 30 }) 31 .Default([&](Operation *operation) { 32 operandSet.insert(operation->operand_begin(), operation->operand_end()); 33 }); 34 } 35 36 template <int limit = 3> 37 static bool setFusedOpOperandLimit(const OpResult &producer, 38 const OpOperand &consumer) { 39 SetVector<Value> fusedOpOperands; 40 if (producer.getOwner()->getNumResults() != 1) 41 return false; 42 addOperands(consumer.getOwner(), fusedOpOperands); 43 fusedOpOperands.remove(producer); 44 addOperands(producer.getOwner(), fusedOpOperands); 45 return fusedOpOperands.size() <= limit; 46 } 47 48 namespace { 49 struct TestLinalgElementwiseFusion 50 : public PassWrapper<TestLinalgElementwiseFusion, 51 OperationPass<func::FuncOp>> { 52 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion) 53 54 TestLinalgElementwiseFusion() = default; 55 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass) 56 : PassWrapper(pass) {} 57 void getDependentDialects(DialectRegistry ®istry) const override { 58 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 59 tensor::TensorDialect>(); 60 } 61 StringRef getArgument() const final { 62 return "test-linalg-elementwise-fusion-patterns"; 63 } 64 StringRef getDescription() const final { 65 return "Test Linalg element wise operation fusion patterns"; 66 } 67 68 Option<bool> fuseGenericOps{ 69 *this, "fuse-generic-ops", 70 llvm::cl::desc("Test fusion of generic operations."), 71 llvm::cl::init(false)}; 72 73 Option<bool> fuseWithReshapeByExpansion{ 74 *this, "fuse-with-reshape-by-expansion", 75 llvm::cl::desc( 76 "Test fusion of generic operations with reshape by expansion"), 77 llvm::cl::init(false)}; 78 79 Option<bool> controlFuseByExpansion{ 80 *this, "control-fusion-by-expansion", 81 llvm::cl::desc( 82 "Test controlling fusion of reshape with generic op by expansion"), 83 llvm::cl::init(false)}; 84 85 Option<bool> fuseWithReshapeByCollapsing{ 86 *this, "fuse-with-reshape-by-collapsing", 87 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " 88 "collapse the iteration space of the consumer"), 89 llvm::cl::init(false)}; 90 91 Option<bool> fuseWithReshapeByCollapsingWithControlFn{ 92 *this, "fuse-with-reshape-by-collapsing-control", 93 llvm::cl::desc("Test controlling the linalg expand_shape -> generic " 94 "fusion patterns that " 95 "collapse the iteration space of the consumer"), 96 llvm::cl::init(false)}; 97 98 void runOnOperation() override { 99 MLIRContext *context = &this->getContext(); 100 func::FuncOp funcOp = this->getOperation(); 101 102 if (fuseGenericOps) { 103 RewritePatternSet fusionPatterns(context); 104 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, 105 setFusedOpOperandLimit<4>); 106 107 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 108 std::move(fusionPatterns)); 109 return; 110 } 111 112 if (fuseWithReshapeByExpansion) { 113 RewritePatternSet fusionPatterns(context); 114 linalg::populateFoldReshapeOpsByExpansionPatterns( 115 fusionPatterns, [](const OpResult & /*producer*/, 116 OpOperand & /*consumer*/) { return true; }); 117 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), 118 std::move(fusionPatterns)))) 119 return signalPassFailure(); 120 return; 121 } 122 123 if (controlFuseByExpansion) { 124 RewritePatternSet fusionPatterns(context); 125 126 linalg::ControlFusionFn controlReshapeFusionFn = 127 [](const OpResult &producer, OpOperand &consumer) { 128 if (auto collapseOp = 129 producer.getDefiningOp<tensor::CollapseShapeOp>()) { 130 if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) { 131 return false; 132 } 133 } 134 if (auto expandOp = 135 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 136 if (expandOp->hasOneUse()) { 137 OpOperand &use = *expandOp->getUses().begin(); 138 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner()); 139 if (linalgOp && linalgOp.isOutputTensor(&use)) 140 return true; 141 } 142 return false; 143 } 144 return true; 145 }; 146 147 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, 148 controlReshapeFusionFn); 149 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 150 std::move(fusionPatterns)); 151 return; 152 } 153 154 if (fuseWithReshapeByCollapsing) { 155 RewritePatternSet patterns(context); 156 linalg::populateFoldReshapeOpsByCollapsingPatterns( 157 patterns, [](const OpResult & /*producer*/, 158 OpOperand & /*consumer*/) { return true; }); 159 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 160 } 161 162 if (fuseWithReshapeByCollapsingWithControlFn) { 163 RewritePatternSet patterns(context); 164 linalg::ControlFusionFn controlFn = [](const OpResult &producer, 165 OpOperand &consumer) -> bool { 166 if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) { 167 // Skip fusing the first operand. 168 return consumer.getOperandNumber(); 169 } 170 return true; 171 }; 172 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); 173 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 174 } 175 } 176 }; 177 178 } // namespace 179 180 namespace mlir { 181 namespace test { 182 void registerTestLinalgElementwiseFusion() { 183 PassRegistration<TestLinalgElementwiseFusion>(); 184 } 185 } // namespace test 186 } // namespace mlir 187