1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Module Bufferization is an extension of One-Shot Bufferize that 10 // bufferizes function boundaries. It provides `BufferizableOpInterface` 11 // implementations for FuncOp, CallOp and ReturnOp. 12 // 13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14 // This function analyzes the given module and determines the order of analysis 15 // and bufferization: Functions that are called are processed before their 16 // respective callers. 17 // 18 // After analyzing a FuncOp, additional information about its bbArgs is 19 // gathered and stored in `FuncAnalysisState`. 20 // 21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22 // for 23 // each tensor return value (if any). 24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25 // read/written. 26 // 27 // Module Bufferization implements the following calling convention. 28 // 29 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 30 // be written to in-place. 31 // * If a tensor operand of a CallOp is read after the CallOp, the operand of 32 // the CallOp must bufferize out-of-place. 33 // 34 // Example: The tensor.insert op bufferizes in-place because it is allowed to 35 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 36 // out-of-place because `%t0` is modified by the callee but read by the 37 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 38 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 39 // ``` 40 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 41 // %f = ... : f32 42 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 43 // return %0 : tensor<?xf32> 44 // } 45 // 46 // func @caller() -> () { 47 // %t0 = ... : tensor<?xf32> 48 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 49 // %2 = tensor.extract %1[...] : tensor<?xf32> 50 // } 51 // ``` 52 // 53 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 54 // analyze the function body. In such a case, the CallOp analysis conservatively 55 // assumes that each tensor OpOperand is both read and written. 56 // 57 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 58 // as "not reading" and/or "not writing". 59 60 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 61 62 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 63 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 64 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 65 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 66 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 67 #include "mlir/Dialect/Func/IR/FuncOps.h" 68 #include "mlir/Dialect/MemRef/IR/MemRef.h" 69 #include "mlir/IR/Operation.h" 70 71 using namespace mlir; 72 using namespace mlir::bufferization; 73 using namespace mlir::bufferization::func_ext; 74 75 /// A mapping of FuncOps to their callers. 76 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 77 78 /// Get FuncAnalysisState. 79 static const FuncAnalysisState & 80 getFuncAnalysisState(const AnalysisState &state) { 81 Optional<const FuncAnalysisState *> maybeState = 82 state.getDialectState<FuncAnalysisState>( 83 func::FuncDialect::getDialectNamespace()); 84 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 85 return **maybeState; 86 } 87 88 /// Get or create FuncAnalysisState. 89 static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { 90 return state.getOrCreateDialectState<FuncAnalysisState>( 91 func::FuncDialect::getDialectNamespace()); 92 } 93 94 /// Return the state (phase) of analysis of the FuncOp. 95 /// Used for debug modes. 96 LLVM_ATTRIBUTE_UNUSED 97 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 98 func::FuncOp funcOp) { 99 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 100 auto it = funcState.analyzedFuncOps.find(funcOp); 101 if (it == funcState.analyzedFuncOps.end()) 102 return FuncOpAnalysisState::NotAnalyzed; 103 return it->second; 104 } 105 106 /// Return the unique ReturnOp that terminates `funcOp`. 107 /// Return nullptr if there is no such unique ReturnOp. 108 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 109 func::ReturnOp returnOp; 110 for (Block &b : funcOp.getBody()) { 111 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 112 if (returnOp) 113 return nullptr; 114 returnOp = candidateOp; 115 } 116 } 117 return returnOp; 118 } 119 120 namespace { 121 122 /// Annotate IR with the results of the analysis. For testing purposes only. 123 static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 124 BlockArgument bbArg) { 125 const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 126 Operation *op = returnVal.getOwner(); 127 128 SmallVector<int64_t> equivBbArgs; 129 if (op->hasAttr(kEquivalentArgsAttr)) { 130 auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>(); 131 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 132 return a.cast<IntegerAttr>().getValue().getSExtValue(); 133 })); 134 } else { 135 equivBbArgs.append(op->getNumOperands(), -1); 136 } 137 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 138 139 OpBuilder b(op->getContext()); 140 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 141 } 142 143 /// Store function BlockArguments that are equivalent to/aliasing a returned 144 /// value in FuncAnalysisState. 145 static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, 146 OneShotAnalysisState &state) { 147 FuncAnalysisState &funcState = getFuncAnalysisState(state); 148 149 // Support only single return-terminated block in the function. 150 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 151 assert(returnOp && "expected func with single return op"); 152 153 for (OpOperand &returnVal : returnOp->getOpOperands()) 154 if (returnVal.get().getType().isa<RankedTensorType>()) 155 for (BlockArgument bbArg : funcOp.getArguments()) 156 if (bbArg.getType().isa<RankedTensorType>()) { 157 int64_t returnIdx = returnVal.getOperandNumber(); 158 int64_t bbArgIdx = bbArg.getArgNumber(); 159 if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { 160 funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; 161 if (state.getOptions().testAnalysisOnly) 162 annotateEquivalentReturnBbArg(returnVal, bbArg); 163 } 164 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) { 165 funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); 166 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 167 } 168 } 169 170 return success(); 171 } 172 173 static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, 174 bool isRead, bool isWritten) { 175 OpBuilder b(funcOp.getContext()); 176 Attribute accessType; 177 if (isRead && isWritten) { 178 accessType = b.getStringAttr("read-write"); 179 } else if (isRead) { 180 accessType = b.getStringAttr("read"); 181 } else if (isWritten) { 182 accessType = b.getStringAttr("write"); 183 } else { 184 accessType = b.getStringAttr("none"); 185 } 186 funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); 187 } 188 189 /// Determine which FuncOp bbArgs are read and which are written. When run on a 190 /// function with unknown ops, we conservatively assume that such ops bufferize 191 /// to a read + write. 192 static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp, 193 OneShotAnalysisState &state) { 194 FuncAnalysisState &funcState = getFuncAnalysisState(state); 195 196 // If the function has no body, conservatively assume that all args are 197 // read + written. 198 if (funcOp.getBody().empty()) { 199 for (BlockArgument bbArg : funcOp.getArguments()) { 200 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 201 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 202 } 203 204 return success(); 205 } 206 207 for (BlockArgument bbArg : funcOp.getArguments()) { 208 if (!bbArg.getType().isa<TensorType>()) 209 continue; 210 bool isRead = state.isValueRead(bbArg); 211 bool isWritten = state.isValueWritten(bbArg); 212 if (state.getOptions().testAnalysisOnly) 213 annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); 214 if (isRead) 215 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 216 if (isWritten) 217 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 218 } 219 220 return success(); 221 } 222 } // namespace 223 224 /// Remove bufferization attributes on FuncOp arguments. 225 static void removeBufferizationAttributes(BlockArgument bbArg) { 226 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 227 funcOp.removeArgAttr(bbArg.getArgNumber(), 228 BufferizationDialect::kBufferLayoutAttrName); 229 funcOp.removeArgAttr(bbArg.getArgNumber(), 230 BufferizationDialect::kWritableAttrName); 231 } 232 233 /// Return the func::FuncOp called by `callOp`. 234 static func::FuncOp getCalledFunction(CallOpInterface callOp) { 235 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 236 if (!sym) 237 return nullptr; 238 return dyn_cast_or_null<func::FuncOp>( 239 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 240 } 241 242 /// Gather equivalence info of CallOps. 243 /// Note: This only adds new equivalence info if the called function was already 244 /// analyzed. 245 // TODO: This does not handle cyclic function call graphs etc. 246 static void equivalenceAnalysis(func::FuncOp funcOp, 247 BufferizationAliasInfo &aliasInfo, 248 OneShotAnalysisState &state) { 249 FuncAnalysisState &funcState = getFuncAnalysisState(state); 250 funcOp->walk([&](func::CallOp callOp) { 251 func::FuncOp calledFunction = getCalledFunction(callOp); 252 assert(calledFunction && "could not retrieved called func::FuncOp"); 253 254 // No equivalence info available for the called function. 255 if (!funcState.equivalentFuncArgs.count(calledFunction)) 256 return WalkResult::skip(); 257 258 for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 259 int64_t returnIdx = it.first; 260 int64_t bbargIdx = it.second; 261 if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) 262 continue; 263 Value returnVal = callOp.getResult(returnIdx); 264 Value argVal = callOp->getOperand(bbargIdx); 265 aliasInfo.unionEquivalenceClasses(returnVal, argVal); 266 } 267 268 return WalkResult::advance(); 269 }); 270 } 271 272 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 273 /// callee-caller order (i.e. callees without callers first). 274 /// Store the map of FuncOp to all its callers in `callerMap`. 275 /// Return `failure()` if a cycle of calls is detected or if we are unable to 276 /// retrieve the called FuncOp from any CallOpInterface. 277 static LogicalResult 278 getFuncOpsOrderedByCalls(ModuleOp moduleOp, 279 SmallVectorImpl<func::FuncOp> &orderedFuncOps, 280 FuncCallerMap &callerMap) { 281 // For each FuncOp, the set of functions called by it (i.e. the union of 282 // symbols of all nested CallOpInterfaceOp). 283 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 284 // For each FuncOp, the number of CallOpInterface it contains. 285 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 286 WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 287 if (!funcOp.getBody().empty()) { 288 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 289 if (!returnOp) 290 return funcOp->emitError() 291 << "cannot bufferize a FuncOp with tensors and " 292 "without a unique ReturnOp"; 293 } 294 295 numberCallOpsContainedInFuncOp[funcOp] = 0; 296 return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { 297 // Only support CallOp for now. 298 if (!isa<func::CallOp>(callOp.getOperation())) 299 return callOp->emitError() << "expected a CallOp"; 300 func::FuncOp calledFunction = getCalledFunction(callOp); 301 assert(calledFunction && "could not retrieved called func::FuncOp"); 302 callerMap[calledFunction].insert(callOp); 303 if (calledBy[calledFunction].insert(funcOp).second) { 304 numberCallOpsContainedInFuncOp[funcOp]++; 305 } 306 return WalkResult::advance(); 307 }); 308 }); 309 if (res.wasInterrupted()) 310 return failure(); 311 // Iteratively remove function operation that do not call any of the 312 // functions remaining in the callCounter map and add them to the worklist. 313 while (!numberCallOpsContainedInFuncOp.empty()) { 314 auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 315 [](auto entry) { return entry.getSecond() == 0; }); 316 if (it == numberCallOpsContainedInFuncOp.end()) 317 return moduleOp.emitOpError( 318 "expected callgraph to be free of circular dependencies."); 319 orderedFuncOps.push_back(it->getFirst()); 320 for (auto callee : calledBy[it->getFirst()]) 321 numberCallOpsContainedInFuncOp[callee]--; 322 numberCallOpsContainedInFuncOp.erase(it); 323 } 324 return success(); 325 } 326 327 /// Set the attribute that triggers inplace bufferization on a FuncOp argument 328 /// `bbArg`. 329 static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { 330 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 331 funcOp.setArgAttr(bbArg.getArgNumber(), 332 BufferizableOpInterface::kInplaceableAttrName, 333 BoolAttr::get(bbArg.getContext(), inPlace)); 334 } 335 336 /// Annotate the IR with the result of the analysis. For testing/debugging only. 337 static void annotateOpsWithBufferizationMarkers(func::FuncOp funcOp, 338 const AnalysisState &state) { 339 auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation()); 340 for (BlockArgument bbArg : funcOp.getArguments()) 341 if (bbArg.getType().isa<TensorType>()) 342 setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); 343 } 344 345 /// Fold return values that are memref casts and update function return types. 346 /// 347 /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 348 /// is not known yet. Therefore, the bufferization uses memref types with the 349 /// most generic layout map as function return types. After bufferizing the 350 /// entire function body, a more concise memref type can potentially be used for 351 /// the return type of the function. 352 static void foldMemRefCasts(func::FuncOp funcOp) { 353 if (funcOp.getBody().empty()) 354 return; 355 356 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 357 SmallVector<Type> resultTypes; 358 359 for (OpOperand &operand : returnOp->getOpOperands()) { 360 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { 361 operand.set(castOp.source()); 362 resultTypes.push_back(castOp.source().getType()); 363 } else { 364 resultTypes.push_back(operand.get().getType()); 365 } 366 } 367 368 auto newFuncType = FunctionType::get( 369 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 370 funcOp.setType(newFuncType); 371 } 372 373 LogicalResult 374 mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, 375 OneShotAnalysisState &state) { 376 OneShotBufferizationOptions options = 377 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 378 assert(options.bufferizeFunctionBoundaries && 379 "expected that function boundary bufferization is activated"); 380 FuncAnalysisState &funcState = getFuncAnalysisState(state); 381 BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); 382 383 // A list of functions in the order in which they are analyzed + bufferized. 384 SmallVector<func::FuncOp> orderedFuncOps; 385 386 // A mapping of FuncOps to their callers. 387 FuncCallerMap callerMap; 388 389 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 390 return failure(); 391 392 // Analyze ops. 393 for (func::FuncOp funcOp : orderedFuncOps) { 394 // No body => no analysis. 395 if (funcOp.getBody().empty()) 396 continue; 397 398 // Now analyzing function. 399 funcState.startFunctionAnalysis(funcOp); 400 401 // Gather equivalence info for CallOps. 402 equivalenceAnalysis(funcOp, aliasInfo, state); 403 404 // Analyze funcOp. 405 if (failed(analyzeOp(funcOp, state))) 406 return failure(); 407 408 // Run some extra function analyses. 409 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state)) || 410 failed(funcOpBbArgReadWriteAnalysis(funcOp, state))) 411 return failure(); 412 413 // Mark op as fully analyzed. 414 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 415 416 // Add annotations to function arguments. 417 if (options.testAnalysisOnly) 418 annotateOpsWithBufferizationMarkers(funcOp, state); 419 } 420 421 return success(); 422 } 423 424 LogicalResult mlir::bufferization::bufferizeModuleOp( 425 ModuleOp moduleOp, const OneShotAnalysisState &analysisState) { 426 auto const &options = static_cast<const OneShotBufferizationOptions &>( 427 analysisState.getOptions()); 428 assert(options.bufferizeFunctionBoundaries && 429 "expected that function boundary bufferization is activated"); 430 IRRewriter rewriter(moduleOp.getContext()); 431 BufferizationState bufferizationState(analysisState); 432 433 // A list of functions in the order in which they are analyzed + bufferized. 434 SmallVector<func::FuncOp> orderedFuncOps; 435 436 // A mapping of FuncOps to their callers. 437 FuncCallerMap callerMap; 438 439 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 440 return failure(); 441 442 // Bufferize functions. 443 for (func::FuncOp funcOp : orderedFuncOps) { 444 // Note: It would be good to apply cleanups here but we cannot as aliasInfo 445 // would be invalidated. 446 if (failed(bufferizeOp(funcOp, bufferizationState))) 447 return failure(); 448 // Change buffer return types to more precise layout maps. 449 if (options.functionBoundaryTypeConversion == 450 BufferizationOptions::LayoutMapOption::InferLayoutMap) 451 foldMemRefCasts(funcOp); 452 } 453 454 // Post-pass cleanup of function argument attributes. 455 moduleOp.walk([&](func::FuncOp op) { 456 for (BlockArgument bbArg : op.getArguments()) 457 removeBufferizationAttributes(bbArg); 458 }); 459 460 return success(); 461 } 462 463 LogicalResult mlir::bufferization::runOneShotModuleBufferize( 464 ModuleOp moduleOp, const OneShotBufferizationOptions &options) { 465 assert(options.bufferizeFunctionBoundaries && 466 "expected that function boundary bufferization is activated"); 467 OneShotAnalysisState analysisState(moduleOp, options); 468 if (failed(analyzeModuleOp(moduleOp, analysisState))) 469 return failure(); 470 if (options.testAnalysisOnly) 471 return success(); 472 if (failed(bufferizeModuleOp(moduleOp, analysisState))) 473 return failure(); 474 return success(); 475 } 476