1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const BufferizationState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const BufferizationState &state) const { 35 return false; 36 } 37 38 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 39 const BufferizationState &state) const { 40 return op->getResult(0); 41 } 42 43 BufferRelation bufferRelation(Operation *op, OpResult opResult, 44 const BufferizationState &state) const { 45 return BufferRelation::Equivalent; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 const BufferizationState &state) const { 50 auto castOp = cast<tensor::CastOp>(op); 51 52 // The result buffer still has the old (pre-cast) type. 53 FailureOr<Value> resultBuffer = 54 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 55 if (failed(resultBuffer)) 56 return failure(); 57 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 58 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 59 TensorType resultTensorType = 60 castOp.getResult().getType().cast<TensorType>(); 61 MemRefLayoutAttrInterface layout; 62 63 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 64 if (resultTensorType.isa<RankedTensorType>()) 65 layout = rankedMemRefType.getLayout(); 66 67 // Compute the new memref type. 68 Type resultMemRefType; 69 if (resultTensorType.isa<RankedTensorType>()) { 70 resultMemRefType = 71 getContiguousMemRefType(resultTensorType, layout, memorySpace); 72 } else { 73 resultMemRefType = 74 getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace); 75 } 76 77 // Replace the op with a memref.cast. 78 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 79 resultMemRefType) && 80 "CallOp::bufferize: cast incompatible"); 81 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 82 *resultBuffer); 83 84 return success(); 85 } 86 }; 87 88 /// Bufferization of tensor.dim. Replace with memref.dim. 89 struct DimOpInterface 90 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 91 tensor::DimOp> { 92 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 93 const BufferizationState &state) const { 94 return true; 95 } 96 97 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 98 const BufferizationState &state) const { 99 return false; 100 } 101 102 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 103 const BufferizationState &state) const { 104 return OpResult(); 105 } 106 107 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 108 const BufferizationState &state) const { 109 auto dimOp = cast<tensor::DimOp>(op); 110 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 111 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 112 return success(); 113 } 114 }; 115 116 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 117 struct ExtractSliceOpInterface 118 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 119 tensor::ExtractSliceOp> { 120 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 121 const BufferizationState &state) const { 122 return false; 123 } 124 125 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 126 const BufferizationState &state) const { 127 return false; 128 } 129 130 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 131 const BufferizationState &state) const { 132 return &opOperand == &op->getOpOperand(0) /*source*/ 133 ? op->getResult(0) 134 : OpResult(); 135 } 136 137 BufferRelation bufferRelation(Operation *op, OpResult opResult, 138 const BufferizationState &state) const { 139 return BufferRelation::None; 140 } 141 142 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 143 const BufferizationState &state) const { 144 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 145 Location loc = extractSliceOp.getLoc(); 146 Value srcMemref = 147 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 148 /*forceInPlace=*/true); 149 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 150 auto dstTensorType = 151 extractSliceOp.result().getType().cast<RankedTensorType>(); 152 153 // If not inplaceable, alloc. 154 bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0)); 155 Value alloc; 156 if (!inplace) { 157 FailureOr<Value> allocOrFailure = 158 createAlloc(rewriter, loc, extractSliceOp.result(), 159 state.getOptions().createDeallocs, state.getOptions()); 160 if (failed(allocOrFailure)) 161 return failure(); 162 alloc = *allocOrFailure; 163 } 164 165 // Expand offsets, sizes and strides to the full rank to handle the 166 // rank-reducing case. 167 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 168 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 169 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 170 OffsetSizeAndStrideOpInterface::expandToRank( 171 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 172 [&](Value target, int64_t dim) -> OpFoldResult { 173 auto shapedType = target.getType().cast<ShapedType>(); 174 if (shapedType.isDynamicDim(dim)) 175 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 176 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 177 }); 178 // Bufferize to subview. 179 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 180 dstTensorType.getRank(), srcMemrefType, 181 mixedOffsets, mixedSizes, mixedStrides) 182 .cast<MemRefType>(); 183 Value subView = rewriter.create<memref::SubViewOp>( 184 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 185 mixedStrides); 186 187 // If not inplaceable, copy. 188 if (!inplace) { 189 // Do not copy if the copied data is never read. 190 if (state.isValueRead(extractSliceOp.result())) 191 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 192 alloc, state.getOptions()))) 193 return failure(); 194 subView = alloc; 195 } 196 197 replaceOpWithBufferizedValues(rewriter, op, subView); 198 return success(); 199 } 200 }; 201 202 /// Bufferization of tensor.extract. Replace with memref.load. 203 struct ExtractOpInterface 204 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 205 tensor::ExtractOp> { 206 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 207 const BufferizationState &state) const { 208 return true; 209 } 210 211 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 212 const BufferizationState &state) const { 213 return false; 214 } 215 216 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 217 const BufferizationState &state) const { 218 return OpResult(); 219 } 220 221 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 222 const BufferizationState &state) const { 223 auto extractOp = cast<tensor::ExtractOp>(op); 224 Value srcMemref = 225 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 226 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 227 extractOp.indices()); 228 return success(); 229 } 230 }; 231 232 // Implements backtracking to traverse indices of the output buffer while 233 // iterating over op.elements(). 234 static void createStores(RewriterBase &rewriter, Location loc, int dim, 235 Value buffer, ArrayRef<int64_t> shape, 236 ArrayRef<Value> constants, 237 OperandRange::iterator &elementIt, 238 SmallVectorImpl<Value> &indices) { 239 if (dim == static_cast<int>(shape.size()) - 1) { 240 for (int i = 0; i < shape.back(); ++i) { 241 indices.back() = constants[i]; 242 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 243 ++elementIt; 244 } 245 return; 246 } 247 for (int i = 0; i < shape[dim]; ++i) { 248 indices[dim] = constants[i]; 249 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 250 indices); 251 } 252 } 253 254 /// Bufferization of tensor.from_elements. 255 struct FromElementsOpInterface 256 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 257 tensor::FromElementsOp> { 258 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 259 const BufferizationState &state) const { 260 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 261 262 // Allocate a buffer for the result. 263 Location loc = op->getLoc(); 264 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 265 auto shape = tensorType.getShape(); 266 MemRefType resultType = 267 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 268 FailureOr<Value> maybeBuffer = 269 createAlloc(rewriter, loc, resultType, {}, 270 /*deallocMemref=*/state.getOptions().createDeallocs, 271 state.getOptions()); 272 if (failed(maybeBuffer)) 273 return failure(); 274 Value buffer = *maybeBuffer; 275 276 // Case: tensor<0xelem_type>. 277 if (fromElementsOp.elements().empty()) { 278 replaceOpWithBufferizedValues(rewriter, op, buffer); 279 return success(); 280 } 281 282 // Case: tensor<elem_type>. 283 if (shape.empty()) { 284 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 285 buffer); 286 replaceOpWithBufferizedValues(rewriter, op, buffer); 287 return success(); 288 } 289 290 // Create constants for the range of possible indices [0, max{shape_i}). 291 auto maxDim = *std::max_element(shape.begin(), shape.end()); 292 SmallVector<Value, 2> constants; 293 constants.reserve(maxDim); 294 for (int i = 0; i < maxDim; ++i) 295 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 296 297 // Traverse all `elements` and create `memref.store` ops. 298 auto elementIt = fromElementsOp.elements().begin(); 299 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 300 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 301 indices); 302 303 replaceOpWithBufferizedValues(rewriter, op, buffer); 304 return success(); 305 } 306 }; 307 308 /// Bufferization of tensor.generate. 309 struct GenerateOpInterface 310 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 311 tensor::GenerateOp> { 312 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 313 const BufferizationState &state) const { 314 auto generateOp = cast<tensor::GenerateOp>(op); 315 316 // Allocate memory. 317 Location loc = op->getLoc(); 318 MemRefType memrefType = 319 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 320 FailureOr<Value> maybeResult = 321 createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), 322 /*deallocMemref=*/state.getOptions().createDeallocs, 323 state.getOptions()); 324 if (failed(maybeResult)) 325 return failure(); 326 Value result = *maybeResult; 327 328 // Collect loop bounds. 329 int64_t rank = memrefType.getRank(); 330 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 331 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 332 SmallVector<Value, 4> lowerBounds(rank, zero); 333 SmallVector<Value, 4> steps(rank, one); 334 SmallVector<Value, 4> upperBounds; 335 int nextDynamicIndex = 0; 336 for (int i = 0; i < rank; i++) { 337 Value upperBound = memrefType.isDynamicDim(i) 338 ? generateOp.dynamicExtents()[nextDynamicIndex++] 339 : rewriter.create<arith::ConstantIndexOp>( 340 loc, memrefType.getDimSize(i)); 341 upperBounds.push_back(upperBound); 342 } 343 344 // Generate tensor elements with a parallel loop that stores into 345 // each element of the resulting memref. We use mergeBlockBefore to "move" 346 // this op's body into the scf.parallel's body. 347 auto parallel = 348 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 349 Block *parallelBody = parallel.getBody(); 350 rewriter.mergeBlockBefore(generateOp.getBody(), 351 parallelBody->getTerminator(), 352 parallelBody->getArguments()); 353 // Replace the inlined yield op with a store op. The scf.parallel's builder 354 // already populated an scf.yield at the end, so we don't need to worry 355 // about creating that. 356 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 357 rewriter.setInsertionPointAfter(elementYield); 358 rewriter.replaceOpWithNewOp<memref::StoreOp>( 359 elementYield, elementYield->getOperands()[0], result, 360 parallelBody->getArguments()); 361 362 replaceOpWithBufferizedValues(rewriter, op, result); 363 return success(); 364 } 365 }; 366 367 /// Bufferization of tensor.insert. Replace with memref.store. 368 struct InsertOpInterface 369 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 370 tensor::InsertOp> { 371 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 372 const BufferizationState &state) const { 373 return true; 374 } 375 376 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 377 const BufferizationState &state) const { 378 return true; 379 } 380 381 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 382 const BufferizationState &state) const { 383 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 384 "expected dest OpOperand"); 385 return op->getOpResult(0); 386 } 387 388 SmallVector<OpOperand *> 389 getAliasingOpOperand(Operation *op, OpResult opResult, 390 const BufferizationState &state) const { 391 return {&op->getOpOperand(1) /*dest*/}; 392 } 393 394 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 395 const BufferizationState &state) const { 396 auto insertOp = cast<tensor::InsertOp>(op); 397 FailureOr<Value> destMemref = 398 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 399 if (failed(destMemref)) 400 return failure(); 401 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 402 *destMemref, insertOp.indices()); 403 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 404 return success(); 405 } 406 407 BufferRelation bufferRelation(Operation *op, OpResult opResult, 408 const BufferizationState &state) const { 409 return BufferRelation::Equivalent; 410 } 411 }; 412 413 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 414 /// equivalent operand / result and same offset/sizes/strides specification). 415 /// 416 /// This is one particular type of relationship between ops on tensors that 417 /// reduce to an equivalence on buffers. This should be generalized and 418 /// exposed as interfaces on the proper types. 419 static bool areEquivalentExtractSliceOps(const BufferizationState &state, 420 ExtractSliceOp st, InsertSliceOp sti) { 421 if (!st || !sti) 422 return false; 423 if (sti != sti && 424 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 425 return false; 426 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 427 return false; 428 return true; 429 } 430 431 /// Return true if `value` is originating from an ExtractSliceOp that matches 432 /// the given InsertSliceOp. 433 static bool hasMatchingExtractSliceOp(const BufferizationState &state, 434 Value value, InsertSliceOp insertOp) { 435 auto condition = [&](Value val) { 436 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 437 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 438 return true; 439 return false; 440 }; 441 442 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 443 condition); 444 } 445 446 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 447 /// certain circumstances, this op can also be a no-op. 448 struct InsertSliceOpInterface 449 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 450 tensor::InsertSliceOp> { 451 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 452 const BufferizationState &state) const { 453 return true; 454 } 455 456 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 457 const BufferizationState &state) const { 458 return &opOperand == &op->getOpOperand(1) /*dest*/; 459 } 460 461 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 462 const BufferizationState &state) const { 463 return &opOperand == &op->getOpOperand(1) /*dest*/ 464 ? op->getResult(0) 465 : OpResult(); 466 } 467 468 BufferRelation bufferRelation(Operation *op, OpResult opResult, 469 const BufferizationState &state) const { 470 return BufferRelation::Equivalent; 471 } 472 473 bool isNotConflicting(Operation *op, OpOperand *uRead, 474 OpOperand *uConflictingWrite, 475 const BufferizationState &state) const { 476 Operation *readingOp = uRead->getOwner(); 477 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 478 479 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 480 // uRead is an InsertSliceOp... 481 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 482 // As an example, consider the following IR. 483 // 484 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 485 // %1 = linalg.fill %cst, %0 {inplace= [true] } 486 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 487 // {inplace= [true] } 488 489 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 490 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 491 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 492 insertSliceOp)) 493 // Case 1: The main insight is that InsertSliceOp reads only part of 494 // the destination tensor. The overwritten area is not read. If 495 // uConflictingWrite writes into exactly the memory location that is 496 // being read by uRead, this is not a conflict. 497 // 498 // In the above example: 499 // uRead = OpOperand 1 (%t) of tensor.insert_slice 500 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 501 // 502 // The read of %t does not conflict with the write of the FillOp 503 // (same aliases!) because the area that the FillOp operates on is 504 // exactly the one that is *not* read via %t. 505 return true; 506 507 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 508 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 509 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 510 // Case 2: The read of the source tensor and the write to the dest 511 // tensor via an InsertSliceOp is not a conflict if the read is 512 // reading exactly that part of an equivalent tensor that the 513 // InsertSliceOp is writing. 514 // 515 // In the above example: 516 // uRead = OpOperand 0 (%1) of tensor.insert_slice 517 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 518 return true; 519 } 520 521 // If uConflictingWrite is an InsertSliceOp... 522 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 523 // As an example, consider the following IR. 524 // 525 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 526 // %1 = linalg.fill %cst, %0 {inplace= [true] } 527 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 528 // {inplace= [true] } 529 // %3 = vector.transfer_read %1, %cst 530 // 531 // In the above example: 532 // uRead = OpOperand 0 (%1) of vector.transfer_read 533 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 534 // lastWrite = %1 535 // 536 // This is not a conflict because the InsertSliceOp overwrites the 537 // memory segment of %1 with the exact same data. (Effectively, there 538 // is no memory write here.) 539 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 540 state.areEquivalentBufferizedValues(uRead->get(), 541 insertSliceOp.source()) && 542 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 543 insertSliceOp)) 544 return true; 545 546 return false; 547 } 548 549 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 550 const BufferizationState &state) const { 551 // insert_slice ops arise from tiling and bufferizing them out-of-place is 552 // generally a deal breaker. When used with loops, this ends up cloning the 553 // whole tensor on every single iteration and is a symptom of a 554 // catastrophically bad scheduling decision. 555 // TODO: be very loud about it or even consider failing the pass. 556 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 557 Location loc = insertSliceOp.getLoc(); 558 559 // When bufferizing out-of-place, `getResultBuffer` allocates. 560 FailureOr<Value> dstMemref = 561 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 562 if (failed(dstMemref)) 563 return failure(); 564 565 // Expand offsets, sizes and strides to the full rank to handle the 566 // rank-reducing case. 567 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 568 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 569 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 570 OffsetSizeAndStrideOpInterface::expandToRank( 571 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 572 [&](Value target, int64_t dim) -> OpFoldResult { 573 auto shapedType = target.getType().cast<ShapedType>(); 574 if (shapedType.isDynamicDim(dim)) 575 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 576 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 577 }); 578 // Take a subview of the dst. 579 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 580 auto subviewMemRefType = 581 memref::SubViewOp::inferRankReducedResultType( 582 insertSliceOp.getSourceType().getRank(), dstMemrefType, 583 mixedOffsets, mixedSizes, mixedStrides) 584 .cast<MemRefType>(); 585 Value subView = rewriter.create<memref::SubViewOp>( 586 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 587 mixedStrides); 588 589 // Copy tensor. If this tensor.insert_slice has a matching 590 // tensor.extract_slice, the copy operation will eventually fold away. 591 Value srcMemref = 592 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 593 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 594 state.getOptions()))) 595 return failure(); 596 597 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 598 return success(); 599 } 600 }; 601 602 /// Bufferization of tensor.rank. Replace with memref.rank. 603 struct RankOpInterface 604 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 605 tensor::RankOp> { 606 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 607 const BufferizationState &state) const { 608 return true; 609 } 610 611 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 612 const BufferizationState &state) const { 613 return false; 614 } 615 616 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 617 const BufferizationState &state) const { 618 return OpResult(); 619 } 620 621 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 622 const BufferizationState &state) const { 623 auto rankOp = cast<tensor::RankOp>(op); 624 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 625 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 626 v); 627 return success(); 628 } 629 }; 630 631 } // namespace 632 } // namespace tensor 633 } // namespace mlir 634 635 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 636 DialectRegistry ®istry) { 637 registry.addOpInterface<CastOp, CastOpInterface>(); 638 registry.addOpInterface<DimOp, DimOpInterface>(); 639 registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 640 registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 641 registry.addOpInterface<FromElementsOp, FromElementsOpInterface>(); 642 registry.addOpInterface<GenerateOp, GenerateOpInterface>(); 643 registry.addOpInterface<InsertOp, InsertOpInterface>(); 644 registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 645 registry.addOpInterface<RankOp, RankOpInterface>(); 646 } 647