1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Module Bufferization is an extension of One-Shot Bufferize that 10 // bufferizes function boundaries. It provides `BufferizableOpInterface` 11 // implementations for FuncOp, CallOp and ReturnOp. 12 // 13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14 // This function analyzes the given module and determines the order of analysis 15 // and bufferization: Functions that are called are processed before their 16 // respective callers. 17 // 18 // After analyzing a FuncOp, additional information about its bbArgs is 19 // gathered through PostAnalysisStepFns and stored in `FuncAnalysisState`. 20 // 21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22 // for 23 // each tensor return value (if any). 24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25 // read/written. 26 // 27 // Only tensors that are equivalent to some FuncOp bbArg may be returned. 28 // Bufferization currently fails if other tensors (in particular tensors that 29 // bufferize out-of-place and result in a new buffer allocation) are returned. 30 // In the future, such allocations could be hoisted to the caller. 31 // 32 // Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. 33 // ``` 34 // func @foo() -> tensor<?xf32> { 35 // %0 = bufferization.alloc_tensor(...) : tensor<?xf32> 36 // return %0 : tensor<?xf32> 37 // } 38 // ``` 39 // 40 // Module Bufferization implements the following calling convention. 41 // 42 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 43 // be written to in-place. 44 // * If a tensor operand of a CallOp is read after the CallOp, the operand of 45 // the CallOp must bufferize out-of-place. 46 // 47 // Example: The tensor.insert op bufferizes in-place because it is allowed to 48 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 49 // out-of-place because `%t0` is modified by the callee but read by the 50 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 51 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 52 // ``` 53 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 54 // %f = ... : f32 55 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 56 // return %0 : tensor<?xf32> 57 // } 58 // 59 // func @caller() -> () { 60 // %t0 = ... : tensor<?xf32> 61 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 62 // %2 = tensor.extract %1[...] : tensor<?xf32> 63 // } 64 // ``` 65 // 66 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 67 // analyze the function body. In such a case, the CallOp analysis conservatively 68 // assumes that each tensor OpOperand is both read and written. 69 // 70 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 71 // as "not reading" and/or "not writing". 72 73 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 74 75 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 76 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 77 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 78 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 79 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 80 #include "mlir/Dialect/Func/IR/FuncOps.h" 81 #include "mlir/Dialect/MemRef/IR/MemRef.h" 82 #include "mlir/IR/Operation.h" 83 84 using namespace mlir; 85 using namespace mlir::bufferization; 86 using namespace mlir::bufferization::func_ext; 87 88 /// A mapping of FuncOps to their callers. 89 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 90 91 /// Get FuncAnalysisState. 92 static const FuncAnalysisState & 93 getFuncAnalysisState(const AnalysisState &state) { 94 Optional<const FuncAnalysisState *> maybeState = 95 state.getDialectState<FuncAnalysisState>( 96 func::FuncDialect::getDialectNamespace()); 97 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 98 return **maybeState; 99 } 100 101 /// Get or create FuncAnalysisState. 102 static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { 103 return state.getOrCreateDialectState<FuncAnalysisState>( 104 func::FuncDialect::getDialectNamespace()); 105 } 106 107 /// Return the state (phase) of analysis of the FuncOp. 108 /// Used for debug modes. 109 LLVM_ATTRIBUTE_UNUSED 110 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 111 func::FuncOp funcOp) { 112 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 113 auto it = funcState.analyzedFuncOps.find(funcOp); 114 if (it == funcState.analyzedFuncOps.end()) 115 return FuncOpAnalysisState::NotAnalyzed; 116 return it->second; 117 } 118 119 /// Return the unique ReturnOp that terminates `funcOp`. 120 /// Return nullptr if there is no such unique ReturnOp. 121 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 122 func::ReturnOp returnOp; 123 for (Block &b : funcOp.getBody()) { 124 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 125 if (returnOp) 126 return nullptr; 127 returnOp = candidateOp; 128 } 129 } 130 return returnOp; 131 } 132 133 namespace { 134 135 /// Annotate IR with the results of the analysis. For testing purposes only. 136 static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 137 BlockArgument bbArg) { 138 const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 139 Operation *op = returnVal.getOwner(); 140 141 SmallVector<int64_t> equivBbArgs; 142 if (op->hasAttr(kEquivalentArgsAttr)) { 143 auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>(); 144 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 145 return a.cast<IntegerAttr>().getValue().getSExtValue(); 146 })); 147 } else { 148 equivBbArgs.append(op->getNumOperands(), -1); 149 } 150 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 151 152 OpBuilder b(op->getContext()); 153 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 154 } 155 156 /// Store function BlockArguments that are equivalent to/aliasing a returned 157 /// value in FuncAnalysisState. 158 static LogicalResult 159 aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, 160 BufferizationAliasInfo &aliasInfo, 161 SmallVector<Operation *> &newOps) { 162 FuncAnalysisState &funcState = getFuncAnalysisState(state); 163 164 // Support only single return-terminated block in the function. 165 auto funcOp = cast<func::FuncOp>(op); 166 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 167 assert(returnOp && "expected func with single return op"); 168 169 for (OpOperand &returnVal : returnOp->getOpOperands()) 170 if (returnVal.get().getType().isa<RankedTensorType>()) 171 for (BlockArgument bbArg : funcOp.getArguments()) 172 if (bbArg.getType().isa<RankedTensorType>()) { 173 int64_t returnIdx = returnVal.getOperandNumber(); 174 int64_t bbArgIdx = bbArg.getArgNumber(); 175 if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { 176 funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; 177 if (state.getOptions().testAnalysisOnly) 178 annotateEquivalentReturnBbArg(returnVal, bbArg); 179 } 180 if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) { 181 funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); 182 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 183 } 184 } 185 186 return success(); 187 } 188 189 /// Return true if the buffer of the given tensor value is written to. Must not 190 /// be called for values inside not yet analyzed functions. (Post-analysis 191 /// steps do not have to be run yet, i.e., "in progress" is also OK.) 192 static bool isValueWritten(Value value, const AnalysisState &state, 193 const BufferizationAliasInfo &aliasInfo) { 194 #ifndef NDEBUG 195 assert(value.getType().isa<TensorType>() && "expected TensorType"); 196 func::FuncOp funcOp; 197 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 198 Operation *owner = bbArg.getOwner()->getParentOp(); 199 funcOp = isa<func::FuncOp>(owner) ? cast<func::FuncOp>(owner) 200 : owner->getParentOfType<func::FuncOp>(); 201 } else { 202 funcOp = value.getDefiningOp()->getParentOfType<func::FuncOp>(); 203 } 204 assert(getFuncOpAnalysisState(state, funcOp) != 205 FuncOpAnalysisState::NotAnalyzed && 206 "FuncOp must be fully analyzed or analysis in progress"); 207 #endif // NDEBUG 208 209 bool isWritten = false; 210 aliasInfo.applyOnAliases(value, [&](Value val) { 211 for (OpOperand &use : val.getUses()) 212 if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use)) 213 isWritten = true; 214 }); 215 return isWritten; 216 } 217 218 static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, 219 bool isRead, bool isWritten) { 220 OpBuilder b(funcOp.getContext()); 221 Attribute accessType; 222 if (isRead && isWritten) { 223 accessType = b.getStringAttr("read-write"); 224 } else if (isRead) { 225 accessType = b.getStringAttr("read"); 226 } else if (isWritten) { 227 accessType = b.getStringAttr("write"); 228 } else { 229 accessType = b.getStringAttr("none"); 230 } 231 funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); 232 } 233 234 /// Determine which FuncOp bbArgs are read and which are written. If this 235 /// PostAnalysisStepFn is run on a function with unknown ops, it will 236 /// conservatively assume that such ops bufferize to a read + write. 237 static LogicalResult 238 funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, 239 BufferizationAliasInfo &aliasInfo, 240 SmallVector<Operation *> &newOps) { 241 FuncAnalysisState &funcState = getFuncAnalysisState(state); 242 auto funcOp = cast<func::FuncOp>(op); 243 244 // If the function has no body, conservatively assume that all args are 245 // read + written. 246 if (funcOp.getBody().empty()) { 247 for (BlockArgument bbArg : funcOp.getArguments()) { 248 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 249 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 250 } 251 252 return success(); 253 } 254 255 for (BlockArgument bbArg : funcOp.getArguments()) { 256 if (!bbArg.getType().isa<TensorType>()) 257 continue; 258 bool isRead = state.isValueRead(bbArg); 259 bool isWritten = isValueWritten(bbArg, state, aliasInfo); 260 if (state.getOptions().testAnalysisOnly) 261 annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); 262 if (isRead) 263 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 264 if (isWritten) 265 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 266 } 267 268 return success(); 269 } 270 } // namespace 271 272 /// Remove bufferization attributes on FuncOp arguments. 273 static void removeBufferizationAttributes(BlockArgument bbArg) { 274 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 275 funcOp.removeArgAttr(bbArg.getArgNumber(), 276 BufferizationDialect::kBufferLayoutAttrName); 277 funcOp.removeArgAttr(bbArg.getArgNumber(), 278 BufferizationDialect::kWritableAttrName); 279 } 280 281 /// Return the func::FuncOp called by `callOp`. 282 static func::FuncOp getCalledFunction(CallOpInterface callOp) { 283 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 284 if (!sym) 285 return nullptr; 286 return dyn_cast_or_null<func::FuncOp>( 287 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 288 } 289 290 /// Gather equivalence info of CallOps. 291 /// Note: This only adds new equivalence info if the called function was already 292 /// analyzed. 293 // TODO: This does not handle cyclic function call graphs etc. 294 static void equivalenceAnalysis(func::FuncOp funcOp, 295 BufferizationAliasInfo &aliasInfo, 296 FuncAnalysisState &funcState) { 297 funcOp->walk([&](func::CallOp callOp) { 298 func::FuncOp calledFunction = getCalledFunction(callOp); 299 assert(calledFunction && "could not retrieved called func::FuncOp"); 300 301 // No equivalence info available for the called function. 302 if (!funcState.equivalentFuncArgs.count(calledFunction)) 303 return WalkResult::skip(); 304 305 for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 306 int64_t returnIdx = it.first; 307 int64_t bbargIdx = it.second; 308 Value returnVal = callOp.getResult(returnIdx); 309 Value argVal = callOp->getOperand(bbargIdx); 310 aliasInfo.unionEquivalenceClasses(returnVal, argVal); 311 } 312 313 return WalkResult::advance(); 314 }); 315 } 316 317 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 318 /// callee-caller order (i.e. callees without callers first). 319 /// Store the map of FuncOp to all its callers in `callerMap`. 320 /// Return `failure()` if a cycle of calls is detected or if we are unable to 321 /// retrieve the called FuncOp from any CallOpInterface. 322 static LogicalResult 323 getFuncOpsOrderedByCalls(ModuleOp moduleOp, 324 SmallVectorImpl<func::FuncOp> &orderedFuncOps, 325 FuncCallerMap &callerMap) { 326 // For each FuncOp, the set of functions called by it (i.e. the union of 327 // symbols of all nested CallOpInterfaceOp). 328 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 329 // For each FuncOp, the number of CallOpInterface it contains. 330 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 331 WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 332 if (!funcOp.getBody().empty()) { 333 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 334 if (!returnOp) 335 return funcOp->emitError() 336 << "cannot bufferize a FuncOp with tensors and " 337 "without a unique ReturnOp"; 338 } 339 340 numberCallOpsContainedInFuncOp[funcOp] = 0; 341 return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { 342 // Only support CallOp for now. 343 if (!isa<func::CallOp>(callOp.getOperation())) 344 return callOp->emitError() << "expected a CallOp"; 345 func::FuncOp calledFunction = getCalledFunction(callOp); 346 assert(calledFunction && "could not retrieved called func::FuncOp"); 347 callerMap[calledFunction].insert(callOp); 348 if (calledBy[calledFunction].insert(funcOp).second) { 349 numberCallOpsContainedInFuncOp[funcOp]++; 350 } 351 return WalkResult::advance(); 352 }); 353 }); 354 if (res.wasInterrupted()) 355 return failure(); 356 // Iteratively remove function operation that do not call any of the 357 // functions remaining in the callCounter map and add them to the worklist. 358 while (!numberCallOpsContainedInFuncOp.empty()) { 359 auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 360 [](auto entry) { return entry.getSecond() == 0; }); 361 if (it == numberCallOpsContainedInFuncOp.end()) 362 return moduleOp.emitOpError( 363 "expected callgraph to be free of circular dependencies."); 364 orderedFuncOps.push_back(it->getFirst()); 365 for (auto callee : calledBy[it->getFirst()]) 366 numberCallOpsContainedInFuncOp[callee]--; 367 numberCallOpsContainedInFuncOp.erase(it); 368 } 369 return success(); 370 } 371 372 /// Set the attribute that triggers inplace bufferization on a FuncOp argument 373 /// `bbArg`. 374 static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { 375 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 376 funcOp.setArgAttr(bbArg.getArgNumber(), 377 BufferizableOpInterface::kInplaceableAttrName, 378 BoolAttr::get(bbArg.getContext(), inPlace)); 379 } 380 381 /// Annotate the IR with the result of the analysis. For testing/debugging only. 382 static void annotateOpsWithBufferizationMarkers(func::FuncOp funcOp, 383 const AnalysisState &state) { 384 auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation()); 385 for (BlockArgument bbArg : funcOp.getArguments()) 386 if (bbArg.getType().isa<TensorType>()) 387 setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); 388 } 389 390 /// Fold return values that are memref casts and update function return types. 391 /// 392 /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 393 /// is not known yet. Therefore, the bufferization uses memref types with the 394 /// most generic layout map as function return types. After bufferizing the 395 /// entire function body, a more concise memref type can potentially be used for 396 /// the return type of the function. 397 static void foldMemRefCasts(func::FuncOp funcOp) { 398 if (funcOp.getBody().empty()) 399 return; 400 401 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 402 SmallVector<Type> resultTypes; 403 404 for (OpOperand &operand : returnOp->getOpOperands()) { 405 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { 406 operand.set(castOp.source()); 407 resultTypes.push_back(castOp.source().getType()); 408 } else { 409 resultTypes.push_back(operand.get().getType()); 410 } 411 } 412 413 auto newFuncType = FunctionType::get( 414 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 415 funcOp.setType(newFuncType); 416 } 417 418 LogicalResult mlir::bufferization::runOneShotModuleBufferize( 419 ModuleOp moduleOp, OneShotBufferizationOptions options) { 420 assert(options.bufferizeFunctionBoundaries && 421 "expected that function boundary bufferization is activated"); 422 IRRewriter rewriter(moduleOp.getContext()); 423 OneShotAnalysisState analysisState(moduleOp, options); 424 BufferizationState bufferizationState(analysisState); 425 FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); 426 BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); 427 428 // A list of functions in the order in which they are analyzed + bufferized. 429 SmallVector<func::FuncOp> orderedFuncOps; 430 431 // A mapping of FuncOps to their callers. 432 FuncCallerMap callerMap; 433 434 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 435 return failure(); 436 437 // Collect bbArg/return value information after the analysis. 438 options.addPostAnalysisStep(aliasingFuncOpBBArgsAnalysis); 439 options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); 440 441 // Analyze ops. 442 for (func::FuncOp funcOp : orderedFuncOps) { 443 // No body => no analysis. 444 if (funcOp.getBody().empty()) 445 continue; 446 447 // Now analyzing function. 448 funcState.startFunctionAnalysis(funcOp); 449 450 // Gather equivalence info for CallOps. 451 equivalenceAnalysis(funcOp, aliasInfo, funcState); 452 453 // Analyze funcOp. 454 if (failed(analyzeOp(funcOp, analysisState))) 455 return failure(); 456 457 // Mark op as fully analyzed. 458 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 459 460 // Add annotations to function arguments. 461 if (options.testAnalysisOnly) 462 annotateOpsWithBufferizationMarkers(funcOp, analysisState); 463 } 464 465 if (options.testAnalysisOnly) 466 return success(); 467 468 // Bufferize functions. 469 for (func::FuncOp funcOp : orderedFuncOps) { 470 // Note: It would be good to apply cleanups here but we cannot as aliasInfo 471 // would be invalidated. 472 if (failed(bufferizeOp(funcOp, bufferizationState))) 473 return failure(); 474 // Change buffer return types to more precise layout maps. 475 if (options.functionBoundaryTypeConversion == 476 BufferizationOptions::LayoutMapOption::InferLayoutMap) 477 foldMemRefCasts(funcOp); 478 } 479 480 // Check result. 481 for (func::FuncOp funcOp : orderedFuncOps) { 482 if (!options.allowReturnAllocs && 483 llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) { 484 return t.isa<MemRefType, UnrankedMemRefType>(); 485 })) { 486 funcOp->emitError("memref return type is unsupported"); 487 return failure(); 488 } 489 } 490 491 // Finalize all buffers. 492 if (failed(finalizeBuffers(moduleOp, options))) 493 return failure(); 494 495 // Post-pass cleanup of function argument attributes. 496 moduleOp.walk([&](func::FuncOp op) { 497 for (BlockArgument bbArg : op.getArguments()) 498 removeBufferizationAttributes(bbArg); 499 }); 500 501 return success(); 502 } 503