1e07a7fd5SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2e07a7fd5SMatthias Springer // 3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e07a7fd5SMatthias Springer // 7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===// 8e07a7fd5SMatthias Springer 9e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 14e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 15e07a7fd5SMatthias Springer #include "mlir/IR/Dialect.h" 16e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h" 17e07a7fd5SMatthias Springer 18e07a7fd5SMatthias Springer namespace mlir { 19e07a7fd5SMatthias Springer namespace bufferization { 20e07a7fd5SMatthias Springer namespace func_ext { 21e07a7fd5SMatthias Springer 22e07a7fd5SMatthias Springer void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 23e07a7fd5SMatthias Springer analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 24e07a7fd5SMatthias Springer auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 25e07a7fd5SMatthias Springer auto createdAliasingOperands = 26e07a7fd5SMatthias Springer aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping()); 27e07a7fd5SMatthias Springer auto createdAliasingResults = 28e07a7fd5SMatthias Springer aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 29e07a7fd5SMatthias Springer auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 30e07a7fd5SMatthias Springer auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 31e07a7fd5SMatthias Springer (void)createdEquiv; 32e07a7fd5SMatthias Springer (void)createdAliasingOperands; 33e07a7fd5SMatthias Springer (void)createdAliasingResults; 34e07a7fd5SMatthias Springer (void)createdRead; 35e07a7fd5SMatthias Springer (void)createdWritten; 36e07a7fd5SMatthias Springer #ifndef NDEBUG 37e07a7fd5SMatthias Springer assert(createdEquiv.second && "equivalence info exists already"); 38e07a7fd5SMatthias Springer assert(createdAliasingOperands.second && "aliasing info exists already"); 39e07a7fd5SMatthias Springer assert(createdAliasingResults.second && "aliasing info exists already"); 40e07a7fd5SMatthias Springer assert(createdRead.second && "bbarg access info exists already"); 41e07a7fd5SMatthias Springer assert(createdWritten.second && "bbarg access info exists already"); 42e07a7fd5SMatthias Springer #endif // NDEBUG 43e07a7fd5SMatthias Springer } 44e07a7fd5SMatthias Springer 45e07a7fd5SMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`. 46e07a7fd5SMatthias Springer /// Return nullptr if there is no such unique ReturnOp. 47e07a7fd5SMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { 48e07a7fd5SMatthias Springer func::ReturnOp returnOp; 49e07a7fd5SMatthias Springer for (Block &b : funcOp.getBody()) { 50e07a7fd5SMatthias Springer if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 51e07a7fd5SMatthias Springer if (returnOp) 52e07a7fd5SMatthias Springer return nullptr; 53e07a7fd5SMatthias Springer returnOp = candidateOp; 54e07a7fd5SMatthias Springer } 55e07a7fd5SMatthias Springer } 56e07a7fd5SMatthias Springer return returnOp; 57e07a7fd5SMatthias Springer } 58e07a7fd5SMatthias Springer 59e07a7fd5SMatthias Springer /// Return the index-th bufferized function argument type. This assumes that the 60e07a7fd5SMatthias Springer /// specified argument is a tensor. If the tensor is ranked, a layout map may be 61e07a7fd5SMatthias Springer /// specified by the user. If no layout map is specified, a fully dynamic map is 62e07a7fd5SMatthias Springer /// used. 63e07a7fd5SMatthias Springer static BaseMemRefType 64e07a7fd5SMatthias Springer getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 65e07a7fd5SMatthias Springer const BufferizationOptions &options) { 66e07a7fd5SMatthias Springer auto tensorType = 67e07a7fd5SMatthias Springer funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>(); 68e07a7fd5SMatthias Springer assert(tensorType && "expected TensorType"); 69e07a7fd5SMatthias Springer BaseMemRefType memrefType = getMemRefType(tensorType, options); 70e07a7fd5SMatthias Springer 71e07a7fd5SMatthias Springer auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 72e07a7fd5SMatthias Springer index, BufferizationDialect::kBufferLayoutAttrName); 73e07a7fd5SMatthias Springer if (!layoutAttr) 74e07a7fd5SMatthias Springer return memrefType; 75e07a7fd5SMatthias Springer 76e07a7fd5SMatthias Springer auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); 77e07a7fd5SMatthias Springer assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 78e07a7fd5SMatthias Springer return MemRefType::get( 79e07a7fd5SMatthias Springer rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 80e07a7fd5SMatthias Springer layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); 81e07a7fd5SMatthias Springer } 82e07a7fd5SMatthias Springer 83e07a7fd5SMatthias Springer /// Return the FuncOp called by `callOp`. 84e07a7fd5SMatthias Springer static FuncOp getCalledFunction(CallOpInterface callOp) { 85e07a7fd5SMatthias Springer SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 86e07a7fd5SMatthias Springer if (!sym) 87e07a7fd5SMatthias Springer return nullptr; 88e07a7fd5SMatthias Springer return dyn_cast_or_null<FuncOp>( 89e07a7fd5SMatthias Springer SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 90e07a7fd5SMatthias Springer } 91e07a7fd5SMatthias Springer 92e07a7fd5SMatthias Springer /// Get FuncAnalysisState. 93e07a7fd5SMatthias Springer static const FuncAnalysisState & 94e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) { 95e07a7fd5SMatthias Springer Optional<const FuncAnalysisState *> maybeState = 96e07a7fd5SMatthias Springer state.getDialectState<FuncAnalysisState>( 97e07a7fd5SMatthias Springer func::FuncDialect::getDialectNamespace()); 98e07a7fd5SMatthias Springer assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 99e07a7fd5SMatthias Springer return **maybeState; 100e07a7fd5SMatthias Springer } 101e07a7fd5SMatthias Springer 102e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp. 103e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 104e07a7fd5SMatthias Springer FuncOp funcOp) { 105e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 106e07a7fd5SMatthias Springer auto it = funcState.analyzedFuncOps.find(funcOp); 107e07a7fd5SMatthias Springer if (it == funcState.analyzedFuncOps.end()) 108e07a7fd5SMatthias Springer return FuncOpAnalysisState::NotAnalyzed; 109e07a7fd5SMatthias Springer return it->second; 110e07a7fd5SMatthias Springer } 111e07a7fd5SMatthias Springer 112e07a7fd5SMatthias Springer /// Return the index of the bbArg in the given FuncOp that is equivalent to the 113e07a7fd5SMatthias Springer /// specified return value (if any). 114e07a7fd5SMatthias Springer static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp, 115e07a7fd5SMatthias Springer const FuncAnalysisState &state, 116e07a7fd5SMatthias Springer int64_t returnValIdx) { 117e07a7fd5SMatthias Springer auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 118e07a7fd5SMatthias Springer if (funcOpIt == state.equivalentFuncArgs.end()) 119e07a7fd5SMatthias Springer // No equivalence info stores for funcOp. 120e07a7fd5SMatthias Springer return None; 121e07a7fd5SMatthias Springer 122e07a7fd5SMatthias Springer auto retValIt = funcOpIt->getSecond().find(returnValIdx); 123e07a7fd5SMatthias Springer if (retValIt == funcOpIt->getSecond().end()) 124e07a7fd5SMatthias Springer // Return value has no equivalent bbArg. 125e07a7fd5SMatthias Springer return None; 126e07a7fd5SMatthias Springer 127e07a7fd5SMatthias Springer return retValIt->getSecond(); 128e07a7fd5SMatthias Springer } 129e07a7fd5SMatthias Springer 130e07a7fd5SMatthias Springer struct CallOpInterface 131e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<CallOpInterface, 132e07a7fd5SMatthias Springer func::CallOp> { 133e07a7fd5SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 134e07a7fd5SMatthias Springer const AnalysisState &state) const { 135e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 136e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 137e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 138e07a7fd5SMatthias Springer 139e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 140e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 141e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Assume that OpOperand is read. 142e07a7fd5SMatthias Springer return true; 143e07a7fd5SMatthias Springer 144e07a7fd5SMatthias Springer return funcState.readBbArgs.lookup(funcOp).contains( 145e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 146e07a7fd5SMatthias Springer } 147e07a7fd5SMatthias Springer 148e07a7fd5SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 149e07a7fd5SMatthias Springer const AnalysisState &state) const { 150e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 151e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 152e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 153e07a7fd5SMatthias Springer 154e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 155e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 156e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Assume that OpOperand is written. 157e07a7fd5SMatthias Springer return true; 158e07a7fd5SMatthias Springer 159e07a7fd5SMatthias Springer return funcState.writtenBbArgs.lookup(funcOp).contains( 160e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 161e07a7fd5SMatthias Springer } 162e07a7fd5SMatthias Springer 163e07a7fd5SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 164e07a7fd5SMatthias Springer const AnalysisState &state) const { 165e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 166e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 167e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 168e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 169e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != 170e07a7fd5SMatthias Springer FuncOpAnalysisState::Analyzed) { 171e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Any OpResult may be aliasing. 172e07a7fd5SMatthias Springer SmallVector<OpResult> result; 173e07a7fd5SMatthias Springer for (OpResult opResult : op->getOpResults()) 174e07a7fd5SMatthias Springer if (opResult.getType().isa<TensorType>()) 175e07a7fd5SMatthias Springer result.push_back(opResult); 176e07a7fd5SMatthias Springer return result; 177e07a7fd5SMatthias Springer } 178e07a7fd5SMatthias Springer 179e07a7fd5SMatthias Springer // Get aliasing results from state. 180e07a7fd5SMatthias Springer auto aliasingReturnVals = 181e07a7fd5SMatthias Springer funcState.aliasingReturnVals.lookup(funcOp).lookup( 182e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 183e07a7fd5SMatthias Springer SmallVector<OpResult> result; 184e07a7fd5SMatthias Springer for (int64_t resultIdx : aliasingReturnVals) 185e07a7fd5SMatthias Springer result.push_back(callOp->getOpResult(resultIdx)); 186e07a7fd5SMatthias Springer return result; 187e07a7fd5SMatthias Springer } 188e07a7fd5SMatthias Springer 189e07a7fd5SMatthias Springer SmallVector<OpOperand *> 190e07a7fd5SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 191e07a7fd5SMatthias Springer const AnalysisState &state) const { 192e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 193e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 194e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 195e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 196e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != 197e07a7fd5SMatthias Springer FuncOpAnalysisState::Analyzed) { 198e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Any OpOperand may be aliasing. 199e07a7fd5SMatthias Springer SmallVector<OpOperand *> result; 200e07a7fd5SMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 201e07a7fd5SMatthias Springer if (opOperand.get().getType().isa<TensorType>()) 202e07a7fd5SMatthias Springer result.push_back(&opOperand); 203e07a7fd5SMatthias Springer return result; 204e07a7fd5SMatthias Springer } 205e07a7fd5SMatthias Springer 206e07a7fd5SMatthias Springer // Get aliasing bbArgs from state. 207e07a7fd5SMatthias Springer auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( 208e07a7fd5SMatthias Springer opResult.getResultNumber()); 209e07a7fd5SMatthias Springer SmallVector<OpOperand *> result; 210e07a7fd5SMatthias Springer for (int64_t bbArgIdx : aliasingFuncArgs) 211e07a7fd5SMatthias Springer result.push_back(&callOp->getOpOperand(bbArgIdx)); 212e07a7fd5SMatthias Springer return result; 213e07a7fd5SMatthias Springer } 214e07a7fd5SMatthias Springer 215e07a7fd5SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 216e07a7fd5SMatthias Springer const AnalysisState &state) const { 217e07a7fd5SMatthias Springer return BufferRelation::Equivalent; 218e07a7fd5SMatthias Springer } 219e07a7fd5SMatthias Springer 220e07a7fd5SMatthias Springer /// All function arguments are writable. It is the responsibility of the 221e07a7fd5SMatthias Springer /// CallOp to insert buffer copies where necessary. 222e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 223e07a7fd5SMatthias Springer BufferizationState &state) const { 224e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 225e07a7fd5SMatthias Springer unsigned numResults = callOp.getNumResults(); 226e07a7fd5SMatthias Springer unsigned numOperands = callOp->getNumOperands(); 227e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 228e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 229e07a7fd5SMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 230e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = 231e07a7fd5SMatthias Springer getFuncAnalysisState(state.getAnalysisState()); 232e07a7fd5SMatthias Springer const OneShotBufferizationOptions &options = 233e07a7fd5SMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 234e07a7fd5SMatthias Springer 235e07a7fd5SMatthias Springer // Result types of the bufferized CallOp. 236e07a7fd5SMatthias Springer SmallVector<Type> resultTypes; 237e07a7fd5SMatthias Springer // Replacement values for the existing CallOp. These are usually the results 238e07a7fd5SMatthias Springer // of the bufferized CallOp, unless a tensor result folds onto an operand. 239e07a7fd5SMatthias Springer SmallVector<Value> replacementValues(numResults, Value()); 240e07a7fd5SMatthias Springer // For non-tensor results: A mapping from return val indices of the old 241e07a7fd5SMatthias Springer // CallOp to return val indices of the bufferized CallOp. 242e07a7fd5SMatthias Springer SmallVector<Optional<unsigned>> retValMapping(numResults, None); 243e07a7fd5SMatthias Springer // Operands of the bufferized CallOp. 244e07a7fd5SMatthias Springer SmallVector<Value> newOperands(numOperands, Value()); 245e07a7fd5SMatthias Springer 246e07a7fd5SMatthias Springer // Based on previously gathered equivalence information, we know if a 247e07a7fd5SMatthias Springer // tensor result folds onto an operand. These are the only tensor value 248e07a7fd5SMatthias Springer // results that are supported at the moment. 249e07a7fd5SMatthias Springer // 250e07a7fd5SMatthias Springer // For tensors return values that do not fold onto an operand, additional 251e07a7fd5SMatthias Springer // work is needed (TODO) to either: 252e07a7fd5SMatthias Springer // * hoist a result into an inplaceable operand or 253e07a7fd5SMatthias Springer // * devise a better representation to truly return a buffer. 254e07a7fd5SMatthias Springer // 255e07a7fd5SMatthias Springer // Note: If a function has no body, no equivalence information is 256e07a7fd5SMatthias Springer // available. Consequently, a tensor return value cannot be proven to fold 257e07a7fd5SMatthias Springer // onto a FuncOp bbArg, so calls to such functions are not bufferizable at 258e07a7fd5SMatthias Springer // the moment. 259e07a7fd5SMatthias Springer 260e07a7fd5SMatthias Springer // 1. Compute the result types of the new CallOp. Tensor results that are 261e07a7fd5SMatthias Springer // equivalent to a FuncOp bbArg are no longer returned. 262e07a7fd5SMatthias Springer for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { 263e07a7fd5SMatthias Springer unsigned returnValIdx = it.index(); 264e07a7fd5SMatthias Springer Type returnType = it.value(); 265e07a7fd5SMatthias Springer if (!returnType.isa<TensorType>()) { 266e07a7fd5SMatthias Springer // Non-tensor values are returned. 267e07a7fd5SMatthias Springer retValMapping[returnValIdx] = resultTypes.size(); 268e07a7fd5SMatthias Springer resultTypes.push_back(returnType); 269e07a7fd5SMatthias Springer continue; 270e07a7fd5SMatthias Springer } 271e07a7fd5SMatthias Springer 272e8f7d019SAlexander Belyaev if (options.dropEquivalentFuncResults) { 273e07a7fd5SMatthias Springer if (Optional<int64_t> bbArgIdx = 274e07a7fd5SMatthias Springer getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) { 275e07a7fd5SMatthias Springer // Return operands that are equivalent to some bbArg, are not 276e07a7fd5SMatthias Springer // returned. 277e07a7fd5SMatthias Springer FailureOr<Value> bufferOrFailure = 278e07a7fd5SMatthias Springer state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); 279e07a7fd5SMatthias Springer if (failed(bufferOrFailure)) 280e07a7fd5SMatthias Springer return failure(); 281e07a7fd5SMatthias Springer replacementValues[returnValIdx] = *bufferOrFailure; 282e07a7fd5SMatthias Springer newOperands[*bbArgIdx] = *bufferOrFailure; 283e07a7fd5SMatthias Springer continue; 284e07a7fd5SMatthias Springer } 285e8f7d019SAlexander Belyaev } 286e07a7fd5SMatthias Springer 287e07a7fd5SMatthias Springer if (!options.allowReturnAllocs) 288e07a7fd5SMatthias Springer return callOp->emitError( 289e07a7fd5SMatthias Springer "call to FuncOp that returns non-equivalent tensors not supported"); 290e07a7fd5SMatthias Springer 291e07a7fd5SMatthias Springer // Returning a memref. This memref is not equivalent to any bbArg. It is 292e07a7fd5SMatthias Springer // likely a newly allocated buffer. We may want to hoist such allocations 293e07a7fd5SMatthias Springer // to the call site in the future. 294e07a7fd5SMatthias Springer retValMapping[returnValIdx] = resultTypes.size(); 295e07a7fd5SMatthias Springer resultTypes.push_back(funcType.getResult(resultTypes.size())); 296e07a7fd5SMatthias Springer } 297e07a7fd5SMatthias Springer 298e07a7fd5SMatthias Springer // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. 299e07a7fd5SMatthias Springer for (OpOperand &opOperand : callOp->getOpOperands()) { 300e07a7fd5SMatthias Springer unsigned idx = opOperand.getOperandNumber(); 301e07a7fd5SMatthias Springer Value tensorOperand = opOperand.get(); 302e07a7fd5SMatthias Springer 303e07a7fd5SMatthias Springer // Non-tensor operands are just copied. 304e07a7fd5SMatthias Springer if (!tensorOperand.getType().isa<TensorType>()) { 305e07a7fd5SMatthias Springer newOperands[idx] = tensorOperand; 306e07a7fd5SMatthias Springer continue; 307e07a7fd5SMatthias Springer } 308e07a7fd5SMatthias Springer 309e07a7fd5SMatthias Springer // Retrieve buffers for tensor operands. Tensor operand buffers, who's 310e07a7fd5SMatthias Springer // corresponding FuncOp bbArgs are equivalent to a returned tensor, were 311e07a7fd5SMatthias Springer // already stored in `newOperands` during Step 1. 312e07a7fd5SMatthias Springer Value buffer = newOperands[idx]; 313e07a7fd5SMatthias Springer if (!buffer) { 314e07a7fd5SMatthias Springer FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand); 315e07a7fd5SMatthias Springer if (failed(bufferOrFailure)) 316e07a7fd5SMatthias Springer return failure(); 317e07a7fd5SMatthias Springer buffer = *bufferOrFailure; 318e07a7fd5SMatthias Springer } 319e07a7fd5SMatthias Springer 320e07a7fd5SMatthias Springer // Caller / callee type mismatch is handled with a CastOp. 321e07a7fd5SMatthias Springer auto memRefType = funcType.getInput(idx); 322e07a7fd5SMatthias Springer // Since we don't yet have a clear layout story, to_memref may 323e07a7fd5SMatthias Springer // conservatively turn tensors into more dynamic memref than necessary. 324e07a7fd5SMatthias Springer // If the memref type of the callee fails, introduce an extra memref.cast 325e07a7fd5SMatthias Springer // that will either canonicalize away or fail compilation until we can do 326e07a7fd5SMatthias Springer // something better. 327e07a7fd5SMatthias Springer if (buffer.getType() != memRefType) { 328e07a7fd5SMatthias Springer assert( 329e07a7fd5SMatthias Springer memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && 330e07a7fd5SMatthias Springer "CallOp::bufferize: cast incompatible"); 331e07a7fd5SMatthias Springer Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), 332e07a7fd5SMatthias Springer memRefType, buffer); 333e07a7fd5SMatthias Springer buffer = castBuffer; 334e07a7fd5SMatthias Springer } 335e07a7fd5SMatthias Springer newOperands[idx] = buffer; 336e07a7fd5SMatthias Springer } 337e07a7fd5SMatthias Springer 338e07a7fd5SMatthias Springer // 3. Create the new CallOp. 339e07a7fd5SMatthias Springer Operation *newCallOp = rewriter.create<func::CallOp>( 340e07a7fd5SMatthias Springer callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 341e07a7fd5SMatthias Springer newCallOp->setAttrs(callOp->getAttrs()); 342e07a7fd5SMatthias Springer // Get replacement values for non-tensor / non-equivalent results. 343e07a7fd5SMatthias Springer for (unsigned i = 0; i < replacementValues.size(); ++i) { 344e07a7fd5SMatthias Springer if (replacementValues[i]) 345e07a7fd5SMatthias Springer continue; 346e07a7fd5SMatthias Springer replacementValues[i] = newCallOp->getResult(*retValMapping[i]); 347e07a7fd5SMatthias Springer } 348e07a7fd5SMatthias Springer 349e07a7fd5SMatthias Springer // 4. Replace the old op with the new op. 350e07a7fd5SMatthias Springer replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); 351e07a7fd5SMatthias Springer 352e07a7fd5SMatthias Springer return success(); 353e07a7fd5SMatthias Springer } 354e07a7fd5SMatthias Springer }; 355e07a7fd5SMatthias Springer 356e07a7fd5SMatthias Springer struct ReturnOpInterface 357e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 358e07a7fd5SMatthias Springer func::ReturnOp> { 359e07a7fd5SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 360e07a7fd5SMatthias Springer const AnalysisState &state) const { 361e07a7fd5SMatthias Springer return true; 362e07a7fd5SMatthias Springer } 363e07a7fd5SMatthias Springer 364e07a7fd5SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 365e07a7fd5SMatthias Springer const AnalysisState &state) const { 366e07a7fd5SMatthias Springer return false; 367e07a7fd5SMatthias Springer } 368e07a7fd5SMatthias Springer 369e07a7fd5SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 370e07a7fd5SMatthias Springer const AnalysisState &state) const { 371e07a7fd5SMatthias Springer return {}; 372e07a7fd5SMatthias Springer } 373e07a7fd5SMatthias Springer 374e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 375e07a7fd5SMatthias Springer BufferizationState &state) const { 376e07a7fd5SMatthias Springer #ifndef NDEBUG 377e07a7fd5SMatthias Springer auto returnOp = cast<func::ReturnOp>(op); 378e07a7fd5SMatthias Springer assert(isa<FuncOp>(returnOp->getParentOp()) && 379e07a7fd5SMatthias Springer "only support FuncOp parent for ReturnOp"); 380e07a7fd5SMatthias Springer #endif // NDEBUG 381e07a7fd5SMatthias Springer 382e07a7fd5SMatthias Springer // ReturnOps are bufferized as part of FuncOps. 383e07a7fd5SMatthias Springer return failure(); 384e07a7fd5SMatthias Springer } 385e07a7fd5SMatthias Springer }; 386e07a7fd5SMatthias Springer 387e07a7fd5SMatthias Springer struct FuncOpInterface 388e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> { 389e07a7fd5SMatthias Springer /// Rewrite function bbArgs and return values into buffer form (using the 390e07a7fd5SMatthias Springer /// canonical memref layout for now). This function bufferizes the function 391e07a7fd5SMatthias Springer /// signature and the ReturnOp. When the entire function body has been 392e07a7fd5SMatthias Springer /// bufferized, function return types can be switched to more concise memref 393e07a7fd5SMatthias Springer /// types as part of `foldMemRefCasts`. 394e07a7fd5SMatthias Springer /// 395e07a7fd5SMatthias Springer /// When a tensor function argument is known to be equivalent to a tensor 396e07a7fd5SMatthias Springer /// result, it is dropped from the return values. 397e07a7fd5SMatthias Springer /// 398e07a7fd5SMatthias Springer /// All function bbArgs are writable unless they are explicitly marked as 399e07a7fd5SMatthias Springer /// read-only. Callers must insert copies when needed. 400e07a7fd5SMatthias Springer /// 401e07a7fd5SMatthias Springer /// Note: Returning a memref is possible, but corresponding CallOp 402e07a7fd5SMatthias Springer /// bufferizations fail unless `allowReturnAllocs`. 403e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 404e07a7fd5SMatthias Springer BufferizationState &state) const { 405e07a7fd5SMatthias Springer auto funcOp = cast<FuncOp>(op); 406e07a7fd5SMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 407e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = 408e07a7fd5SMatthias Springer getFuncAnalysisState(state.getAnalysisState()); 409e8f7d019SAlexander Belyaev const OneShotBufferizationOptions &options = 410e8f7d019SAlexander Belyaev static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 411e07a7fd5SMatthias Springer 412e07a7fd5SMatthias Springer // Construct the bufferized function type. 413e07a7fd5SMatthias Springer SmallVector<Type> argTypes; 414e07a7fd5SMatthias Springer for (const auto &it : llvm::enumerate(funcType.getInputs())) { 415e07a7fd5SMatthias Springer Type argType = it.value(); 416e07a7fd5SMatthias Springer if (auto tensorType = argType.dyn_cast<TensorType>()) { 417e07a7fd5SMatthias Springer argTypes.push_back( 418e07a7fd5SMatthias Springer getBufferizedFunctionArgType(funcOp, it.index(), options)); 419e07a7fd5SMatthias Springer continue; 420e07a7fd5SMatthias Springer } 421e07a7fd5SMatthias Springer argTypes.push_back(argType); 422e07a7fd5SMatthias Springer } 423e07a7fd5SMatthias Springer 424e07a7fd5SMatthias Springer // Bodiless functions are assumed opaque and we cannot know the 425e07a7fd5SMatthias Springer // bufferization contract they want to enforce. As a consequence, only 426e07a7fd5SMatthias Springer // support functions that don't return any tensors atm. 427e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) { 428e07a7fd5SMatthias Springer SmallVector<Type> retTypes; 429e07a7fd5SMatthias Springer for (Type resultType : funcType.getResults()) { 430e07a7fd5SMatthias Springer if (resultType.isa<TensorType>()) 431e07a7fd5SMatthias Springer return funcOp->emitError() << "cannot bufferize bodiless function " 432e07a7fd5SMatthias Springer << "that returns a tensor"; 433e07a7fd5SMatthias Springer retTypes.push_back(resultType); 434e07a7fd5SMatthias Springer } 435e07a7fd5SMatthias Springer funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); 436e07a7fd5SMatthias Springer return success(); 437e07a7fd5SMatthias Springer } 438e07a7fd5SMatthias Springer 439e07a7fd5SMatthias Springer // TODO: Support functions with multiple returns. 440e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 441e07a7fd5SMatthias Springer assert(returnOp && "expected func with single return op"); 442e07a7fd5SMatthias Springer 443e07a7fd5SMatthias Springer // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. 444e07a7fd5SMatthias Springer Block &frontBlock = funcOp.getBody().front(); 445e07a7fd5SMatthias Springer for (BlockArgument &bbArg : frontBlock.getArguments()) { 446e07a7fd5SMatthias Springer auto tensorType = bbArg.getType().dyn_cast<TensorType>(); 447e07a7fd5SMatthias Springer // Non-tensor types stay the same. 448e07a7fd5SMatthias Springer if (!tensorType) 449e07a7fd5SMatthias Springer continue; 450e07a7fd5SMatthias Springer 451e07a7fd5SMatthias Springer // Collect all uses of the bbArg. 452e07a7fd5SMatthias Springer SmallVector<OpOperand *> bbArgUses; 453e07a7fd5SMatthias Springer for (OpOperand &use : bbArg.getUses()) 454e07a7fd5SMatthias Springer bbArgUses.push_back(&use); 455e07a7fd5SMatthias Springer 456e07a7fd5SMatthias Springer // Change the bbArg type to memref. 457e07a7fd5SMatthias Springer Type memrefType = 458e07a7fd5SMatthias Springer getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); 459e07a7fd5SMatthias Springer bbArg.setType(memrefType); 460e07a7fd5SMatthias Springer 461e07a7fd5SMatthias Springer // Replace all uses of the original tensor bbArg. 462e07a7fd5SMatthias Springer rewriter.setInsertionPointToStart(&frontBlock); 463e07a7fd5SMatthias Springer if (!bbArgUses.empty()) { 464e07a7fd5SMatthias Springer // Insert to_tensor because the remaining function body has not been 465e07a7fd5SMatthias Springer // bufferized yet. 466e07a7fd5SMatthias Springer Value toTensorOp = 467e07a7fd5SMatthias Springer rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg); 468e07a7fd5SMatthias Springer for (OpOperand *use : bbArgUses) 469e07a7fd5SMatthias Springer use->set(toTensorOp); 470e07a7fd5SMatthias Springer } 471e07a7fd5SMatthias Springer } 472e07a7fd5SMatthias Springer 473e07a7fd5SMatthias Springer // 2. For each result, keep track of which inplace argument it reuses. 474e07a7fd5SMatthias Springer SmallVector<Value> returnValues; 475e07a7fd5SMatthias Springer for (OpOperand &returnOperand : returnOp->getOpOperands()) { 476e07a7fd5SMatthias Springer Value returnVal = returnOperand.get(); 477e07a7fd5SMatthias Springer 478e07a7fd5SMatthias Springer // If not a tensor type just forward it. 479e07a7fd5SMatthias Springer if (!returnVal.getType().isa<RankedTensorType>()) { 480e07a7fd5SMatthias Springer returnValues.push_back(returnVal); 481e07a7fd5SMatthias Springer continue; 482e07a7fd5SMatthias Springer } 483e07a7fd5SMatthias Springer 484e07a7fd5SMatthias Springer // If return operand is equivalent to some bbArg, no need to return it. 485e8f7d019SAlexander Belyaev if (options.dropEquivalentFuncResults) { 486e07a7fd5SMatthias Springer if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx( 487e07a7fd5SMatthias Springer funcOp, funcState, returnOperand.getOperandNumber())) { 488e07a7fd5SMatthias Springer rewriter.setInsertionPoint(returnOp); 489e07a7fd5SMatthias Springer Location loc = returnOp.getLoc(); 490e07a7fd5SMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 491e8f7d019SAlexander Belyaev loc, 492e8f7d019SAlexander Belyaev getMemRefType(returnVal.getType().cast<TensorType>(), options), 493e07a7fd5SMatthias Springer returnVal); 494e07a7fd5SMatthias Springer BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); 495e07a7fd5SMatthias Springer // Note: This copy will fold away. It must be inserted here to ensure 496e07a7fd5SMatthias Springer // that `returnVal` still has at least one use and does not fold away. 497e07a7fd5SMatthias Springer if (failed( 498*248e113eSMatthias Springer options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg))) 499e07a7fd5SMatthias Springer return funcOp->emitError("could not generate copy for bbArg"); 500e07a7fd5SMatthias Springer continue; 501e07a7fd5SMatthias Springer } 502e8f7d019SAlexander Belyaev } 503e07a7fd5SMatthias Springer 504e07a7fd5SMatthias Springer returnValues.push_back(*state.getBuffer(rewriter, returnOperand)); 505e07a7fd5SMatthias Springer } 506e07a7fd5SMatthias Springer 507e07a7fd5SMatthias Springer // 3. Rewrite the terminator without the in-place bufferizable values. 508e07a7fd5SMatthias Springer returnOp.operandsMutable().assign(returnValues); 509e07a7fd5SMatthias Springer 510e07a7fd5SMatthias Springer // 4. Rewrite the FuncOp type to buffer form. 511e07a7fd5SMatthias Springer funcOp.setType(FunctionType::get(op->getContext(), argTypes, 512e07a7fd5SMatthias Springer ValueRange(returnValues).getTypes())); 513e07a7fd5SMatthias Springer 514e07a7fd5SMatthias Springer return success(); 515e07a7fd5SMatthias Springer } 516e07a7fd5SMatthias Springer 517e07a7fd5SMatthias Springer /// Return `true` if the given function argument is writable. 518e07a7fd5SMatthias Springer bool isWritable(Operation *op, Value value, 519e07a7fd5SMatthias Springer const AnalysisState &state) const { 520e07a7fd5SMatthias Springer auto funcOp = cast<FuncOp>(op); 521e07a7fd5SMatthias Springer BlockArgument bbArg = value.dyn_cast<BlockArgument>(); 522e07a7fd5SMatthias Springer assert(bbArg && "expected BlockArgument"); 523e07a7fd5SMatthias Springer 524e07a7fd5SMatthias Springer // "bufferization.writable" overrides other writability decisions. This is 525e07a7fd5SMatthias Springer // currently used for testing only. 526e07a7fd5SMatthias Springer if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 527e07a7fd5SMatthias Springer bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 528e07a7fd5SMatthias Springer return writable.getValue(); 529e07a7fd5SMatthias Springer 530e07a7fd5SMatthias Springer // All function arguments are writable by default. 531e07a7fd5SMatthias Springer return true; 532e07a7fd5SMatthias Springer } 533e07a7fd5SMatthias Springer 534e07a7fd5SMatthias Springer bool isAllocationHoistingBarrier(Operation *op) const { return true; } 535e07a7fd5SMatthias Springer }; 536e07a7fd5SMatthias Springer 537e07a7fd5SMatthias Springer } // namespace func_ext 538e07a7fd5SMatthias Springer } // namespace bufferization 539e07a7fd5SMatthias Springer } // namespace mlir 540e07a7fd5SMatthias Springer 541e07a7fd5SMatthias Springer void mlir::bufferization::func_ext:: 542e07a7fd5SMatthias Springer registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 543e07a7fd5SMatthias Springer registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 544e07a7fd5SMatthias Springer func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 545e07a7fd5SMatthias Springer func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 546e07a7fd5SMatthias Springer func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 547e07a7fd5SMatthias Springer }); 548e07a7fd5SMatthias Springer } 549