1 //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass drops return values from functions if they are equivalent to one of
10 // their arguments. E.g.:
11 //
12 // ```
13 // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
14 //   return %m : memref<?xf32>
15 // }
16 // ```
17 //
18 // This functions is rewritten to:
19 //
20 // ```
21 // func.func @foo(%m : memref<?xf32>) {
22 //   return
23 // }
24 // ```
25 //
26 // All call sites are updated accordingly. If a function returns a cast of a
27 // function argument, it is also considered equivalent. A cast is inserted at
28 // the call site in that case.
29 
30 #include "PassDetail.h"
31 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h"
33 #include "mlir/Dialect/MemRef/IR/MemRef.h"
34 #include "mlir/IR/Operation.h"
35 #include "mlir/Pass/Pass.h"
36 
37 using namespace mlir;
38 
39 /// Return the unique ReturnOp that terminates `funcOp`.
40 /// Return nullptr if there is no such unique ReturnOp.
getAssumedUniqueReturnOp(func::FuncOp funcOp)41 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
42   func::ReturnOp returnOp;
43   for (Block &b : funcOp.getBody()) {
44     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
45       if (returnOp)
46         return nullptr;
47       returnOp = candidateOp;
48     }
49   }
50   return returnOp;
51 }
52 
53 /// Return the func::FuncOp called by `callOp`.
getCalledFunction(CallOpInterface callOp)54 static func::FuncOp getCalledFunction(CallOpInterface callOp) {
55   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
56   if (!sym)
57     return nullptr;
58   return dyn_cast_or_null<func::FuncOp>(
59       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
60 }
61 
62 LogicalResult
dropEquivalentBufferResults(ModuleOp module)63 mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
64   IRRewriter rewriter(module.getContext());
65 
66   for (auto funcOp : module.getOps<func::FuncOp>()) {
67     if (funcOp.isExternal())
68       continue;
69     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
70     // TODO: Support functions with multiple blocks.
71     if (!returnOp)
72       continue;
73 
74     // Compute erased results.
75     SmallVector<Value> newReturnValues;
76     BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
77     DenseMap<int64_t, int64_t> resultToArgs;
78     for (const auto &it : llvm::enumerate(returnOp.operands())) {
79       bool erased = false;
80       for (BlockArgument bbArg : funcOp.getArguments()) {
81         Value val = it.value();
82         while (auto castOp = val.getDefiningOp<memref::CastOp>())
83           val = castOp.getSource();
84 
85         if (val == bbArg) {
86           resultToArgs[it.index()] = bbArg.getArgNumber();
87           erased = true;
88           break;
89         }
90       }
91 
92       if (erased) {
93         erasedResultIndices.set(it.index());
94       } else {
95         newReturnValues.push_back(it.value());
96       }
97     }
98 
99     // Update function.
100     funcOp.eraseResults(erasedResultIndices);
101     returnOp.operandsMutable().assign(newReturnValues);
102 
103     // Update function calls.
104     module.walk([&](func::CallOp callOp) {
105       if (getCalledFunction(callOp) != funcOp)
106         return WalkResult::skip();
107 
108       rewriter.setInsertionPoint(callOp);
109       auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
110                                                      callOp.operands());
111       SmallVector<Value> newResults;
112       int64_t nextResult = 0;
113       for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
114         if (!resultToArgs.count(i)) {
115           // This result was not erased.
116           newResults.push_back(newCallOp.getResult(nextResult++));
117           continue;
118         }
119 
120         // This result was erased.
121         Value replacement = callOp.getOperand(resultToArgs[i]);
122         Type expectedType = callOp.getResult(i).getType();
123         if (replacement.getType() != expectedType) {
124           // A cast must be inserted at the call site.
125           replacement = rewriter.create<memref::CastOp>(
126               callOp.getLoc(), expectedType, replacement);
127         }
128         newResults.push_back(replacement);
129       }
130       rewriter.replaceOp(callOp, newResults);
131       return WalkResult::advance();
132     });
133   }
134 
135   return success();
136 }
137 
138 namespace {
139 struct DropEquivalentBufferResultsPass
140     : DropEquivalentBufferResultsBase<DropEquivalentBufferResultsPass> {
runOnOperation__anonfa0fa9820211::DropEquivalentBufferResultsPass141   void runOnOperation() override {
142     if (failed(bufferization::dropEquivalentBufferResults(getOperation())))
143       return signalPassFailure();
144   }
145 };
146 } // namespace
147 
148 std::unique_ptr<Pass>
createDropEquivalentBufferResultsPass()149 mlir::bufferization::createDropEquivalentBufferResultsPass() {
150   return std::make_unique<DropEquivalentBufferResultsPass>();
151 }
152