1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 18 namespace mlir { 19 namespace bufferization { 20 namespace func_ext { 21 22 void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { 23 analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; 24 auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); 25 auto createdAliasingOperands = 26 aliasingFuncArgs.try_emplace(funcOp, IndexToIndexListMapping()); 27 auto createdAliasingResults = 28 aliasingReturnVals.try_emplace(funcOp, IndexToIndexListMapping()); 29 auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); 30 auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); 31 (void)createdEquiv; 32 (void)createdAliasingOperands; 33 (void)createdAliasingResults; 34 (void)createdRead; 35 (void)createdWritten; 36 #ifndef NDEBUG 37 assert(createdEquiv.second && "equivalence info exists already"); 38 assert(createdAliasingOperands.second && "aliasing info exists already"); 39 assert(createdAliasingResults.second && "aliasing info exists already"); 40 assert(createdRead.second && "bbarg access info exists already"); 41 assert(createdWritten.second && "bbarg access info exists already"); 42 #endif // NDEBUG 43 } 44 45 /// Return the unique ReturnOp that terminates `funcOp`. 46 /// Return nullptr if there is no such unique ReturnOp. 47 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { 48 func::ReturnOp returnOp; 49 for (Block &b : funcOp.getBody()) { 50 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { 51 if (returnOp) 52 return nullptr; 53 returnOp = candidateOp; 54 } 55 } 56 return returnOp; 57 } 58 59 /// Return the index-th bufferized function argument type. This assumes that the 60 /// specified argument is a tensor. If the tensor is ranked, a layout map may be 61 /// specified by the user. If no layout map is specified, the default layout map 62 /// (as per `options.functionBoundaryTypeConversion`) is used. 63 static BaseMemRefType 64 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, 65 const BufferizationOptions &options) { 66 auto tensorType = 67 funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>(); 68 assert(tensorType && "expected TensorType"); 69 70 BaseMemRefType memrefType; 71 if (options.functionBoundaryTypeConversion == 72 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 73 memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); 74 } else { 75 // Note: Layout maps on function parameters cannot be inferred. The best we 76 // can do at the moment is "fully dynamic". 77 memrefType = getMemRefTypeWithFullyDynamicLayout(tensorType); 78 } 79 80 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( 81 index, BufferizationDialect::kBufferLayoutAttrName); 82 if (!layoutAttr) 83 return memrefType; 84 85 auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); 86 assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); 87 return MemRefType::get( 88 rankedMemrefType.getShape(), rankedMemrefType.getElementType(), 89 layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); 90 } 91 92 /// Return the FuncOp called by `callOp`. 93 static FuncOp getCalledFunction(CallOpInterface callOp) { 94 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); 95 if (!sym) 96 return nullptr; 97 return dyn_cast_or_null<FuncOp>( 98 SymbolTable::lookupNearestSymbolFrom(callOp, sym)); 99 } 100 101 /// Get FuncAnalysisState. 102 static const FuncAnalysisState & 103 getFuncAnalysisState(const AnalysisState &state) { 104 Optional<const FuncAnalysisState *> maybeState = 105 state.getDialectState<FuncAnalysisState>( 106 func::FuncDialect::getDialectNamespace()); 107 assert(maybeState && "FuncAnalysisState does not exist"); 108 return **maybeState; 109 } 110 111 /// Return the state (phase) of analysis of the FuncOp. 112 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, 113 FuncOp funcOp) { 114 Optional<const FuncAnalysisState *> maybeState = 115 state.getDialectState<FuncAnalysisState>( 116 func::FuncDialect::getDialectNamespace()); 117 if (!maybeState.has_value()) 118 return FuncOpAnalysisState::NotAnalyzed; 119 const auto &analyzedFuncOps = maybeState.value()->analyzedFuncOps; 120 auto it = analyzedFuncOps.find(funcOp); 121 if (it == analyzedFuncOps.end()) 122 return FuncOpAnalysisState::NotAnalyzed; 123 return it->second; 124 } 125 126 /// Return the index of the bbArg in the given FuncOp that is equivalent to the 127 /// specified return value (if any). 128 static Optional<int64_t> getEquivalentFuncArgIdx(FuncOp funcOp, 129 const FuncAnalysisState &state, 130 int64_t returnValIdx) { 131 auto funcOpIt = state.equivalentFuncArgs.find(funcOp); 132 if (funcOpIt == state.equivalentFuncArgs.end()) 133 // No equivalence info stores for funcOp. 134 return None; 135 136 auto retValIt = funcOpIt->getSecond().find(returnValIdx); 137 if (retValIt == funcOpIt->getSecond().end()) 138 // Return value has no equivalent bbArg. 139 return None; 140 141 return retValIt->getSecond(); 142 } 143 144 struct CallOpInterface 145 : public BufferizableOpInterface::ExternalModel<CallOpInterface, 146 func::CallOp> { 147 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 148 const AnalysisState &state) const { 149 func::CallOp callOp = cast<func::CallOp>(op); 150 FuncOp funcOp = getCalledFunction(callOp); 151 assert(funcOp && "expected CallOp to a FuncOp"); 152 153 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 154 // FuncOp not analyzed yet. Assume that OpOperand is read. 155 return true; 156 157 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 158 return funcState.readBbArgs.lookup(funcOp).contains( 159 opOperand.getOperandNumber()); 160 } 161 162 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 163 const AnalysisState &state) const { 164 func::CallOp callOp = cast<func::CallOp>(op); 165 FuncOp funcOp = getCalledFunction(callOp); 166 assert(funcOp && "expected CallOp to a FuncOp"); 167 168 if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) 169 // FuncOp not analyzed yet. Assume that OpOperand is written. 170 return true; 171 172 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 173 return funcState.writtenBbArgs.lookup(funcOp).contains( 174 opOperand.getOperandNumber()); 175 } 176 177 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 178 const AnalysisState &state) const { 179 func::CallOp callOp = cast<func::CallOp>(op); 180 FuncOp funcOp = getCalledFunction(callOp); 181 assert(funcOp && "expected CallOp to a FuncOp"); 182 if (getFuncOpAnalysisState(state, funcOp) != 183 FuncOpAnalysisState::Analyzed) { 184 // FuncOp not analyzed yet. Any OpResult may be aliasing. 185 SmallVector<OpResult> result; 186 for (OpResult opResult : op->getOpResults()) 187 if (opResult.getType().isa<TensorType>()) 188 result.push_back(opResult); 189 return result; 190 } 191 192 // Get aliasing results from state. 193 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 194 auto aliasingReturnVals = 195 funcState.aliasingReturnVals.lookup(funcOp).lookup( 196 opOperand.getOperandNumber()); 197 SmallVector<OpResult> result; 198 for (int64_t resultIdx : aliasingReturnVals) 199 result.push_back(callOp->getOpResult(resultIdx)); 200 return result; 201 } 202 203 SmallVector<OpOperand *> 204 getAliasingOpOperand(Operation *op, OpResult opResult, 205 const AnalysisState &state) const { 206 func::CallOp callOp = cast<func::CallOp>(op); 207 FuncOp funcOp = getCalledFunction(callOp); 208 assert(funcOp && "expected CallOp to a FuncOp"); 209 if (getFuncOpAnalysisState(state, funcOp) != 210 FuncOpAnalysisState::Analyzed) { 211 // FuncOp not analyzed yet. Any OpOperand may be aliasing. 212 SmallVector<OpOperand *> result; 213 for (OpOperand &opOperand : op->getOpOperands()) 214 if (opOperand.get().getType().isa<TensorType>()) 215 result.push_back(&opOperand); 216 return result; 217 } 218 219 // Get aliasing bbArgs from state. 220 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 221 auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( 222 opResult.getResultNumber()); 223 SmallVector<OpOperand *> result; 224 for (int64_t bbArgIdx : aliasingFuncArgs) 225 result.push_back(&callOp->getOpOperand(bbArgIdx)); 226 return result; 227 } 228 229 BufferRelation bufferRelation(Operation *op, OpResult opResult, 230 const AnalysisState &state) const { 231 func::CallOp callOp = cast<func::CallOp>(op); 232 FuncOp funcOp = getCalledFunction(callOp); 233 assert(funcOp && "expected CallOp to a FuncOp"); 234 if (getFuncOpAnalysisState(state, funcOp) != 235 FuncOpAnalysisState::Analyzed) { 236 // Function not analyzed yet. The conservative answer is "None". 237 return BufferRelation::None; 238 } 239 240 const FuncAnalysisState &funcState = getFuncAnalysisState(state); 241 Optional<int64_t> maybeEquiv = 242 getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber()); 243 if (maybeEquiv) { 244 #ifndef NDEBUG 245 SmallVector<OpOperand *> aliasingOpOperands = 246 getAliasingOpOperand(op, opResult, state); 247 assert(aliasingOpOperands.size() == 1 && 248 "expected exactly 1 aliasing OpOperand"); 249 assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv && 250 "inconsistent analysis state"); 251 #endif 252 return BufferRelation::Equivalent; 253 } 254 return BufferRelation::None; 255 } 256 257 /// All function arguments are writable. It is the responsibility of the 258 /// CallOp to insert buffer copies where necessary. 259 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 260 const BufferizationOptions &options) const { 261 func::CallOp callOp = cast<func::CallOp>(op); 262 unsigned numResults = callOp.getNumResults(); 263 unsigned numOperands = callOp->getNumOperands(); 264 FuncOp funcOp = getCalledFunction(callOp); 265 assert(funcOp && "expected CallOp to a FuncOp"); 266 FunctionType funcType = funcOp.getFunctionType(); 267 268 // Result types of the bufferized CallOp. 269 SmallVector<Type> resultTypes; 270 // Replacement values for the existing CallOp. These are usually the results 271 // of the bufferized CallOp, unless a tensor result folds onto an operand. 272 SmallVector<Value> replacementValues(numResults, Value()); 273 // For non-tensor results: A mapping from return val indices of the old 274 // CallOp to return val indices of the bufferized CallOp. 275 SmallVector<Optional<unsigned>> retValMapping(numResults, None); 276 // Operands of the bufferized CallOp. 277 SmallVector<Value> newOperands(numOperands, Value()); 278 279 // 1. Compute the result types of the new CallOp. 280 for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { 281 unsigned returnValIdx = it.index(); 282 Type returnType = it.value(); 283 if (!returnType.isa<TensorType>()) { 284 // Non-tensor values are returned. 285 retValMapping[returnValIdx] = resultTypes.size(); 286 resultTypes.push_back(returnType); 287 continue; 288 } 289 290 // Returning a memref. 291 retValMapping[returnValIdx] = resultTypes.size(); 292 resultTypes.push_back(funcType.getResult(resultTypes.size())); 293 } 294 295 // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. 296 for (OpOperand &opOperand : callOp->getOpOperands()) { 297 unsigned idx = opOperand.getOperandNumber(); 298 Value tensorOperand = opOperand.get(); 299 300 // Non-tensor operands are just copied. 301 if (!tensorOperand.getType().isa<TensorType>()) { 302 newOperands[idx] = tensorOperand; 303 continue; 304 } 305 306 // Retrieve buffers for tensor operands. 307 Value buffer = newOperands[idx]; 308 if (!buffer) 309 buffer = getBuffer(rewriter, opOperand.get(), options); 310 311 // Caller / callee type mismatch is handled with a CastOp. 312 auto memRefType = funcType.getInput(idx); 313 // Since we don't yet have a clear layout story, to_memref may 314 // conservatively turn tensors into more dynamic memref than necessary. 315 // If the memref type of the callee fails, introduce an extra memref.cast 316 // that will either canonicalize away or fail compilation until we can do 317 // something better. 318 if (buffer.getType() != memRefType) { 319 assert( 320 memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && 321 "CallOp::bufferize: cast incompatible"); 322 Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), 323 memRefType, buffer); 324 buffer = castBuffer; 325 } 326 newOperands[idx] = buffer; 327 } 328 329 // 3. Create the new CallOp. 330 Operation *newCallOp = rewriter.create<func::CallOp>( 331 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); 332 newCallOp->setAttrs(callOp->getAttrs()); 333 // Get replacement values. 334 for (unsigned i = 0; i < replacementValues.size(); ++i) { 335 if (replacementValues[i]) 336 continue; 337 replacementValues[i] = newCallOp->getResult(*retValMapping[i]); 338 } 339 340 // 4. Replace the old op with the new op. 341 replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); 342 343 return success(); 344 } 345 }; 346 347 struct ReturnOpInterface 348 : public BufferizableOpInterface::ExternalModel<ReturnOpInterface, 349 func::ReturnOp> { 350 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 351 const AnalysisState &state) const { 352 return true; 353 } 354 355 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 356 const AnalysisState &state) const { 357 return false; 358 } 359 360 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 361 const AnalysisState &state) const { 362 return {}; 363 } 364 365 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 366 const BufferizationOptions &options) const { 367 #ifndef NDEBUG 368 auto returnOp = cast<func::ReturnOp>(op); 369 assert(isa<FuncOp>(returnOp->getParentOp()) && 370 "only support FuncOp parent for ReturnOp"); 371 #endif // NDEBUG 372 373 // ReturnOps are bufferized as part of FuncOps. 374 return success(); 375 } 376 }; 377 378 struct FuncOpInterface 379 : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> { 380 /// Rewrite function bbArgs and return values into buffer form. This function 381 /// bufferizes the function signature and the ReturnOp. When the entire 382 /// function body has been bufferized, function return types can be switched 383 /// to more concise memref types as part of `foldMemRefCasts`. 384 /// 385 /// All function bbArgs are writable unless they are explicitly marked as 386 /// read-only. Callers must insert copies when needed. 387 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 388 const BufferizationOptions &options) const { 389 auto funcOp = cast<FuncOp>(op); 390 FunctionType funcType = funcOp.getFunctionType(); 391 392 // Construct the bufferized function type. 393 SmallVector<Type> argTypes; 394 for (const auto &it : llvm::enumerate(funcType.getInputs())) { 395 Type argType = it.value(); 396 if (auto tensorType = argType.dyn_cast<TensorType>()) { 397 argTypes.push_back( 398 getBufferizedFunctionArgType(funcOp, it.index(), options)); 399 continue; 400 } 401 argTypes.push_back(argType); 402 } 403 404 // Bodiless functions are assumed opaque and we cannot know the 405 // bufferization contract they want to enforce. As a consequence, only 406 // support functions that don't return any tensors atm. 407 if (funcOp.getBody().empty()) { 408 SmallVector<Type> retTypes; 409 for (Type resultType : funcType.getResults()) { 410 if (resultType.isa<TensorType>()) 411 return funcOp->emitError() << "cannot bufferize bodiless function " 412 << "that returns a tensor"; 413 retTypes.push_back(resultType); 414 } 415 funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); 416 return success(); 417 } 418 419 // TODO: Support functions with multiple returns. 420 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); 421 assert(returnOp && "expected func with single return op"); 422 Location loc = returnOp.getLoc(); 423 424 // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. 425 Block &frontBlock = funcOp.getBody().front(); 426 for (BlockArgument &bbArg : frontBlock.getArguments()) { 427 auto tensorType = bbArg.getType().dyn_cast<TensorType>(); 428 // Non-tensor types stay the same. 429 if (!tensorType) 430 continue; 431 432 // Collect all uses of the bbArg. 433 SmallVector<OpOperand *> bbArgUses; 434 for (OpOperand &use : bbArg.getUses()) 435 bbArgUses.push_back(&use); 436 437 // Change the bbArg type to memref. 438 Type memrefType = 439 getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); 440 bbArg.setType(memrefType); 441 442 // Replace all uses of the original tensor bbArg. 443 rewriter.setInsertionPointToStart(&frontBlock); 444 if (!bbArgUses.empty()) { 445 // Insert to_tensor because the remaining function body has not been 446 // bufferized yet. 447 Value toTensorOp = 448 rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg); 449 for (OpOperand *use : bbArgUses) 450 use->set(toTensorOp); 451 } 452 } 453 454 // 2. For each result, keep track of which inplace argument it reuses. 455 SmallVector<Value> returnValues; 456 for (OpOperand &returnOperand : returnOp->getOpOperands()) { 457 Value returnVal = returnOperand.get(); 458 auto tensorType = returnVal.getType().dyn_cast<TensorType>(); 459 rewriter.setInsertionPoint(returnOp); 460 461 // If not a tensor type just forward it. 462 if (!tensorType) { 463 returnValues.push_back(returnVal); 464 continue; 465 } 466 467 BaseMemRefType resultType; 468 if (options.functionBoundaryTypeConversion == 469 BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { 470 resultType = getMemRefTypeWithStaticIdentityLayout(tensorType); 471 } else { 472 // Note: If `InferLayoutMap`, cast are later folded away. 473 resultType = getMemRefTypeWithFullyDynamicLayout(tensorType); 474 } 475 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 476 loc, resultType, returnVal); 477 returnValues.push_back(toMemrefOp); 478 } 479 480 // 3. Rewrite the terminator without the in-place bufferizable values. 481 returnOp.operandsMutable().assign(returnValues); 482 483 // 4. Rewrite the FuncOp type to buffer form. 484 funcOp.setType(FunctionType::get(op->getContext(), argTypes, 485 ValueRange(returnValues).getTypes())); 486 487 return success(); 488 } 489 490 /// Return `true` if the given function argument is writable. 491 bool isWritable(Operation *op, Value value, 492 const AnalysisState &state) const { 493 auto funcOp = cast<FuncOp>(op); 494 BlockArgument bbArg = value.dyn_cast<BlockArgument>(); 495 assert(bbArg && "expected BlockArgument"); 496 497 // "bufferization.writable" overrides other writability decisions. This is 498 // currently used for testing only. 499 if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>( 500 bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) 501 return writable.getValue(); 502 503 // All function arguments are writable by default. 504 return true; 505 } 506 }; 507 508 } // namespace func_ext 509 } // namespace bufferization 510 } // namespace mlir 511 512 void mlir::bufferization::func_ext:: 513 registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { 514 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 515 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx); 516 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx); 517 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx); 518 }); 519 } 520