1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/Matchers.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 22 //===----------------------------------------------------------------------===// 23 // Helper functions 24 //===----------------------------------------------------------------------===// 25 26 FailureOr<Value> 27 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 28 MemRefType destType) { 29 auto srcType = value.getType().cast<MemRefType>(); 30 31 // Element type, rank and memory space must match. 32 if (srcType.getElementType() != destType.getElementType()) 33 return failure(); 34 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 35 return failure(); 36 if (srcType.getRank() != destType.getRank()) 37 return failure(); 38 39 // In case the affine maps are different, we may need to use a copy if we go 40 // from dynamic to static offset or stride (the canonicalization cannot know 41 // at this point that it is really cast compatible). 42 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 43 int64_t sourceOffset, targetOffset; 44 SmallVector<int64_t, 4> sourceStrides, targetStrides; 45 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 46 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 47 return false; 48 auto dynamicToStatic = [](int64_t a, int64_t b) { 49 return a == MemRefType::getDynamicStrideOrOffset() && 50 b != MemRefType::getDynamicStrideOrOffset(); 51 }; 52 if (dynamicToStatic(sourceOffset, targetOffset)) 53 return false; 54 for (auto it : zip(sourceStrides, targetStrides)) 55 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 56 return false; 57 return true; 58 }; 59 60 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 61 // ensure that we only generate casts that always succeed at runtime, we check 62 // a fix extra conditions in `isGuaranteedCastCompatible`. 63 if (memref::CastOp::areCastCompatible(srcType, destType) && 64 isGuaranteedCastCompatible(srcType, destType)) { 65 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 66 return casted; 67 } 68 69 auto loc = value.getLoc(); 70 SmallVector<Value, 4> dynamicOperands; 71 for (int i = 0; i < destType.getRank(); ++i) { 72 if (destType.getShape()[i] != ShapedType::kDynamicSize) 73 continue; 74 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 75 Value size = b.create<memref::DimOp>(loc, value, index); 76 dynamicOperands.push_back(size); 77 } 78 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 79 // BufferizableOpInterface impl of ToMemrefOp. 80 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 81 b.create<memref::CopyOp>(loc, value, copy); 82 return copy; 83 } 84 85 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 86 /// to_memref op are different, a memref.cast is needed. 87 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 88 RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { 89 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>(); 90 if (!memrefToTensor) 91 return failure(); 92 93 Type srcType = memrefToTensor.getMemref().getType(); 94 Type destType = toMemref.getType(); 95 96 // Directly rewrite if the type did not change. 97 if (srcType == destType) { 98 // Function can be configured to only handle cases where a cast is needed. 99 if (!allowSameType) 100 return failure(); 101 rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); 102 return success(); 103 } 104 105 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 106 auto rankedDestType = destType.dyn_cast<MemRefType>(); 107 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 108 109 // Ranked memref -> Ranked memref cast. 110 if (rankedSrcType && rankedDestType) { 111 FailureOr<Value> replacement = castOrReallocMemRefValue( 112 rewriter, memrefToTensor.getMemref(), rankedDestType); 113 if (failed(replacement)) 114 return failure(); 115 116 rewriter.replaceOp(toMemref, *replacement); 117 return success(); 118 } 119 120 // Unranked memref -> Ranked memref cast: May require a copy. 121 // TODO: Not implemented at the moment. 122 if (unrankedSrcType && rankedDestType) 123 return failure(); 124 125 // Unranked memref -> unranked memref cast 126 // Ranked memref -> unranked memref cast: No copy needed. 127 assert(memref::CastOp::areCastCompatible(srcType, destType) && 128 "expected that types are cast compatible"); 129 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 130 memrefToTensor.getMemref()); 131 return success(); 132 } 133 134 void mlir::bufferization::populateDynamicDimSizes( 135 OpBuilder &b, Location loc, Value shapedValue, 136 SmallVector<Value> &dynamicDims) { 137 auto shapedType = shapedValue.getType().cast<ShapedType>(); 138 for (int64_t i = 0; i < shapedType.getRank(); ++i) { 139 if (shapedType.isDynamicDim(i)) { 140 if (shapedType.isa<MemRefType>()) { 141 dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i)); 142 } else { 143 assert(shapedType.isa<RankedTensorType>() && "expected tensor"); 144 dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 145 } 146 } 147 } 148 } 149 150 //===----------------------------------------------------------------------===// 151 // AllocTensorOp 152 //===----------------------------------------------------------------------===// 153 154 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 155 const BufferizationOptions &options) { 156 OpBuilder::InsertionGuard g(rewriter); 157 Operation *op = this->getOperation(); 158 Location loc = getLoc(); 159 160 // Nothing to do for dead AllocTensorOps. 161 if (getOperation()->getUses().empty()) { 162 rewriter.eraseOp(getOperation()); 163 return success(); 164 } 165 166 // Create buffer allocation. 167 Value copyBuffer; 168 if (getCopy()) 169 copyBuffer = getBuffer(rewriter, getCopy(), options); 170 auto allocType = 171 MemRefType::get(getType().getShape(), getType().getElementType()); 172 SmallVector<Value> dynamicDims = getDynamicSizes(); 173 if (getCopy()) { 174 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); 175 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); 176 } 177 FailureOr<Value> alloc = 178 options.createAlloc(rewriter, loc, allocType, dynamicDims); 179 if (failed(alloc)) 180 return failure(); 181 182 // Create memory copy (if any). 183 if (getCopy()) { 184 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) 185 return failure(); 186 } 187 188 // Should the buffer be deallocated? 189 AnalysisState analysisState(options); 190 bool dealloc; 191 if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { 192 // AllocTensorOp has one result. 193 ArrayAttr escapeAttr = 194 op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>(); 195 dealloc = !escapeAttr[0].cast<BoolAttr>().getValue(); 196 } else { 197 // No "escape" annotation found. 198 if (options.createDeallocs) { 199 // Perform an ad-hoc analysis. 200 dealloc = !analysisState.isTensorYielded(getResult()); 201 } else { 202 dealloc = false; 203 } 204 } 205 206 // Replace op. 207 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 208 209 // Create buffer deallocation (if requested). 210 if (!dealloc) 211 return success(); 212 213 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); 214 if (failed(options.createDealloc(rewriter, loc, *alloc))) 215 return failure(); 216 return success(); 217 } 218 219 bool AllocTensorOp::isMemoryWrite(OpResult opResult, 220 const AnalysisState &state) { 221 // AllocTensorOps do not write unless they have a `copy` value. 222 return static_cast<bool>(getCopy()); 223 } 224 225 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, 226 const AnalysisState &state) { 227 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 228 "expected copy operand"); 229 return true; 230 } 231 232 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, 233 const AnalysisState &state) { 234 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 235 "expected copy operand"); 236 return false; 237 } 238 239 SmallVector<OpResult> 240 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand, 241 const AnalysisState &state) { 242 // This is a new allocation. It does not alias with any other buffer. 243 return {}; 244 } 245 246 LogicalResult AllocTensorOp::verify() { 247 if (getCopy() && !getDynamicSizes().empty()) 248 return emitError("dynamic sizes not needed when copying a tensor"); 249 if (!getCopy() && getType().getNumDynamicDims() != 250 static_cast<int64_t>(getDynamicSizes().size())) 251 return emitError("expected ") 252 << getType().getNumDynamicDims() << " dynamic sizes"; 253 if (getCopy() && getCopy().getType() != getType()) 254 return emitError("expected that `copy` and return type match"); 255 256 // For sparse tensor allocation, we require that none of its 257 // uses escapes the function boundary directly. 258 if (sparse_tensor::getSparseTensorEncoding(getType())) { 259 for (auto &use : getOperation()->getUses()) 260 if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>( 261 use.getOwner())) 262 return emitError("sparse tensor allocation should not escape function"); 263 } 264 265 return success(); 266 } 267 268 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 269 RankedTensorType type, ValueRange dynamicSizes) { 270 build(builder, result, type, dynamicSizes, /*copy=*/Value()); 271 } 272 273 namespace { 274 /// Change the type of the result of a `bufferization.alloc_tensor` by making 275 /// the result type statically sized along dimension that in the original 276 /// operation where defined as dynamic, but the size was defined using a 277 /// `constant` op. For example: 278 /// 279 /// %c5 = arith.constant 5: index 280 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32> 281 /// 282 /// to 283 /// 284 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32> 285 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 286 using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 287 288 LogicalResult matchAndRewrite(AllocTensorOp op, 289 PatternRewriter &rewriter) const override { 290 if (op.getCopy()) 291 return failure(); 292 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape()); 293 SmallVector<Value> newDynamicSizes; 294 unsigned int dynValCounter = 0; 295 for (int64_t i = 0; i < op.getType().getRank(); ++i) { 296 if (!op.isDynamicDim(i)) 297 continue; 298 Value value = op.getDynamicSizes()[dynValCounter++]; 299 APInt intVal; 300 if (matchPattern(value, m_ConstantInt(&intVal))) { 301 newShape[i] = intVal.getSExtValue(); 302 } else { 303 newDynamicSizes.push_back(value); 304 } 305 } 306 RankedTensorType newType = RankedTensorType::get( 307 newShape, op.getType().getElementType(), op.getType().getEncoding()); 308 if (newType == op.getType()) 309 return failure(); 310 auto newOp = rewriter.create<AllocTensorOp>( 311 op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); 312 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 313 return success(); 314 } 315 }; 316 317 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 318 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 319 320 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 321 PatternRewriter &rewriter) const override { 322 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 323 auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>(); 324 if (!allocTensorOp || !maybeConstantIndex) 325 return failure(); 326 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) 327 return failure(); 328 rewriter.replaceOp( 329 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); 330 return success(); 331 } 332 }; 333 } // namespace 334 335 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 336 MLIRContext *ctx) { 337 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 338 } 339 340 LogicalResult AllocTensorOp::reifyResultShapes( 341 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 342 auto shapes = llvm::to_vector<4>(llvm::map_range( 343 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 344 if (isDynamicDim(dim)) 345 return getDynamicSize(builder, dim); 346 return builder.create<arith::ConstantIndexOp>(getLoc(), 347 getStaticSize(dim)); 348 })); 349 reifiedReturnShapes.emplace_back(std::move(shapes)); 350 return success(); 351 } 352 353 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { 354 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands; 355 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || 356 parser.parseRParen()) 357 return failure(); 358 ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); 359 OpAsmParser::UnresolvedOperand copyOperand; 360 if (copyKeyword.succeeded()) 361 if (parser.parseLParen() || parser.parseOperand(copyOperand) || 362 parser.parseRParen()) 363 return failure(); 364 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) 365 return failure(); 366 367 TensorType type; 368 if (parser.parseCustomTypeWithFallback(type)) 369 return failure(); 370 result.addTypes(type); 371 372 Type indexType = parser.getBuilder().getIndexType(); 373 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) 374 return failure(); 375 if (copyKeyword.succeeded()) 376 if (parser.resolveOperand(copyOperand, type, result.operands)) 377 return failure(); 378 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), 379 parser.getBuilder().getI32VectorAttr( 380 {static_cast<int32_t>(dynamicSizesOperands.size()), 381 static_cast<int32_t>(copyKeyword.succeeded())})); 382 return success(); 383 } 384 385 void AllocTensorOp::print(OpAsmPrinter &p) { 386 p << "(" << getDynamicSizes() << ")"; 387 if (getCopy()) 388 p << " copy(" << getCopy() << ")"; 389 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 390 AllocTensorOp::getOperandSegmentSizeAttr()}); 391 p << " : "; 392 auto type = getResult().getType(); 393 if (auto validType = type.dyn_cast<::mlir::TensorType>()) 394 p.printStrippedAttrOrType(validType); 395 else 396 p << type; 397 } 398 399 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { 400 assert(isDynamicDim(idx) && "expected dynamic dim"); 401 if (getCopy()) 402 return b.create<tensor::DimOp>(getLoc(), getCopy(), idx); 403 return getOperand(getIndexOfDynamicSize(idx)); 404 } 405 406 //===----------------------------------------------------------------------===// 407 // CloneOp 408 //===----------------------------------------------------------------------===// 409 410 void CloneOp::getEffects( 411 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 412 &effects) { 413 effects.emplace_back(MemoryEffects::Read::get(), getInput(), 414 SideEffects::DefaultResource::get()); 415 effects.emplace_back(MemoryEffects::Write::get(), getOutput(), 416 SideEffects::DefaultResource::get()); 417 effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(), 418 SideEffects::DefaultResource::get()); 419 } 420 421 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 422 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 423 } 424 425 namespace { 426 427 /// Merge the clone and its source (by converting the clone to a cast) when 428 /// possible. 429 struct SimplifyClones : public OpRewritePattern<CloneOp> { 430 using OpRewritePattern<CloneOp>::OpRewritePattern; 431 432 LogicalResult matchAndRewrite(CloneOp cloneOp, 433 PatternRewriter &rewriter) const override { 434 if (cloneOp.use_empty()) { 435 rewriter.eraseOp(cloneOp); 436 return success(); 437 } 438 439 Value source = cloneOp.getInput(); 440 441 // This only finds dealloc operations for the immediate value. It should 442 // also consider aliases. That would also make the safety check below 443 // redundant. 444 llvm::Optional<Operation *> maybeCloneDeallocOp = 445 memref::findDealloc(cloneOp.getOutput()); 446 // Skip if either of them has > 1 deallocate operations. 447 if (!maybeCloneDeallocOp.has_value()) 448 return failure(); 449 llvm::Optional<Operation *> maybeSourceDeallocOp = 450 memref::findDealloc(source); 451 if (!maybeSourceDeallocOp.has_value()) 452 return failure(); 453 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 454 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 455 456 // If both are deallocated in the same block, their in-block lifetimes 457 // might not fully overlap, so we cannot decide which one to drop. 458 if (cloneDeallocOp && sourceDeallocOp && 459 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 460 return failure(); 461 462 Block *currentBlock = cloneOp->getBlock(); 463 Operation *redundantDealloc = nullptr; 464 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 465 redundantDealloc = cloneDeallocOp; 466 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 467 redundantDealloc = sourceDeallocOp; 468 } 469 470 if (!redundantDealloc) 471 return failure(); 472 473 // Safety check that there are no other deallocations inbetween 474 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 475 // of source before the uses of the clone. With alias information, we could 476 // restrict this to only fail of the dealloc's operand is an alias 477 // of the source. 478 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 479 pos = pos->getNextNode()) { 480 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 481 if (!effectInterface) 482 continue; 483 if (effectInterface.hasEffect<MemoryEffects::Free>()) 484 return failure(); 485 } 486 487 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 488 source); 489 rewriter.eraseOp(redundantDealloc); 490 return success(); 491 } 492 }; 493 494 } // namespace 495 496 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 497 MLIRContext *context) { 498 results.add<SimplifyClones>(context); 499 } 500 501 //===----------------------------------------------------------------------===// 502 // ToTensorOp 503 //===----------------------------------------------------------------------===// 504 505 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 506 if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>()) 507 // Approximate alias analysis by conservatively folding only when no there 508 // is no interleaved operation. 509 if (toMemref->getBlock() == this->getOperation()->getBlock() && 510 toMemref->getNextNode() == this->getOperation()) 511 return toMemref.getTensor(); 512 return {}; 513 } 514 515 namespace { 516 517 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 518 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 519 520 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 521 PatternRewriter &rewriter) const override { 522 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 523 if (!memrefToTensorOp) 524 return failure(); 525 526 rewriter.replaceOpWithNewOp<memref::DimOp>( 527 dimOp, memrefToTensorOp.getMemref(), dimOp.index()); 528 return success(); 529 } 530 }; 531 532 } // namespace 533 534 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 535 MLIRContext *context) { 536 results.add<DimOfToTensorFolder>(context); 537 } 538 539 //===----------------------------------------------------------------------===// 540 // ToMemrefOp 541 //===----------------------------------------------------------------------===// 542 543 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 544 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>()) 545 if (memrefToTensor.getMemref().getType() == getType()) 546 return memrefToTensor.getMemref(); 547 return {}; 548 } 549 550 namespace { 551 552 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 553 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 554 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 555 556 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 557 PatternRewriter &rewriter) const final { 558 auto tensorCastOperand = 559 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 560 if (!tensorCastOperand) 561 return failure(); 562 auto srcTensorType = 563 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 564 if (!srcTensorType) 565 return failure(); 566 auto memrefType = MemRefType::get(srcTensorType.getShape(), 567 srcTensorType.getElementType()); 568 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 569 tensorCastOperand.getOperand()); 570 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 571 memref); 572 return success(); 573 } 574 }; 575 576 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 577 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 578 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 579 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 580 581 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 582 PatternRewriter &rewriter) const final { 583 // Only handle cases where a cast is needed. The other case is handled by 584 // the folder. 585 return foldToMemrefToTensorPair(rewriter, toMemref, 586 /*allowSameType=*/false); 587 } 588 }; 589 590 /// Fold a load on a to_memref operation into an tensor.extract on the 591 /// corresponding tensor. 592 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 593 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 594 595 LogicalResult matchAndRewrite(memref::LoadOp load, 596 PatternRewriter &rewriter) const override { 597 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 598 if (!toMemref) 599 return failure(); 600 601 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(), 602 load.indices()); 603 return success(); 604 } 605 }; 606 607 /// Fold dim of a to_memref into the dim of the tensor. 608 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 609 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 610 611 LogicalResult matchAndRewrite(memref::DimOp dimOp, 612 PatternRewriter &rewriter) const override { 613 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 614 if (!castOp) 615 return failure(); 616 Value newSource = castOp.getOperand(); 617 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 618 return success(); 619 } 620 }; 621 622 } // namespace 623 624 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 625 MLIRContext *context) { 626 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 627 context); 628 } 629 630 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 631 const BufferizationOptions &options) { 632 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 633 (void)foldToMemrefToTensorPair(rewriter, *this); 634 // Note: The return value of `bufferize` indicates whether there was an error 635 // or not. (And not whether the pattern matched or not.) 636 return success(); 637 } 638 639 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 640 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 641 .getOperation(); 642 } 643 644 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 645 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 646 } 647 648 //===----------------------------------------------------------------------===// 649 // TableGen'd op method definitions 650 //===----------------------------------------------------------------------===// 651 652 #define GET_OP_CLASSES 653 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 654