1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const AnalysisState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const AnalysisState &state) const { 35 return false; 36 } 37 38 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 39 const AnalysisState &state) const { 40 return {op->getResult(0)}; 41 } 42 43 BufferRelation bufferRelation(Operation *op, OpResult opResult, 44 const AnalysisState &state) const { 45 return BufferRelation::Equivalent; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 BufferizationState &state) const { 50 auto castOp = cast<tensor::CastOp>(op); 51 52 // The result buffer still has the old (pre-cast) type. 53 FailureOr<Value> resultBuffer = 54 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 55 if (failed(resultBuffer)) 56 return failure(); 57 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 58 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 59 TensorType resultTensorType = 60 castOp.getResult().getType().cast<TensorType>(); 61 MemRefLayoutAttrInterface layout; 62 63 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 64 if (resultTensorType.isa<RankedTensorType>()) 65 layout = rankedMemRefType.getLayout(); 66 67 // Compute the new memref type. 68 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 69 layout, memorySpace); 70 71 // Replace the op with a memref.cast. 72 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 73 resultMemRefType) && 74 "CallOp::bufferize: cast incompatible"); 75 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 76 *resultBuffer); 77 78 return success(); 79 } 80 }; 81 82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 83 struct CollapseShapeOpInterface 84 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 85 tensor::CollapseShapeOp> { 86 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 87 const AnalysisState &state) const { 88 return false; 89 } 90 91 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 92 const AnalysisState &state) const { 93 return false; 94 } 95 96 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 97 const AnalysisState &state) const { 98 if (&opOperand == &op->getOpOperand(0) /*src*/) 99 return {op->getOpResult(0)}; 100 return {}; 101 } 102 103 BufferRelation bufferRelation(Operation *op, OpResult opResult, 104 const AnalysisState &state) const { 105 return BufferRelation::Equivalent; 106 } 107 108 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 109 BufferizationState &state) const { 110 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 111 Value buffer = 112 *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); 113 Type resultType = 114 getMemRefType(collapseShapeOp.getResultType(), state.getOptions()); 115 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 116 rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 117 return success(); 118 } 119 }; 120 121 /// Bufferization of tensor.dim. Replace with memref.dim. 122 struct DimOpInterface 123 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 124 tensor::DimOp> { 125 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 126 const AnalysisState &state) const { 127 return true; 128 } 129 130 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 131 const AnalysisState &state) const { 132 return false; 133 } 134 135 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 136 const AnalysisState &state) const { 137 return {}; 138 } 139 140 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 141 BufferizationState &state) const { 142 auto dimOp = cast<tensor::DimOp>(op); 143 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 144 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 145 return success(); 146 } 147 }; 148 149 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 150 struct ExpandShapeOpInterface 151 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 152 tensor::ExpandShapeOp> { 153 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 154 const AnalysisState &state) const { 155 return false; 156 } 157 158 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 159 const AnalysisState &state) const { 160 return false; 161 } 162 163 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 164 const AnalysisState &state) const { 165 if (&opOperand == &op->getOpOperand(0) /*src*/) 166 return {op->getOpResult(0)}; 167 return {}; 168 } 169 170 BufferRelation bufferRelation(Operation *op, OpResult opResult, 171 const AnalysisState &state) const { 172 return BufferRelation::Equivalent; 173 } 174 175 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 176 BufferizationState &state) const { 177 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 178 Value buffer = 179 *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 180 Type resultType = 181 getMemRefType(expandShapeOp.getResultType(), state.getOptions()); 182 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 183 rewriter, op, resultType, buffer, expandShapeOp.reassociation()); 184 return success(); 185 } 186 }; 187 188 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 189 struct ExtractSliceOpInterface 190 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 191 tensor::ExtractSliceOp> { 192 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 193 const AnalysisState &state) const { 194 return false; 195 } 196 197 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 198 const AnalysisState &state) const { 199 return false; 200 } 201 202 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 203 const AnalysisState &state) const { 204 if (&opOperand == &op->getOpOperand(0) /*source*/) 205 return {op->getOpResult(0)}; 206 return {}; 207 } 208 209 BufferRelation bufferRelation(Operation *op, OpResult opResult, 210 const AnalysisState &state) const { 211 return BufferRelation::None; 212 } 213 214 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 215 BufferizationState &state) const { 216 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 217 Location loc = extractSliceOp.getLoc(); 218 Value srcMemref = 219 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 220 /*forceInPlace=*/true); 221 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 222 auto dstTensorType = 223 extractSliceOp.result().getType().cast<RankedTensorType>(); 224 225 // If not inplaceable, alloc. 226 bool inplace = 227 state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 228 Value alloc; 229 if (!inplace) { 230 FailureOr<Value> allocOrFailure = 231 state.createAlloc(rewriter, loc, extractSliceOp.result()); 232 if (failed(allocOrFailure)) 233 return failure(); 234 alloc = *allocOrFailure; 235 } 236 237 // Expand offsets, sizes and strides to the full rank to handle the 238 // rank-reducing case. 239 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 240 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 241 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 242 OffsetSizeAndStrideOpInterface::expandToRank( 243 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 244 [&](Value target, int64_t dim) -> OpFoldResult { 245 auto shapedType = target.getType().cast<ShapedType>(); 246 if (shapedType.isDynamicDim(dim)) 247 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 248 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 249 }); 250 // Bufferize to subview. 251 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 252 dstTensorType.getRank(), srcMemrefType, 253 mixedOffsets, mixedSizes, mixedStrides) 254 .cast<MemRefType>(); 255 Value subView = rewriter.create<memref::SubViewOp>( 256 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 257 mixedStrides); 258 259 // If not inplaceable, copy. 260 if (!inplace) { 261 // Do not copy if the copied data is never read. 262 if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 263 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 264 alloc, state.getOptions()))) 265 return failure(); 266 subView = alloc; 267 } 268 269 replaceOpWithBufferizedValues(rewriter, op, subView); 270 return success(); 271 } 272 }; 273 274 /// Bufferization of tensor.extract. Replace with memref.load. 275 struct ExtractOpInterface 276 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 277 tensor::ExtractOp> { 278 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 279 const AnalysisState &state) const { 280 return true; 281 } 282 283 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 284 const AnalysisState &state) const { 285 return false; 286 } 287 288 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 289 const AnalysisState &state) const { 290 return {}; 291 } 292 293 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 294 BufferizationState &state) const { 295 auto extractOp = cast<tensor::ExtractOp>(op); 296 Value srcMemref = 297 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 298 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 299 extractOp.indices()); 300 return success(); 301 } 302 }; 303 304 // Implements backtracking to traverse indices of the output buffer while 305 // iterating over op.elements(). 306 static void createStores(RewriterBase &rewriter, Location loc, int dim, 307 Value buffer, ArrayRef<int64_t> shape, 308 ArrayRef<Value> constants, 309 OperandRange::iterator &elementIt, 310 SmallVectorImpl<Value> &indices) { 311 if (dim == static_cast<int>(shape.size()) - 1) { 312 for (int i = 0; i < shape.back(); ++i) { 313 indices.back() = constants[i]; 314 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 315 ++elementIt; 316 } 317 return; 318 } 319 for (int i = 0; i < shape[dim]; ++i) { 320 indices[dim] = constants[i]; 321 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 322 indices); 323 } 324 } 325 326 /// Bufferization of tensor.from_elements. 327 struct FromElementsOpInterface 328 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 329 tensor::FromElementsOp> { 330 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 331 BufferizationState &state) const { 332 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 333 334 // Allocate a buffer for the result. 335 Location loc = op->getLoc(); 336 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 337 auto shape = tensorType.getShape(); 338 FailureOr<Value> maybeBuffer = 339 state.createAlloc(rewriter, loc, fromElementsOp.result()); 340 if (failed(maybeBuffer)) 341 return failure(); 342 Value buffer = *maybeBuffer; 343 344 // Case: tensor<0xelem_type>. 345 if (fromElementsOp.elements().empty()) { 346 replaceOpWithBufferizedValues(rewriter, op, buffer); 347 return success(); 348 } 349 350 // Case: tensor<elem_type>. 351 if (shape.empty()) { 352 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 353 buffer); 354 replaceOpWithBufferizedValues(rewriter, op, buffer); 355 return success(); 356 } 357 358 // Create constants for the range of possible indices [0, max{shape_i}). 359 auto maxDim = *std::max_element(shape.begin(), shape.end()); 360 SmallVector<Value, 2> constants; 361 constants.reserve(maxDim); 362 for (int i = 0; i < maxDim; ++i) 363 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 364 365 // Traverse all `elements` and create `memref.store` ops. 366 auto elementIt = fromElementsOp.elements().begin(); 367 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 368 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 369 indices); 370 371 replaceOpWithBufferizedValues(rewriter, op, buffer); 372 return success(); 373 } 374 }; 375 376 /// Bufferization of tensor.generate. 377 struct GenerateOpInterface 378 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 379 tensor::GenerateOp> { 380 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 381 BufferizationState &state) const { 382 auto generateOp = cast<tensor::GenerateOp>(op); 383 384 // Allocate memory. 385 Location loc = op->getLoc(); 386 MemRefType memrefType = 387 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 388 FailureOr<Value> maybeResult = 389 state.createAlloc(rewriter, loc, generateOp.result()); 390 if (failed(maybeResult)) 391 return failure(); 392 Value result = *maybeResult; 393 394 // Collect loop bounds. 395 int64_t rank = memrefType.getRank(); 396 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 397 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 398 SmallVector<Value, 4> lowerBounds(rank, zero); 399 SmallVector<Value, 4> steps(rank, one); 400 SmallVector<Value, 4> upperBounds; 401 int nextDynamicIndex = 0; 402 for (int i = 0; i < rank; i++) { 403 Value upperBound = memrefType.isDynamicDim(i) 404 ? generateOp.dynamicExtents()[nextDynamicIndex++] 405 : rewriter.create<arith::ConstantIndexOp>( 406 loc, memrefType.getDimSize(i)); 407 upperBounds.push_back(upperBound); 408 } 409 410 // Generate tensor elements with a parallel loop that stores into 411 // each element of the resulting memref. We use mergeBlockBefore to "move" 412 // this op's body into the scf.parallel's body. 413 auto parallel = 414 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 415 Block *parallelBody = parallel.getBody(); 416 rewriter.mergeBlockBefore(generateOp.getBody(), 417 parallelBody->getTerminator(), 418 parallelBody->getArguments()); 419 // Replace the inlined yield op with a store op. The scf.parallel's builder 420 // already populated an scf.yield at the end, so we don't need to worry 421 // about creating that. 422 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 423 rewriter.setInsertionPointAfter(elementYield); 424 rewriter.replaceOpWithNewOp<memref::StoreOp>( 425 elementYield, elementYield->getOperands()[0], result, 426 parallelBody->getArguments()); 427 428 replaceOpWithBufferizedValues(rewriter, op, result); 429 return success(); 430 } 431 }; 432 433 /// Bufferization of tensor.insert. Replace with memref.store. 434 struct InsertOpInterface 435 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 436 tensor::InsertOp> { 437 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 438 const AnalysisState &state) const { 439 return true; 440 } 441 442 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 443 const AnalysisState &state) const { 444 return true; 445 } 446 447 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 448 const AnalysisState &state) const { 449 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 450 "expected dest OpOperand"); 451 return {op->getOpResult(0)}; 452 } 453 454 SmallVector<OpOperand *> 455 getAliasingOpOperand(Operation *op, OpResult opResult, 456 const AnalysisState &state) const { 457 return {&op->getOpOperand(1) /*dest*/}; 458 } 459 460 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 461 BufferizationState &state) const { 462 auto insertOp = cast<tensor::InsertOp>(op); 463 FailureOr<Value> destMemref = 464 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 465 if (failed(destMemref)) 466 return failure(); 467 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 468 *destMemref, insertOp.indices()); 469 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 470 return success(); 471 } 472 473 BufferRelation bufferRelation(Operation *op, OpResult opResult, 474 const AnalysisState &state) const { 475 return BufferRelation::Equivalent; 476 } 477 }; 478 479 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 480 /// equivalent operand / result and same offset/sizes/strides specification). 481 /// 482 /// This is one particular type of relationship between ops on tensors that 483 /// reduce to an equivalence on buffers. This should be generalized and 484 /// exposed as interfaces on the proper types. 485 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 486 ExtractSliceOp st, InsertSliceOp sti) { 487 if (!st || !sti) 488 return false; 489 if (sti != sti && 490 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 491 return false; 492 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 493 return false; 494 return true; 495 } 496 497 /// Return true if `value` is originating from an ExtractSliceOp that matches 498 /// the given InsertSliceOp. 499 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 500 InsertSliceOp insertOp) { 501 auto condition = [&](Value val) { 502 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 503 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 504 return true; 505 return false; 506 }; 507 508 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 509 condition); 510 } 511 512 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 513 /// certain circumstances, this op can also be a no-op. 514 struct InsertSliceOpInterface 515 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 516 tensor::InsertSliceOp> { 517 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 518 const AnalysisState &state) const { 519 return true; 520 } 521 522 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 523 const AnalysisState &state) const { 524 return &opOperand == &op->getOpOperand(1) /*dest*/; 525 } 526 527 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 528 const AnalysisState &state) const { 529 if (&opOperand == &op->getOpOperand(1) /*dest*/) 530 return {op->getResult(0)}; 531 return {}; 532 } 533 534 BufferRelation bufferRelation(Operation *op, OpResult opResult, 535 const AnalysisState &state) const { 536 return BufferRelation::Equivalent; 537 } 538 539 bool isNotConflicting(Operation *op, OpOperand *uRead, 540 OpOperand *uConflictingWrite, 541 const AnalysisState &state) const { 542 Operation *readingOp = uRead->getOwner(); 543 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 544 545 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 546 // uRead is an InsertSliceOp... 547 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 548 // As an example, consider the following IR. 549 // 550 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 551 // %1 = linalg.fill %cst, %0 {inplace= [true] } 552 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 553 // {inplace= [true] } 554 555 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 556 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 557 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 558 insertSliceOp)) 559 // Case 1: The main insight is that InsertSliceOp reads only part of 560 // the destination tensor. The overwritten area is not read. If 561 // uConflictingWrite writes into exactly the memory location that is 562 // being read by uRead, this is not a conflict. 563 // 564 // In the above example: 565 // uRead = OpOperand 1 (%t) of tensor.insert_slice 566 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 567 // 568 // The read of %t does not conflict with the write of the FillOp 569 // (same aliases!) because the area that the FillOp operates on is 570 // exactly the one that is *not* read via %t. 571 return true; 572 573 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 574 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 575 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 576 // Case 2: The read of the source tensor and the write to the dest 577 // tensor via an InsertSliceOp is not a conflict if the read is 578 // reading exactly that part of an equivalent tensor that the 579 // InsertSliceOp is writing. 580 // 581 // In the above example: 582 // uRead = OpOperand 0 (%1) of tensor.insert_slice 583 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 584 return true; 585 } 586 587 // If uConflictingWrite is an InsertSliceOp... 588 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 589 // As an example, consider the following IR. 590 // 591 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 592 // %1 = linalg.fill %cst, %0 {inplace= [true] } 593 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 594 // {inplace= [true] } 595 // %3 = vector.transfer_read %1, %cst 596 // 597 // In the above example: 598 // uRead = OpOperand 0 (%1) of vector.transfer_read 599 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 600 // lastWrite = %1 601 // 602 // This is not a conflict because the InsertSliceOp overwrites the 603 // memory segment of %1 with the exact same data. (Effectively, there 604 // is no memory write here.) 605 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 606 state.areEquivalentBufferizedValues(uRead->get(), 607 insertSliceOp.source()) && 608 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 609 insertSliceOp)) 610 return true; 611 612 return false; 613 } 614 615 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 616 BufferizationState &state) const { 617 // insert_slice ops arise from tiling and bufferizing them out-of-place is 618 // generally a deal breaker. When used with loops, this ends up cloning the 619 // whole tensor on every single iteration and is a symptom of a 620 // catastrophically bad scheduling decision. 621 // TODO: be very loud about it or even consider failing the pass. 622 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 623 Location loc = insertSliceOp.getLoc(); 624 625 // When bufferizing out-of-place, `getResultBuffer` allocates. 626 FailureOr<Value> dstMemref = 627 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 628 if (failed(dstMemref)) 629 return failure(); 630 631 // Expand offsets, sizes and strides to the full rank to handle the 632 // rank-reducing case. 633 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 634 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 635 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 636 OffsetSizeAndStrideOpInterface::expandToRank( 637 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 638 [&](Value target, int64_t dim) -> OpFoldResult { 639 auto shapedType = target.getType().cast<ShapedType>(); 640 if (shapedType.isDynamicDim(dim)) 641 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 642 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 643 }); 644 // Take a subview of the dst. 645 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 646 auto subviewMemRefType = 647 memref::SubViewOp::inferRankReducedResultType( 648 insertSliceOp.getSourceType().getRank(), dstMemrefType, 649 mixedOffsets, mixedSizes, mixedStrides) 650 .cast<MemRefType>(); 651 Value subView = rewriter.create<memref::SubViewOp>( 652 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 653 mixedStrides); 654 655 // Copy tensor. If this tensor.insert_slice has a matching 656 // tensor.extract_slice, the copy operation will eventually fold away. 657 Value srcMemref = 658 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 659 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 660 state.getOptions()))) 661 return failure(); 662 663 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 664 return success(); 665 } 666 }; 667 668 /// Bufferization of tensor.rank. Replace with memref.rank. 669 struct RankOpInterface 670 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 671 tensor::RankOp> { 672 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 673 const AnalysisState &state) const { 674 return true; 675 } 676 677 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 678 const AnalysisState &state) const { 679 return false; 680 } 681 682 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 683 const AnalysisState &state) const { 684 return {}; 685 } 686 687 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 688 BufferizationState &state) const { 689 auto rankOp = cast<tensor::RankOp>(op); 690 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 691 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 692 v); 693 return success(); 694 } 695 }; 696 697 } // namespace 698 } // namespace tensor 699 } // namespace mlir 700 701 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 702 DialectRegistry ®istry) { 703 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 704 CastOp::attachInterface<CastOpInterface>(*ctx); 705 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 706 DimOp::attachInterface<DimOpInterface>(*ctx); 707 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 708 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 709 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 710 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 711 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 712 InsertOp::attachInterface<InsertOpInterface>(*ctx); 713 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 714 RankOp::attachInterface<RankOpInterface>(*ctx); 715 }); 716 } 717