1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Module Bufferization is an extension of One-Shot Bufferize that 10 // bufferizes function boundaries. It provides `BufferizableOpInterface` 11 // implementations for FuncOp, CallOp and ReturnOp. 12 // 13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14 // This function analyzes the given module and determines the order of analysis 15 // and bufferization: Functions that are called are processed before their 16 // respective callers. 17 // 18 // After analyzing a FuncOp, additional information about its bbArgs is 19 // gathered through PostAnalysisStepFns and stored in `FuncAnalysisState`. 20 // 21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22 // for 23 // each tensor return value (if any). 24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25 // read/written. 26 // 27 // Only tensors that are equivalent to some FuncOp bbArg may be returned. 28 // Bufferization currently fails if other tensors (in particular tensors that 29 // bufferize out-of-place and result in a new buffer allocation) are returned. 30 // In the future, such allocations could be hoisted to the caller. 31 // 32 // Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. 33 // ``` 34 // func @foo() -> tensor<?xf32> { 35 // %0 = linalg.init_tensor [...] : tensor<?xf32> 36 // return %0 : tensor<?xf32> 37 // } 38 // ``` 39 // 40 // Module Bufferization implements the following calling convention. 41 // 42 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 43 // be written to in-place. 44 // * If a tensor operand of a CallOp is read after the CallOp, the operand of 45 // the CallOp must bufferize out-of-place. 46 // 47 // Example: The tensor.insert op bufferizes in-place because it is allowed to 48 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 49 // out-of-place because `%t0` is modified by the callee but read by the 50 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 51 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 52 // ``` 53 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 54 // %f = ... : f32 55 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 56 // return %0 : tensor<?xf32> 57 // } 58 // 59 // func @caller() -> () { 60 // %t0 = ... : tensor<?xf32> 61 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 62 // %2 = tensor.extract %1[...] : tensor<?xf32> 63 // } 64 // ``` 65 // 66 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 67 // analyze the function body. In such a case, the CallOp analysis conservatively 68 // assumes that each tensor OpOperand is both read and written. 69 // 70 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 71 // as "not reading" and/or "not writing". 72 73 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 74 75 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 76 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 77 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 78 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 79 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 80 #include "mlir/Dialect/Func/IR/FuncOps.h" 81 #include "mlir/Dialect/MemRef/IR/MemRef.h" 82 #include "mlir/IR/Operation.h" 83 84 using namespace mlir; 85 using namespace mlir::bufferization; 86 using namespace mlir::bufferization::func_ext; 87 88 /// A mapping of FuncOps to their callers. 89 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 90 91 /// Get FuncAnalysisState. 92 static const FuncAnalysisState & 93 getFuncAnalysisState(const AnalysisState &state) { 94 Optional<const FuncAnalysisState *> maybeState = 95 state.getDialectState<FuncAnalysisState>( 96 func::FuncDialect::getDialectNamespace()); 97 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 98 return **maybeState; 99 } 100 101 /// Get or create FuncAnalysisState. 102 static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { 103 return state.getOrCreateDialectState<FuncAnalysisState>( 104 func::FuncDialect::getDialectNamespace()); 105 } 106 107 /// Return the state (phase) of analysis of the FuncOp. 108 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 109 func::FuncOp funcOp) { 110 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 111 auto it = funcState.analyzedFuncOps.find(funcOp); 112 if (it == funcState.analyzedFuncOps.end()) 113 return FuncOpAnalysisState::NotAnalyzed; 114 return it->second; 115 } 116 117 /// Return the unique ReturnOp that terminates `funcOp`. 118 /// Return nullptr if there is no such unique ReturnOp. 119 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 120 func::ReturnOp returnOp; 121 for (Block &b : funcOp.getBody()) { 122 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 123 if (returnOp) 124 return nullptr; 125 returnOp = candidateOp; 126 } 127 } 128 return returnOp; 129 } 130 131 namespace { 132 133 /// Annotate IR with the results of the analysis. For testing purposes only. 134 static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 135 BlockArgument bbArg) { 136 const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 137 Operation *op = returnVal.getOwner(); 138 139 SmallVector<int64_t> equivBbArgs; 140 if (op->hasAttr(kEquivalentArgsAttr)) { 141 auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>(); 142 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 143 return a.cast<IntegerAttr>().getValue().getSExtValue(); 144 })); 145 } else { 146 equivBbArgs.append(op->getNumOperands(), -1); 147 } 148 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 149 150 OpBuilder b(op->getContext()); 151 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 152 } 153 154 /// Store function BlockArguments that are equivalent to/aliasing a returned 155 /// value in FuncAnalysisState. 156 static LogicalResult 157 aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, 158 BufferizationAliasInfo &aliasInfo, 159 SmallVector<Operation *> &newOps) { 160 FuncAnalysisState &funcState = getFuncAnalysisState(state); 161 162 // Support only single return-terminated block in the function. 163 auto funcOp = cast<func::FuncOp>(op); 164 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 165 assert(returnOp && "expected func with single return op"); 166 167 for (OpOperand &returnVal : returnOp->getOpOperands()) 168 if (returnVal.get().getType().isa<RankedTensorType>()) 169 for (BlockArgument bbArg : funcOp.getArguments()) 170 if (bbArg.getType().isa<RankedTensorType>()) { 171 int64_t returnIdx = returnVal.getOperandNumber(); 172 int64_t bbArgIdx = bbArg.getArgNumber(); 173 if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { 174 funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; 175 if (state.getOptions().testAnalysisOnly) 176 annotateEquivalentReturnBbArg(returnVal, bbArg); 177 } 178 if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) { 179 funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); 180 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 181 } 182 } 183 184 return success(); 185 } 186 187 /// Return true if the buffer of the given tensor value is written to. Must not 188 /// be called for values inside not yet analyzed functions. (Post-analysis 189 /// steps do not have to be run yet, i.e., "in progress" is also OK.) 190 static bool isValueWritten(Value value, const AnalysisState &state, 191 const BufferizationAliasInfo &aliasInfo) { 192 #ifndef NDEBUG 193 assert(value.getType().isa<TensorType>() && "expected TensorType"); 194 func::FuncOp funcOp; 195 if (auto bbArg = value.dyn_cast<BlockArgument>()) { 196 Operation *owner = bbArg.getOwner()->getParentOp(); 197 funcOp = isa<func::FuncOp>(owner) ? cast<func::FuncOp>(owner) 198 : owner->getParentOfType<func::FuncOp>(); 199 } else { 200 funcOp = value.getDefiningOp()->getParentOfType<func::FuncOp>(); 201 } 202 assert(getFuncOpAnalysisState(state, funcOp) != 203 FuncOpAnalysisState::NotAnalyzed && 204 "FuncOp must be fully analyzed or analysis in progress"); 205 #endif // NDEBUG 206 207 bool isWritten = false; 208 aliasInfo.applyOnAliases(value, [&](Value val) { 209 for (OpOperand &use : val.getUses()) 210 if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use)) 211 isWritten = true; 212 }); 213 return isWritten; 214 } 215 216 static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, 217 bool isRead, bool isWritten) { 218 OpBuilder b(funcOp.getContext()); 219 Attribute accessType; 220 if (isRead && isWritten) { 221 accessType = b.getStringAttr("read-write"); 222 } else if (isRead) { 223 accessType = b.getStringAttr("read"); 224 } else if (isWritten) { 225 accessType = b.getStringAttr("write"); 226 } else { 227 accessType = b.getStringAttr("none"); 228 } 229 funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); 230 } 231 232 /// Determine which FuncOp bbArgs are read and which are written. If this 233 /// PostAnalysisStepFn is run on a function with unknown ops, it will 234 /// conservatively assume that such ops bufferize to a read + write. 235 static LogicalResult 236 funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, 237 BufferizationAliasInfo &aliasInfo, 238 SmallVector<Operation *> &newOps) { 239 FuncAnalysisState &funcState = getFuncAnalysisState(state); 240 auto funcOp = cast<func::FuncOp>(op); 241 242 // If the function has no body, conservatively assume that all args are 243 // read + written. 244 if (funcOp.getBody().empty()) { 245 for (BlockArgument bbArg : funcOp.getArguments()) { 246 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 247 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 248 } 249 250 return success(); 251 } 252 253 for (BlockArgument bbArg : funcOp.getArguments()) { 254 if (!bbArg.getType().isa<TensorType>()) 255 continue; 256 bool isRead = state.isValueRead(bbArg); 257 bool isWritten = isValueWritten(bbArg, state, aliasInfo); 258 if (state.getOptions().testAnalysisOnly) 259 annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); 260 if (isRead) 261 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 262 if (isWritten) 263 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 264 } 265 266 return success(); 267 } 268 } // namespace 269 270 /// Remove bufferization attributes on FuncOp arguments. 271 static void removeBufferizationAttributes(BlockArgument bbArg) { 272 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 273 funcOp.removeArgAttr(bbArg.getArgNumber(), 274 BufferizationDialect::kBufferLayoutAttrName); 275 funcOp.removeArgAttr(bbArg.getArgNumber(), 276 BufferizationDialect::kWritableAttrName); 277 } 278 279 /// Return the func::FuncOp called by `callOp`. 280 static func::FuncOp getCalledFunction(CallOpInterface callOp) { 281 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 282 if (!sym) 283 return nullptr; 284 return dyn_cast_or_null<func::FuncOp>( 285 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 286 } 287 288 /// Gather equivalence info of CallOps. 289 /// Note: This only adds new equivalence info if the called function was already 290 /// analyzed. 291 // TODO: This does not handle cyclic function call graphs etc. 292 static void equivalenceAnalysis(func::FuncOp funcOp, 293 BufferizationAliasInfo &aliasInfo, 294 FuncAnalysisState &funcState) { 295 funcOp->walk([&](func::CallOp callOp) { 296 func::FuncOp calledFunction = getCalledFunction(callOp); 297 assert(calledFunction && "could not retrieved called func::FuncOp"); 298 299 // No equivalence info available for the called function. 300 if (!funcState.equivalentFuncArgs.count(calledFunction)) 301 return WalkResult::skip(); 302 303 for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 304 int64_t returnIdx = it.first; 305 int64_t bbargIdx = it.second; 306 Value returnVal = callOp.getResult(returnIdx); 307 Value argVal = callOp->getOperand(bbargIdx); 308 aliasInfo.unionEquivalenceClasses(returnVal, argVal); 309 } 310 311 return WalkResult::advance(); 312 }); 313 } 314 315 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 316 /// callee-caller order (i.e. callees without callers first). 317 /// Store the map of FuncOp to all its callers in `callerMap`. 318 /// Return `failure()` if a cycle of calls is detected or if we are unable to 319 /// retrieve the called FuncOp from any CallOpInterface. 320 static LogicalResult 321 getFuncOpsOrderedByCalls(ModuleOp moduleOp, 322 SmallVectorImpl<func::FuncOp> &orderedFuncOps, 323 FuncCallerMap &callerMap) { 324 // For each FuncOp, the set of functions called by it (i.e. the union of 325 // symbols of all nested CallOpInterfaceOp). 326 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 327 // For each FuncOp, the number of CallOpInterface it contains. 328 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 329 WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 330 if (!funcOp.getBody().empty()) { 331 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 332 if (!returnOp) 333 return funcOp->emitError() 334 << "cannot bufferize a FuncOp with tensors and " 335 "without a unique ReturnOp"; 336 } 337 338 numberCallOpsContainedInFuncOp[funcOp] = 0; 339 return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { 340 // Only support CallOp for now. 341 if (!isa<func::CallOp>(callOp.getOperation())) 342 return callOp->emitError() << "expected a CallOp"; 343 func::FuncOp calledFunction = getCalledFunction(callOp); 344 assert(calledFunction && "could not retrieved called func::FuncOp"); 345 auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{}); 346 it.first->getSecond().insert(callOp); 347 if (calledBy[calledFunction].count(funcOp) == 0) { 348 calledBy[calledFunction].insert(funcOp); 349 numberCallOpsContainedInFuncOp[funcOp]++; 350 } 351 return WalkResult::advance(); 352 }); 353 }); 354 if (res.wasInterrupted()) 355 return failure(); 356 // Iteratively remove function operation that do not call any of the 357 // functions remaining in the callCounter map and add them to the worklist. 358 while (!numberCallOpsContainedInFuncOp.empty()) { 359 auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 360 [](auto entry) { return entry.getSecond() == 0; }); 361 if (it == numberCallOpsContainedInFuncOp.end()) 362 return moduleOp.emitOpError( 363 "expected callgraph to be free of circular dependencies."); 364 orderedFuncOps.push_back(it->getFirst()); 365 for (auto callee : calledBy[it->getFirst()]) 366 numberCallOpsContainedInFuncOp[callee]--; 367 numberCallOpsContainedInFuncOp.erase(it); 368 } 369 return success(); 370 } 371 372 /// Set the attribute that triggers inplace bufferization on a FuncOp argument 373 /// `bbArg`. 374 static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { 375 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 376 funcOp.setArgAttr(bbArg.getArgNumber(), 377 BufferizableOpInterface::kInplaceableAttrName, 378 BoolAttr::get(bbArg.getContext(), inPlace)); 379 } 380 381 /// Annotate the IR with the result of the analysis. For testing/debugging only. 382 static void annotateOpsWithBufferizationMarkers(func::FuncOp funcOp, 383 const AnalysisState &state) { 384 auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation()); 385 for (BlockArgument bbArg : funcOp.getArguments()) 386 if (bbArg.getType().isa<TensorType>()) 387 setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); 388 } 389 390 /// Fold return values that are memref casts and update function return types. 391 /// 392 /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 393 /// is not known yet. Therefore, the bufferization uses memref types with the 394 /// most generic layout map as function return types. After bufferizing the 395 /// entire function body, a more concise memref type can potentially be used for 396 /// the return type of the function. 397 static void foldMemRefCasts(func::FuncOp funcOp) { 398 if (funcOp.getBody().empty()) 399 return; 400 401 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 402 SmallVector<Type> resultTypes; 403 404 for (OpOperand &operand : returnOp->getOpOperands()) { 405 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { 406 operand.set(castOp.source()); 407 resultTypes.push_back(castOp.source().getType()); 408 } else { 409 resultTypes.push_back(operand.get().getType()); 410 } 411 } 412 413 auto newFuncType = FunctionType::get( 414 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 415 funcOp.setType(newFuncType); 416 } 417 418 LogicalResult mlir::bufferization::runOneShotModuleBufferize( 419 ModuleOp moduleOp, OneShotBufferizationOptions options) { 420 assert(options.bufferizeFunctionBoundaries && 421 "expected that function boundary bufferization is activated"); 422 IRRewriter rewriter(moduleOp.getContext()); 423 OneShotAnalysisState analysisState(moduleOp, options); 424 BufferizationState bufferizationState(analysisState); 425 FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); 426 BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); 427 428 // A list of functions in the order in which they are analyzed + bufferized. 429 SmallVector<func::FuncOp> orderedFuncOps; 430 431 // A mapping of FuncOps to their callers. 432 FuncCallerMap callerMap; 433 434 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 435 return failure(); 436 437 // Collect bbArg/return value information after the analysis. 438 options.addPostAnalysisStep(aliasingFuncOpBBArgsAnalysis); 439 options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); 440 441 // Analyze ops. 442 for (func::FuncOp funcOp : orderedFuncOps) { 443 // No body => no analysis. 444 if (funcOp.getBody().empty()) 445 continue; 446 447 // Now analyzing function. 448 funcState.startFunctionAnalysis(funcOp); 449 450 // Gather equivalence info for CallOps. 451 equivalenceAnalysis(funcOp, aliasInfo, funcState); 452 453 // Analyze funcOp. 454 if (failed(analyzeOp(funcOp, analysisState))) 455 return failure(); 456 457 // Mark op as fully analyzed. 458 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 459 460 // Add annotations to function arguments. 461 if (options.testAnalysisOnly) 462 annotateOpsWithBufferizationMarkers(funcOp, analysisState); 463 } 464 465 if (options.testAnalysisOnly) 466 return success(); 467 468 // Bufferize functions. 469 for (func::FuncOp funcOp : orderedFuncOps) { 470 // Note: It would be good to apply cleanups here but we cannot as aliasInfo 471 // would be invalidated. 472 if (failed(bufferizeOp(funcOp, bufferizationState))) 473 return failure(); 474 foldMemRefCasts(funcOp); 475 } 476 477 // Check result. 478 for (func::FuncOp funcOp : orderedFuncOps) { 479 if (!options.allowReturnAllocs && 480 llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) { 481 return t.isa<MemRefType, UnrankedMemRefType>(); 482 })) { 483 funcOp->emitError("memref return type is unsupported"); 484 return failure(); 485 } 486 } 487 488 // Finalize all buffers. 489 if (failed(finalizeBuffers(moduleOp, options))) 490 return failure(); 491 492 // Post-pass cleanup of function argument attributes. 493 moduleOp.walk([&](func::FuncOp op) { 494 for (BlockArgument bbArg : op.getArguments()) 495 removeBufferizationAttributes(bbArg); 496 }); 497 498 return success(); 499 } 500