1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SCF/IR/SCF.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 using namespace mlir::tensor; 22 23 namespace mlir { 24 namespace tensor { 25 namespace { 26 27 struct CastOpInterface 28 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 29 tensor::CastOp> { 30 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 31 const AnalysisState &state) const { 32 return false; 33 } 34 35 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 36 const AnalysisState &state) const { 37 return false; 38 } 39 40 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 41 const AnalysisState &state) const { 42 return {op->getResult(0)}; 43 } 44 45 BufferRelation bufferRelation(Operation *op, OpResult opResult, 46 const AnalysisState &state) const { 47 return BufferRelation::Equivalent; 48 } 49 50 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 51 const BufferizationOptions &options) const { 52 auto castOp = cast<tensor::CastOp>(op); 53 54 // The result buffer still has the old (pre-cast) type. 55 FailureOr<Value> resultBuffer = 56 getBuffer(rewriter, castOp.getSource(), options); 57 if (failed(resultBuffer)) 58 return failure(); 59 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 60 TensorType resultTensorType = 61 castOp.getResult().getType().cast<TensorType>(); 62 MemRefLayoutAttrInterface layout; 63 64 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 65 if (resultTensorType.isa<RankedTensorType>()) 66 layout = rankedMemRefType.getLayout(); 67 68 // Compute the new memref type. 69 Type resultMemRefType = 70 getMemRefType(resultTensorType, options, layout, 71 sourceMemRefType.getMemorySpaceAsInt()); 72 73 // Replace the op with a memref.cast. 74 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 75 resultMemRefType) && 76 "CallOp::bufferize: cast incompatible"); 77 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 78 *resultBuffer); 79 80 return success(); 81 } 82 }; 83 84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 85 struct CollapseShapeOpInterface 86 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 87 tensor::CollapseShapeOp> { 88 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 89 const AnalysisState &state) const { 90 return false; 91 } 92 93 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 94 const AnalysisState &state) const { 95 return false; 96 } 97 98 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 99 const AnalysisState &state) const { 100 if (&opOperand == &op->getOpOperand(0) /*src*/) 101 return {op->getOpResult(0)}; 102 return {}; 103 } 104 105 BufferRelation bufferRelation(Operation *op, OpResult opResult, 106 const AnalysisState &state) const { 107 return BufferRelation::Equivalent; 108 } 109 110 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 111 const BufferizationOptions &options) const { 112 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 113 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 114 FailureOr<Value> maybeBuffer = 115 getBuffer(rewriter, collapseShapeOp.getSrc(), options); 116 if (failed(maybeBuffer)) 117 return failure(); 118 Value buffer = *maybeBuffer; 119 auto bufferType = buffer.getType().cast<MemRefType>(); 120 121 if (tensorResultType.getRank() == 0) { 122 // 0-d collapses must go through a different op builder. 123 MemRefType resultType; 124 125 if (bufferType.getLayout().isIdentity()) { 126 // Standard layout: result type has no offset. 127 MemRefLayoutAttrInterface layout; 128 resultType = MemRefType::get({}, tensorResultType.getElementType(), 129 layout, bufferType.getMemorySpace()); 130 } else { 131 // Source memref has a layout map: result type has the same offset as 132 // the source type. 133 SmallVector<int64_t> strides; 134 int64_t offset; 135 if (failed(getStridesAndOffset(bufferType, strides, offset))) 136 return failure(); 137 AffineMap resultLayout = 138 makeStridedLinearLayoutMap({}, offset, op->getContext()); 139 resultType = 140 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 141 bufferType.getMemorySpaceAsInt()); 142 } 143 144 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 145 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); 146 return success(); 147 } 148 149 // If the dims are not collapsible (due to an incompatible source layout 150 // map), force an out-of-place bufferization, i.e., a buffer copy. This 151 // newly allocated buffer will have no layout map and thus be collapsible. 152 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 153 bufferType, collapseShapeOp.getReassociationIndices()); 154 if (!canBeCollapsed) { 155 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 156 AnalysisState analysisState(options); 157 Value tensorAlloc = allocateTensorForShapedValue( 158 rewriter, op->getLoc(), collapseShapeOp.getSrc(), 159 analysisState.isTensorYielded(collapseShapeOp.getResult())); 160 auto memrefType = 161 MemRefType::get(collapseShapeOp.getSrcType().getShape(), 162 collapseShapeOp.getSrcType().getElementType(), 163 AffineMap(), bufferType.getMemorySpaceAsInt()); 164 buffer = rewriter.create<bufferization::ToMemrefOp>( 165 op->getLoc(), memrefType, tensorAlloc); 166 } 167 168 // Result type is inferred by the builder. 169 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 170 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 171 return success(); 172 } 173 }; 174 175 /// Bufferization of tensor.dim. Replace with memref.dim. 176 struct DimOpInterface 177 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 178 tensor::DimOp> { 179 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 180 const AnalysisState &state) const { 181 return true; 182 } 183 184 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 185 const AnalysisState &state) const { 186 return false; 187 } 188 189 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 190 const AnalysisState &state) const { 191 return {}; 192 } 193 194 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 195 const BufferizationOptions &options) const { 196 auto dimOp = cast<tensor::DimOp>(op); 197 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options); 198 if (failed(v)) 199 return failure(); 200 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, 201 dimOp.index()); 202 return success(); 203 } 204 }; 205 206 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 207 struct ExpandShapeOpInterface 208 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 209 tensor::ExpandShapeOp> { 210 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 211 const AnalysisState &state) const { 212 return false; 213 } 214 215 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 216 const AnalysisState &state) const { 217 return false; 218 } 219 220 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 221 const AnalysisState &state) const { 222 if (&opOperand == &op->getOpOperand(0) /*src*/) 223 return {op->getOpResult(0)}; 224 return {}; 225 } 226 227 BufferRelation bufferRelation(Operation *op, OpResult opResult, 228 const AnalysisState &state) const { 229 return BufferRelation::Equivalent; 230 } 231 232 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 233 const BufferizationOptions &options) const { 234 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 235 auto tensorResultType = expandShapeOp.getResultType(); 236 FailureOr<Value> buffer = 237 getBuffer(rewriter, expandShapeOp.getSrc(), options); 238 if (failed(buffer)) 239 return failure(); 240 241 // Memref result type is inferred by the builder based on reassociation 242 // indices and result shape. 243 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 244 rewriter, op, tensorResultType.getShape(), *buffer, 245 expandShapeOp.getReassociationIndices()); 246 return success(); 247 } 248 }; 249 250 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 251 struct ExtractSliceOpInterface 252 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 253 tensor::ExtractSliceOp> { 254 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 255 const AnalysisState &state) const { 256 return false; 257 } 258 259 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 260 const AnalysisState &state) const { 261 return false; 262 } 263 264 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 265 const AnalysisState &state) const { 266 if (&opOperand == &op->getOpOperand(0) /*source*/) 267 return {op->getOpResult(0)}; 268 return {}; 269 } 270 271 BufferRelation bufferRelation(Operation *op, OpResult opResult, 272 const AnalysisState &state) const { 273 return BufferRelation::None; 274 } 275 276 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 277 const BufferizationOptions &options) const { 278 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 279 Location loc = extractSliceOp.getLoc(); 280 281 // Even if this op was decided to bufferize out-of-place, do not insert the 282 // buffer copy yet. This is done later in this function. 283 FailureOr<Value> srcMemref = 284 getBuffer(rewriter, extractSliceOp.getSource(), options); 285 if (failed(srcMemref)) 286 return failure(); 287 auto srcMemrefType = srcMemref->getType().cast<MemRefType>(); 288 auto dstTensorType = 289 extractSliceOp.getResult().getType().cast<RankedTensorType>(); 290 291 // Expand offsets, sizes and strides to the full rank to handle the 292 // rank-reducing case. 293 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 294 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 295 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 296 OffsetSizeAndStrideOpInterface::expandToRank( 297 *srcMemref, mixedOffsets, mixedSizes, mixedStrides, 298 [&](Value target, int64_t dim) -> OpFoldResult { 299 auto shapedType = target.getType().cast<ShapedType>(); 300 if (shapedType.isDynamicDim(dim)) 301 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 302 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 303 }); 304 // Bufferize to subview. 305 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 306 dstTensorType.getRank(), srcMemrefType, 307 mixedOffsets, mixedSizes, mixedStrides) 308 .cast<MemRefType>(); 309 Value subView = rewriter.create<memref::SubViewOp>( 310 loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, 311 mixedStrides); 312 313 replaceOpWithBufferizedValues(rewriter, op, subView); 314 return success(); 315 } 316 }; 317 318 /// Bufferization of tensor.extract. Replace with memref.load. 319 struct ExtractOpInterface 320 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 321 tensor::ExtractOp> { 322 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 323 const AnalysisState &state) const { 324 return true; 325 } 326 327 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 328 const AnalysisState &state) const { 329 return false; 330 } 331 332 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 333 const AnalysisState &state) const { 334 return {}; 335 } 336 337 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 338 const BufferizationOptions &options) const { 339 auto extractOp = cast<tensor::ExtractOp>(op); 340 FailureOr<Value> srcMemref = 341 getBuffer(rewriter, extractOp.getTensor(), options); 342 if (failed(srcMemref)) 343 return failure(); 344 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, 345 extractOp.indices()); 346 return success(); 347 } 348 }; 349 350 // Implements backtracking to traverse indices of the output buffer while 351 // iterating over op.elements(). 352 static void createStores(RewriterBase &rewriter, Location loc, int dim, 353 Value buffer, ArrayRef<int64_t> shape, 354 ArrayRef<Value> constants, 355 OperandRange::iterator &elementIt, 356 SmallVectorImpl<Value> &indices) { 357 if (dim == static_cast<int>(shape.size()) - 1) { 358 for (int i = 0; i < shape.back(); ++i) { 359 indices.back() = constants[i]; 360 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 361 ++elementIt; 362 } 363 return; 364 } 365 for (int i = 0; i < shape[dim]; ++i) { 366 indices[dim] = constants[i]; 367 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 368 indices); 369 } 370 } 371 372 /// Bufferization of tensor.from_elements. 373 struct FromElementsOpInterface 374 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 375 tensor::FromElementsOp> { 376 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 377 const BufferizationOptions &options) const { 378 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 379 380 // Allocate a buffer for the result. 381 Location loc = op->getLoc(); 382 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 383 auto shape = tensorType.getShape(); 384 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 385 AnalysisState analysisState(options); 386 Value tensorAlloc = allocateTensorForShapedValue( 387 rewriter, loc, fromElementsOp.getResult(), 388 analysisState.isTensorYielded(fromElementsOp.getResult()), 389 /*copy=*/false); 390 auto memrefType = 391 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 392 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 393 op->getLoc(), memrefType, tensorAlloc); 394 395 // Case: tensor<0xelem_type>. 396 if (fromElementsOp.getElements().empty()) { 397 replaceOpWithBufferizedValues(rewriter, op, buffer); 398 return success(); 399 } 400 401 // Case: tensor<elem_type>. 402 if (shape.empty()) { 403 rewriter.create<memref::StoreOp>( 404 loc, fromElementsOp.getElements().front(), buffer); 405 replaceOpWithBufferizedValues(rewriter, op, buffer); 406 return success(); 407 } 408 409 // Create constants for the range of possible indices [0, max{shape_i}). 410 auto maxDim = *std::max_element(shape.begin(), shape.end()); 411 SmallVector<Value, 2> constants; 412 constants.reserve(maxDim); 413 for (int i = 0; i < maxDim; ++i) 414 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 415 416 // Traverse all `elements` and create `memref.store` ops. 417 auto elementIt = fromElementsOp.getElements().begin(); 418 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 419 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 420 indices); 421 422 replaceOpWithBufferizedValues(rewriter, op, buffer); 423 return success(); 424 } 425 }; 426 427 /// Bufferization of tensor.generate. 428 struct GenerateOpInterface 429 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 430 tensor::GenerateOp> { 431 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 432 const BufferizationOptions &options) const { 433 auto generateOp = cast<tensor::GenerateOp>(op); 434 auto tensorType = generateOp.getType().cast<RankedTensorType>(); 435 // Allocate memory. 436 Location loc = op->getLoc(); 437 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 438 AnalysisState analysisState(options); 439 Value tensorAlloc = allocateTensorForShapedValue( 440 rewriter, loc, generateOp.getResult(), 441 analysisState.isTensorYielded(generateOp.getResult()), 442 /*copy=*/false); 443 auto memrefType = 444 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 445 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 446 op->getLoc(), memrefType, tensorAlloc); 447 448 // Collect loop bounds. 449 int64_t rank = memrefType.getRank(); 450 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 451 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 452 SmallVector<Value, 4> lowerBounds(rank, zero); 453 SmallVector<Value, 4> steps(rank, one); 454 SmallVector<Value, 4> upperBounds; 455 int nextDynamicIndex = 0; 456 for (int i = 0; i < rank; i++) { 457 Value upperBound = 458 memrefType.isDynamicDim(i) 459 ? generateOp.getDynamicExtents()[nextDynamicIndex++] 460 : rewriter.create<arith::ConstantIndexOp>( 461 loc, memrefType.getDimSize(i)); 462 upperBounds.push_back(upperBound); 463 } 464 465 // Generate tensor elements with a parallel loop that stores into 466 // each element of the resulting memref. We use mergeBlockBefore to "move" 467 // this op's body into the scf.parallel's body. 468 auto parallel = 469 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 470 Block *parallelBody = parallel.getBody(); 471 rewriter.mergeBlockBefore(&generateOp.getBody().front(), 472 parallelBody->getTerminator(), 473 parallelBody->getArguments()); 474 // Replace the inlined yield op with a store op. The scf.parallel's builder 475 // already populated an scf.yield at the end, so we don't need to worry 476 // about creating that. 477 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 478 rewriter.setInsertionPointAfter(elementYield); 479 rewriter.replaceOpWithNewOp<memref::StoreOp>( 480 elementYield, elementYield->getOperands()[0], buffer, 481 parallelBody->getArguments()); 482 483 replaceOpWithBufferizedValues(rewriter, op, buffer); 484 return success(); 485 } 486 }; 487 488 /// Bufferization of tensor.insert. Replace with memref.store. 489 struct InsertOpInterface 490 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 491 tensor::InsertOp> { 492 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 493 const AnalysisState &state) const { 494 return true; 495 } 496 497 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 498 const AnalysisState &state) const { 499 return true; 500 } 501 502 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 503 const AnalysisState &state) const { 504 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 505 "expected dest OpOperand"); 506 return {op->getOpResult(0)}; 507 } 508 509 SmallVector<OpOperand *> 510 getAliasingOpOperand(Operation *op, OpResult opResult, 511 const AnalysisState &state) const { 512 return {&op->getOpOperand(1) /*dest*/}; 513 } 514 515 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 516 const BufferizationOptions &options) const { 517 auto insertOp = cast<tensor::InsertOp>(op); 518 FailureOr<Value> destMemref = 519 getBuffer(rewriter, insertOp.getDest(), options); 520 if (failed(destMemref)) 521 return failure(); 522 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), 523 *destMemref, insertOp.getIndices()); 524 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 525 return success(); 526 } 527 528 BufferRelation bufferRelation(Operation *op, OpResult opResult, 529 const AnalysisState &state) const { 530 return BufferRelation::Equivalent; 531 } 532 }; 533 534 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 535 /// equivalent operand / result and same offset/sizes/strides specification). 536 /// 537 /// This is one particular type of relationship between ops on tensors that 538 /// reduce to an equivalence on buffers. This should be generalized and 539 /// exposed as interfaces on the proper types. 540 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 541 ExtractSliceOp st, InsertSliceOp sti) { 542 if (!st || !sti) 543 return false; 544 if (sti != sti && 545 !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 546 return false; 547 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 548 return false; 549 return true; 550 } 551 552 /// Return true if `value` is originating from an ExtractSliceOp that matches 553 /// the given InsertSliceOp. 554 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 555 InsertSliceOp insertOp) { 556 auto condition = [&](Value val) { 557 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 558 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 559 return true; 560 return false; 561 }; 562 563 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 564 condition); 565 } 566 567 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 568 /// certain circumstances, this op can also be a no-op. 569 struct InsertSliceOpInterface 570 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 571 tensor::InsertSliceOp> { 572 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 573 const AnalysisState &state) const { 574 return true; 575 } 576 577 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 578 const AnalysisState &state) const { 579 return &opOperand == &op->getOpOperand(1) /*dest*/; 580 } 581 582 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 583 const AnalysisState &state) const { 584 if (&opOperand == &op->getOpOperand(1) /*dest*/) 585 return {op->getResult(0)}; 586 return {}; 587 } 588 589 BufferRelation bufferRelation(Operation *op, OpResult opResult, 590 const AnalysisState &state) const { 591 return BufferRelation::Equivalent; 592 } 593 594 bool isNotConflicting(Operation *op, OpOperand *uRead, 595 OpOperand *uConflictingWrite, 596 const AnalysisState &state) const { 597 Operation *readingOp = uRead->getOwner(); 598 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 599 600 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 601 // uRead is an InsertSliceOp... 602 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 603 // As an example, consider the following IR. 604 // 605 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 606 // %1 = linalg.fill %cst, %0 {inplace= [true] } 607 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 608 // {inplace= [true] } 609 610 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 611 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 612 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 613 insertSliceOp)) 614 // Case 1: The main insight is that InsertSliceOp reads only part of 615 // the destination tensor. The overwritten area is not read. If 616 // uConflictingWrite writes into exactly the memory location that is 617 // being read by uRead, this is not a conflict. 618 // 619 // In the above example: 620 // uRead = OpOperand 1 (%t) of tensor.insert_slice 621 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 622 // 623 // The read of %t does not conflict with the write of the FillOp 624 // (same aliases!) because the area that the FillOp operates on is 625 // exactly the one that is *not* read via %t. 626 return true; 627 628 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 629 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 630 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 631 // Case 2: The read of the source tensor and the write to the dest 632 // tensor via an InsertSliceOp is not a conflict if the read is 633 // reading exactly that part of an equivalent tensor that the 634 // InsertSliceOp is writing. 635 // 636 // In the above example: 637 // uRead = OpOperand 0 (%1) of tensor.insert_slice 638 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 639 return true; 640 } 641 642 // If uConflictingWrite is an InsertSliceOp... 643 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 644 // As an example, consider the following IR. 645 // 646 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 647 // %1 = linalg.fill %cst, %0 {inplace= [true] } 648 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 649 // {inplace= [true] } 650 // %3 = vector.transfer_read %1, %cst 651 // 652 // In the above example: 653 // uRead = OpOperand 0 (%1) of vector.transfer_read 654 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 655 // lastWrite = %1 656 // 657 // This is not a conflict because the InsertSliceOp overwrites the 658 // memory segment of %1 with the exact same data. (Effectively, there 659 // is no memory write here.) 660 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 661 state.areEquivalentBufferizedValues(uRead->get(), 662 insertSliceOp.getSource()) && 663 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 664 insertSliceOp)) 665 return true; 666 667 return false; 668 } 669 670 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 671 const BufferizationOptions &options) const { 672 // insert_slice ops arise from tiling and bufferizing them out-of-place is 673 // generally a deal breaker. When used with loops, this ends up cloning the 674 // whole tensor on every single iteration and is a symptom of a 675 // catastrophically bad scheduling decision. 676 // TODO: be very loud about it or even consider failing the pass. 677 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 678 Location loc = insertSliceOp.getLoc(); 679 FailureOr<Value> dstMemref = 680 getBuffer(rewriter, insertSliceOp.getDest(), options); 681 if (failed(dstMemref)) 682 return failure(); 683 684 // Expand offsets, sizes and strides to the full rank to handle the 685 // rank-reducing case. 686 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 687 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 688 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 689 OffsetSizeAndStrideOpInterface::expandToRank( 690 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 691 [&](Value target, int64_t dim) -> OpFoldResult { 692 auto shapedType = target.getType().cast<ShapedType>(); 693 if (shapedType.isDynamicDim(dim)) 694 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 695 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 696 }); 697 // Take a subview of the dst. 698 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 699 auto subviewMemRefType = 700 memref::SubViewOp::inferRankReducedResultType( 701 insertSliceOp.getSourceType().getRank(), dstMemrefType, 702 mixedOffsets, mixedSizes, mixedStrides) 703 .cast<MemRefType>(); 704 Value subView = rewriter.create<memref::SubViewOp>( 705 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 706 mixedStrides); 707 708 // Copy tensor. If this tensor.insert_slice has a matching 709 // tensor.extract_slice, the copy operation will eventually fold away. 710 FailureOr<Value> srcMemref = 711 getBuffer(rewriter, insertSliceOp.getSource(), options); 712 if (failed(srcMemref)) 713 return failure(); 714 if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) 715 return failure(); 716 717 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 718 return success(); 719 } 720 }; 721 722 /// Bufferization of tensor.rank. Replace with memref.rank. 723 struct RankOpInterface 724 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 725 tensor::RankOp> { 726 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 727 const AnalysisState &state) const { 728 return true; 729 } 730 731 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 732 const AnalysisState &state) const { 733 return false; 734 } 735 736 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 737 const AnalysisState &state) const { 738 return {}; 739 } 740 741 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 742 const BufferizationOptions &options) const { 743 auto rankOp = cast<tensor::RankOp>(op); 744 FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options); 745 if (failed(v)) 746 return failure(); 747 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 748 *v); 749 return success(); 750 } 751 }; 752 753 /// Bufferization of tensor.reshape. Replace with memref.reshape. 754 struct ReshapeOpInterface 755 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 756 tensor::ReshapeOp> { 757 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 758 const AnalysisState &state) const { 759 if (&opOperand == &op->getOpOperand(1) /* shape */) 760 return true; 761 return false; 762 } 763 764 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 765 const AnalysisState &state) const { 766 return false; 767 } 768 769 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 770 const AnalysisState &state) const { 771 return {op->getOpResult(0)}; 772 } 773 774 BufferRelation bufferRelation(Operation *op, OpResult opResult, 775 const AnalysisState &state) const { 776 return BufferRelation::Equivalent; 777 } 778 779 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 780 const BufferizationOptions &options) const { 781 auto reshapeOp = cast<tensor::ReshapeOp>(op); 782 FailureOr<Value> srcBuffer = 783 getBuffer(rewriter, reshapeOp.getSource(), options); 784 FailureOr<Value> shapeBuffer = 785 getBuffer(rewriter, reshapeOp.getShape(), options); 786 if (failed(srcBuffer) || failed(shapeBuffer)) 787 return failure(); 788 auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 789 auto resultMemRefType = getMemRefType(resultTensorType, options); 790 replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 791 rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); 792 return success(); 793 } 794 }; 795 796 } // namespace 797 } // namespace tensor 798 } // namespace mlir 799 800 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 801 DialectRegistry ®istry) { 802 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 803 CastOp::attachInterface<CastOpInterface>(*ctx); 804 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 805 DimOp::attachInterface<DimOpInterface>(*ctx); 806 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 807 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 808 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 809 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 810 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 811 InsertOp::attachInterface<InsertOpInterface>(*ctx); 812 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 813 RankOp::attachInterface<RankOpInterface>(*ctx); 814 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 815 }); 816 } 817