1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 // Updates the func op and entry block. 19 // 20 // Any args appended to the entry block are added to `appendedEntryArgs`. 21 static void updateFuncOp(FuncOp func, 22 SmallVectorImpl<BlockArgument> &appendedEntryArgs) { 23 auto functionType = func.getType(); 24 25 // Collect information about the results will become appended arguments. 26 SmallVector<Type, 6> erasedResultTypes; 27 SmallVector<unsigned, 6> erasedResultIndices; 28 for (const auto &resultType : llvm::enumerate(functionType.getResults())) { 29 if (resultType.value().isa<BaseMemRefType>()) { 30 erasedResultIndices.push_back(resultType.index()); 31 erasedResultTypes.push_back(resultType.value()); 32 } 33 } 34 35 // Add the new arguments to the function type. 36 auto newArgTypes = llvm::to_vector<6>( 37 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); 38 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, 39 functionType.getResults()); 40 func.setType(newFunctionType); 41 42 // Transfer the result attributes to arg attributes. 43 for (int i = 0, e = erasedResultTypes.size(); i < e; i++) 44 func.setArgAttrs(functionType.getNumInputs() + i, 45 func.getResultAttrs(erasedResultIndices[i])); 46 47 // Erase the results. 48 func.eraseResults(erasedResultIndices); 49 50 // Add the new arguments to the entry block if the function is not external. 51 if (func.isExternal()) 52 return; 53 Location loc = func.getLoc(); 54 for (Type type : erasedResultTypes) 55 appendedEntryArgs.push_back(func.front().addArgument(type, loc)); 56 } 57 58 // Updates all ReturnOps in the scope of the given FuncOp by either keeping them 59 // as return values or copying the associated buffer contents into the given 60 // out-params. 61 static void updateReturnOps(FuncOp func, 62 ArrayRef<BlockArgument> appendedEntryArgs) { 63 func.walk([&](ReturnOp op) { 64 SmallVector<Value, 6> copyIntoOutParams; 65 SmallVector<Value, 6> keepAsReturnOperands; 66 for (Value operand : op.getOperands()) { 67 if (operand.getType().isa<BaseMemRefType>()) 68 copyIntoOutParams.push_back(operand); 69 else 70 keepAsReturnOperands.push_back(operand); 71 } 72 OpBuilder builder(op); 73 for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) 74 builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t), 75 std::get<1>(t)); 76 builder.create<ReturnOp>(op.getLoc(), keepAsReturnOperands); 77 op.erase(); 78 }); 79 } 80 81 // Updates all CallOps in the scope of the given ModuleOp by allocating 82 // temporary buffers for newly introduced out params. 83 static LogicalResult updateCalls(ModuleOp module) { 84 bool didFail = false; 85 module.walk([&](CallOp op) { 86 SmallVector<Value, 6> replaceWithNewCallResults; 87 SmallVector<Value, 6> replaceWithOutParams; 88 for (OpResult result : op.getResults()) { 89 if (result.getType().isa<BaseMemRefType>()) 90 replaceWithOutParams.push_back(result); 91 else 92 replaceWithNewCallResults.push_back(result); 93 } 94 SmallVector<Value, 6> outParams; 95 OpBuilder builder(op); 96 for (Value memref : replaceWithOutParams) { 97 if (!memref.getType().cast<BaseMemRefType>().hasStaticShape()) { 98 op.emitError() 99 << "cannot create out param for dynamically shaped result"; 100 didFail = true; 101 return; 102 } 103 Value outParam = builder.create<memref::AllocOp>( 104 op.getLoc(), memref.getType().cast<MemRefType>()); 105 memref.replaceAllUsesWith(outParam); 106 outParams.push_back(outParam); 107 } 108 109 auto newOperands = llvm::to_vector<6>(op.getOperands()); 110 newOperands.append(outParams.begin(), outParams.end()); 111 auto newResultTypes = llvm::to_vector<6>(llvm::map_range( 112 replaceWithNewCallResults, [](Value v) { return v.getType(); })); 113 auto newCall = builder.create<CallOp>(op.getLoc(), op.getCalleeAttr(), 114 newResultTypes, newOperands); 115 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) 116 std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); 117 op.erase(); 118 }); 119 120 return failure(didFail); 121 } 122 123 namespace { 124 struct BufferResultsToOutParamsPass 125 : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> { 126 void runOnOperation() override { 127 ModuleOp module = getOperation(); 128 129 for (auto func : module.getOps<FuncOp>()) { 130 SmallVector<BlockArgument, 6> appendedEntryArgs; 131 updateFuncOp(func, appendedEntryArgs); 132 if (func.isExternal()) 133 continue; 134 updateReturnOps(func, appendedEntryArgs); 135 } 136 if (failed(updateCalls(module))) 137 return signalPassFailure(); 138 } 139 }; 140 } // namespace 141 142 std::unique_ptr<Pass> 143 mlir::bufferization::createBufferResultsToOutParamsPass() { 144 return std::make_unique<BufferResultsToOutParamsPass>(); 145 } 146