1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/IR/Matchers.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 20 //===----------------------------------------------------------------------===// 21 // Helper functions 22 //===----------------------------------------------------------------------===// 23 24 FailureOr<Value> 25 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 26 MemRefType destType) { 27 auto srcType = value.getType().cast<MemRefType>(); 28 29 // Element type, rank and memory space must match. 30 if (srcType.getElementType() != destType.getElementType()) 31 return failure(); 32 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 33 return failure(); 34 if (srcType.getRank() != destType.getRank()) 35 return failure(); 36 37 // In case the affine maps are different, we may need to use a copy if we go 38 // from dynamic to static offset or stride (the canonicalization cannot know 39 // at this point that it is really cast compatible). 40 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 41 int64_t sourceOffset, targetOffset; 42 SmallVector<int64_t, 4> sourceStrides, targetStrides; 43 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 44 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 45 return false; 46 auto dynamicToStatic = [](int64_t a, int64_t b) { 47 return a == MemRefType::getDynamicStrideOrOffset() && 48 b != MemRefType::getDynamicStrideOrOffset(); 49 }; 50 if (dynamicToStatic(sourceOffset, targetOffset)) 51 return false; 52 for (auto it : zip(sourceStrides, targetStrides)) 53 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 54 return false; 55 return true; 56 }; 57 58 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 59 // ensure that we only generate casts that always succeed at runtime, we check 60 // a fix extra conditions in `isGuaranteedCastCompatible`. 61 if (memref::CastOp::areCastCompatible(srcType, destType) && 62 isGuaranteedCastCompatible(srcType, destType)) { 63 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 64 return casted; 65 } 66 67 auto loc = value.getLoc(); 68 SmallVector<Value, 4> dynamicOperands; 69 for (int i = 0; i < destType.getRank(); ++i) { 70 if (destType.getShape()[i] != ShapedType::kDynamicSize) 71 continue; 72 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 73 Value size = b.create<memref::DimOp>(loc, value, index); 74 dynamicOperands.push_back(size); 75 } 76 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 77 // BufferizableOpInterface impl of ToMemrefOp. 78 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 79 b.create<memref::CopyOp>(loc, value, copy); 80 return copy; 81 } 82 83 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 84 /// to_memref op are different, a memref.cast is needed. 85 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 86 RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { 87 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>(); 88 if (!memrefToTensor) 89 return failure(); 90 91 Type srcType = memrefToTensor.getMemref().getType(); 92 Type destType = toMemref.getType(); 93 94 // Directly rewrite if the type did not change. 95 if (srcType == destType) { 96 // Function can be configured to only handle cases where a cast is needed. 97 if (!allowSameType) 98 return failure(); 99 rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); 100 return success(); 101 } 102 103 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 104 auto rankedDestType = destType.dyn_cast<MemRefType>(); 105 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 106 107 // Ranked memref -> Ranked memref cast. 108 if (rankedSrcType && rankedDestType) { 109 FailureOr<Value> replacement = castOrReallocMemRefValue( 110 rewriter, memrefToTensor.getMemref(), rankedDestType); 111 if (failed(replacement)) 112 return failure(); 113 114 rewriter.replaceOp(toMemref, *replacement); 115 return success(); 116 } 117 118 // Unranked memref -> Ranked memref cast: May require a copy. 119 // TODO: Not implemented at the moment. 120 if (unrankedSrcType && rankedDestType) 121 return failure(); 122 123 // Unranked memref -> unranked memref cast 124 // Ranked memref -> unranked memref cast: No copy needed. 125 assert(memref::CastOp::areCastCompatible(srcType, destType) && 126 "expected that types are cast compatible"); 127 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 128 memrefToTensor.getMemref()); 129 return success(); 130 } 131 132 void mlir::bufferization::populateDynamicDimSizes( 133 OpBuilder &b, Location loc, Value shapedValue, 134 SmallVector<Value> &dynamicDims) { 135 auto shapedType = shapedValue.getType().cast<ShapedType>(); 136 for (int64_t i = 0; i < shapedType.getRank(); ++i) { 137 if (shapedType.isDynamicDim(i)) { 138 if (shapedType.isa<MemRefType>()) { 139 dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i)); 140 } else { 141 assert(shapedType.isa<RankedTensorType>() && "expected tensor"); 142 dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 143 } 144 } 145 } 146 } 147 148 //===----------------------------------------------------------------------===// 149 // AllocTensorOp 150 //===----------------------------------------------------------------------===// 151 152 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 153 const BufferizationOptions &options) { 154 OpBuilder::InsertionGuard g(rewriter); 155 Location loc = getLoc(); 156 157 // Nothing to do for dead AllocTensorOps. 158 if (getOperation()->getUses().empty()) { 159 rewriter.eraseOp(getOperation()); 160 return success(); 161 } 162 163 // Create buffer allocation. 164 Value copyBuffer; 165 if (getCopy()) 166 copyBuffer = getBuffer(rewriter, getCopy(), options); 167 auto allocType = 168 MemRefType::get(getType().getShape(), getType().getElementType()); 169 SmallVector<Value> dynamicDims = getDynamicSizes(); 170 if (getCopy()) { 171 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); 172 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); 173 } 174 FailureOr<Value> alloc = 175 options.createAlloc(rewriter, loc, allocType, dynamicDims); 176 if (failed(alloc)) 177 return failure(); 178 179 // Create memory copy (if any). 180 if (getCopy()) { 181 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) 182 return failure(); 183 } 184 185 // Should the buffer be deallocated? 186 AnalysisState analysisState(options); 187 bool dealloc; 188 if (getEscape().hasValue()) { 189 dealloc = !*getEscape(); 190 } else { 191 // No "escape" annotation found. 192 if (options.createDeallocs) { 193 // Perform an ad-hoc analysis. 194 dealloc = !analysisState.isTensorYielded(getResult()); 195 } else { 196 dealloc = false; 197 } 198 } 199 200 // Replace op. 201 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 202 203 // Create buffer deallocation (if requested). 204 if (!dealloc) 205 return success(); 206 207 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); 208 if (failed(options.createDealloc(rewriter, loc, *alloc))) 209 return failure(); 210 return success(); 211 } 212 213 bool AllocTensorOp::isMemoryWrite(OpResult opResult, 214 const AnalysisState &state) { 215 // AllocTensorOps do not write unless they have a `copy` value. 216 return static_cast<bool>(getCopy()); 217 } 218 219 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, 220 const AnalysisState &state) { 221 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 222 "expected copy operand"); 223 return true; 224 } 225 226 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, 227 const AnalysisState &state) { 228 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 229 "expected copy operand"); 230 return false; 231 } 232 233 SmallVector<OpResult> 234 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand, 235 const AnalysisState &state) { 236 // This is a new allocation. It does not alias with any other buffer. 237 return {}; 238 } 239 240 LogicalResult AllocTensorOp::verify() { 241 if (getCopy() && !getDynamicSizes().empty()) 242 return emitError("dynamic sizes not needed when copying a tensor"); 243 if (!getCopy() && getType().getNumDynamicDims() != 244 static_cast<int64_t>(getDynamicSizes().size())) 245 return emitError("expected ") 246 << getType().getNumDynamicDims() << " dynamic sizes"; 247 if (getCopy() && getCopy().getType() != getType()) 248 return emitError("expected that `copy` and return type match"); 249 return success(); 250 } 251 252 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 253 RankedTensorType type, ValueRange dynamicSizes) { 254 build(builder, result, type, dynamicSizes, /*copy=*/Value(), 255 /*escape=*/BoolAttr()); 256 } 257 258 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 259 RankedTensorType type, ValueRange dynamicSizes, 260 Value copy) { 261 build(builder, result, type, dynamicSizes, copy, /*escape=*/BoolAttr()); 262 } 263 264 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 265 RankedTensorType type, ValueRange dynamicSizes, 266 Value copy, bool escape) { 267 build(builder, result, type, dynamicSizes, copy, builder.getBoolAttr(escape)); 268 } 269 270 namespace { 271 /// Change the type of the result of a `bufferization.alloc_tensor` by making 272 /// the result type statically sized along dimension that in the original 273 /// operation where defined as dynamic, but the size was defined using a 274 /// `constant` op. For example: 275 /// 276 /// %c5 = arith.constant 5: index 277 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32> 278 /// 279 /// to 280 /// 281 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32> 282 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 283 using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 284 285 LogicalResult matchAndRewrite(AllocTensorOp op, 286 PatternRewriter &rewriter) const override { 287 if (op.getCopy()) 288 return failure(); 289 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape()); 290 SmallVector<Value> newDynamicSizes; 291 unsigned int dynValCounter = 0; 292 for (int64_t i = 0; i < op.getType().getRank(); ++i) { 293 if (!op.isDynamicDim(i)) 294 continue; 295 Value value = op.getDynamicSizes()[dynValCounter++]; 296 APInt intVal; 297 if (matchPattern(value, m_ConstantInt(&intVal))) { 298 newShape[i] = intVal.getSExtValue(); 299 } else { 300 newDynamicSizes.push_back(value); 301 } 302 } 303 RankedTensorType newType = RankedTensorType::get( 304 newShape, op.getType().getElementType(), op.getType().getEncoding()); 305 if (newType == op.getType()) 306 return failure(); 307 auto newOp = rewriter.create<AllocTensorOp>( 308 op.getLoc(), newType, newDynamicSizes, /*copy=*/Value(), 309 /*escape=*/op.getEscapeAttr()); 310 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 311 return success(); 312 } 313 }; 314 315 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 316 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 317 318 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 319 PatternRewriter &rewriter) const override { 320 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 321 auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>(); 322 if (!allocTensorOp || !maybeConstantIndex) 323 return failure(); 324 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) 325 return failure(); 326 rewriter.replaceOp( 327 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); 328 return success(); 329 } 330 }; 331 } // namespace 332 333 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 334 MLIRContext *ctx) { 335 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 336 } 337 338 LogicalResult AllocTensorOp::reifyResultShapes( 339 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 340 auto shapes = llvm::to_vector<4>(llvm::map_range( 341 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 342 if (isDynamicDim(dim)) 343 return getDynamicSize(builder, dim); 344 return builder.create<arith::ConstantIndexOp>(getLoc(), 345 getStaticSize(dim)); 346 })); 347 reifiedReturnShapes.emplace_back(std::move(shapes)); 348 return success(); 349 } 350 351 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { 352 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands; 353 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || 354 parser.parseRParen()) 355 return failure(); 356 ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); 357 OpAsmParser::UnresolvedOperand copyOperand; 358 if (copyKeyword.succeeded()) 359 if (parser.parseLParen() || parser.parseOperand(copyOperand) || 360 parser.parseRParen()) 361 return failure(); 362 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) 363 return failure(); 364 365 TensorType type; 366 if (parser.parseCustomTypeWithFallback(type)) 367 return failure(); 368 result.addTypes(type); 369 370 Type indexType = parser.getBuilder().getIndexType(); 371 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) 372 return failure(); 373 if (copyKeyword.succeeded()) 374 if (parser.resolveOperand(copyOperand, type, result.operands)) 375 return failure(); 376 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), 377 parser.getBuilder().getI32VectorAttr( 378 {static_cast<int32_t>(dynamicSizesOperands.size()), 379 static_cast<int32_t>(copyKeyword.succeeded())})); 380 return success(); 381 } 382 383 void AllocTensorOp::print(OpAsmPrinter &p) { 384 p << "(" << getDynamicSizes() << ")"; 385 if (getCopy()) 386 p << " copy(" << getCopy() << ")"; 387 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 388 AllocTensorOp::getOperandSegmentSizeAttr()}); 389 p << " : "; 390 auto type = getResult().getType(); 391 if (auto validType = type.dyn_cast<::mlir::TensorType>()) 392 p.printStrippedAttrOrType(validType); 393 else 394 p << type; 395 } 396 397 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { 398 assert(isDynamicDim(idx) && "expected dynamic dim"); 399 if (getCopy()) 400 return b.create<tensor::DimOp>(getLoc(), getCopy(), idx); 401 return getOperand(getIndexOfDynamicSize(idx)); 402 } 403 404 //===----------------------------------------------------------------------===// 405 // CloneOp 406 //===----------------------------------------------------------------------===// 407 408 void CloneOp::getEffects( 409 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 410 &effects) { 411 effects.emplace_back(MemoryEffects::Read::get(), getInput(), 412 SideEffects::DefaultResource::get()); 413 effects.emplace_back(MemoryEffects::Write::get(), getOutput(), 414 SideEffects::DefaultResource::get()); 415 effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(), 416 SideEffects::DefaultResource::get()); 417 } 418 419 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 420 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 421 } 422 423 namespace { 424 425 /// Merge the clone and its source (by converting the clone to a cast) when 426 /// possible. 427 struct SimplifyClones : public OpRewritePattern<CloneOp> { 428 using OpRewritePattern<CloneOp>::OpRewritePattern; 429 430 LogicalResult matchAndRewrite(CloneOp cloneOp, 431 PatternRewriter &rewriter) const override { 432 if (cloneOp.use_empty()) { 433 rewriter.eraseOp(cloneOp); 434 return success(); 435 } 436 437 Value source = cloneOp.getInput(); 438 439 // This only finds dealloc operations for the immediate value. It should 440 // also consider aliases. That would also make the safety check below 441 // redundant. 442 llvm::Optional<Operation *> maybeCloneDeallocOp = 443 memref::findDealloc(cloneOp.getOutput()); 444 // Skip if either of them has > 1 deallocate operations. 445 if (!maybeCloneDeallocOp.hasValue()) 446 return failure(); 447 llvm::Optional<Operation *> maybeSourceDeallocOp = 448 memref::findDealloc(source); 449 if (!maybeSourceDeallocOp.hasValue()) 450 return failure(); 451 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 452 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 453 454 // If both are deallocated in the same block, their in-block lifetimes 455 // might not fully overlap, so we cannot decide which one to drop. 456 if (cloneDeallocOp && sourceDeallocOp && 457 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 458 return failure(); 459 460 Block *currentBlock = cloneOp->getBlock(); 461 Operation *redundantDealloc = nullptr; 462 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 463 redundantDealloc = cloneDeallocOp; 464 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 465 redundantDealloc = sourceDeallocOp; 466 } 467 468 if (!redundantDealloc) 469 return failure(); 470 471 // Safety check that there are no other deallocations inbetween 472 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 473 // of source before the uses of the clone. With alias information, we could 474 // restrict this to only fail of the dealloc's operand is an alias 475 // of the source. 476 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 477 pos = pos->getNextNode()) { 478 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 479 if (!effectInterface) 480 continue; 481 if (effectInterface.hasEffect<MemoryEffects::Free>()) 482 return failure(); 483 } 484 485 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 486 source); 487 rewriter.eraseOp(redundantDealloc); 488 return success(); 489 } 490 }; 491 492 } // namespace 493 494 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 495 MLIRContext *context) { 496 results.add<SimplifyClones>(context); 497 } 498 499 //===----------------------------------------------------------------------===// 500 // ToTensorOp 501 //===----------------------------------------------------------------------===// 502 503 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 504 if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>()) 505 // Approximate alias analysis by conservatively folding only when no there 506 // is no interleaved operation. 507 if (toMemref->getBlock() == this->getOperation()->getBlock() && 508 toMemref->getNextNode() == this->getOperation()) 509 return toMemref.getTensor(); 510 return {}; 511 } 512 513 namespace { 514 515 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 516 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 517 518 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 519 PatternRewriter &rewriter) const override { 520 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 521 if (!memrefToTensorOp) 522 return failure(); 523 524 rewriter.replaceOpWithNewOp<memref::DimOp>( 525 dimOp, memrefToTensorOp.getMemref(), dimOp.index()); 526 return success(); 527 } 528 }; 529 530 } // namespace 531 532 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 533 MLIRContext *context) { 534 results.add<DimOfToTensorFolder>(context); 535 } 536 537 //===----------------------------------------------------------------------===// 538 // ToMemrefOp 539 //===----------------------------------------------------------------------===// 540 541 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 542 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>()) 543 if (memrefToTensor.getMemref().getType() == getType()) 544 return memrefToTensor.getMemref(); 545 return {}; 546 } 547 548 namespace { 549 550 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 551 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 552 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 553 554 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 555 PatternRewriter &rewriter) const final { 556 auto tensorCastOperand = 557 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 558 if (!tensorCastOperand) 559 return failure(); 560 auto srcTensorType = 561 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 562 if (!srcTensorType) 563 return failure(); 564 auto memrefType = MemRefType::get(srcTensorType.getShape(), 565 srcTensorType.getElementType()); 566 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 567 tensorCastOperand.getOperand()); 568 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 569 memref); 570 return success(); 571 } 572 }; 573 574 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 575 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 576 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 577 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 578 579 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 580 PatternRewriter &rewriter) const final { 581 // Only handle cases where a cast is needed. The other case is handled by 582 // the folder. 583 return foldToMemrefToTensorPair(rewriter, toMemref, 584 /*allowSameType=*/false); 585 } 586 }; 587 588 /// Fold a load on a to_memref operation into an tensor.extract on the 589 /// corresponding tensor. 590 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 591 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 592 593 LogicalResult matchAndRewrite(memref::LoadOp load, 594 PatternRewriter &rewriter) const override { 595 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 596 if (!toMemref) 597 return failure(); 598 599 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(), 600 load.indices()); 601 return success(); 602 } 603 }; 604 605 /// Fold dim of a to_memref into the dim of the tensor. 606 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 607 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 608 609 LogicalResult matchAndRewrite(memref::DimOp dimOp, 610 PatternRewriter &rewriter) const override { 611 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 612 if (!castOp) 613 return failure(); 614 Value newSource = castOp.getOperand(); 615 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 616 return success(); 617 } 618 }; 619 620 } // namespace 621 622 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 623 MLIRContext *context) { 624 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 625 context); 626 } 627 628 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 629 const BufferizationOptions &options) { 630 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 631 (void)foldToMemrefToTensorPair(rewriter, *this); 632 // Note: The return value of `bufferize` indicates whether there was an error 633 // or not. (And not whether the pattern matched or not.) 634 return success(); 635 } 636 637 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 638 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 639 .getOperation(); 640 } 641 642 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 643 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 644 } 645 646 //===----------------------------------------------------------------------===// 647 // TableGen'd op method definitions 648 //===----------------------------------------------------------------------===// 649 650 #define GET_OP_CLASSES 651 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 652