1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/SCF/IR/SCF.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 using namespace mlir::tensor; 22 23 namespace mlir { 24 namespace tensor { 25 namespace { 26 27 struct CastOpInterface 28 : public BufferizableOpInterface::ExternalModel<CastOpInterface, 29 tensor::CastOp> { 30 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 31 const AnalysisState &state) const { 32 return false; 33 } 34 35 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 36 const AnalysisState &state) const { 37 return false; 38 } 39 40 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 41 const AnalysisState &state) const { 42 return {op->getResult(0)}; 43 } 44 45 BufferRelation bufferRelation(Operation *op, OpResult opResult, 46 const AnalysisState &state) const { 47 return BufferRelation::Equivalent; 48 } 49 50 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 51 const BufferizationOptions &options) const { 52 auto castOp = cast<tensor::CastOp>(op); 53 54 // The result buffer still has the old (pre-cast) type. 55 FailureOr<Value> resultBuffer = 56 getBuffer(rewriter, castOp.getSource(), options); 57 if (failed(resultBuffer)) 58 return failure(); 59 auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>(); 60 TensorType resultTensorType = 61 castOp.getResult().getType().cast<TensorType>(); 62 MemRefLayoutAttrInterface layout; 63 64 if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>()) 65 if (resultTensorType.isa<RankedTensorType>()) 66 layout = rankedMemRefType.getLayout(); 67 68 // Compute the new memref type. 69 Type resultMemRefType = 70 getMemRefType(resultTensorType, options, layout, 71 sourceMemRefType.getMemorySpaceAsInt()); 72 73 // Replace the op with a memref.cast. 74 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), 75 resultMemRefType) && 76 "CallOp::bufferize: cast incompatible"); 77 replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, 78 *resultBuffer); 79 80 return success(); 81 } 82 }; 83 84 /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. 85 struct CollapseShapeOpInterface 86 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, 87 tensor::CollapseShapeOp> { 88 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 89 const AnalysisState &state) const { 90 return false; 91 } 92 93 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 94 const AnalysisState &state) const { 95 return false; 96 } 97 98 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 99 const AnalysisState &state) const { 100 if (&opOperand == &op->getOpOperand(0) /*src*/) 101 return {op->getOpResult(0)}; 102 return {}; 103 } 104 105 BufferRelation bufferRelation(Operation *op, OpResult opResult, 106 const AnalysisState &state) const { 107 return BufferRelation::Equivalent; 108 } 109 110 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 111 const BufferizationOptions &options) const { 112 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); 113 RankedTensorType tensorResultType = collapseShapeOp.getResultType(); 114 FailureOr<Value> maybeBuffer = 115 getBuffer(rewriter, collapseShapeOp.getSrc(), options); 116 if (failed(maybeBuffer)) 117 return failure(); 118 Value buffer = *maybeBuffer; 119 auto bufferType = buffer.getType().cast<MemRefType>(); 120 121 if (tensorResultType.getRank() == 0) { 122 // 0-d collapses must go through a different op builder. 123 MemRefType resultType; 124 125 if (bufferType.getLayout().isIdentity()) { 126 // Standard layout: result type has no offset. 127 MemRefLayoutAttrInterface layout; 128 resultType = MemRefType::get({}, tensorResultType.getElementType(), 129 layout, bufferType.getMemorySpace()); 130 } else { 131 // Source memref has a layout map: result type has the same offset as 132 // the source type. 133 SmallVector<int64_t> strides; 134 int64_t offset; 135 if (failed(getStridesAndOffset(bufferType, strides, offset))) 136 return failure(); 137 AffineMap resultLayout = 138 makeStridedLinearLayoutMap({}, offset, op->getContext()); 139 resultType = 140 MemRefType::get({}, tensorResultType.getElementType(), resultLayout, 141 bufferType.getMemorySpaceAsInt()); 142 } 143 144 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 145 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); 146 return success(); 147 } 148 149 // If the dims are not collapsible (due to an incompatible source layout 150 // map), force an out-of-place bufferization, i.e., a buffer copy. This 151 // newly allocated buffer will have no layout map and thus be collapsible. 152 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( 153 bufferType, collapseShapeOp.getReassociationIndices()); 154 if (!canBeCollapsed) { 155 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 156 AnalysisState analysisState(options); 157 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 158 rewriter, op->getLoc(), collapseShapeOp.getSrc(), 159 analysisState.isTensorYielded(collapseShapeOp.getResult()), options); 160 if (failed(tensorAlloc)) 161 return failure(); 162 auto memrefType = 163 MemRefType::get(collapseShapeOp.getSrcType().getShape(), 164 collapseShapeOp.getSrcType().getElementType(), 165 AffineMap(), bufferType.getMemorySpaceAsInt()); 166 buffer = rewriter.create<bufferization::ToMemrefOp>( 167 op->getLoc(), memrefType, *tensorAlloc); 168 } 169 170 // Result type is inferred by the builder. 171 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( 172 rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); 173 return success(); 174 } 175 }; 176 177 /// Bufferization of tensor.dim. Replace with memref.dim. 178 struct DimOpInterface 179 : public BufferizableOpInterface::ExternalModel<DimOpInterface, 180 tensor::DimOp> { 181 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 182 const AnalysisState &state) const { 183 return true; 184 } 185 186 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 187 const AnalysisState &state) const { 188 return false; 189 } 190 191 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 192 const AnalysisState &state) const { 193 return {}; 194 } 195 196 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 197 const BufferizationOptions &options) const { 198 auto dimOp = cast<tensor::DimOp>(op); 199 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options); 200 if (failed(v)) 201 return failure(); 202 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, 203 dimOp.index()); 204 return success(); 205 } 206 }; 207 208 /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. 209 struct ExpandShapeOpInterface 210 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, 211 tensor::ExpandShapeOp> { 212 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 213 const AnalysisState &state) const { 214 return false; 215 } 216 217 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 218 const AnalysisState &state) const { 219 return false; 220 } 221 222 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 223 const AnalysisState &state) const { 224 if (&opOperand == &op->getOpOperand(0) /*src*/) 225 return {op->getOpResult(0)}; 226 return {}; 227 } 228 229 BufferRelation bufferRelation(Operation *op, OpResult opResult, 230 const AnalysisState &state) const { 231 return BufferRelation::Equivalent; 232 } 233 234 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 235 const BufferizationOptions &options) const { 236 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); 237 auto tensorResultType = expandShapeOp.getResultType(); 238 FailureOr<Value> buffer = 239 getBuffer(rewriter, expandShapeOp.getSrc(), options); 240 if (failed(buffer)) 241 return failure(); 242 243 // Memref result type is inferred by the builder based on reassociation 244 // indices and result shape. 245 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( 246 rewriter, op, tensorResultType.getShape(), *buffer, 247 expandShapeOp.getReassociationIndices()); 248 return success(); 249 } 250 }; 251 252 /// Bufferization of tensor.extract_slice. Replace with memref.subview. 253 struct ExtractSliceOpInterface 254 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, 255 tensor::ExtractSliceOp> { 256 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 257 const AnalysisState &state) const { 258 return false; 259 } 260 261 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 262 const AnalysisState &state) const { 263 return false; 264 } 265 266 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 267 const AnalysisState &state) const { 268 if (&opOperand == &op->getOpOperand(0) /*source*/) 269 return {op->getOpResult(0)}; 270 return {}; 271 } 272 273 BufferRelation bufferRelation(Operation *op, OpResult opResult, 274 const AnalysisState &state) const { 275 return BufferRelation::None; 276 } 277 278 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 279 const BufferizationOptions &options) const { 280 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 281 Location loc = extractSliceOp.getLoc(); 282 283 // Even if this op was decided to bufferize out-of-place, do not insert the 284 // buffer copy yet. This is done later in this function. 285 FailureOr<Value> srcMemref = 286 getBuffer(rewriter, extractSliceOp.getSource(), options); 287 if (failed(srcMemref)) 288 return failure(); 289 auto srcMemrefType = srcMemref->getType().cast<MemRefType>(); 290 auto dstTensorType = 291 extractSliceOp.getResult().getType().cast<RankedTensorType>(); 292 293 // Expand offsets, sizes and strides to the full rank to handle the 294 // rank-reducing case. 295 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 296 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 297 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 298 OffsetSizeAndStrideOpInterface::expandToRank( 299 *srcMemref, mixedOffsets, mixedSizes, mixedStrides, 300 [&](Value target, int64_t dim) -> OpFoldResult { 301 auto shapedType = target.getType().cast<ShapedType>(); 302 if (shapedType.isDynamicDim(dim)) 303 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 304 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 305 }); 306 // Bufferize to subview. 307 auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( 308 dstTensorType.getRank(), srcMemrefType, 309 mixedOffsets, mixedSizes, mixedStrides) 310 .cast<MemRefType>(); 311 Value subView = rewriter.create<memref::SubViewOp>( 312 loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, 313 mixedStrides); 314 315 replaceOpWithBufferizedValues(rewriter, op, subView); 316 return success(); 317 } 318 }; 319 320 /// Bufferization of tensor.extract. Replace with memref.load. 321 struct ExtractOpInterface 322 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, 323 tensor::ExtractOp> { 324 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 325 const AnalysisState &state) const { 326 return true; 327 } 328 329 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 330 const AnalysisState &state) const { 331 return false; 332 } 333 334 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 335 const AnalysisState &state) const { 336 return {}; 337 } 338 339 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 340 const BufferizationOptions &options) const { 341 auto extractOp = cast<tensor::ExtractOp>(op); 342 FailureOr<Value> srcMemref = 343 getBuffer(rewriter, extractOp.getTensor(), options); 344 if (failed(srcMemref)) 345 return failure(); 346 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, 347 extractOp.indices()); 348 return success(); 349 } 350 }; 351 352 // Implements backtracking to traverse indices of the output buffer while 353 // iterating over op.elements(). 354 static void createStores(RewriterBase &rewriter, Location loc, int dim, 355 Value buffer, ArrayRef<int64_t> shape, 356 ArrayRef<Value> constants, 357 OperandRange::iterator &elementIt, 358 SmallVectorImpl<Value> &indices) { 359 if (dim == static_cast<int>(shape.size()) - 1) { 360 for (int i = 0; i < shape.back(); ++i) { 361 indices.back() = constants[i]; 362 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); 363 ++elementIt; 364 } 365 return; 366 } 367 for (int i = 0; i < shape[dim]; ++i) { 368 indices[dim] = constants[i]; 369 createStores(rewriter, loc, dim + 1, buffer, shape, constants, elementIt, 370 indices); 371 } 372 } 373 374 /// Bufferization of tensor.from_elements. 375 struct FromElementsOpInterface 376 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, 377 tensor::FromElementsOp> { 378 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 379 const BufferizationOptions &options) const { 380 auto fromElementsOp = cast<tensor::FromElementsOp>(op); 381 382 // TODO: Implement memory space for this op. 383 if (options.defaultMemorySpace != static_cast<unsigned>(0)) 384 return op->emitError("memory space not implemented yet"); 385 386 // Allocate a buffer for the result. 387 Location loc = op->getLoc(); 388 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>(); 389 auto shape = tensorType.getShape(); 390 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 391 AnalysisState analysisState(options); 392 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 393 rewriter, loc, fromElementsOp.getResult(), 394 analysisState.isTensorYielded(fromElementsOp.getResult()), options, 395 /*copy=*/false); 396 if (failed(tensorAlloc)) 397 return failure(); 398 auto memrefType = 399 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 400 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 401 op->getLoc(), memrefType, *tensorAlloc); 402 403 // Case: tensor<0xelem_type>. 404 if (fromElementsOp.getElements().empty()) { 405 replaceOpWithBufferizedValues(rewriter, op, buffer); 406 return success(); 407 } 408 409 // Case: tensor<elem_type>. 410 if (shape.empty()) { 411 rewriter.create<memref::StoreOp>( 412 loc, fromElementsOp.getElements().front(), buffer); 413 replaceOpWithBufferizedValues(rewriter, op, buffer); 414 return success(); 415 } 416 417 // Create constants for the range of possible indices [0, max{shape_i}). 418 auto maxDim = *std::max_element(shape.begin(), shape.end()); 419 SmallVector<Value, 2> constants; 420 constants.reserve(maxDim); 421 for (int i = 0; i < maxDim; ++i) 422 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 423 424 // Traverse all `elements` and create `memref.store` ops. 425 auto elementIt = fromElementsOp.getElements().begin(); 426 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 427 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, 428 indices); 429 430 replaceOpWithBufferizedValues(rewriter, op, buffer); 431 return success(); 432 } 433 }; 434 435 /// Bufferization of tensor.generate. 436 struct GenerateOpInterface 437 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, 438 tensor::GenerateOp> { 439 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 440 const BufferizationOptions &options) const { 441 auto generateOp = cast<tensor::GenerateOp>(op); 442 443 // TODO: Implement memory space for this op. 444 if (options.defaultMemorySpace != static_cast<unsigned>(0)) 445 return op->emitError("memory space not implemented yet"); 446 447 auto tensorType = generateOp.getType().cast<RankedTensorType>(); 448 // Allocate memory. 449 Location loc = op->getLoc(); 450 // TODO: Create alloc_tensor ops during TensorCopyInsertion. 451 AnalysisState analysisState(options); 452 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( 453 rewriter, loc, generateOp.getResult(), 454 analysisState.isTensorYielded(generateOp.getResult()), options, 455 /*copy=*/false); 456 if (failed(tensorAlloc)) 457 return failure(); 458 auto memrefType = 459 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 460 Value buffer = rewriter.create<bufferization::ToMemrefOp>( 461 op->getLoc(), memrefType, *tensorAlloc); 462 463 // Collect loop bounds. 464 int64_t rank = memrefType.getRank(); 465 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 466 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 467 SmallVector<Value, 4> lowerBounds(rank, zero); 468 SmallVector<Value, 4> steps(rank, one); 469 SmallVector<Value, 4> upperBounds; 470 int nextDynamicIndex = 0; 471 for (int i = 0; i < rank; i++) { 472 Value upperBound = 473 memrefType.isDynamicDim(i) 474 ? generateOp.getDynamicExtents()[nextDynamicIndex++] 475 : rewriter.create<arith::ConstantIndexOp>( 476 loc, memrefType.getDimSize(i)); 477 upperBounds.push_back(upperBound); 478 } 479 480 // Generate tensor elements with a parallel loop that stores into 481 // each element of the resulting memref. We use mergeBlockBefore to "move" 482 // this op's body into the scf.parallel's body. 483 auto parallel = 484 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 485 Block *parallelBody = parallel.getBody(); 486 rewriter.mergeBlockBefore(&generateOp.getBody().front(), 487 parallelBody->getTerminator(), 488 parallelBody->getArguments()); 489 // Replace the inlined yield op with a store op. The scf.parallel's builder 490 // already populated an scf.yield at the end, so we don't need to worry 491 // about creating that. 492 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 493 rewriter.setInsertionPointAfter(elementYield); 494 rewriter.replaceOpWithNewOp<memref::StoreOp>( 495 elementYield, elementYield->getOperands()[0], buffer, 496 parallelBody->getArguments()); 497 498 replaceOpWithBufferizedValues(rewriter, op, buffer); 499 return success(); 500 } 501 }; 502 503 /// Bufferization of tensor.insert. Replace with memref.store. 504 struct InsertOpInterface 505 : public BufferizableOpInterface::ExternalModel<InsertOpInterface, 506 tensor::InsertOp> { 507 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 508 const AnalysisState &state) const { 509 return true; 510 } 511 512 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 513 const AnalysisState &state) const { 514 return true; 515 } 516 517 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 518 const AnalysisState &state) const { 519 assert(&opOperand == &op->getOpOperand(1) /*dest*/ && 520 "expected dest OpOperand"); 521 return {op->getOpResult(0)}; 522 } 523 524 SmallVector<OpOperand *> 525 getAliasingOpOperand(Operation *op, OpResult opResult, 526 const AnalysisState &state) const { 527 return {&op->getOpOperand(1) /*dest*/}; 528 } 529 530 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 531 const BufferizationOptions &options) const { 532 auto insertOp = cast<tensor::InsertOp>(op); 533 FailureOr<Value> destMemref = 534 getBuffer(rewriter, insertOp.getDest(), options); 535 if (failed(destMemref)) 536 return failure(); 537 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), 538 *destMemref, insertOp.getIndices()); 539 replaceOpWithBufferizedValues(rewriter, op, *destMemref); 540 return success(); 541 } 542 543 BufferRelation bufferRelation(Operation *op, OpResult opResult, 544 const AnalysisState &state) const { 545 return BufferRelation::Equivalent; 546 } 547 }; 548 549 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. 550 /// equivalent operand / result and same offset/sizes/strides specification). 551 /// 552 /// This is one particular type of relationship between ops on tensors that 553 /// reduce to an equivalence on buffers. This should be generalized and 554 /// exposed as interfaces on the proper types. 555 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 556 ExtractSliceOp st, InsertSliceOp sti) { 557 if (!st || !sti) 558 return false; 559 if (sti != sti && 560 !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 561 return false; 562 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 563 return false; 564 return true; 565 } 566 567 /// Return true if `value` is originating from an ExtractSliceOp that matches 568 /// the given InsertSliceOp. 569 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 570 InsertSliceOp insertOp) { 571 auto condition = [&](Value val) { 572 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 573 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 574 return true; 575 return false; 576 }; 577 578 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 579 condition); 580 } 581 582 /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under 583 /// certain circumstances, this op can also be a no-op. 584 struct InsertSliceOpInterface 585 : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface, 586 tensor::InsertSliceOp> { 587 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 588 const AnalysisState &state) const { 589 return true; 590 } 591 592 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 593 const AnalysisState &state) const { 594 return &opOperand == &op->getOpOperand(1) /*dest*/; 595 } 596 597 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 598 const AnalysisState &state) const { 599 if (&opOperand == &op->getOpOperand(1) /*dest*/) 600 return {op->getResult(0)}; 601 return {}; 602 } 603 604 BufferRelation bufferRelation(Operation *op, OpResult opResult, 605 const AnalysisState &state) const { 606 return BufferRelation::Equivalent; 607 } 608 609 bool isNotConflicting(Operation *op, OpOperand *uRead, 610 OpOperand *uConflictingWrite, 611 const AnalysisState &state) const { 612 Operation *readingOp = uRead->getOwner(); 613 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 614 615 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 616 // uRead is an InsertSliceOp... 617 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) { 618 // As an example, consider the following IR. 619 // 620 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 621 // %1 = linalg.fill %cst, %0 {inplace= [true] } 622 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 623 // {inplace= [true] } 624 625 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 626 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 627 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 628 insertSliceOp)) 629 // Case 1: The main insight is that InsertSliceOp reads only part of 630 // the destination tensor. The overwritten area is not read. If 631 // uConflictingWrite writes into exactly the memory location that is 632 // being read by uRead, this is not a conflict. 633 // 634 // In the above example: 635 // uRead = OpOperand 1 (%t) of tensor.insert_slice 636 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 637 // 638 // The read of %t does not conflict with the write of the FillOp 639 // (same aliases!) because the area that the FillOp operates on is 640 // exactly the one that is *not* read via %t. 641 return true; 642 643 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 644 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 645 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 646 // Case 2: The read of the source tensor and the write to the dest 647 // tensor via an InsertSliceOp is not a conflict if the read is 648 // reading exactly that part of an equivalent tensor that the 649 // InsertSliceOp is writing. 650 // 651 // In the above example: 652 // uRead = OpOperand 0 (%1) of tensor.insert_slice 653 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 654 return true; 655 } 656 657 // If uConflictingWrite is an InsertSliceOp... 658 if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp)) 659 // As an example, consider the following IR. 660 // 661 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 662 // %1 = linalg.fill %cst, %0 {inplace= [true] } 663 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 664 // {inplace= [true] } 665 // %3 = vector.transfer_read %1, %cst 666 // 667 // In the above example: 668 // uRead = OpOperand 0 (%1) of vector.transfer_read 669 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 670 // lastWrite = %1 671 // 672 // This is not a conflict because the InsertSliceOp overwrites the 673 // memory segment of %1 with the exact same data. (Effectively, there 674 // is no memory write here.) 675 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 676 state.areEquivalentBufferizedValues(uRead->get(), 677 insertSliceOp.getSource()) && 678 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 679 insertSliceOp)) 680 return true; 681 682 return false; 683 } 684 685 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 686 const BufferizationOptions &options) const { 687 // insert_slice ops arise from tiling and bufferizing them out-of-place is 688 // generally a deal breaker. When used with loops, this ends up cloning the 689 // whole tensor on every single iteration and is a symptom of a 690 // catastrophically bad scheduling decision. 691 // TODO: be very loud about it or even consider failing the pass. 692 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 693 Location loc = insertSliceOp.getLoc(); 694 FailureOr<Value> dstMemref = 695 getBuffer(rewriter, insertSliceOp.getDest(), options); 696 if (failed(dstMemref)) 697 return failure(); 698 699 // Expand offsets, sizes and strides to the full rank to handle the 700 // rank-reducing case. 701 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 702 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 703 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 704 OffsetSizeAndStrideOpInterface::expandToRank( 705 *dstMemref, mixedOffsets, mixedSizes, mixedStrides, 706 [&](Value target, int64_t dim) -> OpFoldResult { 707 auto shapedType = target.getType().cast<ShapedType>(); 708 if (shapedType.isDynamicDim(dim)) 709 return rewriter.create<memref::DimOp>(loc, target, dim).result(); 710 return rewriter.getIndexAttr(shapedType.getDimSize(dim)); 711 }); 712 // Take a subview of the dst. 713 auto dstMemrefType = dstMemref->getType().cast<MemRefType>(); 714 auto subviewMemRefType = 715 memref::SubViewOp::inferRankReducedResultType( 716 insertSliceOp.getSourceType().getRank(), dstMemrefType, 717 mixedOffsets, mixedSizes, mixedStrides) 718 .cast<MemRefType>(); 719 Value subView = rewriter.create<memref::SubViewOp>( 720 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, 721 mixedStrides); 722 723 // Copy tensor. If this tensor.insert_slice has a matching 724 // tensor.extract_slice, the copy operation will eventually fold away. 725 FailureOr<Value> srcMemref = 726 getBuffer(rewriter, insertSliceOp.getSource(), options); 727 if (failed(srcMemref)) 728 return failure(); 729 if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) 730 return failure(); 731 732 replaceOpWithBufferizedValues(rewriter, op, *dstMemref); 733 return success(); 734 } 735 }; 736 737 /// Bufferization of tensor.rank. Replace with memref.rank. 738 struct RankOpInterface 739 : public BufferizableOpInterface::ExternalModel<RankOpInterface, 740 tensor::RankOp> { 741 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 742 const AnalysisState &state) const { 743 return true; 744 } 745 746 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 747 const AnalysisState &state) const { 748 return false; 749 } 750 751 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 752 const AnalysisState &state) const { 753 return {}; 754 } 755 756 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 757 const BufferizationOptions &options) const { 758 auto rankOp = cast<tensor::RankOp>(op); 759 FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options); 760 if (failed(v)) 761 return failure(); 762 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), 763 *v); 764 return success(); 765 } 766 }; 767 768 /// Bufferization of tensor.reshape. Replace with memref.reshape. 769 struct ReshapeOpInterface 770 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, 771 tensor::ReshapeOp> { 772 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 773 const AnalysisState &state) const { 774 if (&opOperand == &op->getOpOperand(1) /* shape */) 775 return true; 776 return false; 777 } 778 779 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 780 const AnalysisState &state) const { 781 return false; 782 } 783 784 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 785 const AnalysisState &state) const { 786 return {op->getOpResult(0)}; 787 } 788 789 BufferRelation bufferRelation(Operation *op, OpResult opResult, 790 const AnalysisState &state) const { 791 return BufferRelation::Equivalent; 792 } 793 794 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 795 const BufferizationOptions &options) const { 796 auto reshapeOp = cast<tensor::ReshapeOp>(op); 797 FailureOr<Value> srcBuffer = 798 getBuffer(rewriter, reshapeOp.getSource(), options); 799 FailureOr<Value> shapeBuffer = 800 getBuffer(rewriter, reshapeOp.getShape(), options); 801 if (failed(srcBuffer) || failed(shapeBuffer)) 802 return failure(); 803 auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>(); 804 auto resultMemRefType = getMemRefType( 805 resultTensorType, options, /*layout=*/{}, 806 srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt()); 807 replaceOpWithNewBufferizedOp<memref::ReshapeOp>( 808 rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); 809 return success(); 810 } 811 }; 812 813 /// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. 814 /// equivalent operand / result and same offset/sizes/strides specification). 815 static bool areEquivalentExtractSliceOps(const AnalysisState &state, 816 ExtractSliceOp st, 817 ParallelInsertSliceOp sti) { 818 if (!st || !sti) 819 return false; 820 if (st != sti && 821 !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) 822 return false; 823 if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) 824 return false; 825 return true; 826 } 827 828 /// Return true if `value` is originating from an ExtractSliceOp that matches 829 /// the given InsertSliceOp. 830 static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, 831 ParallelInsertSliceOp insertOp) { 832 auto condition = [&](Value val) { 833 if (auto extractOp = val.getDefiningOp<ExtractSliceOp>()) 834 if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) 835 return true; 836 return false; 837 }; 838 839 return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), 840 condition); 841 } 842 843 /// Analysis of ParallelInsertSliceOp. 844 struct ParallelInsertSliceOpInterface 845 : public BufferizableOpInterface::ExternalModel< 846 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { 847 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 848 const AnalysisState &state) const { 849 if (&opOperand != &op->getOpOperand(1) /*dest*/) 850 return {}; 851 852 // ParallelInsertSliceOp itself has no results, query its tied op results. 853 auto insertOp = cast<ParallelInsertSliceOp>(op); 854 return {insertOp.getTiedOpResult()}; 855 } 856 857 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 858 const AnalysisState &state) const { 859 return true; 860 } 861 862 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 863 const AnalysisState &state) const { 864 return &opOperand == &op->getOpOperand(1) /*dest*/; 865 } 866 867 BufferRelation bufferRelation(Operation *op, OpResult opResult, 868 const AnalysisState &state) const { 869 return BufferRelation::Equivalent; 870 } 871 872 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, 873 const AnalysisState &state) const { 874 // This interface method is overridden because we want to set a custom 875 // insertion point for tensor copies. They should be inserted right before 876 // the ForeachThreadOp. E.g.: 877 // 878 // %r0, %r1 = foreach_thead ... { 879 // ... 880 // perform_concurrently { 881 // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]} 882 // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]} 883 // } 884 // } 885 // 886 // After TensorCopyInsertion: 887 // 888 // %copy = bufferization.alloc_tensor() copy(%d) 889 // %r0, %r1 = foreach_thead ... { 890 // ... 891 // perform_concurrently { 892 // parallel_insert_slice %a into %b ... 893 // parallel_insert_slice %c into %copy ... 894 // } 895 // } 896 897 OpBuilder::InsertionGuard g(rewriter); 898 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 899 ParallelCombiningOpInterface parallelCombiningParent = 900 parallelInsertSliceOp.getParallelCombiningParent(); 901 Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); 902 903 // Nothing to do if the destination tensor is inplace. 904 assert(state.isInPlace(op->getOpOperand(0) /*src*/) && 905 "source is always in-place"); 906 if (state.isInPlace(op->getOpOperand(1) /*dest*/)) 907 return success(); 908 909 // Find corresponding OpResult. 910 OpResult opResult = parallelInsertSliceOp.getTiedOpResult(); 911 912 // Insert tensor allocation right before the ForeachThreadOp. 913 rewriter.setInsertionPoint(parallelIteratingOp); 914 bool isYielded = state.isTensorYielded(opResult); 915 FailureOr<Value> alloc = allocateTensorForShapedValue( 916 rewriter, op->getLoc(), parallelInsertSliceOp.getDest(), 917 /*escape=*/isYielded, state.getOptions()); 918 if (failed(alloc)) 919 return failure(); 920 921 // Update destination operand. 922 rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() { 923 parallelInsertSliceOp.getDestMutable().assign(*alloc); 924 }); 925 926 return success(); 927 } 928 929 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 930 const BufferizationOptions &options) const { 931 OpBuilder::InsertionGuard g(rewriter); 932 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); 933 ParallelCombiningOpInterface parallelCombiningParent = 934 parallelInsertSliceOp.getParallelCombiningParent(); 935 Operation *parallelIteratingOp = parallelCombiningParent->getParentOp(); 936 937 // Get destination buffer. 938 FailureOr<Value> destBuffer = 939 getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); 940 if (failed(destBuffer)) 941 return failure(); 942 943 // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`. 944 rewriter.setInsertionPoint(parallelCombiningParent); 945 FailureOr<Value> srcBuffer = 946 getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); 947 if (failed(srcBuffer)) 948 return failure(); 949 Value subview = rewriter.create<memref::SubViewOp>( 950 parallelInsertSliceOp.getLoc(), *destBuffer, 951 parallelInsertSliceOp.getMixedOffsets(), 952 parallelInsertSliceOp.getMixedSizes(), 953 parallelInsertSliceOp.getMixedStrides()); 954 // This memcpy will fold away if everything bufferizes in-place. 955 if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), 956 *srcBuffer, subview))) 957 return failure(); 958 959 // Replace all uses of parallelIteratingOp (just the corresponding result). 960 rewriter.setInsertionPointAfter(parallelIteratingOp); 961 Value toTensorOp = 962 rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer); 963 // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. 964 SmallVector<OpOperand *> resultUses = llvm::to_vector( 965 llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(), 966 [](OpOperand &use) { return &use; })); 967 for (OpOperand *use : resultUses) { 968 rewriter.updateRootInPlace(use->getOwner(), 969 [&]() { use->set(toTensorOp); }); 970 } 971 rewriter.eraseOp(op); 972 return success(); 973 } 974 975 // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share 976 // the code. 977 bool isNotConflicting(Operation *op, OpOperand *uRead, 978 OpOperand *uConflictingWrite, 979 const AnalysisState &state) const { 980 Operation *readingOp = uRead->getOwner(); 981 Operation *conflictingWritingOp = uConflictingWrite->getOwner(); 982 983 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If 984 // uRead is an InsertSliceOp... 985 if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) { 986 // As an example, consider the following IR. 987 // 988 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 989 // %1 = linalg.fill %cst, %0 {inplace= [true] } 990 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 991 // {inplace= [true] } 992 993 // TODO: Use insertSliceOp.getDestOpOperand etc. when available. 994 if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && 995 hasMatchingExtractSliceOp(state, uConflictingWrite->get(), 996 insertSliceOp)) 997 // Case 1: The main insight is that InsertSliceOp reads only part of 998 // the destination tensor. The overwritten area is not read. If 999 // uConflictingWrite writes into exactly the memory location that is 1000 // being read by uRead, this is not a conflict. 1001 // 1002 // In the above example: 1003 // uRead = OpOperand 1 (%t) of tensor.insert_slice 1004 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill 1005 // 1006 // The read of %t does not conflict with the write of the FillOp 1007 // (same aliases!) because the area that the FillOp operates on is 1008 // exactly the one that is *not* read via %t. 1009 return true; 1010 1011 if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && 1012 uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1013 hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) 1014 // Case 2: The read of the source tensor and the write to the dest 1015 // tensor via an InsertSliceOp is not a conflict if the read is 1016 // reading exactly that part of an equivalent tensor that the 1017 // InsertSliceOp is writing. 1018 // 1019 // In the above example: 1020 // uRead = OpOperand 0 (%1) of tensor.insert_slice 1021 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1022 return true; 1023 } 1024 1025 // If uConflictingWrite is an InsertSliceOp... 1026 if (auto insertSliceOp = 1027 dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp)) 1028 // As an example, consider the following IR. 1029 // 1030 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } 1031 // %1 = linalg.fill %cst, %0 {inplace= [true] } 1032 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] 1033 // {inplace= [true] } 1034 // %3 = vector.transfer_read %1, %cst 1035 // 1036 // In the above example: 1037 // uRead = OpOperand 0 (%1) of vector.transfer_read 1038 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice 1039 // lastWrite = %1 1040 // 1041 // This is not a conflict because the InsertSliceOp overwrites the 1042 // memory segment of %1 with the exact same data. (Effectively, there 1043 // is no memory write here.) 1044 if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && 1045 state.areEquivalentBufferizedValues(uRead->get(), 1046 insertSliceOp.getSource()) && 1047 hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), 1048 insertSliceOp)) 1049 return true; 1050 1051 return false; 1052 } 1053 }; 1054 1055 } // namespace 1056 } // namespace tensor 1057 } // namespace mlir 1058 1059 void mlir::tensor::registerBufferizableOpInterfaceExternalModels( 1060 DialectRegistry ®istry) { 1061 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 1062 CastOp::attachInterface<CastOpInterface>(*ctx); 1063 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); 1064 DimOp::attachInterface<DimOpInterface>(*ctx); 1065 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); 1066 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); 1067 ExtractOp::attachInterface<ExtractOpInterface>(*ctx); 1068 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); 1069 GenerateOp::attachInterface<GenerateOpInterface>(*ctx); 1070 InsertOp::attachInterface<InsertOpInterface>(*ctx); 1071 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); 1072 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( 1073 *ctx); 1074 RankOp::attachInterface<RankOpInterface>(*ctx); 1075 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); 1076 }); 1077 } 1078