1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 using namespace mlir::scf; 22 23 namespace mlir { 24 namespace scf { 25 namespace { 26 27 // bufferization.to_memref is not allowed to change the rank. 28 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 29 #ifndef NDEBUG 30 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 31 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 32 rankedTensorType.getRank())) && 33 "to_memref would be invalid: mismatching ranks"); 34 #endif 35 } 36 37 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 38 /// fully implemented at the moment. 39 struct ExecuteRegionOpInterface 40 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 41 scf::ExecuteRegionOp> { 42 SmallVector<OpOperand *> 43 getAliasingOpOperand(Operation *op, OpResult opResult, 44 const AnalysisState &state) const { 45 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 46 // any SSA value that is in scope. To allow for use-def chain traversal 47 // through ExecuteRegionOps in the analysis, the corresponding yield value 48 // is considered to be aliasing with the result. 49 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 50 size_t resultNum = std::distance(op->getOpResults().begin(), 51 llvm::find(op->getOpResults(), opResult)); 52 // TODO: Support multiple blocks. 53 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 54 "expected exactly 1 block"); 55 auto yieldOp = dyn_cast<scf::YieldOp>( 56 executeRegionOp.getRegion().front().getTerminator()); 57 assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 58 return {&yieldOp->getOpOperand(resultNum)}; 59 } 60 61 // TODO: For better bufferization results, this could return `true` only if 62 // there is a memory write in the region. 63 bool isMemoryWrite(Operation *op, OpResult opResult, 64 const AnalysisState &state) const { 65 // Similar to scf.if, results of this op are always considered memory writes 66 // in the analysis. This is a useful pattern for all ops that have tensor 67 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 68 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 69 // ops without OpOperands. 70 return true; 71 } 72 73 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 74 BufferizationState &state) const { 75 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 76 77 // Compute new result types. 78 SmallVector<Type> newResultTypes; 79 for (Type type : executeRegionOp->getResultTypes()) { 80 if (auto tensorType = type.dyn_cast<TensorType>()) { 81 newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 82 } else { 83 newResultTypes.push_back(type); 84 } 85 } 86 87 // Create new op and move over region. 88 auto newOp = 89 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 90 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 91 92 // Update terminator. 93 assert(newOp.getRegion().getBlocks().size() == 1 && 94 "only 1 block supported"); 95 Block *newBlock = &newOp.getRegion().front(); 96 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 97 rewriter.setInsertionPoint(yieldOp); 98 SmallVector<Value> newYieldValues; 99 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 100 Value val = it.value(); 101 if (val.getType().isa<TensorType>()) { 102 newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 103 yieldOp.getLoc(), newResultTypes[it.index()], val)); 104 } else { 105 newYieldValues.push_back(val); 106 } 107 } 108 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 109 110 // Update all uses of the old op. 111 rewriter.setInsertionPointAfter(newOp); 112 SmallVector<Value> newResults; 113 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 114 if (it.value().isa<TensorType>()) { 115 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 116 executeRegionOp.getLoc(), newOp->getResult(it.index()))); 117 } else { 118 newResults.push_back(newOp->getResult(it.index())); 119 } 120 } 121 122 // Replace old op. 123 rewriter.replaceOp(executeRegionOp, newResults); 124 125 return success(); 126 } 127 128 BufferRelation bufferRelation(Operation *op, OpResult opResult, 129 const AnalysisState &state) const { 130 return BufferRelation::Equivalent; 131 } 132 }; 133 134 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 135 struct IfOpInterface 136 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 137 SmallVector<OpOperand *> 138 getAliasingOpOperand(Operation *op, OpResult opResult, 139 const AnalysisState &state) const { 140 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 141 // value that is in scope. To allow for use-def chain traversal through 142 // IfOps in the analysis, both corresponding yield values from the then/else 143 // branches are considered to be aliasing with the result. 144 auto ifOp = cast<scf::IfOp>(op); 145 size_t resultNum = std::distance(op->getOpResults().begin(), 146 llvm::find(op->getOpResults(), opResult)); 147 return {&ifOp.thenYield()->getOpOperand(resultNum), 148 &ifOp.elseYield()->getOpOperand(resultNum)}; 149 } 150 151 // TODO: For better bufferization results, this could return `true` only if 152 // there is a memory write in one (or both) of the branches. Since this is not 153 // allowed at the moment, we should never encounter scf.ifs that yield 154 // unmodified tensors. Such scf.yield ops could just fold away. 155 bool isMemoryWrite(Operation *op, OpResult opResult, 156 const AnalysisState &state) const { 157 // IfOp results are always considered memory writes in the analysis. This 158 // design decision simplifies the analysis considerably. E.g., consider the 159 // following test case: 160 // 161 // %0 = "some_writing_op" : tensor<?xf32> 162 // %r = scf.if %c -> (tensor<?xf32>) { 163 // scf.yield %0 164 // } else { 165 // %1 = "another_writing_op"(%0) : tensor<?xf32> 166 // } 167 // "some_reading_op"(%r) 168 // 169 // "another_writing_op" in the above example should be able to bufferize 170 // inplace in the absence of another read of %0. However, if the scf.if op 171 // would not be considered a "write", the analysis would detect the 172 // following conflict: 173 // 174 // * read = some_reading_op 175 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 176 // * conflictingWrite = %1 177 // 178 // For more details, check the "scf.IfOp" section of the design document. 179 return true; 180 } 181 182 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 183 BufferizationState &state) const { 184 auto ifOp = cast<scf::IfOp>(op); 185 186 // Compute new types of the bufferized scf.if op. 187 SmallVector<Type> newTypes; 188 for (Type returnType : ifOp->getResultTypes()) { 189 if (auto tensorType = returnType.dyn_cast<TensorType>()) { 190 newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 191 } else { 192 newTypes.push_back(returnType); 193 } 194 } 195 196 // Create new op. 197 auto newIfOp = 198 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 199 /*withElseRegion=*/true); 200 201 // Remove terminators. 202 if (!newIfOp.thenBlock()->empty()) { 203 rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 204 rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 205 } 206 207 // Move over then/else blocks. 208 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 209 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 210 211 // Update scf.yield of new then-block. 212 auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 213 rewriter.setInsertionPoint(thenYieldOp); 214 SmallVector<Value> thenYieldValues; 215 for (OpOperand &operand : thenYieldOp->getOpOperands()) { 216 if (operand.get().getType().isa<TensorType>()) { 217 ensureToMemrefOpIsValid(operand.get(), 218 newTypes[operand.getOperandNumber()]); 219 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 220 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 221 operand.get()); 222 operand.set(toMemrefOp); 223 } 224 } 225 226 // Update scf.yield of new else-block. 227 auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 228 rewriter.setInsertionPoint(elseYieldOp); 229 SmallVector<Value> elseYieldValues; 230 for (OpOperand &operand : elseYieldOp->getOpOperands()) { 231 if (operand.get().getType().isa<TensorType>()) { 232 ensureToMemrefOpIsValid(operand.get(), 233 newTypes[operand.getOperandNumber()]); 234 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 235 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 236 operand.get()); 237 operand.set(toMemrefOp); 238 } 239 } 240 241 // Replace op results. 242 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 243 244 return success(); 245 } 246 247 BufferRelation bufferRelation(Operation *op, OpResult opResult, 248 const AnalysisState &state) const { 249 // IfOp results are equivalent to their corresponding yield values if both 250 // yield values are equivalent to each other. 251 auto bufferizableOp = cast<BufferizableOpInterface>(op); 252 SmallVector<OpOperand *> yieldValues = 253 bufferizableOp.getAliasingOpOperand(opResult, state); 254 assert(yieldValues.size() == 2 && "expected 2 yield values"); 255 bool equivalentYields = state.areEquivalentBufferizedValues( 256 yieldValues[0]->get(), yieldValues[1]->get()); 257 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 258 } 259 }; 260 261 /// Bufferization of scf.for. Replace with a new scf.for that operates on 262 /// memrefs. 263 struct ForOpInterface 264 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 265 scf::ForOp> { 266 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 267 const AnalysisState &state) const { 268 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 269 // its matching bbArg may. 270 auto forOp = cast<scf::ForOp>(op); 271 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 272 } 273 274 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 275 const AnalysisState &state) const { 276 // Tensor iter_args of scf::ForOps are always considered as a write. 277 return true; 278 } 279 280 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 281 const AnalysisState &state) const { 282 auto forOp = cast<scf::ForOp>(op); 283 return {forOp.getResultForOpOperand(opOperand)}; 284 } 285 286 BufferRelation bufferRelation(Operation *op, OpResult opResult, 287 const AnalysisState &state) const { 288 // ForOp results are equivalent to their corresponding init_args if the 289 // corresponding iter_args and yield values are equivalent. 290 auto forOp = cast<scf::ForOp>(op); 291 OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 292 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 293 auto yieldOp = 294 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 295 bool equivalentYield = state.areEquivalentBufferizedValues( 296 bbArg, yieldOp->getOperand(opResult.getResultNumber())); 297 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 298 } 299 300 bool isWritable(Operation *op, Value value, 301 const AnalysisState &state) const { 302 // Interestingly, scf::ForOp's bbArg can **always** be viewed 303 // inplace from the perspective of ops nested under: 304 // 1. Either the matching iter operand is not bufferized inplace and an 305 // alloc + optional copy makes the bbArg itself inplaceable. 306 // 2. Or the matching iter operand is bufferized inplace and bbArg just 307 // bufferizes to that too. 308 return true; 309 } 310 311 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 312 BufferizationState &state) const { 313 auto forOp = cast<scf::ForOp>(op); 314 auto bufferizableOp = cast<BufferizableOpInterface>(op); 315 Block *oldLoopBody = &forOp.getLoopBody().front(); 316 317 // Helper function for casting MemRef buffers. 318 auto castBuffer = [&](Value buffer, Type type) { 319 assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType"); 320 assert(buffer.getType().isa<BaseMemRefType>() && 321 "expected BaseMemRefType"); 322 // If the buffer already has the correct type, no cast is needed. 323 if (buffer.getType() == type) 324 return buffer; 325 // TODO: In case `type` has a layout map that is not the fully dynamic 326 // one, we may not be able to cast the buffer. In that case, the loop 327 // iter_arg's layout map must be changed (see uses of `castBuffer`). 328 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 329 "scf.for op bufferization: cast incompatible"); 330 return rewriter.create<memref::CastOp>(buffer.getLoc(), type, buffer) 331 .getResult(); 332 }; 333 334 // Indices of all iter_args that have tensor type. These are the ones that 335 // are bufferized. 336 DenseSet<int64_t> indices; 337 // For every yielded value, is the value equivalent to its corresponding 338 // bbArg? 339 SmallVector<bool> equivalentYields; 340 for (const auto &it : llvm::enumerate(forOp.getInitArgs())) { 341 if (it.value().getType().isa<TensorType>()) { 342 indices.insert(it.index()); 343 BufferRelation relation = bufferizableOp.bufferRelation( 344 forOp->getResult(it.index()), state.getAnalysisState()); 345 equivalentYields.push_back(relation == BufferRelation::Equivalent); 346 } else { 347 equivalentYields.push_back(false); 348 } 349 } 350 351 // Given a range of values, apply `func` to those marked in `indices`. 352 // Otherwise, store the unmodified value in the result vector. 353 auto convert = [&](ValueRange values, 354 llvm::function_ref<Value(Value, int64_t)> func) { 355 SmallVector<Value> result; 356 for (const auto &it : llvm::enumerate(values)) { 357 size_t idx = it.index(); 358 Value val = it.value(); 359 result.push_back(indices.contains(idx) ? func(val, idx) : val); 360 } 361 return result; 362 }; 363 364 // Construct a new scf.for op with memref instead of tensor values. 365 SmallVector<Value> initArgs; 366 for (OpOperand &opOperand : forOp.getIterOpOperands()) { 367 if (opOperand.get().getType().isa<TensorType>()) { 368 FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 369 if (failed(resultBuffer)) 370 return failure(); 371 initArgs.push_back(*resultBuffer); 372 } else { 373 initArgs.push_back(opOperand.get()); 374 } 375 } 376 auto newForOp = rewriter.create<scf::ForOp>( 377 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 378 forOp.getStep(), initArgs); 379 Block *loopBody = &newForOp.getLoopBody().front(); 380 381 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 382 // iter_args of the new loop in ToTensorOps. 383 rewriter.setInsertionPointToStart(loopBody); 384 SmallVector<Value> iterArgs = 385 convert(newForOp.getRegionIterArgs(), [&](Value val, int64_t index) { 386 return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 387 }); 388 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 389 390 // Erase terminator if present. 391 if (iterArgs.size() == 1) 392 rewriter.eraseOp(loopBody->getTerminator()); 393 394 // Move loop body to new loop. 395 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 396 397 // Update scf.yield of new loop. 398 auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 399 rewriter.setInsertionPoint(yieldOp); 400 SmallVector<Value> yieldValues = 401 convert(yieldOp.getResults(), [&](Value val, int64_t index) { 402 Type initArgType = initArgs[index].getType(); 403 ensureToMemrefOpIsValid(val, initArgType); 404 Value yieldedVal = 405 bufferization::lookupBuffer(rewriter, val, state.getOptions()); 406 407 if (equivalentYields[index]) 408 // Yielded value is equivalent to the corresponding iter_arg bbArg. 409 // Yield the value directly. Most IR should be like that. Everything 410 // else must be resolved with copies and is potentially inefficient. 411 // By default, such problematic IR would already have been rejected 412 // during `verifyAnalysis`, unless `allow-return-allocs`. 413 return castBuffer(yieldedVal, initArgType); 414 415 // It is not certain that the yielded value and the iter_arg bbArg 416 // have the same buffer. Allocate a new buffer and copy. The yielded 417 // buffer will get deallocated by `deallocateBuffers`. 418 419 // TODO: There are cases in which it is not neccessary to return a new 420 // buffer allocation. E.g., when equivalent values are yielded in a 421 // different order. This could be resolved with copies. 422 Optional<Value> yieldedAlloc = state.createAlloc( 423 rewriter, val.getLoc(), yieldedVal, /*deallocMemref=*/false); 424 // TODO: We should rollback, but for now just assume that this always 425 // succeeds. 426 assert(yieldedAlloc.hasValue() && "could not create alloc"); 427 LogicalResult copyStatus = 428 bufferization::createMemCpy(rewriter, val.getLoc(), yieldedVal, 429 *yieldedAlloc, state.getOptions()); 430 (void)copyStatus; 431 assert(succeeded(copyStatus) && "could not create memcpy"); 432 433 // The iter_arg memref type may have a layout map. Cast the new buffer 434 // to the same type if needed. 435 return castBuffer(*yieldedAlloc, initArgType); 436 }); 437 yieldOp.getResultsMutable().assign(yieldValues); 438 439 // Replace loop results. 440 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 441 442 return success(); 443 } 444 445 /// Assert that yielded values of an scf.for op are equivalent to their 446 /// corresponding bbArgs. Otherwise, an alloc+copy are inserted and yielded 447 /// from the loop. This could be a performance problem, so it must be 448 /// explicitly activated with `alloc-return-allocs`. 449 LogicalResult verifyAnalysis(Operation *op, 450 const AnalysisState &state) const { 451 const auto &options = 452 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 453 if (options.allowReturnAllocs) 454 return success(); 455 456 auto forOp = cast<scf::ForOp>(op); 457 auto yieldOp = 458 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 459 for (OpOperand &operand : yieldOp->getOpOperands()) { 460 auto tensorType = operand.get().getType().dyn_cast<TensorType>(); 461 if (!tensorType) 462 continue; 463 464 OpOperand &forOperand = forOp.getOpOperandForResult( 465 forOp->getResult(operand.getOperandNumber())); 466 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 467 // Note: This is overly strict. We should check for aliasing bufferized 468 // values. But we don't have a "must-alias" analysis yet. 469 if (!state.areEquivalentBufferizedValues(operand.get(), bbArg)) 470 return yieldOp->emitError() 471 << "Yield operand #" << operand.getOperandNumber() 472 << " does not bufferize to a buffer that is aliasing the " 473 "matching enclosing scf::for operand"; 474 } 475 return success(); 476 } 477 }; 478 479 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 480 /// this is for analysis only. 481 struct YieldOpInterface 482 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 483 scf::YieldOp> { 484 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 485 const AnalysisState &state) const { 486 return true; 487 } 488 489 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 490 const AnalysisState &state) const { 491 return false; 492 } 493 494 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 495 const AnalysisState &state) const { 496 if (isa<scf::IfOp>(op->getParentOp())) 497 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 498 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 499 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 500 return {}; 501 } 502 503 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 504 const AnalysisState &state) const { 505 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 506 // may be generated inside the block. We should not return/yield allocations 507 // when possible. 508 return true; 509 } 510 511 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 512 BufferizationState &state) const { 513 auto yieldOp = cast<scf::YieldOp>(op); 514 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>( 515 yieldOp->getParentOp())) 516 return yieldOp->emitError("unsupported scf::YieldOp parent"); 517 return success(); 518 } 519 }; 520 521 } // namespace 522 } // namespace scf 523 } // namespace mlir 524 525 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 526 DialectRegistry ®istry) { 527 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 528 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 529 ForOp::attachInterface<ForOpInterface>(*ctx); 530 IfOp::attachInterface<IfOpInterface>(*ctx); 531 YieldOp::attachInterface<YieldOpInterface>(*ctx); 532 }); 533 } 534