1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 18 namespace mlir { 19 namespace bufferization { 20 namespace func_ext { 21 22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 23 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 24 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 25 auto createdAliasingOperands = 26 aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping()); 27 auto createdAliasingResults = 28 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 29 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 30 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 31 (void)createdEquiv; 32 (void)createdAliasingOperands; 33 (void)createdAliasingResults; 34 (void)createdRead; 35 (void)createdWritten; 36 #ifndef NDEBUG 37 assert(createdEquiv.second && "equivalence info exists already"); 38 assert(createdAliasingOperands.second && "aliasing info exists already"); 39 assert(createdAliasingResults.second && "aliasing info exists already"); 40 assert(createdRead.second && "bbarg access info exists already"); 41 assert(createdWritten.second && "bbarg access info exists already"); 42 #endif // NDEBUG 43 } 44 45 /// Return the unique ReturnOp that terminates `funcOp`. 46 /// Return nullptr if there is no such unique ReturnOp. 47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { 48 func::ReturnOp returnOp; 49 for (Block &b : funcOp.getBody()) { 50 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 51 if (returnOp) 52 return nullptr; 53 returnOp = candidateOp; 54 } 55 } 56 return returnOp; 57 } 58 59 /// Return the index-th bufferized function argument type. This assumes that the 60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be 61 /// specified by the user. If no layout map is specified, the default layout map 62 /// (as per `options.functionBoundaryTypeConversion`) is used. 63 static BaseMemRefType 64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 65 const BufferizationOptions &options) { 66 auto tensorType = 67 funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>(); 68 assert(tensorType && "expected TensorType"); 69 70 BaseMemRefType memrefType; 71 if (options.functionBoundaryTypeConversion == 72 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 73 memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); 74 } else { 75 // Note: Layout maps on function parameters cannot be inferred. The best we 76 // can do at the moment is "fully dynamic". 77 memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType); 78 } 79 80 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 81 index, BufferizationDialect::kBufferLayoutAttrName); 82 if (!layoutAttr) 83 return memrefType; 84 85 auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); 86 assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 87 return MemRefType::get( 88 rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 89 layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); 90 } 91 92 /// Return the FuncOp called by `callOp`. 93 static FuncOp getCalledFunction(CallOpInterface callOp) { 94 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 95 if (!sym) 96 return nullptr; 97 return dyn_cast_or_null<FuncOp>( 98 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 99 } 100 101 /// Get FuncAnalysisState. 102 static const FuncAnalysisState & 103 getFuncAnalysisState(const AnalysisState &state) { 104 Optional<const FuncAnalysisState *> maybeState = 105 state.getDialectState<FuncAnalysisState>( 106 func::FuncDialect::getDialectNamespace()); 107 assert(maybeState && "FuncAnalysisState does not exist"); 108 return **maybeState; 109 } 110 111 /// Return the state (phase) of analysis of the FuncOp. 112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 113 FuncOp funcOp) { 114 Optional<const FuncAnalysisState *> maybeState = 115 state.getDialectState<FuncAnalysisState>( 116 func::FuncDialect::getDialectNamespace()); 117 if (!maybeState.hasValue()) 118 return FuncOpAnalysisState::NotAnalyzed; 119 const auto &analyzedFuncOps = maybeState.getValue()->analyzedFuncOps; 120 auto it = analyzedFuncOps.find(funcOp); 121 if (it == analyzedFuncOps.end()) 122 return FuncOpAnalysisState::NotAnalyzed; 123 return it->second; 124 } 125 126 /// Return the index of the bbArg in the given FuncOp that is equivalent to the 127 /// specified return value (if any). 128 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp, 129 const FuncAnalysisState &state, 130 int64_t returnValIdx) { 131 auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 132 if (funcOpIt == state.equivalentFuncArgs.end()) 133 // No equivalence info stores for funcOp. 134 return None; 135 136 auto retValIt = funcOpIt->getSecond().find(returnValIdx); 137 if (retValIt == funcOpIt->getSecond().end()) 138 // Return value has no equivalent bbArg. 139 return None; 140 141 return retValIt->getSecond(); 142 } 143 144 struct CallOpInterface 145 : public BufferizableOpInterface::ExternalModel<CallOpInterface, 146 func::CallOp> { 147 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 148 const AnalysisState &state) const { 149 func::CallOp callOp = cast<func::CallOp>(op); 150 FuncOp funcOp = getCalledFunction(callOp); 151 assert(funcOp && "expected CallOp to a FuncOp"); 152 153 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 154 // FuncOp not analyzed yet. Assume that OpOperand is read. 155 return true; 156 157 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 158 return funcState.readBbArgs.lookup(funcOp).contains( 159 opOperand.getOperandNumber()); 160 } 161 162 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 163 const AnalysisState &state) const { 164 func::CallOp callOp = cast<func::CallOp>(op); 165 FuncOp funcOp = getCalledFunction(callOp); 166 assert(funcOp && "expected CallOp to a FuncOp"); 167 168 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 169 // FuncOp not analyzed yet. Assume that OpOperand is written. 170 return true; 171 172 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 173 return funcState.writtenBbArgs.lookup(funcOp).contains( 174 opOperand.getOperandNumber()); 175 } 176 177 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 178 const AnalysisState &state) const { 179 func::CallOp callOp = cast<func::CallOp>(op); 180 FuncOp funcOp = getCalledFunction(callOp); 181 assert(funcOp && "expected CallOp to a FuncOp"); 182 if (getFuncOpAnalysisState(state, funcOp) != 183 FuncOpAnalysisState::Analyzed) { 184 // FuncOp not analyzed yet. Any OpResult may be aliasing. 185 SmallVector<OpResult> result; 186 for (OpResult opResult : op->getOpResults()) 187 if (opResult.getType().isa<TensorType>()) 188 result.push_back(opResult); 189 return result; 190 } 191 192 // Get aliasing results from state. 193 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 194 auto aliasingReturnVals = 195 funcState.aliasingReturnVals.lookup(funcOp).lookup( 196 opOperand.getOperandNumber()); 197 SmallVector<OpResult> result; 198 for (int64_t resultIdx : aliasingReturnVals) 199 result.push_back(callOp->getOpResult(resultIdx)); 200 return result; 201 } 202 203 SmallVector<OpOperand *> 204 getAliasingOpOperand(Operation *op, OpResult opResult, 205 const AnalysisState &state) const { 206 func::CallOp callOp = cast<func::CallOp>(op); 207 FuncOp funcOp = getCalledFunction(callOp); 208 assert(funcOp && "expected CallOp to a FuncOp"); 209 if (getFuncOpAnalysisState(state, funcOp) != 210 FuncOpAnalysisState::Analyzed) { 211 // FuncOp not analyzed yet. Any OpOperand may be aliasing. 212 SmallVector<OpOperand *> result; 213 for (OpOperand &opOperand : op->getOpOperands()) 214 if (opOperand.get().getType().isa<TensorType>()) 215 result.push_back(&opOperand); 216 return result; 217 } 218 219 // Get aliasing bbArgs from state. 220 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 221 auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( 222 opResult.getResultNumber()); 223 SmallVector<OpOperand *> result; 224 for (int64_t bbArgIdx : aliasingFuncArgs) 225 result.push_back(&callOp->getOpOperand(bbArgIdx)); 226 return result; 227 } 228 229 BufferRelation bufferRelation(Operation *op, OpResult opResult, 230 const AnalysisState &state) const { 231 func::CallOp callOp = cast<func::CallOp>(op); 232 FuncOp funcOp = getCalledFunction(callOp); 233 assert(funcOp && "expected CallOp to a FuncOp"); 234 if (getFuncOpAnalysisState(state, funcOp) != 235 FuncOpAnalysisState::Analyzed) { 236 // Function not analyzed yet. The conservative answer is "None". 237 return BufferRelation::None; 238 } 239 240 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 241 Optional<int64_t> maybeEquiv = 242 getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber()); 243 if (maybeEquiv) { 244 #ifndef NDEBUG 245 SmallVector<OpOperand *> aliasingOpOperands = 246 getAliasingOpOperand(op, opResult, state); 247 assert(aliasingOpOperands.size() == 1 && 248 "expected exactly 1 aliasing OpOperand"); 249 assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv && 250 "inconsistent analysis state"); 251 #endif 252 return BufferRelation::Equivalent; 253 } 254 return BufferRelation::None; 255 } 256 257 /// All function arguments are writable. It is the responsibility of the 258 /// CallOp to insert buffer copies where necessary. 259 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 260 const BufferizationOptions &options) const { 261 func::CallOp callOp = cast<func::CallOp>(op); 262 unsigned numResults = callOp.getNumResults(); 263 unsigned numOperands = callOp->getNumOperands(); 264 FuncOp funcOp = getCalledFunction(callOp); 265 assert(funcOp && "expected CallOp to a FuncOp"); 266 FunctionType funcType = funcOp.getFunctionType(); 267 268 // Result types of the bufferized CallOp. 269 SmallVector<Type> resultTypes; 270 // Replacement values for the existing CallOp. These are usually the results 271 // of the bufferized CallOp, unless a tensor result folds onto an operand. 272 SmallVector<Value> replacementValues(numResults, Value()); 273 // For non-tensor results: A mapping from return val indices of the old 274 // CallOp to return val indices of the bufferized CallOp. 275 SmallVector<Optional<unsigned>> retValMapping(numResults, None); 276 // Operands of the bufferized CallOp. 277 SmallVector<Value> newOperands(numOperands, Value()); 278 279 // 1. Compute the result types of the new CallOp. 280 for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { 281 unsigned returnValIdx = it.index(); 282 Type returnType = it.value(); 283 if (!returnType.isa<TensorType>()) { 284 // Non-tensor values are returned. 285 retValMapping[returnValIdx] = resultTypes.size(); 286 resultTypes.push_back(returnType); 287 continue; 288 } 289 290 // Returning a memref. 291 retValMapping[returnValIdx] = resultTypes.size(); 292 resultTypes.push_back(funcType.getResult(resultTypes.size())); 293 } 294 295 // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. 296 for (OpOperand &opOperand : callOp->getOpOperands()) { 297 unsigned idx = opOperand.getOperandNumber(); 298 Value tensorOperand = opOperand.get(); 299 300 // Non-tensor operands are just copied. 301 if (!tensorOperand.getType().isa<TensorType>()) { 302 newOperands[idx] = tensorOperand; 303 continue; 304 } 305 306 // Retrieve buffers for tensor operands. 307 Value buffer = newOperands[idx]; 308 if (!buffer) { 309 FailureOr<Value> maybeBuffer = 310 getBuffer(rewriter, opOperand.get(), options); 311 if (failed(maybeBuffer)) 312 return failure(); 313 buffer = *maybeBuffer; 314 } 315 316 // Caller / callee type mismatch is handled with a CastOp. 317 auto memRefType = funcType.getInput(idx); 318 // Since we don't yet have a clear layout story, to_memref may 319 // conservatively turn tensors into more dynamic memref than necessary. 320 // If the memref type of the callee fails, introduce an extra memref.cast 321 // that will either canonicalize away or fail compilation until we can do 322 // something better. 323 if (buffer.getType() != memRefType) { 324 assert( 325 memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && 326 "CallOp::bufferize: cast incompatible"); 327 Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), 328 memRefType, buffer); 329 buffer = castBuffer; 330 } 331 newOperands[idx] = buffer; 332 } 333 334 // 3. Create the new CallOp. 335 Operation *newCallOp = rewriter.create<func::CallOp>( 336 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 337 newCallOp->setAttrs(callOp->getAttrs()); 338 // Get replacement values. 339 for (unsigned i = 0; i < replacementValues.size(); ++i) { 340 if (replacementValues[i]) 341 continue; 342 replacementValues[i] = newCallOp->getResult(*retValMapping[i]); 343 } 344 345 // 4. Replace the old op with the new op. 346 replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); 347 348 return success(); 349 } 350 }; 351 352 struct ReturnOpInterface 353 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 354 func::ReturnOp> { 355 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 356 const AnalysisState &state) const { 357 return true; 358 } 359 360 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 361 const AnalysisState &state) const { 362 return false; 363 } 364 365 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 366 const AnalysisState &state) const { 367 return {}; 368 } 369 370 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 371 const BufferizationOptions &options) const { 372 #ifndef NDEBUG 373 auto returnOp = cast<func::ReturnOp>(op); 374 assert(isa<FuncOp>(returnOp->getParentOp()) && 375 "only support FuncOp parent for ReturnOp"); 376 #endif // NDEBUG 377 378 // ReturnOps are bufferized as part of FuncOps. 379 return success(); 380 } 381 }; 382 383 struct FuncOpInterface 384 : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> { 385 /// Rewrite function bbArgs and return values into buffer form. This function 386 /// bufferizes the function signature and the ReturnOp. When the entire 387 /// function body has been bufferized, function return types can be switched 388 /// to more concise memref types as part of `foldMemRefCasts`. 389 /// 390 /// All function bbArgs are writable unless they are explicitly marked as 391 /// read-only. Callers must insert copies when needed. 392 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 393 const BufferizationOptions &options) const { 394 auto funcOp = cast<FuncOp>(op); 395 FunctionType funcType = funcOp.getFunctionType(); 396 397 // Construct the bufferized function type. 398 SmallVector<Type> argTypes; 399 for (const auto &it : llvm::enumerate(funcType.getInputs())) { 400 Type argType = it.value(); 401 if (auto tensorType = argType.dyn_cast<TensorType>()) { 402 argTypes.push_back( 403 getBufferizedFunctionArgType(funcOp, it.index(), options)); 404 continue; 405 } 406 argTypes.push_back(argType); 407 } 408 409 // Bodiless functions are assumed opaque and we cannot know the 410 // bufferization contract they want to enforce. As a consequence, only 411 // support functions that don't return any tensors atm. 412 if (funcOp.getBody().empty()) { 413 SmallVector<Type> retTypes; 414 for (Type resultType : funcType.getResults()) { 415 if (resultType.isa<TensorType>()) 416 return funcOp->emitError() << "cannot bufferize bodiless function " 417 << "that returns a tensor"; 418 retTypes.push_back(resultType); 419 } 420 funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); 421 return success(); 422 } 423 424 // TODO: Support functions with multiple returns. 425 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 426 assert(returnOp && "expected func with single return op"); 427 Location loc = returnOp.getLoc(); 428 429 // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. 430 Block &frontBlock = funcOp.getBody().front(); 431 for (BlockArgument &bbArg : frontBlock.getArguments()) { 432 auto tensorType = bbArg.getType().dyn_cast<TensorType>(); 433 // Non-tensor types stay the same. 434 if (!tensorType) 435 continue; 436 437 // Collect all uses of the bbArg. 438 SmallVector<OpOperand *> bbArgUses; 439 for (OpOperand &use : bbArg.getUses()) 440 bbArgUses.push_back(&use); 441 442 // Change the bbArg type to memref. 443 Type memrefType = 444 getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); 445 bbArg.setType(memrefType); 446 447 // Replace all uses of the original tensor bbArg. 448 rewriter.setInsertionPointToStart(&frontBlock); 449 if (!bbArgUses.empty()) { 450 // Insert to_tensor because the remaining function body has not been 451 // bufferized yet. 452 Value toTensorOp = 453 rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg); 454 for (OpOperand *use : bbArgUses) 455 use->set(toTensorOp); 456 } 457 } 458 459 // 2. For each result, keep track of which inplace argument it reuses. 460 SmallVector<Value> returnValues; 461 for (OpOperand &returnOperand : returnOp->getOpOperands()) { 462 Value returnVal = returnOperand.get(); 463 auto tensorType = returnVal.getType().dyn_cast<TensorType>(); 464 rewriter.setInsertionPoint(returnOp); 465 466 // If not a tensor type just forward it. 467 if (!tensorType) { 468 returnValues.push_back(returnVal); 469 continue; 470 } 471 472 BaseMemRefType resultType; 473 if (options.functionBoundaryTypeConversion == 474 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 475 resultType = getMemRefTypeWithStaticIdentityLayout(tensorType); 476 } else { 477 // Note: If `InferLayoutMap`, cast are later folded away. 478 resultType = getMemRefTypeWithFullyDynamicLayout(tensorType); 479 } 480 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 481 loc, resultType, returnVal); 482 returnValues.push_back(toMemrefOp); 483 } 484 485 // 3. Rewrite the terminator without the in-place bufferizable values. 486 returnOp.operandsMutable().assign(returnValues); 487 488 // 4. Rewrite the FuncOp type to buffer form. 489 funcOp.setType(FunctionType::get(op->getContext(), argTypes, 490 ValueRange(returnValues).getTypes())); 491 492 return success(); 493 } 494 495 /// Return `true` if the given function argument is writable. 496 bool isWritable(Operation *op, Value value, 497 const AnalysisState &state) const { 498 auto funcOp = cast<FuncOp>(op); 499 BlockArgument bbArg = value.dyn_cast<BlockArgument>(); 500 assert(bbArg && "expected BlockArgument"); 501 502 // "bufferization.writable" overrides other writability decisions. This is 503 // currently used for testing only. 504 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 505 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 506 return writable.getValue(); 507 508 // All function arguments are writable by default. 509 return true; 510 } 511 }; 512 513 } // namespace func_ext 514 } // namespace bufferization 515 } // namespace mlir 516 517 void mlir::bufferization::func_ext:: 518 registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 519 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 520 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 521 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 522 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 523 }); 524 } 525