1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 using namespace mlir::tensor; 20 21 namespace mlir { 22 namespace tensor { 23 namespace { 24 25 struct CastOpInterface 26 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 27 tensor::CastOp> { 28 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 29 const AnalysisState &state) const { 30 return false; 31 } 32 33 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 34 const AnalysisState &state) const { 35 return false; 36 } 37 38 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 39 const AnalysisState &state) const { 40 return {op->getResult(0)}; 41 } 42 43 BufferRelation bufferRelation(Operation *op, OpResult opResult, 44 const AnalysisState &state) const { 45 return BufferRelation::Equivalent; 46 } 47 48 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 49 BufferizationState &state) const { 50 auto castOp = cast<tensor::CastOp>(op); 51 52 // The result buffer still has the old (pre-cast) type. 53 FailureOr<Value> resultBuffer = 54 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 55 if (failed(resultBuffer)) 56 return failure(); 57 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 58 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 59 TensorType resultTensorType = 60 castOp.getResult().getType().cast<TensorType>(); 61 MemRefLayoutAttrInterface layout; 62 63 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 64 if (resultTensorType.isa<RankedTensorType>()) 65 layout = rankedMemRefType.getLayout(); 66 67 // Compute the new memref type. 68 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 69 layout, memorySpace); 70 71 // Replace the op with a memref.cast. 72 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 73 resultMemRefType) && 74 "CallOp::bufferize: cast incompatible"); 75 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 76 *resultBuffer); 77 78 return success(); 79 } 80 }; 81 82 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 83 struct CollapseShapeOpInterface 84 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 85 tensor::CollapseShapeOp> { 86 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 87 const AnalysisState &state) const { 88 return false; 89 } 90 91 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 92 const AnalysisState &state) const { 93 return false; 94 } 95 96 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 97 const AnalysisState &state) const { 98 if (&opOperand == &op->getOpOperand(0) /*src*/) 99 return {op->getOpResult(0)}; 100 return {}; 101 } 102 103 BufferRelation bufferRelation(Operation *op, OpResult opResult, 104 const AnalysisState &state) const { 105 return BufferRelation::Equivalent; 106 } 107 108 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 109 BufferizationState &state) const { 110 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 111 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 112 OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/; 113 auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>(); 114 115 if (tensorResultType.getRank() == 0) { 116 // 0-d collapses must go through a different op builder. 117 Value buffer = *state.getBuffer(rewriter, srcOperand); 118 MemRefType resultType; 119 120 if (bufferType.getLayout().isIdentity()) { 121 // Standard layout: result type has no offset. 122 MemRefLayoutAttrInterface layout; 123 resultType = MemRefType::get({}, tensorResultType.getElementType(), 124 layout, bufferType.getMemorySpace()); 125 } else { 126 // Source memref has a layout map: result type has the same offset as 127 // the source type. 128 SmallVector<int64_t> strides; 129 int64_t offset; 130 if (failed(getStridesAndOffset(bufferType, strides, offset))) 131 return failure(); 132 AffineMap resultLayout = 133 makeStridedLinearLayoutMap({}, offset, op->getContext()); 134 resultType = 135 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 136 bufferType.getMemorySpaceAsInt()); 137 } 138 139 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 140 rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 141 return success(); 142 } 143 144 // If the dims are not collapsible (due to an incompatible source layout 145 // map), force an out-of-place bufferization, i.e., a buffer copy. This 146 // newly allocated buffer will have no layout map and thus be collapsible. 147 bool canBeCollapsed = memref::ExpandShapeOp::isGuaranteedCollapsible( 148 bufferType, collapseShapeOp.getReassociationIndices()); 149 Optional<BufferizationState::ForceInPlacability> overrideInPlace = 150 canBeCollapsed 151 ? None 152 : Optional<BufferizationState::ForceInPlacability>( 153 BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE); 154 Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace); 155 156 // Result type is inferred by the builder. 157 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 158 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 159 return success(); 160 } 161 }; 162 163 /// Bufferization of tensor.dim. Replace with memref.dim. 164 struct DimOpInterface 165 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 166 tensor::DimOp> { 167 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 168 const AnalysisState &state) const { 169 return true; 170 } 171 172 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 173 const AnalysisState &state) const { 174 return false; 175 } 176 177 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 178 const AnalysisState &state) const { 179 return {}; 180 } 181 182 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 183 BufferizationState &state) const { 184 auto dimOp = cast<tensor::DimOp>(op); 185 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 186 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 187 return success(); 188 } 189 }; 190 191 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 192 struct ExpandShapeOpInterface 193 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 194 tensor::ExpandShapeOp> { 195 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 196 const AnalysisState &state) const { 197 return false; 198 } 199 200 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 201 const AnalysisState &state) const { 202 return false; 203 } 204 205 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 206 const AnalysisState &state) const { 207 if (&opOperand == &op->getOpOperand(0) /*src*/) 208 return {op->getOpResult(0)}; 209 return {}; 210 } 211 212 BufferRelation bufferRelation(Operation *op, OpResult opResult, 213 const AnalysisState &state) const { 214 return BufferRelation::Equivalent; 215 } 216 217 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 218 BufferizationState &state) const { 219 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 220 auto tensorResultType = expandShapeOp.getResultType(); 221 Value buffer = 222 *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 223 224 // Memref result type is inferred by the builder based on reassociation 225 // indices and result shape. 226 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 227 rewriter, op, tensorResultType.getShape(), buffer, 228 expandShapeOp.getReassociationIndices()); 229 return success(); 230 } 231 }; 232 233 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 234 struct ExtractSliceOpInterface 235 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 236 tensor::ExtractSliceOp> { 237 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 238 const AnalysisState &state) const { 239 return false; 240 } 241 242 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 243 const AnalysisState &state) const { 244 return false; 245 } 246 247 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 248 const AnalysisState &state) const { 249 if (&opOperand == &op->getOpOperand(0) /*source*/) 250 return {op->getOpResult(0)}; 251 return {}; 252 } 253 254 BufferRelation bufferRelation(Operation *op, OpResult opResult, 255 const AnalysisState &state) const { 256 return BufferRelation::None; 257 } 258 259 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 260 BufferizationState &state) const { 261 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 262 Location loc = extractSliceOp.getLoc(); 263 264 // Even if this op was decided to bufferize out-of-place, do not insert the 265 // buffer copy yet. This is done later in this function. 266 Value srcMemref = 267 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 268 BufferizationState::ForceInPlacability::FORCE_INPLACE); 269 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 270 auto dstTensorType = 271 extractSliceOp.result().getType().cast<RankedTensorType>(); 272 273 // If not inplaceable, alloc. 274 bool inplace = 275 state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 276 Value alloc; 277 if (!inplace) { 278 FailureOr<Value> allocOrFailure = 279 state.createAlloc(rewriter, loc, extractSliceOp.result()); 280 if (failed(allocOrFailure)) 281 return failure(); 282 alloc = *allocOrFailure; 283 } 284 285 // Expand offsets, sizes and strides to the full rank to handle the 286 // rank-reducing case. 287 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 288 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 289 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 290 OffsetSizeAndStrideOpInterface::expandToRank( 291 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 292 [&](Value target, int64_t dim) -> OpFoldResult { 293 auto shapedType = target.getType().cast<ShapedType>(); 294 if (shapedType.isDynamicDim(dim)) 295 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 296 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 297 }); 298 // Bufferize to subview. 299 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 300 dstTensorType.getRank(), srcMemrefType, 301 mixedOffsets, mixedSizes, mixedStrides) 302 .cast<MemRefType>(); 303 Value subView = rewriter.create<memref::SubViewOp>( 304 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 305 mixedStrides); 306 307 // If not inplaceable, copy. 308 if (!inplace) { 309 // Do not copy if the copied data is never read. 310 if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 311 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 312 alloc, state.getOptions()))) 313 return failure(); 314 subView = alloc; 315 } 316 317 replaceOpWithBufferizedValues(rewriter, op, subView); 318 return success(); 319 } 320 }; 321 322 /// Bufferization of tensor.extract. Replace with memref.load. 323 struct ExtractOpInterface 324 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 325 tensor::ExtractOp> { 326 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 327 const AnalysisState &state) const { 328 return true; 329 } 330 331 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 332 const AnalysisState &state) const { 333 return false; 334 } 335 336 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 337 const AnalysisState &state) const { 338 return {}; 339 } 340 341 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 342 BufferizationState &state) const { 343 auto extractOp = cast<tensor::ExtractOp>(op); 344 Value srcMemref = 345 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 346 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 347 extractOp.indices()); 348 return success(); 349 } 350 }; 351 352 // Implements backtracking to traverse indices of the output buffer while 353 // iterating over op.elements(). 354 static void createStores(RewriterBase &rewriter, Location loc, int dim, 355 Value buffer, ArrayRef<int64_t> shape, 356 ArrayRef<Value> constants, 357 OperandRange::iterator &elementIt, 358 SmallVectorImpl<Value> &indices) { 359 if (dim == static_cast<int>(shape.size()) - 1) { 360 for (int i = 0; i < shape.back(); ++i) { 361 indices.back() = constants[i]; 362 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 363 ++elementIt; 364 } 365 return; 366 } 367 for (int i = 0; i < shape[dim]; ++i) { 368 indices[dim] = constants[i]; 369 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 370 indices); 371 } 372 } 373 374 /// Bufferization of tensor.from_elements. 375 struct FromElementsOpInterface 376 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 377 tensor::FromElementsOp> { 378 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 379 BufferizationState &state) const { 380 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 381 382 // Allocate a buffer for the result. 383 Location loc = op->getLoc(); 384 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 385 auto shape = tensorType.getShape(); 386 FailureOr<Value> maybeBuffer = 387 state.createAlloc(rewriter, loc, fromElementsOp.result()); 388 if (failed(maybeBuffer)) 389 return failure(); 390 Value buffer = *maybeBuffer; 391 392 // Case: tensor<0xelem_type>. 393 if (fromElementsOp.elements().empty()) { 394 replaceOpWithBufferizedValues(rewriter, op, buffer); 395 return success(); 396 } 397 398 // Case: tensor<elem_type>. 399 if (shape.empty()) { 400 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 401 buffer); 402 replaceOpWithBufferizedValues(rewriter, op, buffer); 403 return success(); 404 } 405 406 // Create constants for the range of possible indices [0, max{shape_i}). 407 auto maxDim = *std::max_element(shape.begin(), shape.end()); 408 SmallVector<Value, 2> constants; 409 constants.reserve(maxDim); 410 for (int i = 0; i < maxDim; ++i) 411 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 412 413 // Traverse all `elements` and create `memref.store` ops. 414 auto elementIt = fromElementsOp.elements().begin(); 415 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 416 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 417 indices); 418 419 replaceOpWithBufferizedValues(rewriter, op, buffer); 420 return success(); 421 } 422 }; 423 424 /// Bufferization of tensor.generate. 425 struct GenerateOpInterface 426 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 427 tensor::GenerateOp> { 428 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 429 BufferizationState &state) const { 430 auto generateOp = cast<tensor::GenerateOp>(op); 431 432 // Allocate memory. 433 Location loc = op->getLoc(); 434 MemRefType memrefType = 435 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 436 FailureOr<Value> maybeResult = 437 state.createAlloc(rewriter, loc, generateOp.result()); 438 if (failed(maybeResult)) 439 return failure(); 440 Value result = *maybeResult; 441 442 // Collect loop bounds. 443 int64_t rank = memrefType.getRank(); 444 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 445 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 446 SmallVector<Value, 4> lowerBounds(rank, zero); 447 SmallVector<Value, 4> steps(rank, one); 448 SmallVector<Value, 4> upperBounds; 449 int nextDynamicIndex = 0; 450 for (int i = 0; i < rank; i++) { 451 Value upperBound = memrefType.isDynamicDim(i) 452 ? generateOp.dynamicExtents()[nextDynamicIndex++] 453 : rewriter.create<arith::ConstantIndexOp>( 454 loc, memrefType.getDimSize(i)); 455 upperBounds.push_back(upperBound); 456 } 457 458 // Generate tensor elements with a parallel loop that stores into 459 // each element of the resulting memref. We use mergeBlockBefore to "move" 460 // this op's body into the scf.parallel's body. 461 auto parallel = 462 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 463 Block *parallelBody = parallel.getBody(); 464 rewriter.mergeBlockBefore(generateOp.getBody(), 465 parallelBody->getTerminator(), 466 parallelBody->getArguments()); 467 // Replace the inlined yield op with a store op. The scf.parallel's builder 468 // already populated an scf.yield at the end, so we don't need to worry 469 // about creating that. 470 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 471 rewriter.setInsertionPointAfter(elementYield); 472 rewriter.replaceOpWithNewOp<memref::StoreOp>( 473 elementYield, elementYield->getOperands()[0], result, 474 parallelBody->getArguments()); 475 476 replaceOpWithBufferizedValues(rewriter, op, result); 477 return success(); 478 } 479 }; 480 481 /// Bufferization of tensor.insert. Replace with memref.store. 482 struct InsertOpInterface 483 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 484 tensor::InsertOp> { 485 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 486 const AnalysisState &state) const { 487 return true; 488 } 489 490 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 491 const AnalysisState &state) const { 492 return true; 493 } 494 495 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 496 const AnalysisState &state) const { 497 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 498 "expected dest OpOperand"); 499 return {op->getOpResult(0)}; 500 } 501 502 SmallVector<OpOperand *> 503 getAliasingOpOperand(Operation *op, OpResult opResult, 504 const AnalysisState &state) const { 505 return {&op->getOpOperand(1) /*dest*/}; 506 } 507 508 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 509 BufferizationState &state) const { 510 auto insertOp = cast<tensor::InsertOp>(op); 511 FailureOr<Value> destMemref = 512 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 513 if (failed(destMemref)) 514 return failure(); 515 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 516 *destMemref, insertOp.indices()); 517 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 518 return success(); 519 } 520 521 BufferRelation bufferRelation(Operation *op, OpResult opResult, 522 const AnalysisState &state) const { 523 return BufferRelation::Equivalent; 524 } 525 }; 526 527 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 528 /// equivalent operand / result and same offset/sizes/strides specification). 529 /// 530 /// This is one particular type of relationship between ops on tensors that 531 /// reduce to an equivalence on buffers. This should be generalized and 532 /// exposed as interfaces on the proper types. 533 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 534 ExtractSliceOp st, InsertSliceOp sti) { 535 if (!st || !sti) 536 return false; 537 if (sti != sti && 538 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 539 return false; 540 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 541 return false; 542 return true; 543 } 544 545 /// Return true if `value` is originating from an ExtractSliceOp that matches 546 /// the given InsertSliceOp. 547 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 548 InsertSliceOp insertOp) { 549 auto condition = [&](Value val) { 550 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 551 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 552 return true; 553 return false; 554 }; 555 556 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 557 condition); 558 } 559 560 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 561 /// certain circumstances, this op can also be a no-op. 562 struct InsertSliceOpInterface 563 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 564 tensor::InsertSliceOp> { 565 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 566 const AnalysisState &state) const { 567 return true; 568 } 569 570 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 571 const AnalysisState &state) const { 572 return &opOperand == &op->getOpOperand(1) /*dest*/; 573 } 574 575 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 576 const AnalysisState &state) const { 577 if (&opOperand == &op->getOpOperand(1) /*dest*/) 578 return {op->getResult(0)}; 579 return {}; 580 } 581 582 BufferRelation bufferRelation(Operation *op, OpResult opResult, 583 const AnalysisState &state) const { 584 return BufferRelation::Equivalent; 585 } 586 587 bool isNotConflicting(Operation *op, OpOperand *uRead, 588 OpOperand *uConflictingWrite, 589 const AnalysisState &state) const { 590 Operation *readingOp = uRead->getOwner(); 591 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 592 593 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 594 // uRead is an InsertSliceOp... 595 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 596 // As an example, consider the following IR. 597 // 598 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 599 // %1 = linalg.fill %cst, %0 {inplace= [true] } 600 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 601 // {inplace= [true] } 602 603 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 604 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 605 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 606 insertSliceOp)) 607 // Case 1: The main insight is that InsertSliceOp reads only part of 608 // the destination tensor. The overwritten area is not read. If 609 // uConflictingWrite writes into exactly the memory location that is 610 // being read by uRead, this is not a conflict. 611 // 612 // In the above example: 613 // uRead = OpOperand 1 (%t) of tensor.insert_slice 614 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 615 // 616 // The read of %t does not conflict with the write of the FillOp 617 // (same aliases!) because the area that the FillOp operates on is 618 // exactly the one that is *not* read via %t. 619 return true; 620 621 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 622 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 623 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 624 // Case 2: The read of the source tensor and the write to the dest 625 // tensor via an InsertSliceOp is not a conflict if the read is 626 // reading exactly that part of an equivalent tensor that the 627 // InsertSliceOp is writing. 628 // 629 // In the above example: 630 // uRead = OpOperand 0 (%1) of tensor.insert_slice 631 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 632 return true; 633 } 634 635 // If uConflictingWrite is an InsertSliceOp... 636 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 637 // As an example, consider the following IR. 638 // 639 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 640 // %1 = linalg.fill %cst, %0 {inplace= [true] } 641 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 642 // {inplace= [true] } 643 // %3 = vector.transfer_read %1, %cst 644 // 645 // In the above example: 646 // uRead = OpOperand 0 (%1) of vector.transfer_read 647 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 648 // lastWrite = %1 649 // 650 // This is not a conflict because the InsertSliceOp overwrites the 651 // memory segment of %1 with the exact same data. (Effectively, there 652 // is no memory write here.) 653 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 654 state.areEquivalentBufferizedValues(uRead->get(), 655 insertSliceOp.source()) && 656 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 657 insertSliceOp)) 658 return true; 659 660 return false; 661 } 662 663 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 664 BufferizationState &state) const { 665 // insert_slice ops arise from tiling and bufferizing them out-of-place is 666 // generally a deal breaker. When used with loops, this ends up cloning the 667 // whole tensor on every single iteration and is a symptom of a 668 // catastrophically bad scheduling decision. 669 // TODO: be very loud about it or even consider failing the pass. 670 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 671 Location loc = insertSliceOp.getLoc(); 672 673 // When bufferizing out-of-place, `getResultBuffer` allocates. 674 FailureOr<Value> dstMemref = 675 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 676 if (failed(dstMemref)) 677 return failure(); 678 679 // Expand offsets, sizes and strides to the full rank to handle the 680 // rank-reducing case. 681 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 682 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 683 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 684 OffsetSizeAndStrideOpInterface::expandToRank( 685 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 686 [&](Value target, int64_t dim) -> OpFoldResult { 687 auto shapedType = target.getType().cast<ShapedType>(); 688 if (shapedType.isDynamicDim(dim)) 689 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 690 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 691 }); 692 // Take a subview of the dst. 693 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 694 auto subviewMemRefType = 695 memref::SubViewOp::inferRankReducedResultType( 696 insertSliceOp.getSourceType().getRank(), dstMemrefType, 697 mixedOffsets, mixedSizes, mixedStrides) 698 .cast<MemRefType>(); 699 Value subView = rewriter.create<memref::SubViewOp>( 700 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 701 mixedStrides); 702 703 // Copy tensor. If this tensor.insert_slice has a matching 704 // tensor.extract_slice, the copy operation will eventually fold away. 705 Value srcMemref = 706 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 707 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 708 state.getOptions()))) 709 return failure(); 710 711 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 712 return success(); 713 } 714 }; 715 716 /// Bufferization of tensor.rank. Replace with memref.rank. 717 struct RankOpInterface 718 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 719 tensor::RankOp> { 720 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 721 const AnalysisState &state) const { 722 return true; 723 } 724 725 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 726 const AnalysisState &state) const { 727 return false; 728 } 729 730 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 731 const AnalysisState &state) const { 732 return {}; 733 } 734 735 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 736 BufferizationState &state) const { 737 auto rankOp = cast<tensor::RankOp>(op); 738 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 739 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 740 v); 741 return success(); 742 } 743 }; 744 745 } // namespace 746 } // namespace tensor 747 } // namespace mlir 748 749 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 750 DialectRegistry ®istry) { 751 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 752 CastOp::attachInterface<CastOpInterface>(*ctx); 753 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 754 DimOp::attachInterface<DimOpInterface>(*ctx); 755 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 756 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 757 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 758 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 759 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 760 InsertOp::attachInterface<InsertOpInterface>(*ctx); 761 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 762 RankOp::attachInterface<RankOpInterface>(*ctx); 763 }); 764 } 765