1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 using namespace mlir::tensor; 21 22 namespace mlir { 23 namespace tensor { 24 namespace { 25 26 struct CastOpInterface 27 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 28 tensor::CastOp> { 29 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 30 const AnalysisState &state) const { 31 return false; 32 } 33 34 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 35 const AnalysisState &state) const { 36 return false; 37 } 38 39 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 40 const AnalysisState &state) const { 41 return {op->getResult(0)}; 42 } 43 44 BufferRelation bufferRelation(Operation *op, OpResult opResult, 45 const AnalysisState &state) const { 46 return BufferRelation::Equivalent; 47 } 48 49 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 50 BufferizationState &state) const { 51 auto castOp = cast<tensor::CastOp>(op); 52 53 // The result buffer still has the old (pre-cast) type. 54 FailureOr<Value> resultBuffer = 55 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 56 if (failed(resultBuffer)) 57 return failure(); 58 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 59 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 60 TensorType resultTensorType = 61 castOp.getResult().getType().cast<TensorType>(); 62 MemRefLayoutAttrInterface layout; 63 64 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 65 if (resultTensorType.isa<RankedTensorType>()) 66 layout = rankedMemRefType.getLayout(); 67 68 // Compute the new memref type. 69 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 70 layout, memorySpace); 71 72 // Replace the op with a memref.cast. 73 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 74 resultMemRefType) && 75 "CallOp::bufferize: cast incompatible"); 76 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 77 *resultBuffer); 78 79 return success(); 80 } 81 }; 82 83 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 84 struct CollapseShapeOpInterface 85 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 86 tensor::CollapseShapeOp> { 87 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 88 const AnalysisState &state) const { 89 return false; 90 } 91 92 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 93 const AnalysisState &state) const { 94 return false; 95 } 96 97 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 98 const AnalysisState &state) const { 99 if (&opOperand == &op->getOpOperand(0) /*src*/) 100 return {op->getOpResult(0)}; 101 return {}; 102 } 103 104 BufferRelation bufferRelation(Operation *op, OpResult opResult, 105 const AnalysisState &state) const { 106 return BufferRelation::Equivalent; 107 } 108 109 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 110 BufferizationState &state) const { 111 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 112 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 113 OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/; 114 auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>(); 115 116 if (tensorResultType.getRank() == 0) { 117 // 0-d collapses must go through a different op builder. 118 Value buffer = *state.getBuffer(rewriter, srcOperand); 119 MemRefType resultType; 120 121 if (bufferType.getLayout().isIdentity()) { 122 // Standard layout: result type has no offset. 123 MemRefLayoutAttrInterface layout; 124 resultType = MemRefType::get({}, tensorResultType.getElementType(), 125 layout, bufferType.getMemorySpace()); 126 } else { 127 // Source memref has a layout map: result type has the same offset as 128 // the source type. 129 SmallVector<int64_t> strides; 130 int64_t offset; 131 if (failed(getStridesAndOffset(bufferType, strides, offset))) 132 return failure(); 133 AffineMap resultLayout = 134 makeStridedLinearLayoutMap({}, offset, op->getContext()); 135 resultType = 136 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 137 bufferType.getMemorySpaceAsInt()); 138 } 139 140 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 141 rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 142 return success(); 143 } 144 145 // If the dims are not collapsible (due to an incompatible source layout 146 // map), force an out-of-place bufferization, i.e., a buffer copy. This 147 // newly allocated buffer will have no layout map and thus be collapsible. 148 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 149 bufferType, collapseShapeOp.getReassociationIndices()); 150 Optional<BufferizationState::ForceInPlacability> overrideInPlace = 151 canBeCollapsed 152 ? None 153 : Optional<BufferizationState::ForceInPlacability>( 154 BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE); 155 Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace); 156 157 // Result type is inferred by the builder. 158 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 159 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 160 return success(); 161 } 162 }; 163 164 /// Bufferization of tensor.dim. Replace with memref.dim. 165 struct DimOpInterface 166 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 167 tensor::DimOp> { 168 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 169 const AnalysisState &state) const { 170 return true; 171 } 172 173 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 174 const AnalysisState &state) const { 175 return false; 176 } 177 178 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 179 const AnalysisState &state) const { 180 return {}; 181 } 182 183 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 184 BufferizationState &state) const { 185 auto dimOp = cast<tensor::DimOp>(op); 186 Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 187 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 188 return success(); 189 } 190 }; 191 192 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 193 struct ExpandShapeOpInterface 194 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 195 tensor::ExpandShapeOp> { 196 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 197 const AnalysisState &state) const { 198 return false; 199 } 200 201 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 202 const AnalysisState &state) const { 203 return false; 204 } 205 206 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 207 const AnalysisState &state) const { 208 if (&opOperand == &op->getOpOperand(0) /*src*/) 209 return {op->getOpResult(0)}; 210 return {}; 211 } 212 213 BufferRelation bufferRelation(Operation *op, OpResult opResult, 214 const AnalysisState &state) const { 215 return BufferRelation::Equivalent; 216 } 217 218 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 219 BufferizationState &state) const { 220 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 221 auto tensorResultType = expandShapeOp.getResultType(); 222 Value buffer = 223 *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 224 225 // Memref result type is inferred by the builder based on reassociation 226 // indices and result shape. 227 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 228 rewriter, op, tensorResultType.getShape(), buffer, 229 expandShapeOp.getReassociationIndices()); 230 return success(); 231 } 232 }; 233 234 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 235 struct ExtractSliceOpInterface 236 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 237 tensor::ExtractSliceOp> { 238 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 239 const AnalysisState &state) const { 240 return false; 241 } 242 243 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 244 const AnalysisState &state) const { 245 return false; 246 } 247 248 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 249 const AnalysisState &state) const { 250 if (&opOperand == &op->getOpOperand(0) /*source*/) 251 return {op->getOpResult(0)}; 252 return {}; 253 } 254 255 BufferRelation bufferRelation(Operation *op, OpResult opResult, 256 const AnalysisState &state) const { 257 return BufferRelation::None; 258 } 259 260 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 261 BufferizationState &state) const { 262 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 263 Location loc = extractSliceOp.getLoc(); 264 265 // Even if this op was decided to bufferize out-of-place, do not insert the 266 // buffer copy yet. This is done later in this function. 267 Value srcMemref = 268 *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 269 BufferizationState::ForceInPlacability::FORCE_INPLACE); 270 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 271 auto dstTensorType = 272 extractSliceOp.result().getType().cast<RankedTensorType>(); 273 274 // If not inplaceable, alloc. 275 bool inplace = 276 state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 277 Value alloc; 278 if (!inplace) { 279 FailureOr<Value> allocOrFailure = 280 state.createAlloc(rewriter, loc, extractSliceOp.result()); 281 if (failed(allocOrFailure)) 282 return failure(); 283 alloc = *allocOrFailure; 284 } 285 286 // Expand offsets, sizes and strides to the full rank to handle the 287 // rank-reducing case. 288 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 289 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 290 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 291 OffsetSizeAndStrideOpInterface::expandToRank( 292 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 293 [&](Value target, int64_t dim) -> OpFoldResult { 294 auto shapedType = target.getType().cast<ShapedType>(); 295 if (shapedType.isDynamicDim(dim)) 296 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 297 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 298 }); 299 // Bufferize to subview. 300 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 301 dstTensorType.getRank(), srcMemrefType, 302 mixedOffsets, mixedSizes, mixedStrides) 303 .cast<MemRefType>(); 304 Value subView = rewriter.create<memref::SubViewOp>( 305 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 306 mixedStrides); 307 308 // If not inplaceable, copy. 309 if (!inplace) { 310 // Do not copy if the copied data is never read. 311 if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 312 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 313 alloc, state.getOptions()))) 314 return failure(); 315 subView = alloc; 316 } 317 318 replaceOpWithBufferizedValues(rewriter, op, subView); 319 return success(); 320 } 321 }; 322 323 /// Bufferization of tensor.extract. Replace with memref.load. 324 struct ExtractOpInterface 325 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 326 tensor::ExtractOp> { 327 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 328 const AnalysisState &state) const { 329 return true; 330 } 331 332 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 333 const AnalysisState &state) const { 334 return false; 335 } 336 337 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 338 const AnalysisState &state) const { 339 return {}; 340 } 341 342 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 343 BufferizationState &state) const { 344 auto extractOp = cast<tensor::ExtractOp>(op); 345 Value srcMemref = 346 *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 347 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 348 extractOp.indices()); 349 return success(); 350 } 351 }; 352 353 // Implements backtracking to traverse indices of the output buffer while 354 // iterating over op.elements(). 355 static void createStores(RewriterBase &rewriter, Location loc, int dim, 356 Value buffer, ArrayRef<int64_t> shape, 357 ArrayRef<Value> constants, 358 OperandRange::iterator &elementIt, 359 SmallVectorImpl<Value> &indices) { 360 if (dim == static_cast<int>(shape.size()) - 1) { 361 for (int i = 0; i < shape.back(); ++i) { 362 indices.back() = constants[i]; 363 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 364 ++elementIt; 365 } 366 return; 367 } 368 for (int i = 0; i < shape[dim]; ++i) { 369 indices[dim] = constants[i]; 370 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 371 indices); 372 } 373 } 374 375 /// Bufferization of tensor.from_elements. 376 struct FromElementsOpInterface 377 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 378 tensor::FromElementsOp> { 379 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 380 BufferizationState &state) const { 381 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 382 383 // Allocate a buffer for the result. 384 Location loc = op->getLoc(); 385 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 386 auto shape = tensorType.getShape(); 387 FailureOr<Value> maybeBuffer = 388 state.createAlloc(rewriter, loc, fromElementsOp.result()); 389 if (failed(maybeBuffer)) 390 return failure(); 391 Value buffer = *maybeBuffer; 392 393 // Case: tensor<0xelem_type>. 394 if (fromElementsOp.elements().empty()) { 395 replaceOpWithBufferizedValues(rewriter, op, buffer); 396 return success(); 397 } 398 399 // Case: tensor<elem_type>. 400 if (shape.empty()) { 401 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 402 buffer); 403 replaceOpWithBufferizedValues(rewriter, op, buffer); 404 return success(); 405 } 406 407 // Create constants for the range of possible indices [0, max{shape_i}). 408 auto maxDim = *std::max_element(shape.begin(), shape.end()); 409 SmallVector<Value, 2> constants; 410 constants.reserve(maxDim); 411 for (int i = 0; i < maxDim; ++i) 412 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 413 414 // Traverse all `elements` and create `memref.store` ops. 415 auto elementIt = fromElementsOp.elements().begin(); 416 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 417 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 418 indices); 419 420 replaceOpWithBufferizedValues(rewriter, op, buffer); 421 return success(); 422 } 423 }; 424 425 /// Bufferization of tensor.generate. 426 struct GenerateOpInterface 427 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 428 tensor::GenerateOp> { 429 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 430 BufferizationState &state) const { 431 auto generateOp = cast<tensor::GenerateOp>(op); 432 433 // Allocate memory. 434 Location loc = op->getLoc(); 435 MemRefType memrefType = 436 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 437 FailureOr<Value> maybeResult = 438 state.createAlloc(rewriter, loc, generateOp.result()); 439 if (failed(maybeResult)) 440 return failure(); 441 Value result = *maybeResult; 442 443 // Collect loop bounds. 444 int64_t rank = memrefType.getRank(); 445 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 446 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 447 SmallVector<Value, 4> lowerBounds(rank, zero); 448 SmallVector<Value, 4> steps(rank, one); 449 SmallVector<Value, 4> upperBounds; 450 int nextDynamicIndex = 0; 451 for (int i = 0; i < rank; i++) { 452 Value upperBound = memrefType.isDynamicDim(i) 453 ? generateOp.dynamicExtents()[nextDynamicIndex++] 454 : rewriter.create<arith::ConstantIndexOp>( 455 loc, memrefType.getDimSize(i)); 456 upperBounds.push_back(upperBound); 457 } 458 459 // Generate tensor elements with a parallel loop that stores into 460 // each element of the resulting memref. We use mergeBlockBefore to "move" 461 // this op's body into the scf.parallel's body. 462 auto parallel = 463 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 464 Block *parallelBody = parallel.getBody(); 465 rewriter.mergeBlockBefore(generateOp.getBody(), 466 parallelBody->getTerminator(), 467 parallelBody->getArguments()); 468 // Replace the inlined yield op with a store op. The scf.parallel's builder 469 // already populated an scf.yield at the end, so we don't need to worry 470 // about creating that. 471 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 472 rewriter.setInsertionPointAfter(elementYield); 473 rewriter.replaceOpWithNewOp<memref::StoreOp>( 474 elementYield, elementYield->getOperands()[0], result, 475 parallelBody->getArguments()); 476 477 replaceOpWithBufferizedValues(rewriter, op, result); 478 return success(); 479 } 480 }; 481 482 /// Bufferization of tensor.insert. Replace with memref.store. 483 struct InsertOpInterface 484 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 485 tensor::InsertOp> { 486 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 487 const AnalysisState &state) const { 488 return true; 489 } 490 491 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 492 const AnalysisState &state) const { 493 return true; 494 } 495 496 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 497 const AnalysisState &state) const { 498 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 499 "expected dest OpOperand"); 500 return {op->getOpResult(0)}; 501 } 502 503 SmallVector<OpOperand *> 504 getAliasingOpOperand(Operation *op, OpResult opResult, 505 const AnalysisState &state) const { 506 return {&op->getOpOperand(1) /*dest*/}; 507 } 508 509 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 510 BufferizationState &state) const { 511 auto insertOp = cast<tensor::InsertOp>(op); 512 FailureOr<Value> destMemref = 513 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 514 if (failed(destMemref)) 515 return failure(); 516 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 517 *destMemref, insertOp.indices()); 518 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 519 return success(); 520 } 521 522 BufferRelation bufferRelation(Operation *op, OpResult opResult, 523 const AnalysisState &state) const { 524 return BufferRelation::Equivalent; 525 } 526 }; 527 528 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 529 /// equivalent operand / result and same offset/sizes/strides specification). 530 /// 531 /// This is one particular type of relationship between ops on tensors that 532 /// reduce to an equivalence on buffers. This should be generalized and 533 /// exposed as interfaces on the proper types. 534 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 535 ExtractSliceOp st, InsertSliceOp sti) { 536 if (!st || !sti) 537 return false; 538 if (sti != sti && 539 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 540 return false; 541 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 542 return false; 543 return true; 544 } 545 546 /// Return true if `value` is originating from an ExtractSliceOp that matches 547 /// the given InsertSliceOp. 548 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 549 InsertSliceOp insertOp) { 550 auto condition = [&](Value val) { 551 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 552 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 553 return true; 554 return false; 555 }; 556 557 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 558 condition); 559 } 560 561 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 562 /// certain circumstances, this op can also be a no-op. 563 struct InsertSliceOpInterface 564 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 565 tensor::InsertSliceOp> { 566 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 567 const AnalysisState &state) const { 568 return true; 569 } 570 571 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 572 const AnalysisState &state) const { 573 return &opOperand == &op->getOpOperand(1) /*dest*/; 574 } 575 576 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 577 const AnalysisState &state) const { 578 if (&opOperand == &op->getOpOperand(1) /*dest*/) 579 return {op->getResult(0)}; 580 return {}; 581 } 582 583 BufferRelation bufferRelation(Operation *op, OpResult opResult, 584 const AnalysisState &state) const { 585 return BufferRelation::Equivalent; 586 } 587 588 bool isNotConflicting(Operation *op, OpOperand *uRead, 589 OpOperand *uConflictingWrite, 590 const AnalysisState &state) const { 591 Operation *readingOp = uRead->getOwner(); 592 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 593 594 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 595 // uRead is an InsertSliceOp... 596 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 597 // As an example, consider the following IR. 598 // 599 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 600 // %1 = linalg.fill %cst, %0 {inplace= [true] } 601 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 602 // {inplace= [true] } 603 604 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 605 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 606 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 607 insertSliceOp)) 608 // Case 1: The main insight is that InsertSliceOp reads only part of 609 // the destination tensor. The overwritten area is not read. If 610 // uConflictingWrite writes into exactly the memory location that is 611 // being read by uRead, this is not a conflict. 612 // 613 // In the above example: 614 // uRead = OpOperand 1 (%t) of tensor.insert_slice 615 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 616 // 617 // The read of %t does not conflict with the write of the FillOp 618 // (same aliases!) because the area that the FillOp operates on is 619 // exactly the one that is *not* read via %t. 620 return true; 621 622 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 623 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 624 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 625 // Case 2: The read of the source tensor and the write to the dest 626 // tensor via an InsertSliceOp is not a conflict if the read is 627 // reading exactly that part of an equivalent tensor that the 628 // InsertSliceOp is writing. 629 // 630 // In the above example: 631 // uRead = OpOperand 0 (%1) of tensor.insert_slice 632 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 633 return true; 634 } 635 636 // If uConflictingWrite is an InsertSliceOp... 637 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 638 // As an example, consider the following IR. 639 // 640 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 641 // %1 = linalg.fill %cst, %0 {inplace= [true] } 642 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 643 // {inplace= [true] } 644 // %3 = vector.transfer_read %1, %cst 645 // 646 // In the above example: 647 // uRead = OpOperand 0 (%1) of vector.transfer_read 648 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 649 // lastWrite = %1 650 // 651 // This is not a conflict because the InsertSliceOp overwrites the 652 // memory segment of %1 with the exact same data. (Effectively, there 653 // is no memory write here.) 654 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 655 state.areEquivalentBufferizedValues(uRead->get(), 656 insertSliceOp.source()) && 657 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 658 insertSliceOp)) 659 return true; 660 661 return false; 662 } 663 664 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 665 BufferizationState &state) const { 666 // insert_slice ops arise from tiling and bufferizing them out-of-place is 667 // generally a deal breaker. When used with loops, this ends up cloning the 668 // whole tensor on every single iteration and is a symptom of a 669 // catastrophically bad scheduling decision. 670 // TODO: be very loud about it or even consider failing the pass. 671 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 672 Location loc = insertSliceOp.getLoc(); 673 674 // When bufferizing out-of-place, `getResultBuffer` allocates. 675 FailureOr<Value> dstMemref = 676 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 677 if (failed(dstMemref)) 678 return failure(); 679 680 // Expand offsets, sizes and strides to the full rank to handle the 681 // rank-reducing case. 682 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 683 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 684 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 685 OffsetSizeAndStrideOpInterface::expandToRank( 686 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 687 [&](Value target, int64_t dim) -> OpFoldResult { 688 auto shapedType = target.getType().cast<ShapedType>(); 689 if (shapedType.isDynamicDim(dim)) 690 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 691 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 692 }); 693 // Take a subview of the dst. 694 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 695 auto subviewMemRefType = 696 memref::SubViewOp::inferRankReducedResultType( 697 insertSliceOp.getSourceType().getRank(), dstMemrefType, 698 mixedOffsets, mixedSizes, mixedStrides) 699 .cast<MemRefType>(); 700 Value subView = rewriter.create<memref::SubViewOp>( 701 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 702 mixedStrides); 703 704 // Copy tensor. If this tensor.insert_slice has a matching 705 // tensor.extract_slice, the copy operation will eventually fold away. 706 Value srcMemref = 707 *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 708 if (failed(createMemCpy(rewriter, loc, srcMemref, subView, 709 state.getOptions()))) 710 return failure(); 711 712 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 713 return success(); 714 } 715 }; 716 717 /// Bufferization of tensor.rank. Replace with memref.rank. 718 struct RankOpInterface 719 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 720 tensor::RankOp> { 721 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 722 const AnalysisState &state) const { 723 return true; 724 } 725 726 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 727 const AnalysisState &state) const { 728 return false; 729 } 730 731 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 732 const AnalysisState &state) const { 733 return {}; 734 } 735 736 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 737 BufferizationState &state) const { 738 auto rankOp = cast<tensor::RankOp>(op); 739 Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 740 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 741 v); 742 return success(); 743 } 744 }; 745 746 } // namespace 747 } // namespace tensor 748 } // namespace mlir 749 750 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 751 DialectRegistry ®istry) { 752 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 753 CastOp::attachInterface<CastOpInterface>(*ctx); 754 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 755 DimOp::attachInterface<DimOpInterface>(*ctx); 756 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 757 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 758 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 759 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 760 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 761 InsertOp::attachInterface<InsertOpInterface>(*ctx); 762 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 763 RankOp::attachInterface<RankOpInterface>(*ctx); 764 }); 765 } 766