1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 using namespace mlir::tensor; 21 22 namespace mlir { 23 namespace tensor { 24 namespace { 25 26 struct CastOpInterface 27 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 28 tensor::CastOp> { 29 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 30 const AnalysisState &state) const { 31 return false; 32 } 33 34 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 35 const AnalysisState &state) const { 36 return false; 37 } 38 39 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 40 const AnalysisState &state) const { 41 return {op->getResult(0)}; 42 } 43 44 BufferRelation bufferRelation(Operation *op, OpResult opResult, 45 const AnalysisState &state) const { 46 return BufferRelation::Equivalent; 47 } 48 49 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 50 BufferizationState &state) const { 51 auto castOp = cast<tensor::CastOp>(op); 52 53 // The result buffer still has the old (pre-cast) type. 54 FailureOr<Value> resultBuffer = 55 state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); 56 if (failed(resultBuffer)) 57 return failure(); 58 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 59 Attribute memorySpace = sourceMemRefType.getMemorySpace(); 60 TensorType resultTensorType = 61 castOp.getResult().getType().cast<TensorType>(); 62 MemRefLayoutAttrInterface layout; 63 64 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 65 if (resultTensorType.isa<RankedTensorType>()) 66 layout = rankedMemRefType.getLayout(); 67 68 // Compute the new memref type. 69 Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), 70 layout, memorySpace); 71 72 // Replace the op with a memref.cast. 73 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 74 resultMemRefType) && 75 "CallOp::bufferize: cast incompatible"); 76 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 77 *resultBuffer); 78 79 return success(); 80 } 81 }; 82 83 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 84 struct CollapseShapeOpInterface 85 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 86 tensor::CollapseShapeOp> { 87 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 88 const AnalysisState &state) const { 89 return false; 90 } 91 92 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 93 const AnalysisState &state) const { 94 return false; 95 } 96 97 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 98 const AnalysisState &state) const { 99 if (&opOperand == &op->getOpOperand(0) /*src*/) 100 return {op->getOpResult(0)}; 101 return {}; 102 } 103 104 BufferRelation bufferRelation(Operation *op, OpResult opResult, 105 const AnalysisState &state) const { 106 return BufferRelation::Equivalent; 107 } 108 109 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 110 BufferizationState &state) const { 111 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 112 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 113 OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/; 114 auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>(); 115 116 if (tensorResultType.getRank() == 0) { 117 // 0-d collapses must go through a different op builder. 118 auto buffer = state.getBuffer(rewriter, srcOperand); 119 if (failed(buffer)) 120 return failure(); 121 MemRefType resultType; 122 123 if (bufferType.getLayout().isIdentity()) { 124 // Standard layout: result type has no offset. 125 MemRefLayoutAttrInterface layout; 126 resultType = MemRefType::get({}, tensorResultType.getElementType(), 127 layout, bufferType.getMemorySpace()); 128 } else { 129 // Source memref has a layout map: result type has the same offset as 130 // the source type. 131 SmallVector<int64_t> strides; 132 int64_t offset; 133 if (failed(getStridesAndOffset(bufferType, strides, offset))) 134 return failure(); 135 AffineMap resultLayout = 136 makeStridedLinearLayoutMap({}, offset, op->getContext()); 137 resultType = 138 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 139 bufferType.getMemorySpaceAsInt()); 140 } 141 142 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 143 rewriter, op, resultType, *buffer, collapseShapeOp.reassociation()); 144 return success(); 145 } 146 147 // If the dims are not collapsible (due to an incompatible source layout 148 // map), force an out-of-place bufferization, i.e., a buffer copy. This 149 // newly allocated buffer will have no layout map and thus be collapsible. 150 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 151 bufferType, collapseShapeOp.getReassociationIndices()); 152 Optional<BufferizationState::ForceInPlacability> overrideInPlace = 153 canBeCollapsed 154 ? None 155 : Optional<BufferizationState::ForceInPlacability>( 156 BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE); 157 auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace); 158 if (failed(buffer)) 159 return failure(); 160 161 // Result type is inferred by the builder. 162 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 163 rewriter, op, *buffer, collapseShapeOp.getReassociationIndices()); 164 return success(); 165 } 166 }; 167 168 /// Bufferization of tensor.dim. Replace with memref.dim. 169 struct DimOpInterface 170 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 171 tensor::DimOp> { 172 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 173 const AnalysisState &state) const { 174 return true; 175 } 176 177 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 178 const AnalysisState &state) const { 179 return false; 180 } 181 182 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 183 const AnalysisState &state) const { 184 return {}; 185 } 186 187 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 188 BufferizationState &state) const { 189 auto dimOp = cast<tensor::DimOp>(op); 190 auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); 191 if (failed(v)) 192 return failure(); 193 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, 194 dimOp.index()); 195 return success(); 196 } 197 }; 198 199 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 200 struct ExpandShapeOpInterface 201 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 202 tensor::ExpandShapeOp> { 203 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 204 const AnalysisState &state) const { 205 return false; 206 } 207 208 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 209 const AnalysisState &state) const { 210 return false; 211 } 212 213 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 214 const AnalysisState &state) const { 215 if (&opOperand == &op->getOpOperand(0) /*src*/) 216 return {op->getOpResult(0)}; 217 return {}; 218 } 219 220 BufferRelation bufferRelation(Operation *op, OpResult opResult, 221 const AnalysisState &state) const { 222 return BufferRelation::Equivalent; 223 } 224 225 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 226 BufferizationState &state) const { 227 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 228 auto tensorResultType = expandShapeOp.getResultType(); 229 auto buffer = 230 state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); 231 if (failed(buffer)) 232 return failure(); 233 234 // Memref result type is inferred by the builder based on reassociation 235 // indices and result shape. 236 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 237 rewriter, op, tensorResultType.getShape(), *buffer, 238 expandShapeOp.getReassociationIndices()); 239 return success(); 240 } 241 }; 242 243 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 244 struct ExtractSliceOpInterface 245 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 246 tensor::ExtractSliceOp> { 247 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 248 const AnalysisState &state) const { 249 return false; 250 } 251 252 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 253 const AnalysisState &state) const { 254 return false; 255 } 256 257 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 258 const AnalysisState &state) const { 259 if (&opOperand == &op->getOpOperand(0) /*source*/) 260 return {op->getOpResult(0)}; 261 return {}; 262 } 263 264 BufferRelation bufferRelation(Operation *op, OpResult opResult, 265 const AnalysisState &state) const { 266 return BufferRelation::None; 267 } 268 269 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 270 BufferizationState &state) const { 271 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 272 Location loc = extractSliceOp.getLoc(); 273 274 // Even if this op was decided to bufferize out-of-place, do not insert the 275 // buffer copy yet. This is done later in this function. 276 auto srcMemref = 277 state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, 278 BufferizationState::ForceInPlacability::FORCE_INPLACE); 279 if (failed(srcMemref)) 280 return failure(); 281 auto srcMemrefType = srcMemref->getType().cast<MemRefType>(); 282 auto dstTensorType = 283 extractSliceOp.result().getType().cast<RankedTensorType>(); 284 285 // If not inplaceable, alloc. 286 bool inplace = 287 state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); 288 Value alloc; 289 if (!inplace) { 290 FailureOr<Value> allocOrFailure = 291 state.createAlloc(rewriter, loc, extractSliceOp.result()); 292 if (failed(allocOrFailure)) 293 return failure(); 294 alloc = *allocOrFailure; 295 } 296 297 // Expand offsets, sizes and strides to the full rank to handle the 298 // rank-reducing case. 299 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 300 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 301 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 302 OffsetSizeAndStrideOpInterface::expandToRank( 303 *srcMemref, mixedOffsets, mixedSizes, mixedStrides, 304 [&](Value target, int64_t dim) -> OpFoldResult { 305 auto shapedType = target.getType().cast<ShapedType>(); 306 if (shapedType.isDynamicDim(dim)) 307 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 308 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 309 }); 310 // Bufferize to subview. 311 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 312 dstTensorType.getRank(), srcMemrefType, 313 mixedOffsets, mixedSizes, mixedStrides) 314 .cast<MemRefType>(); 315 Value subView = rewriter.create<memref::SubViewOp>( 316 loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, 317 mixedStrides); 318 319 // If not inplaceable, copy. 320 if (!inplace) { 321 // Do not copy if the copied data is never read. 322 if (state.getAnalysisState().isValueRead(extractSliceOp.result())) 323 if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, 324 alloc, state.getOptions()))) 325 return failure(); 326 subView = alloc; 327 } 328 329 replaceOpWithBufferizedValues(rewriter, op, subView); 330 return success(); 331 } 332 }; 333 334 /// Bufferization of tensor.extract. Replace with memref.load. 335 struct ExtractOpInterface 336 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 337 tensor::ExtractOp> { 338 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 339 const AnalysisState &state) const { 340 return true; 341 } 342 343 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 344 const AnalysisState &state) const { 345 return false; 346 } 347 348 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 349 const AnalysisState &state) const { 350 return {}; 351 } 352 353 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 354 BufferizationState &state) const { 355 auto extractOp = cast<tensor::ExtractOp>(op); 356 auto srcMemref = 357 state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); 358 if (failed(srcMemref)) 359 return failure(); 360 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, 361 extractOp.indices()); 362 return success(); 363 } 364 }; 365 366 // Implements backtracking to traverse indices of the output buffer while 367 // iterating over op.elements(). 368 static void createStores(RewriterBase &rewriter, Location loc, int dim, 369 Value buffer, ArrayRef<int64_t> shape, 370 ArrayRef<Value> constants, 371 OperandRange::iterator &elementIt, 372 SmallVectorImpl<Value> &indices) { 373 if (dim == static_cast<int>(shape.size()) - 1) { 374 for (int i = 0; i < shape.back(); ++i) { 375 indices.back() = constants[i]; 376 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 377 ++elementIt; 378 } 379 return; 380 } 381 for (int i = 0; i < shape[dim]; ++i) { 382 indices[dim] = constants[i]; 383 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 384 indices); 385 } 386 } 387 388 /// Bufferization of tensor.from_elements. 389 struct FromElementsOpInterface 390 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 391 tensor::FromElementsOp> { 392 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 393 BufferizationState &state) const { 394 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 395 396 // Allocate a buffer for the result. 397 Location loc = op->getLoc(); 398 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 399 auto shape = tensorType.getShape(); 400 FailureOr<Value> maybeBuffer = 401 state.createAlloc(rewriter, loc, fromElementsOp.result()); 402 if (failed(maybeBuffer)) 403 return failure(); 404 Value buffer = *maybeBuffer; 405 406 // Case: tensor<0xelem_type>. 407 if (fromElementsOp.elements().empty()) { 408 replaceOpWithBufferizedValues(rewriter, op, buffer); 409 return success(); 410 } 411 412 // Case: tensor<elem_type>. 413 if (shape.empty()) { 414 rewriter.create<memref::StoreOp>(loc, fromElementsOp.elements().front(), 415 buffer); 416 replaceOpWithBufferizedValues(rewriter, op, buffer); 417 return success(); 418 } 419 420 // Create constants for the range of possible indices [0, max{shape_i}). 421 auto maxDim = *std::max_element(shape.begin(), shape.end()); 422 SmallVector<Value, 2> constants; 423 constants.reserve(maxDim); 424 for (int i = 0; i < maxDim; ++i) 425 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 426 427 // Traverse all `elements` and create `memref.store` ops. 428 auto elementIt = fromElementsOp.elements().begin(); 429 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 430 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 431 indices); 432 433 replaceOpWithBufferizedValues(rewriter, op, buffer); 434 return success(); 435 } 436 }; 437 438 /// Bufferization of tensor.generate. 439 struct GenerateOpInterface 440 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 441 tensor::GenerateOp> { 442 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 443 BufferizationState &state) const { 444 auto generateOp = cast<tensor::GenerateOp>(op); 445 446 // Allocate memory. 447 Location loc = op->getLoc(); 448 MemRefType memrefType = 449 getContiguousMemRefType(generateOp.getType().cast<RankedTensorType>()); 450 FailureOr<Value> maybeResult = 451 state.createAlloc(rewriter, loc, generateOp.result()); 452 if (failed(maybeResult)) 453 return failure(); 454 Value result = *maybeResult; 455 456 // Collect loop bounds. 457 int64_t rank = memrefType.getRank(); 458 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 459 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 460 SmallVector<Value, 4> lowerBounds(rank, zero); 461 SmallVector<Value, 4> steps(rank, one); 462 SmallVector<Value, 4> upperBounds; 463 int nextDynamicIndex = 0; 464 for (int i = 0; i < rank; i++) { 465 Value upperBound = memrefType.isDynamicDim(i) 466 ? generateOp.dynamicExtents()[nextDynamicIndex++] 467 : rewriter.create<arith::ConstantIndexOp>( 468 loc, memrefType.getDimSize(i)); 469 upperBounds.push_back(upperBound); 470 } 471 472 // Generate tensor elements with a parallel loop that stores into 473 // each element of the resulting memref. We use mergeBlockBefore to "move" 474 // this op's body into the scf.parallel's body. 475 auto parallel = 476 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 477 Block *parallelBody = parallel.getBody(); 478 rewriter.mergeBlockBefore(generateOp.getBody(), 479 parallelBody->getTerminator(), 480 parallelBody->getArguments()); 481 // Replace the inlined yield op with a store op. The scf.parallel's builder 482 // already populated an scf.yield at the end, so we don't need to worry 483 // about creating that. 484 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 485 rewriter.setInsertionPointAfter(elementYield); 486 rewriter.replaceOpWithNewOp<memref::StoreOp>( 487 elementYield, elementYield->getOperands()[0], result, 488 parallelBody->getArguments()); 489 490 replaceOpWithBufferizedValues(rewriter, op, result); 491 return success(); 492 } 493 }; 494 495 /// Bufferization of tensor.insert. Replace with memref.store. 496 struct InsertOpInterface 497 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 498 tensor::InsertOp> { 499 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 500 const AnalysisState &state) const { 501 return true; 502 } 503 504 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 505 const AnalysisState &state) const { 506 return true; 507 } 508 509 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 510 const AnalysisState &state) const { 511 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 512 "expected dest OpOperand"); 513 return {op->getOpResult(0)}; 514 } 515 516 SmallVector<OpOperand *> 517 getAliasingOpOperand(Operation *op, OpResult opResult, 518 const AnalysisState &state) const { 519 return {&op->getOpOperand(1) /*dest*/}; 520 } 521 522 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 523 BufferizationState &state) const { 524 auto insertOp = cast<tensor::InsertOp>(op); 525 FailureOr<Value> destMemref = 526 state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); 527 if (failed(destMemref)) 528 return failure(); 529 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(), 530 *destMemref, insertOp.indices()); 531 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 532 return success(); 533 } 534 535 BufferRelation bufferRelation(Operation *op, OpResult opResult, 536 const AnalysisState &state) const { 537 return BufferRelation::Equivalent; 538 } 539 }; 540 541 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 542 /// equivalent operand / result and same offset/sizes/strides specification). 543 /// 544 /// This is one particular type of relationship between ops on tensors that 545 /// reduce to an equivalence on buffers. This should be generalized and 546 /// exposed as interfaces on the proper types. 547 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 548 ExtractSliceOp st, InsertSliceOp sti) { 549 if (!st || !sti) 550 return false; 551 if (sti != sti && 552 !state.areEquivalentBufferizedValues(st.source(), sti.dest())) 553 return false; 554 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 555 return false; 556 return true; 557 } 558 559 /// Return true if `value` is originating from an ExtractSliceOp that matches 560 /// the given InsertSliceOp. 561 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 562 InsertSliceOp insertOp) { 563 auto condition = [&](Value val) { 564 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 565 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 566 return true; 567 return false; 568 }; 569 570 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 571 condition); 572 } 573 574 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 575 /// certain circumstances, this op can also be a no-op. 576 struct InsertSliceOpInterface 577 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 578 tensor::InsertSliceOp> { 579 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 580 const AnalysisState &state) const { 581 return true; 582 } 583 584 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 585 const AnalysisState &state) const { 586 return &opOperand == &op->getOpOperand(1) /*dest*/; 587 } 588 589 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 590 const AnalysisState &state) const { 591 if (&opOperand == &op->getOpOperand(1) /*dest*/) 592 return {op->getResult(0)}; 593 return {}; 594 } 595 596 BufferRelation bufferRelation(Operation *op, OpResult opResult, 597 const AnalysisState &state) const { 598 return BufferRelation::Equivalent; 599 } 600 601 bool isNotConflicting(Operation *op, OpOperand *uRead, 602 OpOperand *uConflictingWrite, 603 const AnalysisState &state) const { 604 Operation *readingOp = uRead->getOwner(); 605 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 606 607 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 608 // uRead is an InsertSliceOp... 609 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 610 // As an example, consider the following IR. 611 // 612 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 613 // %1 = linalg.fill %cst, %0 {inplace= [true] } 614 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 615 // {inplace= [true] } 616 617 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 618 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 619 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 620 insertSliceOp)) 621 // Case 1: The main insight is that InsertSliceOp reads only part of 622 // the destination tensor. The overwritten area is not read. If 623 // uConflictingWrite writes into exactly the memory location that is 624 // being read by uRead, this is not a conflict. 625 // 626 // In the above example: 627 // uRead = OpOperand 1 (%t) of tensor.insert_slice 628 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 629 // 630 // The read of %t does not conflict with the write of the FillOp 631 // (same aliases!) because the area that the FillOp operates on is 632 // exactly the one that is *not* read via %t. 633 return true; 634 635 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 636 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 637 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 638 // Case 2: The read of the source tensor and the write to the dest 639 // tensor via an InsertSliceOp is not a conflict if the read is 640 // reading exactly that part of an equivalent tensor that the 641 // InsertSliceOp is writing. 642 // 643 // In the above example: 644 // uRead = OpOperand 0 (%1) of tensor.insert_slice 645 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 646 return true; 647 } 648 649 // If uConflictingWrite is an InsertSliceOp... 650 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 651 // As an example, consider the following IR. 652 // 653 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 654 // %1 = linalg.fill %cst, %0 {inplace= [true] } 655 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 656 // {inplace= [true] } 657 // %3 = vector.transfer_read %1, %cst 658 // 659 // In the above example: 660 // uRead = OpOperand 0 (%1) of vector.transfer_read 661 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 662 // lastWrite = %1 663 // 664 // This is not a conflict because the InsertSliceOp overwrites the 665 // memory segment of %1 with the exact same data. (Effectively, there 666 // is no memory write here.) 667 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 668 state.areEquivalentBufferizedValues(uRead->get(), 669 insertSliceOp.source()) && 670 hasMatchingExtractSliceOp(state, insertSliceOp.source(), 671 insertSliceOp)) 672 return true; 673 674 return false; 675 } 676 677 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 678 BufferizationState &state) const { 679 // insert_slice ops arise from tiling and bufferizing them out-of-place is 680 // generally a deal breaker. When used with loops, this ends up cloning the 681 // whole tensor on every single iteration and is a symptom of a 682 // catastrophically bad scheduling decision. 683 // TODO: be very loud about it or even consider failing the pass. 684 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 685 Location loc = insertSliceOp.getLoc(); 686 687 // When bufferizing out-of-place, `getResultBuffer` allocates. 688 FailureOr<Value> dstMemref = 689 state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); 690 if (failed(dstMemref)) 691 return failure(); 692 693 // Expand offsets, sizes and strides to the full rank to handle the 694 // rank-reducing case. 695 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 696 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 697 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 698 OffsetSizeAndStrideOpInterface::expandToRank( 699 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 700 [&](Value target, int64_t dim) -> OpFoldResult { 701 auto shapedType = target.getType().cast<ShapedType>(); 702 if (shapedType.isDynamicDim(dim)) 703 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 704 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 705 }); 706 // Take a subview of the dst. 707 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 708 auto subviewMemRefType = 709 memref::SubViewOp::inferRankReducedResultType( 710 insertSliceOp.getSourceType().getRank(), dstMemrefType, 711 mixedOffsets, mixedSizes, mixedStrides) 712 .cast<MemRefType>(); 713 Value subView = rewriter.create<memref::SubViewOp>( 714 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 715 mixedStrides); 716 717 // Copy tensor. If this tensor.insert_slice has a matching 718 // tensor.extract_slice, the copy operation will eventually fold away. 719 auto srcMemref = 720 state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); 721 if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref, 722 subView, state.getOptions()))) 723 return failure(); 724 725 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 726 return success(); 727 } 728 }; 729 730 /// Bufferization of tensor.rank. Replace with memref.rank. 731 struct RankOpInterface 732 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 733 tensor::RankOp> { 734 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 735 const AnalysisState &state) const { 736 return true; 737 } 738 739 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 740 const AnalysisState &state) const { 741 return false; 742 } 743 744 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 745 const AnalysisState &state) const { 746 return {}; 747 } 748 749 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 750 BufferizationState &state) const { 751 auto rankOp = cast<tensor::RankOp>(op); 752 auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); 753 if (failed(v)) 754 return failure(); 755 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 756 *v); 757 return success(); 758 } 759 }; 760 761 /// Bufferization of tensor.reshape. Replace with memref.reshape. 762 struct ReshapeOpInterface 763 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 764 tensor::ReshapeOp> { 765 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 766 const AnalysisState &state) const { 767 if (&opOperand == &op->getOpOperand(1) /* shape */) 768 return true; 769 return false; 770 } 771 772 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 773 const AnalysisState &state) const { 774 return false; 775 } 776 777 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 778 const AnalysisState &state) const { 779 return {op->getOpResult(0)}; 780 } 781 782 BufferRelation bufferRelation(Operation *op, OpResult opResult, 783 const AnalysisState &state) const { 784 return BufferRelation::Equivalent; 785 } 786 787 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 788 BufferizationState &state) const { 789 auto reshapeOp = cast<tensor::ReshapeOp>(op); 790 auto &srcOperand = reshapeOp->getOpOperand(0); 791 auto srcBuffer = state.getBuffer(rewriter, srcOperand); 792 if (failed(srcBuffer)) 793 return failure(); 794 795 auto &shapeOperand = reshapeOp->getOpOperand(1); 796 auto shapeBuffer = state.getBuffer(rewriter, shapeOperand); 797 if (failed(shapeBuffer)) 798 return failure(); 799 800 auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 801 auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions()); 802 803 replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 804 rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); 805 return success(); 806 } 807 }; 808 809 } // namespace 810 } // namespace tensor 811 } // namespace mlir 812 813 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 814 DialectRegistry ®istry) { 815 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 816 CastOp::attachInterface<CastOpInterface>(*ctx); 817 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 818 DimOp::attachInterface<DimOpInterface>(*ctx); 819 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 820 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 821 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 822 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 823 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 824 InsertOp::attachInterface<InsertOpInterface>(*ctx); 825 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 826 RankOp::attachInterface<RankOpInterface>(*ctx); 827 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 828 }); 829 } 830