1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 using namespace mlir::tensor; 22 23 namespace mlir { 24 namespace tensor { 25 namespace { 26 27 struct CastOpInterface 28 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 29 tensor::CastOp> { 30 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 31 const AnalysisState &state) const { 32 return false; 33 } 34 35 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 36 const AnalysisState &state) const { 37 return false; 38 } 39 40 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 41 const AnalysisState &state) const { 42 return {op->getResult(0)}; 43 } 44 45 BufferRelation bufferRelation(Operation *op, OpResult opResult, 46 const AnalysisState &state) const { 47 return BufferRelation::Equivalent; 48 } 49 50 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 51 const BufferizationOptions &options) const { 52 auto castOp = cast<tensor::CastOp>(op); 53 54 // The result buffer still has the old (pre-cast) type. 55 Value resultBuffer = getBuffer(rewriter, castOp.source(), options); 56 auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>(); 57 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 58 TensorType resultTensorType = 59 castOp.getResult().getType().cast<TensorType>(); 60 MemRefLayoutAttrInterface layout; 61 62 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 63 if (resultTensorType.isa<RankedTensorType>()) 64 layout = rankedMemRefType.getLayout(); 65 66 // Compute the new memref type. 67 Type resultMemRefType = 68 getMemRefType(resultTensorType, options, layout, memorySpace); 69 70 // Replace the op with a memref.cast. 71 assert(memref::CastOp::areCastCompatible(resultBuffer.getType(), 72 resultMemRefType) && 73 "CallOp::bufferize: cast incompatible"); 74 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 75 resultBuffer); 76 77 return success(); 78 } 79 }; 80 81 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 82 struct CollapseShapeOpInterface 83 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 84 tensor::CollapseShapeOp> { 85 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 86 const AnalysisState &state) const { 87 return false; 88 } 89 90 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 91 const AnalysisState &state) const { 92 return false; 93 } 94 95 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 96 const AnalysisState &state) const { 97 if (&opOperand == &op->getOpOperand(0) /*src*/) 98 return {op->getOpResult(0)}; 99 return {}; 100 } 101 102 BufferRelation bufferRelation(Operation *op, OpResult opResult, 103 const AnalysisState &state) const { 104 return BufferRelation::Equivalent; 105 } 106 107 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 108 const BufferizationOptions &options) const { 109 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 110 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 111 Value buffer = getBuffer(rewriter, collapseShapeOp.src(), options); 112 auto bufferType = buffer.getType().cast<MemRefType>(); 113 114 if (tensorResultType.getRank() == 0) { 115 // 0-d collapses must go through a different op builder. 116 MemRefType resultType; 117 118 if (bufferType.getLayout().isIdentity()) { 119 // Standard layout: result type has no offset. 120 MemRefLayoutAttrInterface layout; 121 resultType = MemRefType::get({}, tensorResultType.getElementType(), 122 layout, bufferType.getMemorySpace()); 123 } else { 124 // Source memref has a layout map: result type has the same offset as 125 // the source type. 126 SmallVector<int64_t> strides; 127 int64_t offset; 128 if (failed(getStridesAndOffset(bufferType, strides, offset))) 129 return failure(); 130 AffineMap resultLayout = 131 makeStridedLinearLayoutMap({}, offset, op->getContext()); 132 resultType = 133 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 134 bufferType.getMemorySpaceAsInt()); 135 } 136 137 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 138 rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); 139 return success(); 140 } 141 142 // If the dims are not collapsible (due to an incompatible source layout 143 // map), force an out-of-place bufferization, i.e., a buffer copy. This 144 // newly allocated buffer will have no layout map and thus be collapsible. 145 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 146 bufferType, collapseShapeOp.getReassociationIndices()); 147 if (!canBeCollapsed) { 148 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 149 AnalysisState analysisState(options); 150 Value tensorAlloc = allocateTensorForShapedValue( 151 rewriter, op->getLoc(), collapseShapeOp.src(), 152 analysisState.isTensorYielded(collapseShapeOp.result())); 153 auto memrefType = 154 MemRefType::get(collapseShapeOp.getSrcType().getShape(), 155 collapseShapeOp.getSrcType().getElementType(), 156 AffineMap(), bufferType.getMemorySpaceAsInt()); 157 buffer = rewriter.create<bufferization::ToMemrefOp>( 158 op->getLoc(), memrefType, tensorAlloc); 159 } 160 161 // Result type is inferred by the builder. 162 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 163 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 164 return success(); 165 } 166 }; 167 168 /// Bufferization of tensor.dim. Replace with memref.dim. 169 struct DimOpInterface 170 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 171 tensor::DimOp> { 172 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 173 const AnalysisState &state) const { 174 return true; 175 } 176 177 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 178 const AnalysisState &state) const { 179 return false; 180 } 181 182 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 183 const AnalysisState &state) const { 184 return {}; 185 } 186 187 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 188 const BufferizationOptions &options) const { 189 auto dimOp = cast<tensor::DimOp>(op); 190 auto v = getBuffer(rewriter, dimOp.source(), options); 191 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index()); 192 return success(); 193 } 194 }; 195 196 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 197 struct ExpandShapeOpInterface 198 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 199 tensor::ExpandShapeOp> { 200 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 201 const AnalysisState &state) const { 202 return false; 203 } 204 205 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 206 const AnalysisState &state) const { 207 return false; 208 } 209 210 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 211 const AnalysisState &state) const { 212 if (&opOperand == &op->getOpOperand(0) /*src*/) 213 return {op->getOpResult(0)}; 214 return {}; 215 } 216 217 BufferRelation bufferRelation(Operation *op, OpResult opResult, 218 const AnalysisState &state) const { 219 return BufferRelation::Equivalent; 220 } 221 222 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 223 const BufferizationOptions &options) const { 224 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 225 auto tensorResultType = expandShapeOp.getResultType(); 226 auto buffer = getBuffer(rewriter, expandShapeOp.src(), options); 227 228 // Memref result type is inferred by the builder based on reassociation 229 // indices and result shape. 230 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 231 rewriter, op, tensorResultType.getShape(), buffer, 232 expandShapeOp.getReassociationIndices()); 233 return success(); 234 } 235 }; 236 237 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 238 struct ExtractSliceOpInterface 239 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 240 tensor::ExtractSliceOp> { 241 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 242 const AnalysisState &state) const { 243 return false; 244 } 245 246 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 247 const AnalysisState &state) const { 248 return false; 249 } 250 251 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 252 const AnalysisState &state) const { 253 if (&opOperand == &op->getOpOperand(0) /*source*/) 254 return {op->getOpResult(0)}; 255 return {}; 256 } 257 258 BufferRelation bufferRelation(Operation *op, OpResult opResult, 259 const AnalysisState &state) const { 260 return BufferRelation::None; 261 } 262 263 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 264 const BufferizationOptions &options) const { 265 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 266 Location loc = extractSliceOp.getLoc(); 267 268 // Even if this op was decided to bufferize out-of-place, do not insert the 269 // buffer copy yet. This is done later in this function. 270 auto srcMemref = getBuffer(rewriter, extractSliceOp.source(), options); 271 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 272 auto dstTensorType = 273 extractSliceOp.result().getType().cast<RankedTensorType>(); 274 275 // Expand offsets, sizes and strides to the full rank to handle the 276 // rank-reducing case. 277 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 278 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 279 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 280 OffsetSizeAndStrideOpInterface::expandToRank( 281 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 282 [&](Value target, int64_t dim) -> OpFoldResult { 283 auto shapedType = target.getType().cast<ShapedType>(); 284 if (shapedType.isDynamicDim(dim)) 285 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 286 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 287 }); 288 // Bufferize to subview. 289 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 290 dstTensorType.getRank(), srcMemrefType, 291 mixedOffsets, mixedSizes, mixedStrides) 292 .cast<MemRefType>(); 293 Value subView = rewriter.create<memref::SubViewOp>( 294 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 295 mixedStrides); 296 297 replaceOpWithBufferizedValues(rewriter, op, subView); 298 return success(); 299 } 300 }; 301 302 /// Bufferization of tensor.extract. Replace with memref.load. 303 struct ExtractOpInterface 304 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 305 tensor::ExtractOp> { 306 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 307 const AnalysisState &state) const { 308 return true; 309 } 310 311 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 312 const AnalysisState &state) const { 313 return false; 314 } 315 316 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 317 const AnalysisState &state) const { 318 return {}; 319 } 320 321 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 322 const BufferizationOptions &options) const { 323 auto extractOp = cast<tensor::ExtractOp>(op); 324 Value srcMemref = getBuffer(rewriter, extractOp.tensor(), options); 325 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 326 extractOp.indices()); 327 return success(); 328 } 329 }; 330 331 // Implements backtracking to traverse indices of the output buffer while 332 // iterating over op.elements(). 333 static void createStores(RewriterBase &rewriter, Location loc, int dim, 334 Value buffer, ArrayRef<int64_t> shape, 335 ArrayRef<Value> constants, 336 OperandRange::iterator &elementIt, 337 SmallVectorImpl<Value> &indices) { 338 if (dim == static_cast<int>(shape.size()) - 1) { 339 for (int i = 0; i < shape.back(); ++i) { 340 indices.back() = constants[i]; 341 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 342 ++elementIt; 343 } 344 return; 345 } 346 for (int i = 0; i < shape[dim]; ++i) { 347 indices[dim] = constants[i]; 348 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 349 indices); 350 } 351 } 352 353 /// Bufferization of tensor.from_elements. 354 struct FromElementsOpInterface 355 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 356 tensor::FromElementsOp> { 357 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 358 const BufferizationOptions &options) const { 359 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 360 361 // Allocate a buffer for the result. 362 Location loc = op->getLoc(); 363 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 364 auto shape = tensorType.getShape(); 365 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 366 AnalysisState analysisState(options); 367 Value tensorAlloc = allocateTensorForShapedValue( 368 rewriter, loc, fromElementsOp.result(), 369 analysisState.isTensorYielded(fromElementsOp.result()), 370 /*copy=*/false); 371 auto memrefType = 372 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 373 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 374 op->getLoc(), memrefType, tensorAlloc); 375 376 // Case: tensor<0xelem_type>. 377 if (fromElementsOp.elements().empty()) { 378 replaceOpWithBufferizedValues(rewriter, op, buffer); 379 return success(); 380 } 381 382 // Case: tensor<elem_type>. 383 if (shape.empty()) { 384 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 385 buffer); 386 replaceOpWithBufferizedValues(rewriter, op, buffer); 387 return success(); 388 } 389 390 // Create constants for the range of possible indices [0, max{shape_i}). 391 auto maxDim = *std::max_element(shape.begin(), shape.end()); 392 SmallVector<Value, 2> constants; 393 constants.reserve(maxDim); 394 for (int i = 0; i < maxDim; ++i) 395 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 396 397 // Traverse all `elements` and create `memref.store` ops. 398 auto elementIt = fromElementsOp.elements().begin(); 399 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 400 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 401 indices); 402 403 replaceOpWithBufferizedValues(rewriter, op, buffer); 404 return success(); 405 } 406 }; 407 408 /// Bufferization of tensor.generate. 409 struct GenerateOpInterface 410 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 411 tensor::GenerateOp> { 412 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 413 const BufferizationOptions &options) const { 414 auto generateOp = cast<tensor::GenerateOp>(op); 415 auto tensorType = generateOp.getType().cast<RankedTensorType>(); 416 // Allocate memory. 417 Location loc = op->getLoc(); 418 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 419 AnalysisState analysisState(options); 420 Value tensorAlloc = allocateTensorForShapedValue( 421 rewriter, loc, generateOp.result(), 422 analysisState.isTensorYielded(generateOp.result()), 423 /*copy=*/false); 424 auto memrefType = 425 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 426 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 427 op->getLoc(), memrefType, tensorAlloc); 428 429 // Collect loop bounds. 430 int64_t rank = memrefType.getRank(); 431 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 432 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 433 SmallVector<Value, 4> lowerBounds(rank, zero); 434 SmallVector<Value, 4> steps(rank, one); 435 SmallVector<Value, 4> upperBounds; 436 int nextDynamicIndex = 0; 437 for (int i = 0; i < rank; i++) { 438 Value upperBound = memrefType.isDynamicDim(i) 439 ? generateOp.dynamicExtents()[nextDynamicIndex++] 440 : rewriter.create<arith::ConstantIndexOp>( 441 loc, memrefType.getDimSize(i)); 442 upperBounds.push_back(upperBound); 443 } 444 445 // Generate tensor elements with a parallel loop that stores into 446 // each element of the resulting memref. We use mergeBlockBefore to "move" 447 // this op's body into the scf.parallel's body. 448 auto parallel = 449 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 450 Block *parallelBody = parallel.getBody(); 451 rewriter.mergeBlockBefore(&generateOp.getBody().front(), 452 parallelBody->getTerminator(), 453 parallelBody->getArguments()); 454 // Replace the inlined yield op with a store op. The scf.parallel's builder 455 // already populated an scf.yield at the end, so we don't need to worry 456 // about creating that. 457 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 458 rewriter.setInsertionPointAfter(elementYield); 459 rewriter.replaceOpWithNewOp<memref::StoreOp>( 460 elementYield, elementYield->getOperands()[0], buffer, 461 parallelBody->getArguments()); 462 463 replaceOpWithBufferizedValues(rewriter, op, buffer); 464 return success(); 465 } 466 }; 467 468 /// Bufferization of tensor.insert. Replace with memref.store. 469 struct InsertOpInterface 470 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 471 tensor::InsertOp> { 472 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 473 const AnalysisState &state) const { 474 return true; 475 } 476 477 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 478 const AnalysisState &state) const { 479 return true; 480 } 481 482 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 483 const AnalysisState &state) const { 484 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 485 "expected dest OpOperand"); 486 return {op->getOpResult(0)}; 487 } 488 489 SmallVector<OpOperand *> 490 getAliasingOpOperand(Operation *op, OpResult opResult, 491 const AnalysisState &state) const { 492 return {&op->getOpOperand(1) /*dest*/}; 493 } 494 495 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 496 const BufferizationOptions &options) const { 497 auto insertOp = cast<tensor::InsertOp>(op); 498 Value destMemref = getBuffer(rewriter, insertOp.dest(), options); 499 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 500 destMemref, insertOp.indices()); 501 replaceOpWithBufferizedValues(rewriter, op, destMemref); 502 return success(); 503 } 504 505 BufferRelation bufferRelation(Operation *op, OpResult opResult, 506 const AnalysisState &state) const { 507 return BufferRelation::Equivalent; 508 } 509 }; 510 511 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 512 /// equivalent operand / result and same offset/sizes/strides specification). 513 /// 514 /// This is one particular type of relationship between ops on tensors that 515 /// reduce to an equivalence on buffers. This should be generalized and 516 /// exposed as interfaces on the proper types. 517 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 518 ExtractSliceOp st, InsertSliceOp sti) { 519 if (!st || !sti) 520 return false; 521 if (sti != sti && 522 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 523 return false; 524 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 525 return false; 526 return true; 527 } 528 529 /// Return true if `value` is originating from an ExtractSliceOp that matches 530 /// the given InsertSliceOp. 531 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 532 InsertSliceOp insertOp) { 533 auto condition = [&](Value val) { 534 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 535 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 536 return true; 537 return false; 538 }; 539 540 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 541 condition); 542 } 543 544 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 545 /// certain circumstances, this op can also be a no-op. 546 struct InsertSliceOpInterface 547 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 548 tensor::InsertSliceOp> { 549 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 550 const AnalysisState &state) const { 551 return true; 552 } 553 554 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 555 const AnalysisState &state) const { 556 return &opOperand == &op->getOpOperand(1) /*dest*/; 557 } 558 559 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 560 const AnalysisState &state) const { 561 if (&opOperand == &op->getOpOperand(1) /*dest*/) 562 return {op->getResult(0)}; 563 return {}; 564 } 565 566 BufferRelation bufferRelation(Operation *op, OpResult opResult, 567 const AnalysisState &state) const { 568 return BufferRelation::Equivalent; 569 } 570 571 bool isNotConflicting(Operation *op, OpOperand *uRead, 572 OpOperand *uConflictingWrite, 573 const AnalysisState &state) const { 574 Operation *readingOp = uRead->getOwner(); 575 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 576 577 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 578 // uRead is an InsertSliceOp... 579 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 580 // As an example, consider the following IR. 581 // 582 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 583 // %1 = linalg.fill %cst, %0 {inplace= [true] } 584 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 585 // {inplace= [true] } 586 587 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 588 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 589 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 590 insertSliceOp)) 591 // Case 1: The main insight is that InsertSliceOp reads only part of 592 // the destination tensor. The overwritten area is not read. If 593 // uConflictingWrite writes into exactly the memory location that is 594 // being read by uRead, this is not a conflict. 595 // 596 // In the above example: 597 // uRead = OpOperand 1 (%t) of tensor.insert_slice 598 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 599 // 600 // The read of %t does not conflict with the write of the FillOp 601 // (same aliases!) because the area that the FillOp operates on is 602 // exactly the one that is *not* read via %t. 603 return true; 604 605 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 606 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 607 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 608 // Case 2: The read of the source tensor and the write to the dest 609 // tensor via an InsertSliceOp is not a conflict if the read is 610 // reading exactly that part of an equivalent tensor that the 611 // InsertSliceOp is writing. 612 // 613 // In the above example: 614 // uRead = OpOperand 0 (%1) of tensor.insert_slice 615 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 616 return true; 617 } 618 619 // If uConflictingWrite is an InsertSliceOp... 620 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 621 // As an example, consider the following IR. 622 // 623 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 624 // %1 = linalg.fill %cst, %0 {inplace= [true] } 625 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 626 // {inplace= [true] } 627 // %3 = vector.transfer_read %1, %cst 628 // 629 // In the above example: 630 // uRead = OpOperand 0 (%1) of vector.transfer_read 631 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 632 // lastWrite = %1 633 // 634 // This is not a conflict because the InsertSliceOp overwrites the 635 // memory segment of %1 with the exact same data. (Effectively, there 636 // is no memory write here.) 637 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 638 state.areEquivalentBufferizedValues(uRead->get(), 639 insertSliceOp.source()) && 640 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 641 insertSliceOp)) 642 return true; 643 644 return false; 645 } 646 647 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 648 const BufferizationOptions &options) const { 649 // insert_slice ops arise from tiling and bufferizing them out-of-place is 650 // generally a deal breaker. When used with loops, this ends up cloning the 651 // whole tensor on every single iteration and is a symptom of a 652 // catastrophically bad scheduling decision. 653 // TODO: be very loud about it or even consider failing the pass. 654 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 655 Location loc = insertSliceOp.getLoc(); 656 Value dstMemref = getBuffer(rewriter, insertSliceOp.dest(), options); 657 658 // Expand offsets, sizes and strides to the full rank to handle the 659 // rank-reducing case. 660 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 661 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 662 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 663 OffsetSizeAndStrideOpInterface::expandToRank( 664 dstMemref, mixedOffsets, mixedSizes, mixedStrides, 665 [&](Value target, int64_t dim) -> OpFoldResult { 666 auto shapedType = target.getType().cast<ShapedType>(); 667 if (shapedType.isDynamicDim(dim)) 668 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 669 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 670 }); 671 // Take a subview of the dst. 672 auto dstMemrefType = dstMemref.getType().cast<MemRefType>(); 673 auto subviewMemRefType = 674 memref::SubViewOp::inferRankReducedResultType( 675 insertSliceOp.getSourceType().getRank(), dstMemrefType, 676 mixedOffsets, mixedSizes, mixedStrides) 677 .cast<MemRefType>(); 678 Value subView = rewriter.create<memref::SubViewOp>( 679 loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes, 680 mixedStrides); 681 682 // Copy tensor. If this tensor.insert_slice has a matching 683 // tensor.extract_slice, the copy operation will eventually fold away. 684 auto srcMemref = getBuffer(rewriter, insertSliceOp.source(), options); 685 if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView))) 686 return failure(); 687 688 replaceOpWithBufferizedValues(rewriter, op, dstMemref); 689 return success(); 690 } 691 }; 692 693 /// Bufferization of tensor.rank. Replace with memref.rank. 694 struct RankOpInterface 695 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 696 tensor::RankOp> { 697 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 698 const AnalysisState &state) const { 699 return true; 700 } 701 702 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 703 const AnalysisState &state) const { 704 return false; 705 } 706 707 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 708 const AnalysisState &state) const { 709 return {}; 710 } 711 712 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 713 const BufferizationOptions &options) const { 714 auto rankOp = cast<tensor::RankOp>(op); 715 auto v = getBuffer(rewriter, rankOp.tensor(), options); 716 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 717 v); 718 return success(); 719 } 720 }; 721 722 /// Bufferization of tensor.reshape. Replace with memref.reshape. 723 struct ReshapeOpInterface 724 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 725 tensor::ReshapeOp> { 726 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 727 const AnalysisState &state) const { 728 if (&opOperand == &op->getOpOperand(1) /* shape */) 729 return true; 730 return false; 731 } 732 733 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 734 const AnalysisState &state) const { 735 return false; 736 } 737 738 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 739 const AnalysisState &state) const { 740 return {op->getOpResult(0)}; 741 } 742 743 BufferRelation bufferRelation(Operation *op, OpResult opResult, 744 const AnalysisState &state) const { 745 return BufferRelation::Equivalent; 746 } 747 748 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 749 const BufferizationOptions &options) const { 750 auto reshapeOp = cast<tensor::ReshapeOp>(op); 751 Value srcBuffer = getBuffer(rewriter, reshapeOp.source(), options); 752 Value shapeBuffer = getBuffer(rewriter, reshapeOp.shape(), options); 753 auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 754 auto resultMemRefType = getMemRefType(resultTensorType, options); 755 replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 756 rewriter, op, resultMemRefType, srcBuffer, shapeBuffer); 757 return success(); 758 } 759 }; 760 761 } // namespace 762 } // namespace tensor 763 } // namespace mlir 764 765 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 766 DialectRegistry ®istry) { 767 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 768 CastOp::attachInterface<CastOpInterface>(*ctx); 769 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 770 DimOp::attachInterface<DimOpInterface>(*ctx); 771 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 772 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 773 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 774 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 775 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 776 InsertOp::attachInterface<InsertOpInterface>(*ctx); 777 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 778 RankOp::attachInterface<RankOpInterface>(*ctx); 779 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 780 }); 781 } 782