1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing fusion of elementwise operations in 10 // Linalg, mainly linalg options. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 using namespace mlir; 23 24 static void addOperands(Operation *op, SetVector<Value> &operandSet) { 25 if (!op) 26 return; 27 TypeSwitch<Operation *, void>(op) 28 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 29 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 30 operandSet.insert(inputOperands.begin(), inputOperands.end()); 31 }) 32 .Default([&](Operation *operation) { 33 operandSet.insert(operation->operand_begin(), operation->operand_end()); 34 }); 35 } 36 37 template <int limit = 3> 38 static bool setFusedOpOperandLimit(const OpResult &producer, 39 const OpOperand &consumer) { 40 SetVector<Value> fusedOpOperands; 41 if (producer.getOwner()->getNumResults() != 1) 42 return false; 43 addOperands(consumer.getOwner(), fusedOpOperands); 44 fusedOpOperands.remove(producer); 45 addOperands(producer.getOwner(), fusedOpOperands); 46 return fusedOpOperands.size() <= limit; 47 } 48 49 namespace { 50 struct TestLinalgElementwiseFusion 51 : public PassWrapper<TestLinalgElementwiseFusion, 52 OperationPass<func::FuncOp>> { 53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgElementwiseFusion) 54 55 TestLinalgElementwiseFusion() = default; 56 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass) 57 : PassWrapper(pass) {} 58 void getDependentDialects(DialectRegistry ®istry) const override { 59 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 60 tensor::TensorDialect>(); 61 } 62 StringRef getArgument() const final { 63 return "test-linalg-elementwise-fusion-patterns"; 64 } 65 StringRef getDescription() const final { 66 return "Test Linalg element wise operation fusion patterns"; 67 } 68 69 Option<bool> fuseGenericOps{ 70 *this, "fuse-generic-ops", 71 llvm::cl::desc("Test fusion of generic operations."), 72 llvm::cl::init(false)}; 73 74 Option<bool> fuseWithReshapeByExpansion{ 75 *this, "fuse-with-reshape-by-expansion", 76 llvm::cl::desc( 77 "Test fusion of generic operations with reshape by expansion"), 78 llvm::cl::init(false)}; 79 80 Option<bool> controlFuseByExpansion{ 81 *this, "control-fusion-by-expansion", 82 llvm::cl::desc( 83 "Test controlling fusion of reshape with generic op by expansion"), 84 llvm::cl::init(false)}; 85 86 Option<bool> fuseWithReshapeByCollapsing{ 87 *this, "fuse-with-reshape-by-collapsing", 88 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " 89 "collapse the iteration space of the consumer"), 90 llvm::cl::init(false)}; 91 92 Option<bool> fuseWithReshapeByCollapsingWithControlFn{ 93 *this, "fuse-with-reshape-by-collapsing-control", 94 llvm::cl::desc("Test controlling the linalg expand_shape -> generic " 95 "fusion patterns that " 96 "collapse the iteration space of the consumer"), 97 llvm::cl::init(false)}; 98 99 void runOnOperation() override { 100 MLIRContext *context = &this->getContext(); 101 func::FuncOp funcOp = this->getOperation(); 102 103 if (fuseGenericOps) { 104 RewritePatternSet fusionPatterns(context); 105 linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, 106 setFusedOpOperandLimit<4>); 107 108 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 109 std::move(fusionPatterns)); 110 return; 111 } 112 113 if (fuseWithReshapeByExpansion) { 114 RewritePatternSet fusionPatterns(context); 115 linalg::populateFoldReshapeOpsByExpansionPatterns( 116 fusionPatterns, [](const OpResult & /*producer*/, 117 OpOperand & /*consumer*/) { return true; }); 118 if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), 119 std::move(fusionPatterns)))) 120 return signalPassFailure(); 121 return; 122 } 123 124 if (controlFuseByExpansion) { 125 RewritePatternSet fusionPatterns(context); 126 127 linalg::ControlFusionFn controlReshapeFusionFn = 128 [](const OpResult &producer, OpOperand &consumer) { 129 if (auto collapseOp = 130 producer.getDefiningOp<tensor::CollapseShapeOp>()) { 131 if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) { 132 return false; 133 } 134 } 135 if (auto expandOp = 136 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 137 if (expandOp->hasOneUse()) { 138 OpOperand &use = *expandOp->getUses().begin(); 139 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner()); 140 if (linalgOp && linalgOp.isOutputTensor(&use)) 141 return true; 142 } 143 return false; 144 } 145 return true; 146 }; 147 148 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, 149 controlReshapeFusionFn); 150 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 151 std::move(fusionPatterns)); 152 return; 153 } 154 155 if (fuseWithReshapeByCollapsing) { 156 RewritePatternSet patterns(context); 157 linalg::populateFoldReshapeOpsByCollapsingPatterns( 158 patterns, [](const OpResult & /*producer*/, 159 OpOperand & /*consumer*/) { return true; }); 160 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 161 } 162 163 if (fuseWithReshapeByCollapsingWithControlFn) { 164 RewritePatternSet patterns(context); 165 linalg::ControlFusionFn controlFn = [](const OpResult &producer, 166 OpOperand &consumer) -> bool { 167 if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) { 168 // Skip fusing the first operand. 169 return consumer.getOperandNumber(); 170 } 171 return true; 172 }; 173 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); 174 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 175 } 176 } 177 }; 178 179 } // namespace 180 181 namespace mlir { 182 namespace test { 183 void registerTestLinalgElementwiseFusion() { 184 PassRegistration<TestLinalgElementwiseFusion>(); 185 } 186 } // namespace test 187 } // namespace mlir 188