1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const AnalysisState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const AnalysisState &state) const { 35 return false; 36 } 37 38 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 39 const AnalysisState &state) const { 40 return {op->getResult(0)}; 41 } 42 43 BufferRelation bufferRelation(Operation *op, OpResult opResult, 44 const AnalysisState &state) const { 45 return BufferRelation::Equivalent; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 BufferizationState &state) const { 50 auto castOp = cast<tensor::CastOp>(op); 51 52 // The result buffer still has the old (pre-cast) type. 53 FailureOr<Value> resultBuffer = 54 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 55 if (failed(resultBuffer)) 56 return failure(); 57 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 58 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 59 TensorType resultTensorType = 60 castOp.getResult().getType().cast<TensorType>(); 61 MemRefLayoutAttrInterface layout; 62 63 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 64 if (resultTensorType.isa<RankedTensorType>()) 65 layout = rankedMemRefType.getLayout(); 66 67 // Compute the new memref type. 68 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 69 layout, memorySpace); 70 71 // Replace the op with a memref.cast. 72 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 73 resultMemRefType) && 74 "CallOp::bufferize: cast incompatible"); 75 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 76 *resultBuffer); 77 78 return success(); 79 } 80 }; 81 82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 83 struct CollapseShapeOpInterface 84 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 85 tensor::CollapseShapeOp> { 86 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 87 const AnalysisState &state) const { 88 return false; 89 } 90 91 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 92 const AnalysisState &state) const { 93 return false; 94 } 95 96 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 97 const AnalysisState &state) const { 98 if (&opOperand == &op->getOpOperand(0) /*src*/) 99 return {op->getOpResult(0)}; 100 return {}; 101 } 102 103 BufferRelation bufferRelation(Operation *op, OpResult opResult, 104 const AnalysisState &state) const { 105 return BufferRelation::Equivalent; 106 } 107 108 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 109 BufferizationState &state) const { 110 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 111 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 112 Value buffer = 113 *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); 114 115 if (tensorResultType.getRank() == 0) { 116 // 0-d collapses must go through a different op builder. 117 auto bufferType = buffer.getType().cast<MemRefType>(); 118 // Assume identity layout: No offset. 119 assert(bufferType.getLayout().isIdentity() && 120 "non-zero offset for 0-d collapse not supported"); 121 MemRefLayoutAttrInterface layout; 122 auto resultType = MemRefType::get({}, tensorResultType.getElementType(), 123 layout, bufferType.getMemorySpace()); 124 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 125 rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 126 return success(); 127 } 128 129 // Result type is inferred by the builder. 130 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 131 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 132 return success(); 133 } 134 }; 135 136 /// Bufferization of tensor.dim. Replace with memref.dim. 137 struct DimOpInterface 138 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 139 tensor::DimOp> { 140 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 141 const AnalysisState &state) const { 142 return true; 143 } 144 145 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 146 const AnalysisState &state) const { 147 return false; 148 } 149 150 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 151 const AnalysisState &state) const { 152 return {}; 153 } 154 155 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 156 BufferizationState &state) const { 157 auto dimOp = cast<tensor::DimOp>(op); 158 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 159 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 160 return success(); 161 } 162 }; 163 164 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 165 struct ExpandShapeOpInterface 166 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 167 tensor::ExpandShapeOp> { 168 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 169 const AnalysisState &state) const { 170 return false; 171 } 172 173 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 174 const AnalysisState &state) const { 175 return false; 176 } 177 178 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 179 const AnalysisState &state) const { 180 if (&opOperand == &op->getOpOperand(0) /*src*/) 181 return {op->getOpResult(0)}; 182 return {}; 183 } 184 185 BufferRelation bufferRelation(Operation *op, OpResult opResult, 186 const AnalysisState &state) const { 187 return BufferRelation::Equivalent; 188 } 189 190 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 191 BufferizationState &state) const { 192 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 193 auto tensorResultType = expandShapeOp.getResultType(); 194 Value buffer = 195 *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 196 197 // Memref result type is inferred by the builder based on reassociation 198 // indices and result shape. 199 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 200 rewriter, op, tensorResultType.getShape(), buffer, 201 expandShapeOp.getReassociationIndices()); 202 return success(); 203 } 204 }; 205 206 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 207 struct ExtractSliceOpInterface 208 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 209 tensor::ExtractSliceOp> { 210 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 211 const AnalysisState &state) const { 212 return false; 213 } 214 215 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 216 const AnalysisState &state) const { 217 return false; 218 } 219 220 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 221 const AnalysisState &state) const { 222 if (&opOperand == &op->getOpOperand(0) /*source*/) 223 return {op->getOpResult(0)}; 224 return {}; 225 } 226 227 BufferRelation bufferRelation(Operation *op, OpResult opResult, 228 const AnalysisState &state) const { 229 return BufferRelation::None; 230 } 231 232 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 233 BufferizationState &state) const { 234 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 235 Location loc = extractSliceOp.getLoc(); 236 Value srcMemref = 237 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 238 /*forceInPlace=*/true); 239 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 240 auto dstTensorType = 241 extractSliceOp.result().getType().cast<RankedTensorType>(); 242 243 // If not inplaceable, alloc. 244 bool inplace = 245 state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 246 Value alloc; 247 if (!inplace) { 248 FailureOr<Value> allocOrFailure = 249 state.createAlloc(rewriter, loc, extractSliceOp.result()); 250 if (failed(allocOrFailure)) 251 return failure(); 252 alloc = *allocOrFailure; 253 } 254 255 // Expand offsets, sizes and strides to the full rank to handle the 256 // rank-reducing case. 257 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 258 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 259 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 260 OffsetSizeAndStrideOpInterface::expandToRank( 261 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 262 [&](Value target, int64_t dim) -> OpFoldResult { 263 auto shapedType = target.getType().cast<ShapedType>(); 264 if (shapedType.isDynamicDim(dim)) 265 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 266 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 267 }); 268 // Bufferize to subview. 269 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 270 dstTensorType.getRank(), srcMemrefType, 271 mixedOffsets, mixedSizes, mixedStrides) 272 .cast<MemRefType>(); 273 Value subView = rewriter.create<memref::SubViewOp>( 274 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 275 mixedStrides); 276 277 // If not inplaceable, copy. 278 if (!inplace) { 279 // Do not copy if the copied data is never read. 280 if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 281 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 282 alloc, state.getOptions()))) 283 return failure(); 284 subView = alloc; 285 } 286 287 replaceOpWithBufferizedValues(rewriter, op, subView); 288 return success(); 289 } 290 }; 291 292 /// Bufferization of tensor.extract. Replace with memref.load. 293 struct ExtractOpInterface 294 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 295 tensor::ExtractOp> { 296 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 297 const AnalysisState &state) const { 298 return true; 299 } 300 301 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 302 const AnalysisState &state) const { 303 return false; 304 } 305 306 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 307 const AnalysisState &state) const { 308 return {}; 309 } 310 311 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 312 BufferizationState &state) const { 313 auto extractOp = cast<tensor::ExtractOp>(op); 314 Value srcMemref = 315 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 316 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 317 extractOp.indices()); 318 return success(); 319 } 320 }; 321 322 // Implements backtracking to traverse indices of the output buffer while 323 // iterating over op.elements(). 324 static void createStores(RewriterBase &rewriter, Location loc, int dim, 325 Value buffer, ArrayRef<int64_t> shape, 326 ArrayRef<Value> constants, 327 OperandRange::iterator &elementIt, 328 SmallVectorImpl<Value> &indices) { 329 if (dim == static_cast<int>(shape.size()) - 1) { 330 for (int i = 0; i < shape.back(); ++i) { 331 indices.back() = constants[i]; 332 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 333 ++elementIt; 334 } 335 return; 336 } 337 for (int i = 0; i < shape[dim]; ++i) { 338 indices[dim] = constants[i]; 339 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 340 indices); 341 } 342 } 343 344 /// Bufferization of tensor.from_elements. 345 struct FromElementsOpInterface 346 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 347 tensor::FromElementsOp> { 348 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 349 BufferizationState &state) const { 350 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 351 352 // Allocate a buffer for the result. 353 Location loc = op->getLoc(); 354 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 355 auto shape = tensorType.getShape(); 356 FailureOr<Value> maybeBuffer = 357 state.createAlloc(rewriter, loc, fromElementsOp.result()); 358 if (failed(maybeBuffer)) 359 return failure(); 360 Value buffer = *maybeBuffer; 361 362 // Case: tensor<0xelem_type>. 363 if (fromElementsOp.elements().empty()) { 364 replaceOpWithBufferizedValues(rewriter, op, buffer); 365 return success(); 366 } 367 368 // Case: tensor<elem_type>. 369 if (shape.empty()) { 370 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 371 buffer); 372 replaceOpWithBufferizedValues(rewriter, op, buffer); 373 return success(); 374 } 375 376 // Create constants for the range of possible indices [0, max{shape_i}). 377 auto maxDim = *std::max_element(shape.begin(), shape.end()); 378 SmallVector<Value, 2> constants; 379 constants.reserve(maxDim); 380 for (int i = 0; i < maxDim; ++i) 381 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 382 383 // Traverse all `elements` and create `memref.store` ops. 384 auto elementIt = fromElementsOp.elements().begin(); 385 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 386 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 387 indices); 388 389 replaceOpWithBufferizedValues(rewriter, op, buffer); 390 return success(); 391 } 392 }; 393 394 /// Bufferization of tensor.generate. 395 struct GenerateOpInterface 396 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 397 tensor::GenerateOp> { 398 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 399 BufferizationState &state) const { 400 auto generateOp = cast<tensor::GenerateOp>(op); 401 402 // Allocate memory. 403 Location loc = op->getLoc(); 404 MemRefType memrefType = 405 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 406 FailureOr<Value> maybeResult = 407 state.createAlloc(rewriter, loc, generateOp.result()); 408 if (failed(maybeResult)) 409 return failure(); 410 Value result = *maybeResult; 411 412 // Collect loop bounds. 413 int64_t rank = memrefType.getRank(); 414 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 415 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 416 SmallVector<Value, 4> lowerBounds(rank, zero); 417 SmallVector<Value, 4> steps(rank, one); 418 SmallVector<Value, 4> upperBounds; 419 int nextDynamicIndex = 0; 420 for (int i = 0; i < rank; i++) { 421 Value upperBound = memrefType.isDynamicDim(i) 422 ? generateOp.dynamicExtents()[nextDynamicIndex++] 423 : rewriter.create<arith::ConstantIndexOp>( 424 loc, memrefType.getDimSize(i)); 425 upperBounds.push_back(upperBound); 426 } 427 428 // Generate tensor elements with a parallel loop that stores into 429 // each element of the resulting memref. We use mergeBlockBefore to "move" 430 // this op's body into the scf.parallel's body. 431 auto parallel = 432 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 433 Block *parallelBody = parallel.getBody(); 434 rewriter.mergeBlockBefore(generateOp.getBody(), 435 parallelBody->getTerminator(), 436 parallelBody->getArguments()); 437 // Replace the inlined yield op with a store op. The scf.parallel's builder 438 // already populated an scf.yield at the end, so we don't need to worry 439 // about creating that. 440 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 441 rewriter.setInsertionPointAfter(elementYield); 442 rewriter.replaceOpWithNewOp<memref::StoreOp>( 443 elementYield, elementYield->getOperands()[0], result, 444 parallelBody->getArguments()); 445 446 replaceOpWithBufferizedValues(rewriter, op, result); 447 return success(); 448 } 449 }; 450 451 /// Bufferization of tensor.insert. Replace with memref.store. 452 struct InsertOpInterface 453 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 454 tensor::InsertOp> { 455 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 456 const AnalysisState &state) const { 457 return true; 458 } 459 460 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 461 const AnalysisState &state) const { 462 return true; 463 } 464 465 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 466 const AnalysisState &state) const { 467 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 468 "expected dest OpOperand"); 469 return {op->getOpResult(0)}; 470 } 471 472 SmallVector<OpOperand *> 473 getAliasingOpOperand(Operation *op, OpResult opResult, 474 const AnalysisState &state) const { 475 return {&op->getOpOperand(1) /*dest*/}; 476 } 477 478 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 479 BufferizationState &state) const { 480 auto insertOp = cast<tensor::InsertOp>(op); 481 FailureOr<Value> destMemref = 482 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 483 if (failed(destMemref)) 484 return failure(); 485 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 486 *destMemref, insertOp.indices()); 487 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 488 return success(); 489 } 490 491 BufferRelation bufferRelation(Operation *op, OpResult opResult, 492 const AnalysisState &state) const { 493 return BufferRelation::Equivalent; 494 } 495 }; 496 497 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 498 /// equivalent operand / result and same offset/sizes/strides specification). 499 /// 500 /// This is one particular type of relationship between ops on tensors that 501 /// reduce to an equivalence on buffers. This should be generalized and 502 /// exposed as interfaces on the proper types. 503 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 504 ExtractSliceOp st, InsertSliceOp sti) { 505 if (!st || !sti) 506 return false; 507 if (sti != sti && 508 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 509 return false; 510 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 511 return false; 512 return true; 513 } 514 515 /// Return true if `value` is originating from an ExtractSliceOp that matches 516 /// the given InsertSliceOp. 517 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 518 InsertSliceOp insertOp) { 519 auto condition = [&](Value val) { 520 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 521 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 522 return true; 523 return false; 524 }; 525 526 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 527 condition); 528 } 529 530 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 531 /// certain circumstances, this op can also be a no-op. 532 struct InsertSliceOpInterface 533 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 534 tensor::InsertSliceOp> { 535 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 536 const AnalysisState &state) const { 537 return true; 538 } 539 540 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 541 const AnalysisState &state) const { 542 return &opOperand == &op->getOpOperand(1) /*dest*/; 543 } 544 545 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 546 const AnalysisState &state) const { 547 if (&opOperand == &op->getOpOperand(1) /*dest*/) 548 return {op->getResult(0)}; 549 return {}; 550 } 551 552 BufferRelation bufferRelation(Operation *op, OpResult opResult, 553 const AnalysisState &state) const { 554 return BufferRelation::Equivalent; 555 } 556 557 bool isNotConflicting(Operation *op, OpOperand *uRead, 558 OpOperand *uConflictingWrite, 559 const AnalysisState &state) const { 560 Operation *readingOp = uRead->getOwner(); 561 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 562 563 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 564 // uRead is an InsertSliceOp... 565 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 566 // As an example, consider the following IR. 567 // 568 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 569 // %1 = linalg.fill %cst, %0 {inplace= [true] } 570 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 571 // {inplace= [true] } 572 573 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 574 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 575 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 576 insertSliceOp)) 577 // Case 1: The main insight is that InsertSliceOp reads only part of 578 // the destination tensor. The overwritten area is not read. If 579 // uConflictingWrite writes into exactly the memory location that is 580 // being read by uRead, this is not a conflict. 581 // 582 // In the above example: 583 // uRead = OpOperand 1 (%t) of tensor.insert_slice 584 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 585 // 586 // The read of %t does not conflict with the write of the FillOp 587 // (same aliases!) because the area that the FillOp operates on is 588 // exactly the one that is *not* read via %t. 589 return true; 590 591 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 592 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 593 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 594 // Case 2: The read of the source tensor and the write to the dest 595 // tensor via an InsertSliceOp is not a conflict if the read is 596 // reading exactly that part of an equivalent tensor that the 597 // InsertSliceOp is writing. 598 // 599 // In the above example: 600 // uRead = OpOperand 0 (%1) of tensor.insert_slice 601 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 602 return true; 603 } 604 605 // If uConflictingWrite is an InsertSliceOp... 606 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 607 // As an example, consider the following IR. 608 // 609 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 610 // %1 = linalg.fill %cst, %0 {inplace= [true] } 611 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 612 // {inplace= [true] } 613 // %3 = vector.transfer_read %1, %cst 614 // 615 // In the above example: 616 // uRead = OpOperand 0 (%1) of vector.transfer_read 617 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 618 // lastWrite = %1 619 // 620 // This is not a conflict because the InsertSliceOp overwrites the 621 // memory segment of %1 with the exact same data. (Effectively, there 622 // is no memory write here.) 623 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 624 state.areEquivalentBufferizedValues(uRead->get(), 625 insertSliceOp.source()) && 626 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 627 insertSliceOp)) 628 return true; 629 630 return false; 631 } 632 633 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 634 BufferizationState &state) const { 635 // insert_slice ops arise from tiling and bufferizing them out-of-place is 636 // generally a deal breaker. When used with loops, this ends up cloning the 637 // whole tensor on every single iteration and is a symptom of a 638 // catastrophically bad scheduling decision. 639 // TODO: be very loud about it or even consider failing the pass. 640 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 641 Location loc = insertSliceOp.getLoc(); 642 643 // When bufferizing out-of-place, `getResultBuffer` allocates. 644 FailureOr<Value> dstMemref = 645 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 646 if (failed(dstMemref)) 647 return failure(); 648 649 // Expand offsets, sizes and strides to the full rank to handle the 650 // rank-reducing case. 651 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 652 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 653 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 654 OffsetSizeAndStrideOpInterface::expandToRank( 655 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 656 [&](Value target, int64_t dim) -> OpFoldResult { 657 auto shapedType = target.getType().cast<ShapedType>(); 658 if (shapedType.isDynamicDim(dim)) 659 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 660 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 661 }); 662 // Take a subview of the dst. 663 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 664 auto subviewMemRefType = 665 memref::SubViewOp::inferRankReducedResultType( 666 insertSliceOp.getSourceType().getRank(), dstMemrefType, 667 mixedOffsets, mixedSizes, mixedStrides) 668 .cast<MemRefType>(); 669 Value subView = rewriter.create<memref::SubViewOp>( 670 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 671 mixedStrides); 672 673 // Copy tensor. If this tensor.insert_slice has a matching 674 // tensor.extract_slice, the copy operation will eventually fold away. 675 Value srcMemref = 676 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 677 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 678 state.getOptions()))) 679 return failure(); 680 681 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 682 return success(); 683 } 684 }; 685 686 /// Bufferization of tensor.rank. Replace with memref.rank. 687 struct RankOpInterface 688 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 689 tensor::RankOp> { 690 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 691 const AnalysisState &state) const { 692 return true; 693 } 694 695 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 696 const AnalysisState &state) const { 697 return false; 698 } 699 700 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 701 const AnalysisState &state) const { 702 return {}; 703 } 704 705 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 706 BufferizationState &state) const { 707 auto rankOp = cast<tensor::RankOp>(op); 708 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 709 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 710 v); 711 return success(); 712 } 713 }; 714 715 } // namespace 716 } // namespace tensor 717 } // namespace mlir 718 719 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 720 DialectRegistry ®istry) { 721 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 722 CastOp::attachInterface<CastOpInterface>(*ctx); 723 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 724 DimOp::attachInterface<DimOpInterface>(*ctx); 725 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 726 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 727 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 728 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 729 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 730 InsertOp::attachInterface<InsertOpInterface>(*ctx); 731 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 732 RankOp::attachInterface<RankOpInterface>(*ctx); 733 }); 734 } 735