1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const BufferizationState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const BufferizationState &state) const { 35 return false; 36 } 37 38 SmallVector<OpResult> 39 getAliasingOpResult(Operation *op, OpOperand &opOperand, 40 const BufferizationState &state) const { 41 return {op->getResult(0)}; 42 } 43 44 BufferRelation bufferRelation(Operation *op, OpResult opResult, 45 const BufferizationState &state) const { 46 return BufferRelation::Equivalent; 47 } 48 49 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 50 const BufferizationState &state) const { 51 auto castOp = cast<tensor::CastOp>(op); 52 53 // The result buffer still has the old (pre-cast) type. 54 FailureOr<Value> resultBuffer = 55 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 56 if (failed(resultBuffer)) 57 return failure(); 58 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 59 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 60 TensorType resultTensorType = 61 castOp.getResult().getType().cast<TensorType>(); 62 MemRefLayoutAttrInterface layout; 63 64 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 65 if (resultTensorType.isa<RankedTensorType>()) 66 layout = rankedMemRefType.getLayout(); 67 68 // Compute the new memref type. 69 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 70 layout, memorySpace); 71 72 // Replace the op with a memref.cast. 73 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 74 resultMemRefType) && 75 "CallOp::bufferize: cast incompatible"); 76 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 77 *resultBuffer); 78 79 return success(); 80 } 81 }; 82 83 /// Bufferization of tensor.dim. Replace with memref.dim. 84 struct DimOpInterface 85 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 86 tensor::DimOp> { 87 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 88 const BufferizationState &state) const { 89 return true; 90 } 91 92 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 93 const BufferizationState &state) const { 94 return false; 95 } 96 97 SmallVector<OpResult> 98 getAliasingOpResult(Operation *op, OpOperand &opOperand, 99 const BufferizationState &state) const { 100 return {}; 101 } 102 103 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 104 const BufferizationState &state) const { 105 auto dimOp = cast<tensor::DimOp>(op); 106 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 107 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 108 return success(); 109 } 110 }; 111 112 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 113 struct ExtractSliceOpInterface 114 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 115 tensor::ExtractSliceOp> { 116 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 117 const BufferizationState &state) const { 118 return false; 119 } 120 121 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 122 const BufferizationState &state) const { 123 return false; 124 } 125 126 SmallVector<OpResult> 127 getAliasingOpResult(Operation *op, OpOperand &opOperand, 128 const BufferizationState &state) const { 129 if (&opOperand == &op->getOpOperand(0) /*source*/) 130 return {op->getOpResult(0)}; 131 return {}; 132 } 133 134 BufferRelation bufferRelation(Operation *op, OpResult opResult, 135 const BufferizationState &state) const { 136 return BufferRelation::None; 137 } 138 139 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 140 const BufferizationState &state) const { 141 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 142 Location loc = extractSliceOp.getLoc(); 143 Value srcMemref = 144 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 145 /*forceInPlace=*/true); 146 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 147 auto dstTensorType = 148 extractSliceOp.result().getType().cast<RankedTensorType>(); 149 150 // If not inplaceable, alloc. 151 bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0)); 152 Value alloc; 153 if (!inplace) { 154 FailureOr<Value> allocOrFailure = 155 createAlloc(rewriter, loc, extractSliceOp.result(), 156 state.getOptions().createDeallocs, state.getOptions()); 157 if (failed(allocOrFailure)) 158 return failure(); 159 alloc = *allocOrFailure; 160 } 161 162 // Expand offsets, sizes and strides to the full rank to handle the 163 // rank-reducing case. 164 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 165 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 166 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 167 OffsetSizeAndStrideOpInterface::expandToRank( 168 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 169 [&](Value target, int64_t dim) -> OpFoldResult { 170 auto shapedType = target.getType().cast<ShapedType>(); 171 if (shapedType.isDynamicDim(dim)) 172 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 173 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 174 }); 175 // Bufferize to subview. 176 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 177 dstTensorType.getRank(), srcMemrefType, 178 mixedOffsets, mixedSizes, mixedStrides) 179 .cast<MemRefType>(); 180 Value subView = rewriter.create<memref::SubViewOp>( 181 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 182 mixedStrides); 183 184 // If not inplaceable, copy. 185 if (!inplace) { 186 // Do not copy if the copied data is never read. 187 if (state.isValueRead(extractSliceOp.result())) 188 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 189 alloc, state.getOptions()))) 190 return failure(); 191 subView = alloc; 192 } 193 194 replaceOpWithBufferizedValues(rewriter, op, subView); 195 return success(); 196 } 197 }; 198 199 /// Bufferization of tensor.extract. Replace with memref.load. 200 struct ExtractOpInterface 201 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 202 tensor::ExtractOp> { 203 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 204 const BufferizationState &state) const { 205 return true; 206 } 207 208 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 209 const BufferizationState &state) const { 210 return false; 211 } 212 213 SmallVector<OpResult> 214 getAliasingOpResult(Operation *op, OpOperand &opOperand, 215 const BufferizationState &state) const { 216 return {}; 217 } 218 219 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 220 const BufferizationState &state) const { 221 auto extractOp = cast<tensor::ExtractOp>(op); 222 Value srcMemref = 223 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 224 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 225 extractOp.indices()); 226 return success(); 227 } 228 }; 229 230 // Implements backtracking to traverse indices of the output buffer while 231 // iterating over op.elements(). 232 static void createStores(RewriterBase &rewriter, Location loc, int dim, 233 Value buffer, ArrayRef<int64_t> shape, 234 ArrayRef<Value> constants, 235 OperandRange::iterator &elementIt, 236 SmallVectorImpl<Value> &indices) { 237 if (dim == static_cast<int>(shape.size()) - 1) { 238 for (int i = 0; i < shape.back(); ++i) { 239 indices.back() = constants[i]; 240 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 241 ++elementIt; 242 } 243 return; 244 } 245 for (int i = 0; i < shape[dim]; ++i) { 246 indices[dim] = constants[i]; 247 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 248 indices); 249 } 250 } 251 252 /// Bufferization of tensor.from_elements. 253 struct FromElementsOpInterface 254 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 255 tensor::FromElementsOp> { 256 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 257 const BufferizationState &state) const { 258 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 259 260 // Allocate a buffer for the result. 261 Location loc = op->getLoc(); 262 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 263 auto shape = tensorType.getShape(); 264 MemRefType resultType = getContiguousMemRefType(tensorType); 265 FailureOr<Value> maybeBuffer = 266 createAlloc(rewriter, loc, resultType, {}, 267 /*deallocMemref=*/state.getOptions().createDeallocs, 268 state.getOptions()); 269 if (failed(maybeBuffer)) 270 return failure(); 271 Value buffer = *maybeBuffer; 272 273 // Case: tensor<0xelem_type>. 274 if (fromElementsOp.elements().empty()) { 275 replaceOpWithBufferizedValues(rewriter, op, buffer); 276 return success(); 277 } 278 279 // Case: tensor<elem_type>. 280 if (shape.empty()) { 281 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 282 buffer); 283 replaceOpWithBufferizedValues(rewriter, op, buffer); 284 return success(); 285 } 286 287 // Create constants for the range of possible indices [0, max{shape_i}). 288 auto maxDim = *std::max_element(shape.begin(), shape.end()); 289 SmallVector<Value, 2> constants; 290 constants.reserve(maxDim); 291 for (int i = 0; i < maxDim; ++i) 292 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 293 294 // Traverse all `elements` and create `memref.store` ops. 295 auto elementIt = fromElementsOp.elements().begin(); 296 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 297 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 298 indices); 299 300 replaceOpWithBufferizedValues(rewriter, op, buffer); 301 return success(); 302 } 303 }; 304 305 /// Bufferization of tensor.generate. 306 struct GenerateOpInterface 307 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 308 tensor::GenerateOp> { 309 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 310 const BufferizationState &state) const { 311 auto generateOp = cast<tensor::GenerateOp>(op); 312 313 // Allocate memory. 314 Location loc = op->getLoc(); 315 MemRefType memrefType = 316 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 317 FailureOr<Value> maybeResult = 318 createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), 319 /*deallocMemref=*/state.getOptions().createDeallocs, 320 state.getOptions()); 321 if (failed(maybeResult)) 322 return failure(); 323 Value result = *maybeResult; 324 325 // Collect loop bounds. 326 int64_t rank = memrefType.getRank(); 327 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 328 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 329 SmallVector<Value, 4> lowerBounds(rank, zero); 330 SmallVector<Value, 4> steps(rank, one); 331 SmallVector<Value, 4> upperBounds; 332 int nextDynamicIndex = 0; 333 for (int i = 0; i < rank; i++) { 334 Value upperBound = memrefType.isDynamicDim(i) 335 ? generateOp.dynamicExtents()[nextDynamicIndex++] 336 : rewriter.create<arith::ConstantIndexOp>( 337 loc, memrefType.getDimSize(i)); 338 upperBounds.push_back(upperBound); 339 } 340 341 // Generate tensor elements with a parallel loop that stores into 342 // each element of the resulting memref. We use mergeBlockBefore to "move" 343 // this op's body into the scf.parallel's body. 344 auto parallel = 345 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 346 Block *parallelBody = parallel.getBody(); 347 rewriter.mergeBlockBefore(generateOp.getBody(), 348 parallelBody->getTerminator(), 349 parallelBody->getArguments()); 350 // Replace the inlined yield op with a store op. The scf.parallel's builder 351 // already populated an scf.yield at the end, so we don't need to worry 352 // about creating that. 353 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 354 rewriter.setInsertionPointAfter(elementYield); 355 rewriter.replaceOpWithNewOp<memref::StoreOp>( 356 elementYield, elementYield->getOperands()[0], result, 357 parallelBody->getArguments()); 358 359 replaceOpWithBufferizedValues(rewriter, op, result); 360 return success(); 361 } 362 }; 363 364 /// Bufferization of tensor.insert. Replace with memref.store. 365 struct InsertOpInterface 366 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 367 tensor::InsertOp> { 368 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 369 const BufferizationState &state) const { 370 return true; 371 } 372 373 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 374 const BufferizationState &state) const { 375 return true; 376 } 377 378 SmallVector<OpResult> 379 getAliasingOpResult(Operation *op, OpOperand &opOperand, 380 const BufferizationState &state) const { 381 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 382 "expected dest OpOperand"); 383 return {op->getOpResult(0)}; 384 } 385 386 SmallVector<OpOperand *> 387 getAliasingOpOperand(Operation *op, OpResult opResult, 388 const BufferizationState &state) const { 389 return {&op->getOpOperand(1) /*dest*/}; 390 } 391 392 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 393 const BufferizationState &state) const { 394 auto insertOp = cast<tensor::InsertOp>(op); 395 FailureOr<Value> destMemref = 396 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 397 if (failed(destMemref)) 398 return failure(); 399 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 400 *destMemref, insertOp.indices()); 401 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 402 return success(); 403 } 404 405 BufferRelation bufferRelation(Operation *op, OpResult opResult, 406 const BufferizationState &state) const { 407 return BufferRelation::Equivalent; 408 } 409 }; 410 411 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 412 /// equivalent operand / result and same offset/sizes/strides specification). 413 /// 414 /// This is one particular type of relationship between ops on tensors that 415 /// reduce to an equivalence on buffers. This should be generalized and 416 /// exposed as interfaces on the proper types. 417 static bool areEquivalentExtractSliceOps(const BufferizationState &state, 418 ExtractSliceOp st, InsertSliceOp sti) { 419 if (!st || !sti) 420 return false; 421 if (sti != sti && 422 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 423 return false; 424 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 425 return false; 426 return true; 427 } 428 429 /// Return true if `value` is originating from an ExtractSliceOp that matches 430 /// the given InsertSliceOp. 431 static bool hasMatchingExtractSliceOp(const BufferizationState &state, 432 Value value, InsertSliceOp insertOp) { 433 auto condition = [&](Value val) { 434 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 435 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 436 return true; 437 return false; 438 }; 439 440 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 441 condition); 442 } 443 444 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 445 /// certain circumstances, this op can also be a no-op. 446 struct InsertSliceOpInterface 447 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 448 tensor::InsertSliceOp> { 449 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 450 const BufferizationState &state) const { 451 return true; 452 } 453 454 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 455 const BufferizationState &state) const { 456 return &opOperand == &op->getOpOperand(1) /*dest*/; 457 } 458 459 SmallVector<OpResult> 460 getAliasingOpResult(Operation *op, OpOperand &opOperand, 461 const BufferizationState &state) const { 462 if (&opOperand == &op->getOpOperand(1) /*dest*/) 463 return {op->getResult(0)}; 464 return {}; 465 } 466 467 BufferRelation bufferRelation(Operation *op, OpResult opResult, 468 const BufferizationState &state) const { 469 return BufferRelation::Equivalent; 470 } 471 472 bool isNotConflicting(Operation *op, OpOperand *uRead, 473 OpOperand *uConflictingWrite, 474 const BufferizationState &state) const { 475 Operation *readingOp = uRead->getOwner(); 476 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 477 478 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 479 // uRead is an InsertSliceOp... 480 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 481 // As an example, consider the following IR. 482 // 483 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 484 // %1 = linalg.fill %cst, %0 {inplace= [true] } 485 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 486 // {inplace= [true] } 487 488 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 489 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 490 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 491 insertSliceOp)) 492 // Case 1: The main insight is that InsertSliceOp reads only part of 493 // the destination tensor. The overwritten area is not read. If 494 // uConflictingWrite writes into exactly the memory location that is 495 // being read by uRead, this is not a conflict. 496 // 497 // In the above example: 498 // uRead = OpOperand 1 (%t) of tensor.insert_slice 499 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 500 // 501 // The read of %t does not conflict with the write of the FillOp 502 // (same aliases!) because the area that the FillOp operates on is 503 // exactly the one that is *not* read via %t. 504 return true; 505 506 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 507 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 508 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 509 // Case 2: The read of the source tensor and the write to the dest 510 // tensor via an InsertSliceOp is not a conflict if the read is 511 // reading exactly that part of an equivalent tensor that the 512 // InsertSliceOp is writing. 513 // 514 // In the above example: 515 // uRead = OpOperand 0 (%1) of tensor.insert_slice 516 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 517 return true; 518 } 519 520 // If uConflictingWrite is an InsertSliceOp... 521 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 522 // As an example, consider the following IR. 523 // 524 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 525 // %1 = linalg.fill %cst, %0 {inplace= [true] } 526 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 527 // {inplace= [true] } 528 // %3 = vector.transfer_read %1, %cst 529 // 530 // In the above example: 531 // uRead = OpOperand 0 (%1) of vector.transfer_read 532 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 533 // lastWrite = %1 534 // 535 // This is not a conflict because the InsertSliceOp overwrites the 536 // memory segment of %1 with the exact same data. (Effectively, there 537 // is no memory write here.) 538 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 539 state.areEquivalentBufferizedValues(uRead->get(), 540 insertSliceOp.source()) && 541 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 542 insertSliceOp)) 543 return true; 544 545 return false; 546 } 547 548 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 549 const BufferizationState &state) const { 550 // insert_slice ops arise from tiling and bufferizing them out-of-place is 551 // generally a deal breaker. When used with loops, this ends up cloning the 552 // whole tensor on every single iteration and is a symptom of a 553 // catastrophically bad scheduling decision. 554 // TODO: be very loud about it or even consider failing the pass. 555 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 556 Location loc = insertSliceOp.getLoc(); 557 558 // When bufferizing out-of-place, `getResultBuffer` allocates. 559 FailureOr<Value> dstMemref = 560 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 561 if (failed(dstMemref)) 562 return failure(); 563 564 // Expand offsets, sizes and strides to the full rank to handle the 565 // rank-reducing case. 566 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 567 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 568 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 569 OffsetSizeAndStrideOpInterface::expandToRank( 570 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 571 [&](Value target, int64_t dim) -> OpFoldResult { 572 auto shapedType = target.getType().cast<ShapedType>(); 573 if (shapedType.isDynamicDim(dim)) 574 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 575 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 576 }); 577 // Take a subview of the dst. 578 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 579 auto subviewMemRefType = 580 memref::SubViewOp::inferRankReducedResultType( 581 insertSliceOp.getSourceType().getRank(), dstMemrefType, 582 mixedOffsets, mixedSizes, mixedStrides) 583 .cast<MemRefType>(); 584 Value subView = rewriter.create<memref::SubViewOp>( 585 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 586 mixedStrides); 587 588 // Copy tensor. If this tensor.insert_slice has a matching 589 // tensor.extract_slice, the copy operation will eventually fold away. 590 Value srcMemref = 591 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 592 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 593 state.getOptions()))) 594 return failure(); 595 596 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 597 return success(); 598 } 599 }; 600 601 /// Bufferization of tensor.rank. Replace with memref.rank. 602 struct RankOpInterface 603 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 604 tensor::RankOp> { 605 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 606 const BufferizationState &state) const { 607 return true; 608 } 609 610 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 611 const BufferizationState &state) const { 612 return false; 613 } 614 615 SmallVector<OpResult> 616 getAliasingOpResult(Operation *op, OpOperand &opOperand, 617 const BufferizationState &state) const { 618 return {}; 619 } 620 621 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 622 const BufferizationState &state) const { 623 auto rankOp = cast<tensor::RankOp>(op); 624 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 625 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 626 v); 627 return success(); 628 } 629 }; 630 631 } // namespace 632 } // namespace tensor 633 } // namespace mlir 634 635 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 636 DialectRegistry ®istry) { 637 registry.addOpInterface<CastOp, CastOpInterface>(); 638 registry.addOpInterface<DimOp, DimOpInterface>(); 639 registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 640 registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 641 registry.addOpInterface<FromElementsOp, FromElementsOpInterface>(); 642 registry.addOpInterface<GenerateOp, GenerateOpInterface>(); 643 registry.addOpInterface<InsertOp, InsertOpInterface>(); 644 registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 645 registry.addOpInterface<RankOp, RankOpInterface>(); 646 } 647