1*e07a7fd5SMatthias Springer //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2*e07a7fd5SMatthias Springer // 3*e07a7fd5SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*e07a7fd5SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5*e07a7fd5SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*e07a7fd5SMatthias Springer // 7*e07a7fd5SMatthias Springer //===----------------------------------------------------------------------===// 8*e07a7fd5SMatthias Springer // 9*e07a7fd5SMatthias Springer // Module Bufferization is an extension of One-Shot Bufferize that 10*e07a7fd5SMatthias Springer // bufferizes function boundaries. It provides `BufferizableOpInterface` 11*e07a7fd5SMatthias Springer // implementations for FuncOp, CallOp and ReturnOp. 12*e07a7fd5SMatthias Springer // 13*e07a7fd5SMatthias Springer // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14*e07a7fd5SMatthias Springer // This function analyzes the given module and determines the order of analysis 15*e07a7fd5SMatthias Springer // and bufferization: Functions that are called are processed before their 16*e07a7fd5SMatthias Springer // respective callers. 17*e07a7fd5SMatthias Springer // 18*e07a7fd5SMatthias Springer // After analyzing a FuncOp, additional information about its bbArgs is 19*e07a7fd5SMatthias Springer // gathered through PostAnalysisStepFns and stored in `FuncAnalysisState`. 20*e07a7fd5SMatthias Springer // 21*e07a7fd5SMatthias Springer // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22*e07a7fd5SMatthias Springer // for 23*e07a7fd5SMatthias Springer // each tensor return value (if any). 24*e07a7fd5SMatthias Springer // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25*e07a7fd5SMatthias Springer // read/written. 26*e07a7fd5SMatthias Springer // 27*e07a7fd5SMatthias Springer // Only tensors that are equivalent to some FuncOp bbArg may be returned. 28*e07a7fd5SMatthias Springer // Bufferization currently fails if other tensors (in particular tensors that 29*e07a7fd5SMatthias Springer // bufferize out-of-place and result in a new buffer allocation) are returned. 30*e07a7fd5SMatthias Springer // In the future, such allocations could be hoisted to the caller. 31*e07a7fd5SMatthias Springer // 32*e07a7fd5SMatthias Springer // Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. 33*e07a7fd5SMatthias Springer // ``` 34*e07a7fd5SMatthias Springer // func @foo() -> tensor<?xf32> { 35*e07a7fd5SMatthias Springer // %0 = linalg.init_tensor [...] : tensor<?xf32> 36*e07a7fd5SMatthias Springer // return %0 : tensor<?xf32> 37*e07a7fd5SMatthias Springer // } 38*e07a7fd5SMatthias Springer // ``` 39*e07a7fd5SMatthias Springer // 40*e07a7fd5SMatthias Springer // Module Bufferization implements the following calling convention. 41*e07a7fd5SMatthias Springer // 42*e07a7fd5SMatthias Springer // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 43*e07a7fd5SMatthias Springer // be written to in-place. 44*e07a7fd5SMatthias Springer // * If a tensor operand of a CallOp is read after the CallOp, the operand of 45*e07a7fd5SMatthias Springer // the CallOp must bufferize out-of-place. 46*e07a7fd5SMatthias Springer // 47*e07a7fd5SMatthias Springer // Example: The tensor.insert op bufferizes in-place because it is allowed to 48*e07a7fd5SMatthias Springer // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 49*e07a7fd5SMatthias Springer // out-of-place because `%t0` is modified by the callee but read by the 50*e07a7fd5SMatthias Springer // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 51*e07a7fd5SMatthias Springer // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 52*e07a7fd5SMatthias Springer // ``` 53*e07a7fd5SMatthias Springer // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 54*e07a7fd5SMatthias Springer // %f = ... : f32 55*e07a7fd5SMatthias Springer // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 56*e07a7fd5SMatthias Springer // return %0 : tensor<?xf32> 57*e07a7fd5SMatthias Springer // } 58*e07a7fd5SMatthias Springer // 59*e07a7fd5SMatthias Springer // func @caller() -> () { 60*e07a7fd5SMatthias Springer // %t0 = ... : tensor<?xf32> 61*e07a7fd5SMatthias Springer // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 62*e07a7fd5SMatthias Springer // %2 = tensor.extract %1[...] : tensor<?xf32> 63*e07a7fd5SMatthias Springer // } 64*e07a7fd5SMatthias Springer // ``` 65*e07a7fd5SMatthias Springer // 66*e07a7fd5SMatthias Springer // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 67*e07a7fd5SMatthias Springer // analyze the function body. In such a case, the CallOp analysis conservatively 68*e07a7fd5SMatthias Springer // assumes that each tensor OpOperand is both read and written. 69*e07a7fd5SMatthias Springer // 70*e07a7fd5SMatthias Springer // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 71*e07a7fd5SMatthias Springer // as "not reading" and/or "not writing". 72*e07a7fd5SMatthias Springer 73*e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 74*e07a7fd5SMatthias Springer 75*e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 76*e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 77*e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 78*e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 79*e07a7fd5SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 80*e07a7fd5SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 81*e07a7fd5SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 82*e07a7fd5SMatthias Springer #include "mlir/IR/Operation.h" 83*e07a7fd5SMatthias Springer 84*e07a7fd5SMatthias Springer using namespace mlir; 85*e07a7fd5SMatthias Springer using namespace mlir::bufferization; 86*e07a7fd5SMatthias Springer using namespace mlir::bufferization::func_ext; 87*e07a7fd5SMatthias Springer 88*e07a7fd5SMatthias Springer /// A mapping of FuncOps to their callers. 89*e07a7fd5SMatthias Springer using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 90*e07a7fd5SMatthias Springer 91*e07a7fd5SMatthias Springer /// Get FuncAnalysisState. 92*e07a7fd5SMatthias Springer static const FuncAnalysisState & 93*e07a7fd5SMatthias Springer getFuncAnalysisState(const AnalysisState &state) { 94*e07a7fd5SMatthias Springer Optional<const FuncAnalysisState *> maybeState = 95*e07a7fd5SMatthias Springer state.getDialectState<FuncAnalysisState>( 96*e07a7fd5SMatthias Springer func::FuncDialect::getDialectNamespace()); 97*e07a7fd5SMatthias Springer assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 98*e07a7fd5SMatthias Springer return **maybeState; 99*e07a7fd5SMatthias Springer } 100*e07a7fd5SMatthias Springer 101*e07a7fd5SMatthias Springer /// Get or create FuncAnalysisState. 102*e07a7fd5SMatthias Springer static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { 103*e07a7fd5SMatthias Springer return state.getOrCreateDialectState<FuncAnalysisState>( 104*e07a7fd5SMatthias Springer func::FuncDialect::getDialectNamespace()); 105*e07a7fd5SMatthias Springer } 106*e07a7fd5SMatthias Springer 107*e07a7fd5SMatthias Springer /// Return the state (phase) of analysis of the FuncOp. 108*e07a7fd5SMatthias Springer static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 109*e07a7fd5SMatthias Springer func::FuncOp funcOp) { 110*e07a7fd5SMatthias Springer const FuncAnalysisState &funcState = getFuncAnalysisState(state); 111*e07a7fd5SMatthias Springer auto it = funcState.analyzedFuncOps.find(funcOp); 112*e07a7fd5SMatthias Springer if (it == funcState.analyzedFuncOps.end()) 113*e07a7fd5SMatthias Springer return FuncOpAnalysisState::NotAnalyzed; 114*e07a7fd5SMatthias Springer return it->second; 115*e07a7fd5SMatthias Springer } 116*e07a7fd5SMatthias Springer 117*e07a7fd5SMatthias Springer /// Return the unique ReturnOp that terminates `funcOp`. 118*e07a7fd5SMatthias Springer /// Return nullptr if there is no such unique ReturnOp. 119*e07a7fd5SMatthias Springer static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 120*e07a7fd5SMatthias Springer func::ReturnOp returnOp; 121*e07a7fd5SMatthias Springer for (Block &b : funcOp.getBody()) { 122*e07a7fd5SMatthias Springer if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 123*e07a7fd5SMatthias Springer if (returnOp) 124*e07a7fd5SMatthias Springer return nullptr; 125*e07a7fd5SMatthias Springer returnOp = candidateOp; 126*e07a7fd5SMatthias Springer } 127*e07a7fd5SMatthias Springer } 128*e07a7fd5SMatthias Springer return returnOp; 129*e07a7fd5SMatthias Springer } 130*e07a7fd5SMatthias Springer 131*e07a7fd5SMatthias Springer namespace { 132*e07a7fd5SMatthias Springer 133*e07a7fd5SMatthias Springer /// Annotate IR with the results of the analysis. For testing purposes only. 134*e07a7fd5SMatthias Springer static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 135*e07a7fd5SMatthias Springer BlockArgument bbArg) { 136*e07a7fd5SMatthias Springer const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 137*e07a7fd5SMatthias Springer Operation *op = returnVal.getOwner(); 138*e07a7fd5SMatthias Springer 139*e07a7fd5SMatthias Springer SmallVector<int64_t> equivBbArgs; 140*e07a7fd5SMatthias Springer if (op->hasAttr(kEquivalentArgsAttr)) { 141*e07a7fd5SMatthias Springer auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>(); 142*e07a7fd5SMatthias Springer equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 143*e07a7fd5SMatthias Springer return a.cast<IntegerAttr>().getValue().getSExtValue(); 144*e07a7fd5SMatthias Springer })); 145*e07a7fd5SMatthias Springer } else { 146*e07a7fd5SMatthias Springer equivBbArgs.append(op->getNumOperands(), -1); 147*e07a7fd5SMatthias Springer } 148*e07a7fd5SMatthias Springer equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 149*e07a7fd5SMatthias Springer 150*e07a7fd5SMatthias Springer OpBuilder b(op->getContext()); 151*e07a7fd5SMatthias Springer op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 152*e07a7fd5SMatthias Springer } 153*e07a7fd5SMatthias Springer 154*e07a7fd5SMatthias Springer /// Store function BlockArguments that are equivalent to/aliasing a returned 155*e07a7fd5SMatthias Springer /// value in FuncAnalysisState. 156*e07a7fd5SMatthias Springer static LogicalResult 157*e07a7fd5SMatthias Springer aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, 158*e07a7fd5SMatthias Springer BufferizationAliasInfo &aliasInfo, 159*e07a7fd5SMatthias Springer SmallVector<Operation *> &newOps) { 160*e07a7fd5SMatthias Springer FuncAnalysisState &funcState = getFuncAnalysisState(state); 161*e07a7fd5SMatthias Springer 162*e07a7fd5SMatthias Springer // Support only single return-terminated block in the function. 163*e07a7fd5SMatthias Springer auto funcOp = cast<func::FuncOp>(op); 164*e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 165*e07a7fd5SMatthias Springer assert(returnOp && "expected func with single return op"); 166*e07a7fd5SMatthias Springer 167*e07a7fd5SMatthias Springer for (OpOperand &returnVal : returnOp->getOpOperands()) 168*e07a7fd5SMatthias Springer if (returnVal.get().getType().isa<RankedTensorType>()) 169*e07a7fd5SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) 170*e07a7fd5SMatthias Springer if (bbArg.getType().isa<RankedTensorType>()) { 171*e07a7fd5SMatthias Springer int64_t returnIdx = returnVal.getOperandNumber(); 172*e07a7fd5SMatthias Springer int64_t bbArgIdx = bbArg.getArgNumber(); 173*e07a7fd5SMatthias Springer if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { 174*e07a7fd5SMatthias Springer funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; 175*e07a7fd5SMatthias Springer if (state.getOptions().testAnalysisOnly) 176*e07a7fd5SMatthias Springer annotateEquivalentReturnBbArg(returnVal, bbArg); 177*e07a7fd5SMatthias Springer } 178*e07a7fd5SMatthias Springer if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) { 179*e07a7fd5SMatthias Springer funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); 180*e07a7fd5SMatthias Springer funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 181*e07a7fd5SMatthias Springer } 182*e07a7fd5SMatthias Springer } 183*e07a7fd5SMatthias Springer 184*e07a7fd5SMatthias Springer return success(); 185*e07a7fd5SMatthias Springer } 186*e07a7fd5SMatthias Springer 187*e07a7fd5SMatthias Springer /// Return true if the buffer of the given tensor value is written to. Must not 188*e07a7fd5SMatthias Springer /// be called for values inside not yet analyzed functions. (Post-analysis 189*e07a7fd5SMatthias Springer /// steps do not have to be run yet, i.e., "in progress" is also OK.) 190*e07a7fd5SMatthias Springer static bool isValueWritten(Value value, const AnalysisState &state, 191*e07a7fd5SMatthias Springer const BufferizationAliasInfo &aliasInfo) { 192*e07a7fd5SMatthias Springer #ifndef NDEBUG 193*e07a7fd5SMatthias Springer assert(value.getType().isa<TensorType>() && "expected TensorType"); 194*e07a7fd5SMatthias Springer func::FuncOp funcOp; 195*e07a7fd5SMatthias Springer if (auto bbArg = value.dyn_cast<BlockArgument>()) { 196*e07a7fd5SMatthias Springer Operation *owner = bbArg.getOwner()->getParentOp(); 197*e07a7fd5SMatthias Springer funcOp = isa<func::FuncOp>(owner) ? cast<func::FuncOp>(owner) 198*e07a7fd5SMatthias Springer : owner->getParentOfType<func::FuncOp>(); 199*e07a7fd5SMatthias Springer } else { 200*e07a7fd5SMatthias Springer funcOp = value.getDefiningOp()->getParentOfType<func::FuncOp>(); 201*e07a7fd5SMatthias Springer } 202*e07a7fd5SMatthias Springer assert(getFuncOpAnalysisState(state, funcOp) != 203*e07a7fd5SMatthias Springer FuncOpAnalysisState::NotAnalyzed && 204*e07a7fd5SMatthias Springer "FuncOp must be fully analyzed or analysis in progress"); 205*e07a7fd5SMatthias Springer #endif // NDEBUG 206*e07a7fd5SMatthias Springer 207*e07a7fd5SMatthias Springer bool isWritten = false; 208*e07a7fd5SMatthias Springer aliasInfo.applyOnAliases(value, [&](Value val) { 209*e07a7fd5SMatthias Springer for (OpOperand &use : val.getUses()) 210*e07a7fd5SMatthias Springer if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use)) 211*e07a7fd5SMatthias Springer isWritten = true; 212*e07a7fd5SMatthias Springer }); 213*e07a7fd5SMatthias Springer return isWritten; 214*e07a7fd5SMatthias Springer } 215*e07a7fd5SMatthias Springer 216*e07a7fd5SMatthias Springer static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, 217*e07a7fd5SMatthias Springer bool isRead, bool isWritten) { 218*e07a7fd5SMatthias Springer OpBuilder b(funcOp.getContext()); 219*e07a7fd5SMatthias Springer Attribute accessType; 220*e07a7fd5SMatthias Springer if (isRead && isWritten) { 221*e07a7fd5SMatthias Springer accessType = b.getStringAttr("read-write"); 222*e07a7fd5SMatthias Springer } else if (isRead) { 223*e07a7fd5SMatthias Springer accessType = b.getStringAttr("read"); 224*e07a7fd5SMatthias Springer } else if (isWritten) { 225*e07a7fd5SMatthias Springer accessType = b.getStringAttr("write"); 226*e07a7fd5SMatthias Springer } else { 227*e07a7fd5SMatthias Springer accessType = b.getStringAttr("none"); 228*e07a7fd5SMatthias Springer } 229*e07a7fd5SMatthias Springer funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); 230*e07a7fd5SMatthias Springer } 231*e07a7fd5SMatthias Springer 232*e07a7fd5SMatthias Springer /// Determine which FuncOp bbArgs are read and which are written. If this 233*e07a7fd5SMatthias Springer /// PostAnalysisStepFn is run on a function with unknown ops, it will 234*e07a7fd5SMatthias Springer /// conservatively assume that such ops bufferize to a read + write. 235*e07a7fd5SMatthias Springer static LogicalResult 236*e07a7fd5SMatthias Springer funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, 237*e07a7fd5SMatthias Springer BufferizationAliasInfo &aliasInfo, 238*e07a7fd5SMatthias Springer SmallVector<Operation *> &newOps) { 239*e07a7fd5SMatthias Springer FuncAnalysisState &funcState = getFuncAnalysisState(state); 240*e07a7fd5SMatthias Springer auto funcOp = cast<func::FuncOp>(op); 241*e07a7fd5SMatthias Springer 242*e07a7fd5SMatthias Springer // If the function has no body, conservatively assume that all args are 243*e07a7fd5SMatthias Springer // read + written. 244*e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) { 245*e07a7fd5SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) { 246*e07a7fd5SMatthias Springer funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 247*e07a7fd5SMatthias Springer funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 248*e07a7fd5SMatthias Springer } 249*e07a7fd5SMatthias Springer 250*e07a7fd5SMatthias Springer return success(); 251*e07a7fd5SMatthias Springer } 252*e07a7fd5SMatthias Springer 253*e07a7fd5SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) { 254*e07a7fd5SMatthias Springer if (!bbArg.getType().isa<TensorType>()) 255*e07a7fd5SMatthias Springer continue; 256*e07a7fd5SMatthias Springer bool isRead = state.isValueRead(bbArg); 257*e07a7fd5SMatthias Springer bool isWritten = isValueWritten(bbArg, state, aliasInfo); 258*e07a7fd5SMatthias Springer if (state.getOptions().testAnalysisOnly) 259*e07a7fd5SMatthias Springer annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); 260*e07a7fd5SMatthias Springer if (isRead) 261*e07a7fd5SMatthias Springer funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 262*e07a7fd5SMatthias Springer if (isWritten) 263*e07a7fd5SMatthias Springer funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 264*e07a7fd5SMatthias Springer } 265*e07a7fd5SMatthias Springer 266*e07a7fd5SMatthias Springer return success(); 267*e07a7fd5SMatthias Springer } 268*e07a7fd5SMatthias Springer } // namespace 269*e07a7fd5SMatthias Springer 270*e07a7fd5SMatthias Springer /// Remove bufferization attributes on FuncOp arguments. 271*e07a7fd5SMatthias Springer static void removeBufferizationAttributes(BlockArgument bbArg) { 272*e07a7fd5SMatthias Springer auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 273*e07a7fd5SMatthias Springer funcOp.removeArgAttr(bbArg.getArgNumber(), 274*e07a7fd5SMatthias Springer BufferizationDialect::kBufferLayoutAttrName); 275*e07a7fd5SMatthias Springer funcOp.removeArgAttr(bbArg.getArgNumber(), 276*e07a7fd5SMatthias Springer BufferizationDialect::kWritableAttrName); 277*e07a7fd5SMatthias Springer } 278*e07a7fd5SMatthias Springer 279*e07a7fd5SMatthias Springer /// Return the func::FuncOp called by `callOp`. 280*e07a7fd5SMatthias Springer static func::FuncOp getCalledFunction(CallOpInterface callOp) { 281*e07a7fd5SMatthias Springer SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 282*e07a7fd5SMatthias Springer if (!sym) 283*e07a7fd5SMatthias Springer return nullptr; 284*e07a7fd5SMatthias Springer return dyn_cast_or_null<func::FuncOp>( 285*e07a7fd5SMatthias Springer SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 286*e07a7fd5SMatthias Springer } 287*e07a7fd5SMatthias Springer 288*e07a7fd5SMatthias Springer /// Gather equivalence info of CallOps. 289*e07a7fd5SMatthias Springer /// Note: This only adds new equivalence info if the called function was already 290*e07a7fd5SMatthias Springer /// analyzed. 291*e07a7fd5SMatthias Springer // TODO: This does not handle cyclic function call graphs etc. 292*e07a7fd5SMatthias Springer static void equivalenceAnalysis(func::FuncOp funcOp, 293*e07a7fd5SMatthias Springer BufferizationAliasInfo &aliasInfo, 294*e07a7fd5SMatthias Springer FuncAnalysisState &funcState) { 295*e07a7fd5SMatthias Springer funcOp->walk([&](func::CallOp callOp) { 296*e07a7fd5SMatthias Springer func::FuncOp calledFunction = getCalledFunction(callOp); 297*e07a7fd5SMatthias Springer assert(calledFunction && "could not retrieved called func::FuncOp"); 298*e07a7fd5SMatthias Springer 299*e07a7fd5SMatthias Springer // No equivalence info available for the called function. 300*e07a7fd5SMatthias Springer if (!funcState.equivalentFuncArgs.count(calledFunction)) 301*e07a7fd5SMatthias Springer return WalkResult::skip(); 302*e07a7fd5SMatthias Springer 303*e07a7fd5SMatthias Springer for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 304*e07a7fd5SMatthias Springer int64_t returnIdx = it.first; 305*e07a7fd5SMatthias Springer int64_t bbargIdx = it.second; 306*e07a7fd5SMatthias Springer Value returnVal = callOp.getResult(returnIdx); 307*e07a7fd5SMatthias Springer Value argVal = callOp->getOperand(bbargIdx); 308*e07a7fd5SMatthias Springer aliasInfo.unionEquivalenceClasses(returnVal, argVal); 309*e07a7fd5SMatthias Springer } 310*e07a7fd5SMatthias Springer 311*e07a7fd5SMatthias Springer return WalkResult::advance(); 312*e07a7fd5SMatthias Springer }); 313*e07a7fd5SMatthias Springer } 314*e07a7fd5SMatthias Springer 315*e07a7fd5SMatthias Springer /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 316*e07a7fd5SMatthias Springer /// callee-caller order (i.e. callees without callers first). 317*e07a7fd5SMatthias Springer /// Store the map of FuncOp to all its callers in `callerMap`. 318*e07a7fd5SMatthias Springer /// Return `failure()` if a cycle of calls is detected or if we are unable to 319*e07a7fd5SMatthias Springer /// retrieve the called FuncOp from any CallOpInterface. 320*e07a7fd5SMatthias Springer static LogicalResult 321*e07a7fd5SMatthias Springer getFuncOpsOrderedByCalls(ModuleOp moduleOp, 322*e07a7fd5SMatthias Springer SmallVectorImpl<func::FuncOp> &orderedFuncOps, 323*e07a7fd5SMatthias Springer FuncCallerMap &callerMap) { 324*e07a7fd5SMatthias Springer // For each FuncOp, the set of functions called by it (i.e. the union of 325*e07a7fd5SMatthias Springer // symbols of all nested CallOpInterfaceOp). 326*e07a7fd5SMatthias Springer DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 327*e07a7fd5SMatthias Springer // For each FuncOp, the number of CallOpInterface it contains. 328*e07a7fd5SMatthias Springer DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 329*e07a7fd5SMatthias Springer WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 330*e07a7fd5SMatthias Springer if (!funcOp.getBody().empty()) { 331*e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 332*e07a7fd5SMatthias Springer if (!returnOp) 333*e07a7fd5SMatthias Springer return funcOp->emitError() 334*e07a7fd5SMatthias Springer << "cannot bufferize a FuncOp with tensors and " 335*e07a7fd5SMatthias Springer "without a unique ReturnOp"; 336*e07a7fd5SMatthias Springer } 337*e07a7fd5SMatthias Springer 338*e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[funcOp] = 0; 339*e07a7fd5SMatthias Springer return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { 340*e07a7fd5SMatthias Springer // Only support CallOp for now. 341*e07a7fd5SMatthias Springer if (!isa<func::CallOp>(callOp.getOperation())) 342*e07a7fd5SMatthias Springer return callOp->emitError() << "expected a CallOp"; 343*e07a7fd5SMatthias Springer func::FuncOp calledFunction = getCalledFunction(callOp); 344*e07a7fd5SMatthias Springer assert(calledFunction && "could not retrieved called func::FuncOp"); 345*e07a7fd5SMatthias Springer auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{}); 346*e07a7fd5SMatthias Springer it.first->getSecond().insert(callOp); 347*e07a7fd5SMatthias Springer if (calledBy[calledFunction].count(funcOp) == 0) { 348*e07a7fd5SMatthias Springer calledBy[calledFunction].insert(funcOp); 349*e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[funcOp]++; 350*e07a7fd5SMatthias Springer } 351*e07a7fd5SMatthias Springer return WalkResult::advance(); 352*e07a7fd5SMatthias Springer }); 353*e07a7fd5SMatthias Springer }); 354*e07a7fd5SMatthias Springer if (res.wasInterrupted()) 355*e07a7fd5SMatthias Springer return failure(); 356*e07a7fd5SMatthias Springer // Iteratively remove function operation that do not call any of the 357*e07a7fd5SMatthias Springer // functions remaining in the callCounter map and add them to the worklist. 358*e07a7fd5SMatthias Springer while (!numberCallOpsContainedInFuncOp.empty()) { 359*e07a7fd5SMatthias Springer auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 360*e07a7fd5SMatthias Springer [](auto entry) { return entry.getSecond() == 0; }); 361*e07a7fd5SMatthias Springer if (it == numberCallOpsContainedInFuncOp.end()) 362*e07a7fd5SMatthias Springer return moduleOp.emitOpError( 363*e07a7fd5SMatthias Springer "expected callgraph to be free of circular dependencies."); 364*e07a7fd5SMatthias Springer orderedFuncOps.push_back(it->getFirst()); 365*e07a7fd5SMatthias Springer for (auto callee : calledBy[it->getFirst()]) 366*e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp[callee]--; 367*e07a7fd5SMatthias Springer numberCallOpsContainedInFuncOp.erase(it); 368*e07a7fd5SMatthias Springer } 369*e07a7fd5SMatthias Springer return success(); 370*e07a7fd5SMatthias Springer } 371*e07a7fd5SMatthias Springer 372*e07a7fd5SMatthias Springer /// Set the attribute that triggers inplace bufferization on a FuncOp argument 373*e07a7fd5SMatthias Springer /// `bbArg`. 374*e07a7fd5SMatthias Springer static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { 375*e07a7fd5SMatthias Springer auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 376*e07a7fd5SMatthias Springer funcOp.setArgAttr(bbArg.getArgNumber(), 377*e07a7fd5SMatthias Springer BufferizableOpInterface::kInplaceableAttrName, 378*e07a7fd5SMatthias Springer BoolAttr::get(bbArg.getContext(), inPlace)); 379*e07a7fd5SMatthias Springer } 380*e07a7fd5SMatthias Springer 381*e07a7fd5SMatthias Springer /// Annotate the IR with the result of the analysis. For testing/debugging only. 382*e07a7fd5SMatthias Springer static void annotateOpsWithBufferizationMarkers(func::FuncOp funcOp, 383*e07a7fd5SMatthias Springer const AnalysisState &state) { 384*e07a7fd5SMatthias Springer auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation()); 385*e07a7fd5SMatthias Springer for (BlockArgument bbArg : funcOp.getArguments()) 386*e07a7fd5SMatthias Springer if (bbArg.getType().isa<TensorType>()) 387*e07a7fd5SMatthias Springer setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); 388*e07a7fd5SMatthias Springer } 389*e07a7fd5SMatthias Springer 390*e07a7fd5SMatthias Springer /// Fold return values that are memref casts and update function return types. 391*e07a7fd5SMatthias Springer /// 392*e07a7fd5SMatthias Springer /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 393*e07a7fd5SMatthias Springer /// is not known yet. Therefore, the bufferization uses memref types with the 394*e07a7fd5SMatthias Springer /// most generic layout map as function return types. After bufferizing the 395*e07a7fd5SMatthias Springer /// entire function body, a more concise memref type can potentially be used for 396*e07a7fd5SMatthias Springer /// the return type of the function. 397*e07a7fd5SMatthias Springer static void foldMemRefCasts(func::FuncOp funcOp) { 398*e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) 399*e07a7fd5SMatthias Springer return; 400*e07a7fd5SMatthias Springer 401*e07a7fd5SMatthias Springer func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 402*e07a7fd5SMatthias Springer SmallVector<Type> resultTypes; 403*e07a7fd5SMatthias Springer 404*e07a7fd5SMatthias Springer for (OpOperand &operand : returnOp->getOpOperands()) { 405*e07a7fd5SMatthias Springer if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { 406*e07a7fd5SMatthias Springer operand.set(castOp.source()); 407*e07a7fd5SMatthias Springer resultTypes.push_back(castOp.source().getType()); 408*e07a7fd5SMatthias Springer } else { 409*e07a7fd5SMatthias Springer resultTypes.push_back(operand.get().getType()); 410*e07a7fd5SMatthias Springer } 411*e07a7fd5SMatthias Springer } 412*e07a7fd5SMatthias Springer 413*e07a7fd5SMatthias Springer auto newFuncType = FunctionType::get( 414*e07a7fd5SMatthias Springer funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 415*e07a7fd5SMatthias Springer funcOp.setType(newFuncType); 416*e07a7fd5SMatthias Springer } 417*e07a7fd5SMatthias Springer 418*e07a7fd5SMatthias Springer LogicalResult mlir::bufferization::runOneShotModuleBufferize( 419*e07a7fd5SMatthias Springer ModuleOp moduleOp, OneShotBufferizationOptions options) { 420*e07a7fd5SMatthias Springer IRRewriter rewriter(moduleOp.getContext()); 421*e07a7fd5SMatthias Springer OneShotAnalysisState analysisState(moduleOp, options); 422*e07a7fd5SMatthias Springer BufferizationState bufferizationState(analysisState); 423*e07a7fd5SMatthias Springer FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); 424*e07a7fd5SMatthias Springer BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); 425*e07a7fd5SMatthias Springer 426*e07a7fd5SMatthias Springer // A list of functions in the order in which they are analyzed + bufferized. 427*e07a7fd5SMatthias Springer SmallVector<func::FuncOp> orderedFuncOps; 428*e07a7fd5SMatthias Springer 429*e07a7fd5SMatthias Springer // A mapping of FuncOps to their callers. 430*e07a7fd5SMatthias Springer FuncCallerMap callerMap; 431*e07a7fd5SMatthias Springer 432*e07a7fd5SMatthias Springer if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 433*e07a7fd5SMatthias Springer return failure(); 434*e07a7fd5SMatthias Springer 435*e07a7fd5SMatthias Springer // Collect bbArg/return value information after the analysis. 436*e07a7fd5SMatthias Springer options.addPostAnalysisStep(aliasingFuncOpBBArgsAnalysis); 437*e07a7fd5SMatthias Springer options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); 438*e07a7fd5SMatthias Springer 439*e07a7fd5SMatthias Springer // Analyze ops. 440*e07a7fd5SMatthias Springer for (func::FuncOp funcOp : orderedFuncOps) { 441*e07a7fd5SMatthias Springer // No body => no analysis. 442*e07a7fd5SMatthias Springer if (funcOp.getBody().empty()) 443*e07a7fd5SMatthias Springer continue; 444*e07a7fd5SMatthias Springer 445*e07a7fd5SMatthias Springer // Now analyzing function. 446*e07a7fd5SMatthias Springer funcState.startFunctionAnalysis(funcOp); 447*e07a7fd5SMatthias Springer 448*e07a7fd5SMatthias Springer // Gather equivalence info for CallOps. 449*e07a7fd5SMatthias Springer equivalenceAnalysis(funcOp, aliasInfo, funcState); 450*e07a7fd5SMatthias Springer 451*e07a7fd5SMatthias Springer // Analyze funcOp. 452*e07a7fd5SMatthias Springer if (failed(analyzeOp(funcOp, analysisState))) 453*e07a7fd5SMatthias Springer return failure(); 454*e07a7fd5SMatthias Springer 455*e07a7fd5SMatthias Springer // Mark op as fully analyzed. 456*e07a7fd5SMatthias Springer funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 457*e07a7fd5SMatthias Springer 458*e07a7fd5SMatthias Springer // Add annotations to function arguments. 459*e07a7fd5SMatthias Springer if (options.testAnalysisOnly) 460*e07a7fd5SMatthias Springer annotateOpsWithBufferizationMarkers(funcOp, analysisState); 461*e07a7fd5SMatthias Springer } 462*e07a7fd5SMatthias Springer 463*e07a7fd5SMatthias Springer if (options.testAnalysisOnly) 464*e07a7fd5SMatthias Springer return success(); 465*e07a7fd5SMatthias Springer 466*e07a7fd5SMatthias Springer // Bufferize functions. 467*e07a7fd5SMatthias Springer for (func::FuncOp funcOp : orderedFuncOps) { 468*e07a7fd5SMatthias Springer // Note: It would be good to apply cleanups here but we cannot as aliasInfo 469*e07a7fd5SMatthias Springer // would be invalidated. 470*e07a7fd5SMatthias Springer if (failed(bufferizeOp(funcOp, bufferizationState))) 471*e07a7fd5SMatthias Springer return failure(); 472*e07a7fd5SMatthias Springer foldMemRefCasts(funcOp); 473*e07a7fd5SMatthias Springer } 474*e07a7fd5SMatthias Springer 475*e07a7fd5SMatthias Springer // Check result. 476*e07a7fd5SMatthias Springer for (func::FuncOp funcOp : orderedFuncOps) { 477*e07a7fd5SMatthias Springer if (!options.allowReturnAllocs && 478*e07a7fd5SMatthias Springer llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) { 479*e07a7fd5SMatthias Springer return t.isa<MemRefType, UnrankedMemRefType>(); 480*e07a7fd5SMatthias Springer })) { 481*e07a7fd5SMatthias Springer funcOp->emitError("memref return type is unsupported"); 482*e07a7fd5SMatthias Springer return failure(); 483*e07a7fd5SMatthias Springer } 484*e07a7fd5SMatthias Springer } 485*e07a7fd5SMatthias Springer 486*e07a7fd5SMatthias Springer // Finalize all buffers. 487*e07a7fd5SMatthias Springer if (failed(finalizeBuffers(moduleOp, options))) 488*e07a7fd5SMatthias Springer return failure(); 489*e07a7fd5SMatthias Springer 490*e07a7fd5SMatthias Springer // Post-pass cleanup of function argument attributes. 491*e07a7fd5SMatthias Springer moduleOp.walk([&](func::FuncOp op) { 492*e07a7fd5SMatthias Springer for (BlockArgument bbArg : op.getArguments()) 493*e07a7fd5SMatthias Springer removeBufferizationAttributes(bbArg); 494*e07a7fd5SMatthias Springer }); 495*e07a7fd5SMatthias Springer 496*e07a7fd5SMatthias Springer return success(); 497*e07a7fd5SMatthias Springer } 498