1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 15 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/Matchers.h" 18 19 using namespace mlir; 20 using namespace mlir::bufferization; 21 22 //===----------------------------------------------------------------------===// 23 // Helper functions 24 //===----------------------------------------------------------------------===// 25 26 FailureOr<Value> 27 mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, 28 MemRefType destType) { 29 auto srcType = value.getType().cast<MemRefType>(); 30 31 // Element type, rank and memory space must match. 32 if (srcType.getElementType() != destType.getElementType()) 33 return failure(); 34 if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) 35 return failure(); 36 if (srcType.getRank() != destType.getRank()) 37 return failure(); 38 39 // In case the affine maps are different, we may need to use a copy if we go 40 // from dynamic to static offset or stride (the canonicalization cannot know 41 // at this point that it is really cast compatible). 42 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 43 int64_t sourceOffset, targetOffset; 44 SmallVector<int64_t, 4> sourceStrides, targetStrides; 45 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 46 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 47 return false; 48 auto dynamicToStatic = [](int64_t a, int64_t b) { 49 return a == MemRefType::getDynamicStrideOrOffset() && 50 b != MemRefType::getDynamicStrideOrOffset(); 51 }; 52 if (dynamicToStatic(sourceOffset, targetOffset)) 53 return false; 54 for (auto it : zip(sourceStrides, targetStrides)) 55 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 56 return false; 57 return true; 58 }; 59 60 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To 61 // ensure that we only generate casts that always succeed at runtime, we check 62 // a fix extra conditions in `isGuaranteedCastCompatible`. 63 if (memref::CastOp::areCastCompatible(srcType, destType) && 64 isGuaranteedCastCompatible(srcType, destType)) { 65 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); 66 return casted; 67 } 68 69 auto loc = value.getLoc(); 70 SmallVector<Value, 4> dynamicOperands; 71 for (int i = 0; i < destType.getRank(); ++i) { 72 if (destType.getShape()[i] != ShapedType::kDynamicSize) 73 continue; 74 auto index = b.createOrFold<arith::ConstantIndexOp>(loc, i); 75 Value size = b.create<memref::DimOp>(loc, value, index); 76 dynamicOperands.push_back(size); 77 } 78 // TODO: Use alloc/memcpy callback from BufferizationOptions if called via 79 // BufferizableOpInterface impl of ToMemrefOp. 80 Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands); 81 b.create<memref::CopyOp>(loc, value, copy); 82 return copy; 83 } 84 85 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the 86 /// to_memref op are different, a memref.cast is needed. 87 LogicalResult 88 mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter, 89 ToMemrefOp toMemref) { 90 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>(); 91 if (!memrefToTensor) 92 return failure(); 93 94 Type srcType = memrefToTensor.getMemref().getType(); 95 Type destType = toMemref.getType(); 96 97 // Directly rewrite if the type did not change. 98 if (srcType == destType) { 99 rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); 100 return success(); 101 } 102 103 auto rankedSrcType = srcType.dyn_cast<MemRefType>(); 104 auto rankedDestType = destType.dyn_cast<MemRefType>(); 105 auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>(); 106 107 // Ranked memref -> Ranked memref cast. 108 if (rankedSrcType && rankedDestType) { 109 FailureOr<Value> replacement = castOrReallocMemRefValue( 110 rewriter, memrefToTensor.getMemref(), rankedDestType); 111 if (failed(replacement)) 112 return failure(); 113 114 rewriter.replaceOp(toMemref, *replacement); 115 return success(); 116 } 117 118 // Unranked memref -> Ranked memref cast: May require a copy. 119 // TODO: Not implemented at the moment. 120 if (unrankedSrcType && rankedDestType) 121 return failure(); 122 123 // Unranked memref -> unranked memref cast 124 // Ranked memref -> unranked memref cast: No copy needed. 125 assert(memref::CastOp::areCastCompatible(srcType, destType) && 126 "expected that types are cast compatible"); 127 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType, 128 memrefToTensor.getMemref()); 129 return success(); 130 } 131 132 void mlir::bufferization::populateDynamicDimSizes( 133 OpBuilder &b, Location loc, Value shapedValue, 134 SmallVector<Value> &dynamicDims) { 135 auto shapedType = shapedValue.getType().cast<ShapedType>(); 136 for (int64_t i = 0; i < shapedType.getRank(); ++i) { 137 if (shapedType.isDynamicDim(i)) { 138 if (shapedType.isa<MemRefType>()) { 139 dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i)); 140 } else { 141 assert(shapedType.isa<RankedTensorType>() && "expected tensor"); 142 dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); 143 } 144 } 145 } 146 } 147 148 //===----------------------------------------------------------------------===// 149 // AllocTensorOp 150 //===----------------------------------------------------------------------===// 151 152 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, 153 const BufferizationOptions &options) { 154 OpBuilder::InsertionGuard g(rewriter); 155 Operation *op = this->getOperation(); 156 Location loc = getLoc(); 157 158 // Nothing to do for dead AllocTensorOps. 159 if (getOperation()->getUses().empty()) { 160 rewriter.eraseOp(getOperation()); 161 return success(); 162 } 163 164 // Get "copy" buffer. 165 Value copyBuffer; 166 if (getCopy()) { 167 FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options); 168 if (failed(maybeCopyBuffer)) 169 return failure(); 170 copyBuffer = *maybeCopyBuffer; 171 } 172 173 // Compute memory space of this allocation. 174 unsigned memorySpace; 175 if (getMemorySpace().has_value()) { 176 memorySpace = *getMemorySpace(); 177 } else if (getCopy()) { 178 memorySpace = 179 copyBuffer.getType().cast<BaseMemRefType>().getMemorySpaceAsInt(); 180 } else if (options.defaultMemorySpace.has_value()) { 181 memorySpace = *options.defaultMemorySpace; 182 } else { 183 return op->emitError("could not infer memory space"); 184 } 185 186 // Create memory allocation. 187 auto allocType = 188 MemRefType::get(getType().getShape(), getType().getElementType(), 189 AffineMap(), memorySpace); 190 SmallVector<Value> dynamicDims = getDynamicSizes(); 191 if (getCopy()) { 192 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); 193 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); 194 } 195 FailureOr<Value> alloc = 196 options.createAlloc(rewriter, loc, allocType, dynamicDims); 197 if (failed(alloc)) 198 return failure(); 199 200 // Create memory copy (if any). 201 if (getCopy()) { 202 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) 203 return failure(); 204 } 205 206 // Should the buffer be deallocated? 207 AnalysisState analysisState(options); 208 bool dealloc; 209 if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) { 210 // AllocTensorOp has one result. 211 ArrayAttr escapeAttr = 212 op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>(); 213 dealloc = !escapeAttr[0].cast<BoolAttr>().getValue(); 214 } else { 215 // No "escape" annotation found. 216 if (options.createDeallocs) { 217 // Perform an ad-hoc analysis. 218 dealloc = !analysisState.isTensorYielded(getResult()); 219 } else { 220 dealloc = false; 221 } 222 } 223 224 // Replace op. 225 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); 226 227 // Create buffer deallocation (if requested). 228 if (!dealloc) 229 return success(); 230 231 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); 232 if (failed(options.createDealloc(rewriter, loc, *alloc))) 233 return failure(); 234 return success(); 235 } 236 237 bool AllocTensorOp::isMemoryWrite(OpResult opResult, 238 const AnalysisState &state) { 239 // AllocTensorOps do not write unless they have a `copy` value. 240 return static_cast<bool>(getCopy()); 241 } 242 243 bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, 244 const AnalysisState &state) { 245 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 246 "expected copy operand"); 247 return true; 248 } 249 250 bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, 251 const AnalysisState &state) { 252 assert(opOperand.getOperandNumber() == getNumOperands() - 1 && 253 "expected copy operand"); 254 return false; 255 } 256 257 SmallVector<OpResult> 258 AllocTensorOp::getAliasingOpResult(OpOperand &opOperand, 259 const AnalysisState &state) { 260 // This is a new allocation. It does not alias with any other buffer. 261 return {}; 262 } 263 264 LogicalResult AllocTensorOp::verify() { 265 if (getCopy() && !getDynamicSizes().empty()) 266 return emitError("dynamic sizes not needed when copying a tensor"); 267 if (!getCopy() && getType().getNumDynamicDims() != 268 static_cast<int64_t>(getDynamicSizes().size())) 269 return emitError("expected ") 270 << getType().getNumDynamicDims() << " dynamic sizes"; 271 if (getCopy() && getCopy().getType() != getType()) 272 return emitError("expected that `copy` and return type match"); 273 274 // For sparse tensor allocation, we require that none of its 275 // uses escapes the function boundary directly. 276 if (sparse_tensor::getSparseTensorEncoding(getType())) { 277 for (auto &use : getOperation()->getUses()) 278 if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>( 279 use.getOwner())) 280 return emitError("sparse tensor allocation should not escape function"); 281 } 282 283 return success(); 284 } 285 286 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 287 RankedTensorType type, ValueRange dynamicSizes) { 288 build(builder, result, type, dynamicSizes, /*copy=*/Value(), 289 /*memory_space=*/IntegerAttr()); 290 } 291 292 void AllocTensorOp::build(OpBuilder &builder, OperationState &result, 293 RankedTensorType type, ValueRange dynamicSizes, 294 Value copy) { 295 build(builder, result, type, dynamicSizes, copy, 296 /*memory_space=*/IntegerAttr()); 297 } 298 299 namespace { 300 /// Change the type of the result of a `bufferization.alloc_tensor` by making 301 /// the result type statically sized along dimension that in the original 302 /// operation where defined as dynamic, but the size was defined using a 303 /// `constant` op. For example: 304 /// 305 /// %c5 = arith.constant 5: index 306 /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32> 307 /// 308 /// to 309 /// 310 /// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32> 311 struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { 312 using OpRewritePattern<AllocTensorOp>::OpRewritePattern; 313 314 LogicalResult matchAndRewrite(AllocTensorOp op, 315 PatternRewriter &rewriter) const override { 316 if (op.getCopy()) 317 return failure(); 318 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape()); 319 SmallVector<Value> newDynamicSizes; 320 unsigned int dynValCounter = 0; 321 for (int64_t i = 0; i < op.getType().getRank(); ++i) { 322 if (!op.isDynamicDim(i)) 323 continue; 324 Value value = op.getDynamicSizes()[dynValCounter++]; 325 APInt intVal; 326 if (matchPattern(value, m_ConstantInt(&intVal))) { 327 newShape[i] = intVal.getSExtValue(); 328 } else { 329 newDynamicSizes.push_back(value); 330 } 331 } 332 RankedTensorType newType = RankedTensorType::get( 333 newShape, op.getType().getElementType(), op.getType().getEncoding()); 334 if (newType == op.getType()) 335 return failure(); 336 auto newOp = rewriter.create<AllocTensorOp>( 337 op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); 338 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 339 return success(); 340 } 341 }; 342 343 struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> { 344 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 345 346 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 347 PatternRewriter &rewriter) const override { 348 Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 349 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>(); 350 if (!allocTensorOp || !maybeConstantIndex) 351 return failure(); 352 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) 353 return failure(); 354 rewriter.replaceOp( 355 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); 356 return success(); 357 } 358 }; 359 } // namespace 360 361 void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 362 MLIRContext *ctx) { 363 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx); 364 } 365 366 LogicalResult AllocTensorOp::reifyResultShapes( 367 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 368 auto shapes = llvm::to_vector<4>(llvm::map_range( 369 llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value { 370 if (isDynamicDim(dim)) 371 return getDynamicSize(builder, dim); 372 return builder.create<arith::ConstantIndexOp>(getLoc(), 373 getStaticSize(dim)); 374 })); 375 reifiedReturnShapes.emplace_back(std::move(shapes)); 376 return success(); 377 } 378 379 ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { 380 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands; 381 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || 382 parser.parseRParen()) 383 return failure(); 384 ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); 385 OpAsmParser::UnresolvedOperand copyOperand; 386 if (copyKeyword.succeeded()) 387 if (parser.parseLParen() || parser.parseOperand(copyOperand) || 388 parser.parseRParen()) 389 return failure(); 390 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) 391 return failure(); 392 393 TensorType type; 394 if (parser.parseCustomTypeWithFallback(type)) 395 return failure(); 396 result.addTypes(type); 397 398 Type indexType = parser.getBuilder().getIndexType(); 399 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) 400 return failure(); 401 if (copyKeyword.succeeded()) 402 if (parser.resolveOperand(copyOperand, type, result.operands)) 403 return failure(); 404 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), 405 parser.getBuilder().getI32VectorAttr( 406 {static_cast<int32_t>(dynamicSizesOperands.size()), 407 static_cast<int32_t>(copyKeyword.succeeded())})); 408 return success(); 409 } 410 411 void AllocTensorOp::print(OpAsmPrinter &p) { 412 p << "(" << getDynamicSizes() << ")"; 413 if (getCopy()) 414 p << " copy(" << getCopy() << ")"; 415 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 416 AllocTensorOp::getOperandSegmentSizeAttr()}); 417 p << " : "; 418 auto type = getResult().getType(); 419 if (auto validType = type.dyn_cast<::mlir::TensorType>()) 420 p.printStrippedAttrOrType(validType); 421 else 422 p << type; 423 } 424 425 Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { 426 assert(isDynamicDim(idx) && "expected dynamic dim"); 427 if (getCopy()) 428 return b.create<tensor::DimOp>(getLoc(), getCopy(), idx); 429 return getOperand(getIndexOfDynamicSize(idx)); 430 } 431 432 //===----------------------------------------------------------------------===// 433 // CloneOp 434 //===----------------------------------------------------------------------===// 435 436 void CloneOp::getEffects( 437 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 438 &effects) { 439 effects.emplace_back(MemoryEffects::Read::get(), getInput(), 440 SideEffects::DefaultResource::get()); 441 effects.emplace_back(MemoryEffects::Write::get(), getOutput(), 442 SideEffects::DefaultResource::get()); 443 effects.emplace_back(MemoryEffects::Allocate::get(), getOutput(), 444 SideEffects::DefaultResource::get()); 445 } 446 447 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 448 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); 449 } 450 451 namespace { 452 453 /// Merge the clone and its source (by converting the clone to a cast) when 454 /// possible. 455 struct SimplifyClones : public OpRewritePattern<CloneOp> { 456 using OpRewritePattern<CloneOp>::OpRewritePattern; 457 458 LogicalResult matchAndRewrite(CloneOp cloneOp, 459 PatternRewriter &rewriter) const override { 460 if (cloneOp.use_empty()) { 461 rewriter.eraseOp(cloneOp); 462 return success(); 463 } 464 465 Value source = cloneOp.getInput(); 466 467 // This only finds dealloc operations for the immediate value. It should 468 // also consider aliases. That would also make the safety check below 469 // redundant. 470 llvm::Optional<Operation *> maybeCloneDeallocOp = 471 memref::findDealloc(cloneOp.getOutput()); 472 // Skip if either of them has > 1 deallocate operations. 473 if (!maybeCloneDeallocOp.has_value()) 474 return failure(); 475 llvm::Optional<Operation *> maybeSourceDeallocOp = 476 memref::findDealloc(source); 477 if (!maybeSourceDeallocOp.has_value()) 478 return failure(); 479 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 480 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 481 482 // If both are deallocated in the same block, their in-block lifetimes 483 // might not fully overlap, so we cannot decide which one to drop. 484 if (cloneDeallocOp && sourceDeallocOp && 485 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 486 return failure(); 487 488 Block *currentBlock = cloneOp->getBlock(); 489 Operation *redundantDealloc = nullptr; 490 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 491 redundantDealloc = cloneDeallocOp; 492 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 493 redundantDealloc = sourceDeallocOp; 494 } 495 496 if (!redundantDealloc) 497 return failure(); 498 499 // Safety check that there are no other deallocations inbetween 500 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 501 // of source before the uses of the clone. With alias information, we could 502 // restrict this to only fail of the dealloc's operand is an alias 503 // of the source. 504 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 505 pos = pos->getNextNode()) { 506 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 507 if (!effectInterface) 508 continue; 509 if (effectInterface.hasEffect<MemoryEffects::Free>()) 510 return failure(); 511 } 512 513 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 514 source); 515 rewriter.eraseOp(redundantDealloc); 516 return success(); 517 } 518 }; 519 520 } // namespace 521 522 void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, 523 MLIRContext *context) { 524 results.add<SimplifyClones>(context); 525 } 526 527 //===----------------------------------------------------------------------===// 528 // DeallocTensorOp 529 //===----------------------------------------------------------------------===// 530 531 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, 532 const BufferizationOptions &options) { 533 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options); 534 if (failed(buffer)) 535 return failure(); 536 if (failed(options.createDealloc(rewriter, getLoc(), *buffer))) 537 return failure(); 538 rewriter.eraseOp(getOperation()); 539 return success(); 540 } 541 542 //===----------------------------------------------------------------------===// 543 // ToTensorOp 544 //===----------------------------------------------------------------------===// 545 546 OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) { 547 if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>()) 548 // Approximate alias analysis by conservatively folding only when no there 549 // is no interleaved operation. 550 if (toMemref->getBlock() == this->getOperation()->getBlock() && 551 toMemref->getNextNode() == this->getOperation()) 552 return toMemref.getTensor(); 553 return {}; 554 } 555 556 namespace { 557 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> { 558 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 559 560 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 561 PatternRewriter &rewriter) const override { 562 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>(); 563 if (!memrefToTensorOp) 564 return failure(); 565 566 rewriter.replaceOpWithNewOp<memref::DimOp>( 567 dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex()); 568 return success(); 569 } 570 }; 571 } // namespace 572 573 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, 574 MLIRContext *context) { 575 results.add<DimOfToTensorFolder>(context); 576 } 577 578 //===----------------------------------------------------------------------===// 579 // ToMemrefOp 580 //===----------------------------------------------------------------------===// 581 582 OpFoldResult ToMemrefOp::fold(ArrayRef<Attribute>) { 583 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>()) 584 if (memrefToTensor.getMemref().getType() == getType()) 585 return memrefToTensor.getMemref(); 586 return {}; 587 } 588 589 namespace { 590 591 /// Replace tensor.cast + to_memref by to_memref + memref.cast. 592 struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> { 593 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 594 595 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 596 PatternRewriter &rewriter) const final { 597 auto tensorCastOperand = 598 toMemref.getOperand().getDefiningOp<tensor::CastOp>(); 599 if (!tensorCastOperand) 600 return failure(); 601 auto srcTensorType = 602 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 603 if (!srcTensorType) 604 return failure(); 605 auto memrefType = MemRefType::get(srcTensorType.getShape(), 606 srcTensorType.getElementType()); 607 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType, 608 tensorCastOperand.getOperand()); 609 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(), 610 memref); 611 return success(); 612 } 613 }; 614 615 /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a 616 /// cast if necessary. 617 struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> { 618 using OpRewritePattern<ToMemrefOp>::OpRewritePattern; 619 620 LogicalResult matchAndRewrite(ToMemrefOp toMemref, 621 PatternRewriter &rewriter) const final { 622 return foldToMemrefToTensorPair(rewriter, toMemref); 623 } 624 }; 625 626 /// Fold a load on a to_memref operation into an tensor.extract on the 627 /// corresponding tensor. 628 struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> { 629 using OpRewritePattern<memref::LoadOp>::OpRewritePattern; 630 631 LogicalResult matchAndRewrite(memref::LoadOp load, 632 PatternRewriter &rewriter) const override { 633 auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>(); 634 if (!toMemref) 635 return failure(); 636 637 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(), 638 load.getIndices()); 639 return success(); 640 } 641 }; 642 643 /// Fold dim of a to_memref into the dim of the tensor. 644 struct DimOfCastOp : public OpRewritePattern<memref::DimOp> { 645 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 646 647 LogicalResult matchAndRewrite(memref::DimOp dimOp, 648 PatternRewriter &rewriter) const override { 649 auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>(); 650 if (!castOp) 651 return failure(); 652 Value newSource = castOp.getOperand(); 653 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, 654 dimOp.getIndex()); 655 return success(); 656 } 657 }; 658 659 } // namespace 660 661 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, 662 MLIRContext *context) { 663 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, 664 ToMemrefToTensorFolding>(context); 665 } 666 667 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, 668 const BufferizationOptions &options) { 669 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. 670 (void)foldToMemrefToTensorPair(rewriter, *this); 671 // Note: The return value of `bufferize` indicates whether there was an error 672 // or not. (And not whether the pattern matched or not.) 673 return success(); 674 } 675 676 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 677 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 678 .getOperation(); 679 } 680 681 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 682 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); 683 } 684 685 //===----------------------------------------------------------------------===// 686 // TableGen'd op method definitions 687 //===----------------------------------------------------------------------===// 688 689 #define GET_OP_CLASSES 690 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 691