1 //===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass for testing fusion of elementwise operations in 10 // Linalg, mainly linalg options. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassManager.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 20 namespace mlir { 21 22 static void addOperands(Operation *op, SetVector<Value> &operandSet) { 23 if (!op) 24 return; 25 TypeSwitch<Operation *, void>(op) 26 .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { 27 SmallVector<Value> inputOperands = linalgOp.getInputOperands(); 28 operandSet.insert(inputOperands.begin(), inputOperands.end()); 29 }) 30 .Default([&](Operation *operation) { 31 operandSet.insert(operation->operand_begin(), operation->operand_end()); 32 }); 33 } 34 35 template <int limit = 3> 36 static bool setFusedOpOperandLimit(const OpResult &producer, 37 const OpOperand &consumer) { 38 SetVector<Value> fusedOpOperands; 39 if (producer.getOwner()->getNumResults() != 1) 40 return false; 41 addOperands(consumer.getOwner(), fusedOpOperands); 42 fusedOpOperands.remove(producer); 43 addOperands(producer.getOwner(), fusedOpOperands); 44 return fusedOpOperands.size() <= limit; 45 } 46 47 namespace { 48 struct TestLinalgElementwiseFusion 49 : public PassWrapper<TestLinalgElementwiseFusion, OperationPass<FuncOp>> { 50 TestLinalgElementwiseFusion() = default; 51 TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass) 52 : PassWrapper(pass) {} 53 void getDependentDialects(DialectRegistry ®istry) const override { 54 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 55 tensor::TensorDialect>(); 56 } 57 StringRef getArgument() const final { 58 return "test-linalg-elementwise-fusion-patterns"; 59 } 60 StringRef getDescription() const final { 61 return "Test Linalg element wise operation fusion patterns"; 62 } 63 64 Option<bool> fuseGenericOps{ 65 *this, "fuse-generic-ops", 66 llvm::cl::desc("Test fusion of generic operations."), 67 llvm::cl::init(false)}; 68 69 Option<bool> controlFuseByExpansion{ 70 *this, "control-fusion-by-expansion", 71 llvm::cl::desc( 72 "Test controlling fusion of reshape with generic op by expansion"), 73 llvm::cl::init(false)}; 74 75 Option<bool> pushExpandingReshape{ 76 *this, "push-expanding-reshape", 77 llvm::cl::desc("Test linalg expand_shape -> generic " 78 "to generic -> expand_shape pattern"), 79 llvm::cl::init(false)}; 80 81 Option<bool> fuseWithReshapeByCollapsing{ 82 *this, "fuse-with-reshape-by-collapsing", 83 llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " 84 "collapse the iteration space of the consumer"), 85 llvm::cl::init(false)}; 86 87 Option<bool> fuseWithReshapeByCollapsingWithControlFn{ 88 *this, "fuse-with-reshape-by-collapsing-control", 89 llvm::cl::desc("Test controlling the linalg expand_shape -> generic " 90 "fusion patterns that " 91 "collapse the iteration space of the consumer"), 92 llvm::cl::init(false)}; 93 94 void runOnOperation() override { 95 MLIRContext *context = &this->getContext(); 96 FuncOp funcOp = this->getOperation(); 97 98 if (fuseGenericOps) { 99 RewritePatternSet fusionPatterns(context); 100 linalg::populateElementwiseOpsFusionPatterns( 101 fusionPatterns, 102 linalg::LinalgElementwiseFusionOptions() 103 .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>)); 104 105 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 106 std::move(fusionPatterns)); 107 return; 108 } 109 110 if (controlFuseByExpansion) { 111 RewritePatternSet fusionPatterns(context); 112 113 linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn = 114 [](const OpResult &producer, OpOperand &consumer) { 115 if (auto collapseOp = 116 producer.getDefiningOp<tensor::CollapseShapeOp>()) { 117 if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) { 118 return false; 119 } 120 } 121 if (auto expandOp = 122 dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) { 123 if (expandOp->hasOneUse()) { 124 OpOperand &use = *expandOp->getUses().begin(); 125 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner()); 126 if (linalgOp && linalgOp.isOutputTensor(&use)) 127 return true; 128 } 129 } 130 return linalg::skipUnitDimReshape(producer, consumer); 131 }; 132 133 linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, 134 controlReshapeFusionFn); 135 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), 136 std::move(fusionPatterns)); 137 return; 138 } 139 140 if (pushExpandingReshape) { 141 RewritePatternSet patterns(context); 142 linalg::populatePushReshapeOpsPatterns(patterns); 143 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 144 } 145 146 if (fuseWithReshapeByCollapsing) { 147 RewritePatternSet patterns(context); 148 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns); 149 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 150 } 151 152 if (fuseWithReshapeByCollapsingWithControlFn) { 153 RewritePatternSet patterns(context); 154 linalg::ControlElementwiseOpsFusionFn controlFn = 155 [](const OpResult &producer, OpOperand &consumer) -> bool { 156 if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) { 157 // Skip fusing the first operand. 158 return consumer.getOperandNumber(); 159 } 160 return true; 161 }; 162 linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); 163 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 164 } 165 } 166 }; 167 168 } // namespace 169 170 namespace test { 171 void registerTestLinalgElementwiseFusion() { 172 PassRegistration<TestLinalgElementwiseFusion>(); 173 } 174 } // namespace test 175 176 } // namespace mlir 177