1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const BufferizationState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const BufferizationState &state) const { 35 return false; 36 } 37 38 SmallVector<OpResult> 39 getAliasingOpResult(Operation *op, OpOperand &opOperand, 40 const BufferizationState &state) const { 41 return {op->getResult(0)}; 42 } 43 44 BufferRelation bufferRelation(Operation *op, OpResult opResult, 45 const BufferizationState &state) const { 46 return BufferRelation::Equivalent; 47 } 48 49 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 50 const BufferizationState &state) const { 51 auto castOp = cast<tensor::CastOp>(op); 52 53 // The result buffer still has the old (pre-cast) type. 54 FailureOr<Value> resultBuffer = 55 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 56 if (failed(resultBuffer)) 57 return failure(); 58 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 59 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 60 TensorType resultTensorType = 61 castOp.getResult().getType().cast<TensorType>(); 62 MemRefLayoutAttrInterface layout; 63 64 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 65 if (resultTensorType.isa<RankedTensorType>()) 66 layout = rankedMemRefType.getLayout(); 67 68 // Compute the new memref type. 69 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 70 layout, memorySpace); 71 72 // Replace the op with a memref.cast. 73 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 74 resultMemRefType) && 75 "CallOp::bufferize: cast incompatible"); 76 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 77 *resultBuffer); 78 79 return success(); 80 } 81 }; 82 83 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 84 struct CollapseShapeOpInterface 85 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 86 tensor::CollapseShapeOp> { 87 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 88 const BufferizationState &state) const { 89 return false; 90 } 91 92 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 93 const BufferizationState &state) const { 94 return false; 95 } 96 97 SmallVector<OpResult> 98 getAliasingOpResult(Operation *op, OpOperand &opOperand, 99 const BufferizationState &state) const { 100 if (&opOperand == &op->getOpOperand(0) /*src*/) 101 return {op->getOpResult(0)}; 102 return {}; 103 } 104 105 BufferRelation bufferRelation(Operation *op, OpResult opResult, 106 const BufferizationState &state) const { 107 return BufferRelation::Equivalent; 108 } 109 110 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 111 const BufferizationState &state) const { 112 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 113 Value buffer = 114 *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); 115 Type resultType = 116 getMemRefType(collapseShapeOp.getResultType(), state.getOptions()); 117 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 118 rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 119 return success(); 120 } 121 }; 122 123 /// Bufferization of tensor.dim. Replace with memref.dim. 124 struct DimOpInterface 125 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 126 tensor::DimOp> { 127 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 128 const BufferizationState &state) const { 129 return true; 130 } 131 132 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 133 const BufferizationState &state) const { 134 return false; 135 } 136 137 SmallVector<OpResult> 138 getAliasingOpResult(Operation *op, OpOperand &opOperand, 139 const BufferizationState &state) const { 140 return {}; 141 } 142 143 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 144 const BufferizationState &state) const { 145 auto dimOp = cast<tensor::DimOp>(op); 146 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 147 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 148 return success(); 149 } 150 }; 151 152 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 153 struct ExpandShapeOpInterface 154 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 155 tensor::ExpandShapeOp> { 156 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 157 const BufferizationState &state) const { 158 return false; 159 } 160 161 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 162 const BufferizationState &state) const { 163 return false; 164 } 165 166 SmallVector<OpResult> 167 getAliasingOpResult(Operation *op, OpOperand &opOperand, 168 const BufferizationState &state) const { 169 if (&opOperand == &op->getOpOperand(0) /*src*/) 170 return {op->getOpResult(0)}; 171 return {}; 172 } 173 174 BufferRelation bufferRelation(Operation *op, OpResult opResult, 175 const BufferizationState &state) const { 176 return BufferRelation::Equivalent; 177 } 178 179 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 180 const BufferizationState &state) const { 181 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 182 Value buffer = 183 *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 184 Type resultType = 185 getMemRefType(expandShapeOp.getResultType(), state.getOptions()); 186 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 187 rewriter, op, resultType, buffer, expandShapeOp.reassociation()); 188 return success(); 189 } 190 }; 191 192 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 193 struct ExtractSliceOpInterface 194 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 195 tensor::ExtractSliceOp> { 196 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 197 const BufferizationState &state) const { 198 return false; 199 } 200 201 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 202 const BufferizationState &state) const { 203 return false; 204 } 205 206 SmallVector<OpResult> 207 getAliasingOpResult(Operation *op, OpOperand &opOperand, 208 const BufferizationState &state) const { 209 if (&opOperand == &op->getOpOperand(0) /*source*/) 210 return {op->getOpResult(0)}; 211 return {}; 212 } 213 214 BufferRelation bufferRelation(Operation *op, OpResult opResult, 215 const BufferizationState &state) const { 216 return BufferRelation::None; 217 } 218 219 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 220 const BufferizationState &state) const { 221 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 222 Location loc = extractSliceOp.getLoc(); 223 Value srcMemref = 224 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 225 /*forceInPlace=*/true); 226 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 227 auto dstTensorType = 228 extractSliceOp.result().getType().cast<RankedTensorType>(); 229 230 // If not inplaceable, alloc. 231 bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0)); 232 Value alloc; 233 if (!inplace) { 234 FailureOr<Value> allocOrFailure = 235 createAlloc(rewriter, loc, extractSliceOp.result(), 236 state.getOptions().createDeallocs, state.getOptions()); 237 if (failed(allocOrFailure)) 238 return failure(); 239 alloc = *allocOrFailure; 240 } 241 242 // Expand offsets, sizes and strides to the full rank to handle the 243 // rank-reducing case. 244 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 245 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 246 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 247 OffsetSizeAndStrideOpInterface::expandToRank( 248 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 249 [&](Value target, int64_t dim) -> OpFoldResult { 250 auto shapedType = target.getType().cast<ShapedType>(); 251 if (shapedType.isDynamicDim(dim)) 252 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 253 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 254 }); 255 // Bufferize to subview. 256 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 257 dstTensorType.getRank(), srcMemrefType, 258 mixedOffsets, mixedSizes, mixedStrides) 259 .cast<MemRefType>(); 260 Value subView = rewriter.create<memref::SubViewOp>( 261 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 262 mixedStrides); 263 264 // If not inplaceable, copy. 265 if (!inplace) { 266 // Do not copy if the copied data is never read. 267 if (state.isValueRead(extractSliceOp.result())) 268 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 269 alloc, state.getOptions()))) 270 return failure(); 271 subView = alloc; 272 } 273 274 replaceOpWithBufferizedValues(rewriter, op, subView); 275 return success(); 276 } 277 }; 278 279 /// Bufferization of tensor.extract. Replace with memref.load. 280 struct ExtractOpInterface 281 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 282 tensor::ExtractOp> { 283 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 284 const BufferizationState &state) const { 285 return true; 286 } 287 288 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 289 const BufferizationState &state) const { 290 return false; 291 } 292 293 SmallVector<OpResult> 294 getAliasingOpResult(Operation *op, OpOperand &opOperand, 295 const BufferizationState &state) const { 296 return {}; 297 } 298 299 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 300 const BufferizationState &state) const { 301 auto extractOp = cast<tensor::ExtractOp>(op); 302 Value srcMemref = 303 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 304 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 305 extractOp.indices()); 306 return success(); 307 } 308 }; 309 310 // Implements backtracking to traverse indices of the output buffer while 311 // iterating over op.elements(). 312 static void createStores(RewriterBase &rewriter, Location loc, int dim, 313 Value buffer, ArrayRef<int64_t> shape, 314 ArrayRef<Value> constants, 315 OperandRange::iterator &elementIt, 316 SmallVectorImpl<Value> &indices) { 317 if (dim == static_cast<int>(shape.size()) - 1) { 318 for (int i = 0; i < shape.back(); ++i) { 319 indices.back() = constants[i]; 320 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 321 ++elementIt; 322 } 323 return; 324 } 325 for (int i = 0; i < shape[dim]; ++i) { 326 indices[dim] = constants[i]; 327 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 328 indices); 329 } 330 } 331 332 /// Bufferization of tensor.from_elements. 333 struct FromElementsOpInterface 334 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 335 tensor::FromElementsOp> { 336 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 337 const BufferizationState &state) const { 338 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 339 340 // Allocate a buffer for the result. 341 Location loc = op->getLoc(); 342 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 343 auto shape = tensorType.getShape(); 344 MemRefType resultType = getContiguousMemRefType(tensorType); 345 FailureOr<Value> maybeBuffer = 346 createAlloc(rewriter, loc, resultType, {}, 347 /*deallocMemref=*/state.getOptions().createDeallocs, 348 state.getOptions()); 349 if (failed(maybeBuffer)) 350 return failure(); 351 Value buffer = *maybeBuffer; 352 353 // Case: tensor<0xelem_type>. 354 if (fromElementsOp.elements().empty()) { 355 replaceOpWithBufferizedValues(rewriter, op, buffer); 356 return success(); 357 } 358 359 // Case: tensor<elem_type>. 360 if (shape.empty()) { 361 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 362 buffer); 363 replaceOpWithBufferizedValues(rewriter, op, buffer); 364 return success(); 365 } 366 367 // Create constants for the range of possible indices [0, max{shape_i}). 368 auto maxDim = *std::max_element(shape.begin(), shape.end()); 369 SmallVector<Value, 2> constants; 370 constants.reserve(maxDim); 371 for (int i = 0; i < maxDim; ++i) 372 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 373 374 // Traverse all `elements` and create `memref.store` ops. 375 auto elementIt = fromElementsOp.elements().begin(); 376 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 377 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 378 indices); 379 380 replaceOpWithBufferizedValues(rewriter, op, buffer); 381 return success(); 382 } 383 }; 384 385 /// Bufferization of tensor.generate. 386 struct GenerateOpInterface 387 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 388 tensor::GenerateOp> { 389 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 390 const BufferizationState &state) const { 391 auto generateOp = cast<tensor::GenerateOp>(op); 392 393 // Allocate memory. 394 Location loc = op->getLoc(); 395 MemRefType memrefType = 396 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 397 FailureOr<Value> maybeResult = 398 createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), 399 /*deallocMemref=*/state.getOptions().createDeallocs, 400 state.getOptions()); 401 if (failed(maybeResult)) 402 return failure(); 403 Value result = *maybeResult; 404 405 // Collect loop bounds. 406 int64_t rank = memrefType.getRank(); 407 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 408 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 409 SmallVector<Value, 4> lowerBounds(rank, zero); 410 SmallVector<Value, 4> steps(rank, one); 411 SmallVector<Value, 4> upperBounds; 412 int nextDynamicIndex = 0; 413 for (int i = 0; i < rank; i++) { 414 Value upperBound = memrefType.isDynamicDim(i) 415 ? generateOp.dynamicExtents()[nextDynamicIndex++] 416 : rewriter.create<arith::ConstantIndexOp>( 417 loc, memrefType.getDimSize(i)); 418 upperBounds.push_back(upperBound); 419 } 420 421 // Generate tensor elements with a parallel loop that stores into 422 // each element of the resulting memref. We use mergeBlockBefore to "move" 423 // this op's body into the scf.parallel's body. 424 auto parallel = 425 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 426 Block *parallelBody = parallel.getBody(); 427 rewriter.mergeBlockBefore(generateOp.getBody(), 428 parallelBody->getTerminator(), 429 parallelBody->getArguments()); 430 // Replace the inlined yield op with a store op. The scf.parallel's builder 431 // already populated an scf.yield at the end, so we don't need to worry 432 // about creating that. 433 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 434 rewriter.setInsertionPointAfter(elementYield); 435 rewriter.replaceOpWithNewOp<memref::StoreOp>( 436 elementYield, elementYield->getOperands()[0], result, 437 parallelBody->getArguments()); 438 439 replaceOpWithBufferizedValues(rewriter, op, result); 440 return success(); 441 } 442 }; 443 444 /// Bufferization of tensor.insert. Replace with memref.store. 445 struct InsertOpInterface 446 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 447 tensor::InsertOp> { 448 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 449 const BufferizationState &state) const { 450 return true; 451 } 452 453 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 454 const BufferizationState &state) const { 455 return true; 456 } 457 458 SmallVector<OpResult> 459 getAliasingOpResult(Operation *op, OpOperand &opOperand, 460 const BufferizationState &state) const { 461 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 462 "expected dest OpOperand"); 463 return {op->getOpResult(0)}; 464 } 465 466 SmallVector<OpOperand *> 467 getAliasingOpOperand(Operation *op, OpResult opResult, 468 const BufferizationState &state) const { 469 return {&op->getOpOperand(1) /*dest*/}; 470 } 471 472 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 473 const BufferizationState &state) const { 474 auto insertOp = cast<tensor::InsertOp>(op); 475 FailureOr<Value> destMemref = 476 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 477 if (failed(destMemref)) 478 return failure(); 479 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 480 *destMemref, insertOp.indices()); 481 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 482 return success(); 483 } 484 485 BufferRelation bufferRelation(Operation *op, OpResult opResult, 486 const BufferizationState &state) const { 487 return BufferRelation::Equivalent; 488 } 489 }; 490 491 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 492 /// equivalent operand / result and same offset/sizes/strides specification). 493 /// 494 /// This is one particular type of relationship between ops on tensors that 495 /// reduce to an equivalence on buffers. This should be generalized and 496 /// exposed as interfaces on the proper types. 497 static bool areEquivalentExtractSliceOps(const BufferizationState &state, 498 ExtractSliceOp st, InsertSliceOp sti) { 499 if (!st || !sti) 500 return false; 501 if (sti != sti && 502 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 503 return false; 504 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 505 return false; 506 return true; 507 } 508 509 /// Return true if `value` is originating from an ExtractSliceOp that matches 510 /// the given InsertSliceOp. 511 static bool hasMatchingExtractSliceOp(const BufferizationState &state, 512 Value value, InsertSliceOp insertOp) { 513 auto condition = [&](Value val) { 514 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 515 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 516 return true; 517 return false; 518 }; 519 520 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 521 condition); 522 } 523 524 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 525 /// certain circumstances, this op can also be a no-op. 526 struct InsertSliceOpInterface 527 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 528 tensor::InsertSliceOp> { 529 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 530 const BufferizationState &state) const { 531 return true; 532 } 533 534 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 535 const BufferizationState &state) const { 536 return &opOperand == &op->getOpOperand(1) /*dest*/; 537 } 538 539 SmallVector<OpResult> 540 getAliasingOpResult(Operation *op, OpOperand &opOperand, 541 const BufferizationState &state) const { 542 if (&opOperand == &op->getOpOperand(1) /*dest*/) 543 return {op->getResult(0)}; 544 return {}; 545 } 546 547 BufferRelation bufferRelation(Operation *op, OpResult opResult, 548 const BufferizationState &state) const { 549 return BufferRelation::Equivalent; 550 } 551 552 bool isNotConflicting(Operation *op, OpOperand *uRead, 553 OpOperand *uConflictingWrite, 554 const BufferizationState &state) const { 555 Operation *readingOp = uRead->getOwner(); 556 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 557 558 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 559 // uRead is an InsertSliceOp... 560 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 561 // As an example, consider the following IR. 562 // 563 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 564 // %1 = linalg.fill %cst, %0 {inplace= [true] } 565 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 566 // {inplace= [true] } 567 568 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 569 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 570 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 571 insertSliceOp)) 572 // Case 1: The main insight is that InsertSliceOp reads only part of 573 // the destination tensor. The overwritten area is not read. If 574 // uConflictingWrite writes into exactly the memory location that is 575 // being read by uRead, this is not a conflict. 576 // 577 // In the above example: 578 // uRead = OpOperand 1 (%t) of tensor.insert_slice 579 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 580 // 581 // The read of %t does not conflict with the write of the FillOp 582 // (same aliases!) because the area that the FillOp operates on is 583 // exactly the one that is *not* read via %t. 584 return true; 585 586 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 587 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 588 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 589 // Case 2: The read of the source tensor and the write to the dest 590 // tensor via an InsertSliceOp is not a conflict if the read is 591 // reading exactly that part of an equivalent tensor that the 592 // InsertSliceOp is writing. 593 // 594 // In the above example: 595 // uRead = OpOperand 0 (%1) of tensor.insert_slice 596 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 597 return true; 598 } 599 600 // If uConflictingWrite is an InsertSliceOp... 601 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 602 // As an example, consider the following IR. 603 // 604 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 605 // %1 = linalg.fill %cst, %0 {inplace= [true] } 606 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 607 // {inplace= [true] } 608 // %3 = vector.transfer_read %1, %cst 609 // 610 // In the above example: 611 // uRead = OpOperand 0 (%1) of vector.transfer_read 612 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 613 // lastWrite = %1 614 // 615 // This is not a conflict because the InsertSliceOp overwrites the 616 // memory segment of %1 with the exact same data. (Effectively, there 617 // is no memory write here.) 618 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 619 state.areEquivalentBufferizedValues(uRead->get(), 620 insertSliceOp.source()) && 621 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 622 insertSliceOp)) 623 return true; 624 625 return false; 626 } 627 628 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 629 const BufferizationState &state) const { 630 // insert_slice ops arise from tiling and bufferizing them out-of-place is 631 // generally a deal breaker. When used with loops, this ends up cloning the 632 // whole tensor on every single iteration and is a symptom of a 633 // catastrophically bad scheduling decision. 634 // TODO: be very loud about it or even consider failing the pass. 635 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 636 Location loc = insertSliceOp.getLoc(); 637 638 // When bufferizing out-of-place, `getResultBuffer` allocates. 639 FailureOr<Value> dstMemref = 640 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 641 if (failed(dstMemref)) 642 return failure(); 643 644 // Expand offsets, sizes and strides to the full rank to handle the 645 // rank-reducing case. 646 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 647 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 648 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 649 OffsetSizeAndStrideOpInterface::expandToRank( 650 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 651 [&](Value target, int64_t dim) -> OpFoldResult { 652 auto shapedType = target.getType().cast<ShapedType>(); 653 if (shapedType.isDynamicDim(dim)) 654 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 655 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 656 }); 657 // Take a subview of the dst. 658 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 659 auto subviewMemRefType = 660 memref::SubViewOp::inferRankReducedResultType( 661 insertSliceOp.getSourceType().getRank(), dstMemrefType, 662 mixedOffsets, mixedSizes, mixedStrides) 663 .cast<MemRefType>(); 664 Value subView = rewriter.create<memref::SubViewOp>( 665 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 666 mixedStrides); 667 668 // Copy tensor. If this tensor.insert_slice has a matching 669 // tensor.extract_slice, the copy operation will eventually fold away. 670 Value srcMemref = 671 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 672 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 673 state.getOptions()))) 674 return failure(); 675 676 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 677 return success(); 678 } 679 }; 680 681 /// Bufferization of tensor.rank. Replace with memref.rank. 682 struct RankOpInterface 683 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 684 tensor::RankOp> { 685 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 686 const BufferizationState &state) const { 687 return true; 688 } 689 690 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 691 const BufferizationState &state) const { 692 return false; 693 } 694 695 SmallVector<OpResult> 696 getAliasingOpResult(Operation *op, OpOperand &opOperand, 697 const BufferizationState &state) const { 698 return {}; 699 } 700 701 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 702 const BufferizationState &state) const { 703 auto rankOp = cast<tensor::RankOp>(op); 704 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 705 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 706 v); 707 return success(); 708 } 709 }; 710 711 } // namespace 712 } // namespace tensor 713 } // namespace mlir 714 715 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 716 DialectRegistry ®istry) { 717 registry.addOpInterface<CastOp, CastOpInterface>(); 718 registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>(); 719 registry.addOpInterface<DimOp, DimOpInterface>(); 720 registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>(); 721 registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 722 registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 723 registry.addOpInterface<FromElementsOp, FromElementsOpInterface>(); 724 registry.addOpInterface<GenerateOp, GenerateOpInterface>(); 725 registry.addOpInterface<InsertOp, InsertOpInterface>(); 726 registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 727 registry.addOpInterface<RankOp, RankOpInterface>(); 728 } 729