1 //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass drops return values from functions if they are equivalent to one of 10 // their arguments. E.g.: 11 // 12 // ``` 13 // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) { 14 // return %m : memref<?xf32> 15 // } 16 // ``` 17 // 18 // This functions is rewritten to: 19 // 20 // ``` 21 // func.func @foo(%m : memref<?xf32>) { 22 // return 23 // } 24 // ``` 25 // 26 // All call sites are updated accordingly. If a function returns a cast of a 27 // function argument, it is also considered equivalent. A cast is inserted at 28 // the call site in that case. 29 30 #include "PassDetail.h" 31 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 32 #include "mlir/Dialect/Func/IR/FuncOps.h" 33 #include "mlir/Dialect/MemRef/IR/MemRef.h" 34 #include "mlir/IR/Operation.h" 35 #include "mlir/Pass/Pass.h" 36 37 using namespace mlir; 38 39 /// Return the unique ReturnOp that terminates `funcOp`. 40 /// Return nullptr if there is no such unique ReturnOp. getAssumedUniqueReturnOp(func::FuncOp funcOp)41static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 42 func::ReturnOp returnOp; 43 for (Block &b : funcOp.getBody()) { 44 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 45 if (returnOp) 46 return nullptr; 47 returnOp = candidateOp; 48 } 49 } 50 return returnOp; 51 } 52 53 /// Return the func::FuncOp called by `callOp`. getCalledFunction(CallOpInterface callOp)54static func::FuncOp getCalledFunction(CallOpInterface callOp) { 55 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 56 if (!sym) 57 return nullptr; 58 return dyn_cast_or_null<func::FuncOp>( 59 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 60 } 61 62 LogicalResult dropEquivalentBufferResults(ModuleOp module)63mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { 64 IRRewriter rewriter(module.getContext()); 65 66 for (auto funcOp : module.getOps<func::FuncOp>()) { 67 if (funcOp.isExternal()) 68 continue; 69 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 70 // TODO: Support functions with multiple blocks. 71 if (!returnOp) 72 continue; 73 74 // Compute erased results. 75 SmallVector<Value> newReturnValues; 76 BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); 77 DenseMap<int64_t, int64_t> resultToArgs; 78 for (const auto &it : llvm::enumerate(returnOp.operands())) { 79 bool erased = false; 80 for (BlockArgument bbArg : funcOp.getArguments()) { 81 Value val = it.value(); 82 while (auto castOp = val.getDefiningOp<memref::CastOp>()) 83 val = castOp.getSource(); 84 85 if (val == bbArg) { 86 resultToArgs[it.index()] = bbArg.getArgNumber(); 87 erased = true; 88 break; 89 } 90 } 91 92 if (erased) { 93 erasedResultIndices.set(it.index()); 94 } else { 95 newReturnValues.push_back(it.value()); 96 } 97 } 98 99 // Update function. 100 funcOp.eraseResults(erasedResultIndices); 101 returnOp.operandsMutable().assign(newReturnValues); 102 103 // Update function calls. 104 module.walk([&](func::CallOp callOp) { 105 if (getCalledFunction(callOp) != funcOp) 106 return WalkResult::skip(); 107 108 rewriter.setInsertionPoint(callOp); 109 auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp, 110 callOp.operands()); 111 SmallVector<Value> newResults; 112 int64_t nextResult = 0; 113 for (int64_t i = 0; i < callOp.getNumResults(); ++i) { 114 if (!resultToArgs.count(i)) { 115 // This result was not erased. 116 newResults.push_back(newCallOp.getResult(nextResult++)); 117 continue; 118 } 119 120 // This result was erased. 121 Value replacement = callOp.getOperand(resultToArgs[i]); 122 Type expectedType = callOp.getResult(i).getType(); 123 if (replacement.getType() != expectedType) { 124 // A cast must be inserted at the call site. 125 replacement = rewriter.create<memref::CastOp>( 126 callOp.getLoc(), expectedType, replacement); 127 } 128 newResults.push_back(replacement); 129 } 130 rewriter.replaceOp(callOp, newResults); 131 return WalkResult::advance(); 132 }); 133 } 134 135 return success(); 136 } 137 138 namespace { 139 struct DropEquivalentBufferResultsPass 140 : DropEquivalentBufferResultsBase<DropEquivalentBufferResultsPass> { runOnOperation__anonfa0fa9820211::DropEquivalentBufferResultsPass141 void runOnOperation() override { 142 if (failed(bufferization::dropEquivalentBufferResults(getOperation()))) 143 return signalPassFailure(); 144 } 145 }; 146 } // namespace 147 148 std::unique_ptr<Pass> createDropEquivalentBufferResultsPass()149mlir::bufferization::createDropEquivalentBufferResultsPass() { 150 return std::make_unique<DropEquivalentBufferResultsPass>(); 151 } 152