1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 18 namespace mlir { 19 namespace bufferization { 20 namespace func_ext { 21 22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 23 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 24 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 25 auto createdAliasingOperands = 26 aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping()); 27 auto createdAliasingResults = 28 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 29 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 30 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 31 (void)createdEquiv; 32 (void)createdAliasingOperands; 33 (void)createdAliasingResults; 34 (void)createdRead; 35 (void)createdWritten; 36 #ifndef NDEBUG 37 assert(createdEquiv.second && "equivalence info exists already"); 38 assert(createdAliasingOperands.second && "aliasing info exists already"); 39 assert(createdAliasingResults.second && "aliasing info exists already"); 40 assert(createdRead.second && "bbarg access info exists already"); 41 assert(createdWritten.second && "bbarg access info exists already"); 42 #endif // NDEBUG 43 } 44 45 /// Return the unique ReturnOp that terminates `funcOp`. 46 /// Return nullptr if there is no such unique ReturnOp. 47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { 48 func::ReturnOp returnOp; 49 for (Block &b : funcOp.getBody()) { 50 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 51 if (returnOp) 52 return nullptr; 53 returnOp = candidateOp; 54 } 55 } 56 return returnOp; 57 } 58 59 /// Return the index-th bufferized function argument type. This assumes that the 60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be 61 /// specified by the user. If no layout map is specified, the default layout map 62 /// (as per `options.functionBoundaryTypeConversion`) is used. 63 static BaseMemRefType 64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 65 const BufferizationOptions &options) { 66 auto tensorType = 67 funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>(); 68 assert(tensorType && "expected TensorType"); 69 70 BaseMemRefType memrefType; 71 if (options.functionBoundaryTypeConversion == 72 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 73 memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); 74 } else { 75 // Note: Layout maps on function parameters cannot be inferred. The best we 76 // can do at the moment is "fully dynamic". 77 memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType); 78 } 79 80 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 81 index, BufferizationDialect::kBufferLayoutAttrName); 82 if (!layoutAttr) 83 return memrefType; 84 85 auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); 86 assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 87 return MemRefType::get( 88 rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 89 layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); 90 } 91 92 /// Return the FuncOp called by `callOp`. 93 static FuncOp getCalledFunction(CallOpInterface callOp) { 94 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 95 if (!sym) 96 return nullptr; 97 return dyn_cast_or_null<FuncOp>( 98 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 99 } 100 101 /// Get FuncAnalysisState. 102 static const FuncAnalysisState & 103 getFuncAnalysisState(const AnalysisState &state) { 104 Optional<const FuncAnalysisState *> maybeState = 105 state.getDialectState<FuncAnalysisState>( 106 func::FuncDialect::getDialectNamespace()); 107 assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); 108 return **maybeState; 109 } 110 111 /// Return the state (phase) of analysis of the FuncOp. 112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 113 FuncOp funcOp) { 114 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 115 auto it = funcState.analyzedFuncOps.find(funcOp); 116 if (it == funcState.analyzedFuncOps.end()) 117 return FuncOpAnalysisState::NotAnalyzed; 118 return it->second; 119 } 120 121 /// Return the index of the bbArg in the given FuncOp that is equivalent to the 122 /// specified return value (if any). 123 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp, 124 const FuncAnalysisState &state, 125 int64_t returnValIdx) { 126 auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 127 if (funcOpIt == state.equivalentFuncArgs.end()) 128 // No equivalence info stores for funcOp. 129 return None; 130 131 auto retValIt = funcOpIt->getSecond().find(returnValIdx); 132 if (retValIt == funcOpIt->getSecond().end()) 133 // Return value has no equivalent bbArg. 134 return None; 135 136 return retValIt->getSecond(); 137 } 138 139 struct CallOpInterface 140 : public BufferizableOpInterface::ExternalModel<CallOpInterface, 141 func::CallOp> { 142 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 143 const AnalysisState &state) const { 144 func::CallOp callOp = cast<func::CallOp>(op); 145 FuncOp funcOp = getCalledFunction(callOp); 146 assert(funcOp && "expected CallOp to a FuncOp"); 147 148 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 149 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 150 // FuncOp not analyzed yet. Assume that OpOperand is read. 151 return true; 152 153 return funcState.readBbArgs.lookup(funcOp).contains( 154 opOperand.getOperandNumber()); 155 } 156 157 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 158 const AnalysisState &state) const { 159 func::CallOp callOp = cast<func::CallOp>(op); 160 FuncOp funcOp = getCalledFunction(callOp); 161 assert(funcOp && "expected CallOp to a FuncOp"); 162 163 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 164 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 165 // FuncOp not analyzed yet. Assume that OpOperand is written. 166 return true; 167 168 return funcState.writtenBbArgs.lookup(funcOp).contains( 169 opOperand.getOperandNumber()); 170 } 171 172 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 173 const AnalysisState &state) const { 174 func::CallOp callOp = cast<func::CallOp>(op); 175 FuncOp funcOp = getCalledFunction(callOp); 176 assert(funcOp && "expected CallOp to a FuncOp"); 177 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 178 if (getFuncOpAnalysisState(state, funcOp) != 179 FuncOpAnalysisState::Analyzed) { 180 // FuncOp not analyzed yet. Any OpResult may be aliasing. 181 SmallVector<OpResult> result; 182 for (OpResult opResult : op->getOpResults()) 183 if (opResult.getType().isa<TensorType>()) 184 result.push_back(opResult); 185 return result; 186 } 187 188 // Get aliasing results from state. 189 auto aliasingReturnVals = 190 funcState.aliasingReturnVals.lookup(funcOp).lookup( 191 opOperand.getOperandNumber()); 192 SmallVector<OpResult> result; 193 for (int64_t resultIdx : aliasingReturnVals) 194 result.push_back(callOp->getOpResult(resultIdx)); 195 return result; 196 } 197 198 SmallVector<OpOperand *> 199 getAliasingOpOperand(Operation *op, OpResult opResult, 200 const AnalysisState &state) const { 201 func::CallOp callOp = cast<func::CallOp>(op); 202 FuncOp funcOp = getCalledFunction(callOp); 203 assert(funcOp && "expected CallOp to a FuncOp"); 204 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 205 if (getFuncOpAnalysisState(state, funcOp) != 206 FuncOpAnalysisState::Analyzed) { 207 // FuncOp not analyzed yet. Any OpOperand may be aliasing. 208 SmallVector<OpOperand *> result; 209 for (OpOperand &opOperand : op->getOpOperands()) 210 if (opOperand.get().getType().isa<TensorType>()) 211 result.push_back(&opOperand); 212 return result; 213 } 214 215 // Get aliasing bbArgs from state. 216 auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( 217 opResult.getResultNumber()); 218 SmallVector<OpOperand *> result; 219 for (int64_t bbArgIdx : aliasingFuncArgs) 220 result.push_back(&callOp->getOpOperand(bbArgIdx)); 221 return result; 222 } 223 224 BufferRelation bufferRelation(Operation *op, OpResult opResult, 225 const AnalysisState &state) const { 226 func::CallOp callOp = cast<func::CallOp>(op); 227 FuncOp funcOp = getCalledFunction(callOp); 228 assert(funcOp && "expected CallOp to a FuncOp"); 229 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 230 if (getFuncOpAnalysisState(state, funcOp) != 231 FuncOpAnalysisState::Analyzed) { 232 // Function not analyzed yet. The conservative answer is "None". 233 return BufferRelation::None; 234 } 235 236 Optional<int64_t> maybeEquiv = 237 getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber()); 238 if (maybeEquiv.hasValue()) { 239 #ifndef NDEBUG 240 SmallVector<OpOperand *> aliasingOpOperands = 241 getAliasingOpOperand(op, opResult, state); 242 assert(aliasingOpOperands.size() == 1 && 243 "expected exactly 1 aliasing OpOperand"); 244 assert(aliasingOpOperands.front()->getOperandNumber() == 245 maybeEquiv.getValue() && 246 "inconsistent analysis state"); 247 #endif 248 return BufferRelation::Equivalent; 249 } 250 return BufferRelation::None; 251 } 252 253 /// All function arguments are writable. It is the responsibility of the 254 /// CallOp to insert buffer copies where necessary. 255 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 256 BufferizationState &state) const { 257 func::CallOp callOp = cast<func::CallOp>(op); 258 unsigned numResults = callOp.getNumResults(); 259 unsigned numOperands = callOp->getNumOperands(); 260 FuncOp funcOp = getCalledFunction(callOp); 261 assert(funcOp && "expected CallOp to a FuncOp"); 262 FunctionType funcType = funcOp.getFunctionType(); 263 264 // Result types of the bufferized CallOp. 265 SmallVector<Type> resultTypes; 266 // Replacement values for the existing CallOp. These are usually the results 267 // of the bufferized CallOp, unless a tensor result folds onto an operand. 268 SmallVector<Value> replacementValues(numResults, Value()); 269 // For non-tensor results: A mapping from return val indices of the old 270 // CallOp to return val indices of the bufferized CallOp. 271 SmallVector<Optional<unsigned>> retValMapping(numResults, None); 272 // Operands of the bufferized CallOp. 273 SmallVector<Value> newOperands(numOperands, Value()); 274 275 // 1. Compute the result types of the new CallOp. 276 for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { 277 unsigned returnValIdx = it.index(); 278 Type returnType = it.value(); 279 if (!returnType.isa<TensorType>()) { 280 // Non-tensor values are returned. 281 retValMapping[returnValIdx] = resultTypes.size(); 282 resultTypes.push_back(returnType); 283 continue; 284 } 285 286 // Returning a memref. 287 retValMapping[returnValIdx] = resultTypes.size(); 288 resultTypes.push_back(funcType.getResult(resultTypes.size())); 289 } 290 291 // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. 292 for (OpOperand &opOperand : callOp->getOpOperands()) { 293 unsigned idx = opOperand.getOperandNumber(); 294 Value tensorOperand = opOperand.get(); 295 296 // Non-tensor operands are just copied. 297 if (!tensorOperand.getType().isa<TensorType>()) { 298 newOperands[idx] = tensorOperand; 299 continue; 300 } 301 302 // Retrieve buffers for tensor operands. 303 Value buffer = newOperands[idx]; 304 if (!buffer) { 305 FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand); 306 if (failed(bufferOrFailure)) 307 return failure(); 308 buffer = *bufferOrFailure; 309 } 310 311 // Caller / callee type mismatch is handled with a CastOp. 312 auto memRefType = funcType.getInput(idx); 313 // Since we don't yet have a clear layout story, to_memref may 314 // conservatively turn tensors into more dynamic memref than necessary. 315 // If the memref type of the callee fails, introduce an extra memref.cast 316 // that will either canonicalize away or fail compilation until we can do 317 // something better. 318 if (buffer.getType() != memRefType) { 319 assert( 320 memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && 321 "CallOp::bufferize: cast incompatible"); 322 Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), 323 memRefType, buffer); 324 buffer = castBuffer; 325 } 326 newOperands[idx] = buffer; 327 } 328 329 // 3. Create the new CallOp. 330 Operation *newCallOp = rewriter.create<func::CallOp>( 331 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 332 newCallOp->setAttrs(callOp->getAttrs()); 333 // Get replacement values. 334 for (unsigned i = 0; i < replacementValues.size(); ++i) { 335 if (replacementValues[i]) 336 continue; 337 replacementValues[i] = newCallOp->getResult(*retValMapping[i]); 338 } 339 340 // 4. Replace the old op with the new op. 341 replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); 342 343 return success(); 344 } 345 }; 346 347 struct ReturnOpInterface 348 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 349 func::ReturnOp> { 350 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 351 const AnalysisState &state) const { 352 return true; 353 } 354 355 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 356 const AnalysisState &state) const { 357 return false; 358 } 359 360 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 361 const AnalysisState &state) const { 362 return {}; 363 } 364 365 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 366 BufferizationState &state) const { 367 #ifndef NDEBUG 368 auto returnOp = cast<func::ReturnOp>(op); 369 assert(isa<FuncOp>(returnOp->getParentOp()) && 370 "only support FuncOp parent for ReturnOp"); 371 #endif // NDEBUG 372 373 // ReturnOps are bufferized as part of FuncOps. 374 return success(); 375 } 376 }; 377 378 struct FuncOpInterface 379 : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> { 380 /// Rewrite function bbArgs and return values into buffer form. This function 381 /// bufferizes the function signature and the ReturnOp. When the entire 382 /// function body has been bufferized, function return types can be switched 383 /// to more concise memref types as part of `foldMemRefCasts`. 384 /// 385 /// All function bbArgs are writable unless they are explicitly marked as 386 /// read-only. Callers must insert copies when needed. 387 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 388 BufferizationState &state) const { 389 auto funcOp = cast<FuncOp>(op); 390 FunctionType funcType = funcOp.getFunctionType(); 391 const OneShotBufferizationOptions &options = 392 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 393 394 // Construct the bufferized function type. 395 SmallVector<Type> argTypes; 396 for (const auto &it : llvm::enumerate(funcType.getInputs())) { 397 Type argType = it.value(); 398 if (auto tensorType = argType.dyn_cast<TensorType>()) { 399 argTypes.push_back( 400 getBufferizedFunctionArgType(funcOp, it.index(), options)); 401 continue; 402 } 403 argTypes.push_back(argType); 404 } 405 406 // Bodiless functions are assumed opaque and we cannot know the 407 // bufferization contract they want to enforce. As a consequence, only 408 // support functions that don't return any tensors atm. 409 if (funcOp.getBody().empty()) { 410 SmallVector<Type> retTypes; 411 for (Type resultType : funcType.getResults()) { 412 if (resultType.isa<TensorType>()) 413 return funcOp->emitError() << "cannot bufferize bodiless function " 414 << "that returns a tensor"; 415 retTypes.push_back(resultType); 416 } 417 funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); 418 return success(); 419 } 420 421 // TODO: Support functions with multiple returns. 422 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 423 assert(returnOp && "expected func with single return op"); 424 Location loc = returnOp.getLoc(); 425 426 // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. 427 Block &frontBlock = funcOp.getBody().front(); 428 for (BlockArgument &bbArg : frontBlock.getArguments()) { 429 auto tensorType = bbArg.getType().dyn_cast<TensorType>(); 430 // Non-tensor types stay the same. 431 if (!tensorType) 432 continue; 433 434 // Collect all uses of the bbArg. 435 SmallVector<OpOperand *> bbArgUses; 436 for (OpOperand &use : bbArg.getUses()) 437 bbArgUses.push_back(&use); 438 439 // Change the bbArg type to memref. 440 Type memrefType = 441 getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); 442 bbArg.setType(memrefType); 443 444 // Replace all uses of the original tensor bbArg. 445 rewriter.setInsertionPointToStart(&frontBlock); 446 if (!bbArgUses.empty()) { 447 // Insert to_tensor because the remaining function body has not been 448 // bufferized yet. 449 Value toTensorOp = 450 rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg); 451 for (OpOperand *use : bbArgUses) 452 use->set(toTensorOp); 453 } 454 } 455 456 // 2. For each result, keep track of which inplace argument it reuses. 457 SmallVector<Value> returnValues; 458 for (OpOperand &returnOperand : returnOp->getOpOperands()) { 459 Value returnVal = returnOperand.get(); 460 auto tensorType = returnVal.getType().dyn_cast<TensorType>(); 461 rewriter.setInsertionPoint(returnOp); 462 463 // If not a tensor type just forward it. 464 if (!tensorType) { 465 returnValues.push_back(returnVal); 466 continue; 467 } 468 469 BaseMemRefType resultType; 470 if (options.functionBoundaryTypeConversion == 471 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 472 resultType = getMemRefTypeWithStaticIdentityLayout(tensorType); 473 } else { 474 // Note: If `InferLayoutMap`, cast are later folded away. 475 resultType = getMemRefTypeWithFullyDynamicLayout(tensorType); 476 } 477 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 478 loc, resultType, returnVal); 479 returnValues.push_back(toMemrefOp); 480 } 481 482 // 3. Rewrite the terminator without the in-place bufferizable values. 483 returnOp.operandsMutable().assign(returnValues); 484 485 // 4. Rewrite the FuncOp type to buffer form. 486 funcOp.setType(FunctionType::get(op->getContext(), argTypes, 487 ValueRange(returnValues).getTypes())); 488 489 return success(); 490 } 491 492 /// Return `true` if the given function argument is writable. 493 bool isWritable(Operation *op, Value value, 494 const AnalysisState &state) const { 495 auto funcOp = cast<FuncOp>(op); 496 BlockArgument bbArg = value.dyn_cast<BlockArgument>(); 497 assert(bbArg && "expected BlockArgument"); 498 499 // "bufferization.writable" overrides other writability decisions. This is 500 // currently used for testing only. 501 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 502 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 503 return writable.getValue(); 504 505 // All function arguments are writable by default. 506 return true; 507 } 508 }; 509 510 } // namespace func_ext 511 } // namespace bufferization 512 } // namespace mlir 513 514 void mlir::bufferization::func_ext:: 515 registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 516 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 517 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 518 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 519 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 520 }); 521 } 522