1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SCF/IR/SCF.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 using namespace mlir::tensor; 22 23 namespace mlir { 24 namespace tensor { 25 namespace { 26 27 struct CastOpInterface 28 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 29 tensor::CastOp> { 30 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 31 const AnalysisState &state) const { 32 return false; 33 } 34 35 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 36 const AnalysisState &state) const { 37 return false; 38 } 39 40 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 41 const AnalysisState &state) const { 42 return {op->getResult(0)}; 43 } 44 45 BufferRelation bufferRelation(Operation *op, OpResult opResult, 46 const AnalysisState &state) const { 47 return BufferRelation::Equivalent; 48 } 49 50 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 51 const BufferizationOptions &options) const { 52 auto castOp = cast<tensor::CastOp>(op); 53 54 // The result buffer still has the old (pre-cast) type. 55 FailureOr<Value> resultBuffer = 56 getBuffer(rewriter, castOp.getSource(), options); 57 if (failed(resultBuffer)) 58 return failure(); 59 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 60 TensorType resultTensorType = 61 castOp.getResult().getType().cast<TensorType>(); 62 MemRefLayoutAttrInterface layout; 63 64 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 65 if (resultTensorType.isa<RankedTensorType>()) 66 layout = rankedMemRefType.getLayout(); 67 68 // Compute the new memref type. 69 Type resultMemRefType = 70 getMemRefType(resultTensorType, options, layout, 71 sourceMemRefType.getMemorySpaceAsInt()); 72 73 // Replace the op with a memref.cast. 74 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 75 resultMemRefType) && 76 "CallOp::bufferize: cast incompatible"); 77 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 78 *resultBuffer); 79 80 return success(); 81 } 82 }; 83 84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 85 struct CollapseShapeOpInterface 86 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 87 tensor::CollapseShapeOp> { 88 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 89 const AnalysisState &state) const { 90 return false; 91 } 92 93 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 94 const AnalysisState &state) const { 95 return false; 96 } 97 98 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 99 const AnalysisState &state) const { 100 if (&opOperand == &op->getOpOperand(0) /*src*/) 101 return {op->getOpResult(0)}; 102 return {}; 103 } 104 105 BufferRelation bufferRelation(Operation *op, OpResult opResult, 106 const AnalysisState &state) const { 107 return BufferRelation::Equivalent; 108 } 109 110 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 111 const BufferizationOptions &options) const { 112 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 113 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 114 FailureOr<Value> maybeBuffer = 115 getBuffer(rewriter, collapseShapeOp.getSrc(), options); 116 if (failed(maybeBuffer)) 117 return failure(); 118 Value buffer = *maybeBuffer; 119 auto bufferType = buffer.getType().cast<MemRefType>(); 120 121 if (tensorResultType.getRank() == 0) { 122 // 0-d collapses must go through a different op builder. 123 MemRefType resultType; 124 125 if (bufferType.getLayout().isIdentity()) { 126 // Standard layout: result type has no offset. 127 MemRefLayoutAttrInterface layout; 128 resultType = MemRefType::get({}, tensorResultType.getElementType(), 129 layout, bufferType.getMemorySpace()); 130 } else { 131 // Source memref has a layout map: result type has the same offset as 132 // the source type. 133 SmallVector<int64_t> strides; 134 int64_t offset; 135 if (failed(getStridesAndOffset(bufferType, strides, offset))) 136 return failure(); 137 AffineMap resultLayout = 138 makeStridedLinearLayoutMap({}, offset, op->getContext()); 139 resultType = 140 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 141 bufferType.getMemorySpaceAsInt()); 142 } 143 144 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 145 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); 146 return success(); 147 } 148 149 // If the dims are not collapsible (due to an incompatible source layout 150 // map), force an out-of-place bufferization, i.e., a buffer copy. This 151 // newly allocated buffer will have no layout map and thus be collapsible. 152 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 153 bufferType, collapseShapeOp.getReassociationIndices()); 154 if (!canBeCollapsed) { 155 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 156 AnalysisState analysisState(options); 157 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 158 rewriter, op->getLoc(), collapseShapeOp.getSrc(), 159 analysisState.isTensorYielded(collapseShapeOp.getResult()), options); 160 if (failed(tensorAlloc)) 161 return failure(); 162 auto memrefType = 163 MemRefType::get(collapseShapeOp.getSrcType().getShape(), 164 collapseShapeOp.getSrcType().getElementType(), 165 AffineMap(), bufferType.getMemorySpaceAsInt()); 166 buffer = rewriter.create<bufferization::ToMemrefOp>( 167 op->getLoc(), memrefType, *tensorAlloc); 168 } 169 170 // Result type is inferred by the builder. 171 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 172 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 173 return success(); 174 } 175 }; 176 177 /// Bufferization of tensor.dim. Replace with memref.dim. 178 struct DimOpInterface 179 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 180 tensor::DimOp> { 181 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 182 const AnalysisState &state) const { 183 return true; 184 } 185 186 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 187 const AnalysisState &state) const { 188 return false; 189 } 190 191 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 192 const AnalysisState &state) const { 193 return {}; 194 } 195 196 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 197 const BufferizationOptions &options) const { 198 auto dimOp = cast<tensor::DimOp>(op); 199 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options); 200 if (failed(v)) 201 return failure(); 202 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, 203 dimOp.index()); 204 return success(); 205 } 206 }; 207 208 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 209 struct ExpandShapeOpInterface 210 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 211 tensor::ExpandShapeOp> { 212 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 213 const AnalysisState &state) const { 214 return false; 215 } 216 217 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 218 const AnalysisState &state) const { 219 return false; 220 } 221 222 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 223 const AnalysisState &state) const { 224 if (&opOperand == &op->getOpOperand(0) /*src*/) 225 return {op->getOpResult(0)}; 226 return {}; 227 } 228 229 BufferRelation bufferRelation(Operation *op, OpResult opResult, 230 const AnalysisState &state) const { 231 return BufferRelation::Equivalent; 232 } 233 234 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 235 const BufferizationOptions &options) const { 236 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 237 auto tensorResultType = expandShapeOp.getResultType(); 238 FailureOr<Value> buffer = 239 getBuffer(rewriter, expandShapeOp.getSrc(), options); 240 if (failed(buffer)) 241 return failure(); 242 243 // Memref result type is inferred by the builder based on reassociation 244 // indices and result shape. 245 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 246 rewriter, op, tensorResultType.getShape(), *buffer, 247 expandShapeOp.getReassociationIndices()); 248 return success(); 249 } 250 }; 251 252 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 253 struct ExtractSliceOpInterface 254 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 255 tensor::ExtractSliceOp> { 256 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 257 const AnalysisState &state) const { 258 return false; 259 } 260 261 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 262 const AnalysisState &state) const { 263 return false; 264 } 265 266 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 267 const AnalysisState &state) const { 268 if (&opOperand == &op->getOpOperand(0) /*source*/) 269 return {op->getOpResult(0)}; 270 return {}; 271 } 272 273 BufferRelation bufferRelation(Operation *op, OpResult opResult, 274 const AnalysisState &state) const { 275 return BufferRelation::None; 276 } 277 278 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 279 const BufferizationOptions &options) const { 280 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 281 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 282 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 283 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 284 Location loc = extractSliceOp.getLoc(); 285 286 // Get source buffer. 287 FailureOr<Value> srcMemref = 288 getBuffer(rewriter, extractSliceOp.getSource(), options); 289 if (failed(srcMemref)) 290 return failure(); 291 auto srcMemrefType = srcMemref->getType().cast<MemRefType>(); 292 293 // Take a subview of the source buffer. 294 auto subviewMemRefType = 295 memref::SubViewOp::inferRankReducedResultType( 296 extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets, 297 mixedSizes, mixedStrides) 298 .cast<MemRefType>(); 299 Value subView = rewriter.create<memref::SubViewOp>( 300 loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, 301 mixedStrides); 302 303 replaceOpWithBufferizedValues(rewriter, op, subView); 304 return success(); 305 } 306 }; 307 308 /// Bufferization of tensor.extract. Replace with memref.load. 309 struct ExtractOpInterface 310 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 311 tensor::ExtractOp> { 312 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 313 const AnalysisState &state) const { 314 return true; 315 } 316 317 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 318 const AnalysisState &state) const { 319 return false; 320 } 321 322 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 323 const AnalysisState &state) const { 324 return {}; 325 } 326 327 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 328 const BufferizationOptions &options) const { 329 auto extractOp = cast<tensor::ExtractOp>(op); 330 FailureOr<Value> srcMemref = 331 getBuffer(rewriter, extractOp.getTensor(), options); 332 if (failed(srcMemref)) 333 return failure(); 334 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, 335 extractOp.indices()); 336 return success(); 337 } 338 }; 339 340 // Implements backtracking to traverse indices of the output buffer while 341 // iterating over op.elements(). 342 static void createStores(RewriterBase &rewriter, Location loc, int dim, 343 Value buffer, ArrayRef<int64_t> shape, 344 ArrayRef<Value> constants, 345 OperandRange::iterator &elementIt, 346 SmallVectorImpl<Value> &indices) { 347 if (dim == static_cast<int>(shape.size()) - 1) { 348 for (int i = 0; i < shape.back(); ++i) { 349 indices.back() = constants[i]; 350 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 351 ++elementIt; 352 } 353 return; 354 } 355 for (int i = 0; i < shape[dim]; ++i) { 356 indices[dim] = constants[i]; 357 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 358 indices); 359 } 360 } 361 362 /// Bufferization of tensor.from_elements. 363 struct FromElementsOpInterface 364 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 365 tensor::FromElementsOp> { 366 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 367 const BufferizationOptions &options) const { 368 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 369 370 // TODO: Implement memory space for this op. 371 if (options.defaultMemorySpace != static_cast<unsigned>(0)) 372 return op->emitError("memory space not implemented yet"); 373 374 // Allocate a buffer for the result. 375 Location loc = op->getLoc(); 376 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 377 auto shape = tensorType.getShape(); 378 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 379 AnalysisState analysisState(options); 380 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 381 rewriter, loc, fromElementsOp.getResult(), 382 analysisState.isTensorYielded(fromElementsOp.getResult()), options, 383 /*copy=*/false); 384 if (failed(tensorAlloc)) 385 return failure(); 386 auto memrefType = 387 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 388 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 389 op->getLoc(), memrefType, *tensorAlloc); 390 391 // Case: tensor<0xelem_type>. 392 if (fromElementsOp.getElements().empty()) { 393 replaceOpWithBufferizedValues(rewriter, op, buffer); 394 return success(); 395 } 396 397 // Case: tensor<elem_type>. 398 if (shape.empty()) { 399 rewriter.create<memref::StoreOp>( 400 loc, fromElementsOp.getElements().front(), buffer); 401 replaceOpWithBufferizedValues(rewriter, op, buffer); 402 return success(); 403 } 404 405 // Create constants for the range of possible indices [0, max{shape_i}). 406 auto maxDim = *std::max_element(shape.begin(), shape.end()); 407 SmallVector<Value, 2> constants; 408 constants.reserve(maxDim); 409 for (int i = 0; i < maxDim; ++i) 410 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 411 412 // Traverse all `elements` and create `memref.store` ops. 413 auto elementIt = fromElementsOp.getElements().begin(); 414 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 415 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 416 indices); 417 418 replaceOpWithBufferizedValues(rewriter, op, buffer); 419 return success(); 420 } 421 }; 422 423 /// Bufferization of tensor.generate. 424 struct GenerateOpInterface 425 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 426 tensor::GenerateOp> { 427 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 428 const BufferizationOptions &options) const { 429 auto generateOp = cast<tensor::GenerateOp>(op); 430 431 // TODO: Implement memory space for this op. 432 if (options.defaultMemorySpace != static_cast<unsigned>(0)) 433 return op->emitError("memory space not implemented yet"); 434 435 auto tensorType = generateOp.getType().cast<RankedTensorType>(); 436 // Allocate memory. 437 Location loc = op->getLoc(); 438 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 439 AnalysisState analysisState(options); 440 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 441 rewriter, loc, generateOp.getResult(), 442 analysisState.isTensorYielded(generateOp.getResult()), options, 443 /*copy=*/false); 444 if (failed(tensorAlloc)) 445 return failure(); 446 auto memrefType = 447 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 448 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 449 op->getLoc(), memrefType, *tensorAlloc); 450 451 // Collect loop bounds. 452 int64_t rank = memrefType.getRank(); 453 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 454 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 455 SmallVector<Value, 4> lowerBounds(rank, zero); 456 SmallVector<Value, 4> steps(rank, one); 457 SmallVector<Value, 4> upperBounds; 458 int nextDynamicIndex = 0; 459 for (int i = 0; i < rank; i++) { 460 Value upperBound = 461 memrefType.isDynamicDim(i) 462 ? generateOp.getDynamicExtents()[nextDynamicIndex++] 463 : rewriter.create<arith::ConstantIndexOp>( 464 loc, memrefType.getDimSize(i)); 465 upperBounds.push_back(upperBound); 466 } 467 468 // Generate tensor elements with a parallel loop that stores into 469 // each element of the resulting memref. We use mergeBlockBefore to "move" 470 // this op's body into the scf.parallel's body. 471 auto parallel = 472 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 473 Block *parallelBody = parallel.getBody(); 474 rewriter.mergeBlockBefore(&generateOp.getBody().front(), 475 parallelBody->getTerminator(), 476 parallelBody->getArguments()); 477 // Replace the inlined yield op with a store op. The scf.parallel's builder 478 // already populated an scf.yield at the end, so we don't need to worry 479 // about creating that. 480 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 481 rewriter.setInsertionPointAfter(elementYield); 482 rewriter.replaceOpWithNewOp<memref::StoreOp>( 483 elementYield, elementYield->getOperands()[0], buffer, 484 parallelBody->getArguments()); 485 486 replaceOpWithBufferizedValues(rewriter, op, buffer); 487 return success(); 488 } 489 }; 490 491 /// Bufferization of tensor.insert. Replace with memref.store. 492 struct InsertOpInterface 493 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 494 tensor::InsertOp> { 495 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 496 const AnalysisState &state) const { 497 return true; 498 } 499 500 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 501 const AnalysisState &state) const { 502 return true; 503 } 504 505 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 506 const AnalysisState &state) const { 507 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 508 "expected dest OpOperand"); 509 return {op->getOpResult(0)}; 510 } 511 512 SmallVector<OpOperand *> 513 getAliasingOpOperand(Operation *op, OpResult opResult, 514 const AnalysisState &state) const { 515 return {&op->getOpOperand(1) /*dest*/}; 516 } 517 518 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 519 const BufferizationOptions &options) const { 520 auto insertOp = cast<tensor::InsertOp>(op); 521 FailureOr<Value> destMemref = 522 getBuffer(rewriter, insertOp.getDest(), options); 523 if (failed(destMemref)) 524 return failure(); 525 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), 526 *destMemref, insertOp.getIndices()); 527 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 528 return success(); 529 } 530 531 BufferRelation bufferRelation(Operation *op, OpResult opResult, 532 const AnalysisState &state) const { 533 return BufferRelation::Equivalent; 534 } 535 }; 536 537 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 538 /// equivalent operand / result and same offset/sizes/strides specification). 539 /// 540 /// This is one particular type of relationship between ops on tensors that 541 /// reduce to an equivalence on buffers. This should be generalized and 542 /// exposed as interfaces on the proper types. 543 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 544 ExtractSliceOp st, InsertSliceOp sti) { 545 if (!st || !sti) 546 return false; 547 if (sti != sti && 548 !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 549 return false; 550 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 551 return false; 552 return true; 553 } 554 555 /// Return true if `value` is originating from an ExtractSliceOp that matches 556 /// the given InsertSliceOp. 557 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 558 InsertSliceOp insertOp) { 559 auto condition = [&](Value val) { 560 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 561 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 562 return true; 563 return false; 564 }; 565 566 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 567 condition); 568 } 569 570 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 571 /// certain circumstances, this op can also be a no-op. 572 struct InsertSliceOpInterface 573 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 574 tensor::InsertSliceOp> { 575 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 576 const AnalysisState &state) const { 577 return true; 578 } 579 580 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 581 const AnalysisState &state) const { 582 return &opOperand == &op->getOpOperand(1) /*dest*/; 583 } 584 585 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 586 const AnalysisState &state) const { 587 if (&opOperand == &op->getOpOperand(1) /*dest*/) 588 return {op->getResult(0)}; 589 return {}; 590 } 591 592 BufferRelation bufferRelation(Operation *op, OpResult opResult, 593 const AnalysisState &state) const { 594 return BufferRelation::Equivalent; 595 } 596 597 bool isNotConflicting(Operation *op, OpOperand *uRead, 598 OpOperand *uConflictingWrite, 599 const AnalysisState &state) const { 600 Operation *readingOp = uRead->getOwner(); 601 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 602 603 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 604 // uRead is an InsertSliceOp... 605 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 606 // As an example, consider the following IR. 607 // 608 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 609 // %1 = linalg.fill %cst, %0 {inplace= [true] } 610 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 611 // {inplace= [true] } 612 613 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 614 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 615 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 616 insertSliceOp)) 617 // Case 1: The main insight is that InsertSliceOp reads only part of 618 // the destination tensor. The overwritten area is not read. If 619 // uConflictingWrite writes into exactly the memory location that is 620 // being read by uRead, this is not a conflict. 621 // 622 // In the above example: 623 // uRead = OpOperand 1 (%t) of tensor.insert_slice 624 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 625 // 626 // The read of %t does not conflict with the write of the FillOp 627 // (same aliases!) because the area that the FillOp operates on is 628 // exactly the one that is *not* read via %t. 629 return true; 630 631 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 632 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 633 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 634 // Case 2: The read of the source tensor and the write to the dest 635 // tensor via an InsertSliceOp is not a conflict if the read is 636 // reading exactly that part of an equivalent tensor that the 637 // InsertSliceOp is writing. 638 // 639 // In the above example: 640 // uRead = OpOperand 0 (%1) of tensor.insert_slice 641 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 642 return true; 643 } 644 645 // If uConflictingWrite is an InsertSliceOp... 646 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 647 // As an example, consider the following IR. 648 // 649 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 650 // %1 = linalg.fill %cst, %0 {inplace= [true] } 651 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 652 // {inplace= [true] } 653 // %3 = vector.transfer_read %1, %cst 654 // 655 // In the above example: 656 // uRead = OpOperand 0 (%1) of vector.transfer_read 657 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 658 // lastWrite = %1 659 // 660 // This is not a conflict because the InsertSliceOp overwrites the 661 // memory segment of %1 with the exact same data. (Effectively, there 662 // is no memory write here.) 663 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 664 state.areEquivalentBufferizedValues(uRead->get(), 665 insertSliceOp.getSource()) && 666 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 667 insertSliceOp)) 668 return true; 669 670 return false; 671 } 672 673 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 674 const BufferizationOptions &options) const { 675 // insert_slice ops arise from tiling and bufferizing them out-of-place is 676 // generally a deal breaker. When used with loops, this ends up cloning the 677 // whole tensor on every single iteration and is a symptom of a 678 // catastrophically bad scheduling decision. 679 // TODO: be very loud about it or even consider failing the pass. 680 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 681 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 682 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 683 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 684 Location loc = insertSliceOp.getLoc(); 685 686 // Get destination buffer. 687 FailureOr<Value> dstMemref = 688 getBuffer(rewriter, insertSliceOp.getDest(), options); 689 if (failed(dstMemref)) 690 return failure(); 691 692 // Take a subview of the destination buffer. 693 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 694 auto subviewMemRefType = 695 memref::SubViewOp::inferRankReducedResultType( 696 insertSliceOp.getSourceType().getShape(), dstMemrefType, 697 mixedOffsets, mixedSizes, mixedStrides) 698 .cast<MemRefType>(); 699 Value subView = rewriter.create<memref::SubViewOp>( 700 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 701 mixedStrides); 702 703 // Copy tensor. If this tensor.insert_slice has a matching 704 // tensor.extract_slice, the copy operation will eventually fold away. 705 FailureOr<Value> srcMemref = 706 getBuffer(rewriter, insertSliceOp.getSource(), options); 707 if (failed(srcMemref)) 708 return failure(); 709 if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) 710 return failure(); 711 712 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 713 return success(); 714 } 715 }; 716 717 /// Bufferization of tensor.rank. Replace with memref.rank. 718 struct RankOpInterface 719 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 720 tensor::RankOp> { 721 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 722 const AnalysisState &state) const { 723 return true; 724 } 725 726 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 727 const AnalysisState &state) const { 728 return false; 729 } 730 731 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 732 const AnalysisState &state) const { 733 return {}; 734 } 735 736 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 737 const BufferizationOptions &options) const { 738 auto rankOp = cast<tensor::RankOp>(op); 739 FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options); 740 if (failed(v)) 741 return failure(); 742 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 743 *v); 744 return success(); 745 } 746 }; 747 748 /// Bufferization of tensor.reshape. Replace with memref.reshape. 749 struct ReshapeOpInterface 750 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 751 tensor::ReshapeOp> { 752 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 753 const AnalysisState &state) const { 754 if (&opOperand == &op->getOpOperand(1) /* shape */) 755 return true; 756 return false; 757 } 758 759 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 760 const AnalysisState &state) const { 761 return false; 762 } 763 764 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 765 const AnalysisState &state) const { 766 return {op->getOpResult(0)}; 767 } 768 769 BufferRelation bufferRelation(Operation *op, OpResult opResult, 770 const AnalysisState &state) const { 771 return BufferRelation::Equivalent; 772 } 773 774 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 775 const BufferizationOptions &options) const { 776 auto reshapeOp = cast<tensor::ReshapeOp>(op); 777 FailureOr<Value> srcBuffer = 778 getBuffer(rewriter, reshapeOp.getSource(), options); 779 FailureOr<Value> shapeBuffer = 780 getBuffer(rewriter, reshapeOp.getShape(), options); 781 if (failed(srcBuffer) || failed(shapeBuffer)) 782 return failure(); 783 auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 784 auto resultMemRefType = getMemRefType( 785 resultTensorType, options, /*layout=*/{}, 786 srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt()); 787 replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 788 rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); 789 return success(); 790 } 791 }; 792 793 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. 794 /// equivalent operand / result and same offset/sizes/strides specification). 795 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 796 ExtractSliceOp st, 797 ParallelInsertSliceOp sti) { 798 if (!st || !sti) 799 return false; 800 if (st != sti && 801 !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 802 return false; 803 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 804 return false; 805 return true; 806 } 807 808 /// Return true if `value` is originating from an ExtractSliceOp that matches 809 /// the given InsertSliceOp. 810 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 811 ParallelInsertSliceOp insertOp) { 812 auto condition = [&](Value val) { 813 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 814 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 815 return true; 816 return false; 817 }; 818 819 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 820 condition); 821 } 822 823 /// Analysis of ParallelInsertSliceOp. 824 struct ParallelInsertSliceOpInterface 825 : public BufferizableOpInterface::ExternalModel< 826 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { 827 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 828 const AnalysisState &state) const { 829 if (&opOperand != &op->getOpOperand(1) /*dest*/) 830 return {}; 831 832 // ParallelInsertSliceOp itself has no results, query its tied op results. 833 auto insertOp = cast<ParallelInsertSliceOp>(op); 834 return {insertOp.getTiedOpResult()}; 835 } 836 837 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 838 const AnalysisState &state) const { 839 return true; 840 } 841 842 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 843 const AnalysisState &state) const { 844 return &opOperand == &op->getOpOperand(1) /*dest*/; 845 } 846 847 BufferRelation bufferRelation(Operation *op, OpResult opResult, 848 const AnalysisState &state) const { 849 return BufferRelation::Equivalent; 850 } 851 852 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 853 const AnalysisState &state) const { 854 // This interface method is overridden because we want to set a custom 855 // insertion point for tensor copies. They should be inserted right before 856 // the ForeachThreadOp. E.g.: 857 // 858 // %r0, %r1 = foreach_thead ... { 859 // ... 860 // perform_concurrently { 861 // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]} 862 // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]} 863 // } 864 // } 865 // 866 // After TensorCopyInsertion: 867 // 868 // %copy = bufferization.alloc_tensor() copy(%d) 869 // %r0, %r1 = foreach_thead ... { 870 // ... 871 // perform_concurrently { 872 // parallel_insert_slice %a into %b ... 873 // parallel_insert_slice %c into %copy ... 874 // } 875 // } 876 877 OpBuilder::InsertionGuard g(rewriter); 878 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 879 ParallelCombiningOpInterface parallelCombiningParent = 880 parallelInsertSliceOp.getParallelCombiningParent(); 881 Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); 882 883 // Nothing to do if the destination tensor is inplace. 884 assert(state.isInPlace(op->getOpOperand(0) /*src*/) && 885 "source is always in-place"); 886 if (state.isInPlace(op->getOpOperand(1) /*dest*/)) 887 return success(); 888 889 // Find corresponding OpResult. 890 OpResult opResult = parallelInsertSliceOp.getTiedOpResult(); 891 892 // Insert tensor allocation right before the ForeachThreadOp. 893 rewriter.setInsertionPoint(parallelIteratingOp); 894 bool isYielded = state.isTensorYielded(opResult); 895 FailureOr<Value> alloc = allocateTensorForShapedValue( 896 rewriter, op->getLoc(), parallelInsertSliceOp.getDest(), 897 /*escape=*/isYielded, state.getOptions()); 898 if (failed(alloc)) 899 return failure(); 900 901 // Update destination operand. 902 rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() { 903 parallelInsertSliceOp.getDestMutable().assign(*alloc); 904 }); 905 906 return success(); 907 } 908 909 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 910 const BufferizationOptions &options) const { 911 OpBuilder::InsertionGuard g(rewriter); 912 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 913 ParallelCombiningOpInterface parallelCombiningParent = 914 parallelInsertSliceOp.getParallelCombiningParent(); 915 Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); 916 917 // Get destination buffer. 918 FailureOr<Value> destBuffer = 919 getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); 920 if (failed(destBuffer)) 921 return failure(); 922 923 // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`. 924 rewriter.setInsertionPoint(parallelCombiningParent); 925 FailureOr<Value> srcBuffer = 926 getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); 927 if (failed(srcBuffer)) 928 return failure(); 929 930 // Take a subview of the destination buffer. 931 auto destBufferType = destBuffer->getType().cast<MemRefType>(); 932 auto subviewMemRefType = 933 memref::SubViewOp::inferRankReducedResultType( 934 parallelInsertSliceOp.getSourceType().getShape(), destBufferType, 935 parallelInsertSliceOp.getMixedOffsets(), 936 parallelInsertSliceOp.getMixedSizes(), 937 parallelInsertSliceOp.getMixedStrides()) 938 .cast<MemRefType>(); 939 Value subview = rewriter.create<memref::SubViewOp>( 940 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, 941 parallelInsertSliceOp.getMixedOffsets(), 942 parallelInsertSliceOp.getMixedSizes(), 943 parallelInsertSliceOp.getMixedStrides()); 944 945 // This memcpy will fold away if everything bufferizes in-place. 946 if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), 947 *srcBuffer, subview))) 948 return failure(); 949 950 // Replace all uses of parallelIteratingOp (just the corresponding result). 951 rewriter.setInsertionPointAfter(parallelIteratingOp); 952 Value toTensorOp = 953 rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer); 954 // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. 955 SmallVector<OpOperand *> resultUses = llvm::to_vector( 956 llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(), 957 [](OpOperand &use) { return &use; })); 958 for (OpOperand *use : resultUses) { 959 rewriter.updateRootInPlace(use->getOwner(), 960 [&]() { use->set(toTensorOp); }); 961 } 962 rewriter.eraseOp(op); 963 return success(); 964 } 965 966 // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share 967 // the code. 968 bool isNotConflicting(Operation *op, OpOperand *uRead, 969 OpOperand *uConflictingWrite, 970 const AnalysisState &state) const { 971 Operation *readingOp = uRead->getOwner(); 972 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 973 974 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 975 // uRead is an InsertSliceOp... 976 if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) { 977 // As an example, consider the following IR. 978 // 979 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 980 // %1 = linalg.fill %cst, %0 {inplace= [true] } 981 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 982 // {inplace= [true] } 983 984 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 985 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 986 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 987 insertSliceOp)) 988 // Case 1: The main insight is that InsertSliceOp reads only part of 989 // the destination tensor. The overwritten area is not read. If 990 // uConflictingWrite writes into exactly the memory location that is 991 // being read by uRead, this is not a conflict. 992 // 993 // In the above example: 994 // uRead = OpOperand 1 (%t) of tensor.insert_slice 995 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 996 // 997 // The read of %t does not conflict with the write of the FillOp 998 // (same aliases!) because the area that the FillOp operates on is 999 // exactly the one that is *not* read via %t. 1000 return true; 1001 1002 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 1003 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1004 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 1005 // Case 2: The read of the source tensor and the write to the dest 1006 // tensor via an InsertSliceOp is not a conflict if the read is 1007 // reading exactly that part of an equivalent tensor that the 1008 // InsertSliceOp is writing. 1009 // 1010 // In the above example: 1011 // uRead = OpOperand 0 (%1) of tensor.insert_slice 1012 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1013 return true; 1014 } 1015 1016 // If uConflictingWrite is an InsertSliceOp... 1017 if (auto insertSliceOp = 1018 dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp)) 1019 // As an example, consider the following IR. 1020 // 1021 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1022 // %1 = linalg.fill %cst, %0 {inplace= [true] } 1023 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1024 // {inplace= [true] } 1025 // %3 = vector.transfer_read %1, %cst 1026 // 1027 // In the above example: 1028 // uRead = OpOperand 0 (%1) of vector.transfer_read 1029 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1030 // lastWrite = %1 1031 // 1032 // This is not a conflict because the InsertSliceOp overwrites the 1033 // memory segment of %1 with the exact same data. (Effectively, there 1034 // is no memory write here.) 1035 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1036 state.areEquivalentBufferizedValues(uRead->get(), 1037 insertSliceOp.getSource()) && 1038 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 1039 insertSliceOp)) 1040 return true; 1041 1042 return false; 1043 } 1044 }; 1045 1046 } // namespace 1047 } // namespace tensor 1048 } // namespace mlir 1049 1050 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 1051 DialectRegistry ®istry) { 1052 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 1053 CastOp::attachInterface<CastOpInterface>(*ctx); 1054 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 1055 DimOp::attachInterface<DimOpInterface>(*ctx); 1056 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 1057 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 1058 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 1059 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 1060 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 1061 InsertOp::attachInterface<InsertOpInterface>(*ctx); 1062 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 1063 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 1064 *ctx); 1065 RankOp::attachInterface<RankOpInterface>(*ctx); 1066 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 1067 }); 1068 } 1069