1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/Matchers.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 22 //===----------------------------------------------------------------------===// 23 // Helper functions 24 //===----------------------------------------------------------------------===// 25 26 FailureOr<Value> 27 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 28 MemRefType destType) { 29 auto srcType = value.getType().cast<MemRefType>(); 30 31 // Element type, rank and memory space must match. 32 if (srcType.getElementType() != destType.getElementType()) 33 return failure(); 34 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 35 return failure(); 36 if (srcType.getRank() != destType.getRank()) 37 return failure(); 38 39 // In case the affine maps are different, we may need to use a copy if we go 40 // from dynamic to static offset or stride (the canonicalization cannot know 41 // at this point that it is really cast compatible). 42 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 43 int64_t sourceOffset, targetOffset; 44 SmallVector<int64_t, 4> sourceStrides, targetStrides; 45 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 46 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 47 return false; 48 auto dynamicToStatic = [](int64_t a, int64_t b) { 49 return a == MemRefType::getDynamicStrideOrOffset() && 50 b != MemRefType::getDynamicStrideOrOffset(); 51 }; 52 if (dynamicToStatic(sourceOffset, targetOffset)) 53 return false; 54 for (auto it : zip(sourceStrides, targetStrides)) 55 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 56 return false; 57 return true; 58 }; 59 60 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 61 // ensure that we only generate casts that always succeed at runtime, we check 62 // a fix extra conditions in `isGuaranteedCastCompatible`. 63 if (memref::CastOp::areCastCompatible(srcType, destType) && 64 isGuaranteedCastCompatible(srcType, destType)) { 65 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 66 return casted; 67 } 68 69 auto loc = value.getLoc(); 70 SmallVector<Value, 4> dynamicOperands; 71 for (int i = 0; i < destType.getRank(); ++i) { 72 if (destType.getShape()[i] != ShapedType::kDynamicSize) 73 continue; 74 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 75 Value size = b.create<memref::DimOp>(loc, value, index); 76 dynamicOperands.push_back(size); 77 } 78 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 79 // BufferizableOpInterface impl of ToMemrefOp. 80 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 81 b.create<memref::CopyOp>(loc, value, copy); 82 return copy; 83 } 84 85 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 86 /// to_memref op are different, a memref.cast is needed. 87 LogicalResult mlir::bufferization::foldToMemrefToTensorPair( 88 RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { 89 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>(); 90 if (!memrefToTensor) 91 return failure(); 92 93 Type srcType = memrefToTensor.getMemref().getType(); 94 Type destType = toMemref.getType(); 95 96 // Directly rewrite if the type did not change. 97 if (srcType == destType) { 98 // Function can be configured to only handle cases where a cast is needed. 99 if (!allowSameType) 100 return failure(); 101 rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); 102 return success(); 103 } 104 105 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 106 auto rankedDestType = destType.dyn_cast<MemRefType>(); 107 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 108 109 // Ranked memref -> Ranked memref cast. 110 if (rankedSrcType && rankedDestType) { 111 FailureOr<Value> replacement = castOrReallocMemRefValue( 112 rewriter, memrefToTensor.getMemref(), rankedDestType); 113 if (failed(replacement)) 114 return failure(); 115 116 rewriter.replaceOp(toMemref, *replacement); 117 return success(); 118 } 119 120 // Unranked memref -> Ranked memref cast: May require a copy. 121 // TODO: Not implemented at the moment. 122 if (unrankedSrcType && rankedDestType) 123 return failure(); 124 125 // Unranked memref -> unranked memref cast 126 // Ranked memref -> unranked memref cast: No copy needed. 127 assert(memref::CastOp::areCastCompatible(srcType, destType) && 128 "expected that types are cast compatible"); 129 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 130 memrefToTensor.getMemref()); 131 return success(); 132 } 133 134 void mlir::bufferization::populateDynamicDimSizes( 135 OpBuilder &b, Location loc, Value shapedValue, 136 SmallVector<Value> &dynamicDims) { 137 auto shapedType = shapedValue.getType().cast<ShapedType>(); 138 for (int64_t i = 0; i < shapedType.getRank(); ++i) { 139 if (shapedType.isDynamicDim(i)) { 140 if (shapedType.isa<MemRefType>()) { 141 dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i)); 142 } else { 143 assert(shapedType.isa<RankedTensorType>() && "expected tensor"); 144 dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 145 } 146 } 147 } 148 } 149 150 //===----------------------------------------------------------------------===// 151 // AllocTensorOp 152 //===----------------------------------------------------------------------===// 153 154 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 155 const BufferizationOptions &options) { 156 OpBuilder::InsertionGuard g(rewriter); 157 Operation *op = this->getOperation(); 158 Location loc = getLoc(); 159 160 // Nothing to do for dead AllocTensorOps. 161 if (getOperation()->getUses().empty()) { 162 rewriter.eraseOp(getOperation()); 163 return success(); 164 } 165 166 // Get "copy" buffer. 167 Value copyBuffer; 168 if (getCopy()) 169 copyBuffer = getBuffer(rewriter, getCopy(), options); 170 171 // Compute memory space of this allocation. 172 unsigned memorySpace; 173 if (getMemorySpace().hasValue()) { 174 memorySpace = *getMemorySpace(); 175 } else if (options.defaultMemorySpace.hasValue()) { 176 memorySpace = *options.defaultMemorySpace; 177 } else { 178 return op->emitError("could not infer memory space"); 179 } 180 181 // Create memory allocation. 182 auto allocType = 183 MemRefType::get(getType().getShape(), getType().getElementType(), 184 AffineMap(), memorySpace); 185 SmallVector<Value> dynamicDims = getDynamicSizes(); 186 if (getCopy()) { 187 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); 188 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); 189 } 190 FailureOr<Value> alloc = 191 options.createAlloc(rewriter, loc, allocType, dynamicDims); 192 if (failed(alloc)) 193 return failure(); 194 195 // Create memory copy (if any). 196 if (getCopy()) { 197 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) 198 return failure(); 199 } 200 201 // Should the buffer be deallocated? 202 AnalysisState analysisState(options); 203 bool dealloc; 204 if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { 205 // AllocTensorOp has one result. 206 ArrayAttr escapeAttr = 207 op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>(); 208 dealloc = !escapeAttr[0].cast<BoolAttr>().getValue(); 209 } else { 210 // No "escape" annotation found. 211 if (options.createDeallocs) { 212 // Perform an ad-hoc analysis. 213 dealloc = !analysisState.isTensorYielded(getResult()); 214 } else { 215 dealloc = false; 216 } 217 } 218 219 // Replace op. 220 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 221 222 // Create buffer deallocation (if requested). 223 if (!dealloc) 224 return success(); 225 226 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); 227 if (failed(options.createDealloc(rewriter, loc, *alloc))) 228 return failure(); 229 return success(); 230 } 231 232 bool AllocTensorOp::isMemoryWrite(OpResult opResult, 233 const AnalysisState &state) { 234 // AllocTensorOps do not write unless they have a `copy` value. 235 return static_cast<bool>(getCopy()); 236 } 237 238 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, 239 const AnalysisState &state) { 240 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 241 "expected copy operand"); 242 return true; 243 } 244 245 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, 246 const AnalysisState &state) { 247 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 248 "expected copy operand"); 249 return false; 250 } 251 252 SmallVector<OpResult> 253 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand, 254 const AnalysisState &state) { 255 // This is a new allocation. It does not alias with any other buffer. 256 return {}; 257 } 258 259 LogicalResult AllocTensorOp::verify() { 260 if (getCopy() && !getDynamicSizes().empty()) 261 return emitError("dynamic sizes not needed when copying a tensor"); 262 if (!getCopy() && getType().getNumDynamicDims() != 263 static_cast<int64_t>(getDynamicSizes().size())) 264 return emitError("expected ") 265 << getType().getNumDynamicDims() << " dynamic sizes"; 266 if (getCopy() && getCopy().getType() != getType()) 267 return emitError("expected that `copy` and return type match"); 268 269 // For sparse tensor allocation, we require that none of its 270 // uses escapes the function boundary directly. 271 if (sparse_tensor::getSparseTensorEncoding(getType())) { 272 for (auto &use : getOperation()->getUses()) 273 if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>( 274 use.getOwner())) 275 return emitError("sparse tensor allocation should not escape function"); 276 } 277 278 return success(); 279 } 280 281 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 282 RankedTensorType type, ValueRange dynamicSizes) { 283 build(builder, result, type, dynamicSizes, /*copy=*/Value(), 284 /*memory_space=*/BoolAttr()); 285 } 286 287 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 288 RankedTensorType type, ValueRange dynamicSizes, 289 Value copy) { 290 build(builder, result, type, dynamicSizes, copy, /*memory_space=*/BoolAttr()); 291 } 292 293 namespace { 294 /// Change the type of the result of a `bufferization.alloc_tensor` by making 295 /// the result type statically sized along dimension that in the original 296 /// operation where defined as dynamic, but the size was defined using a 297 /// `constant` op. For example: 298 /// 299 /// %c5 = arith.constant 5: index 300 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32> 301 /// 302 /// to 303 /// 304 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32> 305 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 306 using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 307 308 LogicalResult matchAndRewrite(AllocTensorOp op, 309 PatternRewriter &rewriter) const override { 310 if (op.getCopy()) 311 return failure(); 312 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape()); 313 SmallVector<Value> newDynamicSizes; 314 unsigned int dynValCounter = 0; 315 for (int64_t i = 0; i < op.getType().getRank(); ++i) { 316 if (!op.isDynamicDim(i)) 317 continue; 318 Value value = op.getDynamicSizes()[dynValCounter++]; 319 APInt intVal; 320 if (matchPattern(value, m_ConstantInt(&intVal))) { 321 newShape[i] = intVal.getSExtValue(); 322 } else { 323 newDynamicSizes.push_back(value); 324 } 325 } 326 RankedTensorType newType = RankedTensorType::get( 327 newShape, op.getType().getElementType(), op.getType().getEncoding()); 328 if (newType == op.getType()) 329 return failure(); 330 auto newOp = rewriter.create<AllocTensorOp>( 331 op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); 332 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 333 return success(); 334 } 335 }; 336 337 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 338 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 339 340 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 341 PatternRewriter &rewriter) const override { 342 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 343 auto allocTensorOp = dimOp.source().getDefiningOp<AllocTensorOp>(); 344 if (!allocTensorOp || !maybeConstantIndex) 345 return failure(); 346 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) 347 return failure(); 348 rewriter.replaceOp( 349 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); 350 return success(); 351 } 352 }; 353 } // namespace 354 355 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 356 MLIRContext *ctx) { 357 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 358 } 359 360 LogicalResult AllocTensorOp::reifyResultShapes( 361 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 362 auto shapes = llvm::to_vector<4>(llvm::map_range( 363 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 364 if (isDynamicDim(dim)) 365 return getDynamicSize(builder, dim); 366 return builder.create<arith::ConstantIndexOp>(getLoc(), 367 getStaticSize(dim)); 368 })); 369 reifiedReturnShapes.emplace_back(std::move(shapes)); 370 return success(); 371 } 372 373 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { 374 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands; 375 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || 376 parser.parseRParen()) 377 return failure(); 378 ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); 379 OpAsmParser::UnresolvedOperand copyOperand; 380 if (copyKeyword.succeeded()) 381 if (parser.parseLParen() || parser.parseOperand(copyOperand) || 382 parser.parseRParen()) 383 return failure(); 384 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) 385 return failure(); 386 387 TensorType type; 388 if (parser.parseCustomTypeWithFallback(type)) 389 return failure(); 390 result.addTypes(type); 391 392 Type indexType = parser.getBuilder().getIndexType(); 393 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) 394 return failure(); 395 if (copyKeyword.succeeded()) 396 if (parser.resolveOperand(copyOperand, type, result.operands)) 397 return failure(); 398 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), 399 parser.getBuilder().getI32VectorAttr( 400 {static_cast<int32_t>(dynamicSizesOperands.size()), 401 static_cast<int32_t>(copyKeyword.succeeded())})); 402 return success(); 403 } 404 405 void AllocTensorOp::print(OpAsmPrinter &p) { 406 p << "(" << getDynamicSizes() << ")"; 407 if (getCopy()) 408 p << " copy(" << getCopy() << ")"; 409 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 410 AllocTensorOp::getOperandSegmentSizeAttr()}); 411 p << " : "; 412 auto type = getResult().getType(); 413 if (auto validType = type.dyn_cast<::mlir::TensorType>()) 414 p.printStrippedAttrOrType(validType); 415 else 416 p << type; 417 } 418 419 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { 420 assert(isDynamicDim(idx) && "expected dynamic dim"); 421 if (getCopy()) 422 return b.create<tensor::DimOp>(getLoc(), getCopy(), idx); 423 return getOperand(getIndexOfDynamicSize(idx)); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // CloneOp 428 //===----------------------------------------------------------------------===// 429 430 void CloneOp::getEffects( 431 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 432 &effects) { 433 effects.emplace_back(MemoryEffects::Read::get(), getInput(), 434 SideEffects::DefaultResource::get()); 435 effects.emplace_back(MemoryEffects::Write::get(), getOutput(), 436 SideEffects::DefaultResource::get()); 437 effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(), 438 SideEffects::DefaultResource::get()); 439 } 440 441 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 442 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 443 } 444 445 namespace { 446 447 /// Merge the clone and its source (by converting the clone to a cast) when 448 /// possible. 449 struct SimplifyClones : public OpRewritePattern<CloneOp> { 450 using OpRewritePattern<CloneOp>::OpRewritePattern; 451 452 LogicalResult matchAndRewrite(CloneOp cloneOp, 453 PatternRewriter &rewriter) const override { 454 if (cloneOp.use_empty()) { 455 rewriter.eraseOp(cloneOp); 456 return success(); 457 } 458 459 Value source = cloneOp.getInput(); 460 461 // This only finds dealloc operations for the immediate value. It should 462 // also consider aliases. That would also make the safety check below 463 // redundant. 464 llvm::Optional<Operation *> maybeCloneDeallocOp = 465 memref::findDealloc(cloneOp.getOutput()); 466 // Skip if either of them has > 1 deallocate operations. 467 if (!maybeCloneDeallocOp.hasValue()) 468 return failure(); 469 llvm::Optional<Operation *> maybeSourceDeallocOp = 470 memref::findDealloc(source); 471 if (!maybeSourceDeallocOp.hasValue()) 472 return failure(); 473 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 474 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 475 476 // If both are deallocated in the same block, their in-block lifetimes 477 // might not fully overlap, so we cannot decide which one to drop. 478 if (cloneDeallocOp && sourceDeallocOp && 479 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 480 return failure(); 481 482 Block *currentBlock = cloneOp->getBlock(); 483 Operation *redundantDealloc = nullptr; 484 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 485 redundantDealloc = cloneDeallocOp; 486 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 487 redundantDealloc = sourceDeallocOp; 488 } 489 490 if (!redundantDealloc) 491 return failure(); 492 493 // Safety check that there are no other deallocations inbetween 494 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 495 // of source before the uses of the clone. With alias information, we could 496 // restrict this to only fail of the dealloc's operand is an alias 497 // of the source. 498 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 499 pos = pos->getNextNode()) { 500 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 501 if (!effectInterface) 502 continue; 503 if (effectInterface.hasEffect<MemoryEffects::Free>()) 504 return failure(); 505 } 506 507 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 508 source); 509 rewriter.eraseOp(redundantDealloc); 510 return success(); 511 } 512 }; 513 514 } // namespace 515 516 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 517 MLIRContext *context) { 518 results.add<SimplifyClones>(context); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // ToTensorOp 523 //===----------------------------------------------------------------------===// 524 525 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 526 if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>()) 527 // Approximate alias analysis by conservatively folding only when no there 528 // is no interleaved operation. 529 if (toMemref->getBlock() == this->getOperation()->getBlock() && 530 toMemref->getNextNode() == this->getOperation()) 531 return toMemref.getTensor(); 532 return {}; 533 } 534 535 namespace { 536 537 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 538 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 539 540 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 541 PatternRewriter &rewriter) const override { 542 auto memrefToTensorOp = dimOp.source().getDefiningOp<ToTensorOp>(); 543 if (!memrefToTensorOp) 544 return failure(); 545 546 rewriter.replaceOpWithNewOp<memref::DimOp>( 547 dimOp, memrefToTensorOp.getMemref(), dimOp.index()); 548 return success(); 549 } 550 }; 551 552 } // namespace 553 554 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 555 MLIRContext *context) { 556 results.add<DimOfToTensorFolder>(context); 557 } 558 559 //===----------------------------------------------------------------------===// 560 // ToMemrefOp 561 //===----------------------------------------------------------------------===// 562 563 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 564 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>()) 565 if (memrefToTensor.getMemref().getType() == getType()) 566 return memrefToTensor.getMemref(); 567 return {}; 568 } 569 570 namespace { 571 572 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 573 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 574 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 575 576 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 577 PatternRewriter &rewriter) const final { 578 auto tensorCastOperand = 579 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 580 if (!tensorCastOperand) 581 return failure(); 582 auto srcTensorType = 583 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 584 if (!srcTensorType) 585 return failure(); 586 auto memrefType = MemRefType::get(srcTensorType.getShape(), 587 srcTensorType.getElementType()); 588 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 589 tensorCastOperand.getOperand()); 590 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 591 memref); 592 return success(); 593 } 594 }; 595 596 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to 597 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. 598 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> { 599 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 600 601 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 602 PatternRewriter &rewriter) const final { 603 // Only handle cases where a cast is needed. The other case is handled by 604 // the folder. 605 return foldToMemrefToTensorPair(rewriter, toMemref, 606 /*allowSameType=*/false); 607 } 608 }; 609 610 /// Fold a load on a to_memref operation into an tensor.extract on the 611 /// corresponding tensor. 612 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 613 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 614 615 LogicalResult matchAndRewrite(memref::LoadOp load, 616 PatternRewriter &rewriter) const override { 617 auto toMemref = load.memref().getDefiningOp<ToMemrefOp>(); 618 if (!toMemref) 619 return failure(); 620 621 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(), 622 load.indices()); 623 return success(); 624 } 625 }; 626 627 /// Fold dim of a to_memref into the dim of the tensor. 628 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 629 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 630 631 LogicalResult matchAndRewrite(memref::DimOp dimOp, 632 PatternRewriter &rewriter) const override { 633 auto castOp = dimOp.source().getDefiningOp<ToMemrefOp>(); 634 if (!castOp) 635 return failure(); 636 Value newSource = castOp.getOperand(); 637 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 638 return success(); 639 } 640 }; 641 642 } // namespace 643 644 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 645 MLIRContext *context) { 646 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>( 647 context); 648 } 649 650 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 651 const BufferizationOptions &options) { 652 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 653 (void)foldToMemrefToTensorPair(rewriter, *this); 654 // Note: The return value of `bufferize` indicates whether there was an error 655 // or not. (And not whether the pattern matched or not.) 656 return success(); 657 } 658 659 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 660 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 661 .getOperation(); 662 } 663 664 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 665 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 666 } 667 668 //===----------------------------------------------------------------------===// 669 // TableGen'd op method definitions 670 //===----------------------------------------------------------------------===// 671 672 #define GET_OP_CLASSES 673 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 674