1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Utils/StaticValueUtils.h" 18 #include "mlir/IR/Dialect.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/IR/PatternMatch.h" 21 22 using namespace mlir; 23 using namespace mlir::bufferization; 24 using namespace mlir::scf; 25 26 namespace mlir { 27 namespace scf { 28 namespace { 29 30 // bufferization.to_memref is not allowed to change the rank. 31 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 32 #ifndef NDEBUG 33 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 34 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 35 rankedTensorType.getRank())) && 36 "to_memref would be invalid: mismatching ranks"); 37 #endif 38 } 39 40 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 41 /// fully implemented at the moment. 42 struct ExecuteRegionOpInterface 43 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 44 scf::ExecuteRegionOp> { 45 SmallVector<OpOperand *> 46 getAliasingOpOperand(Operation *op, OpResult opResult, 47 const AnalysisState &state) const { 48 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 49 // any SSA value that is in scope. To allow for use-def chain traversal 50 // through ExecuteRegionOps in the analysis, the corresponding yield value 51 // is considered to be aliasing with the result. 52 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 53 size_t resultNum = std::distance(op->getOpResults().begin(), 54 llvm::find(op->getOpResults(), opResult)); 55 // TODO: Support multiple blocks. 56 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 57 "expected exactly 1 block"); 58 auto yieldOp = dyn_cast<scf::YieldOp>( 59 executeRegionOp.getRegion().front().getTerminator()); 60 assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 61 return {&yieldOp->getOpOperand(resultNum)}; 62 } 63 64 // TODO: For better bufferization results, this could return `true` only if 65 // there is a memory write in the region. 66 bool isMemoryWrite(Operation *op, OpResult opResult, 67 const AnalysisState &state) const { 68 // Similar to scf.if, results of this op are always considered memory writes 69 // in the analysis. This is a useful pattern for all ops that have tensor 70 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 71 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 72 // ops without OpOperands. 73 return true; 74 } 75 76 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 77 const BufferizationOptions &options) const { 78 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 79 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 80 "only 1 block supported"); 81 auto yieldOp = 82 cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator()); 83 TypeRange newResultTypes(yieldOp.getResults()); 84 85 // Create new op and move over region. 86 auto newOp = 87 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 88 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 89 90 // Update all uses of the old op. 91 rewriter.setInsertionPointAfter(newOp); 92 SmallVector<Value> newResults; 93 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 94 if (it.value().isa<TensorType>()) { 95 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 96 executeRegionOp.getLoc(), newOp->getResult(it.index()))); 97 } else { 98 newResults.push_back(newOp->getResult(it.index())); 99 } 100 } 101 102 // Replace old op. 103 rewriter.replaceOp(executeRegionOp, newResults); 104 105 return success(); 106 } 107 108 BufferRelation bufferRelation(Operation *op, OpResult opResult, 109 const AnalysisState &state) const { 110 return BufferRelation::Equivalent; 111 } 112 }; 113 114 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 115 struct IfOpInterface 116 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 117 SmallVector<OpOperand *> 118 getAliasingOpOperand(Operation *op, OpResult opResult, 119 const AnalysisState &state) const { 120 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 121 // value that is in scope. To allow for use-def chain traversal through 122 // IfOps in the analysis, both corresponding yield values from the then/else 123 // branches are considered to be aliasing with the result. 124 auto ifOp = cast<scf::IfOp>(op); 125 size_t resultNum = std::distance(op->getOpResults().begin(), 126 llvm::find(op->getOpResults(), opResult)); 127 return {&ifOp.thenYield()->getOpOperand(resultNum), 128 &ifOp.elseYield()->getOpOperand(resultNum)}; 129 } 130 131 // TODO: For better bufferization results, this could return `true` only if 132 // there is a memory write in one (or both) of the branches. Since this is not 133 // allowed at the moment, we should never encounter scf.ifs that yield 134 // unmodified tensors. Such scf.yield ops could just fold away. 135 bool isMemoryWrite(Operation *op, OpResult opResult, 136 const AnalysisState &state) const { 137 // IfOp results are always considered memory writes in the analysis. This 138 // design decision simplifies the analysis considerably. E.g., consider the 139 // following test case: 140 // 141 // %0 = "some_writing_op" : tensor<?xf32> 142 // %r = scf.if %c -> (tensor<?xf32>) { 143 // scf.yield %0 144 // } else { 145 // %1 = "another_writing_op"(%0) : tensor<?xf32> 146 // } 147 // "some_reading_op"(%r) 148 // 149 // "another_writing_op" in the above example should be able to bufferize 150 // inplace in the absence of another read of %0. However, if the scf.if op 151 // would not be considered a "write", the analysis would detect the 152 // following conflict: 153 // 154 // * read = some_reading_op 155 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 156 // * conflictingWrite = %1 157 // 158 // For more details, check the "scf.IfOp" section of the design document. 159 return true; 160 } 161 162 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 163 const BufferizationOptions &options) const { 164 OpBuilder::InsertionGuard g(rewriter); 165 auto ifOp = cast<scf::IfOp>(op); 166 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator()); 167 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator()); 168 169 // Reconcile type mismatches between then/else branches by inserting memref 170 // casts. 171 SmallVector<Value> thenResults, elseResults; 172 bool insertedCast = false; 173 for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) { 174 Value thenValue = thenYieldOp.getResults()[i]; 175 Value elseValue = elseYieldOp.getResults()[i]; 176 if (thenValue.getType() == elseValue.getType()) { 177 thenResults.push_back(thenValue); 178 elseResults.push_back(elseValue); 179 continue; 180 } 181 182 // Type mismatch between then/else yield value. Cast both to a memref type 183 // with a fully dynamic layout map. 184 auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>(); 185 auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>(); 186 if (thenMemrefType.getMemorySpaceAsInt() != 187 elseMemrefType.getMemorySpaceAsInt()) 188 return op->emitError("inconsistent memory space on then/else branches"); 189 rewriter.setInsertionPoint(thenYieldOp); 190 BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout( 191 ifOp.getResultTypes()[i].cast<TensorType>(), 192 thenMemrefType.getMemorySpaceAsInt()); 193 thenResults.push_back(rewriter.create<memref::CastOp>( 194 thenYieldOp.getLoc(), memrefType, thenValue)); 195 rewriter.setInsertionPoint(elseYieldOp); 196 elseResults.push_back(rewriter.create<memref::CastOp>( 197 elseYieldOp.getLoc(), memrefType, elseValue)); 198 insertedCast = true; 199 } 200 201 if (insertedCast) { 202 rewriter.setInsertionPoint(thenYieldOp); 203 rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults); 204 rewriter.setInsertionPoint(elseYieldOp); 205 rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults); 206 } 207 208 // Create new op. 209 rewriter.setInsertionPoint(ifOp); 210 ValueRange resultsValueRange(thenResults); 211 TypeRange newTypes(resultsValueRange); 212 auto newIfOp = 213 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 214 /*withElseRegion=*/true); 215 216 // Move over then/else blocks. 217 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 218 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 219 220 // Replace op results. 221 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 222 223 return success(); 224 } 225 226 BufferRelation bufferRelation(Operation *op, OpResult opResult, 227 const AnalysisState &state) const { 228 // IfOp results are equivalent to their corresponding yield values if both 229 // yield values are equivalent to each other. 230 auto bufferizableOp = cast<BufferizableOpInterface>(op); 231 SmallVector<OpOperand *> yieldValues = 232 bufferizableOp.getAliasingOpOperand(opResult, state); 233 assert(yieldValues.size() == 2 && "expected 2 yield values"); 234 bool equivalentYields = state.areEquivalentBufferizedValues( 235 yieldValues[0]->get(), yieldValues[1]->get()); 236 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 237 } 238 }; 239 240 /// Helper function for loop bufferization. Return the indices of all values 241 /// that have a tensor type. 242 static DenseSet<int64_t> getTensorIndices(ValueRange values) { 243 DenseSet<int64_t> result; 244 for (const auto &it : llvm::enumerate(values)) 245 if (it.value().getType().isa<TensorType>()) 246 result.insert(it.index()); 247 return result; 248 } 249 250 /// Helper function for loop bufferization. Return the indices of all 251 /// bbArg/yielded value pairs who's buffer relation is "Equivalent". 252 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, 253 ValueRange yieldedValues, 254 const AnalysisState &state) { 255 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); 256 DenseSet<int64_t> result; 257 for (unsigned int i = 0; i < minSize; ++i) { 258 if (!bbArgs[i].getType().isa<TensorType>() || 259 !yieldedValues[i].getType().isa<TensorType>()) 260 continue; 261 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) 262 result.insert(i); 263 } 264 return result; 265 } 266 267 /// Helper function for loop bufferization. Cast the given buffer to the given 268 /// memref type. 269 static Value castBuffer(OpBuilder &b, Value buffer, Type type) { 270 assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType"); 271 assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType"); 272 // If the buffer already has the correct type, no cast is needed. 273 if (buffer.getType() == type) 274 return buffer; 275 // TODO: In case `type` has a layout map that is not the fully dynamic 276 // one, we may not be able to cast the buffer. In that case, the loop 277 // iter_arg's layout map must be changed (see uses of `castBuffer`). 278 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 279 "scf.while op bufferization: cast incompatible"); 280 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); 281 } 282 283 /// Helper function for loop bufferization. Return the bufferized values of the 284 /// given OpOperands. If an operand is not a tensor, return the original value. 285 static FailureOr<SmallVector<Value>> 286 getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands, 287 const BufferizationOptions &options) { 288 SmallVector<Value> result; 289 for (OpOperand &opOperand : operands) { 290 if (opOperand.get().getType().isa<TensorType>()) { 291 FailureOr<Value> resultBuffer = 292 getBuffer(rewriter, opOperand.get(), options); 293 if (failed(resultBuffer)) 294 return failure(); 295 result.push_back(*resultBuffer); 296 } else { 297 result.push_back(opOperand.get()); 298 } 299 } 300 return result; 301 } 302 303 /// Helper function for loop bufferization. Compute the buffer that should be 304 /// yielded from a loop block (loop body or loop condition). 305 static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor, 306 BaseMemRefType type, 307 const BufferizationOptions &options) { 308 assert(tensor.getType().isa<TensorType>() && "expected tensor"); 309 ensureToMemrefOpIsValid(tensor, type); 310 FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options); 311 if (failed(yieldedVal)) 312 return failure(); 313 return castBuffer(rewriter, *yieldedVal, type); 314 } 315 316 /// Helper function for loop bufferization. Given a range of values, apply 317 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified 318 /// value in the result vector. 319 static FailureOr<SmallVector<Value>> 320 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices, 321 llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) { 322 SmallVector<Value> result; 323 for (const auto &it : llvm::enumerate(values)) { 324 size_t idx = it.index(); 325 Value val = it.value(); 326 if (tensorIndices.contains(idx)) { 327 FailureOr<Value> maybeVal = func(val, idx); 328 if (failed(maybeVal)) 329 return failure(); 330 result.push_back(*maybeVal); 331 } else { 332 result.push_back(val); 333 } 334 } 335 return result; 336 } 337 338 /// Helper function for loop bufferization. Given a list of pre-bufferization 339 /// yielded values, compute the list of bufferized yielded values. 340 FailureOr<SmallVector<Value>> 341 getYieldedValues(RewriterBase &rewriter, ValueRange values, 342 TypeRange bufferizedTypes, 343 const DenseSet<int64_t> &tensorIndices, 344 const BufferizationOptions &options) { 345 return convertTensorValues( 346 values, tensorIndices, [&](Value val, int64_t index) { 347 return getYieldedBuffer(rewriter, val, 348 bufferizedTypes[index].cast<BaseMemRefType>(), 349 options); 350 }); 351 } 352 353 /// Helper function for loop bufferization. Given a list of bbArgs of the new 354 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into 355 /// ToTensorOps, so that the block body can be moved over to the new op. 356 SmallVector<Value> 357 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, 358 const DenseSet<int64_t> &tensorIndices) { 359 SmallVector<Value> result; 360 for (const auto &it : llvm::enumerate(bbArgs)) { 361 size_t idx = it.index(); 362 Value val = it.value(); 363 if (tensorIndices.contains(idx)) { 364 result.push_back( 365 rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val) 366 .getResult()); 367 } else { 368 result.push_back(val); 369 } 370 } 371 return result; 372 } 373 374 /// Bufferization of scf.for. Replace with a new scf.for that operates on 375 /// memrefs. 376 struct ForOpInterface 377 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 378 scf::ForOp> { 379 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 380 const AnalysisState &state) const { 381 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 382 // its matching bbArg may. 383 auto forOp = cast<scf::ForOp>(op); 384 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 385 } 386 387 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 388 const AnalysisState &state) const { 389 // Tensor iter_args of scf::ForOps are always considered as a write. 390 return true; 391 } 392 393 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 394 const AnalysisState &state) const { 395 auto forOp = cast<scf::ForOp>(op); 396 return {forOp.getResultForOpOperand(opOperand)}; 397 } 398 399 BufferRelation bufferRelation(Operation *op, OpResult opResult, 400 const AnalysisState &state) const { 401 // ForOp results are equivalent to their corresponding init_args if the 402 // corresponding iter_args and yield values are equivalent. 403 auto forOp = cast<scf::ForOp>(op); 404 OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 405 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 406 auto yieldOp = 407 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 408 bool equivalentYield = state.areEquivalentBufferizedValues( 409 bbArg, yieldOp->getOperand(opResult.getResultNumber())); 410 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 411 } 412 413 bool isWritable(Operation *op, Value value, 414 const AnalysisState &state) const { 415 // Interestingly, scf::ForOp's bbArg can **always** be viewed 416 // inplace from the perspective of ops nested under: 417 // 1. Either the matching iter operand is not bufferized inplace and an 418 // alloc + optional copy makes the bbArg itself inplaceable. 419 // 2. Or the matching iter operand is bufferized inplace and bbArg just 420 // bufferizes to that too. 421 return true; 422 } 423 424 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 425 const AnalysisState &state) const { 426 auto bufferizableOp = cast<BufferizableOpInterface>(op); 427 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 428 return failure(); 429 430 if (!state.getOptions().enforceAliasingInvariants) 431 return success(); 432 433 // According to the `getAliasing...` implementations, a bufferized OpResult 434 // may alias only with the corresponding bufferized init_arg and with no 435 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; 436 // but not with any other OpOperand. If a corresponding OpResult/init_arg 437 // pair bufferizes to equivalent buffers, this aliasing requirement is 438 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. 439 // (New buffer copies do not alias with any buffer.) 440 auto forOp = cast<scf::ForOp>(op); 441 auto yieldOp = 442 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 443 OpBuilder::InsertionGuard g(rewriter); 444 rewriter.setInsertionPoint(yieldOp); 445 446 // Indices of all iter_args that have tensor type. These are the ones that 447 // are bufferized. 448 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 449 // For every yielded value, is the value equivalent to its corresponding 450 // bbArg? 451 DenseSet<int64_t> equivalentYields = getEquivalentBuffers( 452 forOp.getRegionIterArgs(), yieldOp.getResults(), state); 453 SmallVector<Value> yieldValues; 454 for (int64_t idx = 0; 455 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) { 456 Value value = yieldOp.getResults()[idx]; 457 if (!indices.contains(idx) || equivalentYields.contains(idx)) { 458 yieldValues.push_back(value); 459 continue; 460 } 461 Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), 462 value, /*escape=*/true); 463 yieldValues.push_back(alloc); 464 } 465 466 rewriter.updateRootInPlace( 467 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); 468 return success(); 469 } 470 471 FailureOr<BaseMemRefType> 472 getBufferType(Operation *op, BlockArgument bbArg, 473 const BufferizationOptions &options) const { 474 auto forOp = cast<scf::ForOp>(op); 475 return bufferization::getBufferType( 476 forOp.getOpOperandForRegionIterArg(bbArg).get(), options); 477 } 478 479 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 480 const BufferizationOptions &options) const { 481 auto forOp = cast<scf::ForOp>(op); 482 Block *oldLoopBody = &forOp.getLoopBody().front(); 483 484 // Indices of all iter_args that have tensor type. These are the ones that 485 // are bufferized. 486 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 487 488 // The new memref init_args of the loop. 489 FailureOr<SmallVector<Value>> maybeInitArgs = 490 getBuffers(rewriter, forOp.getIterOpOperands(), options); 491 if (failed(maybeInitArgs)) 492 return failure(); 493 SmallVector<Value> initArgs = *maybeInitArgs; 494 495 // Construct a new scf.for op with memref instead of tensor values. 496 auto newForOp = rewriter.create<scf::ForOp>( 497 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 498 forOp.getStep(), initArgs); 499 newForOp->setAttrs(forOp->getAttrs()); 500 ValueRange initArgsRange(initArgs); 501 TypeRange initArgsTypes(initArgsRange); 502 Block *loopBody = &newForOp.getLoopBody().front(); 503 504 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 505 // iter_args of the new loop in ToTensorOps. 506 rewriter.setInsertionPointToStart(loopBody); 507 SmallVector<Value> iterArgs = 508 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); 509 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 510 511 // Move loop body to new loop. 512 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 513 514 // Replace loop results. 515 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 516 517 return success(); 518 } 519 520 /// Assert that yielded values of an scf.for op are equivalent to their 521 /// corresponding bbArgs. In that case, the buffer relations of the 522 /// corresponding OpResults are "Equivalent". 523 /// 524 /// If this is not the case, an allocs+copies are inserted and yielded from 525 /// the loop. This could be a performance problem, so it must be explicitly 526 /// activated with `alloc-return-allocs`. 527 LogicalResult verifyAnalysis(Operation *op, 528 const AnalysisState &state) const { 529 const auto &options = 530 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 531 if (options.allowReturnAllocs) 532 return success(); 533 534 auto forOp = cast<scf::ForOp>(op); 535 auto yieldOp = 536 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 537 for (OpResult opResult : op->getOpResults()) { 538 if (!opResult.getType().isa<TensorType>()) 539 continue; 540 541 // Note: This is overly strict. We should check for aliasing bufferized 542 // values. But we don't have a "must-alias" analysis yet. 543 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) 544 return yieldOp->emitError() 545 << "Yield operand #" << opResult.getResultNumber() 546 << " is not equivalent to the corresponding iter bbArg"; 547 } 548 549 return success(); 550 } 551 }; 552 553 /// Bufferization of scf.while. Replace with a new scf.while that operates on 554 /// memrefs. 555 struct WhileOpInterface 556 : public BufferizableOpInterface::ExternalModel<WhileOpInterface, 557 scf::WhileOp> { 558 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 559 const AnalysisState &state) const { 560 // Tensor iter_args of scf::WhileOps are always considered as a read. 561 return true; 562 } 563 564 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 565 const AnalysisState &state) const { 566 // Tensor iter_args of scf::WhileOps are always considered as a write. 567 return true; 568 } 569 570 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 571 const AnalysisState &state) const { 572 auto whileOp = cast<scf::WhileOp>(op); 573 unsigned int idx = opOperand.getOperandNumber(); 574 575 // The OpResults and OpOperands may not match. They may not even have the 576 // same type. The number of OpResults and OpOperands can also differ. 577 if (idx >= op->getNumResults() || 578 opOperand.get().getType() != op->getResult(idx).getType()) 579 return {}; 580 581 // The only aliasing OpResult may be the one at the same index. 582 return {whileOp->getResult(idx)}; 583 } 584 585 BufferRelation bufferRelation(Operation *op, OpResult opResult, 586 const AnalysisState &state) const { 587 // WhileOp results are equivalent to their corresponding init_args if the 588 // corresponding iter_args and yield values are equivalent (for both the 589 // "before" and the "after" block). 590 unsigned int resultNumber = opResult.getResultNumber(); 591 auto whileOp = cast<scf::WhileOp>(op); 592 593 // The "before" region bbArgs and the OpResults may not match. 594 if (resultNumber >= whileOp.getBeforeArguments().size()) 595 return BufferRelation::None; 596 if (opResult.getType() != 597 whileOp.getBeforeArguments()[resultNumber].getType()) 598 return BufferRelation::None; 599 600 auto conditionOp = whileOp.getConditionOp(); 601 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; 602 Value conditionOperand = conditionOp.getArgs()[resultNumber]; 603 bool equivCondition = 604 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); 605 606 auto yieldOp = whileOp.getYieldOp(); 607 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; 608 Value yieldOperand = yieldOp.getOperand(resultNumber); 609 bool equivYield = 610 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); 611 612 return equivCondition && equivYield ? BufferRelation::Equivalent 613 : BufferRelation::None; 614 } 615 616 bool isWritable(Operation *op, Value value, 617 const AnalysisState &state) const { 618 // Interestingly, scf::WhileOp's bbArg can **always** be viewed 619 // inplace from the perspective of ops nested under: 620 // 1. Either the matching iter operand is not bufferized inplace and an 621 // alloc + optional copy makes the bbArg itself inplaceable. 622 // 2. Or the matching iter operand is bufferized inplace and bbArg just 623 // bufferizes to that too. 624 return true; 625 } 626 627 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 628 const AnalysisState &state) const { 629 auto bufferizableOp = cast<BufferizableOpInterface>(op); 630 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 631 return failure(); 632 633 if (!state.getOptions().enforceAliasingInvariants) 634 return success(); 635 636 // According to the `getAliasing...` implementations, a bufferized OpResult 637 // may alias only with the corresponding bufferized init_arg and with no 638 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; 639 // but not with any other OpOperand. If a corresponding OpResult/init_arg 640 // pair bufferizes to equivalent buffers, this aliasing requirement is 641 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. 642 // (New buffer copies do not alias with any buffer.) 643 OpBuilder::InsertionGuard g(rewriter); 644 auto whileOp = cast<scf::WhileOp>(op); 645 auto conditionOp = whileOp.getConditionOp(); 646 auto yieldOp = whileOp.getYieldOp(); 647 648 // Indices of all bbArgs that have tensor type. These are the ones that 649 // are bufferized. The "before" and "after" regions may have different args. 650 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); 651 DenseSet<int64_t> indicesAfter = 652 getTensorIndices(whileOp.getAfterArguments()); 653 654 // For every yielded value, is the value equivalent to its corresponding 655 // bbArg? 656 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 657 whileOp.getBeforeArguments(), conditionOp.getArgs(), state); 658 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 659 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); 660 661 // Update "before" region. 662 rewriter.setInsertionPoint(conditionOp); 663 SmallVector<Value> beforeYieldValues; 664 for (int64_t idx = 0; 665 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { 666 Value value = conditionOp.getArgs()[idx]; 667 if (!indicesBefore.contains(idx) || 668 equivalentYieldsBefore.contains(idx)) { 669 beforeYieldValues.push_back(value); 670 continue; 671 } 672 Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), 673 value, /*escape=*/true); 674 beforeYieldValues.push_back(alloc); 675 } 676 rewriter.updateRootInPlace(conditionOp, [&]() { 677 conditionOp.getArgsMutable().assign(beforeYieldValues); 678 }); 679 680 // Update "after" region. 681 rewriter.setInsertionPoint(yieldOp); 682 SmallVector<Value> afterYieldValues; 683 for (int64_t idx = 0; 684 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) { 685 Value value = yieldOp.getResults()[idx]; 686 if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) { 687 afterYieldValues.push_back(value); 688 continue; 689 } 690 Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), 691 value, /*escape=*/true); 692 afterYieldValues.push_back(alloc); 693 } 694 rewriter.updateRootInPlace(yieldOp, [&]() { 695 yieldOp.getResultsMutable().assign(afterYieldValues); 696 }); 697 698 return success(); 699 } 700 701 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 702 const BufferizationOptions &options) const { 703 auto whileOp = cast<scf::WhileOp>(op); 704 705 assert(whileOp.getBefore().getBlocks().size() == 1 && 706 "regions with multiple blocks not supported"); 707 Block *beforeBody = &whileOp.getBefore().front(); 708 assert(whileOp.getAfter().getBlocks().size() == 1 && 709 "regions with multiple blocks not supported"); 710 Block *afterBody = &whileOp.getAfter().front(); 711 712 // Indices of all bbArgs that have tensor type. These are the ones that 713 // are bufferized. The "before" and "after" regions may have different args. 714 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); 715 DenseSet<int64_t> indicesAfter = 716 getTensorIndices(whileOp.getAfterArguments()); 717 718 // The new memref init_args of the loop. 719 FailureOr<SmallVector<Value>> maybeInitArgs = 720 getBuffers(rewriter, whileOp->getOpOperands(), options); 721 if (failed(maybeInitArgs)) 722 return failure(); 723 SmallVector<Value> initArgs = *maybeInitArgs; 724 725 // The result types of a WhileOp are the same as the "after" bbArg types. 726 SmallVector<Type> argsTypesAfter = llvm::to_vector( 727 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { 728 // TODO: error handling 729 return bufferization::getBufferType(bbArg, options)->cast<Type>(); 730 })); 731 732 // Construct a new scf.while op with memref instead of tensor values. 733 ValueRange argsRangeBefore(initArgs); 734 TypeRange argsTypesBefore(argsRangeBefore); 735 auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(), 736 argsTypesAfter, initArgs); 737 738 // Add before/after regions to the new op. 739 SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc()); 740 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), 741 whileOp.getLoc()); 742 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); 743 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); 744 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); 745 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); 746 747 // Set up new iter_args and move the loop condition block to the new op. 748 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 749 // in ToTensorOps. 750 rewriter.setInsertionPointToStart(newBeforeBody); 751 SmallVector<Value> newBeforeArgs = getBbArgReplacements( 752 rewriter, newWhileOp.getBeforeArguments(), indicesBefore); 753 rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); 754 755 // Update scf.condition of new loop. 756 auto newConditionOp = newWhileOp.getConditionOp(); 757 rewriter.setInsertionPoint(newConditionOp); 758 // Only equivalent buffers or new buffer allocations may be yielded to the 759 // "after" region. 760 // TODO: This could be relaxed for better bufferization results. 761 FailureOr<SmallVector<Value>> newConditionArgs = 762 getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, 763 indicesAfter, options); 764 if (failed(newConditionArgs)) 765 return failure(); 766 newConditionOp.getArgsMutable().assign(*newConditionArgs); 767 768 // Set up new iter_args and move the loop body block to the new op. 769 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 770 // in ToTensorOps. 771 rewriter.setInsertionPointToStart(newAfterBody); 772 SmallVector<Value> newAfterArgs = getBbArgReplacements( 773 rewriter, newWhileOp.getAfterArguments(), indicesAfter); 774 rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); 775 776 // Update scf.yield of the new loop. 777 auto newYieldOp = newWhileOp.getYieldOp(); 778 rewriter.setInsertionPoint(newYieldOp); 779 // Only equivalent buffers or new buffer allocations may be yielded to the 780 // "before" region. 781 // TODO: This could be relaxed for better bufferization results. 782 FailureOr<SmallVector<Value>> newYieldValues = 783 getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, 784 indicesBefore, options); 785 if (failed(newYieldValues)) 786 return failure(); 787 newYieldOp.getResultsMutable().assign(*newYieldValues); 788 789 // Replace loop results. 790 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); 791 792 return success(); 793 } 794 795 /// Assert that yielded values of an scf.while op are equivalent to their 796 /// corresponding bbArgs. In that case, the buffer relations of the 797 /// corresponding OpResults are "Equivalent". 798 /// 799 /// If this is not the case, allocs+copies are inserted and yielded from 800 /// the loop. This could be a performance problem, so it must be explicitly 801 /// activated with `alloc-return-allocs`. 802 /// 803 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the 804 /// equivalence condition must be checked for both. 805 LogicalResult verifyAnalysis(Operation *op, 806 const AnalysisState &state) const { 807 auto whileOp = cast<scf::WhileOp>(op); 808 const auto &options = 809 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 810 if (options.allowReturnAllocs) 811 return success(); 812 813 auto conditionOp = whileOp.getConditionOp(); 814 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 815 if (!it.value().getType().isa<TensorType>()) 816 continue; 817 if (!state.areEquivalentBufferizedValues( 818 it.value(), conditionOp->getBlock()->getArgument(it.index()))) 819 return conditionOp->emitError() 820 << "Condition arg #" << it.index() 821 << " is not equivalent to the corresponding iter bbArg"; 822 } 823 824 auto yieldOp = whileOp.getYieldOp(); 825 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 826 if (!it.value().getType().isa<TensorType>()) 827 continue; 828 if (!state.areEquivalentBufferizedValues( 829 it.value(), yieldOp->getBlock()->getArgument(it.index()))) 830 return yieldOp->emitError() 831 << "Yield operand #" << it.index() 832 << " is not equivalent to the corresponding iter bbArg"; 833 } 834 835 return success(); 836 } 837 }; 838 839 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 840 /// this is for analysis only. 841 struct YieldOpInterface 842 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 843 scf::YieldOp> { 844 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 845 const AnalysisState &state) const { 846 return true; 847 } 848 849 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 850 const AnalysisState &state) const { 851 return false; 852 } 853 854 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 855 const AnalysisState &state) const { 856 if (isa<scf::IfOp>(op->getParentOp())) 857 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 858 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 859 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 860 return {}; 861 } 862 863 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 864 const AnalysisState &state) const { 865 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 866 // may be generated inside the block. We should not return/yield allocations 867 // when possible. 868 return true; 869 } 870 871 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 872 const BufferizationOptions &options) const { 873 auto yieldOp = cast<scf::YieldOp>(op); 874 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>( 875 yieldOp->getParentOp())) 876 return yieldOp->emitError("unsupported scf::YieldOp parent"); 877 878 // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized 879 // together with scf.while.) 880 if (isa<scf::WhileOp>(yieldOp->getParentOp())) 881 return success(); 882 883 SmallVector<Value> newResults; 884 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 885 Value value = it.value(); 886 if (value.getType().isa<TensorType>()) { 887 FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options); 888 if (failed(maybeBuffer)) 889 return failure(); 890 Value buffer = *maybeBuffer; 891 if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) { 892 FailureOr<BaseMemRefType> resultType = 893 cast<BufferizableOpInterface>(forOp.getOperation()) 894 .getBufferType(forOp.getRegionIterArgs()[it.index()], 895 options); 896 if (failed(resultType)) 897 return failure(); 898 buffer = castBuffer(rewriter, buffer, *resultType); 899 } 900 newResults.push_back(buffer); 901 } else { 902 newResults.push_back(value); 903 } 904 } 905 906 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults); 907 return success(); 908 } 909 }; 910 911 using tensor::ExtractSliceOp; 912 913 /// Return the destinations that an ForeachThreadOp is inserting into. One per 914 /// ParallelInsertSliceOp. 915 static SmallVector<OpOperand *> 916 getInsertionDest(ForeachThreadOp foreachThreadOp) { 917 PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator(); 918 SmallVector<OpOperand *> result; 919 terminator.walk([&](ParallelInsertSliceOp insertOp) { 920 result.push_back(&insertOp->getOpOperand(1) /*dest*/); 921 }); 922 return result; 923 } 924 925 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the 926 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp 927 /// and ParallelInsertSliceOp), but these are only used during analysis. Not 928 /// for bufferization. 929 struct ForeachThreadOpInterface 930 : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface, 931 ForeachThreadOp> { 932 SmallVector<OpOperand *> 933 getAliasingOpOperand(Operation *op, OpResult opResult, 934 const AnalysisState &state) const { 935 // Get OpOperand (dest) from corresponding ParallelInsertSliceOp. 936 auto foreachThreadOp = cast<ForeachThreadOp>(op); 937 return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]}; 938 } 939 940 bool isMemoryWrite(Operation *op, OpResult opResult, 941 const AnalysisState &state) const { 942 // This op is a memory write. Stop lookup here to avoid finding false 943 // conflicts involving this op and one of the ops in the region. This is 944 // similar to how scf.if ops are analyzed. 945 return true; 946 } 947 948 BufferRelation bufferRelation(Operation *op, OpResult opResult, 949 const AnalysisState &state) const { 950 return BufferRelation::Equivalent; 951 } 952 953 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 954 const AnalysisState &state) const { 955 auto bufferizableOp = cast<BufferizableOpInterface>(op); 956 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 957 return failure(); 958 959 OpBuilder::InsertionGuard g(rewriter); 960 auto foreachThreadOp = cast<ForeachThreadOp>(op); 961 for (OpResult opResult : foreachThreadOp->getOpResults()) { 962 SmallVector<OpOperand *> destOperands = 963 state.getAliasingOpOperand(opResult); 964 assert(destOperands.size() == 1 && 965 "expected exactly one aliasing OpOperand"); 966 assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) && 967 "expected ParallelInsertSliceOp"); 968 969 // Nothing to do if there is no conflict. 970 if (state.isInPlace(*destOperands.front())) 971 continue; 972 973 // Insert tensor allocation. 974 bool isYielded = state.isTensorYielded(opResult); 975 Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(), 976 destOperands.front()->get(), 977 /*escape=*/isYielded); 978 979 // Update terminator operand. 980 rewriter.updateRootInPlace(destOperands.front()->getOwner(), 981 [&]() { destOperands.front()->set(alloc); }); 982 } 983 984 return success(); 985 } 986 987 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 988 const BufferizationOptions &options) const { 989 auto foreachThreadOp = cast<ForeachThreadOp>(op); 990 991 #ifndef NDEBUG 992 // ParallelInsertSliceOpInterface replaces all uses. 993 for (OpResult opResult : foreachThreadOp->getOpResults()) 994 assert(opResult.getUses().empty() && 995 "expected that all uses were already replaced"); 996 #endif // NDEBUG 997 998 // Create new ForeachThreadOp without any results and drop the automatically 999 // introduced terminator. 1000 TypeRange newResultTypes; 1001 auto newForeachThreadOp = rewriter.create<ForeachThreadOp>( 1002 foreachThreadOp.getLoc(), newResultTypes, 1003 foreachThreadOp.getNumThreads(), 1004 extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping())); 1005 newForeachThreadOp.getBody()->getTerminator()->erase(); 1006 1007 // Move over block contents of the old op. 1008 rewriter.mergeBlocks(foreachThreadOp.getBody(), 1009 newForeachThreadOp.getBody(), 1010 {newForeachThreadOp.getBody()->getArguments()}); 1011 1012 // Remove the old op. 1013 rewriter.eraseOp(op); 1014 1015 return success(); 1016 } 1017 }; 1018 1019 /// Nothing to do for PerformConcurrentlyOp. 1020 struct PerformConcurrentlyOpInterface 1021 : public BufferizableOpInterface::ExternalModel< 1022 PerformConcurrentlyOpInterface, PerformConcurrentlyOp> { 1023 LogicalResult bufferize(Operation *op, RewriterBase &b, 1024 const BufferizationOptions &options) const { 1025 llvm_unreachable("op does not have any tensor OpOperands / OpResults"); 1026 return failure(); 1027 } 1028 }; 1029 1030 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. 1031 /// equivalent operand / result and same offset/sizes/strides specification). 1032 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 1033 ExtractSliceOp st, 1034 ParallelInsertSliceOp sti) { 1035 if (!st || !sti) 1036 return false; 1037 if (st != sti && 1038 !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 1039 return false; 1040 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 1041 return false; 1042 return true; 1043 } 1044 1045 /// Return true if `value` is originating from an ExtractSliceOp that matches 1046 /// the given InsertSliceOp. 1047 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 1048 ParallelInsertSliceOp insertOp) { 1049 auto condition = [&](Value val) { 1050 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 1051 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 1052 return true; 1053 return false; 1054 }; 1055 1056 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 1057 condition); 1058 } 1059 1060 /// Analysis of ParallelInsertSliceOp. 1061 struct ParallelInsertSliceOpInterface 1062 : public BufferizableOpInterface::ExternalModel< 1063 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { 1064 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1065 const AnalysisState &state) const { 1066 if (&opOperand != &op->getOpOperand(1) /*dest*/) 1067 return {}; 1068 1069 // ParallelInsertSliceOp itself has no results. Tensors are returned via 1070 // the parent op. 1071 auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>(); 1072 assert(foreachThreadOp && 1073 "could not find valid owner of parallel_insert_slice"); 1074 1075 // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult 1076 // of the parent ForeachThreadOp. 1077 Block *block = op->getBlock(); 1078 unsigned int opIdx = 0; 1079 for (ParallelInsertSliceOp insertOp : 1080 block->getOps<ParallelInsertSliceOp>()) { 1081 if (insertOp.getOperation() == op) 1082 break; 1083 ++opIdx; 1084 } 1085 assert(opIdx < foreachThreadOp->getNumResults() && 1086 "could not find op inside terminator op"); 1087 1088 return {foreachThreadOp->getResult(opIdx)}; 1089 } 1090 1091 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1092 const AnalysisState &state) const { 1093 return true; 1094 } 1095 1096 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1097 const AnalysisState &state) const { 1098 return &opOperand == &op->getOpOperand(1) /*dest*/; 1099 } 1100 1101 BufferRelation bufferRelation(Operation *op, OpResult opResult, 1102 const AnalysisState &state) const { 1103 return BufferRelation::Equivalent; 1104 } 1105 1106 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 1107 const AnalysisState &state) const { 1108 return success(); 1109 } 1110 1111 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1112 const BufferizationOptions &options) const { 1113 OpBuilder::InsertionGuard g(rewriter); 1114 auto insertOp = cast<ParallelInsertSliceOp>(op); 1115 auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp()); 1116 auto foreachThreadOp = 1117 cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp()); 1118 1119 // If the op bufferizes out-of-place, allocate the copy before the 1120 // ForeachThreadOp. 1121 rewriter.setInsertionPoint(foreachThreadOp); 1122 FailureOr<Value> destBuffer = 1123 getBuffer(rewriter, insertOp.getDest(), options); 1124 if (failed(destBuffer)) 1125 return failure(); 1126 1127 // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp. 1128 rewriter.setInsertionPoint(performConcurrentlyOp); 1129 FailureOr<Value> srcBuffer = 1130 getBuffer(rewriter, insertOp.getSource(), options); 1131 if (failed(srcBuffer)) 1132 return failure(); 1133 Value subview = rewriter.create<memref::SubViewOp>( 1134 insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(), 1135 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 1136 // This memcpy will fold away if everything bufferizes in-place. 1137 if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer, 1138 subview))) 1139 return failure(); 1140 rewriter.eraseOp(op); 1141 1142 // Replace all uses of ForeachThreadOp (just the corresponding result). 1143 rewriter.setInsertionPointAfter(foreachThreadOp); 1144 Value toTensorOp = 1145 rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer); 1146 unsigned resultNum = 0; 1147 for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) { 1148 if (&nextOp == op) 1149 break; 1150 resultNum++; 1151 } 1152 assert(resultNum < foreachThreadOp->getNumResults() && 1153 "ParallelInsertSliceOp not found in PerformConcurrentlyOp"); 1154 SmallVector<OpOperand *> resultUses = llvm::to_vector( 1155 llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(), 1156 [](OpOperand &use) { return &use; })); 1157 for (OpOperand *use : resultUses) { 1158 rewriter.updateRootInPlace(use->getOwner(), 1159 [&]() { use->set(toTensorOp); }); 1160 } 1161 return success(); 1162 } 1163 1164 // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share 1165 // the code. 1166 bool isNotConflicting(Operation *op, OpOperand *uRead, 1167 OpOperand *uConflictingWrite, 1168 const AnalysisState &state) const { 1169 Operation *readingOp = uRead->getOwner(); 1170 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 1171 1172 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 1173 // uRead is an InsertSliceOp... 1174 if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) { 1175 // As an example, consider the following IR. 1176 // 1177 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1178 // %1 = linalg.fill %cst, %0 {inplace= [true] } 1179 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1180 // {inplace= [true] } 1181 1182 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 1183 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 1184 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 1185 insertSliceOp)) 1186 // Case 1: The main insight is that InsertSliceOp reads only part of 1187 // the destination tensor. The overwritten area is not read. If 1188 // uConflictingWrite writes into exactly the memory location that is 1189 // being read by uRead, this is not a conflict. 1190 // 1191 // In the above example: 1192 // uRead = OpOperand 1 (%t) of tensor.insert_slice 1193 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 1194 // 1195 // The read of %t does not conflict with the write of the FillOp 1196 // (same aliases!) because the area that the FillOp operates on is 1197 // exactly the one that is *not* read via %t. 1198 return true; 1199 1200 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 1201 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1202 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 1203 // Case 2: The read of the source tensor and the write to the dest 1204 // tensor via an InsertSliceOp is not a conflict if the read is 1205 // reading exactly that part of an equivalent tensor that the 1206 // InsertSliceOp is writing. 1207 // 1208 // In the above example: 1209 // uRead = OpOperand 0 (%1) of tensor.insert_slice 1210 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1211 return true; 1212 } 1213 1214 // If uConflictingWrite is an InsertSliceOp... 1215 if (auto insertSliceOp = 1216 dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp)) 1217 // As an example, consider the following IR. 1218 // 1219 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1220 // %1 = linalg.fill %cst, %0 {inplace= [true] } 1221 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1222 // {inplace= [true] } 1223 // %3 = vector.transfer_read %1, %cst 1224 // 1225 // In the above example: 1226 // uRead = OpOperand 0 (%1) of vector.transfer_read 1227 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1228 // lastWrite = %1 1229 // 1230 // This is not a conflict because the InsertSliceOp overwrites the 1231 // memory segment of %1 with the exact same data. (Effectively, there 1232 // is no memory write here.) 1233 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1234 state.areEquivalentBufferizedValues(uRead->get(), 1235 insertSliceOp.getSource()) && 1236 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 1237 insertSliceOp)) 1238 return true; 1239 1240 return false; 1241 } 1242 }; 1243 1244 } // namespace 1245 } // namespace scf 1246 } // namespace mlir 1247 1248 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 1249 DialectRegistry ®istry) { 1250 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 1251 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 1252 ForOp::attachInterface<ForOpInterface>(*ctx); 1253 IfOp::attachInterface<IfOpInterface>(*ctx); 1254 ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx); 1255 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 1256 *ctx); 1257 PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>( 1258 *ctx); 1259 WhileOp::attachInterface<WhileOpInterface>(*ctx); 1260 YieldOp::attachInterface<YieldOpInterface>(*ctx); 1261 }); 1262 } 1263