1e07a7fd5SMatthias Springer //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2e07a7fd5SMatthias Springer // 3e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e07a7fd5SMatthias Springer // 7e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===// 8e07a7fd5SMatthias Springer // 9e07a7fd5SMatthias Springer // Module Bufferization is an extension of One-Shot Bufferize that 10e07a7fd5SMatthias Springer // bufferizes function boundaries. It provides `BufferizableOpInterface` 11e07a7fd5SMatthias Springer // implementations for FuncOp, CallOp and ReturnOp. 12e07a7fd5SMatthias Springer // 13e07a7fd5SMatthias Springer // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14e07a7fd5SMatthias Springer // This function analyzes the given module and determines the order of analysis 15e07a7fd5SMatthias Springer // and bufferization: Functions that are called are processed before their 16e07a7fd5SMatthias Springer // respective callers. 17e07a7fd5SMatthias Springer // 18e07a7fd5SMatthias Springer // After analyzing a FuncOp, additional information about its bbArgs is 19e07a7fd5SMatthias Springer // gathered through PostAnalysisStepFns and stored in `FuncAnalysisState`. 20e07a7fd5SMatthias Springer // 21e07a7fd5SMatthias Springer // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22e07a7fd5SMatthias Springer // for 23e07a7fd5SMatthias Springer // each tensor return value (if any). 24e07a7fd5SMatthias Springer // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25e07a7fd5SMatthias Springer // read/written. 26e07a7fd5SMatthias Springer // 27e07a7fd5SMatthias Springer // Only tensors that are equivalent to some FuncOp bbArg may be returned. 28e07a7fd5SMatthias Springer // Bufferization currently fails if other tensors (in particular tensors that 29e07a7fd5SMatthias Springer // bufferize out-of-place and result in a new buffer allocation) are returned. 30e07a7fd5SMatthias Springer // In the future, such allocations could be hoisted to the caller. 31e07a7fd5SMatthias Springer // 32e07a7fd5SMatthias Springer // Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. 33e07a7fd5SMatthias Springer // ``` 34e07a7fd5SMatthias Springer // func @foo() -> tensor<?xf32> { 35ec55f0bdSMatthias Springer // %0 = bufferization.alloc_tensor(...) : tensor<?xf32> 36e07a7fd5SMatthias Springer // return %0 : tensor<?xf32> 37e07a7fd5SMatthias Springer // } 38e07a7fd5SMatthias Springer // ``` 39e07a7fd5SMatthias Springer // 40e07a7fd5SMatthias Springer // Module Bufferization implements the following calling convention. 41e07a7fd5SMatthias Springer // 42e07a7fd5SMatthias Springer // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 43e07a7fd5SMatthias Springer // be written to in-place. 44e07a7fd5SMatthias Springer // * If a tensor operand of a CallOp is read after the CallOp, the operand of 45e07a7fd5SMatthias Springer // the CallOp must bufferize out-of-place. 46e07a7fd5SMatthias Springer // 47e07a7fd5SMatthias Springer // Example: The tensor.insert op bufferizes in-place because it is allowed to 48e07a7fd5SMatthias Springer // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 49e07a7fd5SMatthias Springer // out-of-place because `%t0` is modified by the callee but read by the 50e07a7fd5SMatthias Springer // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 51e07a7fd5SMatthias Springer // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 52e07a7fd5SMatthias Springer // ``` 53e07a7fd5SMatthias Springer // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 54e07a7fd5SMatthias Springer // %f = ... : f32 55e07a7fd5SMatthias Springer // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 56e07a7fd5SMatthias Springer // return %0 : tensor<?xf32> 57e07a7fd5SMatthias Springer // } 58e07a7fd5SMatthias Springer // 59e07a7fd5SMatthias Springer // func @caller() -> () { 60e07a7fd5SMatthias Springer // %t0 = ... : tensor<?xf32> 61e07a7fd5SMatthias Springer // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 62e07a7fd5SMatthias Springer // %2 = tensor.extract %1[...] : tensor<?xf32> 63e07a7fd5SMatthias Springer // } 64e07a7fd5SMatthias Springer // ``` 65e07a7fd5SMatthias Springer // 66e07a7fd5SMatthias Springer // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 67e07a7fd5SMatthias Springer // analyze the function body. In such a case, the CallOp analysis conservatively 68e07a7fd5SMatthias Springer // assumes that each tensor OpOperand is both read and written. 69e07a7fd5SMatthias Springer // 70e07a7fd5SMatthias Springer // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 71e07a7fd5SMatthias Springer // as "not reading" and/or "not writing". 72e07a7fd5SMatthias Springer 73e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 74e07a7fd5SMatthias Springer 75e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 76e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 77e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 78e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 79e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 80e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 81e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 82e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h" 83e07a7fd5SMatthias Springer 84e07a7fd5SMatthias Springer using namespace mlir; 85e07a7fd5SMatthias Springer using namespace mlir::bufferization; 86e07a7fd5SMatthias Springer using namespace mlir::bufferization::func_ext; 87e07a7fd5SMatthias Springer 88e07a7fd5SMatthias Springer /// A mapping of FuncOps to their callers. 89e07a7fd5SMatthias Springer using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 90e07a7fd5SMatthias Springer 91e07a7fd5SMatthias Springer /// Get FuncAnalysisState. 92e07a7fd5SMatthias Springer static const FuncAnalysisState & 93e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) { 94e07a7fd5SMatthias Springer Optional<const FuncAnalysisState *> maybeState = 95e07a7fd5SMatthias Springer state.getDialectState<FuncAnalysisState>( 96e07a7fd5SMatthias Springer func::FuncDialect::getDialectNamespace()); 97e07a7fd5SMatthias Springer assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 98e07a7fd5SMatthias Springer return **maybeState; 99e07a7fd5SMatthias Springer } 100e07a7fd5SMatthias Springer 101e07a7fd5SMatthias Springer /// Get or create FuncAnalysisState. 102e07a7fd5SMatthias Springer static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { 103e07a7fd5SMatthias Springer return state.getOrCreateDialectState<FuncAnalysisState>( 104e07a7fd5SMatthias Springer func::FuncDialect::getDialectNamespace()); 105e07a7fd5SMatthias Springer } 106e07a7fd5SMatthias Springer 107e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp. 108*02d3499aSStella Laurenzo /// Used for debug modes. 109*02d3499aSStella Laurenzo LLVM_ATTRIBUTE_UNUSED 110e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 111e07a7fd5SMatthias Springer func::FuncOp funcOp) { 112e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 113e07a7fd5SMatthias Springer auto it = funcState.analyzedFuncOps.find(funcOp); 114e07a7fd5SMatthias Springer if (it == funcState.analyzedFuncOps.end()) 115e07a7fd5SMatthias Springer return FuncOpAnalysisState::NotAnalyzed; 116e07a7fd5SMatthias Springer return it->second; 117e07a7fd5SMatthias Springer } 118e07a7fd5SMatthias Springer 119e07a7fd5SMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`. 120e07a7fd5SMatthias Springer /// Return nullptr if there is no such unique ReturnOp. 121e07a7fd5SMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 122e07a7fd5SMatthias Springer func::ReturnOp returnOp; 123e07a7fd5SMatthias Springer for (Block &b : funcOp.getBody()) { 124e07a7fd5SMatthias Springer if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 125e07a7fd5SMatthias Springer if (returnOp) 126e07a7fd5SMatthias Springer return nullptr; 127e07a7fd5SMatthias Springer returnOp = candidateOp; 128e07a7fd5SMatthias Springer } 129e07a7fd5SMatthias Springer } 130e07a7fd5SMatthias Springer return returnOp; 131e07a7fd5SMatthias Springer } 132e07a7fd5SMatthias Springer 133e07a7fd5SMatthias Springer namespace { 134e07a7fd5SMatthias Springer 135e07a7fd5SMatthias Springer /// Annotate IR with the results of the analysis. For testing purposes only. 136e07a7fd5SMatthias Springer static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 137e07a7fd5SMatthias Springer BlockArgument bbArg) { 138e07a7fd5SMatthias Springer const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 139e07a7fd5SMatthias Springer Operation *op = returnVal.getOwner(); 140e07a7fd5SMatthias Springer 141e07a7fd5SMatthias Springer SmallVector<int64_t> equivBbArgs; 142e07a7fd5SMatthias Springer if (op->hasAttr(kEquivalentArgsAttr)) { 143e07a7fd5SMatthias Springer auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>(); 144e07a7fd5SMatthias Springer equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 145e07a7fd5SMatthias Springer return a.cast<IntegerAttr>().getValue().getSExtValue(); 146e07a7fd5SMatthias Springer })); 147e07a7fd5SMatthias Springer } else { 148e07a7fd5SMatthias Springer equivBbArgs.append(op->getNumOperands(), -1); 149e07a7fd5SMatthias Springer } 150e07a7fd5SMatthias Springer equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 151e07a7fd5SMatthias Springer 152e07a7fd5SMatthias Springer OpBuilder b(op->getContext()); 153e07a7fd5SMatthias Springer op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 154e07a7fd5SMatthias Springer } 155e07a7fd5SMatthias Springer 156e07a7fd5SMatthias Springer /// Store function BlockArguments that are equivalent to/aliasing a returned 157e07a7fd5SMatthias Springer /// value in FuncAnalysisState. 158e07a7fd5SMatthias Springer static LogicalResult 159e07a7fd5SMatthias Springer aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, 160e07a7fd5SMatthias Springer BufferizationAliasInfo &aliasInfo, 161e07a7fd5SMatthias Springer SmallVector<Operation *> &newOps) { 162e07a7fd5SMatthias Springer FuncAnalysisState &funcState = getFuncAnalysisState(state); 163e07a7fd5SMatthias Springer 164e07a7fd5SMatthias Springer // Support only single return-terminated block in the function. 165e07a7fd5SMatthias Springer auto funcOp = cast<func::FuncOp>(op); 166e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 167e07a7fd5SMatthias Springer assert(returnOp && "expected func with single return op"); 168e07a7fd5SMatthias Springer 169e07a7fd5SMatthias Springer for (OpOperand &returnVal : returnOp->getOpOperands()) 170e07a7fd5SMatthias Springer if (returnVal.get().getType().isa<RankedTensorType>()) 171e07a7fd5SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) 172e07a7fd5SMatthias Springer if (bbArg.getType().isa<RankedTensorType>()) { 173e07a7fd5SMatthias Springer int64_t returnIdx = returnVal.getOperandNumber(); 174e07a7fd5SMatthias Springer int64_t bbArgIdx = bbArg.getArgNumber(); 175e07a7fd5SMatthias Springer if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { 176e07a7fd5SMatthias Springer funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; 177e07a7fd5SMatthias Springer if (state.getOptions().testAnalysisOnly) 178e07a7fd5SMatthias Springer annotateEquivalentReturnBbArg(returnVal, bbArg); 179e07a7fd5SMatthias Springer } 180e07a7fd5SMatthias Springer if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) { 181e07a7fd5SMatthias Springer funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); 182e07a7fd5SMatthias Springer funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 183e07a7fd5SMatthias Springer } 184e07a7fd5SMatthias Springer } 185e07a7fd5SMatthias Springer 186e07a7fd5SMatthias Springer return success(); 187e07a7fd5SMatthias Springer } 188e07a7fd5SMatthias Springer 189e07a7fd5SMatthias Springer /// Return true if the buffer of the given tensor value is written to. Must not 190e07a7fd5SMatthias Springer /// be called for values inside not yet analyzed functions. (Post-analysis 191e07a7fd5SMatthias Springer /// steps do not have to be run yet, i.e., "in progress" is also OK.) 192e07a7fd5SMatthias Springer static bool isValueWritten(Value value, const AnalysisState &state, 193e07a7fd5SMatthias Springer const BufferizationAliasInfo &aliasInfo) { 194e07a7fd5SMatthias Springer #ifndef NDEBUG 195e07a7fd5SMatthias Springer assert(value.getType().isa<TensorType>() && "expected TensorType"); 196e07a7fd5SMatthias Springer func::FuncOp funcOp; 197e07a7fd5SMatthias Springer if (auto bbArg = value.dyn_cast<BlockArgument>()) { 198e07a7fd5SMatthias Springer Operation *owner = bbArg.getOwner()->getParentOp(); 199e07a7fd5SMatthias Springer funcOp = isa<func::FuncOp>(owner) ? cast<func::FuncOp>(owner) 200e07a7fd5SMatthias Springer : owner->getParentOfType<func::FuncOp>(); 201e07a7fd5SMatthias Springer } else { 202e07a7fd5SMatthias Springer funcOp = value.getDefiningOp()->getParentOfType<func::FuncOp>(); 203e07a7fd5SMatthias Springer } 204e07a7fd5SMatthias Springer assert(getFuncOpAnalysisState(state, funcOp) != 205e07a7fd5SMatthias Springer FuncOpAnalysisState::NotAnalyzed && 206e07a7fd5SMatthias Springer "FuncOp must be fully analyzed or analysis in progress"); 207e07a7fd5SMatthias Springer #endif // NDEBUG 208e07a7fd5SMatthias Springer 209e07a7fd5SMatthias Springer bool isWritten = false; 210e07a7fd5SMatthias Springer aliasInfo.applyOnAliases(value, [&](Value val) { 211e07a7fd5SMatthias Springer for (OpOperand &use : val.getUses()) 212e07a7fd5SMatthias Springer if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use)) 213e07a7fd5SMatthias Springer isWritten = true; 214e07a7fd5SMatthias Springer }); 215e07a7fd5SMatthias Springer return isWritten; 216e07a7fd5SMatthias Springer } 217e07a7fd5SMatthias Springer 218e07a7fd5SMatthias Springer static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, 219e07a7fd5SMatthias Springer bool isRead, bool isWritten) { 220e07a7fd5SMatthias Springer OpBuilder b(funcOp.getContext()); 221e07a7fd5SMatthias Springer Attribute accessType; 222e07a7fd5SMatthias Springer if (isRead && isWritten) { 223e07a7fd5SMatthias Springer accessType = b.getStringAttr("read-write"); 224e07a7fd5SMatthias Springer } else if (isRead) { 225e07a7fd5SMatthias Springer accessType = b.getStringAttr("read"); 226e07a7fd5SMatthias Springer } else if (isWritten) { 227e07a7fd5SMatthias Springer accessType = b.getStringAttr("write"); 228e07a7fd5SMatthias Springer } else { 229e07a7fd5SMatthias Springer accessType = b.getStringAttr("none"); 230e07a7fd5SMatthias Springer } 231e07a7fd5SMatthias Springer funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); 232e07a7fd5SMatthias Springer } 233e07a7fd5SMatthias Springer 234e07a7fd5SMatthias Springer /// Determine which FuncOp bbArgs are read and which are written. If this 235e07a7fd5SMatthias Springer /// PostAnalysisStepFn is run on a function with unknown ops, it will 236e07a7fd5SMatthias Springer /// conservatively assume that such ops bufferize to a read + write. 237e07a7fd5SMatthias Springer static LogicalResult 238e07a7fd5SMatthias Springer funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, 239e07a7fd5SMatthias Springer BufferizationAliasInfo &aliasInfo, 240e07a7fd5SMatthias Springer SmallVector<Operation *> &newOps) { 241e07a7fd5SMatthias Springer FuncAnalysisState &funcState = getFuncAnalysisState(state); 242e07a7fd5SMatthias Springer auto funcOp = cast<func::FuncOp>(op); 243e07a7fd5SMatthias Springer 244e07a7fd5SMatthias Springer // If the function has no body, conservatively assume that all args are 245e07a7fd5SMatthias Springer // read + written. 246e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) { 247e07a7fd5SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) { 248e07a7fd5SMatthias Springer funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 249e07a7fd5SMatthias Springer funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 250e07a7fd5SMatthias Springer } 251e07a7fd5SMatthias Springer 252e07a7fd5SMatthias Springer return success(); 253e07a7fd5SMatthias Springer } 254e07a7fd5SMatthias Springer 255e07a7fd5SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) { 256e07a7fd5SMatthias Springer if (!bbArg.getType().isa<TensorType>()) 257e07a7fd5SMatthias Springer continue; 258e07a7fd5SMatthias Springer bool isRead = state.isValueRead(bbArg); 259e07a7fd5SMatthias Springer bool isWritten = isValueWritten(bbArg, state, aliasInfo); 260e07a7fd5SMatthias Springer if (state.getOptions().testAnalysisOnly) 261e07a7fd5SMatthias Springer annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); 262e07a7fd5SMatthias Springer if (isRead) 263e07a7fd5SMatthias Springer funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 264e07a7fd5SMatthias Springer if (isWritten) 265e07a7fd5SMatthias Springer funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 266e07a7fd5SMatthias Springer } 267e07a7fd5SMatthias Springer 268e07a7fd5SMatthias Springer return success(); 269e07a7fd5SMatthias Springer } 270e07a7fd5SMatthias Springer } // namespace 271e07a7fd5SMatthias Springer 272e07a7fd5SMatthias Springer /// Remove bufferization attributes on FuncOp arguments. 273e07a7fd5SMatthias Springer static void removeBufferizationAttributes(BlockArgument bbArg) { 274e07a7fd5SMatthias Springer auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 275e07a7fd5SMatthias Springer funcOp.removeArgAttr(bbArg.getArgNumber(), 276e07a7fd5SMatthias Springer BufferizationDialect::kBufferLayoutAttrName); 277e07a7fd5SMatthias Springer funcOp.removeArgAttr(bbArg.getArgNumber(), 278e07a7fd5SMatthias Springer BufferizationDialect::kWritableAttrName); 279e07a7fd5SMatthias Springer } 280e07a7fd5SMatthias Springer 281e07a7fd5SMatthias Springer /// Return the func::FuncOp called by `callOp`. 282e07a7fd5SMatthias Springer static func::FuncOp getCalledFunction(CallOpInterface callOp) { 283e07a7fd5SMatthias Springer SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 284e07a7fd5SMatthias Springer if (!sym) 285e07a7fd5SMatthias Springer return nullptr; 286e07a7fd5SMatthias Springer return dyn_cast_or_null<func::FuncOp>( 287e07a7fd5SMatthias Springer SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 288e07a7fd5SMatthias Springer } 289e07a7fd5SMatthias Springer 290e07a7fd5SMatthias Springer /// Gather equivalence info of CallOps. 291e07a7fd5SMatthias Springer /// Note: This only adds new equivalence info if the called function was already 292e07a7fd5SMatthias Springer /// analyzed. 293e07a7fd5SMatthias Springer // TODO: This does not handle cyclic function call graphs etc. 294e07a7fd5SMatthias Springer static void equivalenceAnalysis(func::FuncOp funcOp, 295e07a7fd5SMatthias Springer BufferizationAliasInfo &aliasInfo, 296e07a7fd5SMatthias Springer FuncAnalysisState &funcState) { 297e07a7fd5SMatthias Springer funcOp->walk([&](func::CallOp callOp) { 298e07a7fd5SMatthias Springer func::FuncOp calledFunction = getCalledFunction(callOp); 299e07a7fd5SMatthias Springer assert(calledFunction && "could not retrieved called func::FuncOp"); 300e07a7fd5SMatthias Springer 301e07a7fd5SMatthias Springer // No equivalence info available for the called function. 302e07a7fd5SMatthias Springer if (!funcState.equivalentFuncArgs.count(calledFunction)) 303e07a7fd5SMatthias Springer return WalkResult::skip(); 304e07a7fd5SMatthias Springer 305e07a7fd5SMatthias Springer for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 306e07a7fd5SMatthias Springer int64_t returnIdx = it.first; 307e07a7fd5SMatthias Springer int64_t bbargIdx = it.second; 308e07a7fd5SMatthias Springer Value returnVal = callOp.getResult(returnIdx); 309e07a7fd5SMatthias Springer Value argVal = callOp->getOperand(bbargIdx); 310e07a7fd5SMatthias Springer aliasInfo.unionEquivalenceClasses(returnVal, argVal); 311e07a7fd5SMatthias Springer } 312e07a7fd5SMatthias Springer 313e07a7fd5SMatthias Springer return WalkResult::advance(); 314e07a7fd5SMatthias Springer }); 315e07a7fd5SMatthias Springer } 316e07a7fd5SMatthias Springer 317e07a7fd5SMatthias Springer /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 318e07a7fd5SMatthias Springer /// callee-caller order (i.e. callees without callers first). 319e07a7fd5SMatthias Springer /// Store the map of FuncOp to all its callers in `callerMap`. 320e07a7fd5SMatthias Springer /// Return `failure()` if a cycle of calls is detected or if we are unable to 321e07a7fd5SMatthias Springer /// retrieve the called FuncOp from any CallOpInterface. 322e07a7fd5SMatthias Springer static LogicalResult 323e07a7fd5SMatthias Springer getFuncOpsOrderedByCalls(ModuleOp moduleOp, 324e07a7fd5SMatthias Springer SmallVectorImpl<func::FuncOp> &orderedFuncOps, 325e07a7fd5SMatthias Springer FuncCallerMap &callerMap) { 326e07a7fd5SMatthias Springer // For each FuncOp, the set of functions called by it (i.e. the union of 327e07a7fd5SMatthias Springer // symbols of all nested CallOpInterfaceOp). 328e07a7fd5SMatthias Springer DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 329e07a7fd5SMatthias Springer // For each FuncOp, the number of CallOpInterface it contains. 330e07a7fd5SMatthias Springer DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 331e07a7fd5SMatthias Springer WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 332e07a7fd5SMatthias Springer if (!funcOp.getBody().empty()) { 333e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 334e07a7fd5SMatthias Springer if (!returnOp) 335e07a7fd5SMatthias Springer return funcOp->emitError() 336e07a7fd5SMatthias Springer << "cannot bufferize a FuncOp with tensors and " 337e07a7fd5SMatthias Springer "without a unique ReturnOp"; 338e07a7fd5SMatthias Springer } 339e07a7fd5SMatthias Springer 340e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[funcOp] = 0; 341e07a7fd5SMatthias Springer return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { 342e07a7fd5SMatthias Springer // Only support CallOp for now. 343e07a7fd5SMatthias Springer if (!isa<func::CallOp>(callOp.getOperation())) 344e07a7fd5SMatthias Springer return callOp->emitError() << "expected a CallOp"; 345e07a7fd5SMatthias Springer func::FuncOp calledFunction = getCalledFunction(callOp); 346e07a7fd5SMatthias Springer assert(calledFunction && "could not retrieved called func::FuncOp"); 34786fd1c13SBenjamin Kramer callerMap[calledFunction].insert(callOp); 34886fd1c13SBenjamin Kramer if (calledBy[calledFunction].insert(funcOp).second) { 349e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[funcOp]++; 350e07a7fd5SMatthias Springer } 351e07a7fd5SMatthias Springer return WalkResult::advance(); 352e07a7fd5SMatthias Springer }); 353e07a7fd5SMatthias Springer }); 354e07a7fd5SMatthias Springer if (res.wasInterrupted()) 355e07a7fd5SMatthias Springer return failure(); 356e07a7fd5SMatthias Springer // Iteratively remove function operation that do not call any of the 357e07a7fd5SMatthias Springer // functions remaining in the callCounter map and add them to the worklist. 358e07a7fd5SMatthias Springer while (!numberCallOpsContainedInFuncOp.empty()) { 359e07a7fd5SMatthias Springer auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 360e07a7fd5SMatthias Springer [](auto entry) { return entry.getSecond() == 0; }); 361e07a7fd5SMatthias Springer if (it == numberCallOpsContainedInFuncOp.end()) 362e07a7fd5SMatthias Springer return moduleOp.emitOpError( 363e07a7fd5SMatthias Springer "expected callgraph to be free of circular dependencies."); 364e07a7fd5SMatthias Springer orderedFuncOps.push_back(it->getFirst()); 365e07a7fd5SMatthias Springer for (auto callee : calledBy[it->getFirst()]) 366e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[callee]--; 367e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp.erase(it); 368e07a7fd5SMatthias Springer } 369e07a7fd5SMatthias Springer return success(); 370e07a7fd5SMatthias Springer } 371e07a7fd5SMatthias Springer 372e07a7fd5SMatthias Springer /// Set the attribute that triggers inplace bufferization on a FuncOp argument 373e07a7fd5SMatthias Springer /// `bbArg`. 374e07a7fd5SMatthias Springer static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { 375e07a7fd5SMatthias Springer auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 376e07a7fd5SMatthias Springer funcOp.setArgAttr(bbArg.getArgNumber(), 377e07a7fd5SMatthias Springer BufferizableOpInterface::kInplaceableAttrName, 378e07a7fd5SMatthias Springer BoolAttr::get(bbArg.getContext(), inPlace)); 379e07a7fd5SMatthias Springer } 380e07a7fd5SMatthias Springer 381e07a7fd5SMatthias Springer /// Annotate the IR with the result of the analysis. For testing/debugging only. 382e07a7fd5SMatthias Springer static void annotateOpsWithBufferizationMarkers(func::FuncOp funcOp, 383e07a7fd5SMatthias Springer const AnalysisState &state) { 384e07a7fd5SMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation()); 385e07a7fd5SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) 386e07a7fd5SMatthias Springer if (bbArg.getType().isa<TensorType>()) 387e07a7fd5SMatthias Springer setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); 388e07a7fd5SMatthias Springer } 389e07a7fd5SMatthias Springer 390e07a7fd5SMatthias Springer /// Fold return values that are memref casts and update function return types. 391e07a7fd5SMatthias Springer /// 392e07a7fd5SMatthias Springer /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 393e07a7fd5SMatthias Springer /// is not known yet. Therefore, the bufferization uses memref types with the 394e07a7fd5SMatthias Springer /// most generic layout map as function return types. After bufferizing the 395e07a7fd5SMatthias Springer /// entire function body, a more concise memref type can potentially be used for 396e07a7fd5SMatthias Springer /// the return type of the function. 397e07a7fd5SMatthias Springer static void foldMemRefCasts(func::FuncOp funcOp) { 398e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) 399e07a7fd5SMatthias Springer return; 400e07a7fd5SMatthias Springer 401e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 402e07a7fd5SMatthias Springer SmallVector<Type> resultTypes; 403e07a7fd5SMatthias Springer 404e07a7fd5SMatthias Springer for (OpOperand &operand : returnOp->getOpOperands()) { 405e07a7fd5SMatthias Springer if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { 406e07a7fd5SMatthias Springer operand.set(castOp.source()); 407e07a7fd5SMatthias Springer resultTypes.push_back(castOp.source().getType()); 408e07a7fd5SMatthias Springer } else { 409e07a7fd5SMatthias Springer resultTypes.push_back(operand.get().getType()); 410e07a7fd5SMatthias Springer } 411e07a7fd5SMatthias Springer } 412e07a7fd5SMatthias Springer 413e07a7fd5SMatthias Springer auto newFuncType = FunctionType::get( 414e07a7fd5SMatthias Springer funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 415e07a7fd5SMatthias Springer funcOp.setType(newFuncType); 416e07a7fd5SMatthias Springer } 417e07a7fd5SMatthias Springer 418e07a7fd5SMatthias Springer LogicalResult mlir::bufferization::runOneShotModuleBufferize( 419e07a7fd5SMatthias Springer ModuleOp moduleOp, OneShotBufferizationOptions options) { 420d6dab38aSMatthias Springer assert(options.bufferizeFunctionBoundaries && 421d6dab38aSMatthias Springer "expected that function boundary bufferization is activated"); 422e07a7fd5SMatthias Springer IRRewriter rewriter(moduleOp.getContext()); 423e07a7fd5SMatthias Springer OneShotAnalysisState analysisState(moduleOp, options); 424e07a7fd5SMatthias Springer BufferizationState bufferizationState(analysisState); 425e07a7fd5SMatthias Springer FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); 426e07a7fd5SMatthias Springer BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); 427e07a7fd5SMatthias Springer 428e07a7fd5SMatthias Springer // A list of functions in the order in which they are analyzed + bufferized. 429e07a7fd5SMatthias Springer SmallVector<func::FuncOp> orderedFuncOps; 430e07a7fd5SMatthias Springer 431e07a7fd5SMatthias Springer // A mapping of FuncOps to their callers. 432e07a7fd5SMatthias Springer FuncCallerMap callerMap; 433e07a7fd5SMatthias Springer 434e07a7fd5SMatthias Springer if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 435e07a7fd5SMatthias Springer return failure(); 436e07a7fd5SMatthias Springer 437e07a7fd5SMatthias Springer // Collect bbArg/return value information after the analysis. 438e07a7fd5SMatthias Springer options.addPostAnalysisStep(aliasingFuncOpBBArgsAnalysis); 439e07a7fd5SMatthias Springer options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); 440e07a7fd5SMatthias Springer 441e07a7fd5SMatthias Springer // Analyze ops. 442e07a7fd5SMatthias Springer for (func::FuncOp funcOp : orderedFuncOps) { 443e07a7fd5SMatthias Springer // No body => no analysis. 444e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) 445e07a7fd5SMatthias Springer continue; 446e07a7fd5SMatthias Springer 447e07a7fd5SMatthias Springer // Now analyzing function. 448e07a7fd5SMatthias Springer funcState.startFunctionAnalysis(funcOp); 449e07a7fd5SMatthias Springer 450e07a7fd5SMatthias Springer // Gather equivalence info for CallOps. 451e07a7fd5SMatthias Springer equivalenceAnalysis(funcOp, aliasInfo, funcState); 452e07a7fd5SMatthias Springer 453e07a7fd5SMatthias Springer // Analyze funcOp. 454e07a7fd5SMatthias Springer if (failed(analyzeOp(funcOp, analysisState))) 455e07a7fd5SMatthias Springer return failure(); 456e07a7fd5SMatthias Springer 457e07a7fd5SMatthias Springer // Mark op as fully analyzed. 458e07a7fd5SMatthias Springer funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 459e07a7fd5SMatthias Springer 460e07a7fd5SMatthias Springer // Add annotations to function arguments. 461e07a7fd5SMatthias Springer if (options.testAnalysisOnly) 462e07a7fd5SMatthias Springer annotateOpsWithBufferizationMarkers(funcOp, analysisState); 463e07a7fd5SMatthias Springer } 464e07a7fd5SMatthias Springer 465e07a7fd5SMatthias Springer if (options.testAnalysisOnly) 466e07a7fd5SMatthias Springer return success(); 467e07a7fd5SMatthias Springer 468e07a7fd5SMatthias Springer // Bufferize functions. 469e07a7fd5SMatthias Springer for (func::FuncOp funcOp : orderedFuncOps) { 470e07a7fd5SMatthias Springer // Note: It would be good to apply cleanups here but we cannot as aliasInfo 471e07a7fd5SMatthias Springer // would be invalidated. 472e07a7fd5SMatthias Springer if (failed(bufferizeOp(funcOp, bufferizationState))) 473e07a7fd5SMatthias Springer return failure(); 474f287da8aSMatthias Springer // Change buffer return types to more precise layout maps. 475f287da8aSMatthias Springer if (options.functionBoundaryTypeConversion == 476f287da8aSMatthias Springer BufferizationOptions::LayoutMapOption::InferLayoutMap) 477e07a7fd5SMatthias Springer foldMemRefCasts(funcOp); 478e07a7fd5SMatthias Springer } 479e07a7fd5SMatthias Springer 480e07a7fd5SMatthias Springer // Check result. 481e07a7fd5SMatthias Springer for (func::FuncOp funcOp : orderedFuncOps) { 482e07a7fd5SMatthias Springer if (!options.allowReturnAllocs && 483e07a7fd5SMatthias Springer llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) { 484e07a7fd5SMatthias Springer return t.isa<MemRefType, UnrankedMemRefType>(); 485e07a7fd5SMatthias Springer })) { 486e07a7fd5SMatthias Springer funcOp->emitError("memref return type is unsupported"); 487e07a7fd5SMatthias Springer return failure(); 488e07a7fd5SMatthias Springer } 489e07a7fd5SMatthias Springer } 490e07a7fd5SMatthias Springer 491e07a7fd5SMatthias Springer // Finalize all buffers. 492e07a7fd5SMatthias Springer if (failed(finalizeBuffers(moduleOp, options))) 493e07a7fd5SMatthias Springer return failure(); 494e07a7fd5SMatthias Springer 495e07a7fd5SMatthias Springer // Post-pass cleanup of function argument attributes. 496e07a7fd5SMatthias Springer moduleOp.walk([&](func::FuncOp op) { 497e07a7fd5SMatthias Springer for (BlockArgument bbArg : op.getArguments()) 498e07a7fd5SMatthias Springer removeBufferizationAttributes(bbArg); 499e07a7fd5SMatthias Springer }); 500e07a7fd5SMatthias Springer 501e07a7fd5SMatthias Springer return success(); 502e07a7fd5SMatthias Springer } 503