1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const BufferizationState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const BufferizationState &state) const { 35 return false; 36 } 37 38 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 39 const BufferizationState &state) const { 40 return op->getResult(0); 41 } 42 43 BufferRelation bufferRelation(Operation *op, OpResult opResult, 44 const BufferizationState &state) const { 45 return BufferRelation::Equivalent; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 const BufferizationState &state) const { 50 auto castOp = cast<tensor::CastOp>(op); 51 52 // The result buffer still has the old (pre-cast) type. 53 FailureOr<Value> resultBuffer = 54 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 55 if (failed(resultBuffer)) 56 return failure(); 57 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 58 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 59 TensorType resultTensorType = 60 castOp.getResult().getType().cast<TensorType>(); 61 MemRefLayoutAttrInterface layout; 62 63 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 64 if (resultTensorType.isa<RankedTensorType>()) 65 layout = rankedMemRefType.getLayout(); 66 67 // Compute the new memref type. 68 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 69 layout, memorySpace); 70 71 // Replace the op with a memref.cast. 72 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 73 resultMemRefType) && 74 "CallOp::bufferize: cast incompatible"); 75 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 76 *resultBuffer); 77 78 return success(); 79 } 80 }; 81 82 /// Bufferization of tensor.dim. Replace with memref.dim. 83 struct DimOpInterface 84 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 85 tensor::DimOp> { 86 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 87 const BufferizationState &state) const { 88 return true; 89 } 90 91 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 92 const BufferizationState &state) const { 93 return false; 94 } 95 96 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 97 const BufferizationState &state) const { 98 return OpResult(); 99 } 100 101 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 102 const BufferizationState &state) const { 103 auto dimOp = cast<tensor::DimOp>(op); 104 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 105 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 106 return success(); 107 } 108 }; 109 110 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 111 struct ExtractSliceOpInterface 112 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 113 tensor::ExtractSliceOp> { 114 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 115 const BufferizationState &state) const { 116 return false; 117 } 118 119 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 120 const BufferizationState &state) const { 121 return false; 122 } 123 124 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 125 const BufferizationState &state) const { 126 return &opOperand == &op->getOpOperand(0) /*source*/ 127 ? op->getResult(0) 128 : OpResult(); 129 } 130 131 BufferRelation bufferRelation(Operation *op, OpResult opResult, 132 const BufferizationState &state) const { 133 return BufferRelation::None; 134 } 135 136 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 137 const BufferizationState &state) const { 138 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 139 Location loc = extractSliceOp.getLoc(); 140 Value srcMemref = 141 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 142 /*forceInPlace=*/true); 143 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 144 auto dstTensorType = 145 extractSliceOp.result().getType().cast<RankedTensorType>(); 146 147 // If not inplaceable, alloc. 148 bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0)); 149 Value alloc; 150 if (!inplace) { 151 FailureOr<Value> allocOrFailure = 152 createAlloc(rewriter, loc, extractSliceOp.result(), 153 state.getOptions().createDeallocs, state.getOptions()); 154 if (failed(allocOrFailure)) 155 return failure(); 156 alloc = *allocOrFailure; 157 } 158 159 // Expand offsets, sizes and strides to the full rank to handle the 160 // rank-reducing case. 161 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 162 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 163 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 164 OffsetSizeAndStrideOpInterface::expandToRank( 165 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 166 [&](Value target, int64_t dim) -> OpFoldResult { 167 auto shapedType = target.getType().cast<ShapedType>(); 168 if (shapedType.isDynamicDim(dim)) 169 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 170 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 171 }); 172 // Bufferize to subview. 173 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 174 dstTensorType.getRank(), srcMemrefType, 175 mixedOffsets, mixedSizes, mixedStrides) 176 .cast<MemRefType>(); 177 Value subView = rewriter.create<memref::SubViewOp>( 178 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 179 mixedStrides); 180 181 // If not inplaceable, copy. 182 if (!inplace) { 183 // Do not copy if the copied data is never read. 184 if (state.isValueRead(extractSliceOp.result())) 185 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 186 alloc, state.getOptions()))) 187 return failure(); 188 subView = alloc; 189 } 190 191 replaceOpWithBufferizedValues(rewriter, op, subView); 192 return success(); 193 } 194 }; 195 196 /// Bufferization of tensor.extract. Replace with memref.load. 197 struct ExtractOpInterface 198 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 199 tensor::ExtractOp> { 200 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 201 const BufferizationState &state) const { 202 return true; 203 } 204 205 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 206 const BufferizationState &state) const { 207 return false; 208 } 209 210 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 211 const BufferizationState &state) const { 212 return OpResult(); 213 } 214 215 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 216 const BufferizationState &state) const { 217 auto extractOp = cast<tensor::ExtractOp>(op); 218 Value srcMemref = 219 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 220 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 221 extractOp.indices()); 222 return success(); 223 } 224 }; 225 226 // Implements backtracking to traverse indices of the output buffer while 227 // iterating over op.elements(). 228 static void createStores(RewriterBase &rewriter, Location loc, int dim, 229 Value buffer, ArrayRef<int64_t> shape, 230 ArrayRef<Value> constants, 231 OperandRange::iterator &elementIt, 232 SmallVectorImpl<Value> &indices) { 233 if (dim == static_cast<int>(shape.size()) - 1) { 234 for (int i = 0; i < shape.back(); ++i) { 235 indices.back() = constants[i]; 236 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 237 ++elementIt; 238 } 239 return; 240 } 241 for (int i = 0; i < shape[dim]; ++i) { 242 indices[dim] = constants[i]; 243 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 244 indices); 245 } 246 } 247 248 /// Bufferization of tensor.from_elements. 249 struct FromElementsOpInterface 250 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 251 tensor::FromElementsOp> { 252 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 253 const BufferizationState &state) const { 254 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 255 256 // Allocate a buffer for the result. 257 Location loc = op->getLoc(); 258 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 259 auto shape = tensorType.getShape(); 260 MemRefType resultType = getContiguousMemRefType(tensorType); 261 FailureOr<Value> maybeBuffer = 262 createAlloc(rewriter, loc, resultType, {}, 263 /*deallocMemref=*/state.getOptions().createDeallocs, 264 state.getOptions()); 265 if (failed(maybeBuffer)) 266 return failure(); 267 Value buffer = *maybeBuffer; 268 269 // Case: tensor<0xelem_type>. 270 if (fromElementsOp.elements().empty()) { 271 replaceOpWithBufferizedValues(rewriter, op, buffer); 272 return success(); 273 } 274 275 // Case: tensor<elem_type>. 276 if (shape.empty()) { 277 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 278 buffer); 279 replaceOpWithBufferizedValues(rewriter, op, buffer); 280 return success(); 281 } 282 283 // Create constants for the range of possible indices [0, max{shape_i}). 284 auto maxDim = *std::max_element(shape.begin(), shape.end()); 285 SmallVector<Value, 2> constants; 286 constants.reserve(maxDim); 287 for (int i = 0; i < maxDim; ++i) 288 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 289 290 // Traverse all `elements` and create `memref.store` ops. 291 auto elementIt = fromElementsOp.elements().begin(); 292 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 293 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 294 indices); 295 296 replaceOpWithBufferizedValues(rewriter, op, buffer); 297 return success(); 298 } 299 }; 300 301 /// Bufferization of tensor.generate. 302 struct GenerateOpInterface 303 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 304 tensor::GenerateOp> { 305 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 306 const BufferizationState &state) const { 307 auto generateOp = cast<tensor::GenerateOp>(op); 308 309 // Allocate memory. 310 Location loc = op->getLoc(); 311 MemRefType memrefType = 312 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 313 FailureOr<Value> maybeResult = 314 createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), 315 /*deallocMemref=*/state.getOptions().createDeallocs, 316 state.getOptions()); 317 if (failed(maybeResult)) 318 return failure(); 319 Value result = *maybeResult; 320 321 // Collect loop bounds. 322 int64_t rank = memrefType.getRank(); 323 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 324 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 325 SmallVector<Value, 4> lowerBounds(rank, zero); 326 SmallVector<Value, 4> steps(rank, one); 327 SmallVector<Value, 4> upperBounds; 328 int nextDynamicIndex = 0; 329 for (int i = 0; i < rank; i++) { 330 Value upperBound = memrefType.isDynamicDim(i) 331 ? generateOp.dynamicExtents()[nextDynamicIndex++] 332 : rewriter.create<arith::ConstantIndexOp>( 333 loc, memrefType.getDimSize(i)); 334 upperBounds.push_back(upperBound); 335 } 336 337 // Generate tensor elements with a parallel loop that stores into 338 // each element of the resulting memref. We use mergeBlockBefore to "move" 339 // this op's body into the scf.parallel's body. 340 auto parallel = 341 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 342 Block *parallelBody = parallel.getBody(); 343 rewriter.mergeBlockBefore(generateOp.getBody(), 344 parallelBody->getTerminator(), 345 parallelBody->getArguments()); 346 // Replace the inlined yield op with a store op. The scf.parallel's builder 347 // already populated an scf.yield at the end, so we don't need to worry 348 // about creating that. 349 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 350 rewriter.setInsertionPointAfter(elementYield); 351 rewriter.replaceOpWithNewOp<memref::StoreOp>( 352 elementYield, elementYield->getOperands()[0], result, 353 parallelBody->getArguments()); 354 355 replaceOpWithBufferizedValues(rewriter, op, result); 356 return success(); 357 } 358 }; 359 360 /// Bufferization of tensor.insert. Replace with memref.store. 361 struct InsertOpInterface 362 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 363 tensor::InsertOp> { 364 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 365 const BufferizationState &state) const { 366 return true; 367 } 368 369 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 370 const BufferizationState &state) const { 371 return true; 372 } 373 374 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 375 const BufferizationState &state) const { 376 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 377 "expected dest OpOperand"); 378 return op->getOpResult(0); 379 } 380 381 SmallVector<OpOperand *> 382 getAliasingOpOperand(Operation *op, OpResult opResult, 383 const BufferizationState &state) const { 384 return {&op->getOpOperand(1) /*dest*/}; 385 } 386 387 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 388 const BufferizationState &state) const { 389 auto insertOp = cast<tensor::InsertOp>(op); 390 FailureOr<Value> destMemref = 391 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 392 if (failed(destMemref)) 393 return failure(); 394 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 395 *destMemref, insertOp.indices()); 396 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 397 return success(); 398 } 399 400 BufferRelation bufferRelation(Operation *op, OpResult opResult, 401 const BufferizationState &state) const { 402 return BufferRelation::Equivalent; 403 } 404 }; 405 406 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 407 /// equivalent operand / result and same offset/sizes/strides specification). 408 /// 409 /// This is one particular type of relationship between ops on tensors that 410 /// reduce to an equivalence on buffers. This should be generalized and 411 /// exposed as interfaces on the proper types. 412 static bool areEquivalentExtractSliceOps(const BufferizationState &state, 413 ExtractSliceOp st, InsertSliceOp sti) { 414 if (!st || !sti) 415 return false; 416 if (sti != sti && 417 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 418 return false; 419 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 420 return false; 421 return true; 422 } 423 424 /// Return true if `value` is originating from an ExtractSliceOp that matches 425 /// the given InsertSliceOp. 426 static bool hasMatchingExtractSliceOp(const BufferizationState &state, 427 Value value, InsertSliceOp insertOp) { 428 auto condition = [&](Value val) { 429 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 430 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 431 return true; 432 return false; 433 }; 434 435 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 436 condition); 437 } 438 439 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 440 /// certain circumstances, this op can also be a no-op. 441 struct InsertSliceOpInterface 442 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 443 tensor::InsertSliceOp> { 444 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 445 const BufferizationState &state) const { 446 return true; 447 } 448 449 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 450 const BufferizationState &state) const { 451 return &opOperand == &op->getOpOperand(1) /*dest*/; 452 } 453 454 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 455 const BufferizationState &state) const { 456 return &opOperand == &op->getOpOperand(1) /*dest*/ 457 ? op->getResult(0) 458 : OpResult(); 459 } 460 461 BufferRelation bufferRelation(Operation *op, OpResult opResult, 462 const BufferizationState &state) const { 463 return BufferRelation::Equivalent; 464 } 465 466 bool isNotConflicting(Operation *op, OpOperand *uRead, 467 OpOperand *uConflictingWrite, 468 const BufferizationState &state) const { 469 Operation *readingOp = uRead->getOwner(); 470 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 471 472 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 473 // uRead is an InsertSliceOp... 474 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 475 // As an example, consider the following IR. 476 // 477 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 478 // %1 = linalg.fill %cst, %0 {inplace= [true] } 479 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 480 // {inplace= [true] } 481 482 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 483 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 484 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 485 insertSliceOp)) 486 // Case 1: The main insight is that InsertSliceOp reads only part of 487 // the destination tensor. The overwritten area is not read. If 488 // uConflictingWrite writes into exactly the memory location that is 489 // being read by uRead, this is not a conflict. 490 // 491 // In the above example: 492 // uRead = OpOperand 1 (%t) of tensor.insert_slice 493 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 494 // 495 // The read of %t does not conflict with the write of the FillOp 496 // (same aliases!) because the area that the FillOp operates on is 497 // exactly the one that is *not* read via %t. 498 return true; 499 500 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 501 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 502 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 503 // Case 2: The read of the source tensor and the write to the dest 504 // tensor via an InsertSliceOp is not a conflict if the read is 505 // reading exactly that part of an equivalent tensor that the 506 // InsertSliceOp is writing. 507 // 508 // In the above example: 509 // uRead = OpOperand 0 (%1) of tensor.insert_slice 510 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 511 return true; 512 } 513 514 // If uConflictingWrite is an InsertSliceOp... 515 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 516 // As an example, consider the following IR. 517 // 518 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 519 // %1 = linalg.fill %cst, %0 {inplace= [true] } 520 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 521 // {inplace= [true] } 522 // %3 = vector.transfer_read %1, %cst 523 // 524 // In the above example: 525 // uRead = OpOperand 0 (%1) of vector.transfer_read 526 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 527 // lastWrite = %1 528 // 529 // This is not a conflict because the InsertSliceOp overwrites the 530 // memory segment of %1 with the exact same data. (Effectively, there 531 // is no memory write here.) 532 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 533 state.areEquivalentBufferizedValues(uRead->get(), 534 insertSliceOp.source()) && 535 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 536 insertSliceOp)) 537 return true; 538 539 return false; 540 } 541 542 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 543 const BufferizationState &state) const { 544 // insert_slice ops arise from tiling and bufferizing them out-of-place is 545 // generally a deal breaker. When used with loops, this ends up cloning the 546 // whole tensor on every single iteration and is a symptom of a 547 // catastrophically bad scheduling decision. 548 // TODO: be very loud about it or even consider failing the pass. 549 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 550 Location loc = insertSliceOp.getLoc(); 551 552 // When bufferizing out-of-place, `getResultBuffer` allocates. 553 FailureOr<Value> dstMemref = 554 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 555 if (failed(dstMemref)) 556 return failure(); 557 558 // Expand offsets, sizes and strides to the full rank to handle the 559 // rank-reducing case. 560 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 561 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 562 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 563 OffsetSizeAndStrideOpInterface::expandToRank( 564 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 565 [&](Value target, int64_t dim) -> OpFoldResult { 566 auto shapedType = target.getType().cast<ShapedType>(); 567 if (shapedType.isDynamicDim(dim)) 568 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 569 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 570 }); 571 // Take a subview of the dst. 572 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 573 auto subviewMemRefType = 574 memref::SubViewOp::inferRankReducedResultType( 575 insertSliceOp.getSourceType().getRank(), dstMemrefType, 576 mixedOffsets, mixedSizes, mixedStrides) 577 .cast<MemRefType>(); 578 Value subView = rewriter.create<memref::SubViewOp>( 579 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 580 mixedStrides); 581 582 // Copy tensor. If this tensor.insert_slice has a matching 583 // tensor.extract_slice, the copy operation will eventually fold away. 584 Value srcMemref = 585 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 586 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 587 state.getOptions()))) 588 return failure(); 589 590 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 591 return success(); 592 } 593 }; 594 595 /// Bufferization of tensor.rank. Replace with memref.rank. 596 struct RankOpInterface 597 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 598 tensor::RankOp> { 599 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 600 const BufferizationState &state) const { 601 return true; 602 } 603 604 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 605 const BufferizationState &state) const { 606 return false; 607 } 608 609 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 610 const BufferizationState &state) const { 611 return OpResult(); 612 } 613 614 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 615 const BufferizationState &state) const { 616 auto rankOp = cast<tensor::RankOp>(op); 617 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 618 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 619 v); 620 return success(); 621 } 622 }; 623 624 } // namespace 625 } // namespace tensor 626 } // namespace mlir 627 628 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 629 DialectRegistry ®istry) { 630 registry.addOpInterface<CastOp, CastOpInterface>(); 631 registry.addOpInterface<DimOp, DimOpInterface>(); 632 registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 633 registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 634 registry.addOpInterface<FromElementsOp, FromElementsOpInterface>(); 635 registry.addOpInterface<GenerateOp, GenerateOpInterface>(); 636 registry.addOpInterface<InsertOp, InsertOpInterface>(); 637 registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 638 registry.addOpInterface<RankOp, RankOpInterface>(); 639 } 640