1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/PatternMatch.h" 19 20 using namespace mlir; 21 using namespace mlir::bufferization; 22 using namespace mlir::scf; 23 24 namespace mlir { 25 namespace scf { 26 namespace { 27 28 // bufferization.to_memref is not allowed to change the rank. 29 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 30 #ifndef NDEBUG 31 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 32 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 33 rankedTensorType.getRank())) && 34 "to_memref would be invalid: mismatching ranks"); 35 #endif 36 } 37 38 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 39 /// fully implemented at the moment. 40 struct ExecuteRegionOpInterface 41 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 42 scf::ExecuteRegionOp> { 43 SmallVector<OpOperand *> 44 getAliasingOpOperand(Operation *op, OpResult opResult, 45 const AnalysisState &state) const { 46 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 47 // any SSA value that is in scope. To allow for use-def chain traversal 48 // through ExecuteRegionOps in the analysis, the corresponding yield value 49 // is considered to be aliasing with the result. 50 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 51 size_t resultNum = std::distance(op->getOpResults().begin(), 52 llvm::find(op->getOpResults(), opResult)); 53 // TODO: Support multiple blocks. 54 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 55 "expected exactly 1 block"); 56 auto yieldOp = dyn_cast<scf::YieldOp>( 57 executeRegionOp.getRegion().front().getTerminator()); 58 assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 59 return {&yieldOp->getOpOperand(resultNum)}; 60 } 61 62 // TODO: For better bufferization results, this could return `true` only if 63 // there is a memory write in the region. 64 bool isMemoryWrite(Operation *op, OpResult opResult, 65 const AnalysisState &state) const { 66 // Similar to scf.if, results of this op are always considered memory writes 67 // in the analysis. This is a useful pattern for all ops that have tensor 68 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 69 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 70 // ops without OpOperands. 71 return true; 72 } 73 74 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 75 BufferizationState &state) const { 76 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 77 78 // Compute new result types. 79 SmallVector<Type> newResultTypes; 80 for (Type type : executeRegionOp->getResultTypes()) { 81 if (auto tensorType = type.dyn_cast<TensorType>()) { 82 newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 83 } else { 84 newResultTypes.push_back(type); 85 } 86 } 87 88 // Create new op and move over region. 89 auto newOp = 90 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 91 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 92 93 // Update terminator. 94 assert(newOp.getRegion().getBlocks().size() == 1 && 95 "only 1 block supported"); 96 Block *newBlock = &newOp.getRegion().front(); 97 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 98 rewriter.setInsertionPoint(yieldOp); 99 SmallVector<Value> newYieldValues; 100 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 101 Value val = it.value(); 102 if (val.getType().isa<TensorType>()) { 103 newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 104 yieldOp.getLoc(), newResultTypes[it.index()], val)); 105 } else { 106 newYieldValues.push_back(val); 107 } 108 } 109 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 110 111 // Update all uses of the old op. 112 rewriter.setInsertionPointAfter(newOp); 113 SmallVector<Value> newResults; 114 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 115 if (it.value().isa<TensorType>()) { 116 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 117 executeRegionOp.getLoc(), newOp->getResult(it.index()))); 118 } else { 119 newResults.push_back(newOp->getResult(it.index())); 120 } 121 } 122 123 // Replace old op. 124 rewriter.replaceOp(executeRegionOp, newResults); 125 126 return success(); 127 } 128 129 BufferRelation bufferRelation(Operation *op, OpResult opResult, 130 const AnalysisState &state) const { 131 return BufferRelation::Equivalent; 132 } 133 }; 134 135 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 136 struct IfOpInterface 137 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 138 SmallVector<OpOperand *> 139 getAliasingOpOperand(Operation *op, OpResult opResult, 140 const AnalysisState &state) const { 141 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 142 // value that is in scope. To allow for use-def chain traversal through 143 // IfOps in the analysis, both corresponding yield values from the then/else 144 // branches are considered to be aliasing with the result. 145 auto ifOp = cast<scf::IfOp>(op); 146 size_t resultNum = std::distance(op->getOpResults().begin(), 147 llvm::find(op->getOpResults(), opResult)); 148 return {&ifOp.thenYield()->getOpOperand(resultNum), 149 &ifOp.elseYield()->getOpOperand(resultNum)}; 150 } 151 152 // TODO: For better bufferization results, this could return `true` only if 153 // there is a memory write in one (or both) of the branches. Since this is not 154 // allowed at the moment, we should never encounter scf.ifs that yield 155 // unmodified tensors. Such scf.yield ops could just fold away. 156 bool isMemoryWrite(Operation *op, OpResult opResult, 157 const AnalysisState &state) const { 158 // IfOp results are always considered memory writes in the analysis. This 159 // design decision simplifies the analysis considerably. E.g., consider the 160 // following test case: 161 // 162 // %0 = "some_writing_op" : tensor<?xf32> 163 // %r = scf.if %c -> (tensor<?xf32>) { 164 // scf.yield %0 165 // } else { 166 // %1 = "another_writing_op"(%0) : tensor<?xf32> 167 // } 168 // "some_reading_op"(%r) 169 // 170 // "another_writing_op" in the above example should be able to bufferize 171 // inplace in the absence of another read of %0. However, if the scf.if op 172 // would not be considered a "write", the analysis would detect the 173 // following conflict: 174 // 175 // * read = some_reading_op 176 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 177 // * conflictingWrite = %1 178 // 179 // For more details, check the "scf.IfOp" section of the design document. 180 return true; 181 } 182 183 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 184 BufferizationState &state) const { 185 auto ifOp = cast<scf::IfOp>(op); 186 187 // Compute new types of the bufferized scf.if op. 188 SmallVector<Type> newTypes; 189 for (Type returnType : ifOp->getResultTypes()) { 190 if (auto tensorType = returnType.dyn_cast<TensorType>()) { 191 newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 192 } else { 193 newTypes.push_back(returnType); 194 } 195 } 196 197 // Create new op. 198 auto newIfOp = 199 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 200 /*withElseRegion=*/true); 201 202 // Remove terminators. 203 if (!newIfOp.thenBlock()->empty()) { 204 rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 205 rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 206 } 207 208 // Move over then/else blocks. 209 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 210 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 211 212 // Update scf.yield of new then-block. 213 auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 214 rewriter.setInsertionPoint(thenYieldOp); 215 SmallVector<Value> thenYieldValues; 216 for (OpOperand &operand : thenYieldOp->getOpOperands()) { 217 if (operand.get().getType().isa<TensorType>()) { 218 ensureToMemrefOpIsValid(operand.get(), 219 newTypes[operand.getOperandNumber()]); 220 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 221 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 222 operand.get()); 223 operand.set(toMemrefOp); 224 } 225 } 226 227 // Update scf.yield of new else-block. 228 auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 229 rewriter.setInsertionPoint(elseYieldOp); 230 SmallVector<Value> elseYieldValues; 231 for (OpOperand &operand : elseYieldOp->getOpOperands()) { 232 if (operand.get().getType().isa<TensorType>()) { 233 ensureToMemrefOpIsValid(operand.get(), 234 newTypes[operand.getOperandNumber()]); 235 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 236 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 237 operand.get()); 238 operand.set(toMemrefOp); 239 } 240 } 241 242 // Replace op results. 243 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 244 245 return success(); 246 } 247 248 BufferRelation bufferRelation(Operation *op, OpResult opResult, 249 const AnalysisState &state) const { 250 // IfOp results are equivalent to their corresponding yield values if both 251 // yield values are equivalent to each other. 252 auto bufferizableOp = cast<BufferizableOpInterface>(op); 253 SmallVector<OpOperand *> yieldValues = 254 bufferizableOp.getAliasingOpOperand(opResult, state); 255 assert(yieldValues.size() == 2 && "expected 2 yield values"); 256 bool equivalentYields = state.areEquivalentBufferizedValues( 257 yieldValues[0]->get(), yieldValues[1]->get()); 258 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 259 } 260 }; 261 262 /// Helper function for loop bufferization. Return the indices of all values 263 /// that have a tensor type. 264 static DenseSet<int64_t> getTensorIndices(ValueRange values) { 265 DenseSet<int64_t> result; 266 for (const auto &it : llvm::enumerate(values)) 267 if (it.value().getType().isa<TensorType>()) 268 result.insert(it.index()); 269 return result; 270 } 271 272 /// Helper function for loop bufferization. Return the indices of all 273 /// bbArg/yielded value pairs who's buffer relation is "Equivalent". 274 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, 275 ValueRange yieldedValues, 276 const AnalysisState &state) { 277 DenseSet<int64_t> result; 278 int64_t counter = 0; 279 for (const auto &it : llvm::zip(bbArgs, yieldedValues)) { 280 if (!std::get<0>(it).getType().isa<TensorType>()) 281 continue; 282 if (state.areEquivalentBufferizedValues(std::get<0>(it), std::get<1>(it))) 283 result.insert(counter); 284 counter++; 285 } 286 return result; 287 } 288 289 /// Helper function for loop bufferization. Cast the given buffer to the given 290 /// memref type. 291 static Value castBuffer(OpBuilder &b, Value buffer, Type type) { 292 assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType"); 293 assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType"); 294 // If the buffer already has the correct type, no cast is needed. 295 if (buffer.getType() == type) 296 return buffer; 297 // TODO: In case `type` has a layout map that is not the fully dynamic 298 // one, we may not be able to cast the buffer. In that case, the loop 299 // iter_arg's layout map must be changed (see uses of `castBuffer`). 300 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 301 "scf.while op bufferization: cast incompatible"); 302 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); 303 } 304 305 /// Helper function for loop bufferization. Return the bufferized values of the 306 /// given OpOperands. If an operand is not a tensor, return the original value. 307 static SmallVector<Value> getBuffers(RewriterBase &rewriter, 308 MutableArrayRef<OpOperand> operands, 309 BufferizationState &state) { 310 SmallVector<Value> result; 311 for (OpOperand &opOperand : operands) { 312 if (opOperand.get().getType().isa<TensorType>()) { 313 FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 314 if (failed(resultBuffer)) 315 return {}; 316 result.push_back(*resultBuffer); 317 } else { 318 result.push_back(opOperand.get()); 319 } 320 } 321 return result; 322 } 323 324 /// Helper function for loop bufferization. Compute the buffer that should be 325 /// yielded from a loop block (loop body or loop condition). If the given tensor 326 /// is equivalent to the corresponding block argument (as indicated by 327 /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer 328 /// copy must be yielded. 329 /// 330 /// According to the `BufferizableOpInterface` implementation of scf loops, a 331 /// a bufferized OpResult may alias only with the corresponding bufferized 332 /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with 333 /// the i-th init_arg; but not with any other OpOperand. If a corresponding 334 /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by 335 /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we 336 /// cannot be sure and must yield a new buffer copy. (New buffer copies do not 337 /// alias with any buffer.) 338 static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, 339 BaseMemRefType type, bool isEquivalent, 340 BufferizationState &state) { 341 assert(tensor.getType().isa<TensorType>() && "expected tensor"); 342 ensureToMemrefOpIsValid(tensor, type); 343 Value yieldedVal = 344 bufferization::lookupBuffer(rewriter, tensor, state.getOptions()); 345 346 if (isEquivalent) 347 // Yielded value is equivalent to the corresponding iter_arg bbArg. 348 // Yield the value directly. Most IR should be like that. Everything 349 // else must be resolved with copies and is potentially inefficient. 350 // By default, such problematic IR would already have been rejected 351 // during `verifyAnalysis`, unless `allow-return-allocs`. 352 return castBuffer(rewriter, yieldedVal, type); 353 354 // It is not certain that the yielded value and the iter_arg bbArg 355 // have the same buffer. Allocate a new buffer and copy. The yielded 356 // buffer will get deallocated by `deallocateBuffers`. 357 358 // TODO: There are cases in which it is not neccessary to return a new 359 // buffer allocation. E.g., when equivalent values are yielded in a 360 // different order. This could be resolved with copies. 361 Optional<Value> yieldedAlloc = state.createAlloc( 362 rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false); 363 // TODO: We should rollback, but for now just assume that this always 364 // succeeds. 365 assert(yieldedAlloc.hasValue() && "could not create alloc"); 366 LogicalResult copyStatus = bufferization::createMemCpy( 367 rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions()); 368 (void)copyStatus; 369 assert(succeeded(copyStatus) && "could not create memcpy"); 370 371 // The iter_arg memref type may have a layout map. Cast the new buffer 372 // to the same type if needed. 373 return castBuffer(rewriter, *yieldedAlloc, type); 374 } 375 376 /// Helper function for loop bufferization. Given a range of values, apply 377 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified 378 /// value in the result vector. 379 static SmallVector<Value> 380 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices, 381 llvm::function_ref<Value(Value, int64_t)> func) { 382 SmallVector<Value> result; 383 for (const auto &it : llvm::enumerate(values)) { 384 size_t idx = it.index(); 385 Value val = it.value(); 386 result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val); 387 } 388 return result; 389 } 390 391 /// Helper function for loop bufferization. Given a list of pre-bufferization 392 /// yielded values, compute the list of bufferized yielded values. 393 SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values, 394 TypeRange bufferizedTypes, 395 const DenseSet<int64_t> &tensorIndices, 396 const DenseSet<int64_t> &equivalentTensors, 397 BufferizationState &state) { 398 return convertTensorValues( 399 values, tensorIndices, [&](Value val, int64_t index) { 400 return getYieldedBuffer(rewriter, val, 401 bufferizedTypes[index].cast<BaseMemRefType>(), 402 equivalentTensors.contains(index), state); 403 }); 404 } 405 406 /// Helper function for loop bufferization. Given a list of bbArgs of the new 407 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into 408 /// ToTensorOps, so that the block body can be moved over to the new op. 409 SmallVector<Value> 410 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, 411 const DenseSet<int64_t> &tensorIndices) { 412 return convertTensorValues( 413 bbArgs, tensorIndices, [&](Value val, int64_t index) { 414 return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 415 }); 416 } 417 418 /// Bufferization of scf.for. Replace with a new scf.for that operates on 419 /// memrefs. 420 struct ForOpInterface 421 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 422 scf::ForOp> { 423 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 424 const AnalysisState &state) const { 425 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 426 // its matching bbArg may. 427 auto forOp = cast<scf::ForOp>(op); 428 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 429 } 430 431 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 432 const AnalysisState &state) const { 433 // Tensor iter_args of scf::ForOps are always considered as a write. 434 return true; 435 } 436 437 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 438 const AnalysisState &state) const { 439 auto forOp = cast<scf::ForOp>(op); 440 return {forOp.getResultForOpOperand(opOperand)}; 441 } 442 443 BufferRelation bufferRelation(Operation *op, OpResult opResult, 444 const AnalysisState &state) const { 445 // ForOp results are equivalent to their corresponding init_args if the 446 // corresponding iter_args and yield values are equivalent. 447 auto forOp = cast<scf::ForOp>(op); 448 OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 449 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 450 auto yieldOp = 451 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 452 bool equivalentYield = state.areEquivalentBufferizedValues( 453 bbArg, yieldOp->getOperand(opResult.getResultNumber())); 454 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 455 } 456 457 bool isWritable(Operation *op, Value value, 458 const AnalysisState &state) const { 459 // Interestingly, scf::ForOp's bbArg can **always** be viewed 460 // inplace from the perspective of ops nested under: 461 // 1. Either the matching iter operand is not bufferized inplace and an 462 // alloc + optional copy makes the bbArg itself inplaceable. 463 // 2. Or the matching iter operand is bufferized inplace and bbArg just 464 // bufferizes to that too. 465 return true; 466 } 467 468 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 469 BufferizationState &state) const { 470 auto forOp = cast<scf::ForOp>(op); 471 auto oldYieldOp = 472 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 473 Block *oldLoopBody = &forOp.getLoopBody().front(); 474 475 // Indices of all iter_args that have tensor type. These are the ones that 476 // are bufferized. 477 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 478 // For every yielded value, is the value equivalent to its corresponding 479 // bbArg? 480 DenseSet<int64_t> equivalentYields = 481 getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(), 482 state.getAnalysisState()); 483 484 // The new memref init_args of the loop. 485 SmallVector<Value> initArgs = 486 getBuffers(rewriter, forOp.getIterOpOperands(), state); 487 if (initArgs.size() != indices.size()) 488 return failure(); 489 490 // Construct a new scf.for op with memref instead of tensor values. 491 auto newForOp = rewriter.create<scf::ForOp>( 492 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 493 forOp.getStep(), initArgs); 494 ValueRange initArgsRange(initArgs); 495 TypeRange initArgsTypes(initArgsRange); 496 Block *loopBody = &newForOp.getLoopBody().front(); 497 498 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 499 // iter_args of the new loop in ToTensorOps. 500 rewriter.setInsertionPointToStart(loopBody); 501 SmallVector<Value> iterArgs = 502 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); 503 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 504 505 // Erase terminator if present. 506 if (iterArgs.size() == 1) 507 rewriter.eraseOp(loopBody->getTerminator()); 508 509 // Move loop body to new loop. 510 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 511 512 // Update scf.yield of new loop. 513 auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 514 rewriter.setInsertionPoint(yieldOp); 515 SmallVector<Value> yieldValues = 516 getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices, 517 equivalentYields, state); 518 yieldOp.getResultsMutable().assign(yieldValues); 519 520 // Replace loop results. 521 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 522 523 return success(); 524 } 525 526 /// Assert that yielded values of an scf.for op are equivalent to their 527 /// corresponding bbArgs. In that case, the buffer relations of the 528 /// corresponding OpResults are "Equivalent". 529 /// 530 /// If this is not the case, an allocs+copies are inserted and yielded from 531 /// the loop. This could be a performance problem, so it must be explicitly 532 /// activated with `alloc-return-allocs`. 533 LogicalResult verifyAnalysis(Operation *op, 534 const AnalysisState &state) const { 535 const auto &options = 536 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 537 if (options.allowReturnAllocs) 538 return success(); 539 540 auto forOp = cast<scf::ForOp>(op); 541 auto yieldOp = 542 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 543 for (OpResult opResult : op->getOpResults()) { 544 if (!opResult.getType().isa<TensorType>()) 545 continue; 546 547 // Note: This is overly strict. We should check for aliasing bufferized 548 // values. But we don't have a "must-alias" analysis yet. 549 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) 550 return yieldOp->emitError() 551 << "Yield operand #" << opResult.getResultNumber() 552 << " is not equivalent to the corresponding iter bbArg"; 553 } 554 555 return success(); 556 } 557 }; 558 559 /// Bufferization of scf.while. Replace with a new scf.while that operates on 560 /// memrefs. 561 struct WhileOpInterface 562 : public BufferizableOpInterface::ExternalModel<WhileOpInterface, 563 scf::WhileOp> { 564 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 565 const AnalysisState &state) const { 566 // Tensor iter_args of scf::WhileOps are always considered as a read. 567 return true; 568 } 569 570 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 571 const AnalysisState &state) const { 572 // Tensor iter_args of scf::WhileOps are always considered as a write. 573 return true; 574 } 575 576 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 577 const AnalysisState &state) const { 578 auto whileOp = cast<scf::WhileOp>(op); 579 return {whileOp->getResult(opOperand.getOperandNumber())}; 580 } 581 582 BufferRelation bufferRelation(Operation *op, OpResult opResult, 583 const AnalysisState &state) const { 584 // WhileOp results are equivalent to their corresponding init_args if the 585 // corresponding iter_args and yield values are equivalent (for both the 586 // "before" and the "after" block). 587 unsigned int resultNumber = opResult.getResultNumber(); 588 auto whileOp = cast<scf::WhileOp>(op); 589 590 auto conditionOp = whileOp.getConditionOp(); 591 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; 592 Value conditionOperand = conditionOp.getArgs()[resultNumber]; 593 bool equivCondition = 594 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); 595 596 auto yieldOp = whileOp.getYieldOp(); 597 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; 598 Value yieldOperand = yieldOp.getOperand(resultNumber); 599 bool equivYield = 600 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); 601 602 return equivCondition && equivYield ? BufferRelation::Equivalent 603 : BufferRelation::None; 604 } 605 606 bool isWritable(Operation *op, Value value, 607 const AnalysisState &state) const { 608 // Interestingly, scf::WhileOp's bbArg can **always** be viewed 609 // inplace from the perspective of ops nested under: 610 // 1. Either the matching iter operand is not bufferized inplace and an 611 // alloc + optional copy makes the bbArg itself inplaceable. 612 // 2. Or the matching iter operand is bufferized inplace and bbArg just 613 // bufferizes to that too. 614 return true; 615 } 616 617 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 618 BufferizationState &state) const { 619 auto whileOp = cast<scf::WhileOp>(op); 620 621 assert(whileOp.getBefore().getBlocks().size() == 1 && 622 "regions with multiple blocks not supported"); 623 Block *beforeBody = &whileOp.getBefore().front(); 624 assert(whileOp.getAfter().getBlocks().size() == 1 && 625 "regions with multiple blocks not supported"); 626 Block *afterBody = &whileOp.getAfter().front(); 627 628 // Indices of all iter_args that have tensor type. These are the ones that 629 // are bufferized. 630 DenseSet<int64_t> indices = getTensorIndices(whileOp.getInits()); 631 // For every yielded value, is the value equivalent to its corresponding 632 // bbArg? 633 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 634 whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(), 635 state.getAnalysisState()); 636 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 637 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), 638 state.getAnalysisState()); 639 640 // The new memref init_args of the loop. 641 SmallVector<Value> initArgs = 642 getBuffers(rewriter, whileOp->getOpOperands(), state); 643 if (initArgs.size() != indices.size()) 644 return failure(); 645 646 // Construct a new scf.while op with memref instead of tensor values. 647 ValueRange argsRange(initArgs); 648 TypeRange argsTypes(argsRange); 649 auto newWhileOp = 650 rewriter.create<scf::WhileOp>(whileOp.getLoc(), argsTypes, initArgs); 651 // Add before/after regions to the new op. 652 SmallVector<Location> bbArgLocs(initArgs.size(), whileOp.getLoc()); 653 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); 654 newWhileOp.getBefore().addArguments(argsTypes, bbArgLocs); 655 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); 656 newWhileOp.getAfter().addArguments(argsTypes, bbArgLocs); 657 658 // Set up new iter_args and move the loop condition block to the new op. 659 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 660 // in ToTensorOps. 661 rewriter.setInsertionPointToStart(newBeforeBody); 662 SmallVector<Value> newBeforeArgs = getBbArgReplacements( 663 rewriter, newWhileOp.getBeforeArguments(), indices); 664 rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); 665 666 // Update scf.condition of new loop. 667 auto newConditionOp = newWhileOp.getConditionOp(); 668 rewriter.setInsertionPoint(newConditionOp); 669 SmallVector<Value> newConditionArgs = 670 getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypes, indices, 671 equivalentYieldsBefore, state); 672 newConditionOp.getArgsMutable().assign(newConditionArgs); 673 674 // Set up new iter_args and move the loop body block to the new op. 675 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 676 // in ToTensorOps. 677 rewriter.setInsertionPointToStart(newAfterBody); 678 SmallVector<Value> newAfterArgs = 679 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), indices); 680 rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); 681 682 // Update scf.yield of the new loop. 683 auto newYieldOp = newWhileOp.getYieldOp(); 684 rewriter.setInsertionPoint(newYieldOp); 685 SmallVector<Value> newYieldValues = 686 getYieldedValues(rewriter, newYieldOp.getResults(), argsTypes, indices, 687 equivalentYieldsAfter, state); 688 newYieldOp.getResultsMutable().assign(newYieldValues); 689 690 // Replace loop results. 691 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); 692 693 return success(); 694 } 695 696 /// Assert that yielded values of an scf.while op are equivalent to their 697 /// corresponding bbArgs. In that case, the buffer relations of the 698 /// corresponding OpResults are "Equivalent". 699 /// 700 /// If this is not the case, allocs+copies are inserted and yielded from 701 /// the loop. This could be a performance problem, so it must be explicitly 702 /// activated with `alloc-return-allocs`. 703 /// 704 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the 705 /// equivalence condition must be checked for both. 706 LogicalResult verifyAnalysis(Operation *op, 707 const AnalysisState &state) const { 708 auto whileOp = cast<scf::WhileOp>(op); 709 const auto &options = 710 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 711 if (options.allowReturnAllocs) 712 return success(); 713 714 auto conditionOp = whileOp.getConditionOp(); 715 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 716 if (!it.value().getType().isa<TensorType>()) 717 continue; 718 if (!state.areEquivalentBufferizedValues( 719 it.value(), conditionOp->getBlock()->getArgument(it.index()))) 720 return conditionOp->emitError() 721 << "Condition arg #" << it.index() 722 << " is not equivalent to the corresponding iter bbArg"; 723 } 724 725 auto yieldOp = whileOp.getYieldOp(); 726 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 727 if (!it.value().getType().isa<TensorType>()) 728 continue; 729 if (!state.areEquivalentBufferizedValues( 730 it.value(), yieldOp->getBlock()->getArgument(it.index()))) 731 return yieldOp->emitError() 732 << "Yield operand #" << it.index() 733 << " is not equivalent to the corresponding iter bbArg"; 734 } 735 736 return success(); 737 } 738 }; 739 740 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 741 /// this is for analysis only. 742 struct YieldOpInterface 743 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 744 scf::YieldOp> { 745 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 746 const AnalysisState &state) const { 747 return true; 748 } 749 750 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 751 const AnalysisState &state) const { 752 return false; 753 } 754 755 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 756 const AnalysisState &state) const { 757 if (isa<scf::IfOp>(op->getParentOp())) 758 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 759 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 760 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 761 return {}; 762 } 763 764 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 765 const AnalysisState &state) const { 766 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 767 // may be generated inside the block. We should not return/yield allocations 768 // when possible. 769 return true; 770 } 771 772 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 773 BufferizationState &state) const { 774 auto yieldOp = cast<scf::YieldOp>(op); 775 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>( 776 yieldOp->getParentOp())) 777 return yieldOp->emitError("unsupported scf::YieldOp parent"); 778 return success(); 779 } 780 }; 781 782 } // namespace 783 } // namespace scf 784 } // namespace mlir 785 786 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 787 DialectRegistry ®istry) { 788 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 789 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 790 ForOp::attachInterface<ForOpInterface>(*ctx); 791 IfOp::attachInterface<IfOpInterface>(*ctx); 792 WhileOp::attachInterface<WhileOpInterface>(*ctx); 793 YieldOp::attachInterface<YieldOpInterface>(*ctx); 794 }); 795 } 796