1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/PatternMatch.h" 20 21 using namespace mlir; 22 using namespace mlir::bufferization; 23 using namespace mlir::scf; 24 25 namespace mlir { 26 namespace scf { 27 namespace { 28 29 // bufferization.to_memref is not allowed to change the rank. 30 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 31 #ifndef NDEBUG 32 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 33 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 34 rankedTensorType.getRank())) && 35 "to_memref would be invalid: mismatching ranks"); 36 #endif 37 } 38 39 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 40 /// fully implemented at the moment. 41 struct ExecuteRegionOpInterface 42 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 43 scf::ExecuteRegionOp> { 44 SmallVector<OpOperand *> 45 getAliasingOpOperand(Operation *op, OpResult opResult, 46 const AnalysisState &state) const { 47 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 48 // any SSA value that is in scope. To allow for use-def chain traversal 49 // through ExecuteRegionOps in the analysis, the corresponding yield value 50 // is considered to be aliasing with the result. 51 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 52 size_t resultNum = std::distance(op->getOpResults().begin(), 53 llvm::find(op->getOpResults(), opResult)); 54 // TODO: Support multiple blocks. 55 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 56 "expected exactly 1 block"); 57 auto yieldOp = dyn_cast<scf::YieldOp>( 58 executeRegionOp.getRegion().front().getTerminator()); 59 assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 60 return {&yieldOp->getOpOperand(resultNum)}; 61 } 62 63 // TODO: For better bufferization results, this could return `true` only if 64 // there is a memory write in the region. 65 bool isMemoryWrite(Operation *op, OpResult opResult, 66 const AnalysisState &state) const { 67 // Similar to scf.if, results of this op are always considered memory writes 68 // in the analysis. This is a useful pattern for all ops that have tensor 69 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 70 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 71 // ops without OpOperands. 72 return true; 73 } 74 75 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 76 const BufferizationOptions &options) const { 77 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 78 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 79 "only 1 block supported"); 80 auto yieldOp = 81 cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator()); 82 TypeRange newResultTypes(yieldOp.getResults()); 83 84 // Create new op and move over region. 85 auto newOp = 86 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 87 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 88 89 // Update all uses of the old op. 90 rewriter.setInsertionPointAfter(newOp); 91 SmallVector<Value> newResults; 92 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 93 if (it.value().isa<TensorType>()) { 94 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 95 executeRegionOp.getLoc(), newOp->getResult(it.index()))); 96 } else { 97 newResults.push_back(newOp->getResult(it.index())); 98 } 99 } 100 101 // Replace old op. 102 rewriter.replaceOp(executeRegionOp, newResults); 103 104 return success(); 105 } 106 107 BufferRelation bufferRelation(Operation *op, OpResult opResult, 108 const AnalysisState &state) const { 109 return BufferRelation::Equivalent; 110 } 111 }; 112 113 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 114 struct IfOpInterface 115 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 116 SmallVector<OpOperand *> 117 getAliasingOpOperand(Operation *op, OpResult opResult, 118 const AnalysisState &state) const { 119 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 120 // value that is in scope. To allow for use-def chain traversal through 121 // IfOps in the analysis, both corresponding yield values from the then/else 122 // branches are considered to be aliasing with the result. 123 auto ifOp = cast<scf::IfOp>(op); 124 size_t resultNum = std::distance(op->getOpResults().begin(), 125 llvm::find(op->getOpResults(), opResult)); 126 return {&ifOp.thenYield()->getOpOperand(resultNum), 127 &ifOp.elseYield()->getOpOperand(resultNum)}; 128 } 129 130 // TODO: For better bufferization results, this could return `true` only if 131 // there is a memory write in one (or both) of the branches. Since this is not 132 // allowed at the moment, we should never encounter scf.ifs that yield 133 // unmodified tensors. Such scf.yield ops could just fold away. 134 bool isMemoryWrite(Operation *op, OpResult opResult, 135 const AnalysisState &state) const { 136 // IfOp results are always considered memory writes in the analysis. This 137 // design decision simplifies the analysis considerably. E.g., consider the 138 // following test case: 139 // 140 // %0 = "some_writing_op" : tensor<?xf32> 141 // %r = scf.if %c -> (tensor<?xf32>) { 142 // scf.yield %0 143 // } else { 144 // %1 = "another_writing_op"(%0) : tensor<?xf32> 145 // } 146 // "some_reading_op"(%r) 147 // 148 // "another_writing_op" in the above example should be able to bufferize 149 // inplace in the absence of another read of %0. However, if the scf.if op 150 // would not be considered a "write", the analysis would detect the 151 // following conflict: 152 // 153 // * read = some_reading_op 154 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 155 // * conflictingWrite = %1 156 // 157 // For more details, check the "scf.IfOp" section of the design document. 158 return true; 159 } 160 161 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 162 const BufferizationOptions &options) const { 163 OpBuilder::InsertionGuard g(rewriter); 164 auto ifOp = cast<scf::IfOp>(op); 165 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator()); 166 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator()); 167 168 // Reconcile type mismatches between then/else branches by inserting memref 169 // casts. 170 SmallVector<Value> thenResults, elseResults; 171 bool insertedCast = false; 172 for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) { 173 Value thenValue = thenYieldOp.getResults()[i]; 174 Value elseValue = elseYieldOp.getResults()[i]; 175 if (thenValue.getType() == elseValue.getType()) { 176 thenResults.push_back(thenValue); 177 elseResults.push_back(elseValue); 178 continue; 179 } 180 181 // Type mismatch between then/else yield value. Cast both to a memref type 182 // with a fully dynamic layout map. 183 auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>(); 184 auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>(); 185 if (thenMemrefType.getMemorySpaceAsInt() != 186 elseMemrefType.getMemorySpaceAsInt()) 187 return op->emitError("inconsistent memory space on then/else branches"); 188 rewriter.setInsertionPoint(thenYieldOp); 189 BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout( 190 ifOp.getResultTypes()[i].cast<TensorType>(), 191 thenMemrefType.getMemorySpaceAsInt()); 192 thenResults.push_back(rewriter.create<memref::CastOp>( 193 thenYieldOp.getLoc(), memrefType, thenValue)); 194 rewriter.setInsertionPoint(elseYieldOp); 195 elseResults.push_back(rewriter.create<memref::CastOp>( 196 elseYieldOp.getLoc(), memrefType, elseValue)); 197 insertedCast = true; 198 } 199 200 if (insertedCast) { 201 rewriter.setInsertionPoint(thenYieldOp); 202 rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults); 203 rewriter.setInsertionPoint(elseYieldOp); 204 rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults); 205 } 206 207 // Create new op. 208 rewriter.setInsertionPoint(ifOp); 209 ValueRange resultsValueRange(thenResults); 210 TypeRange newTypes(resultsValueRange); 211 auto newIfOp = 212 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 213 /*withElseRegion=*/true); 214 215 // Move over then/else blocks. 216 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 217 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 218 219 // Replace op results. 220 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 221 222 return success(); 223 } 224 225 BufferRelation bufferRelation(Operation *op, OpResult opResult, 226 const AnalysisState &state) const { 227 // IfOp results are equivalent to their corresponding yield values if both 228 // yield values are equivalent to each other. 229 auto bufferizableOp = cast<BufferizableOpInterface>(op); 230 SmallVector<OpOperand *> yieldValues = 231 bufferizableOp.getAliasingOpOperand(opResult, state); 232 assert(yieldValues.size() == 2 && "expected 2 yield values"); 233 bool equivalentYields = state.areEquivalentBufferizedValues( 234 yieldValues[0]->get(), yieldValues[1]->get()); 235 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 236 } 237 }; 238 239 /// Helper function for loop bufferization. Return the indices of all values 240 /// that have a tensor type. 241 static DenseSet<int64_t> getTensorIndices(ValueRange values) { 242 DenseSet<int64_t> result; 243 for (const auto &it : llvm::enumerate(values)) 244 if (it.value().getType().isa<TensorType>()) 245 result.insert(it.index()); 246 return result; 247 } 248 249 /// Helper function for loop bufferization. Return the indices of all 250 /// bbArg/yielded value pairs who's buffer relation is "Equivalent". 251 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, 252 ValueRange yieldedValues, 253 const AnalysisState &state) { 254 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); 255 DenseSet<int64_t> result; 256 for (unsigned int i = 0; i < minSize; ++i) { 257 if (!bbArgs[i].getType().isa<TensorType>() || 258 !yieldedValues[i].getType().isa<TensorType>()) 259 continue; 260 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) 261 result.insert(i); 262 } 263 return result; 264 } 265 266 /// Helper function for loop bufferization. Cast the given buffer to the given 267 /// memref type. 268 static Value castBuffer(OpBuilder &b, Value buffer, Type type) { 269 assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType"); 270 assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType"); 271 // If the buffer already has the correct type, no cast is needed. 272 if (buffer.getType() == type) 273 return buffer; 274 // TODO: In case `type` has a layout map that is not the fully dynamic 275 // one, we may not be able to cast the buffer. In that case, the loop 276 // iter_arg's layout map must be changed (see uses of `castBuffer`). 277 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 278 "scf.while op bufferization: cast incompatible"); 279 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); 280 } 281 282 /// Helper function for loop bufferization. Return the bufferized values of the 283 /// given OpOperands. If an operand is not a tensor, return the original value. 284 static FailureOr<SmallVector<Value>> 285 getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands, 286 const BufferizationOptions &options) { 287 SmallVector<Value> result; 288 for (OpOperand &opOperand : operands) { 289 if (opOperand.get().getType().isa<TensorType>()) { 290 FailureOr<Value> resultBuffer = 291 getBuffer(rewriter, opOperand.get(), options); 292 if (failed(resultBuffer)) 293 return failure(); 294 result.push_back(*resultBuffer); 295 } else { 296 result.push_back(opOperand.get()); 297 } 298 } 299 return result; 300 } 301 302 /// Helper function for loop bufferization. Compute the buffer that should be 303 /// yielded from a loop block (loop body or loop condition). 304 static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor, 305 BaseMemRefType type, 306 const BufferizationOptions &options) { 307 assert(tensor.getType().isa<TensorType>() && "expected tensor"); 308 ensureToMemrefOpIsValid(tensor, type); 309 FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options); 310 if (failed(yieldedVal)) 311 return failure(); 312 return castBuffer(rewriter, *yieldedVal, type); 313 } 314 315 /// Helper function for loop bufferization. Given a range of values, apply 316 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified 317 /// value in the result vector. 318 static FailureOr<SmallVector<Value>> 319 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices, 320 llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) { 321 SmallVector<Value> result; 322 for (const auto &it : llvm::enumerate(values)) { 323 size_t idx = it.index(); 324 Value val = it.value(); 325 if (tensorIndices.contains(idx)) { 326 FailureOr<Value> maybeVal = func(val, idx); 327 if (failed(maybeVal)) 328 return failure(); 329 result.push_back(*maybeVal); 330 } else { 331 result.push_back(val); 332 } 333 } 334 return result; 335 } 336 337 /// Helper function for loop bufferization. Given a list of pre-bufferization 338 /// yielded values, compute the list of bufferized yielded values. 339 FailureOr<SmallVector<Value>> 340 getYieldedValues(RewriterBase &rewriter, ValueRange values, 341 TypeRange bufferizedTypes, 342 const DenseSet<int64_t> &tensorIndices, 343 const BufferizationOptions &options) { 344 return convertTensorValues( 345 values, tensorIndices, [&](Value val, int64_t index) { 346 return getYieldedBuffer(rewriter, val, 347 bufferizedTypes[index].cast<BaseMemRefType>(), 348 options); 349 }); 350 } 351 352 /// Helper function for loop bufferization. Given a list of bbArgs of the new 353 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into 354 /// ToTensorOps, so that the block body can be moved over to the new op. 355 SmallVector<Value> 356 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, 357 const DenseSet<int64_t> &tensorIndices) { 358 SmallVector<Value> result; 359 for (const auto &it : llvm::enumerate(bbArgs)) { 360 size_t idx = it.index(); 361 Value val = it.value(); 362 if (tensorIndices.contains(idx)) { 363 result.push_back( 364 rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val) 365 .getResult()); 366 } else { 367 result.push_back(val); 368 } 369 } 370 return result; 371 } 372 373 /// Bufferization of scf.for. Replace with a new scf.for that operates on 374 /// memrefs. 375 struct ForOpInterface 376 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 377 scf::ForOp> { 378 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 379 const AnalysisState &state) const { 380 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 381 // its matching bbArg may. 382 auto forOp = cast<scf::ForOp>(op); 383 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 384 } 385 386 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 387 const AnalysisState &state) const { 388 // Tensor iter_args of scf::ForOps are always considered as a write. 389 return true; 390 } 391 392 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 393 const AnalysisState &state) const { 394 auto forOp = cast<scf::ForOp>(op); 395 return {forOp.getResultForOpOperand(opOperand)}; 396 } 397 398 BufferRelation bufferRelation(Operation *op, OpResult opResult, 399 const AnalysisState &state) const { 400 // ForOp results are equivalent to their corresponding init_args if the 401 // corresponding iter_args and yield values are equivalent. 402 auto forOp = cast<scf::ForOp>(op); 403 OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 404 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 405 auto yieldOp = 406 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 407 bool equivalentYield = state.areEquivalentBufferizedValues( 408 bbArg, yieldOp->getOperand(opResult.getResultNumber())); 409 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 410 } 411 412 bool isWritable(Operation *op, Value value, 413 const AnalysisState &state) const { 414 // Interestingly, scf::ForOp's bbArg can **always** be viewed 415 // inplace from the perspective of ops nested under: 416 // 1. Either the matching iter operand is not bufferized inplace and an 417 // alloc + optional copy makes the bbArg itself inplaceable. 418 // 2. Or the matching iter operand is bufferized inplace and bbArg just 419 // bufferizes to that too. 420 return true; 421 } 422 423 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 424 const AnalysisState &state) const { 425 auto bufferizableOp = cast<BufferizableOpInterface>(op); 426 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 427 return failure(); 428 429 if (!state.getOptions().enforceAliasingInvariants) 430 return success(); 431 432 // According to the `getAliasing...` implementations, a bufferized OpResult 433 // may alias only with the corresponding bufferized init_arg and with no 434 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; 435 // but not with any other OpOperand. If a corresponding OpResult/init_arg 436 // pair bufferizes to equivalent buffers, this aliasing requirement is 437 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. 438 // (New buffer copies do not alias with any buffer.) 439 auto forOp = cast<scf::ForOp>(op); 440 auto yieldOp = 441 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 442 OpBuilder::InsertionGuard g(rewriter); 443 rewriter.setInsertionPoint(yieldOp); 444 445 // Indices of all iter_args that have tensor type. These are the ones that 446 // are bufferized. 447 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 448 // For every yielded value, is the value equivalent to its corresponding 449 // bbArg? 450 DenseSet<int64_t> equivalentYields = getEquivalentBuffers( 451 forOp.getRegionIterArgs(), yieldOp.getResults(), state); 452 SmallVector<Value> yieldValues; 453 for (int64_t idx = 0; 454 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) { 455 Value value = yieldOp.getResults()[idx]; 456 if (!indices.contains(idx) || equivalentYields.contains(idx)) { 457 yieldValues.push_back(value); 458 continue; 459 } 460 Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), 461 value, /*escape=*/true); 462 yieldValues.push_back(alloc); 463 } 464 465 rewriter.updateRootInPlace( 466 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); 467 return success(); 468 } 469 470 FailureOr<BaseMemRefType> 471 getBufferType(Operation *op, BlockArgument bbArg, 472 const BufferizationOptions &options) const { 473 auto forOp = cast<scf::ForOp>(op); 474 return bufferization::getBufferType( 475 forOp.getOpOperandForRegionIterArg(bbArg).get(), options); 476 } 477 478 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 479 const BufferizationOptions &options) const { 480 auto forOp = cast<scf::ForOp>(op); 481 Block *oldLoopBody = &forOp.getLoopBody().front(); 482 483 // Indices of all iter_args that have tensor type. These are the ones that 484 // are bufferized. 485 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 486 487 // The new memref init_args of the loop. 488 FailureOr<SmallVector<Value>> maybeInitArgs = 489 getBuffers(rewriter, forOp.getIterOpOperands(), options); 490 if (failed(maybeInitArgs)) 491 return failure(); 492 SmallVector<Value> initArgs = *maybeInitArgs; 493 494 // Construct a new scf.for op with memref instead of tensor values. 495 auto newForOp = rewriter.create<scf::ForOp>( 496 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 497 forOp.getStep(), initArgs); 498 newForOp->setAttrs(forOp->getAttrs()); 499 ValueRange initArgsRange(initArgs); 500 TypeRange initArgsTypes(initArgsRange); 501 Block *loopBody = &newForOp.getLoopBody().front(); 502 503 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 504 // iter_args of the new loop in ToTensorOps. 505 rewriter.setInsertionPointToStart(loopBody); 506 SmallVector<Value> iterArgs = 507 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); 508 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 509 510 // Move loop body to new loop. 511 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 512 513 // Replace loop results. 514 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 515 516 return success(); 517 } 518 519 /// Assert that yielded values of an scf.for op are equivalent to their 520 /// corresponding bbArgs. In that case, the buffer relations of the 521 /// corresponding OpResults are "Equivalent". 522 /// 523 /// If this is not the case, an allocs+copies are inserted and yielded from 524 /// the loop. This could be a performance problem, so it must be explicitly 525 /// activated with `alloc-return-allocs`. 526 LogicalResult verifyAnalysis(Operation *op, 527 const AnalysisState &state) const { 528 const auto &options = 529 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 530 if (options.allowReturnAllocs) 531 return success(); 532 533 auto forOp = cast<scf::ForOp>(op); 534 auto yieldOp = 535 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 536 for (OpResult opResult : op->getOpResults()) { 537 if (!opResult.getType().isa<TensorType>()) 538 continue; 539 540 // Note: This is overly strict. We should check for aliasing bufferized 541 // values. But we don't have a "must-alias" analysis yet. 542 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) 543 return yieldOp->emitError() 544 << "Yield operand #" << opResult.getResultNumber() 545 << " is not equivalent to the corresponding iter bbArg"; 546 } 547 548 return success(); 549 } 550 }; 551 552 /// Bufferization of scf.while. Replace with a new scf.while that operates on 553 /// memrefs. 554 struct WhileOpInterface 555 : public BufferizableOpInterface::ExternalModel<WhileOpInterface, 556 scf::WhileOp> { 557 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 558 const AnalysisState &state) const { 559 // Tensor iter_args of scf::WhileOps are always considered as a read. 560 return true; 561 } 562 563 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 564 const AnalysisState &state) const { 565 // Tensor iter_args of scf::WhileOps are always considered as a write. 566 return true; 567 } 568 569 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 570 const AnalysisState &state) const { 571 auto whileOp = cast<scf::WhileOp>(op); 572 unsigned int idx = opOperand.getOperandNumber(); 573 574 // The OpResults and OpOperands may not match. They may not even have the 575 // same type. The number of OpResults and OpOperands can also differ. 576 if (idx >= op->getNumResults() || 577 opOperand.get().getType() != op->getResult(idx).getType()) 578 return {}; 579 580 // The only aliasing OpResult may be the one at the same index. 581 return {whileOp->getResult(idx)}; 582 } 583 584 BufferRelation bufferRelation(Operation *op, OpResult opResult, 585 const AnalysisState &state) const { 586 // WhileOp results are equivalent to their corresponding init_args if the 587 // corresponding iter_args and yield values are equivalent (for both the 588 // "before" and the "after" block). 589 unsigned int resultNumber = opResult.getResultNumber(); 590 auto whileOp = cast<scf::WhileOp>(op); 591 592 // The "before" region bbArgs and the OpResults may not match. 593 if (resultNumber >= whileOp.getBeforeArguments().size()) 594 return BufferRelation::None; 595 if (opResult.getType() != 596 whileOp.getBeforeArguments()[resultNumber].getType()) 597 return BufferRelation::None; 598 599 auto conditionOp = whileOp.getConditionOp(); 600 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; 601 Value conditionOperand = conditionOp.getArgs()[resultNumber]; 602 bool equivCondition = 603 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); 604 605 auto yieldOp = whileOp.getYieldOp(); 606 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; 607 Value yieldOperand = yieldOp.getOperand(resultNumber); 608 bool equivYield = 609 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); 610 611 return equivCondition && equivYield ? BufferRelation::Equivalent 612 : BufferRelation::None; 613 } 614 615 bool isWritable(Operation *op, Value value, 616 const AnalysisState &state) const { 617 // Interestingly, scf::WhileOp's bbArg can **always** be viewed 618 // inplace from the perspective of ops nested under: 619 // 1. Either the matching iter operand is not bufferized inplace and an 620 // alloc + optional copy makes the bbArg itself inplaceable. 621 // 2. Or the matching iter operand is bufferized inplace and bbArg just 622 // bufferizes to that too. 623 return true; 624 } 625 626 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 627 const AnalysisState &state) const { 628 auto bufferizableOp = cast<BufferizableOpInterface>(op); 629 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 630 return failure(); 631 632 if (!state.getOptions().enforceAliasingInvariants) 633 return success(); 634 635 // According to the `getAliasing...` implementations, a bufferized OpResult 636 // may alias only with the corresponding bufferized init_arg and with no 637 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; 638 // but not with any other OpOperand. If a corresponding OpResult/init_arg 639 // pair bufferizes to equivalent buffers, this aliasing requirement is 640 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. 641 // (New buffer copies do not alias with any buffer.) 642 OpBuilder::InsertionGuard g(rewriter); 643 auto whileOp = cast<scf::WhileOp>(op); 644 auto conditionOp = whileOp.getConditionOp(); 645 auto yieldOp = whileOp.getYieldOp(); 646 647 // Indices of all bbArgs that have tensor type. These are the ones that 648 // are bufferized. The "before" and "after" regions may have different args. 649 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); 650 DenseSet<int64_t> indicesAfter = 651 getTensorIndices(whileOp.getAfterArguments()); 652 653 // For every yielded value, is the value equivalent to its corresponding 654 // bbArg? 655 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 656 whileOp.getBeforeArguments(), conditionOp.getArgs(), state); 657 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 658 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); 659 660 // Update "before" region. 661 rewriter.setInsertionPoint(conditionOp); 662 SmallVector<Value> beforeYieldValues; 663 for (int64_t idx = 0; 664 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { 665 Value value = conditionOp.getArgs()[idx]; 666 if (!indicesBefore.contains(idx) || 667 equivalentYieldsBefore.contains(idx)) { 668 beforeYieldValues.push_back(value); 669 continue; 670 } 671 Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), 672 value, /*escape=*/true); 673 beforeYieldValues.push_back(alloc); 674 } 675 rewriter.updateRootInPlace(conditionOp, [&]() { 676 conditionOp.getArgsMutable().assign(beforeYieldValues); 677 }); 678 679 // Update "after" region. 680 rewriter.setInsertionPoint(yieldOp); 681 SmallVector<Value> afterYieldValues; 682 for (int64_t idx = 0; 683 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) { 684 Value value = yieldOp.getResults()[idx]; 685 if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) { 686 afterYieldValues.push_back(value); 687 continue; 688 } 689 Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), 690 value, /*escape=*/true); 691 afterYieldValues.push_back(alloc); 692 } 693 rewriter.updateRootInPlace(yieldOp, [&]() { 694 yieldOp.getResultsMutable().assign(afterYieldValues); 695 }); 696 697 return success(); 698 } 699 700 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 701 const BufferizationOptions &options) const { 702 auto whileOp = cast<scf::WhileOp>(op); 703 704 assert(whileOp.getBefore().getBlocks().size() == 1 && 705 "regions with multiple blocks not supported"); 706 Block *beforeBody = &whileOp.getBefore().front(); 707 assert(whileOp.getAfter().getBlocks().size() == 1 && 708 "regions with multiple blocks not supported"); 709 Block *afterBody = &whileOp.getAfter().front(); 710 711 // Indices of all bbArgs that have tensor type. These are the ones that 712 // are bufferized. The "before" and "after" regions may have different args. 713 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); 714 DenseSet<int64_t> indicesAfter = 715 getTensorIndices(whileOp.getAfterArguments()); 716 717 // The new memref init_args of the loop. 718 FailureOr<SmallVector<Value>> maybeInitArgs = 719 getBuffers(rewriter, whileOp->getOpOperands(), options); 720 if (failed(maybeInitArgs)) 721 return failure(); 722 SmallVector<Value> initArgs = *maybeInitArgs; 723 724 // The result types of a WhileOp are the same as the "after" bbArg types. 725 SmallVector<Type> argsTypesAfter = llvm::to_vector( 726 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { 727 // TODO: error handling 728 return bufferization::getBufferType(bbArg, options)->cast<Type>(); 729 })); 730 731 // Construct a new scf.while op with memref instead of tensor values. 732 ValueRange argsRangeBefore(initArgs); 733 TypeRange argsTypesBefore(argsRangeBefore); 734 auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(), 735 argsTypesAfter, initArgs); 736 737 // Add before/after regions to the new op. 738 SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc()); 739 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), 740 whileOp.getLoc()); 741 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); 742 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); 743 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); 744 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); 745 746 // Set up new iter_args and move the loop condition block to the new op. 747 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 748 // in ToTensorOps. 749 rewriter.setInsertionPointToStart(newBeforeBody); 750 SmallVector<Value> newBeforeArgs = getBbArgReplacements( 751 rewriter, newWhileOp.getBeforeArguments(), indicesBefore); 752 rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); 753 754 // Update scf.condition of new loop. 755 auto newConditionOp = newWhileOp.getConditionOp(); 756 rewriter.setInsertionPoint(newConditionOp); 757 // Only equivalent buffers or new buffer allocations may be yielded to the 758 // "after" region. 759 // TODO: This could be relaxed for better bufferization results. 760 FailureOr<SmallVector<Value>> newConditionArgs = 761 getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, 762 indicesAfter, options); 763 if (failed(newConditionArgs)) 764 return failure(); 765 newConditionOp.getArgsMutable().assign(*newConditionArgs); 766 767 // Set up new iter_args and move the loop body block to the new op. 768 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 769 // in ToTensorOps. 770 rewriter.setInsertionPointToStart(newAfterBody); 771 SmallVector<Value> newAfterArgs = getBbArgReplacements( 772 rewriter, newWhileOp.getAfterArguments(), indicesAfter); 773 rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); 774 775 // Update scf.yield of the new loop. 776 auto newYieldOp = newWhileOp.getYieldOp(); 777 rewriter.setInsertionPoint(newYieldOp); 778 // Only equivalent buffers or new buffer allocations may be yielded to the 779 // "before" region. 780 // TODO: This could be relaxed for better bufferization results. 781 FailureOr<SmallVector<Value>> newYieldValues = 782 getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, 783 indicesBefore, options); 784 if (failed(newYieldValues)) 785 return failure(); 786 newYieldOp.getResultsMutable().assign(*newYieldValues); 787 788 // Replace loop results. 789 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); 790 791 return success(); 792 } 793 794 /// Assert that yielded values of an scf.while op are equivalent to their 795 /// corresponding bbArgs. In that case, the buffer relations of the 796 /// corresponding OpResults are "Equivalent". 797 /// 798 /// If this is not the case, allocs+copies are inserted and yielded from 799 /// the loop. This could be a performance problem, so it must be explicitly 800 /// activated with `alloc-return-allocs`. 801 /// 802 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the 803 /// equivalence condition must be checked for both. 804 LogicalResult verifyAnalysis(Operation *op, 805 const AnalysisState &state) const { 806 auto whileOp = cast<scf::WhileOp>(op); 807 const auto &options = 808 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 809 if (options.allowReturnAllocs) 810 return success(); 811 812 auto conditionOp = whileOp.getConditionOp(); 813 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 814 if (!it.value().getType().isa<TensorType>()) 815 continue; 816 if (!state.areEquivalentBufferizedValues( 817 it.value(), conditionOp->getBlock()->getArgument(it.index()))) 818 return conditionOp->emitError() 819 << "Condition arg #" << it.index() 820 << " is not equivalent to the corresponding iter bbArg"; 821 } 822 823 auto yieldOp = whileOp.getYieldOp(); 824 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 825 if (!it.value().getType().isa<TensorType>()) 826 continue; 827 if (!state.areEquivalentBufferizedValues( 828 it.value(), yieldOp->getBlock()->getArgument(it.index()))) 829 return yieldOp->emitError() 830 << "Yield operand #" << it.index() 831 << " is not equivalent to the corresponding iter bbArg"; 832 } 833 834 return success(); 835 } 836 }; 837 838 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 839 /// this is for analysis only. 840 struct YieldOpInterface 841 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 842 scf::YieldOp> { 843 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 844 const AnalysisState &state) const { 845 return true; 846 } 847 848 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 849 const AnalysisState &state) const { 850 return false; 851 } 852 853 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 854 const AnalysisState &state) const { 855 if (isa<scf::IfOp>(op->getParentOp())) 856 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 857 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 858 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 859 return {}; 860 } 861 862 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 863 const AnalysisState &state) const { 864 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 865 // may be generated inside the block. We should not return/yield allocations 866 // when possible. 867 return true; 868 } 869 870 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 871 const BufferizationOptions &options) const { 872 auto yieldOp = cast<scf::YieldOp>(op); 873 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>( 874 yieldOp->getParentOp())) 875 return yieldOp->emitError("unsupported scf::YieldOp parent"); 876 877 // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized 878 // together with scf.while.) 879 if (isa<scf::WhileOp>(yieldOp->getParentOp())) 880 return success(); 881 882 SmallVector<Value> newResults; 883 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 884 Value value = it.value(); 885 if (value.getType().isa<TensorType>()) { 886 FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options); 887 if (failed(maybeBuffer)) 888 return failure(); 889 Value buffer = *maybeBuffer; 890 if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) { 891 FailureOr<BaseMemRefType> resultType = 892 cast<BufferizableOpInterface>(forOp.getOperation()) 893 .getBufferType(forOp.getRegionIterArgs()[it.index()], 894 options); 895 if (failed(resultType)) 896 return failure(); 897 buffer = castBuffer(rewriter, buffer, *resultType); 898 } 899 newResults.push_back(buffer); 900 } else { 901 newResults.push_back(value); 902 } 903 } 904 905 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults); 906 return success(); 907 } 908 }; 909 910 using tensor::ExtractSliceOp; 911 912 /// Return the destinations that an ForeachThreadOp is inserting into. One per 913 /// ParallelInsertSliceOp. 914 static SmallVector<OpOperand *> 915 getInsertionDest(ForeachThreadOp foreachThreadOp) { 916 PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator(); 917 SmallVector<OpOperand *> result; 918 terminator.walk([&](ParallelInsertSliceOp insertOp) { 919 result.push_back(&insertOp->getOpOperand(1) /*dest*/); 920 }); 921 return result; 922 } 923 924 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the 925 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp 926 /// and ParallelInsertSliceOp), but these are only used during analysis. Not 927 /// for bufferization. 928 struct ForeachThreadOpInterface 929 : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface, 930 ForeachThreadOp> { 931 SmallVector<OpOperand *> 932 getAliasingOpOperand(Operation *op, OpResult opResult, 933 const AnalysisState &state) const { 934 // Get OpOperand (dest) from corresponding ParallelInsertSliceOp. 935 auto foreachThreadOp = cast<ForeachThreadOp>(op); 936 return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]}; 937 } 938 939 bool isMemoryWrite(Operation *op, OpResult opResult, 940 const AnalysisState &state) const { 941 // This op is a memory write. Stop lookup here to avoid finding false 942 // conflicts involving this op and one of the ops in the region. This is 943 // similar to how scf.if ops are analyzed. 944 return true; 945 } 946 947 BufferRelation bufferRelation(Operation *op, OpResult opResult, 948 const AnalysisState &state) const { 949 return BufferRelation::Equivalent; 950 } 951 952 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 953 const AnalysisState &state) const { 954 auto bufferizableOp = cast<BufferizableOpInterface>(op); 955 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 956 return failure(); 957 958 OpBuilder::InsertionGuard g(rewriter); 959 auto foreachThreadOp = cast<ForeachThreadOp>(op); 960 for (OpResult opResult : foreachThreadOp->getOpResults()) { 961 SmallVector<OpOperand *> destOperands = 962 state.getAliasingOpOperand(opResult); 963 assert(destOperands.size() == 1 && 964 "expected exactly one aliasing OpOperand"); 965 assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) && 966 "expected ParallelInsertSliceOp"); 967 968 // Nothing to do if there is no conflict. 969 if (state.isInPlace(*destOperands.front())) 970 continue; 971 972 // Insert tensor allocation. 973 bool isYielded = state.isTensorYielded(opResult); 974 Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(), 975 destOperands.front()->get(), 976 /*escape=*/isYielded); 977 978 // Update terminator operand. 979 rewriter.updateRootInPlace(destOperands.front()->getOwner(), 980 [&]() { destOperands.front()->set(alloc); }); 981 } 982 983 return success(); 984 } 985 986 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 987 const BufferizationOptions &options) const { 988 auto foreachThreadOp = cast<ForeachThreadOp>(op); 989 990 #ifndef NDEBUG 991 // ParallelInsertSliceOpInterface replaces all uses. 992 for (OpResult opResult : foreachThreadOp->getOpResults()) 993 assert(opResult.getUses().empty() && 994 "expected that all uses were already replaced"); 995 #endif // NDEBUG 996 997 // Create new ForeachThreadOp without any results and drop the automatically 998 // introduced terminator. 999 TypeRange newResultTypes; 1000 auto newForeachThreadOp = rewriter.create<ForeachThreadOp>( 1001 foreachThreadOp.getLoc(), newResultTypes, 1002 foreachThreadOp.getNumThreads()); 1003 newForeachThreadOp.getBody()->getTerminator()->erase(); 1004 1005 // Move over block contents of the old op. 1006 rewriter.mergeBlocks(foreachThreadOp.getBody(), 1007 newForeachThreadOp.getBody(), 1008 {newForeachThreadOp.getBody()->getArguments()}); 1009 1010 // Remove the old op. 1011 rewriter.eraseOp(op); 1012 1013 return success(); 1014 } 1015 }; 1016 1017 /// Nothing to do for PerformConcurrentlyOp. 1018 struct PerformConcurrentlyOpInterface 1019 : public BufferizableOpInterface::ExternalModel< 1020 PerformConcurrentlyOpInterface, PerformConcurrentlyOp> { 1021 LogicalResult bufferize(Operation *op, RewriterBase &b, 1022 const BufferizationOptions &options) const { 1023 llvm_unreachable("op does not have any tensor OpOperands / OpResults"); 1024 return failure(); 1025 } 1026 }; 1027 1028 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. 1029 /// equivalent operand / result and same offset/sizes/strides specification). 1030 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 1031 ExtractSliceOp st, 1032 ParallelInsertSliceOp sti) { 1033 if (!st || !sti) 1034 return false; 1035 if (st != sti && 1036 !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 1037 return false; 1038 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 1039 return false; 1040 return true; 1041 } 1042 1043 /// Return true if `value` is originating from an ExtractSliceOp that matches 1044 /// the given InsertSliceOp. 1045 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 1046 ParallelInsertSliceOp insertOp) { 1047 auto condition = [&](Value val) { 1048 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 1049 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 1050 return true; 1051 return false; 1052 }; 1053 1054 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 1055 condition); 1056 } 1057 1058 /// Analysis of ParallelInsertSliceOp. 1059 struct ParallelInsertSliceOpInterface 1060 : public BufferizableOpInterface::ExternalModel< 1061 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { 1062 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1063 const AnalysisState &state) const { 1064 if (&opOperand != &op->getOpOperand(1) /*dest*/) 1065 return {}; 1066 1067 // ParallelInsertSliceOp itself has no results. Tensors are returned via 1068 // the parent op. 1069 auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>(); 1070 assert(foreachThreadOp && 1071 "could not find valid owner of parallel_insert_slice"); 1072 1073 // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult 1074 // of the parent ForeachThreadOp. 1075 Block *block = op->getBlock(); 1076 unsigned int opIdx = 0; 1077 for (ParallelInsertSliceOp insertOp : 1078 block->getOps<ParallelInsertSliceOp>()) { 1079 if (insertOp.getOperation() == op) 1080 break; 1081 ++opIdx; 1082 } 1083 assert(opIdx < foreachThreadOp->getNumResults() && 1084 "could not find op inside terminator op"); 1085 1086 return {foreachThreadOp->getResult(opIdx)}; 1087 } 1088 1089 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1090 const AnalysisState &state) const { 1091 return true; 1092 } 1093 1094 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1095 const AnalysisState &state) const { 1096 return &opOperand == &op->getOpOperand(1) /*dest*/; 1097 } 1098 1099 BufferRelation bufferRelation(Operation *op, OpResult opResult, 1100 const AnalysisState &state) const { 1101 return BufferRelation::Equivalent; 1102 } 1103 1104 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 1105 const AnalysisState &state) const { 1106 return success(); 1107 } 1108 1109 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1110 const BufferizationOptions &options) const { 1111 OpBuilder::InsertionGuard g(rewriter); 1112 auto insertOp = cast<ParallelInsertSliceOp>(op); 1113 auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(op->getParentOp()); 1114 auto foreachThreadOp = 1115 cast<ForeachThreadOp>(performConcurrentlyOp->getParentOp()); 1116 1117 // If the op bufferizes out-of-place, allocate the copy before the 1118 // ForeachThreadOp. 1119 rewriter.setInsertionPoint(foreachThreadOp); 1120 FailureOr<Value> destBuffer = 1121 getBuffer(rewriter, insertOp.getDest(), options); 1122 if (failed(destBuffer)) 1123 return failure(); 1124 1125 // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp. 1126 rewriter.setInsertionPoint(performConcurrentlyOp); 1127 FailureOr<Value> srcBuffer = 1128 getBuffer(rewriter, insertOp.getSource(), options); 1129 if (failed(srcBuffer)) 1130 return failure(); 1131 Value subview = rewriter.create<memref::SubViewOp>( 1132 insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(), 1133 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 1134 // This memcpy will fold away if everything bufferizes in-place. 1135 if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer, 1136 subview))) 1137 return failure(); 1138 rewriter.eraseOp(op); 1139 1140 // Replace all uses of ForeachThreadOp (just the corresponding result). 1141 rewriter.setInsertionPointAfter(foreachThreadOp); 1142 Value toTensorOp = 1143 rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer); 1144 unsigned resultNum = 0; 1145 for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) { 1146 if (&nextOp == op) 1147 break; 1148 resultNum++; 1149 } 1150 assert(resultNum < foreachThreadOp->getNumResults() && 1151 "ParallelInsertSliceOp not found in PerformConcurrentlyOp"); 1152 SmallVector<OpOperand *> resultUses = llvm::to_vector( 1153 llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(), 1154 [](OpOperand &use) { return &use; })); 1155 for (OpOperand *use : resultUses) { 1156 rewriter.updateRootInPlace(use->getOwner(), 1157 [&]() { use->set(toTensorOp); }); 1158 } 1159 return success(); 1160 } 1161 1162 // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share 1163 // the code. 1164 bool isNotConflicting(Operation *op, OpOperand *uRead, 1165 OpOperand *uConflictingWrite, 1166 const AnalysisState &state) const { 1167 Operation *readingOp = uRead->getOwner(); 1168 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 1169 1170 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 1171 // uRead is an InsertSliceOp... 1172 if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) { 1173 // As an example, consider the following IR. 1174 // 1175 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1176 // %1 = linalg.fill %cst, %0 {inplace= [true] } 1177 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1178 // {inplace= [true] } 1179 1180 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 1181 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 1182 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 1183 insertSliceOp)) 1184 // Case 1: The main insight is that InsertSliceOp reads only part of 1185 // the destination tensor. The overwritten area is not read. If 1186 // uConflictingWrite writes into exactly the memory location that is 1187 // being read by uRead, this is not a conflict. 1188 // 1189 // In the above example: 1190 // uRead = OpOperand 1 (%t) of tensor.insert_slice 1191 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 1192 // 1193 // The read of %t does not conflict with the write of the FillOp 1194 // (same aliases!) because the area that the FillOp operates on is 1195 // exactly the one that is *not* read via %t. 1196 return true; 1197 1198 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 1199 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1200 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 1201 // Case 2: The read of the source tensor and the write to the dest 1202 // tensor via an InsertSliceOp is not a conflict if the read is 1203 // reading exactly that part of an equivalent tensor that the 1204 // InsertSliceOp is writing. 1205 // 1206 // In the above example: 1207 // uRead = OpOperand 0 (%1) of tensor.insert_slice 1208 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1209 return true; 1210 } 1211 1212 // If uConflictingWrite is an InsertSliceOp... 1213 if (auto insertSliceOp = 1214 dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp)) 1215 // As an example, consider the following IR. 1216 // 1217 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1218 // %1 = linalg.fill %cst, %0 {inplace= [true] } 1219 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1220 // {inplace= [true] } 1221 // %3 = vector.transfer_read %1, %cst 1222 // 1223 // In the above example: 1224 // uRead = OpOperand 0 (%1) of vector.transfer_read 1225 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1226 // lastWrite = %1 1227 // 1228 // This is not a conflict because the InsertSliceOp overwrites the 1229 // memory segment of %1 with the exact same data. (Effectively, there 1230 // is no memory write here.) 1231 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1232 state.areEquivalentBufferizedValues(uRead->get(), 1233 insertSliceOp.getSource()) && 1234 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 1235 insertSliceOp)) 1236 return true; 1237 1238 return false; 1239 } 1240 }; 1241 1242 } // namespace 1243 } // namespace scf 1244 } // namespace mlir 1245 1246 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 1247 DialectRegistry ®istry) { 1248 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 1249 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 1250 ForOp::attachInterface<ForOpInterface>(*ctx); 1251 IfOp::attachInterface<IfOpInterface>(*ctx); 1252 ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx); 1253 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 1254 *ctx); 1255 PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>( 1256 *ctx); 1257 WhileOp::attachInterface<WhileOpInterface>(*ctx); 1258 YieldOp::attachInterface<YieldOpInterface>(*ctx); 1259 }); 1260 } 1261