1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 
18 // Updates the func op and entry block.
19 //
20 // Any args appended to the entry block are added to `appendedEntryArgs`.
21 static void updateFuncOp(func::FuncOp func,
22                          SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
23   auto functionType = func.getFunctionType();
24 
25   // Collect information about the results will become appended arguments.
26   SmallVector<Type, 6> erasedResultTypes;
27   BitVector erasedResultIndices(functionType.getNumResults());
28   for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
29     if (resultType.value().isa<BaseMemRefType>()) {
30       erasedResultIndices.set(resultType.index());
31       erasedResultTypes.push_back(resultType.value());
32     }
33   }
34 
35   // Add the new arguments to the function type.
36   auto newArgTypes = llvm::to_vector<6>(
37       llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
38   auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
39                                            functionType.getResults());
40   func.setType(newFunctionType);
41 
42   // Transfer the result attributes to arg attributes.
43   auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
44   for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
45     func.setArgAttrs(functionType.getNumInputs() + i,
46                      func.getResultAttrs(*erasedIndicesIt));
47   }
48 
49   // Erase the results.
50   func.eraseResults(erasedResultIndices);
51 
52   // Add the new arguments to the entry block if the function is not external.
53   if (func.isExternal())
54     return;
55   Location loc = func.getLoc();
56   for (Type type : erasedResultTypes)
57     appendedEntryArgs.push_back(func.front().addArgument(type, loc));
58 }
59 
60 // Updates all ReturnOps in the scope of the given func::FuncOp by either
61 // keeping them as return values or copying the associated buffer contents into
62 // the given out-params.
63 static void updateReturnOps(func::FuncOp func,
64                             ArrayRef<BlockArgument> appendedEntryArgs) {
65   func.walk([&](func::ReturnOp op) {
66     SmallVector<Value, 6> copyIntoOutParams;
67     SmallVector<Value, 6> keepAsReturnOperands;
68     for (Value operand : op.getOperands()) {
69       if (operand.getType().isa<BaseMemRefType>())
70         copyIntoOutParams.push_back(operand);
71       else
72         keepAsReturnOperands.push_back(operand);
73     }
74     OpBuilder builder(op);
75     for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
76       builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
77                                      std::get<1>(t));
78     builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
79     op.erase();
80   });
81 }
82 
83 // Updates all CallOps in the scope of the given ModuleOp by allocating
84 // temporary buffers for newly introduced out params.
85 static LogicalResult updateCalls(ModuleOp module) {
86   bool didFail = false;
87   module.walk([&](func::CallOp op) {
88     SmallVector<Value, 6> replaceWithNewCallResults;
89     SmallVector<Value, 6> replaceWithOutParams;
90     for (OpResult result : op.getResults()) {
91       if (result.getType().isa<BaseMemRefType>())
92         replaceWithOutParams.push_back(result);
93       else
94         replaceWithNewCallResults.push_back(result);
95     }
96     SmallVector<Value, 6> outParams;
97     OpBuilder builder(op);
98     for (Value memref : replaceWithOutParams) {
99       if (!memref.getType().cast<BaseMemRefType>().hasStaticShape()) {
100         op.emitError()
101             << "cannot create out param for dynamically shaped result";
102         didFail = true;
103         return;
104       }
105       Value outParam = builder.create<memref::AllocOp>(
106           op.getLoc(), memref.getType().cast<MemRefType>());
107       memref.replaceAllUsesWith(outParam);
108       outParams.push_back(outParam);
109     }
110 
111     auto newOperands = llvm::to_vector<6>(op.getOperands());
112     newOperands.append(outParams.begin(), outParams.end());
113     auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
114         replaceWithNewCallResults, [](Value v) { return v.getType(); }));
115     auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
116                                                 newResultTypes, newOperands);
117     for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
118       std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
119     op.erase();
120   });
121 
122   return failure(didFail);
123 }
124 
125 namespace {
126 struct BufferResultsToOutParamsPass
127     : BufferResultsToOutParamsBase<BufferResultsToOutParamsPass> {
128   void runOnOperation() override {
129     ModuleOp module = getOperation();
130 
131     for (auto func : module.getOps<func::FuncOp>()) {
132       SmallVector<BlockArgument, 6> appendedEntryArgs;
133       updateFuncOp(func, appendedEntryArgs);
134       if (func.isExternal())
135         continue;
136       updateReturnOps(func, appendedEntryArgs);
137     }
138     if (failed(updateCalls(module)))
139       return signalPassFailure();
140   }
141 };
142 } // namespace
143 
144 std::unique_ptr<Pass>
145 mlir::bufferization::createBufferResultsToOutParamsPass() {
146   return std::make_unique<BufferResultsToOutParamsPass>();
147 }
148