1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 18 namespace mlir { 19 namespace bufferization { 20 namespace func_ext { 21 22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 23 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 24 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 25 auto createdAliasingOperands = 26 aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping()); 27 auto createdAliasingResults = 28 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 29 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 30 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 31 (void)createdEquiv; 32 (void)createdAliasingOperands; 33 (void)createdAliasingResults; 34 (void)createdRead; 35 (void)createdWritten; 36 #ifndef NDEBUG 37 assert(createdEquiv.second && "equivalence info exists already"); 38 assert(createdAliasingOperands.second && "aliasing info exists already"); 39 assert(createdAliasingResults.second && "aliasing info exists already"); 40 assert(createdRead.second && "bbarg access info exists already"); 41 assert(createdWritten.second && "bbarg access info exists already"); 42 #endif // NDEBUG 43 } 44 45 /// Return the unique ReturnOp that terminates `funcOp`. 46 /// Return nullptr if there is no such unique ReturnOp. 47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { 48 func::ReturnOp returnOp; 49 for (Block &b : funcOp.getBody()) { 50 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 51 if (returnOp) 52 return nullptr; 53 returnOp = candidateOp; 54 } 55 } 56 return returnOp; 57 } 58 59 /// Return the index-th bufferized function argument type. This assumes that the 60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be 61 /// specified by the user. If no layout map is specified, a fully dynamic map is 62 /// used. 63 static BaseMemRefType 64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 65 const BufferizationOptions &options) { 66 auto tensorType = 67 funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>(); 68 assert(tensorType && "expected TensorType"); 69 BaseMemRefType memrefType = getMemRefType(tensorType, options); 70 71 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 72 index, BufferizationDialect::kBufferLayoutAttrName); 73 if (!layoutAttr) 74 return memrefType; 75 76 auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); 77 assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 78 return MemRefType::get( 79 rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 80 layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); 81 } 82 83 /// Return the FuncOp called by `callOp`. 84 static FuncOp getCalledFunction(CallOpInterface callOp) { 85 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 86 if (!sym) 87 return nullptr; 88 return dyn_cast_or_null<FuncOp>( 89 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 90 } 91 92 /// Get FuncAnalysisState. 93 static const FuncAnalysisState & 94 getFuncAnalysisState(const AnalysisState &state) { 95 Optional<const FuncAnalysisState *> maybeState = 96 state.getDialectState<FuncAnalysisState>( 97 func::FuncDialect::getDialectNamespace()); 98 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 99 return **maybeState; 100 } 101 102 /// Return the state (phase) of analysis of the FuncOp. 103 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 104 FuncOp funcOp) { 105 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 106 auto it = funcState.analyzedFuncOps.find(funcOp); 107 if (it == funcState.analyzedFuncOps.end()) 108 return FuncOpAnalysisState::NotAnalyzed; 109 return it->second; 110 } 111 112 /// Return the index of the bbArg in the given FuncOp that is equivalent to the 113 /// specified return value (if any). 114 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp, 115 const FuncAnalysisState &state, 116 int64_t returnValIdx) { 117 auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 118 if (funcOpIt == state.equivalentFuncArgs.end()) 119 // No equivalence info stores for funcOp. 120 return None; 121 122 auto retValIt = funcOpIt->getSecond().find(returnValIdx); 123 if (retValIt == funcOpIt->getSecond().end()) 124 // Return value has no equivalent bbArg. 125 return None; 126 127 return retValIt->getSecond(); 128 } 129 130 struct CallOpInterface 131 : public BufferizableOpInterface::ExternalModel<CallOpInterface, 132 func::CallOp> { 133 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 134 const AnalysisState &state) const { 135 func::CallOp callOp = cast<func::CallOp>(op); 136 FuncOp funcOp = getCalledFunction(callOp); 137 assert(funcOp && "expected CallOp to a FuncOp"); 138 139 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 140 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 141 // FuncOp not analyzed yet. Assume that OpOperand is read. 142 return true; 143 144 return funcState.readBbArgs.lookup(funcOp).contains( 145 opOperand.getOperandNumber()); 146 } 147 148 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 149 const AnalysisState &state) const { 150 func::CallOp callOp = cast<func::CallOp>(op); 151 FuncOp funcOp = getCalledFunction(callOp); 152 assert(funcOp && "expected CallOp to a FuncOp"); 153 154 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 155 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 156 // FuncOp not analyzed yet. Assume that OpOperand is written. 157 return true; 158 159 return funcState.writtenBbArgs.lookup(funcOp).contains( 160 opOperand.getOperandNumber()); 161 } 162 163 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 164 const AnalysisState &state) const { 165 func::CallOp callOp = cast<func::CallOp>(op); 166 FuncOp funcOp = getCalledFunction(callOp); 167 assert(funcOp && "expected CallOp to a FuncOp"); 168 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 169 if (getFuncOpAnalysisState(state, funcOp) != 170 FuncOpAnalysisState::Analyzed) { 171 // FuncOp not analyzed yet. Any OpResult may be aliasing. 172 SmallVector<OpResult> result; 173 for (OpResult opResult : op->getOpResults()) 174 if (opResult.getType().isa<TensorType>()) 175 result.push_back(opResult); 176 return result; 177 } 178 179 // Get aliasing results from state. 180 auto aliasingReturnVals = 181 funcState.aliasingReturnVals.lookup(funcOp).lookup( 182 opOperand.getOperandNumber()); 183 SmallVector<OpResult> result; 184 for (int64_t resultIdx : aliasingReturnVals) 185 result.push_back(callOp->getOpResult(resultIdx)); 186 return result; 187 } 188 189 SmallVector<OpOperand *> 190 getAliasingOpOperand(Operation *op, OpResult opResult, 191 const AnalysisState &state) const { 192 func::CallOp callOp = cast<func::CallOp>(op); 193 FuncOp funcOp = getCalledFunction(callOp); 194 assert(funcOp && "expected CallOp to a FuncOp"); 195 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 196 if (getFuncOpAnalysisState(state, funcOp) != 197 FuncOpAnalysisState::Analyzed) { 198 // FuncOp not analyzed yet. Any OpOperand may be aliasing. 199 SmallVector<OpOperand *> result; 200 for (OpOperand &opOperand : op->getOpOperands()) 201 if (opOperand.get().getType().isa<TensorType>()) 202 result.push_back(&opOperand); 203 return result; 204 } 205 206 // Get aliasing bbArgs from state. 207 auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( 208 opResult.getResultNumber()); 209 SmallVector<OpOperand *> result; 210 for (int64_t bbArgIdx : aliasingFuncArgs) 211 result.push_back(&callOp->getOpOperand(bbArgIdx)); 212 return result; 213 } 214 215 BufferRelation bufferRelation(Operation *op, OpResult opResult, 216 const AnalysisState &state) const { 217 return BufferRelation::Equivalent; 218 } 219 220 /// All function arguments are writable. It is the responsibility of the 221 /// CallOp to insert buffer copies where necessary. 222 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 223 BufferizationState &state) const { 224 func::CallOp callOp = cast<func::CallOp>(op); 225 unsigned numResults = callOp.getNumResults(); 226 unsigned numOperands = callOp->getNumOperands(); 227 FuncOp funcOp = getCalledFunction(callOp); 228 assert(funcOp && "expected CallOp to a FuncOp"); 229 FunctionType funcType = funcOp.getFunctionType(); 230 const FuncAnalysisState &funcState = 231 getFuncAnalysisState(state.getAnalysisState()); 232 const OneShotBufferizationOptions &options = 233 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 234 235 // Result types of the bufferized CallOp. 236 SmallVector<Type> resultTypes; 237 // Replacement values for the existing CallOp. These are usually the results 238 // of the bufferized CallOp, unless a tensor result folds onto an operand. 239 SmallVector<Value> replacementValues(numResults, Value()); 240 // For non-tensor results: A mapping from return val indices of the old 241 // CallOp to return val indices of the bufferized CallOp. 242 SmallVector<Optional<unsigned>> retValMapping(numResults, None); 243 // Operands of the bufferized CallOp. 244 SmallVector<Value> newOperands(numOperands, Value()); 245 246 // Based on previously gathered equivalence information, we know if a 247 // tensor result folds onto an operand. These are the only tensor value 248 // results that are supported at the moment. 249 // 250 // For tensors return values that do not fold onto an operand, additional 251 // work is needed (TODO) to either: 252 // * hoist a result into an inplaceable operand or 253 // * devise a better representation to truly return a buffer. 254 // 255 // Note: If a function has no body, no equivalence information is 256 // available. Consequently, a tensor return value cannot be proven to fold 257 // onto a FuncOp bbArg, so calls to such functions are not bufferizable at 258 // the moment. 259 260 // 1. Compute the result types of the new CallOp. Tensor results that are 261 // equivalent to a FuncOp bbArg are no longer returned. 262 for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { 263 unsigned returnValIdx = it.index(); 264 Type returnType = it.value(); 265 if (!returnType.isa<TensorType>()) { 266 // Non-tensor values are returned. 267 retValMapping[returnValIdx] = resultTypes.size(); 268 resultTypes.push_back(returnType); 269 continue; 270 } 271 272 if (Optional<int64_t> bbArgIdx = 273 getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) { 274 // Return operands that are equivalent to some bbArg, are not 275 // returned. 276 FailureOr<Value> bufferOrFailure = 277 state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); 278 if (failed(bufferOrFailure)) 279 return failure(); 280 replacementValues[returnValIdx] = *bufferOrFailure; 281 newOperands[*bbArgIdx] = *bufferOrFailure; 282 continue; 283 } 284 285 if (!options.allowReturnAllocs) 286 return callOp->emitError( 287 "call to FuncOp that returns non-equivalent tensors not supported"); 288 289 // Returning a memref. This memref is not equivalent to any bbArg. It is 290 // likely a newly allocated buffer. We may want to hoist such allocations 291 // to the call site in the future. 292 retValMapping[returnValIdx] = resultTypes.size(); 293 resultTypes.push_back(funcType.getResult(resultTypes.size())); 294 } 295 296 // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. 297 for (OpOperand &opOperand : callOp->getOpOperands()) { 298 unsigned idx = opOperand.getOperandNumber(); 299 Value tensorOperand = opOperand.get(); 300 301 // Non-tensor operands are just copied. 302 if (!tensorOperand.getType().isa<TensorType>()) { 303 newOperands[idx] = tensorOperand; 304 continue; 305 } 306 307 // Retrieve buffers for tensor operands. Tensor operand buffers, who's 308 // corresponding FuncOp bbArgs are equivalent to a returned tensor, were 309 // already stored in `newOperands` during Step 1. 310 Value buffer = newOperands[idx]; 311 if (!buffer) { 312 FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand); 313 if (failed(bufferOrFailure)) 314 return failure(); 315 buffer = *bufferOrFailure; 316 } 317 318 // Caller / callee type mismatch is handled with a CastOp. 319 auto memRefType = funcType.getInput(idx); 320 // Since we don't yet have a clear layout story, to_memref may 321 // conservatively turn tensors into more dynamic memref than necessary. 322 // If the memref type of the callee fails, introduce an extra memref.cast 323 // that will either canonicalize away or fail compilation until we can do 324 // something better. 325 if (buffer.getType() != memRefType) { 326 assert( 327 memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && 328 "CallOp::bufferize: cast incompatible"); 329 Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), 330 memRefType, buffer); 331 buffer = castBuffer; 332 } 333 newOperands[idx] = buffer; 334 } 335 336 // 3. Create the new CallOp. 337 Operation *newCallOp = rewriter.create<func::CallOp>( 338 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 339 newCallOp->setAttrs(callOp->getAttrs()); 340 // Get replacement values for non-tensor / non-equivalent results. 341 for (unsigned i = 0; i < replacementValues.size(); ++i) { 342 if (replacementValues[i]) 343 continue; 344 replacementValues[i] = newCallOp->getResult(*retValMapping[i]); 345 } 346 347 // 4. Replace the old op with the new op. 348 replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); 349 350 return success(); 351 } 352 }; 353 354 struct ReturnOpInterface 355 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 356 func::ReturnOp> { 357 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 358 const AnalysisState &state) const { 359 return true; 360 } 361 362 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 363 const AnalysisState &state) const { 364 return false; 365 } 366 367 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 368 const AnalysisState &state) const { 369 return {}; 370 } 371 372 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 373 BufferizationState &state) const { 374 #ifndef NDEBUG 375 auto returnOp = cast<func::ReturnOp>(op); 376 assert(isa<FuncOp>(returnOp->getParentOp()) && 377 "only support FuncOp parent for ReturnOp"); 378 #endif // NDEBUG 379 380 // ReturnOps are bufferized as part of FuncOps. 381 return failure(); 382 } 383 }; 384 385 struct FuncOpInterface 386 : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> { 387 /// Rewrite function bbArgs and return values into buffer form (using the 388 /// canonical memref layout for now). This function bufferizes the function 389 /// signature and the ReturnOp. When the entire function body has been 390 /// bufferized, function return types can be switched to more concise memref 391 /// types as part of `foldMemRefCasts`. 392 /// 393 /// When a tensor function argument is known to be equivalent to a tensor 394 /// result, it is dropped from the return values. 395 /// 396 /// All function bbArgs are writable unless they are explicitly marked as 397 /// read-only. Callers must insert copies when needed. 398 /// 399 /// Note: Returning a memref is possible, but corresponding CallOp 400 /// bufferizations fail unless `allowReturnAllocs`. 401 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 402 BufferizationState &state) const { 403 auto funcOp = cast<FuncOp>(op); 404 FunctionType funcType = funcOp.getFunctionType(); 405 const FuncAnalysisState &funcState = 406 getFuncAnalysisState(state.getAnalysisState()); 407 const BufferizationOptions &options = state.getOptions(); 408 409 // Construct the bufferized function type. 410 SmallVector<Type> argTypes; 411 for (const auto &it : llvm::enumerate(funcType.getInputs())) { 412 Type argType = it.value(); 413 if (auto tensorType = argType.dyn_cast<TensorType>()) { 414 argTypes.push_back( 415 getBufferizedFunctionArgType(funcOp, it.index(), options)); 416 continue; 417 } 418 argTypes.push_back(argType); 419 } 420 421 // Bodiless functions are assumed opaque and we cannot know the 422 // bufferization contract they want to enforce. As a consequence, only 423 // support functions that don't return any tensors atm. 424 if (funcOp.getBody().empty()) { 425 SmallVector<Type> retTypes; 426 for (Type resultType : funcType.getResults()) { 427 if (resultType.isa<TensorType>()) 428 return funcOp->emitError() << "cannot bufferize bodiless function " 429 << "that returns a tensor"; 430 retTypes.push_back(resultType); 431 } 432 funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); 433 return success(); 434 } 435 436 // TODO: Support functions with multiple returns. 437 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 438 assert(returnOp && "expected func with single return op"); 439 440 // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. 441 Block &frontBlock = funcOp.getBody().front(); 442 for (BlockArgument &bbArg : frontBlock.getArguments()) { 443 auto tensorType = bbArg.getType().dyn_cast<TensorType>(); 444 // Non-tensor types stay the same. 445 if (!tensorType) 446 continue; 447 448 // Collect all uses of the bbArg. 449 SmallVector<OpOperand *> bbArgUses; 450 for (OpOperand &use : bbArg.getUses()) 451 bbArgUses.push_back(&use); 452 453 // Change the bbArg type to memref. 454 Type memrefType = 455 getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); 456 bbArg.setType(memrefType); 457 458 // Replace all uses of the original tensor bbArg. 459 rewriter.setInsertionPointToStart(&frontBlock); 460 if (!bbArgUses.empty()) { 461 // Insert to_tensor because the remaining function body has not been 462 // bufferized yet. 463 Value toTensorOp = 464 rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg); 465 for (OpOperand *use : bbArgUses) 466 use->set(toTensorOp); 467 } 468 } 469 470 // 2. For each result, keep track of which inplace argument it reuses. 471 SmallVector<Value> returnValues; 472 for (OpOperand &returnOperand : returnOp->getOpOperands()) { 473 Value returnVal = returnOperand.get(); 474 475 // If not a tensor type just forward it. 476 if (!returnVal.getType().isa<RankedTensorType>()) { 477 returnValues.push_back(returnVal); 478 continue; 479 } 480 481 // If return operand is equivalent to some bbArg, no need to return it. 482 if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx( 483 funcOp, funcState, returnOperand.getOperandNumber())) { 484 rewriter.setInsertionPoint(returnOp); 485 Location loc = returnOp.getLoc(); 486 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 487 loc, getMemRefType(returnVal.getType().cast<TensorType>(), options), 488 returnVal); 489 BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); 490 // Note: This copy will fold away. It must be inserted here to ensure 491 // that `returnVal` still has at least one use and does not fold away. 492 if (failed( 493 createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) 494 return funcOp->emitError("could not generate copy for bbArg"); 495 continue; 496 } 497 498 returnValues.push_back(*state.getBuffer(rewriter, returnOperand)); 499 } 500 501 // 3. Rewrite the terminator without the in-place bufferizable values. 502 returnOp.operandsMutable().assign(returnValues); 503 504 // 4. Rewrite the FuncOp type to buffer form. 505 funcOp.setType(FunctionType::get(op->getContext(), argTypes, 506 ValueRange(returnValues).getTypes())); 507 508 return success(); 509 } 510 511 /// Return `true` if the given function argument is writable. 512 bool isWritable(Operation *op, Value value, 513 const AnalysisState &state) const { 514 auto funcOp = cast<FuncOp>(op); 515 BlockArgument bbArg = value.dyn_cast<BlockArgument>(); 516 assert(bbArg && "expected BlockArgument"); 517 518 // "bufferization.writable" overrides other writability decisions. This is 519 // currently used for testing only. 520 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 521 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 522 return writable.getValue(); 523 524 // All function arguments are writable by default. 525 return true; 526 } 527 528 bool isAllocationHoistingBarrier(Operation *op) const { return true; } 529 }; 530 531 } // namespace func_ext 532 } // namespace bufferization 533 } // namespace mlir 534 535 void mlir::bufferization::func_ext:: 536 registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 537 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 538 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 539 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 540 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 541 }); 542 } 543