1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Module Bufferization is an extension of One-Shot Bufferize that 10 // bufferizes function boundaries. It provides `BufferizableOpInterface` 11 // implementations for FuncOp, CallOp and ReturnOp. 12 // 13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14 // This function analyzes the given module and determines the order of analysis 15 // and bufferization: Functions that are called are processed before their 16 // respective callers. 17 // 18 // After analyzing a FuncOp, additional information about its bbArgs is 19 // gathered and stored in `FuncAnalysisState`. 20 // 21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22 // for 23 // each tensor return value (if any). 24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25 // read/written. 26 // 27 // Only tensors that are equivalent to some FuncOp bbArg may be returned. 28 // Bufferization currently fails if other tensors (in particular tensors that 29 // bufferize out-of-place and result in a new buffer allocation) are returned. 30 // In the future, such allocations could be hoisted to the caller. 31 // 32 // Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. 33 // ``` 34 // func @foo() -> tensor<?xf32> { 35 // %0 = bufferization.alloc_tensor(...) : tensor<?xf32> 36 // return %0 : tensor<?xf32> 37 // } 38 // ``` 39 // 40 // Module Bufferization implements the following calling convention. 41 // 42 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 43 // be written to in-place. 44 // * If a tensor operand of a CallOp is read after the CallOp, the operand of 45 // the CallOp must bufferize out-of-place. 46 // 47 // Example: The tensor.insert op bufferizes in-place because it is allowed to 48 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 49 // out-of-place because `%t0` is modified by the callee but read by the 50 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 51 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 52 // ``` 53 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 54 // %f = ... : f32 55 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 56 // return %0 : tensor<?xf32> 57 // } 58 // 59 // func @caller() -> () { 60 // %t0 = ... : tensor<?xf32> 61 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 62 // %2 = tensor.extract %1[...] : tensor<?xf32> 63 // } 64 // ``` 65 // 66 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 67 // analyze the function body. In such a case, the CallOp analysis conservatively 68 // assumes that each tensor OpOperand is both read and written. 69 // 70 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 71 // as "not reading" and/or "not writing". 72 73 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 74 75 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 76 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 77 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 78 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 79 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 80 #include "mlir/Dialect/Func/IR/FuncOps.h" 81 #include "mlir/Dialect/MemRef/IR/MemRef.h" 82 #include "mlir/IR/Operation.h" 83 84 using namespace mlir; 85 using namespace mlir::bufferization; 86 using namespace mlir::bufferization::func_ext; 87 88 /// A mapping of FuncOps to their callers. 89 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 90 91 /// Get FuncAnalysisState. 92 static const FuncAnalysisState & 93 getFuncAnalysisState(const AnalysisState &state) { 94 Optional<const FuncAnalysisState *> maybeState = 95 state.getDialectState<FuncAnalysisState>( 96 func::FuncDialect::getDialectNamespace()); 97 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 98 return **maybeState; 99 } 100 101 /// Get or create FuncAnalysisState. 102 static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { 103 return state.getOrCreateDialectState<FuncAnalysisState>( 104 func::FuncDialect::getDialectNamespace()); 105 } 106 107 /// Return the state (phase) of analysis of the FuncOp. 108 /// Used for debug modes. 109 LLVM_ATTRIBUTE_UNUSED 110 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 111 func::FuncOp funcOp) { 112 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 113 auto it = funcState.analyzedFuncOps.find(funcOp); 114 if (it == funcState.analyzedFuncOps.end()) 115 return FuncOpAnalysisState::NotAnalyzed; 116 return it->second; 117 } 118 119 /// Return the unique ReturnOp that terminates `funcOp`. 120 /// Return nullptr if there is no such unique ReturnOp. 121 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 122 func::ReturnOp returnOp; 123 for (Block &b : funcOp.getBody()) { 124 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 125 if (returnOp) 126 return nullptr; 127 returnOp = candidateOp; 128 } 129 } 130 return returnOp; 131 } 132 133 namespace { 134 135 /// Annotate IR with the results of the analysis. For testing purposes only. 136 static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 137 BlockArgument bbArg) { 138 const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 139 Operation *op = returnVal.getOwner(); 140 141 SmallVector<int64_t> equivBbArgs; 142 if (op->hasAttr(kEquivalentArgsAttr)) { 143 auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>(); 144 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 145 return a.cast<IntegerAttr>().getValue().getSExtValue(); 146 })); 147 } else { 148 equivBbArgs.append(op->getNumOperands(), -1); 149 } 150 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 151 152 OpBuilder b(op->getContext()); 153 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 154 } 155 156 /// Store function BlockArguments that are equivalent to/aliasing a returned 157 /// value in FuncAnalysisState. 158 static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, 159 OneShotAnalysisState &state) { 160 FuncAnalysisState &funcState = getFuncAnalysisState(state); 161 162 // Support only single return-terminated block in the function. 163 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 164 assert(returnOp && "expected func with single return op"); 165 166 for (OpOperand &returnVal : returnOp->getOpOperands()) 167 if (returnVal.get().getType().isa<RankedTensorType>()) 168 for (BlockArgument bbArg : funcOp.getArguments()) 169 if (bbArg.getType().isa<RankedTensorType>()) { 170 int64_t returnIdx = returnVal.getOperandNumber(); 171 int64_t bbArgIdx = bbArg.getArgNumber(); 172 if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { 173 funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; 174 if (state.getOptions().testAnalysisOnly) 175 annotateEquivalentReturnBbArg(returnVal, bbArg); 176 } 177 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) { 178 funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); 179 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 180 } 181 } 182 183 return success(); 184 } 185 186 static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, 187 bool isRead, bool isWritten) { 188 OpBuilder b(funcOp.getContext()); 189 Attribute accessType; 190 if (isRead && isWritten) { 191 accessType = b.getStringAttr("read-write"); 192 } else if (isRead) { 193 accessType = b.getStringAttr("read"); 194 } else if (isWritten) { 195 accessType = b.getStringAttr("write"); 196 } else { 197 accessType = b.getStringAttr("none"); 198 } 199 funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); 200 } 201 202 /// Determine which FuncOp bbArgs are read and which are written. When run on a 203 /// function with unknown ops, we conservatively assume that such ops bufferize 204 /// to a read + write. 205 static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp, 206 OneShotAnalysisState &state) { 207 FuncAnalysisState &funcState = getFuncAnalysisState(state); 208 209 // If the function has no body, conservatively assume that all args are 210 // read + written. 211 if (funcOp.getBody().empty()) { 212 for (BlockArgument bbArg : funcOp.getArguments()) { 213 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 214 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 215 } 216 217 return success(); 218 } 219 220 for (BlockArgument bbArg : funcOp.getArguments()) { 221 if (!bbArg.getType().isa<TensorType>()) 222 continue; 223 bool isRead = state.isValueRead(bbArg); 224 bool isWritten = state.isValueWritten(bbArg); 225 if (state.getOptions().testAnalysisOnly) 226 annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); 227 if (isRead) 228 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 229 if (isWritten) 230 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 231 } 232 233 return success(); 234 } 235 } // namespace 236 237 /// Remove bufferization attributes on FuncOp arguments. 238 static void removeBufferizationAttributes(BlockArgument bbArg) { 239 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 240 funcOp.removeArgAttr(bbArg.getArgNumber(), 241 BufferizationDialect::kBufferLayoutAttrName); 242 funcOp.removeArgAttr(bbArg.getArgNumber(), 243 BufferizationDialect::kWritableAttrName); 244 } 245 246 /// Return the func::FuncOp called by `callOp`. 247 static func::FuncOp getCalledFunction(CallOpInterface callOp) { 248 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 249 if (!sym) 250 return nullptr; 251 return dyn_cast_or_null<func::FuncOp>( 252 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 253 } 254 255 /// Gather equivalence info of CallOps. 256 /// Note: This only adds new equivalence info if the called function was already 257 /// analyzed. 258 // TODO: This does not handle cyclic function call graphs etc. 259 static void equivalenceAnalysis(func::FuncOp funcOp, 260 BufferizationAliasInfo &aliasInfo, 261 FuncAnalysisState &funcState) { 262 funcOp->walk([&](func::CallOp callOp) { 263 func::FuncOp calledFunction = getCalledFunction(callOp); 264 assert(calledFunction && "could not retrieved called func::FuncOp"); 265 266 // No equivalence info available for the called function. 267 if (!funcState.equivalentFuncArgs.count(calledFunction)) 268 return WalkResult::skip(); 269 270 for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 271 int64_t returnIdx = it.first; 272 int64_t bbargIdx = it.second; 273 Value returnVal = callOp.getResult(returnIdx); 274 Value argVal = callOp->getOperand(bbargIdx); 275 aliasInfo.unionEquivalenceClasses(returnVal, argVal); 276 } 277 278 return WalkResult::advance(); 279 }); 280 } 281 282 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 283 /// callee-caller order (i.e. callees without callers first). 284 /// Store the map of FuncOp to all its callers in `callerMap`. 285 /// Return `failure()` if a cycle of calls is detected or if we are unable to 286 /// retrieve the called FuncOp from any CallOpInterface. 287 static LogicalResult 288 getFuncOpsOrderedByCalls(ModuleOp moduleOp, 289 SmallVectorImpl<func::FuncOp> &orderedFuncOps, 290 FuncCallerMap &callerMap) { 291 // For each FuncOp, the set of functions called by it (i.e. the union of 292 // symbols of all nested CallOpInterfaceOp). 293 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 294 // For each FuncOp, the number of CallOpInterface it contains. 295 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 296 WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 297 if (!funcOp.getBody().empty()) { 298 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 299 if (!returnOp) 300 return funcOp->emitError() 301 << "cannot bufferize a FuncOp with tensors and " 302 "without a unique ReturnOp"; 303 } 304 305 numberCallOpsContainedInFuncOp[funcOp] = 0; 306 return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { 307 // Only support CallOp for now. 308 if (!isa<func::CallOp>(callOp.getOperation())) 309 return callOp->emitError() << "expected a CallOp"; 310 func::FuncOp calledFunction = getCalledFunction(callOp); 311 assert(calledFunction && "could not retrieved called func::FuncOp"); 312 callerMap[calledFunction].insert(callOp); 313 if (calledBy[calledFunction].insert(funcOp).second) { 314 numberCallOpsContainedInFuncOp[funcOp]++; 315 } 316 return WalkResult::advance(); 317 }); 318 }); 319 if (res.wasInterrupted()) 320 return failure(); 321 // Iteratively remove function operation that do not call any of the 322 // functions remaining in the callCounter map and add them to the worklist. 323 while (!numberCallOpsContainedInFuncOp.empty()) { 324 auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 325 [](auto entry) { return entry.getSecond() == 0; }); 326 if (it == numberCallOpsContainedInFuncOp.end()) 327 return moduleOp.emitOpError( 328 "expected callgraph to be free of circular dependencies."); 329 orderedFuncOps.push_back(it->getFirst()); 330 for (auto callee : calledBy[it->getFirst()]) 331 numberCallOpsContainedInFuncOp[callee]--; 332 numberCallOpsContainedInFuncOp.erase(it); 333 } 334 return success(); 335 } 336 337 /// Set the attribute that triggers inplace bufferization on a FuncOp argument 338 /// `bbArg`. 339 static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { 340 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 341 funcOp.setArgAttr(bbArg.getArgNumber(), 342 BufferizableOpInterface::kInplaceableAttrName, 343 BoolAttr::get(bbArg.getContext(), inPlace)); 344 } 345 346 /// Annotate the IR with the result of the analysis. For testing/debugging only. 347 static void annotateOpsWithBufferizationMarkers(func::FuncOp funcOp, 348 const AnalysisState &state) { 349 auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation()); 350 for (BlockArgument bbArg : funcOp.getArguments()) 351 if (bbArg.getType().isa<TensorType>()) 352 setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); 353 } 354 355 /// Fold return values that are memref casts and update function return types. 356 /// 357 /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 358 /// is not known yet. Therefore, the bufferization uses memref types with the 359 /// most generic layout map as function return types. After bufferizing the 360 /// entire function body, a more concise memref type can potentially be used for 361 /// the return type of the function. 362 static void foldMemRefCasts(func::FuncOp funcOp) { 363 if (funcOp.getBody().empty()) 364 return; 365 366 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 367 SmallVector<Type> resultTypes; 368 369 for (OpOperand &operand : returnOp->getOpOperands()) { 370 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { 371 operand.set(castOp.source()); 372 resultTypes.push_back(castOp.source().getType()); 373 } else { 374 resultTypes.push_back(operand.get().getType()); 375 } 376 } 377 378 auto newFuncType = FunctionType::get( 379 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 380 funcOp.setType(newFuncType); 381 } 382 383 LogicalResult mlir::bufferization::runOneShotModuleBufferize( 384 ModuleOp moduleOp, OneShotBufferizationOptions options) { 385 assert(options.bufferizeFunctionBoundaries && 386 "expected that function boundary bufferization is activated"); 387 IRRewriter rewriter(moduleOp.getContext()); 388 OneShotAnalysisState analysisState(moduleOp, options); 389 BufferizationState bufferizationState(analysisState); 390 FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); 391 BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); 392 393 // A list of functions in the order in which they are analyzed + bufferized. 394 SmallVector<func::FuncOp> orderedFuncOps; 395 396 // A mapping of FuncOps to their callers. 397 FuncCallerMap callerMap; 398 399 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 400 return failure(); 401 402 // Analyze ops. 403 for (func::FuncOp funcOp : orderedFuncOps) { 404 // No body => no analysis. 405 if (funcOp.getBody().empty()) 406 continue; 407 408 // Now analyzing function. 409 funcState.startFunctionAnalysis(funcOp); 410 411 // Gather equivalence info for CallOps. 412 equivalenceAnalysis(funcOp, aliasInfo, funcState); 413 414 // Analyze funcOp. 415 if (failed(analyzeOp(funcOp, analysisState))) 416 return failure(); 417 418 // Run some extra function analyses. 419 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, analysisState)) || 420 failed(funcOpBbArgReadWriteAnalysis(funcOp, analysisState))) 421 return failure(); 422 423 // Mark op as fully analyzed. 424 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 425 426 // Add annotations to function arguments. 427 if (options.testAnalysisOnly) 428 annotateOpsWithBufferizationMarkers(funcOp, analysisState); 429 } 430 431 if (options.testAnalysisOnly) 432 return success(); 433 434 // Bufferize functions. 435 for (func::FuncOp funcOp : orderedFuncOps) { 436 // Note: It would be good to apply cleanups here but we cannot as aliasInfo 437 // would be invalidated. 438 if (failed(bufferizeOp(funcOp, bufferizationState))) 439 return failure(); 440 // Change buffer return types to more precise layout maps. 441 if (options.functionBoundaryTypeConversion == 442 BufferizationOptions::LayoutMapOption::InferLayoutMap) 443 foldMemRefCasts(funcOp); 444 } 445 446 // Check result. 447 for (func::FuncOp funcOp : orderedFuncOps) { 448 if (!options.allowReturnAllocs && 449 llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) { 450 return t.isa<MemRefType, UnrankedMemRefType>(); 451 })) { 452 funcOp->emitError("memref return type is unsupported"); 453 return failure(); 454 } 455 } 456 457 // Finalize all buffers. 458 if (failed(finalizeBuffers(moduleOp, options))) 459 return failure(); 460 461 // Post-pass cleanup of function argument attributes. 462 moduleOp.walk([&](func::FuncOp op) { 463 for (BlockArgument bbArg : op.getArguments()) 464 removeBufferizationAttributes(bbArg); 465 }); 466 467 return success(); 468 } 469