188539c5bSMatthias Springer //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===//
288539c5bSMatthias Springer //
388539c5bSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
488539c5bSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
588539c5bSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
688539c5bSMatthias Springer //
788539c5bSMatthias Springer //===----------------------------------------------------------------------===//
888539c5bSMatthias Springer //
988539c5bSMatthias Springer // This pass drops return values from functions if they are equivalent to one of
1088539c5bSMatthias Springer // their arguments. E.g.:
1188539c5bSMatthias Springer //
1288539c5bSMatthias Springer // ```
1388539c5bSMatthias Springer // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) {
1488539c5bSMatthias Springer //   return %m : memref<?xf32>
1588539c5bSMatthias Springer // }
1688539c5bSMatthias Springer // ```
1788539c5bSMatthias Springer //
1888539c5bSMatthias Springer // This functions is rewritten to:
1988539c5bSMatthias Springer //
2088539c5bSMatthias Springer // ```
2188539c5bSMatthias Springer // func.func @foo(%m : memref<?xf32>) {
2288539c5bSMatthias Springer //   return
2388539c5bSMatthias Springer // }
2488539c5bSMatthias Springer // ```
2588539c5bSMatthias Springer //
2688539c5bSMatthias Springer // All call sites are updated accordingly. If a function returns a cast of a
2788539c5bSMatthias Springer // function argument, it is also considered equivalent. A cast is inserted at
2888539c5bSMatthias Springer // the call site in that case.
2988539c5bSMatthias Springer 
3088539c5bSMatthias Springer #include "PassDetail.h"
3188539c5bSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
3288539c5bSMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
3388539c5bSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
3488539c5bSMatthias Springer #include "mlir/IR/Operation.h"
3588539c5bSMatthias Springer #include "mlir/Pass/Pass.h"
3688539c5bSMatthias Springer 
3788539c5bSMatthias Springer using namespace mlir;
3888539c5bSMatthias Springer 
3988539c5bSMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`.
4088539c5bSMatthias Springer /// Return nullptr if there is no such unique ReturnOp.
getAssumedUniqueReturnOp(func::FuncOp funcOp)4188539c5bSMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
4288539c5bSMatthias Springer   func::ReturnOp returnOp;
4388539c5bSMatthias Springer   for (Block &b : funcOp.getBody()) {
4488539c5bSMatthias Springer     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
4588539c5bSMatthias Springer       if (returnOp)
4688539c5bSMatthias Springer         return nullptr;
4788539c5bSMatthias Springer       returnOp = candidateOp;
4888539c5bSMatthias Springer     }
4988539c5bSMatthias Springer   }
5088539c5bSMatthias Springer   return returnOp;
5188539c5bSMatthias Springer }
5288539c5bSMatthias Springer 
5388539c5bSMatthias Springer /// Return the func::FuncOp called by `callOp`.
getCalledFunction(CallOpInterface callOp)5488539c5bSMatthias Springer static func::FuncOp getCalledFunction(CallOpInterface callOp) {
5588539c5bSMatthias Springer   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
5688539c5bSMatthias Springer   if (!sym)
5788539c5bSMatthias Springer     return nullptr;
5888539c5bSMatthias Springer   return dyn_cast_or_null<func::FuncOp>(
5988539c5bSMatthias Springer       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
6088539c5bSMatthias Springer }
6188539c5bSMatthias Springer 
6288539c5bSMatthias Springer LogicalResult
dropEquivalentBufferResults(ModuleOp module)6388539c5bSMatthias Springer mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
6488539c5bSMatthias Springer   IRRewriter rewriter(module.getContext());
6588539c5bSMatthias Springer 
6688539c5bSMatthias Springer   for (auto funcOp : module.getOps<func::FuncOp>()) {
6788539c5bSMatthias Springer     if (funcOp.isExternal())
6888539c5bSMatthias Springer       continue;
6988539c5bSMatthias Springer     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
7088539c5bSMatthias Springer     // TODO: Support functions with multiple blocks.
7188539c5bSMatthias Springer     if (!returnOp)
7288539c5bSMatthias Springer       continue;
7388539c5bSMatthias Springer 
7488539c5bSMatthias Springer     // Compute erased results.
7588539c5bSMatthias Springer     SmallVector<Value> newReturnValues;
7688539c5bSMatthias Springer     BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
7788539c5bSMatthias Springer     DenseMap<int64_t, int64_t> resultToArgs;
7888539c5bSMatthias Springer     for (const auto &it : llvm::enumerate(returnOp.operands())) {
7988539c5bSMatthias Springer       bool erased = false;
8088539c5bSMatthias Springer       for (BlockArgument bbArg : funcOp.getArguments()) {
8188539c5bSMatthias Springer         Value val = it.value();
8288539c5bSMatthias Springer         while (auto castOp = val.getDefiningOp<memref::CastOp>())
83*136d746eSJacques Pienaar           val = castOp.getSource();
8488539c5bSMatthias Springer 
8588539c5bSMatthias Springer         if (val == bbArg) {
8688539c5bSMatthias Springer           resultToArgs[it.index()] = bbArg.getArgNumber();
8788539c5bSMatthias Springer           erased = true;
8888539c5bSMatthias Springer           break;
8988539c5bSMatthias Springer         }
9088539c5bSMatthias Springer       }
9188539c5bSMatthias Springer 
9288539c5bSMatthias Springer       if (erased) {
9388539c5bSMatthias Springer         erasedResultIndices.set(it.index());
9488539c5bSMatthias Springer       } else {
9588539c5bSMatthias Springer         newReturnValues.push_back(it.value());
9688539c5bSMatthias Springer       }
9788539c5bSMatthias Springer     }
9888539c5bSMatthias Springer 
9988539c5bSMatthias Springer     // Update function.
10088539c5bSMatthias Springer     funcOp.eraseResults(erasedResultIndices);
10188539c5bSMatthias Springer     returnOp.operandsMutable().assign(newReturnValues);
10288539c5bSMatthias Springer 
10388539c5bSMatthias Springer     // Update function calls.
10488539c5bSMatthias Springer     module.walk([&](func::CallOp callOp) {
10588539c5bSMatthias Springer       if (getCalledFunction(callOp) != funcOp)
10688539c5bSMatthias Springer         return WalkResult::skip();
10788539c5bSMatthias Springer 
10888539c5bSMatthias Springer       rewriter.setInsertionPoint(callOp);
10988539c5bSMatthias Springer       auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
11088539c5bSMatthias Springer                                                      callOp.operands());
11188539c5bSMatthias Springer       SmallVector<Value> newResults;
11288539c5bSMatthias Springer       int64_t nextResult = 0;
11388539c5bSMatthias Springer       for (int64_t i = 0; i < callOp.getNumResults(); ++i) {
11488539c5bSMatthias Springer         if (!resultToArgs.count(i)) {
11588539c5bSMatthias Springer           // This result was not erased.
11688539c5bSMatthias Springer           newResults.push_back(newCallOp.getResult(nextResult++));
11788539c5bSMatthias Springer           continue;
11888539c5bSMatthias Springer         }
11988539c5bSMatthias Springer 
12088539c5bSMatthias Springer         // This result was erased.
12188539c5bSMatthias Springer         Value replacement = callOp.getOperand(resultToArgs[i]);
12288539c5bSMatthias Springer         Type expectedType = callOp.getResult(i).getType();
12388539c5bSMatthias Springer         if (replacement.getType() != expectedType) {
12488539c5bSMatthias Springer           // A cast must be inserted at the call site.
12588539c5bSMatthias Springer           replacement = rewriter.create<memref::CastOp>(
12688539c5bSMatthias Springer               callOp.getLoc(), expectedType, replacement);
12788539c5bSMatthias Springer         }
12888539c5bSMatthias Springer         newResults.push_back(replacement);
12988539c5bSMatthias Springer       }
13088539c5bSMatthias Springer       rewriter.replaceOp(callOp, newResults);
13188539c5bSMatthias Springer       return WalkResult::advance();
13288539c5bSMatthias Springer     });
13388539c5bSMatthias Springer   }
13488539c5bSMatthias Springer 
13588539c5bSMatthias Springer   return success();
13688539c5bSMatthias Springer }
13788539c5bSMatthias Springer 
13888539c5bSMatthias Springer namespace {
13988539c5bSMatthias Springer struct DropEquivalentBufferResultsPass
14088539c5bSMatthias Springer     : DropEquivalentBufferResultsBase<DropEquivalentBufferResultsPass> {
runOnOperation__anonfa0fa9820211::DropEquivalentBufferResultsPass14188539c5bSMatthias Springer   void runOnOperation() override {
14288539c5bSMatthias Springer     if (failed(bufferization::dropEquivalentBufferResults(getOperation())))
14388539c5bSMatthias Springer       return signalPassFailure();
14488539c5bSMatthias Springer   }
14588539c5bSMatthias Springer };
14688539c5bSMatthias Springer } // namespace
14788539c5bSMatthias Springer 
14888539c5bSMatthias Springer std::unique_ptr<Pass>
createDropEquivalentBufferResultsPass()14988539c5bSMatthias Springer mlir::bufferization::createDropEquivalentBufferResultsPass() {
15088539c5bSMatthias Springer   return std::make_unique<DropEquivalentBufferResultsPass>();
15188539c5bSMatthias Springer }
152