1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 18 namespace mlir { 19 namespace bufferization { 20 namespace func_ext { 21 22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 23 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 24 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 25 auto createdAliasingOperands = 26 aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping()); 27 auto createdAliasingResults = 28 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 29 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 30 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 31 (void)createdEquiv; 32 (void)createdAliasingOperands; 33 (void)createdAliasingResults; 34 (void)createdRead; 35 (void)createdWritten; 36 #ifndef NDEBUG 37 assert(createdEquiv.second && "equivalence info exists already"); 38 assert(createdAliasingOperands.second && "aliasing info exists already"); 39 assert(createdAliasingResults.second && "aliasing info exists already"); 40 assert(createdRead.second && "bbarg access info exists already"); 41 assert(createdWritten.second && "bbarg access info exists already"); 42 #endif // NDEBUG 43 } 44 45 /// Return the unique ReturnOp that terminates `funcOp`. 46 /// Return nullptr if there is no such unique ReturnOp. 47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { 48 func::ReturnOp returnOp; 49 for (Block &b : funcOp.getBody()) { 50 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 51 if (returnOp) 52 return nullptr; 53 returnOp = candidateOp; 54 } 55 } 56 return returnOp; 57 } 58 59 /// Return the index-th bufferized function argument type. This assumes that the 60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be 61 /// specified by the user. If no layout map is specified, a fully dynamic map is 62 /// used. 63 static BaseMemRefType 64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 65 const BufferizationOptions &options) { 66 auto tensorType = 67 funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>(); 68 assert(tensorType && "expected TensorType"); 69 BaseMemRefType memrefType = getMemRefType(tensorType, options); 70 71 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 72 index, BufferizationDialect::kBufferLayoutAttrName); 73 if (!layoutAttr) 74 return memrefType; 75 76 auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); 77 assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 78 return MemRefType::get( 79 rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 80 layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); 81 } 82 83 /// Return the FuncOp called by `callOp`. 84 static FuncOp getCalledFunction(CallOpInterface callOp) { 85 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 86 if (!sym) 87 return nullptr; 88 return dyn_cast_or_null<FuncOp>( 89 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 90 } 91 92 /// Get FuncAnalysisState. 93 static const FuncAnalysisState & 94 getFuncAnalysisState(const AnalysisState &state) { 95 Optional<const FuncAnalysisState *> maybeState = 96 state.getDialectState<FuncAnalysisState>( 97 func::FuncDialect::getDialectNamespace()); 98 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 99 return **maybeState; 100 } 101 102 /// Return the state (phase) of analysis of the FuncOp. 103 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 104 FuncOp funcOp) { 105 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 106 auto it = funcState.analyzedFuncOps.find(funcOp); 107 if (it == funcState.analyzedFuncOps.end()) 108 return FuncOpAnalysisState::NotAnalyzed; 109 return it->second; 110 } 111 112 /// Return the index of the bbArg in the given FuncOp that is equivalent to the 113 /// specified return value (if any). 114 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp, 115 const FuncAnalysisState &state, 116 int64_t returnValIdx) { 117 auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 118 if (funcOpIt == state.equivalentFuncArgs.end()) 119 // No equivalence info stores for funcOp. 120 return None; 121 122 auto retValIt = funcOpIt->getSecond().find(returnValIdx); 123 if (retValIt == funcOpIt->getSecond().end()) 124 // Return value has no equivalent bbArg. 125 return None; 126 127 return retValIt->getSecond(); 128 } 129 130 struct CallOpInterface 131 : public BufferizableOpInterface::ExternalModel<CallOpInterface, 132 func::CallOp> { 133 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 134 const AnalysisState &state) const { 135 func::CallOp callOp = cast<func::CallOp>(op); 136 FuncOp funcOp = getCalledFunction(callOp); 137 assert(funcOp && "expected CallOp to a FuncOp"); 138 139 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 140 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 141 // FuncOp not analyzed yet. Assume that OpOperand is read. 142 return true; 143 144 return funcState.readBbArgs.lookup(funcOp).contains( 145 opOperand.getOperandNumber()); 146 } 147 148 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 149 const AnalysisState &state) const { 150 func::CallOp callOp = cast<func::CallOp>(op); 151 FuncOp funcOp = getCalledFunction(callOp); 152 assert(funcOp && "expected CallOp to a FuncOp"); 153 154 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 155 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 156 // FuncOp not analyzed yet. Assume that OpOperand is written. 157 return true; 158 159 return funcState.writtenBbArgs.lookup(funcOp).contains( 160 opOperand.getOperandNumber()); 161 } 162 163 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 164 const AnalysisState &state) const { 165 func::CallOp callOp = cast<func::CallOp>(op); 166 FuncOp funcOp = getCalledFunction(callOp); 167 assert(funcOp && "expected CallOp to a FuncOp"); 168 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 169 if (getFuncOpAnalysisState(state, funcOp) != 170 FuncOpAnalysisState::Analyzed) { 171 // FuncOp not analyzed yet. Any OpResult may be aliasing. 172 SmallVector<OpResult> result; 173 for (OpResult opResult : op->getOpResults()) 174 if (opResult.getType().isa<TensorType>()) 175 result.push_back(opResult); 176 return result; 177 } 178 179 // Get aliasing results from state. 180 auto aliasingReturnVals = 181 funcState.aliasingReturnVals.lookup(funcOp).lookup( 182 opOperand.getOperandNumber()); 183 SmallVector<OpResult> result; 184 for (int64_t resultIdx : aliasingReturnVals) 185 result.push_back(callOp->getOpResult(resultIdx)); 186 return result; 187 } 188 189 SmallVector<OpOperand *> 190 getAliasingOpOperand(Operation *op, OpResult opResult, 191 const AnalysisState &state) const { 192 func::CallOp callOp = cast<func::CallOp>(op); 193 FuncOp funcOp = getCalledFunction(callOp); 194 assert(funcOp && "expected CallOp to a FuncOp"); 195 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 196 if (getFuncOpAnalysisState(state, funcOp) != 197 FuncOpAnalysisState::Analyzed) { 198 // FuncOp not analyzed yet. Any OpOperand may be aliasing. 199 SmallVector<OpOperand *> result; 200 for (OpOperand &opOperand : op->getOpOperands()) 201 if (opOperand.get().getType().isa<TensorType>()) 202 result.push_back(&opOperand); 203 return result; 204 } 205 206 // Get aliasing bbArgs from state. 207 auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( 208 opResult.getResultNumber()); 209 SmallVector<OpOperand *> result; 210 for (int64_t bbArgIdx : aliasingFuncArgs) 211 result.push_back(&callOp->getOpOperand(bbArgIdx)); 212 return result; 213 } 214 215 BufferRelation bufferRelation(Operation *op, OpResult opResult, 216 const AnalysisState &state) const { 217 return BufferRelation::Equivalent; 218 } 219 220 /// All function arguments are writable. It is the responsibility of the 221 /// CallOp to insert buffer copies where necessary. 222 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 223 BufferizationState &state) const { 224 func::CallOp callOp = cast<func::CallOp>(op); 225 unsigned numResults = callOp.getNumResults(); 226 unsigned numOperands = callOp->getNumOperands(); 227 FuncOp funcOp = getCalledFunction(callOp); 228 assert(funcOp && "expected CallOp to a FuncOp"); 229 FunctionType funcType = funcOp.getFunctionType(); 230 const FuncAnalysisState &funcState = 231 getFuncAnalysisState(state.getAnalysisState()); 232 const OneShotBufferizationOptions &options = 233 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 234 235 // Result types of the bufferized CallOp. 236 SmallVector<Type> resultTypes; 237 // Replacement values for the existing CallOp. These are usually the results 238 // of the bufferized CallOp, unless a tensor result folds onto an operand. 239 SmallVector<Value> replacementValues(numResults, Value()); 240 // For non-tensor results: A mapping from return val indices of the old 241 // CallOp to return val indices of the bufferized CallOp. 242 SmallVector<Optional<unsigned>> retValMapping(numResults, None); 243 // Operands of the bufferized CallOp. 244 SmallVector<Value> newOperands(numOperands, Value()); 245 246 // Based on previously gathered equivalence information, we know if a 247 // tensor result folds onto an operand. These are the only tensor value 248 // results that are supported at the moment. 249 // 250 // For tensors return values that do not fold onto an operand, additional 251 // work is needed (TODO) to either: 252 // * hoist a result into an inplaceable operand or 253 // * devise a better representation to truly return a buffer. 254 // 255 // Note: If a function has no body, no equivalence information is 256 // available. Consequently, a tensor return value cannot be proven to fold 257 // onto a FuncOp bbArg, so calls to such functions are not bufferizable at 258 // the moment. 259 260 // 1. Compute the result types of the new CallOp. Tensor results that are 261 // equivalent to a FuncOp bbArg are no longer returned. 262 for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { 263 unsigned returnValIdx = it.index(); 264 Type returnType = it.value(); 265 if (!returnType.isa<TensorType>()) { 266 // Non-tensor values are returned. 267 retValMapping[returnValIdx] = resultTypes.size(); 268 resultTypes.push_back(returnType); 269 continue; 270 } 271 272 if (options.dropEquivalentFuncResults) { 273 if (Optional<int64_t> bbArgIdx = 274 getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) { 275 // Return operands that are equivalent to some bbArg, are not 276 // returned. 277 FailureOr<Value> bufferOrFailure = 278 state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); 279 if (failed(bufferOrFailure)) 280 return failure(); 281 replacementValues[returnValIdx] = *bufferOrFailure; 282 newOperands[*bbArgIdx] = *bufferOrFailure; 283 continue; 284 } 285 } 286 287 if (!options.allowReturnAllocs) 288 return callOp->emitError( 289 "call to FuncOp that returns non-equivalent tensors not supported"); 290 291 // Returning a memref. This memref is not equivalent to any bbArg. It is 292 // likely a newly allocated buffer. We may want to hoist such allocations 293 // to the call site in the future. 294 retValMapping[returnValIdx] = resultTypes.size(); 295 resultTypes.push_back(funcType.getResult(resultTypes.size())); 296 } 297 298 // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. 299 for (OpOperand &opOperand : callOp->getOpOperands()) { 300 unsigned idx = opOperand.getOperandNumber(); 301 Value tensorOperand = opOperand.get(); 302 303 // Non-tensor operands are just copied. 304 if (!tensorOperand.getType().isa<TensorType>()) { 305 newOperands[idx] = tensorOperand; 306 continue; 307 } 308 309 // Retrieve buffers for tensor operands. Tensor operand buffers, who's 310 // corresponding FuncOp bbArgs are equivalent to a returned tensor, were 311 // already stored in `newOperands` during Step 1. 312 Value buffer = newOperands[idx]; 313 if (!buffer) { 314 FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand); 315 if (failed(bufferOrFailure)) 316 return failure(); 317 buffer = *bufferOrFailure; 318 } 319 320 // Caller / callee type mismatch is handled with a CastOp. 321 auto memRefType = funcType.getInput(idx); 322 // Since we don't yet have a clear layout story, to_memref may 323 // conservatively turn tensors into more dynamic memref than necessary. 324 // If the memref type of the callee fails, introduce an extra memref.cast 325 // that will either canonicalize away or fail compilation until we can do 326 // something better. 327 if (buffer.getType() != memRefType) { 328 assert( 329 memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && 330 "CallOp::bufferize: cast incompatible"); 331 Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), 332 memRefType, buffer); 333 buffer = castBuffer; 334 } 335 newOperands[idx] = buffer; 336 } 337 338 // 3. Create the new CallOp. 339 Operation *newCallOp = rewriter.create<func::CallOp>( 340 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 341 newCallOp->setAttrs(callOp->getAttrs()); 342 // Get replacement values for non-tensor / non-equivalent results. 343 for (unsigned i = 0; i < replacementValues.size(); ++i) { 344 if (replacementValues[i]) 345 continue; 346 replacementValues[i] = newCallOp->getResult(*retValMapping[i]); 347 } 348 349 // 4. Replace the old op with the new op. 350 replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); 351 352 return success(); 353 } 354 }; 355 356 struct ReturnOpInterface 357 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 358 func::ReturnOp> { 359 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 360 const AnalysisState &state) const { 361 return true; 362 } 363 364 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 365 const AnalysisState &state) const { 366 return false; 367 } 368 369 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 370 const AnalysisState &state) const { 371 return {}; 372 } 373 374 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 375 BufferizationState &state) const { 376 #ifndef NDEBUG 377 auto returnOp = cast<func::ReturnOp>(op); 378 assert(isa<FuncOp>(returnOp->getParentOp()) && 379 "only support FuncOp parent for ReturnOp"); 380 #endif // NDEBUG 381 382 // ReturnOps are bufferized as part of FuncOps. 383 return failure(); 384 } 385 }; 386 387 struct FuncOpInterface 388 : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> { 389 /// Rewrite function bbArgs and return values into buffer form (using the 390 /// canonical memref layout for now). This function bufferizes the function 391 /// signature and the ReturnOp. When the entire function body has been 392 /// bufferized, function return types can be switched to more concise memref 393 /// types as part of `foldMemRefCasts`. 394 /// 395 /// When a tensor function argument is known to be equivalent to a tensor 396 /// result, it is dropped from the return values. 397 /// 398 /// All function bbArgs are writable unless they are explicitly marked as 399 /// read-only. Callers must insert copies when needed. 400 /// 401 /// Note: Returning a memref is possible, but corresponding CallOp 402 /// bufferizations fail unless `allowReturnAllocs`. 403 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 404 BufferizationState &state) const { 405 auto funcOp = cast<FuncOp>(op); 406 FunctionType funcType = funcOp.getFunctionType(); 407 const FuncAnalysisState &funcState = 408 getFuncAnalysisState(state.getAnalysisState()); 409 const OneShotBufferizationOptions &options = 410 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 411 412 // Construct the bufferized function type. 413 SmallVector<Type> argTypes; 414 for (const auto &it : llvm::enumerate(funcType.getInputs())) { 415 Type argType = it.value(); 416 if (auto tensorType = argType.dyn_cast<TensorType>()) { 417 argTypes.push_back( 418 getBufferizedFunctionArgType(funcOp, it.index(), options)); 419 continue; 420 } 421 argTypes.push_back(argType); 422 } 423 424 // Bodiless functions are assumed opaque and we cannot know the 425 // bufferization contract they want to enforce. As a consequence, only 426 // support functions that don't return any tensors atm. 427 if (funcOp.getBody().empty()) { 428 SmallVector<Type> retTypes; 429 for (Type resultType : funcType.getResults()) { 430 if (resultType.isa<TensorType>()) 431 return funcOp->emitError() << "cannot bufferize bodiless function " 432 << "that returns a tensor"; 433 retTypes.push_back(resultType); 434 } 435 funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); 436 return success(); 437 } 438 439 // TODO: Support functions with multiple returns. 440 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 441 assert(returnOp && "expected func with single return op"); 442 443 // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. 444 Block &frontBlock = funcOp.getBody().front(); 445 for (BlockArgument &bbArg : frontBlock.getArguments()) { 446 auto tensorType = bbArg.getType().dyn_cast<TensorType>(); 447 // Non-tensor types stay the same. 448 if (!tensorType) 449 continue; 450 451 // Collect all uses of the bbArg. 452 SmallVector<OpOperand *> bbArgUses; 453 for (OpOperand &use : bbArg.getUses()) 454 bbArgUses.push_back(&use); 455 456 // Change the bbArg type to memref. 457 Type memrefType = 458 getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); 459 bbArg.setType(memrefType); 460 461 // Replace all uses of the original tensor bbArg. 462 rewriter.setInsertionPointToStart(&frontBlock); 463 if (!bbArgUses.empty()) { 464 // Insert to_tensor because the remaining function body has not been 465 // bufferized yet. 466 Value toTensorOp = 467 rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg); 468 for (OpOperand *use : bbArgUses) 469 use->set(toTensorOp); 470 } 471 } 472 473 // 2. For each result, keep track of which inplace argument it reuses. 474 SmallVector<Value> returnValues; 475 for (OpOperand &returnOperand : returnOp->getOpOperands()) { 476 Value returnVal = returnOperand.get(); 477 478 // If not a tensor type just forward it. 479 if (!returnVal.getType().isa<RankedTensorType>()) { 480 returnValues.push_back(returnVal); 481 continue; 482 } 483 484 // If return operand is equivalent to some bbArg, no need to return it. 485 if (options.dropEquivalentFuncResults) { 486 if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx( 487 funcOp, funcState, returnOperand.getOperandNumber())) { 488 rewriter.setInsertionPoint(returnOp); 489 Location loc = returnOp.getLoc(); 490 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 491 loc, 492 getMemRefType(returnVal.getType().cast<TensorType>(), options), 493 returnVal); 494 BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); 495 // Note: This copy will fold away. It must be inserted here to ensure 496 // that `returnVal` still has at least one use and does not fold away. 497 if (failed( 498 createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) 499 return funcOp->emitError("could not generate copy for bbArg"); 500 continue; 501 } 502 } 503 504 returnValues.push_back(*state.getBuffer(rewriter, returnOperand)); 505 } 506 507 // 3. Rewrite the terminator without the in-place bufferizable values. 508 returnOp.operandsMutable().assign(returnValues); 509 510 // 4. Rewrite the FuncOp type to buffer form. 511 funcOp.setType(FunctionType::get(op->getContext(), argTypes, 512 ValueRange(returnValues).getTypes())); 513 514 return success(); 515 } 516 517 /// Return `true` if the given function argument is writable. 518 bool isWritable(Operation *op, Value value, 519 const AnalysisState &state) const { 520 auto funcOp = cast<FuncOp>(op); 521 BlockArgument bbArg = value.dyn_cast<BlockArgument>(); 522 assert(bbArg && "expected BlockArgument"); 523 524 // "bufferization.writable" overrides other writability decisions. This is 525 // currently used for testing only. 526 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 527 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 528 return writable.getValue(); 529 530 // All function arguments are writable by default. 531 return true; 532 } 533 534 bool isAllocationHoistingBarrier(Operation *op) const { return true; } 535 }; 536 537 } // namespace func_ext 538 } // namespace bufferization 539 } // namespace mlir 540 541 void mlir::bufferization::func_ext:: 542 registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 543 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 544 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 545 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 546 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 547 }); 548 } 549