1e07a7fd5SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2e07a7fd5SMatthias Springer // 3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e07a7fd5SMatthias Springer // 7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===// 8e07a7fd5SMatthias Springer 9e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 14e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 15e07a7fd5SMatthias Springer #include "mlir/IR/Dialect.h" 16e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h" 17e07a7fd5SMatthias Springer 18e07a7fd5SMatthias Springer namespace mlir { 19e07a7fd5SMatthias Springer namespace bufferization { 20e07a7fd5SMatthias Springer namespace func_ext { 21e07a7fd5SMatthias Springer 22e07a7fd5SMatthias Springer void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 23e07a7fd5SMatthias Springer analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 24e07a7fd5SMatthias Springer auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 25e07a7fd5SMatthias Springer auto createdAliasingOperands = 26e07a7fd5SMatthias Springer aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping()); 27e07a7fd5SMatthias Springer auto createdAliasingResults = 28e07a7fd5SMatthias Springer aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 29e07a7fd5SMatthias Springer auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 30e07a7fd5SMatthias Springer auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 31e07a7fd5SMatthias Springer (void)createdEquiv; 32e07a7fd5SMatthias Springer (void)createdAliasingOperands; 33e07a7fd5SMatthias Springer (void)createdAliasingResults; 34e07a7fd5SMatthias Springer (void)createdRead; 35e07a7fd5SMatthias Springer (void)createdWritten; 36e07a7fd5SMatthias Springer #ifndef NDEBUG 37e07a7fd5SMatthias Springer assert(createdEquiv.second && "equivalence info exists already"); 38e07a7fd5SMatthias Springer assert(createdAliasingOperands.second && "aliasing info exists already"); 39e07a7fd5SMatthias Springer assert(createdAliasingResults.second && "aliasing info exists already"); 40e07a7fd5SMatthias Springer assert(createdRead.second && "bbarg access info exists already"); 41e07a7fd5SMatthias Springer assert(createdWritten.second && "bbarg access info exists already"); 42e07a7fd5SMatthias Springer #endif // NDEBUG 43e07a7fd5SMatthias Springer } 44e07a7fd5SMatthias Springer 45e07a7fd5SMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`. 46e07a7fd5SMatthias Springer /// Return nullptr if there is no such unique ReturnOp. 47e07a7fd5SMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { 48e07a7fd5SMatthias Springer func::ReturnOp returnOp; 49e07a7fd5SMatthias Springer for (Block &b : funcOp.getBody()) { 50e07a7fd5SMatthias Springer if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 51e07a7fd5SMatthias Springer if (returnOp) 52e07a7fd5SMatthias Springer return nullptr; 53e07a7fd5SMatthias Springer returnOp = candidateOp; 54e07a7fd5SMatthias Springer } 55e07a7fd5SMatthias Springer } 56e07a7fd5SMatthias Springer return returnOp; 57e07a7fd5SMatthias Springer } 58e07a7fd5SMatthias Springer 59e07a7fd5SMatthias Springer /// Return the index-th bufferized function argument type. This assumes that the 60e07a7fd5SMatthias Springer /// specified argument is a tensor. If the tensor is ranked, a layout map may be 61*f287da8aSMatthias Springer /// specified by the user. If no layout map is specified, the default layout map 62*f287da8aSMatthias Springer /// (as per `options.functionBoundaryTypeConversion`) is used. 63e07a7fd5SMatthias Springer static BaseMemRefType 64e07a7fd5SMatthias Springer getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 65e07a7fd5SMatthias Springer const BufferizationOptions &options) { 66e07a7fd5SMatthias Springer auto tensorType = 67e07a7fd5SMatthias Springer funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>(); 68e07a7fd5SMatthias Springer assert(tensorType && "expected TensorType"); 69*f287da8aSMatthias Springer 70*f287da8aSMatthias Springer BaseMemRefType memrefType; 71*f287da8aSMatthias Springer if (options.functionBoundaryTypeConversion == 72*f287da8aSMatthias Springer BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 73*f287da8aSMatthias Springer memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); 74*f287da8aSMatthias Springer } else { 75*f287da8aSMatthias Springer // Note: Layout maps on function parameters cannot be inferred. The best we 76*f287da8aSMatthias Springer // can do at the moment is "fully dynamic". 77*f287da8aSMatthias Springer memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType); 78*f287da8aSMatthias Springer } 79e07a7fd5SMatthias Springer 80e07a7fd5SMatthias Springer auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 81e07a7fd5SMatthias Springer index, BufferizationDialect::kBufferLayoutAttrName); 82e07a7fd5SMatthias Springer if (!layoutAttr) 83e07a7fd5SMatthias Springer return memrefType; 84e07a7fd5SMatthias Springer 85e07a7fd5SMatthias Springer auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); 86e07a7fd5SMatthias Springer assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 87e07a7fd5SMatthias Springer return MemRefType::get( 88e07a7fd5SMatthias Springer rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 89e07a7fd5SMatthias Springer layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); 90e07a7fd5SMatthias Springer } 91e07a7fd5SMatthias Springer 92e07a7fd5SMatthias Springer /// Return the FuncOp called by `callOp`. 93e07a7fd5SMatthias Springer static FuncOp getCalledFunction(CallOpInterface callOp) { 94e07a7fd5SMatthias Springer SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 95e07a7fd5SMatthias Springer if (!sym) 96e07a7fd5SMatthias Springer return nullptr; 97e07a7fd5SMatthias Springer return dyn_cast_or_null<FuncOp>( 98e07a7fd5SMatthias Springer SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 99e07a7fd5SMatthias Springer } 100e07a7fd5SMatthias Springer 101e07a7fd5SMatthias Springer /// Get FuncAnalysisState. 102e07a7fd5SMatthias Springer static const FuncAnalysisState & 103e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) { 104e07a7fd5SMatthias Springer Optional<const FuncAnalysisState *> maybeState = 105e07a7fd5SMatthias Springer state.getDialectState<FuncAnalysisState>( 106e07a7fd5SMatthias Springer func::FuncDialect::getDialectNamespace()); 107e07a7fd5SMatthias Springer assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 108e07a7fd5SMatthias Springer return **maybeState; 109e07a7fd5SMatthias Springer } 110e07a7fd5SMatthias Springer 111e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp. 112e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 113e07a7fd5SMatthias Springer FuncOp funcOp) { 114e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 115e07a7fd5SMatthias Springer auto it = funcState.analyzedFuncOps.find(funcOp); 116e07a7fd5SMatthias Springer if (it == funcState.analyzedFuncOps.end()) 117e07a7fd5SMatthias Springer return FuncOpAnalysisState::NotAnalyzed; 118e07a7fd5SMatthias Springer return it->second; 119e07a7fd5SMatthias Springer } 120e07a7fd5SMatthias Springer 121e07a7fd5SMatthias Springer /// Return the index of the bbArg in the given FuncOp that is equivalent to the 122e07a7fd5SMatthias Springer /// specified return value (if any). 123e07a7fd5SMatthias Springer static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp, 124e07a7fd5SMatthias Springer const FuncAnalysisState &state, 125e07a7fd5SMatthias Springer int64_t returnValIdx) { 126e07a7fd5SMatthias Springer auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 127e07a7fd5SMatthias Springer if (funcOpIt == state.equivalentFuncArgs.end()) 128e07a7fd5SMatthias Springer // No equivalence info stores for funcOp. 129e07a7fd5SMatthias Springer return None; 130e07a7fd5SMatthias Springer 131e07a7fd5SMatthias Springer auto retValIt = funcOpIt->getSecond().find(returnValIdx); 132e07a7fd5SMatthias Springer if (retValIt == funcOpIt->getSecond().end()) 133e07a7fd5SMatthias Springer // Return value has no equivalent bbArg. 134e07a7fd5SMatthias Springer return None; 135e07a7fd5SMatthias Springer 136e07a7fd5SMatthias Springer return retValIt->getSecond(); 137e07a7fd5SMatthias Springer } 138e07a7fd5SMatthias Springer 139e07a7fd5SMatthias Springer struct CallOpInterface 140e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<CallOpInterface, 141e07a7fd5SMatthias Springer func::CallOp> { 142e07a7fd5SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 143e07a7fd5SMatthias Springer const AnalysisState &state) const { 144e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 145e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 146e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 147e07a7fd5SMatthias Springer 148e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 149e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 150e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Assume that OpOperand is read. 151e07a7fd5SMatthias Springer return true; 152e07a7fd5SMatthias Springer 153e07a7fd5SMatthias Springer return funcState.readBbArgs.lookup(funcOp).contains( 154e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 155e07a7fd5SMatthias Springer } 156e07a7fd5SMatthias Springer 157e07a7fd5SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 158e07a7fd5SMatthias Springer const AnalysisState &state) const { 159e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 160e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 161e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 162e07a7fd5SMatthias Springer 163e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 164e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 165e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Assume that OpOperand is written. 166e07a7fd5SMatthias Springer return true; 167e07a7fd5SMatthias Springer 168e07a7fd5SMatthias Springer return funcState.writtenBbArgs.lookup(funcOp).contains( 169e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 170e07a7fd5SMatthias Springer } 171e07a7fd5SMatthias Springer 172e07a7fd5SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 173e07a7fd5SMatthias Springer const AnalysisState &state) const { 174e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 175e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 176e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 177e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 178e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != 179e07a7fd5SMatthias Springer FuncOpAnalysisState::Analyzed) { 180e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Any OpResult may be aliasing. 181e07a7fd5SMatthias Springer SmallVector<OpResult> result; 182e07a7fd5SMatthias Springer for (OpResult opResult : op->getOpResults()) 183e07a7fd5SMatthias Springer if (opResult.getType().isa<TensorType>()) 184e07a7fd5SMatthias Springer result.push_back(opResult); 185e07a7fd5SMatthias Springer return result; 186e07a7fd5SMatthias Springer } 187e07a7fd5SMatthias Springer 188e07a7fd5SMatthias Springer // Get aliasing results from state. 189e07a7fd5SMatthias Springer auto aliasingReturnVals = 190e07a7fd5SMatthias Springer funcState.aliasingReturnVals.lookup(funcOp).lookup( 191e07a7fd5SMatthias Springer opOperand.getOperandNumber()); 192e07a7fd5SMatthias Springer SmallVector<OpResult> result; 193e07a7fd5SMatthias Springer for (int64_t resultIdx : aliasingReturnVals) 194e07a7fd5SMatthias Springer result.push_back(callOp->getOpResult(resultIdx)); 195e07a7fd5SMatthias Springer return result; 196e07a7fd5SMatthias Springer } 197e07a7fd5SMatthias Springer 198e07a7fd5SMatthias Springer SmallVector<OpOperand *> 199e07a7fd5SMatthias Springer getAliasingOpOperand(Operation *op, OpResult opResult, 200e07a7fd5SMatthias Springer const AnalysisState &state) const { 201e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 202e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 203e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 204e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 205e07a7fd5SMatthias Springer if (getFuncOpAnalysisState(state, funcOp) != 206e07a7fd5SMatthias Springer FuncOpAnalysisState::Analyzed) { 207e07a7fd5SMatthias Springer // FuncOp not analyzed yet. Any OpOperand may be aliasing. 208e07a7fd5SMatthias Springer SmallVector<OpOperand *> result; 209e07a7fd5SMatthias Springer for (OpOperand &opOperand : op->getOpOperands()) 210e07a7fd5SMatthias Springer if (opOperand.get().getType().isa<TensorType>()) 211e07a7fd5SMatthias Springer result.push_back(&opOperand); 212e07a7fd5SMatthias Springer return result; 213e07a7fd5SMatthias Springer } 214e07a7fd5SMatthias Springer 215e07a7fd5SMatthias Springer // Get aliasing bbArgs from state. 216e07a7fd5SMatthias Springer auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( 217e07a7fd5SMatthias Springer opResult.getResultNumber()); 218e07a7fd5SMatthias Springer SmallVector<OpOperand *> result; 219e07a7fd5SMatthias Springer for (int64_t bbArgIdx : aliasingFuncArgs) 220e07a7fd5SMatthias Springer result.push_back(&callOp->getOpOperand(bbArgIdx)); 221e07a7fd5SMatthias Springer return result; 222e07a7fd5SMatthias Springer } 223e07a7fd5SMatthias Springer 224e07a7fd5SMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 225e07a7fd5SMatthias Springer const AnalysisState &state) const { 226e07a7fd5SMatthias Springer return BufferRelation::Equivalent; 227e07a7fd5SMatthias Springer } 228e07a7fd5SMatthias Springer 229e07a7fd5SMatthias Springer /// All function arguments are writable. It is the responsibility of the 230e07a7fd5SMatthias Springer /// CallOp to insert buffer copies where necessary. 231e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 232e07a7fd5SMatthias Springer BufferizationState &state) const { 233e07a7fd5SMatthias Springer func::CallOp callOp = cast<func::CallOp>(op); 234e07a7fd5SMatthias Springer unsigned numResults = callOp.getNumResults(); 235e07a7fd5SMatthias Springer unsigned numOperands = callOp->getNumOperands(); 236e07a7fd5SMatthias Springer FuncOp funcOp = getCalledFunction(callOp); 237e07a7fd5SMatthias Springer assert(funcOp && "expected CallOp to a FuncOp"); 238e07a7fd5SMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 239e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = 240e07a7fd5SMatthias Springer getFuncAnalysisState(state.getAnalysisState()); 241e07a7fd5SMatthias Springer const OneShotBufferizationOptions &options = 242e07a7fd5SMatthias Springer static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 243e07a7fd5SMatthias Springer 244e07a7fd5SMatthias Springer // Result types of the bufferized CallOp. 245e07a7fd5SMatthias Springer SmallVector<Type> resultTypes; 246e07a7fd5SMatthias Springer // Replacement values for the existing CallOp. These are usually the results 247e07a7fd5SMatthias Springer // of the bufferized CallOp, unless a tensor result folds onto an operand. 248e07a7fd5SMatthias Springer SmallVector<Value> replacementValues(numResults, Value()); 249e07a7fd5SMatthias Springer // For non-tensor results: A mapping from return val indices of the old 250e07a7fd5SMatthias Springer // CallOp to return val indices of the bufferized CallOp. 251e07a7fd5SMatthias Springer SmallVector<Optional<unsigned>> retValMapping(numResults, None); 252e07a7fd5SMatthias Springer // Operands of the bufferized CallOp. 253e07a7fd5SMatthias Springer SmallVector<Value> newOperands(numOperands, Value()); 254e07a7fd5SMatthias Springer 255e07a7fd5SMatthias Springer // Based on previously gathered equivalence information, we know if a 256e07a7fd5SMatthias Springer // tensor result folds onto an operand. These are the only tensor value 257e07a7fd5SMatthias Springer // results that are supported at the moment. 258e07a7fd5SMatthias Springer // 259e07a7fd5SMatthias Springer // For tensors return values that do not fold onto an operand, additional 260e07a7fd5SMatthias Springer // work is needed (TODO) to either: 261e07a7fd5SMatthias Springer // * hoist a result into an inplaceable operand or 262e07a7fd5SMatthias Springer // * devise a better representation to truly return a buffer. 263e07a7fd5SMatthias Springer // 264e07a7fd5SMatthias Springer // Note: If a function has no body, no equivalence information is 265e07a7fd5SMatthias Springer // available. Consequently, a tensor return value cannot be proven to fold 266e07a7fd5SMatthias Springer // onto a FuncOp bbArg, so calls to such functions are not bufferizable at 267e07a7fd5SMatthias Springer // the moment. 268e07a7fd5SMatthias Springer 269e07a7fd5SMatthias Springer // 1. Compute the result types of the new CallOp. Tensor results that are 270e07a7fd5SMatthias Springer // equivalent to a FuncOp bbArg are no longer returned. 271e07a7fd5SMatthias Springer for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { 272e07a7fd5SMatthias Springer unsigned returnValIdx = it.index(); 273e07a7fd5SMatthias Springer Type returnType = it.value(); 274e07a7fd5SMatthias Springer if (!returnType.isa<TensorType>()) { 275e07a7fd5SMatthias Springer // Non-tensor values are returned. 276e07a7fd5SMatthias Springer retValMapping[returnValIdx] = resultTypes.size(); 277e07a7fd5SMatthias Springer resultTypes.push_back(returnType); 278e07a7fd5SMatthias Springer continue; 279e07a7fd5SMatthias Springer } 280e07a7fd5SMatthias Springer 281e8f7d019SAlexander Belyaev if (options.dropEquivalentFuncResults) { 282e07a7fd5SMatthias Springer if (Optional<int64_t> bbArgIdx = 283e07a7fd5SMatthias Springer getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) { 284e07a7fd5SMatthias Springer // Return operands that are equivalent to some bbArg, are not 285e07a7fd5SMatthias Springer // returned. 286e07a7fd5SMatthias Springer FailureOr<Value> bufferOrFailure = 287e07a7fd5SMatthias Springer state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); 288e07a7fd5SMatthias Springer if (failed(bufferOrFailure)) 289e07a7fd5SMatthias Springer return failure(); 290e07a7fd5SMatthias Springer replacementValues[returnValIdx] = *bufferOrFailure; 291e07a7fd5SMatthias Springer newOperands[*bbArgIdx] = *bufferOrFailure; 292e07a7fd5SMatthias Springer continue; 293e07a7fd5SMatthias Springer } 294e8f7d019SAlexander Belyaev } 295e07a7fd5SMatthias Springer 296e07a7fd5SMatthias Springer if (!options.allowReturnAllocs) 297e07a7fd5SMatthias Springer return callOp->emitError( 298e07a7fd5SMatthias Springer "call to FuncOp that returns non-equivalent tensors not supported"); 299e07a7fd5SMatthias Springer 300e07a7fd5SMatthias Springer // Returning a memref. This memref is not equivalent to any bbArg. It is 301e07a7fd5SMatthias Springer // likely a newly allocated buffer. We may want to hoist such allocations 302e07a7fd5SMatthias Springer // to the call site in the future. 303e07a7fd5SMatthias Springer retValMapping[returnValIdx] = resultTypes.size(); 304e07a7fd5SMatthias Springer resultTypes.push_back(funcType.getResult(resultTypes.size())); 305e07a7fd5SMatthias Springer } 306e07a7fd5SMatthias Springer 307e07a7fd5SMatthias Springer // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. 308e07a7fd5SMatthias Springer for (OpOperand &opOperand : callOp->getOpOperands()) { 309e07a7fd5SMatthias Springer unsigned idx = opOperand.getOperandNumber(); 310e07a7fd5SMatthias Springer Value tensorOperand = opOperand.get(); 311e07a7fd5SMatthias Springer 312e07a7fd5SMatthias Springer // Non-tensor operands are just copied. 313e07a7fd5SMatthias Springer if (!tensorOperand.getType().isa<TensorType>()) { 314e07a7fd5SMatthias Springer newOperands[idx] = tensorOperand; 315e07a7fd5SMatthias Springer continue; 316e07a7fd5SMatthias Springer } 317e07a7fd5SMatthias Springer 318e07a7fd5SMatthias Springer // Retrieve buffers for tensor operands. Tensor operand buffers, who's 319e07a7fd5SMatthias Springer // corresponding FuncOp bbArgs are equivalent to a returned tensor, were 320e07a7fd5SMatthias Springer // already stored in `newOperands` during Step 1. 321e07a7fd5SMatthias Springer Value buffer = newOperands[idx]; 322e07a7fd5SMatthias Springer if (!buffer) { 323e07a7fd5SMatthias Springer FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand); 324e07a7fd5SMatthias Springer if (failed(bufferOrFailure)) 325e07a7fd5SMatthias Springer return failure(); 326e07a7fd5SMatthias Springer buffer = *bufferOrFailure; 327e07a7fd5SMatthias Springer } 328e07a7fd5SMatthias Springer 329e07a7fd5SMatthias Springer // Caller / callee type mismatch is handled with a CastOp. 330e07a7fd5SMatthias Springer auto memRefType = funcType.getInput(idx); 331e07a7fd5SMatthias Springer // Since we don't yet have a clear layout story, to_memref may 332e07a7fd5SMatthias Springer // conservatively turn tensors into more dynamic memref than necessary. 333e07a7fd5SMatthias Springer // If the memref type of the callee fails, introduce an extra memref.cast 334e07a7fd5SMatthias Springer // that will either canonicalize away or fail compilation until we can do 335e07a7fd5SMatthias Springer // something better. 336e07a7fd5SMatthias Springer if (buffer.getType() != memRefType) { 337e07a7fd5SMatthias Springer assert( 338e07a7fd5SMatthias Springer memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && 339e07a7fd5SMatthias Springer "CallOp::bufferize: cast incompatible"); 340e07a7fd5SMatthias Springer Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), 341e07a7fd5SMatthias Springer memRefType, buffer); 342e07a7fd5SMatthias Springer buffer = castBuffer; 343e07a7fd5SMatthias Springer } 344e07a7fd5SMatthias Springer newOperands[idx] = buffer; 345e07a7fd5SMatthias Springer } 346e07a7fd5SMatthias Springer 347e07a7fd5SMatthias Springer // 3. Create the new CallOp. 348e07a7fd5SMatthias Springer Operation *newCallOp = rewriter.create<func::CallOp>( 349e07a7fd5SMatthias Springer callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 350e07a7fd5SMatthias Springer newCallOp->setAttrs(callOp->getAttrs()); 351e07a7fd5SMatthias Springer // Get replacement values for non-tensor / non-equivalent results. 352e07a7fd5SMatthias Springer for (unsigned i = 0; i < replacementValues.size(); ++i) { 353e07a7fd5SMatthias Springer if (replacementValues[i]) 354e07a7fd5SMatthias Springer continue; 355e07a7fd5SMatthias Springer replacementValues[i] = newCallOp->getResult(*retValMapping[i]); 356e07a7fd5SMatthias Springer } 357e07a7fd5SMatthias Springer 358e07a7fd5SMatthias Springer // 4. Replace the old op with the new op. 359e07a7fd5SMatthias Springer replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); 360e07a7fd5SMatthias Springer 361e07a7fd5SMatthias Springer return success(); 362e07a7fd5SMatthias Springer } 363e07a7fd5SMatthias Springer }; 364e07a7fd5SMatthias Springer 365e07a7fd5SMatthias Springer struct ReturnOpInterface 366e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 367e07a7fd5SMatthias Springer func::ReturnOp> { 368e07a7fd5SMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 369e07a7fd5SMatthias Springer const AnalysisState &state) const { 370e07a7fd5SMatthias Springer return true; 371e07a7fd5SMatthias Springer } 372e07a7fd5SMatthias Springer 373e07a7fd5SMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 374e07a7fd5SMatthias Springer const AnalysisState &state) const { 375e07a7fd5SMatthias Springer return false; 376e07a7fd5SMatthias Springer } 377e07a7fd5SMatthias Springer 378e07a7fd5SMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 379e07a7fd5SMatthias Springer const AnalysisState &state) const { 380e07a7fd5SMatthias Springer return {}; 381e07a7fd5SMatthias Springer } 382e07a7fd5SMatthias Springer 383e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 384e07a7fd5SMatthias Springer BufferizationState &state) const { 385e07a7fd5SMatthias Springer #ifndef NDEBUG 386e07a7fd5SMatthias Springer auto returnOp = cast<func::ReturnOp>(op); 387e07a7fd5SMatthias Springer assert(isa<FuncOp>(returnOp->getParentOp()) && 388e07a7fd5SMatthias Springer "only support FuncOp parent for ReturnOp"); 389e07a7fd5SMatthias Springer #endif // NDEBUG 390e07a7fd5SMatthias Springer 391e07a7fd5SMatthias Springer // ReturnOps are bufferized as part of FuncOps. 392e07a7fd5SMatthias Springer return failure(); 393e07a7fd5SMatthias Springer } 394e07a7fd5SMatthias Springer }; 395e07a7fd5SMatthias Springer 396e07a7fd5SMatthias Springer struct FuncOpInterface 397e07a7fd5SMatthias Springer : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> { 398*f287da8aSMatthias Springer /// Rewrite function bbArgs and return values into buffer form. This function 399*f287da8aSMatthias Springer /// bufferizes the function signature and the ReturnOp. When the entire 400*f287da8aSMatthias Springer /// function body has been bufferized, function return types can be switched 401*f287da8aSMatthias Springer /// to more concise memref types as part of `foldMemRefCasts`. 402e07a7fd5SMatthias Springer /// 403e07a7fd5SMatthias Springer /// When a tensor function argument is known to be equivalent to a tensor 404e07a7fd5SMatthias Springer /// result, it is dropped from the return values. 405e07a7fd5SMatthias Springer /// 406e07a7fd5SMatthias Springer /// All function bbArgs are writable unless they are explicitly marked as 407e07a7fd5SMatthias Springer /// read-only. Callers must insert copies when needed. 408e07a7fd5SMatthias Springer /// 409e07a7fd5SMatthias Springer /// Note: Returning a memref is possible, but corresponding CallOp 410e07a7fd5SMatthias Springer /// bufferizations fail unless `allowReturnAllocs`. 411e07a7fd5SMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 412e07a7fd5SMatthias Springer BufferizationState &state) const { 413e07a7fd5SMatthias Springer auto funcOp = cast<FuncOp>(op); 414e07a7fd5SMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 415e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = 416e07a7fd5SMatthias Springer getFuncAnalysisState(state.getAnalysisState()); 417e8f7d019SAlexander Belyaev const OneShotBufferizationOptions &options = 418e8f7d019SAlexander Belyaev static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 419e07a7fd5SMatthias Springer 420e07a7fd5SMatthias Springer // Construct the bufferized function type. 421e07a7fd5SMatthias Springer SmallVector<Type> argTypes; 422e07a7fd5SMatthias Springer for (const auto &it : llvm::enumerate(funcType.getInputs())) { 423e07a7fd5SMatthias Springer Type argType = it.value(); 424e07a7fd5SMatthias Springer if (auto tensorType = argType.dyn_cast<TensorType>()) { 425e07a7fd5SMatthias Springer argTypes.push_back( 426e07a7fd5SMatthias Springer getBufferizedFunctionArgType(funcOp, it.index(), options)); 427e07a7fd5SMatthias Springer continue; 428e07a7fd5SMatthias Springer } 429e07a7fd5SMatthias Springer argTypes.push_back(argType); 430e07a7fd5SMatthias Springer } 431e07a7fd5SMatthias Springer 432e07a7fd5SMatthias Springer // Bodiless functions are assumed opaque and we cannot know the 433e07a7fd5SMatthias Springer // bufferization contract they want to enforce. As a consequence, only 434e07a7fd5SMatthias Springer // support functions that don't return any tensors atm. 435e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) { 436e07a7fd5SMatthias Springer SmallVector<Type> retTypes; 437e07a7fd5SMatthias Springer for (Type resultType : funcType.getResults()) { 438e07a7fd5SMatthias Springer if (resultType.isa<TensorType>()) 439e07a7fd5SMatthias Springer return funcOp->emitError() << "cannot bufferize bodiless function " 440e07a7fd5SMatthias Springer << "that returns a tensor"; 441e07a7fd5SMatthias Springer retTypes.push_back(resultType); 442e07a7fd5SMatthias Springer } 443e07a7fd5SMatthias Springer funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); 444e07a7fd5SMatthias Springer return success(); 445e07a7fd5SMatthias Springer } 446e07a7fd5SMatthias Springer 447e07a7fd5SMatthias Springer // TODO: Support functions with multiple returns. 448e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 449e07a7fd5SMatthias Springer assert(returnOp && "expected func with single return op"); 450*f287da8aSMatthias Springer Location loc = returnOp.getLoc(); 451e07a7fd5SMatthias Springer 452e07a7fd5SMatthias Springer // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. 453e07a7fd5SMatthias Springer Block &frontBlock = funcOp.getBody().front(); 454e07a7fd5SMatthias Springer for (BlockArgument &bbArg : frontBlock.getArguments()) { 455e07a7fd5SMatthias Springer auto tensorType = bbArg.getType().dyn_cast<TensorType>(); 456e07a7fd5SMatthias Springer // Non-tensor types stay the same. 457e07a7fd5SMatthias Springer if (!tensorType) 458e07a7fd5SMatthias Springer continue; 459e07a7fd5SMatthias Springer 460e07a7fd5SMatthias Springer // Collect all uses of the bbArg. 461e07a7fd5SMatthias Springer SmallVector<OpOperand *> bbArgUses; 462e07a7fd5SMatthias Springer for (OpOperand &use : bbArg.getUses()) 463e07a7fd5SMatthias Springer bbArgUses.push_back(&use); 464e07a7fd5SMatthias Springer 465e07a7fd5SMatthias Springer // Change the bbArg type to memref. 466e07a7fd5SMatthias Springer Type memrefType = 467e07a7fd5SMatthias Springer getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); 468e07a7fd5SMatthias Springer bbArg.setType(memrefType); 469e07a7fd5SMatthias Springer 470e07a7fd5SMatthias Springer // Replace all uses of the original tensor bbArg. 471e07a7fd5SMatthias Springer rewriter.setInsertionPointToStart(&frontBlock); 472e07a7fd5SMatthias Springer if (!bbArgUses.empty()) { 473e07a7fd5SMatthias Springer // Insert to_tensor because the remaining function body has not been 474e07a7fd5SMatthias Springer // bufferized yet. 475e07a7fd5SMatthias Springer Value toTensorOp = 476e07a7fd5SMatthias Springer rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg); 477e07a7fd5SMatthias Springer for (OpOperand *use : bbArgUses) 478e07a7fd5SMatthias Springer use->set(toTensorOp); 479e07a7fd5SMatthias Springer } 480e07a7fd5SMatthias Springer } 481e07a7fd5SMatthias Springer 482e07a7fd5SMatthias Springer // 2. For each result, keep track of which inplace argument it reuses. 483e07a7fd5SMatthias Springer SmallVector<Value> returnValues; 484e07a7fd5SMatthias Springer for (OpOperand &returnOperand : returnOp->getOpOperands()) { 485e07a7fd5SMatthias Springer Value returnVal = returnOperand.get(); 486*f287da8aSMatthias Springer auto tensorType = returnVal.getType().dyn_cast<TensorType>(); 487*f287da8aSMatthias Springer rewriter.setInsertionPoint(returnOp); 488e07a7fd5SMatthias Springer 489e07a7fd5SMatthias Springer // If not a tensor type just forward it. 490*f287da8aSMatthias Springer if (!tensorType) { 491e07a7fd5SMatthias Springer returnValues.push_back(returnVal); 492e07a7fd5SMatthias Springer continue; 493e07a7fd5SMatthias Springer } 494e07a7fd5SMatthias Springer 495e07a7fd5SMatthias Springer // If return operand is equivalent to some bbArg, no need to return it. 496e8f7d019SAlexander Belyaev if (options.dropEquivalentFuncResults) { 497e07a7fd5SMatthias Springer if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx( 498e07a7fd5SMatthias Springer funcOp, funcState, returnOperand.getOperandNumber())) { 499*f287da8aSMatthias Springer // TODO: Use memref type with fully dynamic layout map and add folder 500*f287da8aSMatthias Springer // for memref.cast + memref.copy. 501e07a7fd5SMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 502*f287da8aSMatthias Springer loc, getMemRefType(tensorType, options), returnVal); 503e07a7fd5SMatthias Springer BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); 504e07a7fd5SMatthias Springer // Note: This copy will fold away. It must be inserted here to ensure 505e07a7fd5SMatthias Springer // that `returnVal` still has at least one use and does not fold away. 506e07a7fd5SMatthias Springer if (failed( 507248e113eSMatthias Springer options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg))) 508e07a7fd5SMatthias Springer return funcOp->emitError("could not generate copy for bbArg"); 509e07a7fd5SMatthias Springer continue; 510e07a7fd5SMatthias Springer } 511e8f7d019SAlexander Belyaev } 512e07a7fd5SMatthias Springer 513*f287da8aSMatthias Springer BaseMemRefType resultType; 514*f287da8aSMatthias Springer if (options.functionBoundaryTypeConversion == 515*f287da8aSMatthias Springer BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 516*f287da8aSMatthias Springer resultType = getMemRefTypeWithStaticIdentityLayout(tensorType); 517*f287da8aSMatthias Springer } else { 518*f287da8aSMatthias Springer // Note: If `InferLayoutMap`, cast are later folded away. 519*f287da8aSMatthias Springer resultType = getMemRefTypeWithFullyDynamicLayout(tensorType); 520*f287da8aSMatthias Springer } 521*f287da8aSMatthias Springer Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 522*f287da8aSMatthias Springer loc, resultType, returnVal); 523*f287da8aSMatthias Springer returnValues.push_back(toMemrefOp); 524e07a7fd5SMatthias Springer } 525e07a7fd5SMatthias Springer 526e07a7fd5SMatthias Springer // 3. Rewrite the terminator without the in-place bufferizable values. 527e07a7fd5SMatthias Springer returnOp.operandsMutable().assign(returnValues); 528e07a7fd5SMatthias Springer 529e07a7fd5SMatthias Springer // 4. Rewrite the FuncOp type to buffer form. 530e07a7fd5SMatthias Springer funcOp.setType(FunctionType::get(op->getContext(), argTypes, 531e07a7fd5SMatthias Springer ValueRange(returnValues).getTypes())); 532e07a7fd5SMatthias Springer 533e07a7fd5SMatthias Springer return success(); 534e07a7fd5SMatthias Springer } 535e07a7fd5SMatthias Springer 536e07a7fd5SMatthias Springer /// Return `true` if the given function argument is writable. 537e07a7fd5SMatthias Springer bool isWritable(Operation *op, Value value, 538e07a7fd5SMatthias Springer const AnalysisState &state) const { 539e07a7fd5SMatthias Springer auto funcOp = cast<FuncOp>(op); 540e07a7fd5SMatthias Springer BlockArgument bbArg = value.dyn_cast<BlockArgument>(); 541e07a7fd5SMatthias Springer assert(bbArg && "expected BlockArgument"); 542e07a7fd5SMatthias Springer 543e07a7fd5SMatthias Springer // "bufferization.writable" overrides other writability decisions. This is 544e07a7fd5SMatthias Springer // currently used for testing only. 545e07a7fd5SMatthias Springer if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 546e07a7fd5SMatthias Springer bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 547e07a7fd5SMatthias Springer return writable.getValue(); 548e07a7fd5SMatthias Springer 549e07a7fd5SMatthias Springer // All function arguments are writable by default. 550e07a7fd5SMatthias Springer return true; 551e07a7fd5SMatthias Springer } 552e07a7fd5SMatthias Springer 553e07a7fd5SMatthias Springer bool isAllocationHoistingBarrier(Operation *op) const { return true; } 554e07a7fd5SMatthias Springer }; 555e07a7fd5SMatthias Springer 556e07a7fd5SMatthias Springer } // namespace func_ext 557e07a7fd5SMatthias Springer } // namespace bufferization 558e07a7fd5SMatthias Springer } // namespace mlir 559e07a7fd5SMatthias Springer 560e07a7fd5SMatthias Springer void mlir::bufferization::func_ext:: 561e07a7fd5SMatthias Springer registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 562e07a7fd5SMatthias Springer registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 563e07a7fd5SMatthias Springer func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 564e07a7fd5SMatthias Springer func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 565e07a7fd5SMatthias Springer func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 566e07a7fd5SMatthias Springer }); 567e07a7fd5SMatthias Springer } 568