1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/Dialect.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/PatternMatch.h" 20 21 using namespace mlir; 22 using namespace mlir::bufferization; 23 using namespace mlir::scf; 24 25 namespace mlir { 26 namespace scf { 27 namespace { 28 29 // bufferization.to_memref is not allowed to change the rank. 30 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { 31 #ifndef NDEBUG 32 auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>(); 33 assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() == 34 rankedTensorType.getRank())) && 35 "to_memref would be invalid: mismatching ranks"); 36 #endif 37 } 38 39 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not 40 /// fully implemented at the moment. 41 struct ExecuteRegionOpInterface 42 : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface, 43 scf::ExecuteRegionOp> { 44 SmallVector<OpOperand *> 45 getAliasingOpOperand(Operation *op, OpResult opResult, 46 const AnalysisState &state) const { 47 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be 48 // any SSA value that is in scope. To allow for use-def chain traversal 49 // through ExecuteRegionOps in the analysis, the corresponding yield value 50 // is considered to be aliasing with the result. 51 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 52 size_t resultNum = std::distance(op->getOpResults().begin(), 53 llvm::find(op->getOpResults(), opResult)); 54 // TODO: Support multiple blocks. 55 assert(executeRegionOp.getRegion().getBlocks().size() == 1 && 56 "expected exactly 1 block"); 57 auto yieldOp = dyn_cast<scf::YieldOp>( 58 executeRegionOp.getRegion().front().getTerminator()); 59 assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); 60 return {&yieldOp->getOpOperand(resultNum)}; 61 } 62 63 // TODO: For better bufferization results, this could return `true` only if 64 // there is a memory write in the region. 65 bool isMemoryWrite(Operation *op, OpResult opResult, 66 const AnalysisState &state) const { 67 // Similar to scf.if, results of this op are always considered memory writes 68 // in the analysis. This is a useful pattern for all ops that have tensor 69 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 70 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 71 // ops without OpOperands. 72 return true; 73 } 74 75 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 76 BufferizationState &state) const { 77 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); 78 79 // Compute new result types. 80 SmallVector<Type> newResultTypes; 81 for (Type type : executeRegionOp->getResultTypes()) { 82 if (auto tensorType = type.dyn_cast<TensorType>()) { 83 // TODO: Infer the result type instead of computing it. 84 newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 85 } else { 86 newResultTypes.push_back(type); 87 } 88 } 89 90 // Create new op and move over region. 91 auto newOp = 92 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); 93 newOp.getRegion().takeBody(executeRegionOp.getRegion()); 94 95 // Update terminator. 96 assert(newOp.getRegion().getBlocks().size() == 1 && 97 "only 1 block supported"); 98 Block *newBlock = &newOp.getRegion().front(); 99 auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); 100 rewriter.setInsertionPoint(yieldOp); 101 SmallVector<Value> newYieldValues; 102 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 103 Value val = it.value(); 104 if (val.getType().isa<TensorType>()) { 105 newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 106 yieldOp.getLoc(), newResultTypes[it.index()], val)); 107 } else { 108 newYieldValues.push_back(val); 109 } 110 } 111 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); 112 113 // Update all uses of the old op. 114 rewriter.setInsertionPointAfter(newOp); 115 SmallVector<Value> newResults; 116 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { 117 if (it.value().isa<TensorType>()) { 118 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 119 executeRegionOp.getLoc(), newOp->getResult(it.index()))); 120 } else { 121 newResults.push_back(newOp->getResult(it.index())); 122 } 123 } 124 125 // Replace old op. 126 rewriter.replaceOp(executeRegionOp, newResults); 127 128 return success(); 129 } 130 131 BufferRelation bufferRelation(Operation *op, OpResult opResult, 132 const AnalysisState &state) const { 133 return BufferRelation::Equivalent; 134 } 135 }; 136 137 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. 138 struct IfOpInterface 139 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { 140 SmallVector<OpOperand *> 141 getAliasingOpOperand(Operation *op, OpResult opResult, 142 const AnalysisState &state) const { 143 // IfOps do not have tensor OpOperands. The yielded value can be any SSA 144 // value that is in scope. To allow for use-def chain traversal through 145 // IfOps in the analysis, both corresponding yield values from the then/else 146 // branches are considered to be aliasing with the result. 147 auto ifOp = cast<scf::IfOp>(op); 148 size_t resultNum = std::distance(op->getOpResults().begin(), 149 llvm::find(op->getOpResults(), opResult)); 150 return {&ifOp.thenYield()->getOpOperand(resultNum), 151 &ifOp.elseYield()->getOpOperand(resultNum)}; 152 } 153 154 // TODO: For better bufferization results, this could return `true` only if 155 // there is a memory write in one (or both) of the branches. Since this is not 156 // allowed at the moment, we should never encounter scf.ifs that yield 157 // unmodified tensors. Such scf.yield ops could just fold away. 158 bool isMemoryWrite(Operation *op, OpResult opResult, 159 const AnalysisState &state) const { 160 // IfOp results are always considered memory writes in the analysis. This 161 // design decision simplifies the analysis considerably. E.g., consider the 162 // following test case: 163 // 164 // %0 = "some_writing_op" : tensor<?xf32> 165 // %r = scf.if %c -> (tensor<?xf32>) { 166 // scf.yield %0 167 // } else { 168 // %1 = "another_writing_op"(%0) : tensor<?xf32> 169 // } 170 // "some_reading_op"(%r) 171 // 172 // "another_writing_op" in the above example should be able to bufferize 173 // inplace in the absence of another read of %0. However, if the scf.if op 174 // would not be considered a "write", the analysis would detect the 175 // following conflict: 176 // 177 // * read = some_reading_op 178 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) 179 // * conflictingWrite = %1 180 // 181 // For more details, check the "scf.IfOp" section of the design document. 182 return true; 183 } 184 185 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 186 BufferizationState &state) const { 187 auto ifOp = cast<scf::IfOp>(op); 188 189 // Compute new types of the bufferized scf.if op. 190 SmallVector<Type> newTypes; 191 for (Type returnType : ifOp->getResultTypes()) { 192 if (auto tensorType = returnType.dyn_cast<TensorType>()) { 193 // TODO: Infer the result type instead of computing it. 194 newTypes.push_back(getMemRefType(tensorType, state.getOptions())); 195 } else { 196 newTypes.push_back(returnType); 197 } 198 } 199 200 // Create new op. 201 auto newIfOp = 202 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), 203 /*withElseRegion=*/true); 204 205 // Remove terminators. 206 if (!newIfOp.thenBlock()->empty()) { 207 rewriter.eraseOp(newIfOp.thenBlock()->getTerminator()); 208 rewriter.eraseOp(newIfOp.elseBlock()->getTerminator()); 209 } 210 211 // Move over then/else blocks. 212 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); 213 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); 214 215 // Update scf.yield of new then-block. 216 auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator()); 217 rewriter.setInsertionPoint(thenYieldOp); 218 SmallVector<Value> thenYieldValues; 219 for (OpOperand &operand : thenYieldOp->getOpOperands()) { 220 if (operand.get().getType().isa<TensorType>()) { 221 ensureToMemrefOpIsValid(operand.get(), 222 newTypes[operand.getOperandNumber()]); 223 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 224 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 225 operand.get()); 226 operand.set(toMemrefOp); 227 } 228 } 229 230 // Update scf.yield of new else-block. 231 auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator()); 232 rewriter.setInsertionPoint(elseYieldOp); 233 SmallVector<Value> elseYieldValues; 234 for (OpOperand &operand : elseYieldOp->getOpOperands()) { 235 if (operand.get().getType().isa<TensorType>()) { 236 ensureToMemrefOpIsValid(operand.get(), 237 newTypes[operand.getOperandNumber()]); 238 Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( 239 operand.get().getLoc(), newTypes[operand.getOperandNumber()], 240 operand.get()); 241 operand.set(toMemrefOp); 242 } 243 } 244 245 // Replace op results. 246 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); 247 248 return success(); 249 } 250 251 BufferRelation bufferRelation(Operation *op, OpResult opResult, 252 const AnalysisState &state) const { 253 // IfOp results are equivalent to their corresponding yield values if both 254 // yield values are equivalent to each other. 255 auto bufferizableOp = cast<BufferizableOpInterface>(op); 256 SmallVector<OpOperand *> yieldValues = 257 bufferizableOp.getAliasingOpOperand(opResult, state); 258 assert(yieldValues.size() == 2 && "expected 2 yield values"); 259 bool equivalentYields = state.areEquivalentBufferizedValues( 260 yieldValues[0]->get(), yieldValues[1]->get()); 261 return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; 262 } 263 }; 264 265 /// Helper function for loop bufferization. Return the indices of all values 266 /// that have a tensor type. 267 static DenseSet<int64_t> getTensorIndices(ValueRange values) { 268 DenseSet<int64_t> result; 269 for (const auto &it : llvm::enumerate(values)) 270 if (it.value().getType().isa<TensorType>()) 271 result.insert(it.index()); 272 return result; 273 } 274 275 /// Helper function for loop bufferization. Return the indices of all 276 /// bbArg/yielded value pairs who's buffer relation is "Equivalent". 277 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, 278 ValueRange yieldedValues, 279 const AnalysisState &state) { 280 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); 281 DenseSet<int64_t> result; 282 for (unsigned int i = 0; i < minSize; ++i) { 283 if (!bbArgs[i].getType().isa<TensorType>() || 284 !yieldedValues[i].getType().isa<TensorType>()) 285 continue; 286 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) 287 result.insert(i); 288 } 289 return result; 290 } 291 292 /// Helper function for loop bufferization. Cast the given buffer to the given 293 /// memref type. 294 static Value castBuffer(OpBuilder &b, Value buffer, Type type) { 295 assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType"); 296 assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType"); 297 // If the buffer already has the correct type, no cast is needed. 298 if (buffer.getType() == type) 299 return buffer; 300 // TODO: In case `type` has a layout map that is not the fully dynamic 301 // one, we may not be able to cast the buffer. In that case, the loop 302 // iter_arg's layout map must be changed (see uses of `castBuffer`). 303 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && 304 "scf.while op bufferization: cast incompatible"); 305 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); 306 } 307 308 /// Helper function for loop bufferization. Return the bufferized values of the 309 /// given OpOperands. If an operand is not a tensor, return the original value. 310 static SmallVector<Value> getBuffers(RewriterBase &rewriter, 311 MutableArrayRef<OpOperand> operands, 312 BufferizationState &state) { 313 SmallVector<Value> result; 314 for (OpOperand &opOperand : operands) { 315 if (opOperand.get().getType().isa<TensorType>()) { 316 FailureOr<Value> resultBuffer = state.getBuffer(rewriter, opOperand); 317 if (failed(resultBuffer)) 318 return {}; 319 result.push_back(*resultBuffer); 320 } else { 321 result.push_back(opOperand.get()); 322 } 323 } 324 return result; 325 } 326 327 /// Helper function for loop bufferization. Compute the buffer that should be 328 /// yielded from a loop block (loop body or loop condition). If the given tensor 329 /// is equivalent to the corresponding block argument (as indicated by 330 /// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer 331 /// copy must be yielded. 332 /// 333 /// According to the `BufferizableOpInterface` implementation of scf loops, a 334 /// a bufferized OpResult may alias only with the corresponding bufferized 335 /// init_arg and with no other buffers. I.e., the i-th OpResult may alias with 336 /// the i-th init_arg; but not with any other OpOperand. If a corresponding 337 /// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by 338 /// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we 339 /// cannot be sure and must yield a new buffer copy. (New buffer copies do not 340 /// alias with any buffer.) 341 static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, 342 BaseMemRefType type, bool isEquivalent, 343 BufferizationState &state) { 344 assert(tensor.getType().isa<TensorType>() && "expected tensor"); 345 ensureToMemrefOpIsValid(tensor, type); 346 Value yieldedVal = 347 bufferization::lookupBuffer(rewriter, tensor, state.getOptions()); 348 349 if (isEquivalent) 350 // Yielded value is equivalent to the corresponding iter_arg bbArg. 351 // Yield the value directly. Most IR should be like that. Everything 352 // else must be resolved with copies and is potentially inefficient. 353 // By default, such problematic IR would already have been rejected 354 // during `verifyAnalysis`, unless `allow-return-allocs`. 355 return castBuffer(rewriter, yieldedVal, type); 356 357 // It is not certain that the yielded value and the iter_arg bbArg 358 // have the same buffer. Allocate a new buffer and copy. The yielded 359 // buffer will get deallocated by `deallocateBuffers`. 360 361 // TODO: There are cases in which it is not neccessary to return a new 362 // buffer allocation. E.g., when equivalent values are yielded in a 363 // different order. This could be resolved with copies. 364 Optional<Value> yieldedAlloc = state.createAlloc( 365 rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false); 366 // TODO: We should rollback, but for now just assume that this always 367 // succeeds. 368 assert(yieldedAlloc.hasValue() && "could not create alloc"); 369 LogicalResult copyStatus = state.getOptions().createMemCpy( 370 rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc); 371 (void)copyStatus; 372 assert(succeeded(copyStatus) && "could not create memcpy"); 373 374 // The iter_arg memref type may have a layout map. Cast the new buffer 375 // to the same type if needed. 376 return castBuffer(rewriter, *yieldedAlloc, type); 377 } 378 379 /// Helper function for loop bufferization. Given a range of values, apply 380 /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified 381 /// value in the result vector. 382 static SmallVector<Value> 383 convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices, 384 llvm::function_ref<Value(Value, int64_t)> func) { 385 SmallVector<Value> result; 386 for (const auto &it : llvm::enumerate(values)) { 387 size_t idx = it.index(); 388 Value val = it.value(); 389 result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val); 390 } 391 return result; 392 } 393 394 /// Helper function for loop bufferization. Given a list of pre-bufferization 395 /// yielded values, compute the list of bufferized yielded values. 396 SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values, 397 TypeRange bufferizedTypes, 398 const DenseSet<int64_t> &tensorIndices, 399 const DenseSet<int64_t> &equivalentTensors, 400 BufferizationState &state) { 401 return convertTensorValues( 402 values, tensorIndices, [&](Value val, int64_t index) { 403 return getYieldedBuffer(rewriter, val, 404 bufferizedTypes[index].cast<BaseMemRefType>(), 405 equivalentTensors.contains(index), state); 406 }); 407 } 408 409 /// Helper function for loop bufferization. Given a list of bbArgs of the new 410 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into 411 /// ToTensorOps, so that the block body can be moved over to the new op. 412 SmallVector<Value> 413 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, 414 const DenseSet<int64_t> &tensorIndices) { 415 return convertTensorValues( 416 bbArgs, tensorIndices, [&](Value val, int64_t index) { 417 return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val); 418 }); 419 } 420 421 /// Bufferization of scf.for. Replace with a new scf.for that operates on 422 /// memrefs. 423 struct ForOpInterface 424 : public BufferizableOpInterface::ExternalModel<ForOpInterface, 425 scf::ForOp> { 426 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 427 const AnalysisState &state) const { 428 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of 429 // its matching bbArg may. 430 auto forOp = cast<scf::ForOp>(op); 431 return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); 432 } 433 434 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 435 const AnalysisState &state) const { 436 // Tensor iter_args of scf::ForOps are always considered as a write. 437 return true; 438 } 439 440 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 441 const AnalysisState &state) const { 442 auto forOp = cast<scf::ForOp>(op); 443 return {forOp.getResultForOpOperand(opOperand)}; 444 } 445 446 BufferRelation bufferRelation(Operation *op, OpResult opResult, 447 const AnalysisState &state) const { 448 // ForOp results are equivalent to their corresponding init_args if the 449 // corresponding iter_args and yield values are equivalent. 450 auto forOp = cast<scf::ForOp>(op); 451 OpOperand &forOperand = forOp.getOpOperandForResult(opResult); 452 auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); 453 auto yieldOp = 454 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 455 bool equivalentYield = state.areEquivalentBufferizedValues( 456 bbArg, yieldOp->getOperand(opResult.getResultNumber())); 457 return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; 458 } 459 460 bool isWritable(Operation *op, Value value, 461 const AnalysisState &state) const { 462 // Interestingly, scf::ForOp's bbArg can **always** be viewed 463 // inplace from the perspective of ops nested under: 464 // 1. Either the matching iter operand is not bufferized inplace and an 465 // alloc + optional copy makes the bbArg itself inplaceable. 466 // 2. Or the matching iter operand is bufferized inplace and bbArg just 467 // bufferizes to that too. 468 return true; 469 } 470 471 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 472 const AnalysisState &state) const { 473 auto bufferizableOp = cast<BufferizableOpInterface>(op); 474 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 475 return failure(); 476 477 if (!state.getOptions().enforceAliasingInvariants) 478 return success(); 479 480 // According to the `getAliasing...` implementations, a bufferized OpResult 481 // may alias only with the corresponding bufferized init_arg and with no 482 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; 483 // but not with any other OpOperand. If a corresponding OpResult/init_arg 484 // pair bufferizes to equivalent buffers, this aliasing requirement is 485 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. 486 // (New buffer copies do not alias with any buffer.) 487 auto forOp = cast<scf::ForOp>(op); 488 auto yieldOp = 489 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 490 OpBuilder::InsertionGuard g(rewriter); 491 rewriter.setInsertionPoint(yieldOp); 492 493 // Indices of all iter_args that have tensor type. These are the ones that 494 // are bufferized. 495 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 496 // For every yielded value, is the value equivalent to its corresponding 497 // bbArg? 498 DenseSet<int64_t> equivalentYields = getEquivalentBuffers( 499 forOp.getRegionIterArgs(), yieldOp.getResults(), state); 500 SmallVector<Value> yieldValues; 501 for (int64_t idx = 0; 502 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) { 503 Value value = yieldOp.getResults()[idx]; 504 if (!indices.contains(idx) || equivalentYields.contains(idx)) { 505 yieldValues.push_back(value); 506 continue; 507 } 508 Value alloc = rewriter.create<bufferization::AllocTensorOp>( 509 yieldOp.getLoc(), value.getType().cast<RankedTensorType>(), 510 /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); 511 yieldValues.push_back(alloc); 512 } 513 514 rewriter.updateRootInPlace( 515 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); 516 return success(); 517 } 518 519 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 520 BufferizationState &state) const { 521 auto forOp = cast<scf::ForOp>(op); 522 auto oldYieldOp = 523 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 524 Block *oldLoopBody = &forOp.getLoopBody().front(); 525 526 // Indices of all iter_args that have tensor type. These are the ones that 527 // are bufferized. 528 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); 529 // For every yielded value, is the value equivalent to its corresponding 530 // bbArg? 531 DenseSet<int64_t> equivalentYields = 532 getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(), 533 state.getAnalysisState()); 534 535 // The new memref init_args of the loop. 536 SmallVector<Value> initArgs = 537 getBuffers(rewriter, forOp.getIterOpOperands(), state); 538 539 // Construct a new scf.for op with memref instead of tensor values. 540 auto newForOp = rewriter.create<scf::ForOp>( 541 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 542 forOp.getStep(), initArgs); 543 newForOp->setAttrs(forOp->getAttrs()); 544 ValueRange initArgsRange(initArgs); 545 TypeRange initArgsTypes(initArgsRange); 546 Block *loopBody = &newForOp.getLoopBody().front(); 547 548 // Set up new iter_args. The loop body uses tensors, so wrap the (memref) 549 // iter_args of the new loop in ToTensorOps. 550 rewriter.setInsertionPointToStart(loopBody); 551 SmallVector<Value> iterArgs = 552 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); 553 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); 554 555 // Erase terminator if present. 556 if (iterArgs.size() == 1) 557 rewriter.eraseOp(loopBody->getTerminator()); 558 559 // Move loop body to new loop. 560 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); 561 562 // Update scf.yield of new loop. 563 auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator()); 564 rewriter.setInsertionPoint(yieldOp); 565 SmallVector<Value> yieldValues = 566 getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices, 567 equivalentYields, state); 568 yieldOp.getResultsMutable().assign(yieldValues); 569 570 // Replace loop results. 571 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); 572 573 return success(); 574 } 575 576 /// Assert that yielded values of an scf.for op are equivalent to their 577 /// corresponding bbArgs. In that case, the buffer relations of the 578 /// corresponding OpResults are "Equivalent". 579 /// 580 /// If this is not the case, an allocs+copies are inserted and yielded from 581 /// the loop. This could be a performance problem, so it must be explicitly 582 /// activated with `alloc-return-allocs`. 583 LogicalResult verifyAnalysis(Operation *op, 584 const AnalysisState &state) const { 585 const auto &options = 586 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 587 if (options.allowReturnAllocs) 588 return success(); 589 590 auto forOp = cast<scf::ForOp>(op); 591 auto yieldOp = 592 cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator()); 593 for (OpResult opResult : op->getOpResults()) { 594 if (!opResult.getType().isa<TensorType>()) 595 continue; 596 597 // Note: This is overly strict. We should check for aliasing bufferized 598 // values. But we don't have a "must-alias" analysis yet. 599 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) 600 return yieldOp->emitError() 601 << "Yield operand #" << opResult.getResultNumber() 602 << " is not equivalent to the corresponding iter bbArg"; 603 } 604 605 return success(); 606 } 607 }; 608 609 /// Bufferization of scf.while. Replace with a new scf.while that operates on 610 /// memrefs. 611 struct WhileOpInterface 612 : public BufferizableOpInterface::ExternalModel<WhileOpInterface, 613 scf::WhileOp> { 614 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 615 const AnalysisState &state) const { 616 // Tensor iter_args of scf::WhileOps are always considered as a read. 617 return true; 618 } 619 620 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 621 const AnalysisState &state) const { 622 // Tensor iter_args of scf::WhileOps are always considered as a write. 623 return true; 624 } 625 626 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 627 const AnalysisState &state) const { 628 auto whileOp = cast<scf::WhileOp>(op); 629 unsigned int idx = opOperand.getOperandNumber(); 630 631 // The OpResults and OpOperands may not match. They may not even have the 632 // same type. The number of OpResults and OpOperands can also differ. 633 if (idx >= op->getNumResults() || 634 opOperand.get().getType() != op->getResult(idx).getType()) 635 return {}; 636 637 // The only aliasing OpResult may be the one at the same index. 638 return {whileOp->getResult(idx)}; 639 } 640 641 BufferRelation bufferRelation(Operation *op, OpResult opResult, 642 const AnalysisState &state) const { 643 // WhileOp results are equivalent to their corresponding init_args if the 644 // corresponding iter_args and yield values are equivalent (for both the 645 // "before" and the "after" block). 646 unsigned int resultNumber = opResult.getResultNumber(); 647 auto whileOp = cast<scf::WhileOp>(op); 648 649 // The "before" region bbArgs and the OpResults may not match. 650 if (resultNumber >= whileOp.getBeforeArguments().size()) 651 return BufferRelation::None; 652 if (opResult.getType() != 653 whileOp.getBeforeArguments()[resultNumber].getType()) 654 return BufferRelation::None; 655 656 auto conditionOp = whileOp.getConditionOp(); 657 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; 658 Value conditionOperand = conditionOp.getArgs()[resultNumber]; 659 bool equivCondition = 660 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); 661 662 auto yieldOp = whileOp.getYieldOp(); 663 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; 664 Value yieldOperand = yieldOp.getOperand(resultNumber); 665 bool equivYield = 666 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); 667 668 return equivCondition && equivYield ? BufferRelation::Equivalent 669 : BufferRelation::None; 670 } 671 672 bool isWritable(Operation *op, Value value, 673 const AnalysisState &state) const { 674 // Interestingly, scf::WhileOp's bbArg can **always** be viewed 675 // inplace from the perspective of ops nested under: 676 // 1. Either the matching iter operand is not bufferized inplace and an 677 // alloc + optional copy makes the bbArg itself inplaceable. 678 // 2. Or the matching iter operand is bufferized inplace and bbArg just 679 // bufferizes to that too. 680 return true; 681 } 682 683 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 684 const AnalysisState &state) const { 685 auto bufferizableOp = cast<BufferizableOpInterface>(op); 686 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 687 return failure(); 688 689 if (!state.getOptions().enforceAliasingInvariants) 690 return success(); 691 692 // According to the `getAliasing...` implementations, a bufferized OpResult 693 // may alias only with the corresponding bufferized init_arg and with no 694 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; 695 // but not with any other OpOperand. If a corresponding OpResult/init_arg 696 // pair bufferizes to equivalent buffers, this aliasing requirement is 697 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. 698 // (New buffer copies do not alias with any buffer.) 699 OpBuilder::InsertionGuard g(rewriter); 700 auto whileOp = cast<scf::WhileOp>(op); 701 auto conditionOp = whileOp.getConditionOp(); 702 auto yieldOp = whileOp.getYieldOp(); 703 704 // Indices of all bbArgs that have tensor type. These are the ones that 705 // are bufferized. The "before" and "after" regions may have different args. 706 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); 707 DenseSet<int64_t> indicesAfter = 708 getTensorIndices(whileOp.getAfterArguments()); 709 710 // For every yielded value, is the value equivalent to its corresponding 711 // bbArg? 712 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 713 whileOp.getBeforeArguments(), conditionOp.getArgs(), state); 714 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 715 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); 716 717 // Update "before" region. 718 rewriter.setInsertionPoint(conditionOp); 719 SmallVector<Value> beforeYieldValues; 720 for (int64_t idx = 0; 721 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { 722 Value value = conditionOp.getArgs()[idx]; 723 if (!indicesBefore.contains(idx) || 724 equivalentYieldsBefore.contains(idx)) { 725 beforeYieldValues.push_back(value); 726 continue; 727 } 728 Value alloc = rewriter.create<bufferization::AllocTensorOp>( 729 conditionOp.getLoc(), value.getType().cast<RankedTensorType>(), 730 /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); 731 beforeYieldValues.push_back(alloc); 732 } 733 rewriter.updateRootInPlace(conditionOp, [&]() { 734 conditionOp.getArgsMutable().assign(beforeYieldValues); 735 }); 736 737 // Update "after" region. 738 rewriter.setInsertionPoint(yieldOp); 739 SmallVector<Value> afterYieldValues; 740 for (int64_t idx = 0; 741 idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) { 742 Value value = yieldOp.getResults()[idx]; 743 if (!indicesAfter.contains(idx) || equivalentYieldsAfter.contains(idx)) { 744 afterYieldValues.push_back(value); 745 continue; 746 } 747 Value alloc = rewriter.create<bufferization::AllocTensorOp>( 748 yieldOp.getLoc(), value.getType().cast<RankedTensorType>(), 749 /*dynamicSizes=*/ValueRange(), value, /*escape=*/true); 750 afterYieldValues.push_back(alloc); 751 } 752 rewriter.updateRootInPlace(yieldOp, [&]() { 753 yieldOp.getResultsMutable().assign(afterYieldValues); 754 }); 755 756 return success(); 757 } 758 759 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 760 BufferizationState &state) const { 761 auto whileOp = cast<scf::WhileOp>(op); 762 763 assert(whileOp.getBefore().getBlocks().size() == 1 && 764 "regions with multiple blocks not supported"); 765 Block *beforeBody = &whileOp.getBefore().front(); 766 assert(whileOp.getAfter().getBlocks().size() == 1 && 767 "regions with multiple blocks not supported"); 768 Block *afterBody = &whileOp.getAfter().front(); 769 770 // Indices of all bbArgs that have tensor type. These are the ones that 771 // are bufferized. The "before" and "after" regions may have different args. 772 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); 773 DenseSet<int64_t> indicesAfter = 774 getTensorIndices(whileOp.getAfterArguments()); 775 776 // For every yielded value, is the value equivalent to its corresponding 777 // bbArg? 778 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( 779 whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(), 780 state.getAnalysisState()); 781 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( 782 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), 783 state.getAnalysisState()); 784 785 // The new memref init_args of the loop. 786 SmallVector<Value> initArgs = 787 getBuffers(rewriter, whileOp->getOpOperands(), state); 788 789 // The result types of a WhileOp are the same as the "after" bbArg types. 790 SmallVector<Type> argsTypesAfter = llvm::to_vector( 791 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { 792 return state.getBufferType(bbArg).cast<Type>(); 793 })); 794 795 // Construct a new scf.while op with memref instead of tensor values. 796 ValueRange argsRangeBefore(initArgs); 797 TypeRange argsTypesBefore(argsRangeBefore); 798 auto newWhileOp = rewriter.create<scf::WhileOp>(whileOp.getLoc(), 799 argsTypesAfter, initArgs); 800 801 // Add before/after regions to the new op. 802 SmallVector<Location> bbArgLocsBefore(initArgs.size(), whileOp.getLoc()); 803 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), 804 whileOp.getLoc()); 805 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); 806 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); 807 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); 808 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); 809 810 // Set up new iter_args and move the loop condition block to the new op. 811 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 812 // in ToTensorOps. 813 rewriter.setInsertionPointToStart(newBeforeBody); 814 SmallVector<Value> newBeforeArgs = getBbArgReplacements( 815 rewriter, newWhileOp.getBeforeArguments(), indicesBefore); 816 rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); 817 818 // Update scf.condition of new loop. 819 auto newConditionOp = newWhileOp.getConditionOp(); 820 rewriter.setInsertionPoint(newConditionOp); 821 // Only equivalent buffers or new buffer allocations may be yielded to the 822 // "after" region. 823 // TODO: This could be relaxed for better bufferization results. 824 SmallVector<Value> newConditionArgs = 825 getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, 826 indicesAfter, equivalentYieldsBefore, state); 827 newConditionOp.getArgsMutable().assign(newConditionArgs); 828 829 // Set up new iter_args and move the loop body block to the new op. 830 // The old block uses tensors, so wrap the (memref) bbArgs of the new block 831 // in ToTensorOps. 832 rewriter.setInsertionPointToStart(newAfterBody); 833 SmallVector<Value> newAfterArgs = getBbArgReplacements( 834 rewriter, newWhileOp.getAfterArguments(), indicesAfter); 835 rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); 836 837 // Update scf.yield of the new loop. 838 auto newYieldOp = newWhileOp.getYieldOp(); 839 rewriter.setInsertionPoint(newYieldOp); 840 // Only equivalent buffers or new buffer allocations may be yielded to the 841 // "before" region. 842 // TODO: This could be relaxed for better bufferization results. 843 SmallVector<Value> newYieldValues = 844 getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, 845 indicesBefore, equivalentYieldsAfter, state); 846 newYieldOp.getResultsMutable().assign(newYieldValues); 847 848 // Replace loop results. 849 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); 850 851 return success(); 852 } 853 854 /// Assert that yielded values of an scf.while op are equivalent to their 855 /// corresponding bbArgs. In that case, the buffer relations of the 856 /// corresponding OpResults are "Equivalent". 857 /// 858 /// If this is not the case, allocs+copies are inserted and yielded from 859 /// the loop. This could be a performance problem, so it must be explicitly 860 /// activated with `alloc-return-allocs`. 861 /// 862 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the 863 /// equivalence condition must be checked for both. 864 LogicalResult verifyAnalysis(Operation *op, 865 const AnalysisState &state) const { 866 auto whileOp = cast<scf::WhileOp>(op); 867 const auto &options = 868 static_cast<const OneShotBufferizationOptions &>(state.getOptions()); 869 if (options.allowReturnAllocs) 870 return success(); 871 872 auto conditionOp = whileOp.getConditionOp(); 873 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { 874 if (!it.value().getType().isa<TensorType>()) 875 continue; 876 if (!state.areEquivalentBufferizedValues( 877 it.value(), conditionOp->getBlock()->getArgument(it.index()))) 878 return conditionOp->emitError() 879 << "Condition arg #" << it.index() 880 << " is not equivalent to the corresponding iter bbArg"; 881 } 882 883 auto yieldOp = whileOp.getYieldOp(); 884 for (const auto &it : llvm::enumerate(yieldOp.getResults())) { 885 if (!it.value().getType().isa<TensorType>()) 886 continue; 887 if (!state.areEquivalentBufferizedValues( 888 it.value(), yieldOp->getBlock()->getArgument(it.index()))) 889 return yieldOp->emitError() 890 << "Yield operand #" << it.index() 891 << " is not equivalent to the corresponding iter bbArg"; 892 } 893 894 return success(); 895 } 896 }; 897 898 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so 899 /// this is for analysis only. 900 struct YieldOpInterface 901 : public BufferizableOpInterface::ExternalModel<YieldOpInterface, 902 scf::YieldOp> { 903 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 904 const AnalysisState &state) const { 905 return true; 906 } 907 908 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 909 const AnalysisState &state) const { 910 return false; 911 } 912 913 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 914 const AnalysisState &state) const { 915 if (isa<scf::IfOp>(op->getParentOp())) 916 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 917 if (isa<scf::ExecuteRegionOp>(op->getParentOp())) 918 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 919 return {}; 920 } 921 922 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 923 const AnalysisState &state) const { 924 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 925 // may be generated inside the block. We should not return/yield allocations 926 // when possible. 927 return true; 928 } 929 930 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 931 BufferizationState &state) const { 932 auto yieldOp = cast<scf::YieldOp>(op); 933 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>( 934 yieldOp->getParentOp())) 935 return yieldOp->emitError("unsupported scf::YieldOp parent"); 936 return success(); 937 } 938 }; 939 940 using tensor::ExtractSliceOp; 941 942 /// Return the destinations that an ForeachThreadOp is inserting into. One per 943 /// ParallelInsertSliceOp. 944 static SmallVector<OpOperand *> 945 getInsertionDest(ForeachThreadOp foreachThreadOp) { 946 PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator(); 947 SmallVector<OpOperand *> result; 948 terminator.walk([&](ParallelInsertSliceOp insertOp) { 949 result.push_back(&insertOp->getOpOperand(1) /*dest*/); 950 }); 951 return result; 952 } 953 954 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the 955 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp 956 /// and ParallelInsertSliceOp), but these are only used during analysis. Not 957 /// for bufferization. 958 struct ForeachThreadOpInterface 959 : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface, 960 ForeachThreadOp> { 961 SmallVector<OpOperand *> 962 getAliasingOpOperand(Operation *op, OpResult opResult, 963 const AnalysisState &state) const { 964 // Get OpOperand (dest) from corresponding ParallelInsertSliceOp. 965 auto foreachThreadOp = cast<ForeachThreadOp>(op); 966 return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]}; 967 } 968 969 bool isMemoryWrite(Operation *op, OpResult opResult, 970 const AnalysisState &state) const { 971 // This op is a memory write. Stop lookup here to avoid finding false 972 // conflicts involving this op and one of the ops in the region. This is 973 // similar to how scf.if ops are analyzed. 974 return true; 975 } 976 977 BufferRelation bufferRelation(Operation *op, OpResult opResult, 978 const AnalysisState &state) const { 979 return BufferRelation::Equivalent; 980 } 981 982 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 983 const AnalysisState &state) const { 984 auto bufferizableOp = cast<BufferizableOpInterface>(op); 985 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) 986 return failure(); 987 988 OpBuilder::InsertionGuard g(rewriter); 989 auto foreachThreadOp = cast<ForeachThreadOp>(op); 990 for (OpResult opResult : foreachThreadOp->getOpResults()) { 991 SmallVector<OpOperand *> destOperands = 992 state.getAliasingOpOperand(opResult); 993 assert(destOperands.size() == 1 && 994 "expected exactly one aliasing OpOperand"); 995 assert(isa<ParallelInsertSliceOp>(destOperands.front()->getOwner()) && 996 "expected ParallelInsertSliceOp"); 997 998 // Nothing to do if there is no conflict. 999 if (state.isInPlace(*destOperands.front())) 1000 continue; 1001 1002 // Create AllocTensorOp. 1003 bool isYielded = state.isTensorYielded(opResult); 1004 auto resultType = opResult.getType().cast<RankedTensorType>(); 1005 Value alloc = rewriter.create<bufferization::AllocTensorOp>( 1006 op->getLoc(), resultType, /*dynamicDims=*/ValueRange(), 1007 /*copy=*/destOperands.front()->get(), 1008 /*escape=*/isYielded); 1009 1010 // Update terminator operand. 1011 rewriter.updateRootInPlace(destOperands.front()->getOwner(), 1012 [&]() { destOperands.front()->set(alloc); }); 1013 } 1014 1015 return success(); 1016 } 1017 1018 LogicalResult bufferize(Operation *op, RewriterBase &b, 1019 BufferizationState &state) const { 1020 OpBuilder::InsertionGuard g(b); 1021 auto foreachThreadOp = cast<ForeachThreadOp>(op); 1022 1023 // Gather new results of the ForeachThreadOp. 1024 SmallVector<Value> newResults; 1025 for (OpResult opResult : foreachThreadOp->getOpResults()) { 1026 SmallVector<OpOperand *> insertDestOperands = 1027 state.getAnalysisState().getAliasingOpOperand(opResult); 1028 assert(insertDestOperands.size() == 1 && 1029 "expected exactly one aliasing OpOperand"); 1030 // Insert copies right before the PerformConcurrentlyOp terminator. They 1031 // should not be inside terminator (which would be the default insertion 1032 // point). 1033 Value buffer = *state.getBuffer(b, *insertDestOperands.front(), 1034 /*forceInPlace=*/llvm::None, 1035 /*customCopyInsertionPoint=*/op); 1036 newResults.push_back(buffer); 1037 } 1038 1039 // Create new ForeachThreadOp without any results and drop the automatically 1040 // introduced terminator. 1041 TypeRange newResultTypes; 1042 auto newForeachThreadOp = 1043 b.create<ForeachThreadOp>(foreachThreadOp.getLoc(), newResultTypes, 1044 foreachThreadOp.getNumThreads()); 1045 newForeachThreadOp.getBody()->getTerminator()->erase(); 1046 1047 // Move over block contents of the old op. 1048 b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(), 1049 {newForeachThreadOp.getBody()->getArguments()}); 1050 1051 // Bufferize terminator. 1052 auto performConcurrentlyOp = cast<PerformConcurrentlyOp>( 1053 newForeachThreadOp.getBody()->getTerminator()); 1054 b.setInsertionPoint(performConcurrentlyOp); 1055 unsigned resultCounter = 0; 1056 WalkResult walkResult = 1057 performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) { 1058 Location loc = insertOp.getLoc(); 1059 Type srcType = getMemRefType( 1060 insertOp.getSource().getType().cast<RankedTensorType>(), 1061 state.getOptions()); 1062 // ParallelInsertSliceOp bufferizes to a copy. 1063 auto srcMemref = b.create<bufferization::ToMemrefOp>( 1064 loc, srcType, insertOp.getSource()); 1065 Value destMemref = newResults[resultCounter++]; 1066 Value subview = b.create<memref::SubViewOp>( 1067 loc, destMemref, insertOp.getMixedOffsets(), 1068 insertOp.getMixedSizes(), insertOp.getMixedStrides()); 1069 // This memcpy will fold away if everything bufferizes in-place. 1070 if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(), 1071 srcMemref, subview))) 1072 return WalkResult::interrupt(); 1073 b.eraseOp(insertOp); 1074 return WalkResult::advance(); 1075 }); 1076 if (walkResult.wasInterrupted()) 1077 return failure(); 1078 1079 // Replace the op. 1080 replaceOpWithBufferizedValues(b, op, newResults); 1081 1082 return success(); 1083 } 1084 }; 1085 1086 /// Nothing to do for PerformConcurrentlyOp. 1087 struct PerformConcurrentlyOpInterface 1088 : public BufferizableOpInterface::ExternalModel< 1089 PerformConcurrentlyOpInterface, PerformConcurrentlyOp> { 1090 LogicalResult bufferize(Operation *op, RewriterBase &b, 1091 BufferizationState &state) const { 1092 assert(false && "op does not have any tensor OpOperands / OpResults"); 1093 return failure(); 1094 } 1095 }; 1096 1097 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. 1098 /// equivalent operand / result and same offset/sizes/strides specification). 1099 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 1100 ExtractSliceOp st, 1101 ParallelInsertSliceOp sti) { 1102 if (!st || !sti) 1103 return false; 1104 if (st != sti && 1105 !state.areEquivalentBufferizedValues(st.source(), sti.getDest())) 1106 return false; 1107 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 1108 return false; 1109 return true; 1110 } 1111 1112 /// Return true if `value` is originating from an ExtractSliceOp that matches 1113 /// the given InsertSliceOp. 1114 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 1115 ParallelInsertSliceOp insertOp) { 1116 auto condition = [&](Value val) { 1117 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 1118 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 1119 return true; 1120 return false; 1121 }; 1122 1123 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 1124 condition); 1125 } 1126 1127 /// Analysis of ParallelInsertSliceOp. 1128 struct ParallelInsertSliceOpInterface 1129 : public BufferizableOpInterface::ExternalModel< 1130 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { 1131 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1132 const AnalysisState &state) const { 1133 if (&opOperand != &op->getOpOperand(1) /*dest*/) 1134 return {}; 1135 1136 // ParallelInsertSliceOp itself has no results. Tensors are returned via 1137 // the parent op. 1138 auto foreachThreadOp = op->getParentOfType<ForeachThreadOp>(); 1139 assert(foreachThreadOp && 1140 "could not find valid owner of parallel_insert_slice"); 1141 1142 // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult 1143 // of the parent ForeachThreadOp. 1144 Block *block = op->getBlock(); 1145 unsigned int opIdx = 0; 1146 for (ParallelInsertSliceOp insertOp : 1147 block->getOps<ParallelInsertSliceOp>()) { 1148 if (insertOp.getOperation() == op) 1149 break; 1150 ++opIdx; 1151 } 1152 assert(opIdx < foreachThreadOp->getNumResults() && 1153 "could not find op inside terminator op"); 1154 1155 return {foreachThreadOp->getResult(opIdx)}; 1156 } 1157 1158 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1159 const AnalysisState &state) const { 1160 return true; 1161 } 1162 1163 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1164 const AnalysisState &state) const { 1165 return &opOperand == &op->getOpOperand(1) /*dest*/; 1166 } 1167 1168 BufferRelation bufferRelation(Operation *op, OpResult opResult, 1169 const AnalysisState &state) const { 1170 return BufferRelation::Equivalent; 1171 } 1172 1173 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 1174 const AnalysisState &state) const { 1175 return success(); 1176 } 1177 1178 LogicalResult bufferize(Operation *op, RewriterBase &b, 1179 BufferizationState &state) const { 1180 // Will be bufferized as part of ForeachThreadOp. 1181 return failure(); 1182 } 1183 1184 // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share 1185 // the code. 1186 bool isNotConflicting(Operation *op, OpOperand *uRead, 1187 OpOperand *uConflictingWrite, 1188 const AnalysisState &state) const { 1189 Operation *readingOp = uRead->getOwner(); 1190 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 1191 1192 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 1193 // uRead is an InsertSliceOp... 1194 if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) { 1195 // As an example, consider the following IR. 1196 // 1197 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1198 // %1 = linalg.fill %cst, %0 {inplace= [true] } 1199 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1200 // {inplace= [true] } 1201 1202 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 1203 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 1204 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 1205 insertSliceOp)) 1206 // Case 1: The main insight is that InsertSliceOp reads only part of 1207 // the destination tensor. The overwritten area is not read. If 1208 // uConflictingWrite writes into exactly the memory location that is 1209 // being read by uRead, this is not a conflict. 1210 // 1211 // In the above example: 1212 // uRead = OpOperand 1 (%t) of tensor.insert_slice 1213 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 1214 // 1215 // The read of %t does not conflict with the write of the FillOp 1216 // (same aliases!) because the area that the FillOp operates on is 1217 // exactly the one that is *not* read via %t. 1218 return true; 1219 1220 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 1221 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1222 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 1223 // Case 2: The read of the source tensor and the write to the dest 1224 // tensor via an InsertSliceOp is not a conflict if the read is 1225 // reading exactly that part of an equivalent tensor that the 1226 // InsertSliceOp is writing. 1227 // 1228 // In the above example: 1229 // uRead = OpOperand 0 (%1) of tensor.insert_slice 1230 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1231 return true; 1232 } 1233 1234 // If uConflictingWrite is an InsertSliceOp... 1235 if (auto insertSliceOp = 1236 dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp)) 1237 // As an example, consider the following IR. 1238 // 1239 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1240 // %1 = linalg.fill %cst, %0 {inplace= [true] } 1241 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1242 // {inplace= [true] } 1243 // %3 = vector.transfer_read %1, %cst 1244 // 1245 // In the above example: 1246 // uRead = OpOperand 0 (%1) of vector.transfer_read 1247 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1248 // lastWrite = %1 1249 // 1250 // This is not a conflict because the InsertSliceOp overwrites the 1251 // memory segment of %1 with the exact same data. (Effectively, there 1252 // is no memory write here.) 1253 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1254 state.areEquivalentBufferizedValues(uRead->get(), 1255 insertSliceOp.getSource()) && 1256 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 1257 insertSliceOp)) 1258 return true; 1259 1260 return false; 1261 } 1262 }; 1263 1264 } // namespace 1265 } // namespace scf 1266 } // namespace mlir 1267 1268 void mlir::scf::registerBufferizableOpInterfaceExternalModels( 1269 DialectRegistry ®istry) { 1270 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { 1271 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); 1272 ForOp::attachInterface<ForOpInterface>(*ctx); 1273 IfOp::attachInterface<IfOpInterface>(*ctx); 1274 ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx); 1275 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 1276 *ctx); 1277 PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>( 1278 *ctx); 1279 WhileOp::attachInterface<WhileOpInterface>(*ctx); 1280 YieldOp::attachInterface<YieldOpInterface>(*ctx); 1281 }); 1282 } 1283