1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SCF/SCF.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 using namespace mlir::tensor; 22 23 namespace mlir { 24 namespace tensor { 25 namespace { 26 27 struct CastOpInterface 28 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 29 tensor::CastOp> { 30 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 31 const AnalysisState &state) const { 32 return false; 33 } 34 35 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 36 const AnalysisState &state) const { 37 return false; 38 } 39 40 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 41 const AnalysisState &state) const { 42 return {op->getResult(0)}; 43 } 44 45 BufferRelation bufferRelation(Operation *op, OpResult opResult, 46 const AnalysisState &state) const { 47 return BufferRelation::Equivalent; 48 } 49 50 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 51 const BufferizationOptions &options) const { 52 auto castOp = cast<tensor::CastOp>(op); 53 54 // The result buffer still has the old (pre-cast) type. 55 Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options); 56 auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>(); 57 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 58 TensorType resultTensorType = 59 castOp.getResult().getType().cast<TensorType>(); 60 MemRefLayoutAttrInterface layout; 61 62 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 63 if (resultTensorType.isa<RankedTensorType>()) 64 layout = rankedMemRefType.getLayout(); 65 66 // Compute the new memref type. 67 Type resultMemRefType = 68 getMemRefType(resultTensorType, options, layout, memorySpace); 69 70 // Replace the op with a memref.cast. 71 assert(memref::CastOp::areCastCompatible(resultBuffer.getType(), 72 resultMemRefType) && 73 "CallOp::bufferize: cast incompatible"); 74 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 75 resultBuffer); 76 77 return success(); 78 } 79 }; 80 81 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 82 struct CollapseShapeOpInterface 83 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 84 tensor::CollapseShapeOp> { 85 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 86 const AnalysisState &state) const { 87 return false; 88 } 89 90 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 91 const AnalysisState &state) const { 92 return false; 93 } 94 95 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 96 const AnalysisState &state) const { 97 if (&opOperand == &op->getOpOperand(0) /*src*/) 98 return {op->getOpResult(0)}; 99 return {}; 100 } 101 102 BufferRelation bufferRelation(Operation *op, OpResult opResult, 103 const AnalysisState &state) const { 104 return BufferRelation::Equivalent; 105 } 106 107 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 108 const BufferizationOptions &options) const { 109 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 110 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 111 Value buffer = getBuffer(rewriter, collapseShapeOp.getSrc(), options); 112 auto bufferType = buffer.getType().cast<MemRefType>(); 113 114 if (tensorResultType.getRank() == 0) { 115 // 0-d collapses must go through a different op builder. 116 MemRefType resultType; 117 118 if (bufferType.getLayout().isIdentity()) { 119 // Standard layout: result type has no offset. 120 MemRefLayoutAttrInterface layout; 121 resultType = MemRefType::get({}, tensorResultType.getElementType(), 122 layout, bufferType.getMemorySpace()); 123 } else { 124 // Source memref has a layout map: result type has the same offset as 125 // the source type. 126 SmallVector<int64_t> strides; 127 int64_t offset; 128 if (failed(getStridesAndOffset(bufferType, strides, offset))) 129 return failure(); 130 AffineMap resultLayout = 131 makeStridedLinearLayoutMap({}, offset, op->getContext()); 132 resultType = 133 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 134 bufferType.getMemorySpaceAsInt()); 135 } 136 137 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 138 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); 139 return success(); 140 } 141 142 // If the dims are not collapsible (due to an incompatible source layout 143 // map), force an out-of-place bufferization, i.e., a buffer copy. This 144 // newly allocated buffer will have no layout map and thus be collapsible. 145 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 146 bufferType, collapseShapeOp.getReassociationIndices()); 147 if (!canBeCollapsed) { 148 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 149 AnalysisState analysisState(options); 150 Value tensorAlloc = allocateTensorForShapedValue( 151 rewriter, op->getLoc(), collapseShapeOp.getSrc(), 152 analysisState.isTensorYielded(collapseShapeOp.getResult())); 153 auto memrefType = 154 MemRefType::get(collapseShapeOp.getSrcType().getShape(), 155 collapseShapeOp.getSrcType().getElementType(), 156 AffineMap(), bufferType.getMemorySpaceAsInt()); 157 buffer = rewriter.create<bufferization::ToMemrefOp>( 158 op->getLoc(), memrefType, tensorAlloc); 159 } 160 161 // Result type is inferred by the builder. 162 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 163 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 164 return success(); 165 } 166 }; 167 168 /// Bufferization of tensor.dim. Replace with memref.dim. 169 struct DimOpInterface 170 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 171 tensor::DimOp> { 172 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 173 const AnalysisState &state) const { 174 return true; 175 } 176 177 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 178 const AnalysisState &state) const { 179 return false; 180 } 181 182 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 183 const AnalysisState &state) const { 184 return {}; 185 } 186 187 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 188 const BufferizationOptions &options) const { 189 auto dimOp = cast<tensor::DimOp>(op); 190 auto v = getBuffer(rewriter, dimOp.getSource(), options); 191 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, 192 dimOp.getIndex()); 193 return success(); 194 } 195 }; 196 197 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 198 struct ExpandShapeOpInterface 199 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 200 tensor::ExpandShapeOp> { 201 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 202 const AnalysisState &state) const { 203 return false; 204 } 205 206 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 207 const AnalysisState &state) const { 208 return false; 209 } 210 211 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 212 const AnalysisState &state) const { 213 if (&opOperand == &op->getOpOperand(0) /*src*/) 214 return {op->getOpResult(0)}; 215 return {}; 216 } 217 218 BufferRelation bufferRelation(Operation *op, OpResult opResult, 219 const AnalysisState &state) const { 220 return BufferRelation::Equivalent; 221 } 222 223 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 224 const BufferizationOptions &options) const { 225 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 226 auto tensorResultType = expandShapeOp.getResultType(); 227 auto buffer = getBuffer(rewriter, expandShapeOp.getSrc(), options); 228 229 // Memref result type is inferred by the builder based on reassociation 230 // indices and result shape. 231 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 232 rewriter, op, tensorResultType.getShape(), buffer, 233 expandShapeOp.getReassociationIndices()); 234 return success(); 235 } 236 }; 237 238 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 239 struct ExtractSliceOpInterface 240 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 241 tensor::ExtractSliceOp> { 242 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 243 const AnalysisState &state) const { 244 return false; 245 } 246 247 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 248 const AnalysisState &state) const { 249 return false; 250 } 251 252 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 253 const AnalysisState &state) const { 254 if (&opOperand == &op->getOpOperand(0) /*source*/) 255 return {op->getOpResult(0)}; 256 return {}; 257 } 258 259 BufferRelation bufferRelation(Operation *op, OpResult opResult, 260 const AnalysisState &state) const { 261 return BufferRelation::None; 262 } 263 264 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 265 const BufferizationOptions &options) const { 266 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 267 Location loc = extractSliceOp.getLoc(); 268 269 // Even if this op was decided to bufferize out-of-place, do not insert the 270 // buffer copy yet. This is done later in this function. 271 auto srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options); 272 auto srcMemrefType = srcMemref.getType().cast<MemRefType>(); 273 auto dstTensorType = 274 extractSliceOp.getResult().getType().cast<RankedTensorType>(); 275 276 // Expand offsets, sizes and strides to the full rank to handle the 277 // rank-reducing case. 278 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 279 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 280 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 281 OffsetSizeAndStrideOpInterface::expandToRank( 282 srcMemref, mixedOffsets, mixedSizes, mixedStrides, 283 [&](Value target, int64_t dim) -> OpFoldResult { 284 auto shapedType = target.getType().cast<ShapedType>(); 285 if (shapedType.isDynamicDim(dim)) 286 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 287 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 288 }); 289 // Bufferize to subview. 290 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 291 dstTensorType.getRank(), srcMemrefType, 292 mixedOffsets, mixedSizes, mixedStrides) 293 .cast<MemRefType>(); 294 Value subView = rewriter.create<memref::SubViewOp>( 295 loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, 296 mixedStrides); 297 298 replaceOpWithBufferizedValues(rewriter, op, subView); 299 return success(); 300 } 301 }; 302 303 /// Bufferization of tensor.extract. Replace with memref.load. 304 struct ExtractOpInterface 305 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 306 tensor::ExtractOp> { 307 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 308 const AnalysisState &state) const { 309 return true; 310 } 311 312 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 313 const AnalysisState &state) const { 314 return false; 315 } 316 317 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 318 const AnalysisState &state) const { 319 return {}; 320 } 321 322 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 323 const BufferizationOptions &options) const { 324 auto extractOp = cast<tensor::ExtractOp>(op); 325 Value srcMemref = getBuffer(rewriter, extractOp.getTensor(), options); 326 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref, 327 extractOp.getIndices()); 328 return success(); 329 } 330 }; 331 332 // Implements backtracking to traverse indices of the output buffer while 333 // iterating over op.elements(). 334 static void createStores(RewriterBase &rewriter, Location loc, int dim, 335 Value buffer, ArrayRef<int64_t> shape, 336 ArrayRef<Value> constants, 337 OperandRange::iterator &elementIt, 338 SmallVectorImpl<Value> &indices) { 339 if (dim == static_cast<int>(shape.size()) - 1) { 340 for (int i = 0; i < shape.back(); ++i) { 341 indices.back() = constants[i]; 342 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 343 ++elementIt; 344 } 345 return; 346 } 347 for (int i = 0; i < shape[dim]; ++i) { 348 indices[dim] = constants[i]; 349 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 350 indices); 351 } 352 } 353 354 /// Bufferization of tensor.from_elements. 355 struct FromElementsOpInterface 356 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 357 tensor::FromElementsOp> { 358 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 359 const BufferizationOptions &options) const { 360 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 361 362 // Allocate a buffer for the result. 363 Location loc = op->getLoc(); 364 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 365 auto shape = tensorType.getShape(); 366 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 367 AnalysisState analysisState(options); 368 Value tensorAlloc = allocateTensorForShapedValue( 369 rewriter, loc, fromElementsOp.getResult(), 370 analysisState.isTensorYielded(fromElementsOp.getResult()), 371 /*copy=*/false); 372 auto memrefType = 373 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 374 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 375 op->getLoc(), memrefType, tensorAlloc); 376 377 // Case: tensor<0xelem_type>. 378 if (fromElementsOp.getElements().empty()) { 379 replaceOpWithBufferizedValues(rewriter, op, buffer); 380 return success(); 381 } 382 383 // Case: tensor<elem_type>. 384 if (shape.empty()) { 385 rewriter.create<memref::StoreOp>( 386 loc, fromElementsOp.getElements().front(), buffer); 387 replaceOpWithBufferizedValues(rewriter, op, buffer); 388 return success(); 389 } 390 391 // Create constants for the range of possible indices [0, max{shape_i}). 392 auto maxDim = *std::max_element(shape.begin(), shape.end()); 393 SmallVector<Value, 2> constants; 394 constants.reserve(maxDim); 395 for (int i = 0; i < maxDim; ++i) 396 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 397 398 // Traverse all `elements` and create `memref.store` ops. 399 auto elementIt = fromElementsOp.getElements().begin(); 400 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 401 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 402 indices); 403 404 replaceOpWithBufferizedValues(rewriter, op, buffer); 405 return success(); 406 } 407 }; 408 409 /// Bufferization of tensor.generate. 410 struct GenerateOpInterface 411 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 412 tensor::GenerateOp> { 413 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 414 const BufferizationOptions &options) const { 415 auto generateOp = cast<tensor::GenerateOp>(op); 416 auto tensorType = generateOp.getType().cast<RankedTensorType>(); 417 // Allocate memory. 418 Location loc = op->getLoc(); 419 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 420 AnalysisState analysisState(options); 421 Value tensorAlloc = allocateTensorForShapedValue( 422 rewriter, loc, generateOp.getResult(), 423 analysisState.isTensorYielded(generateOp.getResult()), 424 /*copy=*/false); 425 auto memrefType = 426 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 427 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 428 op->getLoc(), memrefType, tensorAlloc); 429 430 // Collect loop bounds. 431 int64_t rank = memrefType.getRank(); 432 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 433 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 434 SmallVector<Value, 4> lowerBounds(rank, zero); 435 SmallVector<Value, 4> steps(rank, one); 436 SmallVector<Value, 4> upperBounds; 437 int nextDynamicIndex = 0; 438 for (int i = 0; i < rank; i++) { 439 Value upperBound = 440 memrefType.isDynamicDim(i) 441 ? generateOp.getDynamicExtents()[nextDynamicIndex++] 442 : rewriter.create<arith::ConstantIndexOp>( 443 loc, memrefType.getDimSize(i)); 444 upperBounds.push_back(upperBound); 445 } 446 447 // Generate tensor elements with a parallel loop that stores into 448 // each element of the resulting memref. We use mergeBlockBefore to "move" 449 // this op's body into the scf.parallel's body. 450 auto parallel = 451 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 452 Block *parallelBody = parallel.getBody(); 453 rewriter.mergeBlockBefore(&generateOp.getBody().front(), 454 parallelBody->getTerminator(), 455 parallelBody->getArguments()); 456 // Replace the inlined yield op with a store op. The scf.parallel's builder 457 // already populated an scf.yield at the end, so we don't need to worry 458 // about creating that. 459 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 460 rewriter.setInsertionPointAfter(elementYield); 461 rewriter.replaceOpWithNewOp<memref::StoreOp>( 462 elementYield, elementYield->getOperands()[0], buffer, 463 parallelBody->getArguments()); 464 465 replaceOpWithBufferizedValues(rewriter, op, buffer); 466 return success(); 467 } 468 }; 469 470 /// Bufferization of tensor.insert. Replace with memref.store. 471 struct InsertOpInterface 472 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 473 tensor::InsertOp> { 474 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 475 const AnalysisState &state) const { 476 return true; 477 } 478 479 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 480 const AnalysisState &state) const { 481 return true; 482 } 483 484 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 485 const AnalysisState &state) const { 486 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 487 "expected dest OpOperand"); 488 return {op->getOpResult(0)}; 489 } 490 491 SmallVector<OpOperand *> 492 getAliasingOpOperand(Operation *op, OpResult opResult, 493 const AnalysisState &state) const { 494 return {&op->getOpOperand(1) /*dest*/}; 495 } 496 497 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 498 const BufferizationOptions &options) const { 499 auto insertOp = cast<tensor::InsertOp>(op); 500 Value destMemref = getBuffer(rewriter, insertOp.getDest(), options); 501 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), 502 destMemref, insertOp.getIndices()); 503 replaceOpWithBufferizedValues(rewriter, op, destMemref); 504 return success(); 505 } 506 507 BufferRelation bufferRelation(Operation *op, OpResult opResult, 508 const AnalysisState &state) const { 509 return BufferRelation::Equivalent; 510 } 511 }; 512 513 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 514 /// equivalent operand / result and same offset/sizes/strides specification). 515 /// 516 /// This is one particular type of relationship between ops on tensors that 517 /// reduce to an equivalence on buffers. This should be generalized and 518 /// exposed as interfaces on the proper types. 519 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 520 ExtractSliceOp st, InsertSliceOp sti) { 521 if (!st || !sti) 522 return false; 523 if (sti != sti && 524 !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 525 return false; 526 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 527 return false; 528 return true; 529 } 530 531 /// Return true if `value` is originating from an ExtractSliceOp that matches 532 /// the given InsertSliceOp. 533 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 534 InsertSliceOp insertOp) { 535 auto condition = [&](Value val) { 536 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 537 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 538 return true; 539 return false; 540 }; 541 542 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 543 condition); 544 } 545 546 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 547 /// certain circumstances, this op can also be a no-op. 548 struct InsertSliceOpInterface 549 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 550 tensor::InsertSliceOp> { 551 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 552 const AnalysisState &state) const { 553 return true; 554 } 555 556 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 557 const AnalysisState &state) const { 558 return &opOperand == &op->getOpOperand(1) /*dest*/; 559 } 560 561 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 562 const AnalysisState &state) const { 563 if (&opOperand == &op->getOpOperand(1) /*dest*/) 564 return {op->getResult(0)}; 565 return {}; 566 } 567 568 BufferRelation bufferRelation(Operation *op, OpResult opResult, 569 const AnalysisState &state) const { 570 return BufferRelation::Equivalent; 571 } 572 573 bool isNotConflicting(Operation *op, OpOperand *uRead, 574 OpOperand *uConflictingWrite, 575 const AnalysisState &state) const { 576 Operation *readingOp = uRead->getOwner(); 577 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 578 579 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 580 // uRead is an InsertSliceOp... 581 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 582 // As an example, consider the following IR. 583 // 584 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 585 // %1 = linalg.fill %cst, %0 {inplace= [true] } 586 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 587 // {inplace= [true] } 588 589 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 590 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 591 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 592 insertSliceOp)) 593 // Case 1: The main insight is that InsertSliceOp reads only part of 594 // the destination tensor. The overwritten area is not read. If 595 // uConflictingWrite writes into exactly the memory location that is 596 // being read by uRead, this is not a conflict. 597 // 598 // In the above example: 599 // uRead = OpOperand 1 (%t) of tensor.insert_slice 600 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 601 // 602 // The read of %t does not conflict with the write of the FillOp 603 // (same aliases!) because the area that the FillOp operates on is 604 // exactly the one that is *not* read via %t. 605 return true; 606 607 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 608 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 609 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 610 // Case 2: The read of the source tensor and the write to the dest 611 // tensor via an InsertSliceOp is not a conflict if the read is 612 // reading exactly that part of an equivalent tensor that the 613 // InsertSliceOp is writing. 614 // 615 // In the above example: 616 // uRead = OpOperand 0 (%1) of tensor.insert_slice 617 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 618 return true; 619 } 620 621 // If uConflictingWrite is an InsertSliceOp... 622 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 623 // As an example, consider the following IR. 624 // 625 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 626 // %1 = linalg.fill %cst, %0 {inplace= [true] } 627 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 628 // {inplace= [true] } 629 // %3 = vector.transfer_read %1, %cst 630 // 631 // In the above example: 632 // uRead = OpOperand 0 (%1) of vector.transfer_read 633 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 634 // lastWrite = %1 635 // 636 // This is not a conflict because the InsertSliceOp overwrites the 637 // memory segment of %1 with the exact same data. (Effectively, there 638 // is no memory write here.) 639 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 640 state.areEquivalentBufferizedValues(uRead->get(), 641 insertSliceOp.getSource()) && 642 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 643 insertSliceOp)) 644 return true; 645 646 return false; 647 } 648 649 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 650 const BufferizationOptions &options) const { 651 // insert_slice ops arise from tiling and bufferizing them out-of-place is 652 // generally a deal breaker. When used with loops, this ends up cloning the 653 // whole tensor on every single iteration and is a symptom of a 654 // catastrophically bad scheduling decision. 655 // TODO: be very loud about it or even consider failing the pass. 656 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 657 Location loc = insertSliceOp.getLoc(); 658 Value dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options); 659 660 // Expand offsets, sizes and strides to the full rank to handle the 661 // rank-reducing case. 662 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 663 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 664 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 665 OffsetSizeAndStrideOpInterface::expandToRank( 666 dstMemref, mixedOffsets, mixedSizes, mixedStrides, 667 [&](Value target, int64_t dim) -> OpFoldResult { 668 auto shapedType = target.getType().cast<ShapedType>(); 669 if (shapedType.isDynamicDim(dim)) 670 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 671 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 672 }); 673 // Take a subview of the dst. 674 auto dstMemrefType = dstMemref.getType().cast<MemRefType>(); 675 auto subviewMemRefType = 676 memref::SubViewOp::inferRankReducedResultType( 677 insertSliceOp.getSourceType().getRank(), dstMemrefType, 678 mixedOffsets, mixedSizes, mixedStrides) 679 .cast<MemRefType>(); 680 Value subView = rewriter.create<memref::SubViewOp>( 681 loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes, 682 mixedStrides); 683 684 // Copy tensor. If this tensor.insert_slice has a matching 685 // tensor.extract_slice, the copy operation will eventually fold away. 686 auto srcMemref = getBuffer(rewriter, insertSliceOp.getSource(), options); 687 if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView))) 688 return failure(); 689 690 replaceOpWithBufferizedValues(rewriter, op, dstMemref); 691 return success(); 692 } 693 }; 694 695 /// Bufferization of tensor.rank. Replace with memref.rank. 696 struct RankOpInterface 697 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 698 tensor::RankOp> { 699 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 700 const AnalysisState &state) const { 701 return true; 702 } 703 704 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 705 const AnalysisState &state) const { 706 return false; 707 } 708 709 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 710 const AnalysisState &state) const { 711 return {}; 712 } 713 714 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 715 const BufferizationOptions &options) const { 716 auto rankOp = cast<tensor::RankOp>(op); 717 auto v = getBuffer(rewriter, rankOp.getTensor(), options); 718 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 719 v); 720 return success(); 721 } 722 }; 723 724 /// Bufferization of tensor.reshape. Replace with memref.reshape. 725 struct ReshapeOpInterface 726 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 727 tensor::ReshapeOp> { 728 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 729 const AnalysisState &state) const { 730 if (&opOperand == &op->getOpOperand(1) /* shape */) 731 return true; 732 return false; 733 } 734 735 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 736 const AnalysisState &state) const { 737 return false; 738 } 739 740 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 741 const AnalysisState &state) const { 742 return {op->getOpResult(0)}; 743 } 744 745 BufferRelation bufferRelation(Operation *op, OpResult opResult, 746 const AnalysisState &state) const { 747 return BufferRelation::Equivalent; 748 } 749 750 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 751 const BufferizationOptions &options) const { 752 auto reshapeOp = cast<tensor::ReshapeOp>(op); 753 Value srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options); 754 Value shapeBuffer = getBuffer(rewriter, reshapeOp.getShape(), options); 755 auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 756 auto resultMemRefType = getMemRefType(resultTensorType, options); 757 replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 758 rewriter, op, resultMemRefType, srcBuffer, shapeBuffer); 759 return success(); 760 } 761 }; 762 763 } // namespace 764 } // namespace tensor 765 } // namespace mlir 766 767 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 768 DialectRegistry ®istry) { 769 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 770 CastOp::attachInterface<CastOpInterface>(*ctx); 771 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 772 DimOp::attachInterface<DimOpInterface>(*ctx); 773 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 774 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 775 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 776 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 777 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 778 InsertOp::attachInterface<InsertOpInterface>(*ctx); 779 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 780 RankOp::attachInterface<RankOpInterface>(*ctx); 781 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 782 }); 783 } 784