1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/IR/PatternMatch.h" 19 20 using namespace mlir; 21 using namespace mlir::bufferization; 22 using namespace mlir::scf; 23 24 namespace mlir { 25 namespace scf { 26 namespace { 27 28 // bufferization.to_memref is not allowed to change the rank. 29 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 30 #ifndef NDEBUG 31 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 32 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 33 rankedTensorType.getRank())) && 34 "to_memref would be invalid: mismatching ranks"); 35 #endif 36 } 37 38 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 39 /// fully implemented at the moment. 40 struct ExecuteRegionOpInterface 41 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 42 scf::ExecuteRegionOp> { 43 SmallVector<OpOperand *> 44 getAliasingOpOperand(Operation *op, OpResult opResult, 45 const AnalysisState &state) const { 46 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 47 // any SSA value that is in scope. To allow for use-def chain traversal 48 // through ExecuteRegionOps in the analysis, the corresponding yield value 49 // is considered to be aliasing with the result. 50 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 51 size_t resultNum = std::distance(op->getOpResults().begin(), 52 llvm::find(op->getOpResults(), opResult)); 53 // TODO: Support multiple blocks. 54 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 55 "expected exactly 1 block"); 56 auto yieldOp = dyn_cast<scf::YieldOp>( 57 executeRegionOp.getRegion().front().getTerminator()); 58 assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 59 return {&yieldOp->getOpOperand(resultNum)}; 60 } 61 62 // TODO: For better bufferization results, this could return `true` only if 63 // there is a memory write in the region. 64 bool isMemoryWrite(Operation *op, OpResult opResult, 65 const AnalysisState &state) const { 66 // Similar to scf.if, results of this op are always considered memory writes 67 // in the analysis. This is a useful pattern for all ops that have tensor 68 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 69 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 70 // ops without OpOperands. 71 return true; 72 } 73 74 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 75 BufferizationState &state) const { 76 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 77 78 // Compute new result types. 79 SmallVector<Type> newResultTypes; 80 for (Type type : executeRegionOp->getResultTypes()) { 81 if (auto tensorType = type.dyn_cast<TensorType>()) { 82 // TODO: Infer the result type instead of computing it. 83 newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 84 } else { 85 newResultTypes.push_back(type); 86 } 87 } 88 89 // Create new op and move over region. 90 auto newOp = 91 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 92 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 93 94 // Update terminator. 95 assert(newOp.getRegion().getBlocks().size() == 1 && 96 "only 1 block supported"); 97 Block *newBlock = &newOp.getRegion().front(); 98 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 99 rewriter.setInsertionPoint(yieldOp); 100 SmallVector<Value> newYieldValues; 101 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 102 Value val = it.value(); 103 if (val.getType().isa<TensorType>()) { 104 newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 105 yieldOp.getLoc(), newResultTypes[it.index()], val)); 106 } else { 107 newYieldValues.push_back(val); 108 } 109 } 110 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 111 112 // Update all uses of the old op. 113 rewriter.setInsertionPointAfter(newOp); 114 SmallVector<Value> newResults; 115 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 116 if (it.value().isa<TensorType>()) { 117 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 118 executeRegionOp.getLoc(), newOp->getResult(it.index()))); 119 } else { 120 newResults.push_back(newOp->getResult(it.index())); 121 } 122 } 123 124 // Replace old op. 125 rewriter.replaceOp(executeRegionOp, newResults); 126 127 return success(); 128 } 129 130 BufferRelation bufferRelation(Operation *op, OpResult opResult, 131 const AnalysisState &state) const { 132 return BufferRelation::Equivalent; 133 } 134 }; 135 136 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 137 struct IfOpInterface 138 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 139 SmallVector<OpOperand *> 140 getAliasingOpOperand(Operation *op, OpResult opResult, 141 const AnalysisState &state) const { 142 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 143 // value that is in scope. To allow for use-def chain traversal through 144 // IfOps in the analysis, both corresponding yield values from the then/else 145 // branches are considered to be aliasing with the result. 146 auto ifOp = cast<scf::IfOp>(op); 147 size_t resultNum = std::distance(op->getOpResults().begin(), 148 llvm::find(op->getOpResults(), opResult)); 149 return {&ifOp.thenYield()->getOpOperand(resultNum), 150 &ifOp.elseYield()->getOpOperand(resultNum)}; 151 } 152 153 // TODO: For better bufferization results, this could return `true` only if 154 // there is a memory write in one (or both) of the branches. Since this is not 155 // allowed at the moment, we should never encounter scf.ifs that yield 156 // unmodified tensors. Such scf.yield ops could just fold away. 157 bool isMemoryWrite(Operation *op, OpResult opResult, 158 const AnalysisState &state) const { 159 // IfOp results are always considered memory writes in the analysis. This 160 // design decision simplifies the analysis considerably. E.g., consider the 161 // following test case: 162 // 163 // %0 = "some_writing_op" : tensor<?xf32> 164 // %r = scf.if %c -> (tensor<?xf32>) { 165 // scf.yield %0 166 // } else { 167 // %1 = "another_writing_op"(%0) : tensor<?xf32> 168 // } 169 // "some_reading_op"(%r) 170 // 171 // "another_writing_op" in the above example should be able to bufferize 172 // inplace in the absence of another read of %0. However, if the scf.if op 173 // would not be considered a "write", the analysis would detect the 174 // following conflict: 175 // 176 // * read = some_reading_op 177 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 178 // * conflictingWrite = %1 179 // 180 // For more details, check the "scf.IfOp" section of the design document. 181 return true; 182 } 183 184 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 185 BufferizationState &state) const { 186 auto ifOp = cast<scf::IfOp>(op); 187 188 // Compute new types of the bufferized scf.if op. 189 SmallVector<Type> newTypes; 190 for (Type returnType : ifOp->getResultTypes()) { 191 if (auto tensorType = returnType.dyn_cast<TensorType>()) { 192 // TODO: Infer the result type instead of computing it. 193 newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 194 } else { 195 newTypes.push_back(returnType); 196 } 197 } 198 199 // Create new op. 200 auto newIfOp = 201 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 202 /*withElseRegion=*/true); 203 204 // Remove terminators. 205 if (!newIfOp.thenBlock()->empty()) { 206 rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 207 rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 208 } 209 210 // Move over then/else blocks. 211 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 212 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 213 214 // Update scf.yield of new then-block. 215 auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 216 rewriter.setInsertionPoint(thenYieldOp); 217 SmallVector<Value> thenYieldValues; 218 for (OpOperand &operand : thenYieldOp->getOpOperands()) { 219 if (operand.get().getType().isa<TensorType>()) { 220 ensureToMemrefOpIsValid(operand.get(), 221 newTypes[operand.getOperandNumber()]); 222 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 223 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 224 operand.get()); 225 operand.set(toMemrefOp); 226 } 227 } 228 229 // Update scf.yield of new else-block. 230 auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 231 rewriter.setInsertionPoint(elseYieldOp); 232 SmallVector<Value> elseYieldValues; 233 for (OpOperand &operand : elseYieldOp->getOpOperands()) { 234 if (operand.get().getType().isa<TensorType>()) { 235 ensureToMemrefOpIsValid(operand.get(), 236 newTypes[operand.getOperandNumber()]); 237 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 238 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 239 operand.get()); 240 operand.set(toMemrefOp); 241 } 242 } 243 244 // Replace op results. 245 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 246 247 return success(); 248 } 249 250 BufferRelation bufferRelation(Operation *op, OpResult opResult, 251 const AnalysisState &state) const { 252 // IfOp results are equivalent to their corresponding yield values if both 253 // yield values are equivalent to each other. 254 auto bufferizableOp = cast<BufferizableOpInterface>(op); 255 SmallVector<OpOperand *> yieldValues = 256 bufferizableOp.getAliasingOpOperand(opResult, state); 257 assert(yieldValues.size() == 2 && "expected 2 yield values"); 258 bool equivalentYields = state.areEquivalentBufferizedValues( 259 yieldValues[0]->get(), yieldValues[1]->get()); 260 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 261 } 262 }; 263 264 /// Helper function for loop bufferization. Return the indices of all values 265 /// that have a tensor type. 266 static DenseSet<int64_t> getTensorIndices(ValueRange values) { 267 DenseSet<int64_t> result; 268 for (const auto &it : llvm::enumerate(values)) 269 if (it.value().getType().isa<TensorType>()) 270 result.insert(it.index()); 271 return result; 272 } 273 274 /// Helper function for loop bufferization. Return the indices of all 275 /// bbArg/yielded value pairs who's buffer relation is "Equivalent". 276 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, 277 ValueRange yieldedValues, 278 const AnalysisState &state) { 279 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); 280 DenseSet<int64_t> result; 281 for (unsigned int i = 0; i < minSize; ++i) { 282 if (!bbArgs[i].getType().isa<TensorType>() || 283 !yieldedValues[i].getType().isa<TensorType>()) 284 continue; 285 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) 286 result.insert(i); 287 } 288 return result; 289 } 290 291 /// Helper function for loop bufferization. Cast the given buffer to the given 292 /// memref type. 293 static Value castBuffer(OpBuilder &b, Value buffer, Type type) { 294 assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType"); 295 assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType"); 296 // If the buffer already has the correct type, no cast is needed. 297 if (buffer.getType() == type) 298 return buffer; 299 // TODO: In case `type` has a layout map that is not the fully dynamic 300 // one, we may not be able to cast the buffer. In that case, the loop 301 // iter_arg's layout map must be changed (see uses of `castBuffer`). 302 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 303 "scf.while op bufferization: cast incompatible"); 304 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); 305 } 306 307 /// Helper function for loop bufferization. Return the bufferized values of the 308 /// given OpOperands. If an operand is not a tensor, return the original value. 309 static SmallVector<Value> getBuffers(RewriterBase &rewriter, 310 MutableArrayRef<OpOperand> operands, 311 BufferizationState &state) { 312 SmallVector<Value> result; 313 for (OpOperand &opOperand : operands) { 314 if (opOperand.get().getType().isa<TensorType>()) { 315 FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 316 if (failed(resultBuffer)) 317 return {}; 318 result.push_back(*resultBuffer); 319 } else { 320 result.push_back(opOperand.get()); 321 } 322 } 323 return result; 324 } 325 326 /// Helper function for loop bufferization. Compute the buffer that should be 327 /// yielded from a loop block (loop body or loop condition). If the given tensor 328 /// is equivalent to the corresponding block argument (as indicated by 329 /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer 330 /// copy must be yielded. 331 /// 332 /// According to the `BufferizableOpInterface` implementation of scf loops, a 333 /// a bufferized OpResult may alias only with the corresponding bufferized 334 /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with 335 /// the i-th init_arg; but not with any other OpOperand. If a corresponding 336 /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by 337 /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we 338 /// cannot be sure and must yield a new buffer copy. (New buffer copies do not 339 /// alias with any buffer.) 340 static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, 341 BaseMemRefType type, bool isEquivalent, 342 BufferizationState &state) { 343 assert(tensor.getType().isa<TensorType>() && "expected tensor"); 344 ensureToMemrefOpIsValid(tensor, type); 345 Value yieldedVal = 346 bufferization::lookupBuffer(rewriter, tensor, state.getOptions()); 347 348 if (isEquivalent) 349 // Yielded value is equivalent to the corresponding iter_arg bbArg. 350 // Yield the value directly. Most IR should be like that. Everything 351 // else must be resolved with copies and is potentially inefficient. 352 // By default, such problematic IR would already have been rejected 353 // during `verifyAnalysis`, unless `allow-return-allocs`. 354 return castBuffer(rewriter, yieldedVal, type); 355 356 // It is not certain that the yielded value and the iter_arg bbArg 357 // have the same buffer. Allocate a new buffer and copy. The yielded 358 // buffer will get deallocated by `deallocateBuffers`. 359 360 // TODO: There are cases in which it is not neccessary to return a new 361 // buffer allocation. E.g., when equivalent values are yielded in a 362 // different order. This could be resolved with copies. 363 Optional<Value> yieldedAlloc = state.createAlloc( 364 rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false); 365 // TODO: We should rollback, but for now just assume that this always 366 // succeeds. 367 assert(yieldedAlloc.hasValue() && "could not create alloc"); 368 LogicalResult copyStatus = state.getOptions().createMemCpy( 369 rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc); 370 (void)copyStatus; 371 assert(succeeded(copyStatus) && "could not create memcpy"); 372 373 // The iter_arg memref type may have a layout map. Cast the new buffer 374 // to the same type if needed. 375 return castBuffer(rewriter, *yieldedAlloc, type); 376 } 377 378 /// Helper function for loop bufferization. Given a range of values, apply 379 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified 380 /// value in the result vector. 381 static SmallVector<Value> 382 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices, 383 llvm::function_ref<Value(Value, int64_t)> func) { 384 SmallVector<Value> result; 385 for (const auto &it : llvm::enumerate(values)) { 386 size_t idx = it.index(); 387 Value val = it.value(); 388 result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val); 389 } 390 return result; 391 } 392 393 /// Helper function for loop bufferization. Given a list of pre-bufferization 394 /// yielded values, compute the list of bufferized yielded values. 395 SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values, 396 TypeRange bufferizedTypes, 397 const DenseSet<int64_t> &tensorIndices, 398 const DenseSet<int64_t> &equivalentTensors, 399 BufferizationState &state) { 400 return convertTensorValues( 401 values, tensorIndices, [&](Value val, int64_t index) { 402 return getYieldedBuffer(rewriter, val, 403 bufferizedTypes[index].cast<BaseMemRefType>(), 404 equivalentTensors.contains(index), state); 405 }); 406 } 407 408 /// Helper function for loop bufferization. Given a list of bbArgs of the new 409 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into 410 /// ToTensorOps, so that the block body can be moved over to the new op. 411 SmallVector<Value> 412 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, 413 const DenseSet<int64_t> &tensorIndices) { 414 return convertTensorValues( 415 bbArgs, tensorIndices, [&](Value val, int64_t index) { 416 return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 417 }); 418 } 419 420 /// Bufferization of scf.for. Replace with a new scf.for that operates on 421 /// memrefs. 422 struct ForOpInterface 423 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 424 scf::ForOp> { 425 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 426 const AnalysisState &state) const { 427 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 428 // its matching bbArg may. 429 auto forOp = cast<scf::ForOp>(op); 430 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 431 } 432 433 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 434 const AnalysisState &state) const { 435 // Tensor iter_args of scf::ForOps are always considered as a write. 436 return true; 437 } 438 439 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 440 const AnalysisState &state) const { 441 auto forOp = cast<scf::ForOp>(op); 442 return {forOp.getResultForOpOperand(opOperand)}; 443 } 444 445 BufferRelation bufferRelation(Operation *op, OpResult opResult, 446 const AnalysisState &state) const { 447 // ForOp results are equivalent to their corresponding init_args if the 448 // corresponding iter_args and yield values are equivalent. 449 auto forOp = cast<scf::ForOp>(op); 450 OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 451 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 452 auto yieldOp = 453 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 454 bool equivalentYield = state.areEquivalentBufferizedValues( 455 bbArg, yieldOp->getOperand(opResult.getResultNumber())); 456 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 457 } 458 459 bool isWritable(Operation *op, Value value, 460 const AnalysisState &state) const { 461 // Interestingly, scf::ForOp's bbArg can **always** be viewed 462 // inplace from the perspective of ops nested under: 463 // 1. Either the matching iter operand is not bufferized inplace and an 464 // alloc + optional copy makes the bbArg itself inplaceable. 465 // 2. Or the matching iter operand is bufferized inplace and bbArg just 466 // bufferizes to that too. 467 return true; 468 } 469 470 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 471 BufferizationState &state) const { 472 auto forOp = cast<scf::ForOp>(op); 473 auto oldYieldOp = 474 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 475 Block *oldLoopBody = &forOp.getLoopBody().front(); 476 477 // Indices of all iter_args that have tensor type. These are the ones that 478 // are bufferized. 479 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 480 // For every yielded value, is the value equivalent to its corresponding 481 // bbArg? 482 DenseSet<int64_t> equivalentYields = 483 getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(), 484 state.getAnalysisState()); 485 486 // The new memref init_args of the loop. 487 SmallVector<Value> initArgs = 488 getBuffers(rewriter, forOp.getIterOpOperands(), state); 489 490 // Construct a new scf.for op with memref instead of tensor values. 491 auto newForOp = rewriter.create<scf::ForOp>( 492 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 493 forOp.getStep(), initArgs); 494 newForOp->setAttrs(forOp->getAttrs()); 495 ValueRange initArgsRange(initArgs); 496 TypeRange initArgsTypes(initArgsRange); 497 Block *loopBody = &newForOp.getLoopBody().front(); 498 499 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 500 // iter_args of the new loop in ToTensorOps. 501 rewriter.setInsertionPointToStart(loopBody); 502 SmallVector<Value> iterArgs = 503 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); 504 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 505 506 // Erase terminator if present. 507 if (iterArgs.size() == 1) 508 rewriter.eraseOp(loopBody->getTerminator()); 509 510 // Move loop body to new loop. 511 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 512 513 // Update scf.yield of new loop. 514 auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 515 rewriter.setInsertionPoint(yieldOp); 516 SmallVector<Value> yieldValues = 517 getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices, 518 equivalentYields, state); 519 yieldOp.getResultsMutable().assign(yieldValues); 520 521 // Replace loop results. 522 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 523 524 return success(); 525 } 526 527 /// Assert that yielded values of an scf.for op are equivalent to their 528 /// corresponding bbArgs. In that case, the buffer relations of the 529 /// corresponding OpResults are "Equivalent". 530 /// 531 /// If this is not the case, an allocs+copies are inserted and yielded from 532 /// the loop. This could be a performance problem, so it must be explicitly 533 /// activated with `alloc-return-allocs`. 534 LogicalResult verifyAnalysis(Operation *op, 535 const AnalysisState &state) const { 536 const auto &options = 537 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 538 if (options.allowReturnAllocs) 539 return success(); 540 541 auto forOp = cast<scf::ForOp>(op); 542 auto yieldOp = 543 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 544 for (OpResult opResult : op->getOpResults()) { 545 if (!opResult.getType().isa<TensorType>()) 546 continue; 547 548 // Note: This is overly strict. We should check for aliasing bufferized 549 // values. But we don't have a "must-alias" analysis yet. 550 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) 551 return yieldOp->emitError() 552 << "Yield operand #" << opResult.getResultNumber() 553 << " is not equivalent to the corresponding iter bbArg"; 554 } 555 556 return success(); 557 } 558 }; 559 560 /// Bufferization of scf.while. Replace with a new scf.while that operates on 561 /// memrefs. 562 struct WhileOpInterface 563 : public BufferizableOpInterface::ExternalModel<WhileOpInterface, 564 scf::WhileOp> { 565 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 566 const AnalysisState &state) const { 567 // Tensor iter_args of scf::WhileOps are always considered as a read. 568 return true; 569 } 570 571 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 572 const AnalysisState &state) const { 573 // Tensor iter_args of scf::WhileOps are always considered as a write. 574 return true; 575 } 576 577 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 578 const AnalysisState &state) const { 579 auto whileOp = cast<scf::WhileOp>(op); 580 unsigned int idx = opOperand.getOperandNumber(); 581 582 // The OpResults and OpOperands may not match. They may not even have the 583 // same type. The number of OpResults and OpOperands can also differ. 584 if (idx >= op->getNumResults() || 585 opOperand.get().getType() != op->getResult(idx).getType()) 586 return {}; 587 588 // The only aliasing OpResult may be the one at the same index. 589 return {whileOp->getResult(idx)}; 590 } 591 592 BufferRelation bufferRelation(Operation *op, OpResult opResult, 593 const AnalysisState &state) const { 594 // WhileOp results are equivalent to their corresponding init_args if the 595 // corresponding iter_args and yield values are equivalent (for both the 596 // "before" and the "after" block). 597 unsigned int resultNumber = opResult.getResultNumber(); 598 auto whileOp = cast<scf::WhileOp>(op); 599 600 // The "before" region bbArgs and the OpResults may not match. 601 if (resultNumber >= whileOp.getBeforeArguments().size()) 602 return BufferRelation::None; 603 if (opResult.getType() != 604 whileOp.getBeforeArguments()[resultNumber].getType()) 605 return BufferRelation::None; 606 607 auto conditionOp = whileOp.getConditionOp(); 608 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; 609 Value conditionOperand = conditionOp.getArgs()[resultNumber]; 610 bool equivCondition = 611 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); 612 613 auto yieldOp = whileOp.getYieldOp(); 614 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; 615 Value yieldOperand = yieldOp.getOperand(resultNumber); 616 bool equivYield = 617 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); 618 619 return equivCondition && equivYield ? BufferRelation::Equivalent 620 : BufferRelation::None; 621 } 622 623 bool isWritable(Operation *op, Value value, 624 const AnalysisState &state) const { 625 // Interestingly, scf::WhileOp's bbArg can **always** be viewed 626 // inplace from the perspective of ops nested under: 627 // 1. Either the matching iter operand is not bufferized inplace and an 628 // alloc + optional copy makes the bbArg itself inplaceable. 629 // 2. Or the matching iter operand is bufferized inplace and bbArg just 630 // bufferizes to that too. 631 return true; 632 } 633 634 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 635 BufferizationState &state) const { 636 auto whileOp = cast<scf::WhileOp>(op); 637 638 assert(whileOp.getBefore().getBlocks().size() == 1 && 639 "regions with multiple blocks not supported"); 640 Block *beforeBody = &whileOp.getBefore().front(); 641 assert(whileOp.getAfter().getBlocks().size() == 1 && 642 "regions with multiple blocks not supported"); 643 Block *afterBody = &whileOp.getAfter().front(); 644 645 // Indices of all bbArgs that have tensor type. These are the ones that 646 // are bufferized. The "before" and "after" regions may have different args. 647 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); 648 DenseSet<int64_t> indicesAfter = 649 getTensorIndices(whileOp.getAfterArguments()); 650 651 // For every yielded value, is the value equivalent to its corresponding 652 // bbArg? 653 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 654 whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(), 655 state.getAnalysisState()); 656 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 657 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), 658 state.getAnalysisState()); 659 660 // The new memref init_args of the loop. 661 SmallVector<Value> initArgs = 662 getBuffers(rewriter, whileOp->getOpOperands(), state); 663 664 // The result types of a WhileOp are the same as the "after" bbArg types. 665 SmallVector<Type> argsTypesAfter = llvm::to_vector( 666 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { 667 return state.getBufferType(bbArg).cast<Type>(); 668 })); 669 670 // Construct a new scf.while op with memref instead of tensor values. 671 ValueRange argsRangeBefore(initArgs); 672 TypeRange argsTypesBefore(argsRangeBefore); 673 auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(), 674 argsTypesAfter, initArgs); 675 676 // Add before/after regions to the new op. 677 SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc()); 678 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), 679 whileOp.getLoc()); 680 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); 681 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); 682 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); 683 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); 684 685 // Set up new iter_args and move the loop condition block to the new op. 686 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 687 // in ToTensorOps. 688 rewriter.setInsertionPointToStart(newBeforeBody); 689 SmallVector<Value> newBeforeArgs = getBbArgReplacements( 690 rewriter, newWhileOp.getBeforeArguments(), indicesBefore); 691 rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); 692 693 // Update scf.condition of new loop. 694 auto newConditionOp = newWhileOp.getConditionOp(); 695 rewriter.setInsertionPoint(newConditionOp); 696 // Only equivalent buffers or new buffer allocations may be yielded to the 697 // "after" region. 698 // TODO: This could be relaxed for better bufferization results. 699 SmallVector<Value> newConditionArgs = 700 getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, 701 indicesAfter, equivalentYieldsBefore, state); 702 newConditionOp.getArgsMutable().assign(newConditionArgs); 703 704 // Set up new iter_args and move the loop body block to the new op. 705 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 706 // in ToTensorOps. 707 rewriter.setInsertionPointToStart(newAfterBody); 708 SmallVector<Value> newAfterArgs = getBbArgReplacements( 709 rewriter, newWhileOp.getAfterArguments(), indicesAfter); 710 rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); 711 712 // Update scf.yield of the new loop. 713 auto newYieldOp = newWhileOp.getYieldOp(); 714 rewriter.setInsertionPoint(newYieldOp); 715 // Only equivalent buffers or new buffer allocations may be yielded to the 716 // "before" region. 717 // TODO: This could be relaxed for better bufferization results. 718 SmallVector<Value> newYieldValues = 719 getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, 720 indicesBefore, equivalentYieldsAfter, state); 721 newYieldOp.getResultsMutable().assign(newYieldValues); 722 723 // Replace loop results. 724 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); 725 726 return success(); 727 } 728 729 /// Assert that yielded values of an scf.while op are equivalent to their 730 /// corresponding bbArgs. In that case, the buffer relations of the 731 /// corresponding OpResults are "Equivalent". 732 /// 733 /// If this is not the case, allocs+copies are inserted and yielded from 734 /// the loop. This could be a performance problem, so it must be explicitly 735 /// activated with `alloc-return-allocs`. 736 /// 737 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the 738 /// equivalence condition must be checked for both. 739 LogicalResult verifyAnalysis(Operation *op, 740 const AnalysisState &state) const { 741 auto whileOp = cast<scf::WhileOp>(op); 742 const auto &options = 743 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 744 if (options.allowReturnAllocs) 745 return success(); 746 747 auto conditionOp = whileOp.getConditionOp(); 748 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 749 if (!it.value().getType().isa<TensorType>()) 750 continue; 751 if (!state.areEquivalentBufferizedValues( 752 it.value(), conditionOp->getBlock()->getArgument(it.index()))) 753 return conditionOp->emitError() 754 << "Condition arg #" << it.index() 755 << " is not equivalent to the corresponding iter bbArg"; 756 } 757 758 auto yieldOp = whileOp.getYieldOp(); 759 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 760 if (!it.value().getType().isa<TensorType>()) 761 continue; 762 if (!state.areEquivalentBufferizedValues( 763 it.value(), yieldOp->getBlock()->getArgument(it.index()))) 764 return yieldOp->emitError() 765 << "Yield operand #" << it.index() 766 << " is not equivalent to the corresponding iter bbArg"; 767 } 768 769 return success(); 770 } 771 }; 772 773 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 774 /// this is for analysis only. 775 struct YieldOpInterface 776 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 777 scf::YieldOp> { 778 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 779 const AnalysisState &state) const { 780 return true; 781 } 782 783 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 784 const AnalysisState &state) const { 785 return false; 786 } 787 788 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 789 const AnalysisState &state) const { 790 if (isa<scf::IfOp>(op->getParentOp())) 791 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 792 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 793 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 794 return {}; 795 } 796 797 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 798 const AnalysisState &state) const { 799 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 800 // may be generated inside the block. We should not return/yield allocations 801 // when possible. 802 return true; 803 } 804 805 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 806 BufferizationState &state) const { 807 auto yieldOp = cast<scf::YieldOp>(op); 808 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>( 809 yieldOp->getParentOp())) 810 return yieldOp->emitError("unsupported scf::YieldOp parent"); 811 return success(); 812 } 813 }; 814 815 } // namespace 816 } // namespace scf 817 } // namespace mlir 818 819 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 820 DialectRegistry ®istry) { 821 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 822 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 823 ForOp::attachInterface<ForOpInterface>(*ctx); 824 IfOp::attachInterface<IfOpInterface>(*ctx); 825 WhileOp::attachInterface<WhileOpInterface>(*ctx); 826 YieldOp::attachInterface<YieldOpInterface>(*ctx); 827 }); 828 } 829