1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Module Bufferization is an extension of One-Shot Bufferize that 10 // bufferizes function boundaries. It provides `BufferizableOpInterface` 11 // implementations for FuncOp, CallOp and ReturnOp. 12 // 13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14 // This function analyzes the given module and determines the order of analysis 15 // and bufferization: Functions that are called are processed before their 16 // respective callers. 17 // 18 // After analyzing a FuncOp, additional information about its bbArgs is 19 // gathered and stored in `FuncAnalysisState`. 20 // 21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22 // for 23 // each tensor return value (if any). 24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25 // read/written. 26 // 27 // Only tensors that are equivalent to some FuncOp bbArg may be returned. 28 // Bufferization currently fails if other tensors (in particular tensors that 29 // bufferize out-of-place and result in a new buffer allocation) are returned. 30 // In the future, such allocations could be hoisted to the caller. 31 // 32 // Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. 33 // ``` 34 // func @foo() -> tensor<?xf32> { 35 // %0 = bufferization.alloc_tensor(...) : tensor<?xf32> 36 // return %0 : tensor<?xf32> 37 // } 38 // ``` 39 // 40 // Module Bufferization implements the following calling convention. 41 // 42 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 43 // be written to in-place. 44 // * If a tensor operand of a CallOp is read after the CallOp, the operand of 45 // the CallOp must bufferize out-of-place. 46 // 47 // Example: The tensor.insert op bufferizes in-place because it is allowed to 48 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 49 // out-of-place because `%t0` is modified by the callee but read by the 50 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 51 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 52 // ``` 53 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 54 // %f = ... : f32 55 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 56 // return %0 : tensor<?xf32> 57 // } 58 // 59 // func @caller() -> () { 60 // %t0 = ... : tensor<?xf32> 61 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 62 // %2 = tensor.extract %1[...] : tensor<?xf32> 63 // } 64 // ``` 65 // 66 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 67 // analyze the function body. In such a case, the CallOp analysis conservatively 68 // assumes that each tensor OpOperand is both read and written. 69 // 70 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 71 // as "not reading" and/or "not writing". 72 73 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 74 75 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 76 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 77 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 78 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 79 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 80 #include "mlir/Dialect/Func/IR/FuncOps.h" 81 #include "mlir/Dialect/MemRef/IR/MemRef.h" 82 #include "mlir/IR/Operation.h" 83 84 using namespace mlir; 85 using namespace mlir::bufferization; 86 using namespace mlir::bufferization::func_ext; 87 88 /// A mapping of FuncOps to their callers. 89 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 90 91 /// Get FuncAnalysisState. 92 static const FuncAnalysisState & 93 getFuncAnalysisState(const AnalysisState &state) { 94 Optional<const FuncAnalysisState *> maybeState = 95 state.getDialectState<FuncAnalysisState>( 96 func::FuncDialect::getDialectNamespace()); 97 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 98 return **maybeState; 99 } 100 101 /// Get or create FuncAnalysisState. 102 static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { 103 return state.getOrCreateDialectState<FuncAnalysisState>( 104 func::FuncDialect::getDialectNamespace()); 105 } 106 107 /// Return the state (phase) of analysis of the FuncOp. 108 /// Used for debug modes. 109 LLVM_ATTRIBUTE_UNUSED 110 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 111 func::FuncOp funcOp) { 112 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 113 auto it = funcState.analyzedFuncOps.find(funcOp); 114 if (it == funcState.analyzedFuncOps.end()) 115 return FuncOpAnalysisState::NotAnalyzed; 116 return it->second; 117 } 118 119 /// Return the unique ReturnOp that terminates `funcOp`. 120 /// Return nullptr if there is no such unique ReturnOp. 121 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 122 func::ReturnOp returnOp; 123 for (Block &b : funcOp.getBody()) { 124 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 125 if (returnOp) 126 return nullptr; 127 returnOp = candidateOp; 128 } 129 } 130 return returnOp; 131 } 132 133 namespace { 134 135 /// Annotate IR with the results of the analysis. For testing purposes only. 136 static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 137 BlockArgument bbArg) { 138 const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 139 Operation *op = returnVal.getOwner(); 140 141 SmallVector<int64_t> equivBbArgs; 142 if (op->hasAttr(kEquivalentArgsAttr)) { 143 auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>(); 144 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 145 return a.cast<IntegerAttr>().getValue().getSExtValue(); 146 })); 147 } else { 148 equivBbArgs.append(op->getNumOperands(), -1); 149 } 150 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 151 152 OpBuilder b(op->getContext()); 153 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 154 } 155 156 /// Store function BlockArguments that are equivalent to/aliasing a returned 157 /// value in FuncAnalysisState. 158 static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, 159 OneShotAnalysisState &state) { 160 FuncAnalysisState &funcState = getFuncAnalysisState(state); 161 162 // Support only single return-terminated block in the function. 163 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 164 assert(returnOp && "expected func with single return op"); 165 166 for (OpOperand &returnVal : returnOp->getOpOperands()) 167 if (returnVal.get().getType().isa<RankedTensorType>()) 168 for (BlockArgument bbArg : funcOp.getArguments()) 169 if (bbArg.getType().isa<RankedTensorType>()) { 170 int64_t returnIdx = returnVal.getOperandNumber(); 171 int64_t bbArgIdx = bbArg.getArgNumber(); 172 if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { 173 funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; 174 if (state.getOptions().testAnalysisOnly) 175 annotateEquivalentReturnBbArg(returnVal, bbArg); 176 } 177 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) { 178 funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); 179 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 180 } 181 } 182 183 return success(); 184 } 185 186 static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, 187 bool isRead, bool isWritten) { 188 OpBuilder b(funcOp.getContext()); 189 Attribute accessType; 190 if (isRead && isWritten) { 191 accessType = b.getStringAttr("read-write"); 192 } else if (isRead) { 193 accessType = b.getStringAttr("read"); 194 } else if (isWritten) { 195 accessType = b.getStringAttr("write"); 196 } else { 197 accessType = b.getStringAttr("none"); 198 } 199 funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); 200 } 201 202 /// Determine which FuncOp bbArgs are read and which are written. When run on a 203 /// function with unknown ops, we conservatively assume that such ops bufferize 204 /// to a read + write. 205 static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp, 206 OneShotAnalysisState &state) { 207 FuncAnalysisState &funcState = getFuncAnalysisState(state); 208 209 // If the function has no body, conservatively assume that all args are 210 // read + written. 211 if (funcOp.getBody().empty()) { 212 for (BlockArgument bbArg : funcOp.getArguments()) { 213 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 214 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 215 } 216 217 return success(); 218 } 219 220 for (BlockArgument bbArg : funcOp.getArguments()) { 221 if (!bbArg.getType().isa<TensorType>()) 222 continue; 223 bool isRead = state.isValueRead(bbArg); 224 bool isWritten = state.isValueWritten(bbArg); 225 if (state.getOptions().testAnalysisOnly) 226 annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); 227 if (isRead) 228 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 229 if (isWritten) 230 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 231 } 232 233 return success(); 234 } 235 } // namespace 236 237 /// Remove bufferization attributes on FuncOp arguments. 238 static void removeBufferizationAttributes(BlockArgument bbArg) { 239 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 240 funcOp.removeArgAttr(bbArg.getArgNumber(), 241 BufferizationDialect::kBufferLayoutAttrName); 242 funcOp.removeArgAttr(bbArg.getArgNumber(), 243 BufferizationDialect::kWritableAttrName); 244 } 245 246 /// Return the func::FuncOp called by `callOp`. 247 static func::FuncOp getCalledFunction(CallOpInterface callOp) { 248 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 249 if (!sym) 250 return nullptr; 251 return dyn_cast_or_null<func::FuncOp>( 252 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 253 } 254 255 /// Gather equivalence info of CallOps. 256 /// Note: This only adds new equivalence info if the called function was already 257 /// analyzed. 258 // TODO: This does not handle cyclic function call graphs etc. 259 static void equivalenceAnalysis(func::FuncOp funcOp, 260 BufferizationAliasInfo &aliasInfo, 261 OneShotAnalysisState &state) { 262 FuncAnalysisState &funcState = getFuncAnalysisState(state); 263 funcOp->walk([&](func::CallOp callOp) { 264 func::FuncOp calledFunction = getCalledFunction(callOp); 265 assert(calledFunction && "could not retrieved called func::FuncOp"); 266 267 // No equivalence info available for the called function. 268 if (!funcState.equivalentFuncArgs.count(calledFunction)) 269 return WalkResult::skip(); 270 271 for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 272 int64_t returnIdx = it.first; 273 int64_t bbargIdx = it.second; 274 if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) 275 continue; 276 Value returnVal = callOp.getResult(returnIdx); 277 Value argVal = callOp->getOperand(bbargIdx); 278 aliasInfo.unionEquivalenceClasses(returnVal, argVal); 279 } 280 281 return WalkResult::advance(); 282 }); 283 } 284 285 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 286 /// callee-caller order (i.e. callees without callers first). 287 /// Store the map of FuncOp to all its callers in `callerMap`. 288 /// Return `failure()` if a cycle of calls is detected or if we are unable to 289 /// retrieve the called FuncOp from any CallOpInterface. 290 static LogicalResult 291 getFuncOpsOrderedByCalls(ModuleOp moduleOp, 292 SmallVectorImpl<func::FuncOp> &orderedFuncOps, 293 FuncCallerMap &callerMap) { 294 // For each FuncOp, the set of functions called by it (i.e. the union of 295 // symbols of all nested CallOpInterfaceOp). 296 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 297 // For each FuncOp, the number of CallOpInterface it contains. 298 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 299 WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 300 if (!funcOp.getBody().empty()) { 301 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 302 if (!returnOp) 303 return funcOp->emitError() 304 << "cannot bufferize a FuncOp with tensors and " 305 "without a unique ReturnOp"; 306 } 307 308 numberCallOpsContainedInFuncOp[funcOp] = 0; 309 return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { 310 // Only support CallOp for now. 311 if (!isa<func::CallOp>(callOp.getOperation())) 312 return callOp->emitError() << "expected a CallOp"; 313 func::FuncOp calledFunction = getCalledFunction(callOp); 314 assert(calledFunction && "could not retrieved called func::FuncOp"); 315 callerMap[calledFunction].insert(callOp); 316 if (calledBy[calledFunction].insert(funcOp).second) { 317 numberCallOpsContainedInFuncOp[funcOp]++; 318 } 319 return WalkResult::advance(); 320 }); 321 }); 322 if (res.wasInterrupted()) 323 return failure(); 324 // Iteratively remove function operation that do not call any of the 325 // functions remaining in the callCounter map and add them to the worklist. 326 while (!numberCallOpsContainedInFuncOp.empty()) { 327 auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 328 [](auto entry) { return entry.getSecond() == 0; }); 329 if (it == numberCallOpsContainedInFuncOp.end()) 330 return moduleOp.emitOpError( 331 "expected callgraph to be free of circular dependencies."); 332 orderedFuncOps.push_back(it->getFirst()); 333 for (auto callee : calledBy[it->getFirst()]) 334 numberCallOpsContainedInFuncOp[callee]--; 335 numberCallOpsContainedInFuncOp.erase(it); 336 } 337 return success(); 338 } 339 340 /// Set the attribute that triggers inplace bufferization on a FuncOp argument 341 /// `bbArg`. 342 static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { 343 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 344 funcOp.setArgAttr(bbArg.getArgNumber(), 345 BufferizableOpInterface::kInplaceableAttrName, 346 BoolAttr::get(bbArg.getContext(), inPlace)); 347 } 348 349 /// Annotate the IR with the result of the analysis. For testing/debugging only. 350 static void annotateOpsWithBufferizationMarkers(func::FuncOp funcOp, 351 const AnalysisState &state) { 352 auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation()); 353 for (BlockArgument bbArg : funcOp.getArguments()) 354 if (bbArg.getType().isa<TensorType>()) 355 setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); 356 } 357 358 /// Fold return values that are memref casts and update function return types. 359 /// 360 /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 361 /// is not known yet. Therefore, the bufferization uses memref types with the 362 /// most generic layout map as function return types. After bufferizing the 363 /// entire function body, a more concise memref type can potentially be used for 364 /// the return type of the function. 365 static void foldMemRefCasts(func::FuncOp funcOp) { 366 if (funcOp.getBody().empty()) 367 return; 368 369 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 370 SmallVector<Type> resultTypes; 371 372 for (OpOperand &operand : returnOp->getOpOperands()) { 373 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { 374 operand.set(castOp.source()); 375 resultTypes.push_back(castOp.source().getType()); 376 } else { 377 resultTypes.push_back(operand.get().getType()); 378 } 379 } 380 381 auto newFuncType = FunctionType::get( 382 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 383 funcOp.setType(newFuncType); 384 } 385 386 LogicalResult 387 mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, 388 OneShotAnalysisState &state) { 389 OneShotBufferizationOptions options = 390 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 391 assert(options.bufferizeFunctionBoundaries && 392 "expected that function boundary bufferization is activated"); 393 FuncAnalysisState &funcState = getFuncAnalysisState(state); 394 BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); 395 396 // A list of functions in the order in which they are analyzed + bufferized. 397 SmallVector<func::FuncOp> orderedFuncOps; 398 399 // A mapping of FuncOps to their callers. 400 FuncCallerMap callerMap; 401 402 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 403 return failure(); 404 405 // Analyze ops. 406 for (func::FuncOp funcOp : orderedFuncOps) { 407 // No body => no analysis. 408 if (funcOp.getBody().empty()) 409 continue; 410 411 // Now analyzing function. 412 funcState.startFunctionAnalysis(funcOp); 413 414 // Gather equivalence info for CallOps. 415 equivalenceAnalysis(funcOp, aliasInfo, state); 416 417 // Analyze funcOp. 418 if (failed(analyzeOp(funcOp, state))) 419 return failure(); 420 421 // Run some extra function analyses. 422 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state)) || 423 failed(funcOpBbArgReadWriteAnalysis(funcOp, state))) 424 return failure(); 425 426 // Mark op as fully analyzed. 427 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 428 429 // Add annotations to function arguments. 430 if (options.testAnalysisOnly) 431 annotateOpsWithBufferizationMarkers(funcOp, state); 432 } 433 434 return success(); 435 } 436 437 LogicalResult mlir::bufferization::bufferizeModuleOp( 438 ModuleOp moduleOp, const OneShotAnalysisState &analysisState) { 439 auto const &options = static_cast<const OneShotBufferizationOptions &>( 440 analysisState.getOptions()); 441 assert(options.bufferizeFunctionBoundaries && 442 "expected that function boundary bufferization is activated"); 443 IRRewriter rewriter(moduleOp.getContext()); 444 BufferizationState bufferizationState(analysisState); 445 446 // A list of functions in the order in which they are analyzed + bufferized. 447 SmallVector<func::FuncOp> orderedFuncOps; 448 449 // A mapping of FuncOps to their callers. 450 FuncCallerMap callerMap; 451 452 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 453 return failure(); 454 455 // Bufferize functions. 456 for (func::FuncOp funcOp : orderedFuncOps) { 457 // Note: It would be good to apply cleanups here but we cannot as aliasInfo 458 // would be invalidated. 459 if (failed(bufferizeOp(funcOp, bufferizationState))) 460 return failure(); 461 // Change buffer return types to more precise layout maps. 462 if (options.functionBoundaryTypeConversion == 463 BufferizationOptions::LayoutMapOption::InferLayoutMap) 464 foldMemRefCasts(funcOp); 465 } 466 467 // Check result. 468 for (func::FuncOp funcOp : orderedFuncOps) { 469 if (!options.allowReturnAllocs && 470 llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) { 471 return t.isa<MemRefType, UnrankedMemRefType>(); 472 })) { 473 funcOp->emitError("memref return type is unsupported"); 474 return failure(); 475 } 476 } 477 478 // Post-pass cleanup of function argument attributes. 479 moduleOp.walk([&](func::FuncOp op) { 480 for (BlockArgument bbArg : op.getArguments()) 481 removeBufferizationAttributes(bbArg); 482 }); 483 484 return success(); 485 } 486 487 LogicalResult mlir::bufferization::runOneShotModuleBufferize( 488 ModuleOp moduleOp, const OneShotBufferizationOptions &options) { 489 assert(options.bufferizeFunctionBoundaries && 490 "expected that function boundary bufferization is activated"); 491 OneShotAnalysisState analysisState(moduleOp, options); 492 if (failed(analyzeModuleOp(moduleOp, analysisState))) 493 return failure(); 494 if (options.testAnalysisOnly) 495 return success(); 496 if (failed(bufferizeModuleOp(moduleOp, analysisState))) 497 return failure(); 498 return success(); 499 } 500