1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const AnalysisState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const AnalysisState &state) const { 35 return false; 36 } 37 38 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 39 const AnalysisState &state) const { 40 return {op->getResult(0)}; 41 } 42 43 BufferRelation bufferRelation(Operation *op, OpResult opResult, 44 const AnalysisState &state) const { 45 return BufferRelation::Equivalent; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 BufferizationState &state) const { 50 auto castOp = cast<tensor::CastOp>(op); 51 52 // The result buffer still has the old (pre-cast) type. 53 FailureOr<Value> resultBuffer = 54 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 55 if (failed(resultBuffer)) 56 return failure(); 57 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 58 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 59 TensorType resultTensorType = 60 castOp.getResult().getType().cast<TensorType>(); 61 MemRefLayoutAttrInterface layout; 62 63 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 64 if (resultTensorType.isa<RankedTensorType>()) 65 layout = rankedMemRefType.getLayout(); 66 67 // Compute the new memref type. 68 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 69 layout, memorySpace); 70 71 // Replace the op with a memref.cast. 72 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 73 resultMemRefType) && 74 "CallOp::bufferize: cast incompatible"); 75 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 76 *resultBuffer); 77 78 return success(); 79 } 80 }; 81 82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 83 struct CollapseShapeOpInterface 84 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 85 tensor::CollapseShapeOp> { 86 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 87 const AnalysisState &state) const { 88 return false; 89 } 90 91 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 92 const AnalysisState &state) const { 93 return false; 94 } 95 96 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 97 const AnalysisState &state) const { 98 if (&opOperand == &op->getOpOperand(0) /*src*/) 99 return {op->getOpResult(0)}; 100 return {}; 101 } 102 103 BufferRelation bufferRelation(Operation *op, OpResult opResult, 104 const AnalysisState &state) const { 105 return BufferRelation::Equivalent; 106 } 107 108 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 109 BufferizationState &state) const { 110 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 111 Value buffer = 112 *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); 113 Type resultType = 114 getMemRefType(collapseShapeOp.getResultType(), state.getOptions()); 115 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 116 rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 117 return success(); 118 } 119 }; 120 121 /// Bufferization of tensor.dim. Replace with memref.dim. 122 struct DimOpInterface 123 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 124 tensor::DimOp> { 125 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 126 const AnalysisState &state) const { 127 return true; 128 } 129 130 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 131 const AnalysisState &state) const { 132 return false; 133 } 134 135 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 136 const AnalysisState &state) const { 137 return {}; 138 } 139 140 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 141 BufferizationState &state) const { 142 auto dimOp = cast<tensor::DimOp>(op); 143 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 144 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 145 return success(); 146 } 147 }; 148 149 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 150 struct ExpandShapeOpInterface 151 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 152 tensor::ExpandShapeOp> { 153 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 154 const AnalysisState &state) const { 155 return false; 156 } 157 158 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 159 const AnalysisState &state) const { 160 return false; 161 } 162 163 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 164 const AnalysisState &state) const { 165 if (&opOperand == &op->getOpOperand(0) /*src*/) 166 return {op->getOpResult(0)}; 167 return {}; 168 } 169 170 BufferRelation bufferRelation(Operation *op, OpResult opResult, 171 const AnalysisState &state) const { 172 return BufferRelation::Equivalent; 173 } 174 175 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 176 BufferizationState &state) const { 177 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 178 Value buffer = 179 *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 180 Type resultType = 181 getMemRefType(expandShapeOp.getResultType(), state.getOptions()); 182 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 183 rewriter, op, resultType, buffer, expandShapeOp.reassociation()); 184 return success(); 185 } 186 }; 187 188 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 189 struct ExtractSliceOpInterface 190 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 191 tensor::ExtractSliceOp> { 192 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 193 const AnalysisState &state) const { 194 return false; 195 } 196 197 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 198 const AnalysisState &state) const { 199 return false; 200 } 201 202 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 203 const AnalysisState &state) const { 204 if (&opOperand == &op->getOpOperand(0) /*source*/) 205 return {op->getOpResult(0)}; 206 return {}; 207 } 208 209 BufferRelation bufferRelation(Operation *op, OpResult opResult, 210 const AnalysisState &state) const { 211 return BufferRelation::None; 212 } 213 214 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 215 BufferizationState &state) const { 216 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 217 Location loc = extractSliceOp.getLoc(); 218 Value srcMemref = 219 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 220 /*forceInPlace=*/true); 221 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 222 auto dstTensorType = 223 extractSliceOp.result().getType().cast<RankedTensorType>(); 224 225 // If not inplaceable, alloc. 226 bool inplace = 227 state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 228 Value alloc; 229 if (!inplace) { 230 FailureOr<Value> allocOrFailure = 231 state.createAlloc(rewriter, loc, extractSliceOp.result()); 232 if (failed(allocOrFailure)) 233 return failure(); 234 alloc = *allocOrFailure; 235 } 236 237 // Expand offsets, sizes and strides to the full rank to handle the 238 // rank-reducing case. 239 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 240 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 241 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 242 OffsetSizeAndStrideOpInterface::expandToRank( 243 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 244 [&](Value target, int64_t dim) -> OpFoldResult { 245 auto shapedType = target.getType().cast<ShapedType>(); 246 if (shapedType.isDynamicDim(dim)) 247 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 248 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 249 }); 250 // Bufferize to subview. 251 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 252 dstTensorType.getRank(), srcMemrefType, 253 mixedOffsets, mixedSizes, mixedStrides) 254 .cast<MemRefType>(); 255 Value subView = rewriter.create<memref::SubViewOp>( 256 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 257 mixedStrides); 258 259 // If not inplaceable, copy. 260 if (!inplace) { 261 // Do not copy if the copied data is never read. 262 if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 263 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 264 alloc, state.getOptions()))) 265 return failure(); 266 subView = alloc; 267 } 268 269 replaceOpWithBufferizedValues(rewriter, op, subView); 270 return success(); 271 } 272 }; 273 274 /// Bufferization of tensor.extract. Replace with memref.load. 275 struct ExtractOpInterface 276 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 277 tensor::ExtractOp> { 278 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 279 const AnalysisState &state) const { 280 return true; 281 } 282 283 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 284 const AnalysisState &state) const { 285 return false; 286 } 287 288 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 289 const AnalysisState &state) const { 290 return {}; 291 } 292 293 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 294 BufferizationState &state) const { 295 auto extractOp = cast<tensor::ExtractOp>(op); 296 Value srcMemref = 297 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 298 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 299 extractOp.indices()); 300 return success(); 301 } 302 }; 303 304 // Implements backtracking to traverse indices of the output buffer while 305 // iterating over op.elements(). 306 static void createStores(RewriterBase &rewriter, Location loc, int dim, 307 Value buffer, ArrayRef<int64_t> shape, 308 ArrayRef<Value> constants, 309 OperandRange::iterator &elementIt, 310 SmallVectorImpl<Value> &indices) { 311 if (dim == static_cast<int>(shape.size()) - 1) { 312 for (int i = 0; i < shape.back(); ++i) { 313 indices.back() = constants[i]; 314 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 315 ++elementIt; 316 } 317 return; 318 } 319 for (int i = 0; i < shape[dim]; ++i) { 320 indices[dim] = constants[i]; 321 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 322 indices); 323 } 324 } 325 326 /// Bufferization of tensor.from_elements. 327 struct FromElementsOpInterface 328 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 329 tensor::FromElementsOp> { 330 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 331 BufferizationState &state) const { 332 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 333 334 // Allocate a buffer for the result. 335 Location loc = op->getLoc(); 336 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 337 auto shape = tensorType.getShape(); 338 MemRefType resultType = getContiguousMemRefType(tensorType); 339 FailureOr<Value> maybeBuffer = 340 state.createAlloc(rewriter, loc, resultType, {}); 341 if (failed(maybeBuffer)) 342 return failure(); 343 Value buffer = *maybeBuffer; 344 345 // Case: tensor<0xelem_type>. 346 if (fromElementsOp.elements().empty()) { 347 replaceOpWithBufferizedValues(rewriter, op, buffer); 348 return success(); 349 } 350 351 // Case: tensor<elem_type>. 352 if (shape.empty()) { 353 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 354 buffer); 355 replaceOpWithBufferizedValues(rewriter, op, buffer); 356 return success(); 357 } 358 359 // Create constants for the range of possible indices [0, max{shape_i}). 360 auto maxDim = *std::max_element(shape.begin(), shape.end()); 361 SmallVector<Value, 2> constants; 362 constants.reserve(maxDim); 363 for (int i = 0; i < maxDim; ++i) 364 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 365 366 // Traverse all `elements` and create `memref.store` ops. 367 auto elementIt = fromElementsOp.elements().begin(); 368 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 369 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 370 indices); 371 372 replaceOpWithBufferizedValues(rewriter, op, buffer); 373 return success(); 374 } 375 }; 376 377 /// Bufferization of tensor.generate. 378 struct GenerateOpInterface 379 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 380 tensor::GenerateOp> { 381 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 382 BufferizationState &state) const { 383 auto generateOp = cast<tensor::GenerateOp>(op); 384 385 // Allocate memory. 386 Location loc = op->getLoc(); 387 MemRefType memrefType = 388 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 389 FailureOr<Value> maybeResult = state.createAlloc( 390 rewriter, loc, memrefType, generateOp.dynamicExtents()); 391 if (failed(maybeResult)) 392 return failure(); 393 Value result = *maybeResult; 394 395 // Collect loop bounds. 396 int64_t rank = memrefType.getRank(); 397 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 398 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 399 SmallVector<Value, 4> lowerBounds(rank, zero); 400 SmallVector<Value, 4> steps(rank, one); 401 SmallVector<Value, 4> upperBounds; 402 int nextDynamicIndex = 0; 403 for (int i = 0; i < rank; i++) { 404 Value upperBound = memrefType.isDynamicDim(i) 405 ? generateOp.dynamicExtents()[nextDynamicIndex++] 406 : rewriter.create<arith::ConstantIndexOp>( 407 loc, memrefType.getDimSize(i)); 408 upperBounds.push_back(upperBound); 409 } 410 411 // Generate tensor elements with a parallel loop that stores into 412 // each element of the resulting memref. We use mergeBlockBefore to "move" 413 // this op's body into the scf.parallel's body. 414 auto parallel = 415 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 416 Block *parallelBody = parallel.getBody(); 417 rewriter.mergeBlockBefore(generateOp.getBody(), 418 parallelBody->getTerminator(), 419 parallelBody->getArguments()); 420 // Replace the inlined yield op with a store op. The scf.parallel's builder 421 // already populated an scf.yield at the end, so we don't need to worry 422 // about creating that. 423 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 424 rewriter.setInsertionPointAfter(elementYield); 425 rewriter.replaceOpWithNewOp<memref::StoreOp>( 426 elementYield, elementYield->getOperands()[0], result, 427 parallelBody->getArguments()); 428 429 replaceOpWithBufferizedValues(rewriter, op, result); 430 return success(); 431 } 432 }; 433 434 /// Bufferization of tensor.insert. Replace with memref.store. 435 struct InsertOpInterface 436 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 437 tensor::InsertOp> { 438 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 439 const AnalysisState &state) const { 440 return true; 441 } 442 443 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 444 const AnalysisState &state) const { 445 return true; 446 } 447 448 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 449 const AnalysisState &state) const { 450 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 451 "expected dest OpOperand"); 452 return {op->getOpResult(0)}; 453 } 454 455 SmallVector<OpOperand *> 456 getAliasingOpOperand(Operation *op, OpResult opResult, 457 const AnalysisState &state) const { 458 return {&op->getOpOperand(1) /*dest*/}; 459 } 460 461 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 462 BufferizationState &state) const { 463 auto insertOp = cast<tensor::InsertOp>(op); 464 FailureOr<Value> destMemref = 465 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 466 if (failed(destMemref)) 467 return failure(); 468 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 469 *destMemref, insertOp.indices()); 470 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 471 return success(); 472 } 473 474 BufferRelation bufferRelation(Operation *op, OpResult opResult, 475 const AnalysisState &state) const { 476 return BufferRelation::Equivalent; 477 } 478 }; 479 480 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 481 /// equivalent operand / result and same offset/sizes/strides specification). 482 /// 483 /// This is one particular type of relationship between ops on tensors that 484 /// reduce to an equivalence on buffers. This should be generalized and 485 /// exposed as interfaces on the proper types. 486 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 487 ExtractSliceOp st, InsertSliceOp sti) { 488 if (!st || !sti) 489 return false; 490 if (sti != sti && 491 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 492 return false; 493 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 494 return false; 495 return true; 496 } 497 498 /// Return true if `value` is originating from an ExtractSliceOp that matches 499 /// the given InsertSliceOp. 500 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 501 InsertSliceOp insertOp) { 502 auto condition = [&](Value val) { 503 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 504 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 505 return true; 506 return false; 507 }; 508 509 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 510 condition); 511 } 512 513 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 514 /// certain circumstances, this op can also be a no-op. 515 struct InsertSliceOpInterface 516 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 517 tensor::InsertSliceOp> { 518 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 519 const AnalysisState &state) const { 520 return true; 521 } 522 523 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 524 const AnalysisState &state) const { 525 return &opOperand == &op->getOpOperand(1) /*dest*/; 526 } 527 528 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 529 const AnalysisState &state) const { 530 if (&opOperand == &op->getOpOperand(1) /*dest*/) 531 return {op->getResult(0)}; 532 return {}; 533 } 534 535 BufferRelation bufferRelation(Operation *op, OpResult opResult, 536 const AnalysisState &state) const { 537 return BufferRelation::Equivalent; 538 } 539 540 bool isNotConflicting(Operation *op, OpOperand *uRead, 541 OpOperand *uConflictingWrite, 542 const AnalysisState &state) const { 543 Operation *readingOp = uRead->getOwner(); 544 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 545 546 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 547 // uRead is an InsertSliceOp... 548 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 549 // As an example, consider the following IR. 550 // 551 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 552 // %1 = linalg.fill %cst, %0 {inplace= [true] } 553 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 554 // {inplace= [true] } 555 556 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 557 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 558 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 559 insertSliceOp)) 560 // Case 1: The main insight is that InsertSliceOp reads only part of 561 // the destination tensor. The overwritten area is not read. If 562 // uConflictingWrite writes into exactly the memory location that is 563 // being read by uRead, this is not a conflict. 564 // 565 // In the above example: 566 // uRead = OpOperand 1 (%t) of tensor.insert_slice 567 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 568 // 569 // The read of %t does not conflict with the write of the FillOp 570 // (same aliases!) because the area that the FillOp operates on is 571 // exactly the one that is *not* read via %t. 572 return true; 573 574 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 575 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 576 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 577 // Case 2: The read of the source tensor and the write to the dest 578 // tensor via an InsertSliceOp is not a conflict if the read is 579 // reading exactly that part of an equivalent tensor that the 580 // InsertSliceOp is writing. 581 // 582 // In the above example: 583 // uRead = OpOperand 0 (%1) of tensor.insert_slice 584 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 585 return true; 586 } 587 588 // If uConflictingWrite is an InsertSliceOp... 589 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 590 // As an example, consider the following IR. 591 // 592 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 593 // %1 = linalg.fill %cst, %0 {inplace= [true] } 594 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 595 // {inplace= [true] } 596 // %3 = vector.transfer_read %1, %cst 597 // 598 // In the above example: 599 // uRead = OpOperand 0 (%1) of vector.transfer_read 600 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 601 // lastWrite = %1 602 // 603 // This is not a conflict because the InsertSliceOp overwrites the 604 // memory segment of %1 with the exact same data. (Effectively, there 605 // is no memory write here.) 606 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 607 state.areEquivalentBufferizedValues(uRead->get(), 608 insertSliceOp.source()) && 609 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 610 insertSliceOp)) 611 return true; 612 613 return false; 614 } 615 616 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 617 BufferizationState &state) const { 618 // insert_slice ops arise from tiling and bufferizing them out-of-place is 619 // generally a deal breaker. When used with loops, this ends up cloning the 620 // whole tensor on every single iteration and is a symptom of a 621 // catastrophically bad scheduling decision. 622 // TODO: be very loud about it or even consider failing the pass. 623 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 624 Location loc = insertSliceOp.getLoc(); 625 626 // When bufferizing out-of-place, `getResultBuffer` allocates. 627 FailureOr<Value> dstMemref = 628 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 629 if (failed(dstMemref)) 630 return failure(); 631 632 // Expand offsets, sizes and strides to the full rank to handle the 633 // rank-reducing case. 634 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 635 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 636 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 637 OffsetSizeAndStrideOpInterface::expandToRank( 638 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 639 [&](Value target, int64_t dim) -> OpFoldResult { 640 auto shapedType = target.getType().cast<ShapedType>(); 641 if (shapedType.isDynamicDim(dim)) 642 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 643 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 644 }); 645 // Take a subview of the dst. 646 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 647 auto subviewMemRefType = 648 memref::SubViewOp::inferRankReducedResultType( 649 insertSliceOp.getSourceType().getRank(), dstMemrefType, 650 mixedOffsets, mixedSizes, mixedStrides) 651 .cast<MemRefType>(); 652 Value subView = rewriter.create<memref::SubViewOp>( 653 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 654 mixedStrides); 655 656 // Copy tensor. If this tensor.insert_slice has a matching 657 // tensor.extract_slice, the copy operation will eventually fold away. 658 Value srcMemref = 659 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 660 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 661 state.getOptions()))) 662 return failure(); 663 664 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 665 return success(); 666 } 667 }; 668 669 /// Bufferization of tensor.rank. Replace with memref.rank. 670 struct RankOpInterface 671 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 672 tensor::RankOp> { 673 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 674 const AnalysisState &state) const { 675 return true; 676 } 677 678 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 679 const AnalysisState &state) const { 680 return false; 681 } 682 683 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 684 const AnalysisState &state) const { 685 return {}; 686 } 687 688 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 689 BufferizationState &state) const { 690 auto rankOp = cast<tensor::RankOp>(op); 691 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 692 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 693 v); 694 return success(); 695 } 696 }; 697 698 } // namespace 699 } // namespace tensor 700 } // namespace mlir 701 702 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 703 DialectRegistry ®istry) { 704 registry.addOpInterface<CastOp, CastOpInterface>(); 705 registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>(); 706 registry.addOpInterface<DimOp, DimOpInterface>(); 707 registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>(); 708 registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>(); 709 registry.addOpInterface<ExtractOp, ExtractOpInterface>(); 710 registry.addOpInterface<FromElementsOp, FromElementsOpInterface>(); 711 registry.addOpInterface<GenerateOp, GenerateOpInterface>(); 712 registry.addOpInterface<InsertOp, InsertOpInterface>(); 713 registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>(); 714 registry.addOpInterface<RankOp, RankOpInterface>(); 715 } 716