1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Module Bufferization is an extension of One-Shot Bufferize that 10 // bufferizes function boundaries. It provides `BufferizableOpInterface` 11 // implementations for FuncOp, CallOp and ReturnOp. 12 // 13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. 14 // This function analyzes the given module and determines the order of analysis 15 // and bufferization: Functions that are called are processed before their 16 // respective callers. 17 // 18 // After analyzing a FuncOp, additional information about its bbArgs is 19 // gathered and stored in `FuncAnalysisState`. 20 // 21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs 22 // for 23 // each tensor return value (if any). 24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is 25 // read/written. 26 // 27 // Module Bufferization implements the following calling convention. 28 // 29 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always 30 // be written to in-place. 31 // * If a tensor operand of a CallOp is read after the CallOp, the operand of 32 // the CallOp must bufferize out-of-place. 33 // 34 // Example: The tensor.insert op bufferizes in-place because it is allowed to 35 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize 36 // out-of-place because `%t0` is modified by the callee but read by the 37 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must 38 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. 39 // ``` 40 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> { 41 // %f = ... : f32 42 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32> 43 // return %0 : tensor<?xf32> 44 // } 45 // 46 // func @caller() -> () { 47 // %t0 = ... : tensor<?xf32> 48 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>) 49 // %2 = tensor.extract %1[...] : tensor<?xf32> 50 // } 51 // ``` 52 // 53 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot 54 // analyze the function body. In such a case, the CallOp analysis conservatively 55 // assumes that each tensor OpOperand is both read and written. 56 // 57 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked 58 // as "not reading" and/or "not writing". 59 60 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 61 62 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 63 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 64 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 65 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 66 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 67 #include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" 68 #include "mlir/Dialect/Func/IR/FuncOps.h" 69 #include "mlir/Dialect/MemRef/IR/MemRef.h" 70 #include "mlir/IR/Operation.h" 71 72 using namespace mlir; 73 using namespace mlir::bufferization; 74 using namespace mlir::bufferization::func_ext; 75 76 /// A mapping of FuncOps to their callers. 77 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>; 78 79 /// Get FuncAnalysisState. 80 static const FuncAnalysisState & 81 getFuncAnalysisState(const AnalysisState &state) { 82 Optional<const FuncAnalysisState *> maybeState = 83 state.getDialectState<FuncAnalysisState>( 84 func::FuncDialect::getDialectNamespace()); 85 assert(maybeState && "FuncAnalysisState does not exist"); 86 return **maybeState; 87 } 88 89 /// Get or create FuncAnalysisState. 90 static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { 91 return state.getOrCreateDialectState<FuncAnalysisState>( 92 func::FuncDialect::getDialectNamespace()); 93 } 94 95 /// Return the state (phase) of analysis of the FuncOp. 96 /// Used for debug modes. 97 LLVM_ATTRIBUTE_UNUSED 98 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 99 func::FuncOp funcOp) { 100 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 101 auto it = funcState.analyzedFuncOps.find(funcOp); 102 if (it == funcState.analyzedFuncOps.end()) 103 return FuncOpAnalysisState::NotAnalyzed; 104 return it->second; 105 } 106 107 /// Return the unique ReturnOp that terminates `funcOp`. 108 /// Return nullptr if there is no such unique ReturnOp. 109 static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { 110 func::ReturnOp returnOp; 111 for (Block &b : funcOp.getBody()) { 112 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 113 if (returnOp) 114 return nullptr; 115 returnOp = candidateOp; 116 } 117 } 118 return returnOp; 119 } 120 121 namespace { 122 123 /// Annotate IR with the results of the analysis. For testing purposes only. 124 static void annotateEquivalentReturnBbArg(OpOperand &returnVal, 125 BlockArgument bbArg) { 126 const char *kEquivalentArgsAttr = "__equivalent_func_args__"; 127 Operation *op = returnVal.getOwner(); 128 129 SmallVector<int64_t> equivBbArgs; 130 if (op->hasAttr(kEquivalentArgsAttr)) { 131 auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>(); 132 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { 133 return a.cast<IntegerAttr>().getValue().getSExtValue(); 134 })); 135 } else { 136 equivBbArgs.append(op->getNumOperands(), -1); 137 } 138 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); 139 140 OpBuilder b(op->getContext()); 141 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); 142 } 143 144 /// Store function BlockArguments that are equivalent to/aliasing a returned 145 /// value in FuncAnalysisState. 146 static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, 147 OneShotAnalysisState &state) { 148 FuncAnalysisState &funcState = getFuncAnalysisState(state); 149 150 // Support only single return-terminated block in the function. 151 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 152 assert(returnOp && "expected func with single return op"); 153 154 for (OpOperand &returnVal : returnOp->getOpOperands()) 155 if (returnVal.get().getType().isa<RankedTensorType>()) 156 for (BlockArgument bbArg : funcOp.getArguments()) 157 if (bbArg.getType().isa<RankedTensorType>()) { 158 int64_t returnIdx = returnVal.getOperandNumber(); 159 int64_t bbArgIdx = bbArg.getArgNumber(); 160 if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { 161 funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; 162 if (state.getOptions().testAnalysisOnly) 163 annotateEquivalentReturnBbArg(returnVal, bbArg); 164 } 165 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) { 166 funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); 167 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); 168 } 169 } 170 171 return success(); 172 } 173 174 static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, 175 bool isRead, bool isWritten) { 176 OpBuilder b(funcOp.getContext()); 177 Attribute accessType; 178 if (isRead && isWritten) { 179 accessType = b.getStringAttr("read-write"); 180 } else if (isRead) { 181 accessType = b.getStringAttr("read"); 182 } else if (isWritten) { 183 accessType = b.getStringAttr("write"); 184 } else { 185 accessType = b.getStringAttr("none"); 186 } 187 funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); 188 } 189 190 /// Determine which FuncOp bbArgs are read and which are written. When run on a 191 /// function with unknown ops, we conservatively assume that such ops bufferize 192 /// to a read + write. 193 static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp, 194 OneShotAnalysisState &state) { 195 FuncAnalysisState &funcState = getFuncAnalysisState(state); 196 197 // If the function has no body, conservatively assume that all args are 198 // read + written. 199 if (funcOp.getBody().empty()) { 200 for (BlockArgument bbArg : funcOp.getArguments()) { 201 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 202 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 203 } 204 205 return success(); 206 } 207 208 for (BlockArgument bbArg : funcOp.getArguments()) { 209 if (!bbArg.getType().isa<TensorType>()) 210 continue; 211 bool isRead = state.isValueRead(bbArg); 212 bool isWritten = state.isValueWritten(bbArg); 213 if (state.getOptions().testAnalysisOnly) 214 annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); 215 if (isRead) 216 funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); 217 if (isWritten) 218 funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); 219 } 220 221 return success(); 222 } 223 } // namespace 224 225 /// Remove bufferization attributes on FuncOp arguments. 226 static void removeBufferizationAttributes(BlockArgument bbArg) { 227 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp()); 228 funcOp.removeArgAttr(bbArg.getArgNumber(), 229 BufferizationDialect::kBufferLayoutAttrName); 230 funcOp.removeArgAttr(bbArg.getArgNumber(), 231 BufferizationDialect::kWritableAttrName); 232 } 233 234 /// Return the func::FuncOp called by `callOp`. 235 static func::FuncOp getCalledFunction(CallOpInterface callOp) { 236 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 237 if (!sym) 238 return nullptr; 239 return dyn_cast_or_null<func::FuncOp>( 240 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 241 } 242 243 /// Gather equivalence info of CallOps. 244 /// Note: This only adds new equivalence info if the called function was already 245 /// analyzed. 246 // TODO: This does not handle cyclic function call graphs etc. 247 static void equivalenceAnalysis(func::FuncOp funcOp, 248 BufferizationAliasInfo &aliasInfo, 249 OneShotAnalysisState &state) { 250 FuncAnalysisState &funcState = getFuncAnalysisState(state); 251 funcOp->walk([&](func::CallOp callOp) { 252 func::FuncOp calledFunction = getCalledFunction(callOp); 253 assert(calledFunction && "could not retrieved called func::FuncOp"); 254 255 // No equivalence info available for the called function. 256 if (!funcState.equivalentFuncArgs.count(calledFunction)) 257 return WalkResult::skip(); 258 259 for (auto it : funcState.equivalentFuncArgs[calledFunction]) { 260 int64_t returnIdx = it.first; 261 int64_t bbargIdx = it.second; 262 if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) 263 continue; 264 Value returnVal = callOp.getResult(returnIdx); 265 Value argVal = callOp->getOperand(bbargIdx); 266 aliasInfo.unionEquivalenceClasses(returnVal, argVal); 267 } 268 269 return WalkResult::advance(); 270 }); 271 } 272 273 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by 274 /// callee-caller order (i.e. callees without callers first). 275 /// Store the map of FuncOp to all its callers in `callerMap`. 276 /// Return `failure()` if a cycle of calls is detected or if we are unable to 277 /// retrieve the called FuncOp from any CallOpInterface. 278 static LogicalResult 279 getFuncOpsOrderedByCalls(ModuleOp moduleOp, 280 SmallVectorImpl<func::FuncOp> &orderedFuncOps, 281 FuncCallerMap &callerMap) { 282 // For each FuncOp, the set of functions called by it (i.e. the union of 283 // symbols of all nested CallOpInterfaceOp). 284 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; 285 // For each FuncOp, the number of CallOpInterface it contains. 286 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; 287 WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { 288 if (!funcOp.getBody().empty()) { 289 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 290 if (!returnOp) 291 return funcOp->emitError() 292 << "cannot bufferize a FuncOp with tensors and " 293 "without a unique ReturnOp"; 294 } 295 296 numberCallOpsContainedInFuncOp[funcOp] = 0; 297 return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { 298 // Only support CallOp for now. 299 if (!isa<func::CallOp>(callOp.getOperation())) 300 return callOp->emitError() << "expected a CallOp"; 301 func::FuncOp calledFunction = getCalledFunction(callOp); 302 assert(calledFunction && "could not retrieved called func::FuncOp"); 303 callerMap[calledFunction].insert(callOp); 304 if (calledBy[calledFunction].insert(funcOp).second) { 305 numberCallOpsContainedInFuncOp[funcOp]++; 306 } 307 return WalkResult::advance(); 308 }); 309 }); 310 if (res.wasInterrupted()) 311 return failure(); 312 // Iteratively remove function operation that do not call any of the 313 // functions remaining in the callCounter map and add them to the worklist. 314 while (!numberCallOpsContainedInFuncOp.empty()) { 315 auto it = llvm::find_if(numberCallOpsContainedInFuncOp, 316 [](auto entry) { return entry.getSecond() == 0; }); 317 if (it == numberCallOpsContainedInFuncOp.end()) 318 return moduleOp.emitOpError( 319 "expected callgraph to be free of circular dependencies."); 320 orderedFuncOps.push_back(it->getFirst()); 321 for (auto callee : calledBy[it->getFirst()]) 322 numberCallOpsContainedInFuncOp[callee]--; 323 numberCallOpsContainedInFuncOp.erase(it); 324 } 325 return success(); 326 } 327 328 /// Fold return values that are memref casts and update function return types. 329 /// 330 /// During FuncOp bufferization, the exact type of the returned memrefs (if any) 331 /// is not known yet. Therefore, the bufferization uses memref types with the 332 /// most generic layout map as function return types. After bufferizing the 333 /// entire function body, a more concise memref type can potentially be used for 334 /// the return type of the function. 335 static void foldMemRefCasts(func::FuncOp funcOp) { 336 if (funcOp.getBody().empty()) 337 return; 338 339 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 340 SmallVector<Type> resultTypes; 341 342 for (OpOperand &operand : returnOp->getOpOperands()) { 343 if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { 344 operand.set(castOp.getSource()); 345 resultTypes.push_back(castOp.getSource().getType()); 346 } else { 347 resultTypes.push_back(operand.get().getType()); 348 } 349 } 350 351 auto newFuncType = FunctionType::get( 352 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); 353 funcOp.setType(newFuncType); 354 } 355 356 LogicalResult 357 mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, 358 OneShotAnalysisState &state) { 359 OneShotBufferizationOptions options = 360 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 361 assert(options.bufferizeFunctionBoundaries && 362 "expected that function boundary bufferization is activated"); 363 FuncAnalysisState &funcState = getFuncAnalysisState(state); 364 BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); 365 366 // A list of functions in the order in which they are analyzed + bufferized. 367 SmallVector<func::FuncOp> orderedFuncOps; 368 369 // A mapping of FuncOps to their callers. 370 FuncCallerMap callerMap; 371 372 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 373 return failure(); 374 375 // Analyze ops. 376 for (func::FuncOp funcOp : orderedFuncOps) { 377 // No body => no analysis. 378 if (funcOp.getBody().empty()) 379 continue; 380 381 // Now analyzing function. 382 funcState.startFunctionAnalysis(funcOp); 383 384 // Gather equivalence info for CallOps. 385 equivalenceAnalysis(funcOp, aliasInfo, state); 386 387 // Analyze funcOp. 388 if (failed(analyzeOp(funcOp, state))) 389 return failure(); 390 391 // Run some extra function analyses. 392 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state)) || 393 failed(funcOpBbArgReadWriteAnalysis(funcOp, state))) 394 return failure(); 395 396 // Mark op as fully analyzed. 397 funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; 398 } 399 400 return success(); 401 } 402 403 LogicalResult mlir::bufferization::bufferizeModuleOp( 404 ModuleOp moduleOp, const OneShotAnalysisState &analysisState) { 405 auto const &options = static_cast<const OneShotBufferizationOptions &>( 406 analysisState.getOptions()); 407 assert(options.bufferizeFunctionBoundaries && 408 "expected that function boundary bufferization is activated"); 409 IRRewriter rewriter(moduleOp.getContext()); 410 411 // A list of functions in the order in which they are analyzed + bufferized. 412 SmallVector<func::FuncOp> orderedFuncOps; 413 414 // A mapping of FuncOps to their callers. 415 FuncCallerMap callerMap; 416 417 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) 418 return failure(); 419 420 // Bufferize functions. 421 for (func::FuncOp funcOp : orderedFuncOps) { 422 // Note: It would be good to apply cleanups here but we cannot as aliasInfo 423 // would be invalidated. 424 if (failed(bufferizeOp(funcOp, options, /*copyBeforeWrite=*/false))) 425 return failure(); 426 // Change buffer return types to more precise layout maps. 427 if (options.functionBoundaryTypeConversion == 428 BufferizationOptions::LayoutMapOption::InferLayoutMap) 429 foldMemRefCasts(funcOp); 430 } 431 432 // Post-pass cleanup of function argument attributes. 433 moduleOp.walk([&](func::FuncOp op) { 434 for (BlockArgument bbArg : op.getArguments()) 435 removeBufferizationAttributes(bbArg); 436 }); 437 438 return success(); 439 } 440 441 LogicalResult mlir::bufferization::runOneShotModuleBufferize( 442 ModuleOp moduleOp, const OneShotBufferizationOptions &options) { 443 assert(options.bufferizeFunctionBoundaries && 444 "expected that function boundary bufferization is activated"); 445 OneShotAnalysisState analysisState(moduleOp, options); 446 if (failed(insertTensorCopies(moduleOp, options))) 447 return failure(); 448 if (options.testAnalysisOnly) 449 return success(); 450 if (failed(bufferizeModuleOp(moduleOp, analysisState))) 451 return failure(); 452 return success(); 453 } 454