1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15
16 using namespace mlir;
17
18 /// Return `true` if the given MemRef type has a fully dynamic layout.
hasFullyDynamicLayoutMap(MemRefType type)19 static bool hasFullyDynamicLayoutMap(MemRefType type) {
20 int64_t offset;
21 SmallVector<int64_t, 4> strides;
22 if (failed(getStridesAndOffset(type, strides, offset)))
23 return false;
24 if (!llvm::all_of(strides, ShapedType::isDynamicStrideOrOffset))
25 return false;
26 if (!ShapedType::isDynamicStrideOrOffset(offset))
27 return false;
28 return true;
29 }
30
31 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
32 /// no layout).
hasStaticIdentityLayout(MemRefType type)33 static bool hasStaticIdentityLayout(MemRefType type) {
34 return type.getLayout().isIdentity();
35 }
36
37 // Updates the func op and entry block.
38 //
39 // Any args appended to the entry block are added to `appendedEntryArgs`.
40 static LogicalResult
updateFuncOp(func::FuncOp func,SmallVectorImpl<BlockArgument> & appendedEntryArgs)41 updateFuncOp(func::FuncOp func,
42 SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
43 auto functionType = func.getFunctionType();
44
45 // Collect information about the results will become appended arguments.
46 SmallVector<Type, 6> erasedResultTypes;
47 BitVector erasedResultIndices(functionType.getNumResults());
48 for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
49 if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) {
50 if (!hasStaticIdentityLayout(memrefType) &&
51 !hasFullyDynamicLayoutMap(memrefType)) {
52 // Only buffers with static identity layout can be allocated. These can
53 // be casted to memrefs with fully dynamic layout map. Other layout maps
54 // are not supported.
55 return func->emitError()
56 << "cannot create out param for result with unsupported layout";
57 }
58 erasedResultIndices.set(resultType.index());
59 erasedResultTypes.push_back(memrefType);
60 }
61 }
62
63 // Add the new arguments to the function type.
64 auto newArgTypes = llvm::to_vector<6>(
65 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
66 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
67 functionType.getResults());
68 func.setType(newFunctionType);
69
70 // Transfer the result attributes to arg attributes.
71 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
72 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
73 func.setArgAttrs(functionType.getNumInputs() + i,
74 func.getResultAttrs(*erasedIndicesIt));
75 }
76
77 // Erase the results.
78 func.eraseResults(erasedResultIndices);
79
80 // Add the new arguments to the entry block if the function is not external.
81 if (func.isExternal())
82 return success();
83 Location loc = func.getLoc();
84 for (Type type : erasedResultTypes)
85 appendedEntryArgs.push_back(func.front().addArgument(type, loc));
86
87 return success();
88 }
89
90 // Updates all ReturnOps in the scope of the given func::FuncOp by either
91 // keeping them as return values or copying the associated buffer contents into
92 // the given out-params.
updateReturnOps(func::FuncOp func,ArrayRef<BlockArgument> appendedEntryArgs)93 static void updateReturnOps(func::FuncOp func,
94 ArrayRef<BlockArgument> appendedEntryArgs) {
95 func.walk([&](func::ReturnOp op) {
96 SmallVector<Value, 6> copyIntoOutParams;
97 SmallVector<Value, 6> keepAsReturnOperands;
98 for (Value operand : op.getOperands()) {
99 if (operand.getType().isa<MemRefType>())
100 copyIntoOutParams.push_back(operand);
101 else
102 keepAsReturnOperands.push_back(operand);
103 }
104 OpBuilder builder(op);
105 for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
106 builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
107 std::get<1>(t));
108 builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
109 op.erase();
110 });
111 }
112
113 // Updates all CallOps in the scope of the given ModuleOp by allocating
114 // temporary buffers for newly introduced out params.
updateCalls(ModuleOp module)115 static LogicalResult updateCalls(ModuleOp module) {
116 bool didFail = false;
117 module.walk([&](func::CallOp op) {
118 SmallVector<Value, 6> replaceWithNewCallResults;
119 SmallVector<Value, 6> replaceWithOutParams;
120 for (OpResult result : op.getResults()) {
121 if (result.getType().isa<MemRefType>())
122 replaceWithOutParams.push_back(result);
123 else
124 replaceWithNewCallResults.push_back(result);
125 }
126 SmallVector<Value, 6> outParams;
127 OpBuilder builder(op);
128 for (Value memref : replaceWithOutParams) {
129 if (!memref.getType().cast<MemRefType>().hasStaticShape()) {
130 op.emitError()
131 << "cannot create out param for dynamically shaped result";
132 didFail = true;
133 return;
134 }
135 auto memrefType = memref.getType().cast<MemRefType>();
136 auto allocType =
137 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
138 AffineMap(), memrefType.getMemorySpaceAsInt());
139 Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
140 if (!hasStaticIdentityLayout(memrefType)) {
141 // Layout maps are already checked in `updateFuncOp`.
142 assert(hasFullyDynamicLayoutMap(memrefType) &&
143 "layout map not supported");
144 outParam =
145 builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
146 }
147 memref.replaceAllUsesWith(outParam);
148 outParams.push_back(outParam);
149 }
150
151 auto newOperands = llvm::to_vector<6>(op.getOperands());
152 newOperands.append(outParams.begin(), outParams.end());
153 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
154 replaceWithNewCallResults, [](Value v) { return v.getType(); }));
155 auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
156 newResultTypes, newOperands);
157 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
158 std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
159 op.erase();
160 });
161
162 return failure(didFail);
163 }
164
165 LogicalResult
promoteBufferResultsToOutParams(ModuleOp module)166 mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) {
167 for (auto func : module.getOps<func::FuncOp>()) {
168 SmallVector<BlockArgument, 6> appendedEntryArgs;
169 if (failed(updateFuncOp(func, appendedEntryArgs)))
170 return failure();
171 if (func.isExternal())
172 continue;
173 updateReturnOps(func, appendedEntryArgs);
174 }
175 if (failed(updateCalls(module)))
176 return failure();
177 return success();
178 }
179
180 namespace {
181 struct BufferResultsToOutParamsPass
182 : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> {
runOnOperation__anon6cfd8ace0411::BufferResultsToOutParamsPass183 void runOnOperation() override {
184 if (failed(bufferization::promoteBufferResultsToOutParams(getOperation())))
185 return signalPassFailure();
186 }
187 };
188 } // namespace
189
190 std::unique_ptr<Pass>
createBufferResultsToOutParamsPass()191 mlir::bufferization::createBufferResultsToOutParamsPass() {
192 return std::make_unique<BufferResultsToOutParamsPass>();
193 }
194