1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 // Updates the func op and entry block. 19 // 20 // Any args appended to the entry block are added to `appendedEntryArgs`. 21 static void updateFuncOp(func::FuncOp func, 22 SmallVectorImpl<BlockArgument> &appendedEntryArgs) { 23 auto functionType = func.getFunctionType(); 24 25 // Collect information about the results will become appended arguments. 26 SmallVector<Type, 6> erasedResultTypes; 27 BitVector erasedResultIndices(functionType.getNumResults()); 28 for (const auto &resultType : llvm::enumerate(functionType.getResults())) { 29 if (resultType.value().isa<BaseMemRefType>()) { 30 erasedResultIndices.set(resultType.index()); 31 erasedResultTypes.push_back(resultType.value()); 32 } 33 } 34 35 // Add the new arguments to the function type. 36 auto newArgTypes = llvm::to_vector<6>( 37 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); 38 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, 39 functionType.getResults()); 40 func.setType(newFunctionType); 41 42 // Transfer the result attributes to arg attributes. 43 auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); 44 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { 45 func.setArgAttrs(functionType.getNumInputs() + i, 46 func.getResultAttrs(*erasedIndicesIt)); 47 } 48 49 // Erase the results. 50 func.eraseResults(erasedResultIndices); 51 52 // Add the new arguments to the entry block if the function is not external. 53 if (func.isExternal()) 54 return; 55 Location loc = func.getLoc(); 56 for (Type type : erasedResultTypes) 57 appendedEntryArgs.push_back(func.front().addArgument(type, loc)); 58 } 59 60 // Updates all ReturnOps in the scope of the given func::FuncOp by either 61 // keeping them as return values or copying the associated buffer contents into 62 // the given out-params. 63 static void updateReturnOps(func::FuncOp func, 64 ArrayRef<BlockArgument> appendedEntryArgs) { 65 func.walk([&](func::ReturnOp op) { 66 SmallVector<Value, 6> copyIntoOutParams; 67 SmallVector<Value, 6> keepAsReturnOperands; 68 for (Value operand : op.getOperands()) { 69 if (operand.getType().isa<BaseMemRefType>()) 70 copyIntoOutParams.push_back(operand); 71 else 72 keepAsReturnOperands.push_back(operand); 73 } 74 OpBuilder builder(op); 75 for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) 76 builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t), 77 std::get<1>(t)); 78 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); 79 op.erase(); 80 }); 81 } 82 83 // Updates all CallOps in the scope of the given ModuleOp by allocating 84 // temporary buffers for newly introduced out params. 85 static LogicalResult updateCalls(ModuleOp module) { 86 bool didFail = false; 87 module.walk([&](func::CallOp op) { 88 SmallVector<Value, 6> replaceWithNewCallResults; 89 SmallVector<Value, 6> replaceWithOutParams; 90 for (OpResult result : op.getResults()) { 91 if (result.getType().isa<BaseMemRefType>()) 92 replaceWithOutParams.push_back(result); 93 else 94 replaceWithNewCallResults.push_back(result); 95 } 96 SmallVector<Value, 6> outParams; 97 OpBuilder builder(op); 98 for (Value memref : replaceWithOutParams) { 99 if (!memref.getType().cast<BaseMemRefType>().hasStaticShape()) { 100 op.emitError() 101 << "cannot create out param for dynamically shaped result"; 102 didFail = true; 103 return; 104 } 105 Value outParam = builder.create<memref::AllocOp>( 106 op.getLoc(), memref.getType().cast<MemRefType>()); 107 memref.replaceAllUsesWith(outParam); 108 outParams.push_back(outParam); 109 } 110 111 auto newOperands = llvm::to_vector<6>(op.getOperands()); 112 newOperands.append(outParams.begin(), outParams.end()); 113 auto newResultTypes = llvm::to_vector<6>(llvm::map_range( 114 replaceWithNewCallResults, [](Value v) { return v.getType(); })); 115 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), 116 newResultTypes, newOperands); 117 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) 118 std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); 119 op.erase(); 120 }); 121 122 return failure(didFail); 123 } 124 125 namespace { 126 struct BufferResultsToOutParamsPass 127 : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> { 128 void runOnOperation() override { 129 ModuleOp module = getOperation(); 130 131 for (auto func : module.getOps<func::FuncOp>()) { 132 SmallVector<BlockArgument, 6> appendedEntryArgs; 133 updateFuncOp(func, appendedEntryArgs); 134 if (func.isExternal()) 135 continue; 136 updateReturnOps(func, appendedEntryArgs); 137 } 138 if (failed(updateCalls(module))) 139 return signalPassFailure(); 140 } 141 }; 142 } // namespace 143 144 std::unique_ptr<Pass> 145 mlir::bufferization::createBufferResultsToOutParamsPass() { 146 return std::make_unique<BufferResultsToOutParamsPass>(); 147 } 148