1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/Matchers.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 22 //===----------------------------------------------------------------------===// 23 // Helper functions 24 //===----------------------------------------------------------------------===// 25 26 FailureOr<Value> 27 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 28 MemRefType destType) { 29 auto srcType = value.getType().cast<MemRefType>(); 30 31 // Element type, rank and memory space must match. 32 if (srcType.getElementType() != destType.getElementType()) 33 return failure(); 34 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 35 return failure(); 36 if (srcType.getRank() != destType.getRank()) 37 return failure(); 38 39 // In case the affine maps are different, we may need to use a copy if we go 40 // from dynamic to static offset or stride (the canonicalization cannot know 41 // at this point that it is really cast compatible). 42 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 43 int64_t sourceOffset, targetOffset; 44 SmallVector<int64_t, 4> sourceStrides, targetStrides; 45 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 46 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 47 return false; 48 auto dynamicToStatic = [](int64_t a, int64_t b) { 49 return a == MemRefType::getDynamicStrideOrOffset() && 50 b != MemRefType::getDynamicStrideOrOffset(); 51 }; 52 if (dynamicToStatic(sourceOffset, targetOffset)) 53 return false; 54 for (auto it : zip(sourceStrides, targetStrides)) 55 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 56 return false; 57 return true; 58 }; 59 60 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 61 // ensure that we only generate casts that always succeed at runtime, we check 62 // a fix extra conditions in `isGuaranteedCastCompatible`. 63 if (memref::CastOp::areCastCompatible(srcType, destType) && 64 isGuaranteedCastCompatible(srcType, destType)) { 65 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 66 return casted; 67 } 68 69 auto loc = value.getLoc(); 70 SmallVector<Value, 4> dynamicOperands; 71 for (int i = 0; i < destType.getRank(); ++i) { 72 if (destType.getShape()[i] != ShapedType::kDynamicSize) 73 continue; 74 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 75 Value size = b.create<memref::DimOp>(loc, value, index); 76 dynamicOperands.push_back(size); 77 } 78 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 79 // BufferizableOpInterface impl of ToMemrefOp. 80 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 81 b.create<memref::CopyOp>(loc, value, copy); 82 return copy; 83 } 84 85 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 86 /// to_memref op are different, a memref.cast is needed. 87 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 88 RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { 89 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>(); 90 if (!memrefToTensor) 91 return failure(); 92 93 Type srcType = memrefToTensor.getMemref().getType(); 94 Type destType = toMemref.getType(); 95 96 // Directly rewrite if the type did not change. 97 if (srcType == destType) { 98 // Function can be configured to only handle cases where a cast is needed. 99 if (!allowSameType) 100 return failure(); 101 rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); 102 return success(); 103 } 104 105 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 106 auto rankedDestType = destType.dyn_cast<MemRefType>(); 107 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 108 109 // Ranked memref -> Ranked memref cast. 110 if (rankedSrcType && rankedDestType) { 111 FailureOr<Value> replacement = castOrReallocMemRefValue( 112 rewriter, memrefToTensor.getMemref(), rankedDestType); 113 if (failed(replacement)) 114 return failure(); 115 116 rewriter.replaceOp(toMemref, *replacement); 117 return success(); 118 } 119 120 // Unranked memref -> Ranked memref cast: May require a copy. 121 // TODO: Not implemented at the moment. 122 if (unrankedSrcType && rankedDestType) 123 return failure(); 124 125 // Unranked memref -> unranked memref cast 126 // Ranked memref -> unranked memref cast: No copy needed. 127 assert(memref::CastOp::areCastCompatible(srcType, destType) && 128 "expected that types are cast compatible"); 129 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 130 memrefToTensor.getMemref()); 131 return success(); 132 } 133 134 void mlir::bufferization::populateDynamicDimSizes( 135 OpBuilder &b, Location loc, Value shapedValue, 136 SmallVector<Value> &dynamicDims) { 137 auto shapedType = shapedValue.getType().cast<ShapedType>(); 138 for (int64_t i = 0; i < shapedType.getRank(); ++i) { 139 if (shapedType.isDynamicDim(i)) { 140 if (shapedType.isa<MemRefType>()) { 141 dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i)); 142 } else { 143 assert(shapedType.isa<RankedTensorType>() && "expected tensor"); 144 dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 145 } 146 } 147 } 148 } 149 150 //===----------------------------------------------------------------------===// 151 // AllocTensorOp 152 //===----------------------------------------------------------------------===// 153 154 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 155 const BufferizationOptions &options) { 156 OpBuilder::InsertionGuard g(rewriter); 157 Operation *op = this->getOperation(); 158 Location loc = getLoc(); 159 160 // Nothing to do for dead AllocTensorOps. 161 if (getOperation()->getUses().empty()) { 162 rewriter.eraseOp(getOperation()); 163 return success(); 164 } 165 166 // Get "copy" buffer. 167 Value copyBuffer; 168 if (getCopy()) { 169 FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options); 170 if (failed(maybeCopyBuffer)) 171 return failure(); 172 copyBuffer = *maybeCopyBuffer; 173 } 174 175 // Compute memory space of this allocation. 176 unsigned memorySpace; 177 if (getMemorySpace().hasValue()) { 178 memorySpace = *getMemorySpace(); 179 } else if (getCopy()) { 180 memorySpace = 181 copyBuffer.getType().cast<BaseMemRefType>().getMemorySpaceAsInt(); 182 } else if (options.defaultMemorySpace.hasValue()) { 183 memorySpace = *options.defaultMemorySpace; 184 } else { 185 return op->emitError("could not infer memory space"); 186 } 187 188 // Create memory allocation. 189 auto allocType = 190 MemRefType::get(getType().getShape(), getType().getElementType(), 191 AffineMap(), memorySpace); 192 SmallVector<Value> dynamicDims = getDynamicSizes(); 193 if (getCopy()) { 194 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); 195 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); 196 } 197 FailureOr<Value> alloc = 198 options.createAlloc(rewriter, loc, allocType, dynamicDims); 199 if (failed(alloc)) 200 return failure(); 201 202 // Create memory copy (if any). 203 if (getCopy()) { 204 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) 205 return failure(); 206 } 207 208 // Should the buffer be deallocated? 209 AnalysisState analysisState(options); 210 bool dealloc; 211 if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { 212 // AllocTensorOp has one result. 213 ArrayAttr escapeAttr = 214 op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>(); 215 dealloc = !escapeAttr[0].cast<BoolAttr>().getValue(); 216 } else { 217 // No "escape" annotation found. 218 if (options.createDeallocs) { 219 // Perform an ad-hoc analysis. 220 dealloc = !analysisState.isTensorYielded(getResult()); 221 } else { 222 dealloc = false; 223 } 224 } 225 226 // Replace op. 227 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 228 229 // Create buffer deallocation (if requested). 230 if (!dealloc) 231 return success(); 232 233 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); 234 if (failed(options.createDealloc(rewriter, loc, *alloc))) 235 return failure(); 236 return success(); 237 } 238 239 bool AllocTensorOp::isMemoryWrite(OpResult opResult, 240 const AnalysisState &state) { 241 // AllocTensorOps do not write unless they have a `copy` value. 242 return static_cast<bool>(getCopy()); 243 } 244 245 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, 246 const AnalysisState &state) { 247 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 248 "expected copy operand"); 249 return true; 250 } 251 252 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, 253 const AnalysisState &state) { 254 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 255 "expected copy operand"); 256 return false; 257 } 258 259 SmallVector<OpResult> 260 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand, 261 const AnalysisState &state) { 262 // This is a new allocation. It does not alias with any other buffer. 263 return {}; 264 } 265 266 LogicalResult AllocTensorOp::verify() { 267 if (getCopy() && !getDynamicSizes().empty()) 268 return emitError("dynamic sizes not needed when copying a tensor"); 269 if (!getCopy() && getType().getNumDynamicDims() != 270 static_cast<int64_t>(getDynamicSizes().size())) 271 return emitError("expected ") 272 << getType().getNumDynamicDims() << " dynamic sizes"; 273 if (getCopy() && getCopy().getType() != getType()) 274 return emitError("expected that `copy` and return type match"); 275 276 // For sparse tensor allocation, we require that none of its 277 // uses escapes the function boundary directly. 278 if (sparse_tensor::getSparseTensorEncoding(getType())) { 279 for (auto &use : getOperation()->getUses()) 280 if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>( 281 use.getOwner())) 282 return emitError("sparse tensor allocation should not escape function"); 283 } 284 285 return success(); 286 } 287 288 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 289 RankedTensorType type, ValueRange dynamicSizes) { 290 build(builder, result, type, dynamicSizes, /*copy=*/Value(), 291 /*memory_space=*/IntegerAttr()); 292 } 293 294 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 295 RankedTensorType type, ValueRange dynamicSizes, 296 Value copy) { 297 build(builder, result, type, dynamicSizes, copy, 298 /*memory_space=*/IntegerAttr()); 299 } 300 301 namespace { 302 /// Change the type of the result of a `bufferization.alloc_tensor` by making 303 /// the result type statically sized along dimension that in the original 304 /// operation where defined as dynamic, but the size was defined using a 305 /// `constant` op. For example: 306 /// 307 /// %c5 = arith.constant 5: index 308 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32> 309 /// 310 /// to 311 /// 312 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32> 313 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 314 using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 315 316 LogicalResult matchAndRewrite(AllocTensorOp op, 317 PatternRewriter &rewriter) const override { 318 if (op.getCopy()) 319 return failure(); 320 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape()); 321 SmallVector<Value> newDynamicSizes; 322 unsigned int dynValCounter = 0; 323 for (int64_t i = 0; i < op.getType().getRank(); ++i) { 324 if (!op.isDynamicDim(i)) 325 continue; 326 Value value = op.getDynamicSizes()[dynValCounter++]; 327 APInt intVal; 328 if (matchPattern(value, m_ConstantInt(&intVal))) { 329 newShape[i] = intVal.getSExtValue(); 330 } else { 331 newDynamicSizes.push_back(value); 332 } 333 } 334 RankedTensorType newType = RankedTensorType::get( 335 newShape, op.getType().getElementType(), op.getType().getEncoding()); 336 if (newType == op.getType()) 337 return failure(); 338 auto newOp = rewriter.create<AllocTensorOp>( 339 op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); 340 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 341 return success(); 342 } 343 }; 344 345 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 346 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 347 348 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 349 PatternRewriter &rewriter) const override { 350 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 351 auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>(); 352 if (!allocTensorOp || !maybeConstantIndex) 353 return failure(); 354 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) 355 return failure(); 356 rewriter.replaceOp( 357 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); 358 return success(); 359 } 360 }; 361 } // namespace 362 363 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 364 MLIRContext *ctx) { 365 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 366 } 367 368 LogicalResult AllocTensorOp::reifyResultShapes( 369 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 370 auto shapes = llvm::to_vector<4>(llvm::map_range( 371 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 372 if (isDynamicDim(dim)) 373 return getDynamicSize(builder, dim); 374 return builder.create<arith::ConstantIndexOp>(getLoc(), 375 getStaticSize(dim)); 376 })); 377 reifiedReturnShapes.emplace_back(std::move(shapes)); 378 return success(); 379 } 380 381 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { 382 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands; 383 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || 384 parser.parseRParen()) 385 return failure(); 386 ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); 387 OpAsmParser::UnresolvedOperand copyOperand; 388 if (copyKeyword.succeeded()) 389 if (parser.parseLParen() || parser.parseOperand(copyOperand) || 390 parser.parseRParen()) 391 return failure(); 392 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) 393 return failure(); 394 395 TensorType type; 396 if (parser.parseCustomTypeWithFallback(type)) 397 return failure(); 398 result.addTypes(type); 399 400 Type indexType = parser.getBuilder().getIndexType(); 401 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) 402 return failure(); 403 if (copyKeyword.succeeded()) 404 if (parser.resolveOperand(copyOperand, type, result.operands)) 405 return failure(); 406 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), 407 parser.getBuilder().getI32VectorAttr( 408 {static_cast<int32_t>(dynamicSizesOperands.size()), 409 static_cast<int32_t>(copyKeyword.succeeded())})); 410 return success(); 411 } 412 413 void AllocTensorOp::print(OpAsmPrinter &p) { 414 p << "(" << getDynamicSizes() << ")"; 415 if (getCopy()) 416 p << " copy(" << getCopy() << ")"; 417 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 418 AllocTensorOp::getOperandSegmentSizeAttr()}); 419 p << " : "; 420 auto type = getResult().getType(); 421 if (auto validType = type.dyn_cast<::mlir::TensorType>()) 422 p.printStrippedAttrOrType(validType); 423 else 424 p << type; 425 } 426 427 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { 428 assert(isDynamicDim(idx) && "expected dynamic dim"); 429 if (getCopy()) 430 return b.create<tensor::DimOp>(getLoc(), getCopy(), idx); 431 return getOperand(getIndexOfDynamicSize(idx)); 432 } 433 434 //===----------------------------------------------------------------------===// 435 // CloneOp 436 //===----------------------------------------------------------------------===// 437 438 void CloneOp::getEffects( 439 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 440 &effects) { 441 effects.emplace_back(MemoryEffects::Read::get(), getInput(), 442 SideEffects::DefaultResource::get()); 443 effects.emplace_back(MemoryEffects::Write::get(), getOutput(), 444 SideEffects::DefaultResource::get()); 445 effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(), 446 SideEffects::DefaultResource::get()); 447 } 448 449 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 450 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 451 } 452 453 namespace { 454 455 /// Merge the clone and its source (by converting the clone to a cast) when 456 /// possible. 457 struct SimplifyClones : public OpRewritePattern<CloneOp> { 458 using OpRewritePattern<CloneOp>::OpRewritePattern; 459 460 LogicalResult matchAndRewrite(CloneOp cloneOp, 461 PatternRewriter &rewriter) const override { 462 if (cloneOp.use_empty()) { 463 rewriter.eraseOp(cloneOp); 464 return success(); 465 } 466 467 Value source = cloneOp.getInput(); 468 469 // This only finds dealloc operations for the immediate value. It should 470 // also consider aliases. That would also make the safety check below 471 // redundant. 472 llvm::Optional<Operation *> maybeCloneDeallocOp = 473 memref::findDealloc(cloneOp.getOutput()); 474 // Skip if either of them has > 1 deallocate operations. 475 if (!maybeCloneDeallocOp.hasValue()) 476 return failure(); 477 llvm::Optional<Operation *> maybeSourceDeallocOp = 478 memref::findDealloc(source); 479 if (!maybeSourceDeallocOp.hasValue()) 480 return failure(); 481 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 482 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 483 484 // If both are deallocated in the same block, their in-block lifetimes 485 // might not fully overlap, so we cannot decide which one to drop. 486 if (cloneDeallocOp && sourceDeallocOp && 487 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 488 return failure(); 489 490 Block *currentBlock = cloneOp->getBlock(); 491 Operation *redundantDealloc = nullptr; 492 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 493 redundantDealloc = cloneDeallocOp; 494 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 495 redundantDealloc = sourceDeallocOp; 496 } 497 498 if (!redundantDealloc) 499 return failure(); 500 501 // Safety check that there are no other deallocations inbetween 502 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 503 // of source before the uses of the clone. With alias information, we could 504 // restrict this to only fail of the dealloc's operand is an alias 505 // of the source. 506 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 507 pos = pos->getNextNode()) { 508 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 509 if (!effectInterface) 510 continue; 511 if (effectInterface.hasEffect<MemoryEffects::Free>()) 512 return failure(); 513 } 514 515 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 516 source); 517 rewriter.eraseOp(redundantDealloc); 518 return success(); 519 } 520 }; 521 522 } // namespace 523 524 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 525 MLIRContext *context) { 526 results.add<SimplifyClones>(context); 527 } 528 529 //===----------------------------------------------------------------------===// 530 // ToTensorOp 531 //===----------------------------------------------------------------------===// 532 533 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 534 if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>()) 535 // Approximate alias analysis by conservatively folding only when no there 536 // is no interleaved operation. 537 if (toMemref->getBlock() == this->getOperation()->getBlock() && 538 toMemref->getNextNode() == this->getOperation()) 539 return toMemref.getTensor(); 540 return {}; 541 } 542 543 namespace { 544 545 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 546 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 547 548 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 549 PatternRewriter &rewriter) const override { 550 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 551 if (!memrefToTensorOp) 552 return failure(); 553 554 rewriter.replaceOpWithNewOp<memref::DimOp>( 555 dimOp, memrefToTensorOp.getMemref(), dimOp.index()); 556 return success(); 557 } 558 }; 559 560 } // namespace 561 562 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 563 MLIRContext *context) { 564 results.add<DimOfToTensorFolder>(context); 565 } 566 567 //===----------------------------------------------------------------------===// 568 // ToMemrefOp 569 //===----------------------------------------------------------------------===// 570 571 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 572 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>()) 573 if (memrefToTensor.getMemref().getType() == getType()) 574 return memrefToTensor.getMemref(); 575 return {}; 576 } 577 578 namespace { 579 580 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 581 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 582 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 583 584 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 585 PatternRewriter &rewriter) const final { 586 auto tensorCastOperand = 587 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 588 if (!tensorCastOperand) 589 return failure(); 590 auto srcTensorType = 591 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 592 if (!srcTensorType) 593 return failure(); 594 auto memrefType = MemRefType::get(srcTensorType.getShape(), 595 srcTensorType.getElementType()); 596 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 597 tensorCastOperand.getOperand()); 598 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 599 memref); 600 return success(); 601 } 602 }; 603 604 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 605 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 606 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 607 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 608 609 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 610 PatternRewriter &rewriter) const final { 611 // Only handle cases where a cast is needed. The other case is handled by 612 // the folder. 613 return foldToMemrefToTensorPair(rewriter, toMemref, 614 /*allowSameType=*/false); 615 } 616 }; 617 618 /// Fold a load on a to_memref operation into an tensor.extract on the 619 /// corresponding tensor. 620 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 621 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 622 623 LogicalResult matchAndRewrite(memref::LoadOp load, 624 PatternRewriter &rewriter) const override { 625 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 626 if (!toMemref) 627 return failure(); 628 629 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(), 630 load.indices()); 631 return success(); 632 } 633 }; 634 635 /// Fold dim of a to_memref into the dim of the tensor. 636 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 637 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 638 639 LogicalResult matchAndRewrite(memref::DimOp dimOp, 640 PatternRewriter &rewriter) const override { 641 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 642 if (!castOp) 643 return failure(); 644 Value newSource = castOp.getOperand(); 645 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 646 return success(); 647 } 648 }; 649 650 } // namespace 651 652 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 653 MLIRContext *context) { 654 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 655 context); 656 } 657 658 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 659 const BufferizationOptions &options) { 660 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 661 (void)foldToMemrefToTensorPair(rewriter, *this); 662 // Note: The return value of `bufferize` indicates whether there was an error 663 // or not. (And not whether the pattern matched or not.) 664 return success(); 665 } 666 667 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 668 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 669 .getOperation(); 670 } 671 672 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 673 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // TableGen'd op method definitions 678 //===----------------------------------------------------------------------===// 679 680 #define GET_OP_CLASSES 681 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 682