1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 18 namespace mlir { 19 namespace bufferization { 20 namespace func_ext { 21 22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 23 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 24 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 25 auto createdAliasingOperands = 26 aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping()); 27 auto createdAliasingResults = 28 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 29 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 30 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 31 (void)createdEquiv; 32 (void)createdAliasingOperands; 33 (void)createdAliasingResults; 34 (void)createdRead; 35 (void)createdWritten; 36 #ifndef NDEBUG 37 assert(createdEquiv.second && "equivalence info exists already"); 38 assert(createdAliasingOperands.second && "aliasing info exists already"); 39 assert(createdAliasingResults.second && "aliasing info exists already"); 40 assert(createdRead.second && "bbarg access info exists already"); 41 assert(createdWritten.second && "bbarg access info exists already"); 42 #endif // NDEBUG 43 } 44 45 /// Return the unique ReturnOp that terminates `funcOp`. 46 /// Return nullptr if there is no such unique ReturnOp. 47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { 48 func::ReturnOp returnOp; 49 for (Block &b : funcOp.getBody()) { 50 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 51 if (returnOp) 52 return nullptr; 53 returnOp = candidateOp; 54 } 55 } 56 return returnOp; 57 } 58 59 /// Return the index-th bufferized function argument type. This assumes that the 60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be 61 /// specified by the user. If no layout map is specified, the default layout map 62 /// (as per `options.functionBoundaryTypeConversion`) is used. 63 static BaseMemRefType 64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 65 const BufferizationOptions &options) { 66 auto tensorType = 67 funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>(); 68 assert(tensorType && "expected TensorType"); 69 70 BaseMemRefType memrefType; 71 if (options.functionBoundaryTypeConversion == 72 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 73 memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); 74 } else { 75 // Note: Layout maps on function parameters cannot be inferred. The best we 76 // can do at the moment is "fully dynamic". 77 memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType); 78 } 79 80 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 81 index, BufferizationDialect::kBufferLayoutAttrName); 82 if (!layoutAttr) 83 return memrefType; 84 85 auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); 86 assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 87 return MemRefType::get( 88 rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 89 layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); 90 } 91 92 /// Return the FuncOp called by `callOp`. 93 static FuncOp getCalledFunction(CallOpInterface callOp) { 94 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 95 if (!sym) 96 return nullptr; 97 return dyn_cast_or_null<FuncOp>( 98 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 99 } 100 101 /// Get FuncAnalysisState. 102 static const FuncAnalysisState & 103 getFuncAnalysisState(const AnalysisState &state) { 104 Optional<const FuncAnalysisState *> maybeState = 105 state.getDialectState<FuncAnalysisState>( 106 func::FuncDialect::getDialectNamespace()); 107 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 108 return **maybeState; 109 } 110 111 /// Return the state (phase) of analysis of the FuncOp. 112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 113 FuncOp funcOp) { 114 Optional<const FuncAnalysisState *> maybeState = 115 state.getDialectState<FuncAnalysisState>( 116 func::FuncDialect::getDialectNamespace()); 117 if (!maybeState.hasValue()) 118 return FuncOpAnalysisState::NotAnalyzed; 119 const auto &analyzedFuncOps = maybeState.getValue()->analyzedFuncOps; 120 auto it = analyzedFuncOps.find(funcOp); 121 if (it == analyzedFuncOps.end()) 122 return FuncOpAnalysisState::NotAnalyzed; 123 return it->second; 124 } 125 126 /// Return the index of the bbArg in the given FuncOp that is equivalent to the 127 /// specified return value (if any). 128 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp, 129 const FuncAnalysisState &state, 130 int64_t returnValIdx) { 131 auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 132 if (funcOpIt == state.equivalentFuncArgs.end()) 133 // No equivalence info stores for funcOp. 134 return None; 135 136 auto retValIt = funcOpIt->getSecond().find(returnValIdx); 137 if (retValIt == funcOpIt->getSecond().end()) 138 // Return value has no equivalent bbArg. 139 return None; 140 141 return retValIt->getSecond(); 142 } 143 144 struct CallOpInterface 145 : public BufferizableOpInterface::ExternalModel<CallOpInterface, 146 func::CallOp> { 147 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 148 const AnalysisState &state) const { 149 func::CallOp callOp = cast<func::CallOp>(op); 150 FuncOp funcOp = getCalledFunction(callOp); 151 assert(funcOp && "expected CallOp to a FuncOp"); 152 153 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 154 // FuncOp not analyzed yet. Assume that OpOperand is read. 155 return true; 156 157 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 158 return funcState.readBbArgs.lookup(funcOp).contains( 159 opOperand.getOperandNumber()); 160 } 161 162 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 163 const AnalysisState &state) const { 164 func::CallOp callOp = cast<func::CallOp>(op); 165 FuncOp funcOp = getCalledFunction(callOp); 166 assert(funcOp && "expected CallOp to a FuncOp"); 167 168 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 169 // FuncOp not analyzed yet. Assume that OpOperand is written. 170 return true; 171 172 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 173 return funcState.writtenBbArgs.lookup(funcOp).contains( 174 opOperand.getOperandNumber()); 175 } 176 177 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 178 const AnalysisState &state) const { 179 func::CallOp callOp = cast<func::CallOp>(op); 180 FuncOp funcOp = getCalledFunction(callOp); 181 assert(funcOp && "expected CallOp to a FuncOp"); 182 if (getFuncOpAnalysisState(state, funcOp) != 183 FuncOpAnalysisState::Analyzed) { 184 // FuncOp not analyzed yet. Any OpResult may be aliasing. 185 SmallVector<OpResult> result; 186 for (OpResult opResult : op->getOpResults()) 187 if (opResult.getType().isa<TensorType>()) 188 result.push_back(opResult); 189 return result; 190 } 191 192 // Get aliasing results from state. 193 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 194 auto aliasingReturnVals = 195 funcState.aliasingReturnVals.lookup(funcOp).lookup( 196 opOperand.getOperandNumber()); 197 SmallVector<OpResult> result; 198 for (int64_t resultIdx : aliasingReturnVals) 199 result.push_back(callOp->getOpResult(resultIdx)); 200 return result; 201 } 202 203 SmallVector<OpOperand *> 204 getAliasingOpOperand(Operation *op, OpResult opResult, 205 const AnalysisState &state) const { 206 func::CallOp callOp = cast<func::CallOp>(op); 207 FuncOp funcOp = getCalledFunction(callOp); 208 assert(funcOp && "expected CallOp to a FuncOp"); 209 if (getFuncOpAnalysisState(state, funcOp) != 210 FuncOpAnalysisState::Analyzed) { 211 // FuncOp not analyzed yet. Any OpOperand may be aliasing. 212 SmallVector<OpOperand *> result; 213 for (OpOperand &opOperand : op->getOpOperands()) 214 if (opOperand.get().getType().isa<TensorType>()) 215 result.push_back(&opOperand); 216 return result; 217 } 218 219 // Get aliasing bbArgs from state. 220 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 221 auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( 222 opResult.getResultNumber()); 223 SmallVector<OpOperand *> result; 224 for (int64_t bbArgIdx : aliasingFuncArgs) 225 result.push_back(&callOp->getOpOperand(bbArgIdx)); 226 return result; 227 } 228 229 BufferRelation bufferRelation(Operation *op, OpResult opResult, 230 const AnalysisState &state) const { 231 func::CallOp callOp = cast<func::CallOp>(op); 232 FuncOp funcOp = getCalledFunction(callOp); 233 assert(funcOp && "expected CallOp to a FuncOp"); 234 if (getFuncOpAnalysisState(state, funcOp) != 235 FuncOpAnalysisState::Analyzed) { 236 // Function not analyzed yet. The conservative answer is "None". 237 return BufferRelation::None; 238 } 239 240 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 241 Optional<int64_t> maybeEquiv = 242 getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber()); 243 if (maybeEquiv.hasValue()) { 244 #ifndef NDEBUG 245 SmallVector<OpOperand *> aliasingOpOperands = 246 getAliasingOpOperand(op, opResult, state); 247 assert(aliasingOpOperands.size() == 1 && 248 "expected exactly 1 aliasing OpOperand"); 249 assert(aliasingOpOperands.front()->getOperandNumber() == 250 maybeEquiv.getValue() && 251 "inconsistent analysis state"); 252 #endif 253 return BufferRelation::Equivalent; 254 } 255 return BufferRelation::None; 256 } 257 258 /// All function arguments are writable. It is the responsibility of the 259 /// CallOp to insert buffer copies where necessary. 260 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 261 BufferizationState &state) const { 262 func::CallOp callOp = cast<func::CallOp>(op); 263 unsigned numResults = callOp.getNumResults(); 264 unsigned numOperands = callOp->getNumOperands(); 265 FuncOp funcOp = getCalledFunction(callOp); 266 assert(funcOp && "expected CallOp to a FuncOp"); 267 FunctionType funcType = funcOp.getFunctionType(); 268 269 // Result types of the bufferized CallOp. 270 SmallVector<Type> resultTypes; 271 // Replacement values for the existing CallOp. These are usually the results 272 // of the bufferized CallOp, unless a tensor result folds onto an operand. 273 SmallVector<Value> replacementValues(numResults, Value()); 274 // For non-tensor results: A mapping from return val indices of the old 275 // CallOp to return val indices of the bufferized CallOp. 276 SmallVector<Optional<unsigned>> retValMapping(numResults, None); 277 // Operands of the bufferized CallOp. 278 SmallVector<Value> newOperands(numOperands, Value()); 279 280 // 1. Compute the result types of the new CallOp. 281 for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { 282 unsigned returnValIdx = it.index(); 283 Type returnType = it.value(); 284 if (!returnType.isa<TensorType>()) { 285 // Non-tensor values are returned. 286 retValMapping[returnValIdx] = resultTypes.size(); 287 resultTypes.push_back(returnType); 288 continue; 289 } 290 291 // Returning a memref. 292 retValMapping[returnValIdx] = resultTypes.size(); 293 resultTypes.push_back(funcType.getResult(resultTypes.size())); 294 } 295 296 // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. 297 for (OpOperand &opOperand : callOp->getOpOperands()) { 298 unsigned idx = opOperand.getOperandNumber(); 299 Value tensorOperand = opOperand.get(); 300 301 // Non-tensor operands are just copied. 302 if (!tensorOperand.getType().isa<TensorType>()) { 303 newOperands[idx] = tensorOperand; 304 continue; 305 } 306 307 // Retrieve buffers for tensor operands. 308 Value buffer = newOperands[idx]; 309 if (!buffer) { 310 FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand); 311 if (failed(bufferOrFailure)) 312 return failure(); 313 buffer = *bufferOrFailure; 314 } 315 316 // Caller / callee type mismatch is handled with a CastOp. 317 auto memRefType = funcType.getInput(idx); 318 // Since we don't yet have a clear layout story, to_memref may 319 // conservatively turn tensors into more dynamic memref than necessary. 320 // If the memref type of the callee fails, introduce an extra memref.cast 321 // that will either canonicalize away or fail compilation until we can do 322 // something better. 323 if (buffer.getType() != memRefType) { 324 assert( 325 memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && 326 "CallOp::bufferize: cast incompatible"); 327 Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), 328 memRefType, buffer); 329 buffer = castBuffer; 330 } 331 newOperands[idx] = buffer; 332 } 333 334 // 3. Create the new CallOp. 335 Operation *newCallOp = rewriter.create<func::CallOp>( 336 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 337 newCallOp->setAttrs(callOp->getAttrs()); 338 // Get replacement values. 339 for (unsigned i = 0; i < replacementValues.size(); ++i) { 340 if (replacementValues[i]) 341 continue; 342 replacementValues[i] = newCallOp->getResult(*retValMapping[i]); 343 } 344 345 // 4. Replace the old op with the new op. 346 replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); 347 348 return success(); 349 } 350 }; 351 352 struct ReturnOpInterface 353 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 354 func::ReturnOp> { 355 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 356 const AnalysisState &state) const { 357 return true; 358 } 359 360 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 361 const AnalysisState &state) const { 362 return false; 363 } 364 365 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 366 const AnalysisState &state) const { 367 return {}; 368 } 369 370 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 371 BufferizationState &state) const { 372 #ifndef NDEBUG 373 auto returnOp = cast<func::ReturnOp>(op); 374 assert(isa<FuncOp>(returnOp->getParentOp()) && 375 "only support FuncOp parent for ReturnOp"); 376 #endif // NDEBUG 377 378 // ReturnOps are bufferized as part of FuncOps. 379 return success(); 380 } 381 }; 382 383 struct FuncOpInterface 384 : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> { 385 /// Rewrite function bbArgs and return values into buffer form. This function 386 /// bufferizes the function signature and the ReturnOp. When the entire 387 /// function body has been bufferized, function return types can be switched 388 /// to more concise memref types as part of `foldMemRefCasts`. 389 /// 390 /// All function bbArgs are writable unless they are explicitly marked as 391 /// read-only. Callers must insert copies when needed. 392 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 393 BufferizationState &state) const { 394 auto funcOp = cast<FuncOp>(op); 395 FunctionType funcType = funcOp.getFunctionType(); 396 const OneShotBufferizationOptions &options = 397 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 398 399 // Construct the bufferized function type. 400 SmallVector<Type> argTypes; 401 for (const auto &it : llvm::enumerate(funcType.getInputs())) { 402 Type argType = it.value(); 403 if (auto tensorType = argType.dyn_cast<TensorType>()) { 404 argTypes.push_back( 405 getBufferizedFunctionArgType(funcOp, it.index(), options)); 406 continue; 407 } 408 argTypes.push_back(argType); 409 } 410 411 // Bodiless functions are assumed opaque and we cannot know the 412 // bufferization contract they want to enforce. As a consequence, only 413 // support functions that don't return any tensors atm. 414 if (funcOp.getBody().empty()) { 415 SmallVector<Type> retTypes; 416 for (Type resultType : funcType.getResults()) { 417 if (resultType.isa<TensorType>()) 418 return funcOp->emitError() << "cannot bufferize bodiless function " 419 << "that returns a tensor"; 420 retTypes.push_back(resultType); 421 } 422 funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); 423 return success(); 424 } 425 426 // TODO: Support functions with multiple returns. 427 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 428 assert(returnOp && "expected func with single return op"); 429 Location loc = returnOp.getLoc(); 430 431 // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. 432 Block &frontBlock = funcOp.getBody().front(); 433 for (BlockArgument &bbArg : frontBlock.getArguments()) { 434 auto tensorType = bbArg.getType().dyn_cast<TensorType>(); 435 // Non-tensor types stay the same. 436 if (!tensorType) 437 continue; 438 439 // Collect all uses of the bbArg. 440 SmallVector<OpOperand *> bbArgUses; 441 for (OpOperand &use : bbArg.getUses()) 442 bbArgUses.push_back(&use); 443 444 // Change the bbArg type to memref. 445 Type memrefType = 446 getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); 447 bbArg.setType(memrefType); 448 449 // Replace all uses of the original tensor bbArg. 450 rewriter.setInsertionPointToStart(&frontBlock); 451 if (!bbArgUses.empty()) { 452 // Insert to_tensor because the remaining function body has not been 453 // bufferized yet. 454 Value toTensorOp = 455 rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg); 456 for (OpOperand *use : bbArgUses) 457 use->set(toTensorOp); 458 } 459 } 460 461 // 2. For each result, keep track of which inplace argument it reuses. 462 SmallVector<Value> returnValues; 463 for (OpOperand &returnOperand : returnOp->getOpOperands()) { 464 Value returnVal = returnOperand.get(); 465 auto tensorType = returnVal.getType().dyn_cast<TensorType>(); 466 rewriter.setInsertionPoint(returnOp); 467 468 // If not a tensor type just forward it. 469 if (!tensorType) { 470 returnValues.push_back(returnVal); 471 continue; 472 } 473 474 BaseMemRefType resultType; 475 if (options.functionBoundaryTypeConversion == 476 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 477 resultType = getMemRefTypeWithStaticIdentityLayout(tensorType); 478 } else { 479 // Note: If `InferLayoutMap`, cast are later folded away. 480 resultType = getMemRefTypeWithFullyDynamicLayout(tensorType); 481 } 482 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 483 loc, resultType, returnVal); 484 returnValues.push_back(toMemrefOp); 485 } 486 487 // 3. Rewrite the terminator without the in-place bufferizable values. 488 returnOp.operandsMutable().assign(returnValues); 489 490 // 4. Rewrite the FuncOp type to buffer form. 491 funcOp.setType(FunctionType::get(op->getContext(), argTypes, 492 ValueRange(returnValues).getTypes())); 493 494 return success(); 495 } 496 497 /// Return `true` if the given function argument is writable. 498 bool isWritable(Operation *op, Value value, 499 const AnalysisState &state) const { 500 auto funcOp = cast<FuncOp>(op); 501 BlockArgument bbArg = value.dyn_cast<BlockArgument>(); 502 assert(bbArg && "expected BlockArgument"); 503 504 // "bufferization.writable" overrides other writability decisions. This is 505 // currently used for testing only. 506 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 507 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 508 return writable.getValue(); 509 510 // All function arguments are writable by default. 511 return true; 512 } 513 }; 514 515 } // namespace func_ext 516 } // namespace bufferization 517 } // namespace mlir 518 519 void mlir::bufferization::func_ext:: 520 registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 521 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 522 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 523 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 524 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 525 }); 526 } 527