1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const AnalysisState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const AnalysisState &state) const { 35 return false; 36 } 37 38 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 39 const AnalysisState &state) const { 40 return {op->getResult(0)}; 41 } 42 43 BufferRelation bufferRelation(Operation *op, OpResult opResult, 44 const AnalysisState &state) const { 45 return BufferRelation::Equivalent; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 BufferizationState &state) const { 50 auto castOp = cast<tensor::CastOp>(op); 51 52 // The result buffer still has the old (pre-cast) type. 53 FailureOr<Value> resultBuffer = 54 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 55 if (failed(resultBuffer)) 56 return failure(); 57 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 58 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 59 TensorType resultTensorType = 60 castOp.getResult().getType().cast<TensorType>(); 61 MemRefLayoutAttrInterface layout; 62 63 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 64 if (resultTensorType.isa<RankedTensorType>()) 65 layout = rankedMemRefType.getLayout(); 66 67 // Compute the new memref type. 68 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 69 layout, memorySpace); 70 71 // Replace the op with a memref.cast. 72 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 73 resultMemRefType) && 74 "CallOp::bufferize: cast incompatible"); 75 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 76 *resultBuffer); 77 78 return success(); 79 } 80 }; 81 82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 83 struct CollapseShapeOpInterface 84 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 85 tensor::CollapseShapeOp> { 86 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 87 const AnalysisState &state) const { 88 return false; 89 } 90 91 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 92 const AnalysisState &state) const { 93 return false; 94 } 95 96 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 97 const AnalysisState &state) const { 98 if (&opOperand == &op->getOpOperand(0) /*src*/) 99 return {op->getOpResult(0)}; 100 return {}; 101 } 102 103 BufferRelation bufferRelation(Operation *op, OpResult opResult, 104 const AnalysisState &state) const { 105 return BufferRelation::Equivalent; 106 } 107 108 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 109 BufferizationState &state) const { 110 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 111 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 112 Value buffer = 113 *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); 114 115 if (tensorResultType.getRank() == 0) { 116 // 0-d collapses must go through a different op builder. 117 auto bufferType = buffer.getType().cast<MemRefType>(); 118 MemRefType resultType; 119 120 if (bufferType.getLayout().isIdentity()) { 121 // Standard layout: result type has no offset. 122 MemRefLayoutAttrInterface layout; 123 resultType = MemRefType::get({}, tensorResultType.getElementType(), 124 layout, bufferType.getMemorySpace()); 125 } else { 126 // Source memref has a layout map: result type has the same offset as 127 // the source type. 128 SmallVector<int64_t> strides; 129 int64_t offset; 130 if (failed(getStridesAndOffset(bufferType, strides, offset))) 131 return failure(); 132 AffineMap resultLayout = 133 makeStridedLinearLayoutMap({}, offset, op->getContext()); 134 resultType = 135 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 136 bufferType.getMemorySpaceAsInt()); 137 } 138 139 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 140 rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 141 return success(); 142 } 143 144 // Result type is inferred by the builder. 145 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 146 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 147 return success(); 148 } 149 }; 150 151 /// Bufferization of tensor.dim. Replace with memref.dim. 152 struct DimOpInterface 153 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 154 tensor::DimOp> { 155 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 156 const AnalysisState &state) const { 157 return true; 158 } 159 160 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 161 const AnalysisState &state) const { 162 return false; 163 } 164 165 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 166 const AnalysisState &state) const { 167 return {}; 168 } 169 170 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 171 BufferizationState &state) const { 172 auto dimOp = cast<tensor::DimOp>(op); 173 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 174 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 175 return success(); 176 } 177 }; 178 179 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 180 struct ExpandShapeOpInterface 181 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 182 tensor::ExpandShapeOp> { 183 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 184 const AnalysisState &state) const { 185 return false; 186 } 187 188 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 189 const AnalysisState &state) const { 190 return false; 191 } 192 193 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 194 const AnalysisState &state) const { 195 if (&opOperand == &op->getOpOperand(0) /*src*/) 196 return {op->getOpResult(0)}; 197 return {}; 198 } 199 200 BufferRelation bufferRelation(Operation *op, OpResult opResult, 201 const AnalysisState &state) const { 202 return BufferRelation::Equivalent; 203 } 204 205 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 206 BufferizationState &state) const { 207 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 208 auto tensorResultType = expandShapeOp.getResultType(); 209 Value buffer = 210 *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 211 212 // Memref result type is inferred by the builder based on reassociation 213 // indices and result shape. 214 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 215 rewriter, op, tensorResultType.getShape(), buffer, 216 expandShapeOp.getReassociationIndices()); 217 return success(); 218 } 219 }; 220 221 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 222 struct ExtractSliceOpInterface 223 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 224 tensor::ExtractSliceOp> { 225 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 226 const AnalysisState &state) const { 227 return false; 228 } 229 230 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 231 const AnalysisState &state) const { 232 return false; 233 } 234 235 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 236 const AnalysisState &state) const { 237 if (&opOperand == &op->getOpOperand(0) /*source*/) 238 return {op->getOpResult(0)}; 239 return {}; 240 } 241 242 BufferRelation bufferRelation(Operation *op, OpResult opResult, 243 const AnalysisState &state) const { 244 return BufferRelation::None; 245 } 246 247 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 248 BufferizationState &state) const { 249 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 250 Location loc = extractSliceOp.getLoc(); 251 Value srcMemref = 252 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 253 /*forceInPlace=*/true); 254 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 255 auto dstTensorType = 256 extractSliceOp.result().getType().cast<RankedTensorType>(); 257 258 // If not inplaceable, alloc. 259 bool inplace = 260 state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 261 Value alloc; 262 if (!inplace) { 263 FailureOr<Value> allocOrFailure = 264 state.createAlloc(rewriter, loc, extractSliceOp.result()); 265 if (failed(allocOrFailure)) 266 return failure(); 267 alloc = *allocOrFailure; 268 } 269 270 // Expand offsets, sizes and strides to the full rank to handle the 271 // rank-reducing case. 272 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 273 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 274 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 275 OffsetSizeAndStrideOpInterface::expandToRank( 276 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 277 [&](Value target, int64_t dim) -> OpFoldResult { 278 auto shapedType = target.getType().cast<ShapedType>(); 279 if (shapedType.isDynamicDim(dim)) 280 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 281 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 282 }); 283 // Bufferize to subview. 284 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 285 dstTensorType.getRank(), srcMemrefType, 286 mixedOffsets, mixedSizes, mixedStrides) 287 .cast<MemRefType>(); 288 Value subView = rewriter.create<memref::SubViewOp>( 289 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 290 mixedStrides); 291 292 // If not inplaceable, copy. 293 if (!inplace) { 294 // Do not copy if the copied data is never read. 295 if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 296 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 297 alloc, state.getOptions()))) 298 return failure(); 299 subView = alloc; 300 } 301 302 replaceOpWithBufferizedValues(rewriter, op, subView); 303 return success(); 304 } 305 }; 306 307 /// Bufferization of tensor.extract. Replace with memref.load. 308 struct ExtractOpInterface 309 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 310 tensor::ExtractOp> { 311 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 312 const AnalysisState &state) const { 313 return true; 314 } 315 316 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 317 const AnalysisState &state) const { 318 return false; 319 } 320 321 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 322 const AnalysisState &state) const { 323 return {}; 324 } 325 326 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 327 BufferizationState &state) const { 328 auto extractOp = cast<tensor::ExtractOp>(op); 329 Value srcMemref = 330 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 331 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 332 extractOp.indices()); 333 return success(); 334 } 335 }; 336 337 // Implements backtracking to traverse indices of the output buffer while 338 // iterating over op.elements(). 339 static void createStores(RewriterBase &rewriter, Location loc, int dim, 340 Value buffer, ArrayRef<int64_t> shape, 341 ArrayRef<Value> constants, 342 OperandRange::iterator &elementIt, 343 SmallVectorImpl<Value> &indices) { 344 if (dim == static_cast<int>(shape.size()) - 1) { 345 for (int i = 0; i < shape.back(); ++i) { 346 indices.back() = constants[i]; 347 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 348 ++elementIt; 349 } 350 return; 351 } 352 for (int i = 0; i < shape[dim]; ++i) { 353 indices[dim] = constants[i]; 354 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 355 indices); 356 } 357 } 358 359 /// Bufferization of tensor.from_elements. 360 struct FromElementsOpInterface 361 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 362 tensor::FromElementsOp> { 363 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 364 BufferizationState &state) const { 365 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 366 367 // Allocate a buffer for the result. 368 Location loc = op->getLoc(); 369 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 370 auto shape = tensorType.getShape(); 371 FailureOr<Value> maybeBuffer = 372 state.createAlloc(rewriter, loc, fromElementsOp.result()); 373 if (failed(maybeBuffer)) 374 return failure(); 375 Value buffer = *maybeBuffer; 376 377 // Case: tensor<0xelem_type>. 378 if (fromElementsOp.elements().empty()) { 379 replaceOpWithBufferizedValues(rewriter, op, buffer); 380 return success(); 381 } 382 383 // Case: tensor<elem_type>. 384 if (shape.empty()) { 385 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 386 buffer); 387 replaceOpWithBufferizedValues(rewriter, op, buffer); 388 return success(); 389 } 390 391 // Create constants for the range of possible indices [0, max{shape_i}). 392 auto maxDim = *std::max_element(shape.begin(), shape.end()); 393 SmallVector<Value, 2> constants; 394 constants.reserve(maxDim); 395 for (int i = 0; i < maxDim; ++i) 396 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 397 398 // Traverse all `elements` and create `memref.store` ops. 399 auto elementIt = fromElementsOp.elements().begin(); 400 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 401 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 402 indices); 403 404 replaceOpWithBufferizedValues(rewriter, op, buffer); 405 return success(); 406 } 407 }; 408 409 /// Bufferization of tensor.generate. 410 struct GenerateOpInterface 411 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 412 tensor::GenerateOp> { 413 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 414 BufferizationState &state) const { 415 auto generateOp = cast<tensor::GenerateOp>(op); 416 417 // Allocate memory. 418 Location loc = op->getLoc(); 419 MemRefType memrefType = 420 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 421 FailureOr<Value> maybeResult = 422 state.createAlloc(rewriter, loc, generateOp.result()); 423 if (failed(maybeResult)) 424 return failure(); 425 Value result = *maybeResult; 426 427 // Collect loop bounds. 428 int64_t rank = memrefType.getRank(); 429 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 430 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 431 SmallVector<Value, 4> lowerBounds(rank, zero); 432 SmallVector<Value, 4> steps(rank, one); 433 SmallVector<Value, 4> upperBounds; 434 int nextDynamicIndex = 0; 435 for (int i = 0; i < rank; i++) { 436 Value upperBound = memrefType.isDynamicDim(i) 437 ? generateOp.dynamicExtents()[nextDynamicIndex++] 438 : rewriter.create<arith::ConstantIndexOp>( 439 loc, memrefType.getDimSize(i)); 440 upperBounds.push_back(upperBound); 441 } 442 443 // Generate tensor elements with a parallel loop that stores into 444 // each element of the resulting memref. We use mergeBlockBefore to "move" 445 // this op's body into the scf.parallel's body. 446 auto parallel = 447 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 448 Block *parallelBody = parallel.getBody(); 449 rewriter.mergeBlockBefore(generateOp.getBody(), 450 parallelBody->getTerminator(), 451 parallelBody->getArguments()); 452 // Replace the inlined yield op with a store op. The scf.parallel's builder 453 // already populated an scf.yield at the end, so we don't need to worry 454 // about creating that. 455 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 456 rewriter.setInsertionPointAfter(elementYield); 457 rewriter.replaceOpWithNewOp<memref::StoreOp>( 458 elementYield, elementYield->getOperands()[0], result, 459 parallelBody->getArguments()); 460 461 replaceOpWithBufferizedValues(rewriter, op, result); 462 return success(); 463 } 464 }; 465 466 /// Bufferization of tensor.insert. Replace with memref.store. 467 struct InsertOpInterface 468 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 469 tensor::InsertOp> { 470 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 471 const AnalysisState &state) const { 472 return true; 473 } 474 475 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 476 const AnalysisState &state) const { 477 return true; 478 } 479 480 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 481 const AnalysisState &state) const { 482 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 483 "expected dest OpOperand"); 484 return {op->getOpResult(0)}; 485 } 486 487 SmallVector<OpOperand *> 488 getAliasingOpOperand(Operation *op, OpResult opResult, 489 const AnalysisState &state) const { 490 return {&op->getOpOperand(1) /*dest*/}; 491 } 492 493 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 494 BufferizationState &state) const { 495 auto insertOp = cast<tensor::InsertOp>(op); 496 FailureOr<Value> destMemref = 497 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 498 if (failed(destMemref)) 499 return failure(); 500 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 501 *destMemref, insertOp.indices()); 502 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 503 return success(); 504 } 505 506 BufferRelation bufferRelation(Operation *op, OpResult opResult, 507 const AnalysisState &state) const { 508 return BufferRelation::Equivalent; 509 } 510 }; 511 512 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 513 /// equivalent operand / result and same offset/sizes/strides specification). 514 /// 515 /// This is one particular type of relationship between ops on tensors that 516 /// reduce to an equivalence on buffers. This should be generalized and 517 /// exposed as interfaces on the proper types. 518 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 519 ExtractSliceOp st, InsertSliceOp sti) { 520 if (!st || !sti) 521 return false; 522 if (sti != sti && 523 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 524 return false; 525 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 526 return false; 527 return true; 528 } 529 530 /// Return true if `value` is originating from an ExtractSliceOp that matches 531 /// the given InsertSliceOp. 532 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 533 InsertSliceOp insertOp) { 534 auto condition = [&](Value val) { 535 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 536 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 537 return true; 538 return false; 539 }; 540 541 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 542 condition); 543 } 544 545 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 546 /// certain circumstances, this op can also be a no-op. 547 struct InsertSliceOpInterface 548 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 549 tensor::InsertSliceOp> { 550 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 551 const AnalysisState &state) const { 552 return true; 553 } 554 555 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 556 const AnalysisState &state) const { 557 return &opOperand == &op->getOpOperand(1) /*dest*/; 558 } 559 560 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 561 const AnalysisState &state) const { 562 if (&opOperand == &op->getOpOperand(1) /*dest*/) 563 return {op->getResult(0)}; 564 return {}; 565 } 566 567 BufferRelation bufferRelation(Operation *op, OpResult opResult, 568 const AnalysisState &state) const { 569 return BufferRelation::Equivalent; 570 } 571 572 bool isNotConflicting(Operation *op, OpOperand *uRead, 573 OpOperand *uConflictingWrite, 574 const AnalysisState &state) const { 575 Operation *readingOp = uRead->getOwner(); 576 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 577 578 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 579 // uRead is an InsertSliceOp... 580 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 581 // As an example, consider the following IR. 582 // 583 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 584 // %1 = linalg.fill %cst, %0 {inplace= [true] } 585 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 586 // {inplace= [true] } 587 588 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 589 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 590 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 591 insertSliceOp)) 592 // Case 1: The main insight is that InsertSliceOp reads only part of 593 // the destination tensor. The overwritten area is not read. If 594 // uConflictingWrite writes into exactly the memory location that is 595 // being read by uRead, this is not a conflict. 596 // 597 // In the above example: 598 // uRead = OpOperand 1 (%t) of tensor.insert_slice 599 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 600 // 601 // The read of %t does not conflict with the write of the FillOp 602 // (same aliases!) because the area that the FillOp operates on is 603 // exactly the one that is *not* read via %t. 604 return true; 605 606 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 607 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 608 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 609 // Case 2: The read of the source tensor and the write to the dest 610 // tensor via an InsertSliceOp is not a conflict if the read is 611 // reading exactly that part of an equivalent tensor that the 612 // InsertSliceOp is writing. 613 // 614 // In the above example: 615 // uRead = OpOperand 0 (%1) of tensor.insert_slice 616 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 617 return true; 618 } 619 620 // If uConflictingWrite is an InsertSliceOp... 621 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 622 // As an example, consider the following IR. 623 // 624 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 625 // %1 = linalg.fill %cst, %0 {inplace= [true] } 626 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 627 // {inplace= [true] } 628 // %3 = vector.transfer_read %1, %cst 629 // 630 // In the above example: 631 // uRead = OpOperand 0 (%1) of vector.transfer_read 632 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 633 // lastWrite = %1 634 // 635 // This is not a conflict because the InsertSliceOp overwrites the 636 // memory segment of %1 with the exact same data. (Effectively, there 637 // is no memory write here.) 638 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 639 state.areEquivalentBufferizedValues(uRead->get(), 640 insertSliceOp.source()) && 641 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 642 insertSliceOp)) 643 return true; 644 645 return false; 646 } 647 648 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 649 BufferizationState &state) const { 650 // insert_slice ops arise from tiling and bufferizing them out-of-place is 651 // generally a deal breaker. When used with loops, this ends up cloning the 652 // whole tensor on every single iteration and is a symptom of a 653 // catastrophically bad scheduling decision. 654 // TODO: be very loud about it or even consider failing the pass. 655 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 656 Location loc = insertSliceOp.getLoc(); 657 658 // When bufferizing out-of-place, `getResultBuffer` allocates. 659 FailureOr<Value> dstMemref = 660 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 661 if (failed(dstMemref)) 662 return failure(); 663 664 // Expand offsets, sizes and strides to the full rank to handle the 665 // rank-reducing case. 666 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 667 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 668 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 669 OffsetSizeAndStrideOpInterface::expandToRank( 670 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 671 [&](Value target, int64_t dim) -> OpFoldResult { 672 auto shapedType = target.getType().cast<ShapedType>(); 673 if (shapedType.isDynamicDim(dim)) 674 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 675 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 676 }); 677 // Take a subview of the dst. 678 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 679 auto subviewMemRefType = 680 memref::SubViewOp::inferRankReducedResultType( 681 insertSliceOp.getSourceType().getRank(), dstMemrefType, 682 mixedOffsets, mixedSizes, mixedStrides) 683 .cast<MemRefType>(); 684 Value subView = rewriter.create<memref::SubViewOp>( 685 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 686 mixedStrides); 687 688 // Copy tensor. If this tensor.insert_slice has a matching 689 // tensor.extract_slice, the copy operation will eventually fold away. 690 Value srcMemref = 691 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 692 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 693 state.getOptions()))) 694 return failure(); 695 696 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 697 return success(); 698 } 699 }; 700 701 /// Bufferization of tensor.rank. Replace with memref.rank. 702 struct RankOpInterface 703 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 704 tensor::RankOp> { 705 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 706 const AnalysisState &state) const { 707 return true; 708 } 709 710 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 711 const AnalysisState &state) const { 712 return false; 713 } 714 715 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 716 const AnalysisState &state) const { 717 return {}; 718 } 719 720 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 721 BufferizationState &state) const { 722 auto rankOp = cast<tensor::RankOp>(op); 723 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 724 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 725 v); 726 return success(); 727 } 728 }; 729 730 } // namespace 731 } // namespace tensor 732 } // namespace mlir 733 734 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 735 DialectRegistry ®istry) { 736 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 737 CastOp::attachInterface<CastOpInterface>(*ctx); 738 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 739 DimOp::attachInterface<DimOpInterface>(*ctx); 740 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 741 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 742 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 743 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 744 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 745 InsertOp::attachInterface<InsertOpInterface>(*ctx); 746 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 747 RankOp::attachInterface<RankOpInterface>(*ctx); 748 }); 749 } 750