188539c5bSMatthias Springer //===- DropEquivalentBufferResults.cpp - Calling convention conversion ----===// 288539c5bSMatthias Springer // 388539c5bSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 488539c5bSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 588539c5bSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 688539c5bSMatthias Springer // 788539c5bSMatthias Springer //===----------------------------------------------------------------------===// 888539c5bSMatthias Springer // 988539c5bSMatthias Springer // This pass drops return values from functions if they are equivalent to one of 1088539c5bSMatthias Springer // their arguments. E.g.: 1188539c5bSMatthias Springer // 1288539c5bSMatthias Springer // ``` 1388539c5bSMatthias Springer // func.func @foo(%m : memref<?xf32>) -> (memref<?xf32>) { 1488539c5bSMatthias Springer // return %m : memref<?xf32> 1588539c5bSMatthias Springer // } 1688539c5bSMatthias Springer // ``` 1788539c5bSMatthias Springer // 1888539c5bSMatthias Springer // This functions is rewritten to: 1988539c5bSMatthias Springer // 2088539c5bSMatthias Springer // ``` 2188539c5bSMatthias Springer // func.func @foo(%m : memref<?xf32>) { 2288539c5bSMatthias Springer // return 2388539c5bSMatthias Springer // } 2488539c5bSMatthias Springer // ``` 2588539c5bSMatthias Springer // 2688539c5bSMatthias Springer // All call sites are updated accordingly. If a function returns a cast of a 2788539c5bSMatthias Springer // function argument, it is also considered equivalent. A cast is inserted at 2888539c5bSMatthias Springer // the call site in that case. 2988539c5bSMatthias Springer 3088539c5bSMatthias Springer #include "PassDetail.h" 3188539c5bSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 3288539c5bSMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 3388539c5bSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 3488539c5bSMatthias Springer #include "mlir/IR/Operation.h" 3588539c5bSMatthias Springer #include "mlir/Pass/Pass.h" 3688539c5bSMatthias Springer 3788539c5bSMatthias Springer using namespace mlir; 3888539c5bSMatthias Springer 3988539c5bSMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`. 4088539c5bSMatthias Springer /// Return nullptr if there is no such unique ReturnOp. getAssumedUniqueReturnOp(func::FuncOp funcOp)4188539c5bSMatthias Springerstatic func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 4288539c5bSMatthias Springer func::ReturnOp returnOp; 4388539c5bSMatthias Springer for (Block &b : funcOp.getBody()) { 4488539c5bSMatthias Springer if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 4588539c5bSMatthias Springer if (returnOp) 4688539c5bSMatthias Springer return nullptr; 4788539c5bSMatthias Springer returnOp = candidateOp; 4888539c5bSMatthias Springer } 4988539c5bSMatthias Springer } 5088539c5bSMatthias Springer return returnOp; 5188539c5bSMatthias Springer } 5288539c5bSMatthias Springer 5388539c5bSMatthias Springer /// Return the func::FuncOp called by `callOp`. getCalledFunction(CallOpInterface callOp)5488539c5bSMatthias Springerstatic func::FuncOp getCalledFunction(CallOpInterface callOp) { 5588539c5bSMatthias Springer SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 5688539c5bSMatthias Springer if (!sym) 5788539c5bSMatthias Springer return nullptr; 5888539c5bSMatthias Springer return dyn_cast_or_null<func::FuncOp>( 5988539c5bSMatthias Springer SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 6088539c5bSMatthias Springer } 6188539c5bSMatthias Springer 6288539c5bSMatthias Springer LogicalResult dropEquivalentBufferResults(ModuleOp module)6388539c5bSMatthias Springermlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { 6488539c5bSMatthias Springer IRRewriter rewriter(module.getContext()); 6588539c5bSMatthias Springer 6688539c5bSMatthias Springer for (auto funcOp : module.getOps<func::FuncOp>()) { 6788539c5bSMatthias Springer if (funcOp.isExternal()) 6888539c5bSMatthias Springer continue; 6988539c5bSMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 7088539c5bSMatthias Springer // TODO: Support functions with multiple blocks. 7188539c5bSMatthias Springer if (!returnOp) 7288539c5bSMatthias Springer continue; 7388539c5bSMatthias Springer 7488539c5bSMatthias Springer // Compute erased results. 7588539c5bSMatthias Springer SmallVector<Value> newReturnValues; 7688539c5bSMatthias Springer BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); 7788539c5bSMatthias Springer DenseMap<int64_t, int64_t> resultToArgs; 7888539c5bSMatthias Springer for (const auto &it : llvm::enumerate(returnOp.operands())) { 7988539c5bSMatthias Springer bool erased = false; 8088539c5bSMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) { 8188539c5bSMatthias Springer Value val = it.value(); 8288539c5bSMatthias Springer while (auto castOp = val.getDefiningOp<memref::CastOp>()) 83*136d746eSJacques Pienaar val = castOp.getSource(); 8488539c5bSMatthias Springer 8588539c5bSMatthias Springer if (val == bbArg) { 8688539c5bSMatthias Springer resultToArgs[it.index()] = bbArg.getArgNumber(); 8788539c5bSMatthias Springer erased = true; 8888539c5bSMatthias Springer break; 8988539c5bSMatthias Springer } 9088539c5bSMatthias Springer } 9188539c5bSMatthias Springer 9288539c5bSMatthias Springer if (erased) { 9388539c5bSMatthias Springer erasedResultIndices.set(it.index()); 9488539c5bSMatthias Springer } else { 9588539c5bSMatthias Springer newReturnValues.push_back(it.value()); 9688539c5bSMatthias Springer } 9788539c5bSMatthias Springer } 9888539c5bSMatthias Springer 9988539c5bSMatthias Springer // Update function. 10088539c5bSMatthias Springer funcOp.eraseResults(erasedResultIndices); 10188539c5bSMatthias Springer returnOp.operandsMutable().assign(newReturnValues); 10288539c5bSMatthias Springer 10388539c5bSMatthias Springer // Update function calls. 10488539c5bSMatthias Springer module.walk([&](func::CallOp callOp) { 10588539c5bSMatthias Springer if (getCalledFunction(callOp) != funcOp) 10688539c5bSMatthias Springer return WalkResult::skip(); 10788539c5bSMatthias Springer 10888539c5bSMatthias Springer rewriter.setInsertionPoint(callOp); 10988539c5bSMatthias Springer auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp, 11088539c5bSMatthias Springer callOp.operands()); 11188539c5bSMatthias Springer SmallVector<Value> newResults; 11288539c5bSMatthias Springer int64_t nextResult = 0; 11388539c5bSMatthias Springer for (int64_t i = 0; i < callOp.getNumResults(); ++i) { 11488539c5bSMatthias Springer if (!resultToArgs.count(i)) { 11588539c5bSMatthias Springer // This result was not erased. 11688539c5bSMatthias Springer newResults.push_back(newCallOp.getResult(nextResult++)); 11788539c5bSMatthias Springer continue; 11888539c5bSMatthias Springer } 11988539c5bSMatthias Springer 12088539c5bSMatthias Springer // This result was erased. 12188539c5bSMatthias Springer Value replacement = callOp.getOperand(resultToArgs[i]); 12288539c5bSMatthias Springer Type expectedType = callOp.getResult(i).getType(); 12388539c5bSMatthias Springer if (replacement.getType() != expectedType) { 12488539c5bSMatthias Springer // A cast must be inserted at the call site. 12588539c5bSMatthias Springer replacement = rewriter.create<memref::CastOp>( 12688539c5bSMatthias Springer callOp.getLoc(), expectedType, replacement); 12788539c5bSMatthias Springer } 12888539c5bSMatthias Springer newResults.push_back(replacement); 12988539c5bSMatthias Springer } 13088539c5bSMatthias Springer rewriter.replaceOp(callOp, newResults); 13188539c5bSMatthias Springer return WalkResult::advance(); 13288539c5bSMatthias Springer }); 13388539c5bSMatthias Springer } 13488539c5bSMatthias Springer 13588539c5bSMatthias Springer return success(); 13688539c5bSMatthias Springer } 13788539c5bSMatthias Springer 13888539c5bSMatthias Springer namespace { 13988539c5bSMatthias Springer struct DropEquivalentBufferResultsPass 14088539c5bSMatthias Springer : DropEquivalentBufferResultsBase<DropEquivalentBufferResultsPass> { runOnOperation__anonfa0fa9820211::DropEquivalentBufferResultsPass14188539c5bSMatthias Springer void runOnOperation() override { 14288539c5bSMatthias Springer if (failed(bufferization::dropEquivalentBufferResults(getOperation()))) 14388539c5bSMatthias Springer return signalPassFailure(); 14488539c5bSMatthias Springer } 14588539c5bSMatthias Springer }; 14688539c5bSMatthias Springer } // namespace 14788539c5bSMatthias Springer 14888539c5bSMatthias Springer std::unique_ptr<Pass> createDropEquivalentBufferResultsPass()14988539c5bSMatthias Springermlir::bufferization::createDropEquivalentBufferResultsPass() { 15088539c5bSMatthias Springer return std::make_unique<DropEquivalentBufferResultsPass>(); 15188539c5bSMatthias Springer } 152