1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 return builder.create<mlir::ConstantOp>(loc, type, value); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Common canonicalization pattern support logic 38 //===----------------------------------------------------------------------===// 39 40 /// This is a common class used for patterns of the form 41 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 42 /// into the root operation directly. 43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 44 bool folded = false; 45 for (OpOperand &operand : op->getOpOperands()) { 46 auto cast = operand.get().getDefiningOp<CastOp>(); 47 if (cast && operand.get() != inner && 48 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 49 operand.set(cast.getOperand()); 50 folded = true; 51 } 52 } 53 return success(folded); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Helpers for GlobalOp 58 //===----------------------------------------------------------------------===// 59 60 static Type getTensorTypeFromMemRefType(Type type) { 61 if (auto memref = type.dyn_cast<MemRefType>()) 62 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 63 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 64 return UnrankedTensorType::get(memref.getElementType()); 65 return NoneType::get(type.getContext()); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // AllocOp / AllocaOp 70 //===----------------------------------------------------------------------===// 71 72 template <typename AllocLikeOp> 73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 74 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 75 "applies to only alloc or alloca"); 76 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 77 if (!memRefType) 78 return op.emitOpError("result must be a memref"); 79 80 if (static_cast<int64_t>(op.dynamicSizes().size()) != 81 memRefType.getNumDynamicDims()) 82 return op.emitOpError("dimension operand count does not equal memref " 83 "dynamic dimension count"); 84 85 unsigned numSymbols = 0; 86 if (!memRefType.getAffineMaps().empty()) 87 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 88 if (op.symbolOperands().size() != numSymbols) 89 return op.emitOpError("symbol operand count does not equal memref symbol " 90 "count: expected ") 91 << numSymbols << ", got " << op.symbolOperands().size(); 92 93 return success(); 94 } 95 96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 97 98 static LogicalResult verify(AllocaOp op) { 99 // An alloca op needs to have an ancestor with an allocation scope trait. 100 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 101 return op.emitOpError( 102 "requires an ancestor op with AutomaticAllocationScope trait"); 103 104 return verifyAllocLikeOp(op); 105 } 106 107 namespace { 108 /// Fold constant dimensions into an alloc like operation. 109 template <typename AllocLikeOp> 110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 111 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(AllocLikeOp alloc, 114 PatternRewriter &rewriter) const override { 115 // Check to see if any dimensions operands are constants. If so, we can 116 // substitute and drop them. 117 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 118 return matchPattern(operand, matchConstantIndex()); 119 })) 120 return failure(); 121 122 auto memrefType = alloc.getType(); 123 124 // Ok, we have one or more constant operands. Collect the non-constant ones 125 // and keep track of the resultant memref type to build. 126 SmallVector<int64_t, 4> newShapeConstants; 127 newShapeConstants.reserve(memrefType.getRank()); 128 SmallVector<Value, 4> dynamicSizes; 129 130 unsigned dynamicDimPos = 0; 131 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 132 int64_t dimSize = memrefType.getDimSize(dim); 133 // If this is already static dimension, keep it. 134 if (dimSize != -1) { 135 newShapeConstants.push_back(dimSize); 136 continue; 137 } 138 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 139 auto *defOp = dynamicSize.getDefiningOp(); 140 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 141 // Dynamic shape dimension will be folded. 142 newShapeConstants.push_back(constantIndexOp.getValue()); 143 } else { 144 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 145 newShapeConstants.push_back(-1); 146 dynamicSizes.push_back(dynamicSize); 147 } 148 dynamicDimPos++; 149 } 150 151 // Create new memref type (which will have fewer dynamic dimensions). 152 MemRefType newMemRefType = 153 MemRefType::Builder(memrefType).setShape(newShapeConstants); 154 assert(static_cast<int64_t>(dynamicSizes.size()) == 155 newMemRefType.getNumDynamicDims()); 156 157 // Create and insert the alloc op for the new memref. 158 auto newAlloc = rewriter.create<AllocLikeOp>( 159 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 160 alloc.alignmentAttr()); 161 // Insert a cast so we have the same type as the old alloc. 162 auto resultCast = 163 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 164 165 rewriter.replaceOp(alloc, {resultCast}); 166 return success(); 167 } 168 }; 169 170 /// Fold alloc operations with no users or only store and dealloc uses. 171 template <typename T> 172 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 173 using OpRewritePattern<T>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(T alloc, 176 PatternRewriter &rewriter) const override { 177 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 178 if (auto storeOp = dyn_cast<StoreOp>(op)) 179 return storeOp.value() == alloc; 180 return !isa<DeallocOp>(op); 181 })) 182 return failure(); 183 184 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 185 rewriter.eraseOp(user); 186 187 rewriter.eraseOp(alloc); 188 return success(); 189 } 190 }; 191 } // end anonymous namespace. 192 193 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 194 MLIRContext *context) { 195 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 196 } 197 198 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 199 MLIRContext *context) { 200 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 201 context); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // AllocaScopeOp 206 //===----------------------------------------------------------------------===// 207 208 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 209 bool printBlockTerminators = false; 210 211 p << AllocaScopeOp::getOperationName() << " "; 212 if (!op.results().empty()) { 213 p << " -> (" << op.getResultTypes() << ")"; 214 printBlockTerminators = true; 215 } 216 p.printRegion(op.bodyRegion(), 217 /*printEntryBlockArgs=*/false, 218 /*printBlockTerminators=*/printBlockTerminators); 219 p.printOptionalAttrDict(op->getAttrs()); 220 } 221 222 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 223 OperationState &result) { 224 // Create a region for the body. 225 result.regions.reserve(1); 226 Region *bodyRegion = result.addRegion(); 227 228 // Parse optional results type list. 229 if (parser.parseOptionalArrowTypeList(result.types)) 230 return failure(); 231 232 // Parse the body region. 233 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 234 return failure(); 235 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 236 result.location); 237 238 // Parse the optional attribute list. 239 if (parser.parseOptionalAttrDict(result.attributes)) 240 return failure(); 241 242 return success(); 243 } 244 245 static LogicalResult verify(AllocaScopeOp op) { 246 if (failed(RegionBranchOpInterface::verifyTypes(op))) 247 return failure(); 248 249 return success(); 250 } 251 252 void AllocaScopeOp::getSuccessorRegions( 253 Optional<unsigned> index, ArrayRef<Attribute> operands, 254 SmallVectorImpl<RegionSuccessor> ®ions) { 255 if (index.hasValue()) { 256 regions.push_back(RegionSuccessor(getResults())); 257 return; 258 } 259 260 regions.push_back(RegionSuccessor(&bodyRegion())); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // AssumeAlignmentOp 265 //===----------------------------------------------------------------------===// 266 267 static LogicalResult verify(AssumeAlignmentOp op) { 268 unsigned alignment = op.alignment(); 269 if (!llvm::isPowerOf2_32(alignment)) 270 return op.emitOpError("alignment must be power of 2"); 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // BufferCastOp 276 //===----------------------------------------------------------------------===// 277 278 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 279 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 280 if (tensorLoad.memref().getType() == getType()) 281 return tensorLoad.memref(); 282 return {}; 283 } 284 285 namespace { 286 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 287 struct BufferCast : public OpRewritePattern<BufferCastOp> { 288 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 289 290 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 291 PatternRewriter &rewriter) const final { 292 auto tensorCastOperand = 293 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 294 if (!tensorCastOperand) 295 return failure(); 296 auto srcTensorType = 297 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 298 if (!srcTensorType) 299 return failure(); 300 auto memrefType = MemRefType::get(srcTensorType.getShape(), 301 srcTensorType.getElementType()); 302 Value memref = rewriter.create<BufferCastOp>( 303 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 304 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 305 memref); 306 return success(); 307 } 308 }; 309 310 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 311 /// type mismatches prevent `BufferCastOp::fold` to kick in. 312 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 313 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 316 PatternRewriter &rewriter) const final { 317 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 318 // Bail unless we have a tensor_load + memref.buffer_cast with different 319 // types. `BufferCastOp::fold` handles the same type case. 320 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 321 return failure(); 322 // If types are not cast-compatible, bail. 323 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 324 bufferCast.getType())) 325 return failure(); 326 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 327 tensorLoad.memref()); 328 return success(); 329 } 330 }; 331 332 } // namespace 333 334 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 335 MLIRContext *context) { 336 results.add<BufferCast, TensorLoadToMemRef>(context); 337 } 338 339 //===----------------------------------------------------------------------===// 340 // CastOp 341 //===----------------------------------------------------------------------===// 342 343 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 344 /// source memref. This is useful to to fold a memref.cast into a consuming op 345 /// and implement canonicalization patterns for ops in different dialects that 346 /// may consume the results of memref.cast operations. Such foldable memref.cast 347 /// operations are typically inserted as `view` and `subview` ops are 348 /// canonicalized, to preserve the type compatibility of their uses. 349 /// 350 /// Returns true when all conditions are met: 351 /// 1. source and result are ranked memrefs with strided semantics and same 352 /// element type and rank. 353 /// 2. each of the source's size, offset or stride has more static information 354 /// than the corresponding result's size, offset or stride. 355 /// 356 /// Example 1: 357 /// ```mlir 358 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 359 /// %2 = consumer %1 ... : memref<?x?xf32> ... 360 /// ``` 361 /// 362 /// may fold into: 363 /// 364 /// ```mlir 365 /// %2 = consumer %0 ... : memref<8x16xf32> ... 366 /// ``` 367 /// 368 /// Example 2: 369 /// ``` 370 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 371 /// to memref<?x?xf32> 372 /// consumer %1 : memref<?x?xf32> ... 373 /// ``` 374 /// 375 /// may fold into: 376 /// 377 /// ``` 378 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 379 /// ``` 380 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 381 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 382 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 383 384 // Requires ranked MemRefType. 385 if (!sourceType || !resultType) 386 return false; 387 388 // Requires same elemental type. 389 if (sourceType.getElementType() != resultType.getElementType()) 390 return false; 391 392 // Requires same rank. 393 if (sourceType.getRank() != resultType.getRank()) 394 return false; 395 396 // Only fold casts between strided memref forms. 397 int64_t sourceOffset, resultOffset; 398 SmallVector<int64_t, 4> sourceStrides, resultStrides; 399 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 400 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 401 return false; 402 403 // If cast is towards more static sizes along any dimension, don't fold. 404 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 405 auto ss = std::get<0>(it), st = std::get<1>(it); 406 if (ss != st) 407 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 408 return false; 409 } 410 411 // If cast is towards more static offset along any dimension, don't fold. 412 if (sourceOffset != resultOffset) 413 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 414 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 415 return false; 416 417 // If cast is towards more static strides along any dimension, don't fold. 418 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 419 auto ss = std::get<0>(it), st = std::get<1>(it); 420 if (ss != st) 421 if (MemRefType::isDynamicStrideOrOffset(ss) && 422 !MemRefType::isDynamicStrideOrOffset(st)) 423 return false; 424 } 425 426 return true; 427 } 428 429 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 430 if (inputs.size() != 1 || outputs.size() != 1) 431 return false; 432 Type a = inputs.front(), b = outputs.front(); 433 auto aT = a.dyn_cast<MemRefType>(); 434 auto bT = b.dyn_cast<MemRefType>(); 435 436 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 437 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 438 439 if (aT && bT) { 440 if (aT.getElementType() != bT.getElementType()) 441 return false; 442 if (aT.getAffineMaps() != bT.getAffineMaps()) { 443 int64_t aOffset, bOffset; 444 SmallVector<int64_t, 4> aStrides, bStrides; 445 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 446 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 447 aStrides.size() != bStrides.size()) 448 return false; 449 450 // Strides along a dimension/offset are compatible if the value in the 451 // source memref is static and the value in the target memref is the 452 // same. They are also compatible if either one is dynamic (see 453 // description of MemRefCastOp for details). 454 auto checkCompatible = [](int64_t a, int64_t b) { 455 return (a == MemRefType::getDynamicStrideOrOffset() || 456 b == MemRefType::getDynamicStrideOrOffset() || a == b); 457 }; 458 if (!checkCompatible(aOffset, bOffset)) 459 return false; 460 for (auto aStride : enumerate(aStrides)) 461 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 462 return false; 463 } 464 if (aT.getMemorySpace() != bT.getMemorySpace()) 465 return false; 466 467 // They must have the same rank, and any specified dimensions must match. 468 if (aT.getRank() != bT.getRank()) 469 return false; 470 471 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 472 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 473 if (aDim != -1 && bDim != -1 && aDim != bDim) 474 return false; 475 } 476 return true; 477 } else { 478 if (!aT && !uaT) 479 return false; 480 if (!bT && !ubT) 481 return false; 482 // Unranked to unranked casting is unsupported 483 if (uaT && ubT) 484 return false; 485 486 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 487 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 488 if (aEltType != bEltType) 489 return false; 490 491 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 492 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 493 if (aMemSpace != bMemSpace) 494 return false; 495 496 return true; 497 } 498 499 return false; 500 } 501 502 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 503 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // CloneOp 508 //===----------------------------------------------------------------------===// 509 510 void CloneOp::getEffects( 511 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 512 &effects) { 513 effects.emplace_back(MemoryEffects::Read::get(), input(), 514 SideEffects::DefaultResource::get()); 515 effects.emplace_back(MemoryEffects::Write::get(), output(), 516 SideEffects::DefaultResource::get()); 517 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 518 SideEffects::DefaultResource::get()); 519 } 520 521 namespace { 522 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 523 /// by other Dealloc operations. 524 struct SimplifyClones : public OpRewritePattern<CloneOp> { 525 using OpRewritePattern<CloneOp>::OpRewritePattern; 526 527 LogicalResult matchAndRewrite(CloneOp cloneOp, 528 PatternRewriter &rewriter) const override { 529 if (cloneOp.use_empty()) { 530 rewriter.eraseOp(cloneOp); 531 return success(); 532 } 533 534 Value source = cloneOp.input(); 535 536 // This only finds dealloc operations for the immediate value. It should 537 // also consider aliases. That would also make the safety check below 538 // redundant. 539 Operation *cloneDeallocOp = findDealloc(cloneOp.output()); 540 Operation *sourceDeallocOp = findDealloc(source); 541 542 // If both are deallocated in the same block, their in-block lifetimes 543 // might not fully overlap, so we cannot decide which one to drop. 544 if (cloneDeallocOp && sourceDeallocOp && 545 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 546 return failure(); 547 548 Block *currentBlock = cloneOp->getBlock(); 549 Operation *redundantDealloc = nullptr; 550 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 551 redundantDealloc = cloneDeallocOp; 552 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 553 redundantDealloc = sourceDeallocOp; 554 } 555 556 if (!redundantDealloc) 557 return failure(); 558 559 // Safety check that there are no other deallocations inbetween 560 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 561 // of source before the uses of the clone. With alias information, we could 562 // restrict this to only fail of the dealloc's operand is an alias 563 // of the source. 564 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 565 pos = pos->getNextNode()) { 566 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 567 if (!effectInterface) 568 continue; 569 if (effectInterface.hasEffect<MemoryEffects::Free>()) 570 return failure(); 571 } 572 573 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 574 source); 575 rewriter.eraseOp(redundantDealloc); 576 return success(); 577 } 578 }; 579 580 } // end anonymous namespace. 581 582 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 583 MLIRContext *context) { 584 results.insert<SimplifyClones>(context); 585 } 586 587 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 588 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 589 } 590 591 //===----------------------------------------------------------------------===// 592 // DeallocOp 593 //===----------------------------------------------------------------------===// 594 595 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 596 SmallVectorImpl<OpFoldResult> &results) { 597 /// dealloc(memrefcast) -> dealloc 598 return foldMemRefCast(*this); 599 } 600 601 //===----------------------------------------------------------------------===// 602 // DimOp 603 //===----------------------------------------------------------------------===// 604 605 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 606 int64_t index) { 607 auto loc = result.location; 608 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 609 build(builder, result, memref, indexValue); 610 } 611 612 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 613 Value index) { 614 auto indexTy = builder.getIndexType(); 615 build(builder, result, indexTy, memref, index); 616 } 617 618 Optional<int64_t> DimOp::getConstantIndex() { 619 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 620 return constantOp.getValue().cast<IntegerAttr>().getInt(); 621 return {}; 622 } 623 624 static LogicalResult verify(DimOp op) { 625 // Assume unknown index to be in range. 626 Optional<int64_t> index = op.getConstantIndex(); 627 if (!index.hasValue()) 628 return success(); 629 630 // Check that constant index is not knowingly out of range. 631 auto type = op.memrefOrTensor().getType(); 632 if (auto memrefType = type.dyn_cast<MemRefType>()) { 633 if (index.getValue() >= memrefType.getRank()) 634 return op.emitOpError("index is out of range"); 635 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 636 if (index.getValue() >= tensorType.getRank()) 637 return op.emitOpError("index is out of range"); 638 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 639 // Assume index to be in range. 640 } else { 641 llvm_unreachable("expected operand with memref type"); 642 } 643 return success(); 644 } 645 646 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 647 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 648 649 // All forms of folding require a known index. 650 if (!index) 651 return {}; 652 653 auto argTy = memrefOrTensor().getType(); 654 // Fold if the shape extent along the given index is known. 655 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 656 // Folding for unranked types (UnrankedMemRefType) is not supported. 657 if (!shapedTy.hasRank()) 658 return {}; 659 if (!shapedTy.isDynamicDim(index.getInt())) { 660 Builder builder(getContext()); 661 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 662 } 663 } 664 665 Operation *definingOp = memrefOrTensor().getDefiningOp(); 666 667 // dim(memref.tensor_load(memref)) -> dim(memref) 668 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 669 setOperand(0, tensorLoadOp.memref()); 670 return getResult(); 671 } 672 673 // Fold dim to the operand of tensor.generate. 674 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 675 auto resultType = 676 fromElements.getResult().getType().cast<RankedTensorType>(); 677 // The case where the type encodes the size of the dimension is handled 678 // above. 679 assert(resultType.getShape()[index.getInt()] == 680 RankedTensorType::kDynamicSize); 681 682 // Find the operand of the fromElements that corresponds to this index. 683 auto dynExtents = fromElements.dynamicExtents().begin(); 684 for (auto dim : resultType.getShape().take_front(index.getInt())) 685 if (dim == RankedTensorType::kDynamicSize) 686 dynExtents++; 687 688 return Value{*dynExtents}; 689 } 690 691 // The size at the given index is now known to be a dynamic size. 692 unsigned unsignedIndex = index.getValue().getZExtValue(); 693 694 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) { 695 assert(sliceOp.isDynamicSize(unsignedIndex) && 696 "Expected dynamic slice size"); 697 return sliceOp.getDynamicSize(unsignedIndex); 698 } 699 700 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 701 auto memrefType = argTy.dyn_cast<MemRefType>(); 702 if (!memrefType) 703 return {}; 704 705 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 706 return *(alloc.getDynamicSizes().begin() + 707 memrefType.getDynamicDimIndex(unsignedIndex)); 708 709 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 710 return *(alloca.getDynamicSizes().begin() + 711 memrefType.getDynamicDimIndex(unsignedIndex)); 712 713 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 714 return *(view.getDynamicSizes().begin() + 715 memrefType.getDynamicDimIndex(unsignedIndex)); 716 717 if (auto sizeInterface = 718 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 719 assert(sizeInterface.isDynamicSize(unsignedIndex) && 720 "Expected dynamic subview size"); 721 return sizeInterface.getDynamicSize(unsignedIndex); 722 } 723 724 // dim(memrefcast) -> dim 725 if (succeeded(foldMemRefCast(*this))) 726 return getResult(); 727 728 return {}; 729 } 730 731 namespace { 732 /// Fold dim of a memref reshape operation to a load into the reshape's shape 733 /// operand. 734 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 735 using OpRewritePattern<DimOp>::OpRewritePattern; 736 737 LogicalResult matchAndRewrite(DimOp dim, 738 PatternRewriter &rewriter) const override { 739 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 740 741 if (!reshape) 742 return failure(); 743 744 // Place the load directly after the reshape to ensure that the shape memref 745 // was not mutated. 746 rewriter.setInsertionPointAfter(reshape); 747 Location loc = dim.getLoc(); 748 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 749 if (load.getType() != dim.getType()) 750 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load); 751 rewriter.replaceOp(dim, load); 752 return success(); 753 } 754 }; 755 756 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 757 template <typename CastOpTy> 758 struct DimOfCastOp : public OpRewritePattern<DimOp> { 759 using OpRewritePattern<DimOp>::OpRewritePattern; 760 761 LogicalResult matchAndRewrite(DimOp dimOp, 762 PatternRewriter &rewriter) const override { 763 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 764 if (!castOp) 765 return failure(); 766 Value newSource = castOp.getOperand(); 767 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 768 return success(); 769 } 770 }; 771 } // end anonymous namespace. 772 773 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 774 MLIRContext *context) { 775 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 776 DimOfCastOp<tensor::CastOp>>(context); 777 } 778 779 // --------------------------------------------------------------------------- 780 // DmaStartOp 781 // --------------------------------------------------------------------------- 782 783 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 784 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 785 ValueRange destIndices, Value numElements, 786 Value tagMemRef, ValueRange tagIndices, Value stride, 787 Value elementsPerStride) { 788 result.addOperands(srcMemRef); 789 result.addOperands(srcIndices); 790 result.addOperands(destMemRef); 791 result.addOperands(destIndices); 792 result.addOperands({numElements, tagMemRef}); 793 result.addOperands(tagIndices); 794 if (stride) 795 result.addOperands({stride, elementsPerStride}); 796 } 797 798 void DmaStartOp::print(OpAsmPrinter &p) { 799 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 800 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 801 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 802 << ']'; 803 if (isStrided()) 804 p << ", " << getStride() << ", " << getNumElementsPerStride(); 805 806 p.printOptionalAttrDict((*this)->getAttrs()); 807 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 808 << ", " << getTagMemRef().getType(); 809 } 810 811 // Parse DmaStartOp. 812 // Ex: 813 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 814 // %tag[%index], %stride, %num_elt_per_stride : 815 // : memref<3076 x f32, 0>, 816 // memref<1024 x f32, 2>, 817 // memref<1 x i32> 818 // 819 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 820 OpAsmParser::OperandType srcMemRefInfo; 821 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 822 OpAsmParser::OperandType dstMemRefInfo; 823 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 824 OpAsmParser::OperandType numElementsInfo; 825 OpAsmParser::OperandType tagMemrefInfo; 826 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 827 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 828 829 SmallVector<Type, 3> types; 830 auto indexType = parser.getBuilder().getIndexType(); 831 832 // Parse and resolve the following list of operands: 833 // *) source memref followed by its indices (in square brackets). 834 // *) destination memref followed by its indices (in square brackets). 835 // *) dma size in KiB. 836 if (parser.parseOperand(srcMemRefInfo) || 837 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 838 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 839 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 840 parser.parseComma() || parser.parseOperand(numElementsInfo) || 841 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 842 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 843 return failure(); 844 845 // Parse optional stride and elements per stride. 846 if (parser.parseTrailingOperandList(strideInfo)) 847 return failure(); 848 849 bool isStrided = strideInfo.size() == 2; 850 if (!strideInfo.empty() && !isStrided) { 851 return parser.emitError(parser.getNameLoc(), 852 "expected two stride related operands"); 853 } 854 855 if (parser.parseColonTypeList(types)) 856 return failure(); 857 if (types.size() != 3) 858 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 859 860 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 861 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 862 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 863 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 864 // size should be an index. 865 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 866 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 867 // tag indices should be index. 868 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 869 return failure(); 870 871 if (isStrided) { 872 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 873 return failure(); 874 } 875 876 return success(); 877 } 878 879 LogicalResult DmaStartOp::verify() { 880 unsigned numOperands = getNumOperands(); 881 882 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 883 // the number of elements. 884 if (numOperands < 4) 885 return emitOpError("expected at least 4 operands"); 886 887 // Check types of operands. The order of these calls is important: the later 888 // calls rely on some type properties to compute the operand position. 889 // 1. Source memref. 890 if (!getSrcMemRef().getType().isa<MemRefType>()) 891 return emitOpError("expected source to be of memref type"); 892 if (numOperands < getSrcMemRefRank() + 4) 893 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 894 << " operands"; 895 if (!getSrcIndices().empty() && 896 !llvm::all_of(getSrcIndices().getTypes(), 897 [](Type t) { return t.isIndex(); })) 898 return emitOpError("expected source indices to be of index type"); 899 900 // 2. Destination memref. 901 if (!getDstMemRef().getType().isa<MemRefType>()) 902 return emitOpError("expected destination to be of memref type"); 903 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 904 if (numOperands < numExpectedOperands) 905 return emitOpError() << "expected at least " << numExpectedOperands 906 << " operands"; 907 if (!getDstIndices().empty() && 908 !llvm::all_of(getDstIndices().getTypes(), 909 [](Type t) { return t.isIndex(); })) 910 return emitOpError("expected destination indices to be of index type"); 911 912 // 3. Number of elements. 913 if (!getNumElements().getType().isIndex()) 914 return emitOpError("expected num elements to be of index type"); 915 916 // 4. Tag memref. 917 if (!getTagMemRef().getType().isa<MemRefType>()) 918 return emitOpError("expected tag to be of memref type"); 919 numExpectedOperands += getTagMemRefRank(); 920 if (numOperands < numExpectedOperands) 921 return emitOpError() << "expected at least " << numExpectedOperands 922 << " operands"; 923 if (!getTagIndices().empty() && 924 !llvm::all_of(getTagIndices().getTypes(), 925 [](Type t) { return t.isIndex(); })) 926 return emitOpError("expected tag indices to be of index type"); 927 928 // Optional stride-related operands must be either both present or both 929 // absent. 930 if (numOperands != numExpectedOperands && 931 numOperands != numExpectedOperands + 2) 932 return emitOpError("incorrect number of operands"); 933 934 // 5. Strides. 935 if (isStrided()) { 936 if (!getStride().getType().isIndex() || 937 !getNumElementsPerStride().getType().isIndex()) 938 return emitOpError( 939 "expected stride and num elements per stride to be of type index"); 940 } 941 942 return success(); 943 } 944 945 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 946 SmallVectorImpl<OpFoldResult> &results) { 947 /// dma_start(memrefcast) -> dma_start 948 return foldMemRefCast(*this); 949 } 950 951 // --------------------------------------------------------------------------- 952 // DmaWaitOp 953 // --------------------------------------------------------------------------- 954 955 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 956 Value tagMemRef, ValueRange tagIndices, 957 Value numElements) { 958 result.addOperands(tagMemRef); 959 result.addOperands(tagIndices); 960 result.addOperands(numElements); 961 } 962 963 void DmaWaitOp::print(OpAsmPrinter &p) { 964 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 965 << "], " << getNumElements(); 966 p.printOptionalAttrDict((*this)->getAttrs()); 967 p << " : " << getTagMemRef().getType(); 968 } 969 970 // Parse DmaWaitOp. 971 // Eg: 972 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 973 // 974 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 975 OpAsmParser::OperandType tagMemrefInfo; 976 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 977 Type type; 978 auto indexType = parser.getBuilder().getIndexType(); 979 OpAsmParser::OperandType numElementsInfo; 980 981 // Parse tag memref, its indices, and dma size. 982 if (parser.parseOperand(tagMemrefInfo) || 983 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 984 parser.parseComma() || parser.parseOperand(numElementsInfo) || 985 parser.parseColonType(type) || 986 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 987 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 988 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 989 return failure(); 990 991 return success(); 992 } 993 994 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 995 SmallVectorImpl<OpFoldResult> &results) { 996 /// dma_wait(memrefcast) -> dma_wait 997 return foldMemRefCast(*this); 998 } 999 1000 LogicalResult DmaWaitOp::verify() { 1001 // Mandatory non-variadic operands are tag and the number of elements. 1002 if (getNumOperands() < 2) 1003 return emitOpError() << "expected at least 2 operands"; 1004 1005 // Check types of operands. The order of these calls is important: the later 1006 // calls rely on some type properties to compute the operand position. 1007 if (!getTagMemRef().getType().isa<MemRefType>()) 1008 return emitOpError() << "expected tag to be of memref type"; 1009 1010 if (getNumOperands() != 2 + getTagMemRefRank()) 1011 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1012 << " operands"; 1013 1014 if (!getTagIndices().empty() && 1015 !llvm::all_of(getTagIndices().getTypes(), 1016 [](Type t) { return t.isIndex(); })) 1017 return emitOpError() << "expected tag indices to be of index type"; 1018 1019 if (!getNumElements().getType().isIndex()) 1020 return emitOpError() 1021 << "expected the number of elements to be of index type"; 1022 1023 return success(); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // GlobalOp 1028 //===----------------------------------------------------------------------===// 1029 1030 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1031 TypeAttr type, 1032 Attribute initialValue) { 1033 p << type; 1034 if (!op.isExternal()) { 1035 p << " = "; 1036 if (op.isUninitialized()) 1037 p << "uninitialized"; 1038 else 1039 p.printAttributeWithoutType(initialValue); 1040 } 1041 } 1042 1043 static ParseResult 1044 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1045 Attribute &initialValue) { 1046 Type type; 1047 if (parser.parseType(type)) 1048 return failure(); 1049 1050 auto memrefType = type.dyn_cast<MemRefType>(); 1051 if (!memrefType || !memrefType.hasStaticShape()) 1052 return parser.emitError(parser.getNameLoc()) 1053 << "type should be static shaped memref, but got " << type; 1054 typeAttr = TypeAttr::get(type); 1055 1056 if (parser.parseOptionalEqual()) 1057 return success(); 1058 1059 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1060 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1061 return success(); 1062 } 1063 1064 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1065 if (parser.parseAttribute(initialValue, tensorType)) 1066 return failure(); 1067 if (!initialValue.isa<ElementsAttr>()) 1068 return parser.emitError(parser.getNameLoc()) 1069 << "initial value should be a unit or elements attribute"; 1070 return success(); 1071 } 1072 1073 static LogicalResult verify(GlobalOp op) { 1074 auto memrefType = op.type().dyn_cast<MemRefType>(); 1075 if (!memrefType || !memrefType.hasStaticShape()) 1076 return op.emitOpError("type should be static shaped memref, but got ") 1077 << op.type(); 1078 1079 // Verify that the initial value, if present, is either a unit attribute or 1080 // an elements attribute. 1081 if (op.initial_value().hasValue()) { 1082 Attribute initValue = op.initial_value().getValue(); 1083 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1084 return op.emitOpError("initial value should be a unit or elements " 1085 "attribute, but got ") 1086 << initValue; 1087 1088 // Check that the type of the initial value is compatible with the type of 1089 // the global variable. 1090 if (initValue.isa<ElementsAttr>()) { 1091 Type initType = initValue.getType(); 1092 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1093 if (initType != tensorType) 1094 return op.emitOpError("initial value expected to be of type ") 1095 << tensorType << ", but was of type " << initType; 1096 } 1097 } 1098 1099 // TODO: verify visibility for declarations. 1100 return success(); 1101 } 1102 1103 //===----------------------------------------------------------------------===// 1104 // GetGlobalOp 1105 //===----------------------------------------------------------------------===// 1106 1107 LogicalResult 1108 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1109 // Verify that the result type is same as the type of the referenced 1110 // memref.global op. 1111 auto global = 1112 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1113 if (!global) 1114 return emitOpError("'") 1115 << name() << "' does not reference a valid global memref"; 1116 1117 Type resultType = result().getType(); 1118 if (global.type() != resultType) 1119 return emitOpError("result type ") 1120 << resultType << " does not match type " << global.type() 1121 << " of the global memref @" << name(); 1122 return success(); 1123 } 1124 1125 //===----------------------------------------------------------------------===// 1126 // LoadOp 1127 //===----------------------------------------------------------------------===// 1128 1129 static LogicalResult verify(LoadOp op) { 1130 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1131 return op.emitOpError("incorrect number of indices for load"); 1132 return success(); 1133 } 1134 1135 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1136 /// load(memrefcast) -> load 1137 if (succeeded(foldMemRefCast(*this))) 1138 return getResult(); 1139 return OpFoldResult(); 1140 } 1141 1142 namespace { 1143 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1144 /// corresponding tensor. 1145 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1146 using OpRewritePattern<LoadOp>::OpRewritePattern; 1147 1148 LogicalResult matchAndRewrite(LoadOp load, 1149 PatternRewriter &rewriter) const override { 1150 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1151 if (!buffercast) 1152 return failure(); 1153 1154 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1155 load.indices()); 1156 return success(); 1157 } 1158 }; 1159 } // end anonymous namespace. 1160 1161 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1162 MLIRContext *context) { 1163 results.add<LoadOfBufferCast>(context); 1164 } 1165 1166 //===----------------------------------------------------------------------===// 1167 // PrefetchOp 1168 //===----------------------------------------------------------------------===// 1169 1170 static void print(OpAsmPrinter &p, PrefetchOp op) { 1171 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1172 p.printOperands(op.indices()); 1173 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1174 p << ", locality<" << op.localityHint(); 1175 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1176 p.printOptionalAttrDict( 1177 op->getAttrs(), 1178 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1179 p << " : " << op.getMemRefType(); 1180 } 1181 1182 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1183 OperationState &result) { 1184 OpAsmParser::OperandType memrefInfo; 1185 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1186 IntegerAttr localityHint; 1187 MemRefType type; 1188 StringRef readOrWrite, cacheType; 1189 1190 auto indexTy = parser.getBuilder().getIndexType(); 1191 auto i32Type = parser.getBuilder().getIntegerType(32); 1192 if (parser.parseOperand(memrefInfo) || 1193 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1194 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1195 parser.parseComma() || parser.parseKeyword("locality") || 1196 parser.parseLess() || 1197 parser.parseAttribute(localityHint, i32Type, "localityHint", 1198 result.attributes) || 1199 parser.parseGreater() || parser.parseComma() || 1200 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1201 parser.resolveOperand(memrefInfo, type, result.operands) || 1202 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1203 return failure(); 1204 1205 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1206 return parser.emitError(parser.getNameLoc(), 1207 "rw specifier has to be 'read' or 'write'"); 1208 result.addAttribute( 1209 PrefetchOp::getIsWriteAttrName(), 1210 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1211 1212 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1213 return parser.emitError(parser.getNameLoc(), 1214 "cache type has to be 'data' or 'instr'"); 1215 1216 result.addAttribute( 1217 PrefetchOp::getIsDataCacheAttrName(), 1218 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1219 1220 return success(); 1221 } 1222 1223 static LogicalResult verify(PrefetchOp op) { 1224 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1225 return op.emitOpError("too few indices"); 1226 1227 return success(); 1228 } 1229 1230 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1231 SmallVectorImpl<OpFoldResult> &results) { 1232 // prefetch(memrefcast) -> prefetch 1233 return foldMemRefCast(*this); 1234 } 1235 1236 //===----------------------------------------------------------------------===// 1237 // ReinterpretCastOp 1238 //===----------------------------------------------------------------------===// 1239 1240 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1241 /// `staticSizes` and `staticStrides` are automatically filled with 1242 /// source-memref-rank sentinel values that encode dynamic entries. 1243 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1244 MemRefType resultType, Value source, 1245 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1246 ArrayRef<OpFoldResult> strides, 1247 ArrayRef<NamedAttribute> attrs) { 1248 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1249 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1250 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1251 ShapedType::kDynamicStrideOrOffset); 1252 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1253 ShapedType::kDynamicSize); 1254 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1255 ShapedType::kDynamicStrideOrOffset); 1256 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1257 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1258 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1259 result.addAttributes(attrs); 1260 } 1261 1262 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1263 MemRefType resultType, Value source, 1264 int64_t offset, ArrayRef<int64_t> sizes, 1265 ArrayRef<int64_t> strides, 1266 ArrayRef<NamedAttribute> attrs) { 1267 SmallVector<OpFoldResult> sizeValues = 1268 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1269 return b.getI64IntegerAttr(v); 1270 })); 1271 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1272 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1273 return b.getI64IntegerAttr(v); 1274 })); 1275 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1276 strideValues, attrs); 1277 } 1278 1279 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1280 MemRefType resultType, Value source, Value offset, 1281 ValueRange sizes, ValueRange strides, 1282 ArrayRef<NamedAttribute> attrs) { 1283 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1284 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1285 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1286 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1287 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1288 } 1289 1290 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1291 // completed automatically, like we have for subview and extract_slice. 1292 static LogicalResult verify(ReinterpretCastOp op) { 1293 // The source and result memrefs should be in the same memory space. 1294 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1295 auto resultType = op.getType().cast<MemRefType>(); 1296 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1297 return op.emitError("different memory spaces specified for source type ") 1298 << srcType << " and result memref type " << resultType; 1299 if (srcType.getElementType() != resultType.getElementType()) 1300 return op.emitError("different element types specified for source type ") 1301 << srcType << " and result memref type " << resultType; 1302 1303 // Match sizes in result memref type and in static_sizes attribute. 1304 for (auto &en : 1305 llvm::enumerate(llvm::zip(resultType.getShape(), 1306 extractFromI64ArrayAttr(op.static_sizes())))) { 1307 int64_t resultSize = std::get<0>(en.value()); 1308 int64_t expectedSize = std::get<1>(en.value()); 1309 if (resultSize != expectedSize) 1310 return op.emitError("expected result type with size = ") 1311 << expectedSize << " instead of " << resultSize 1312 << " in dim = " << en.index(); 1313 } 1314 1315 // Match offset and strides in static_offset and static_strides attributes if 1316 // result memref type has an affine map specified. 1317 if (!resultType.getAffineMaps().empty()) { 1318 int64_t resultOffset; 1319 SmallVector<int64_t, 4> resultStrides; 1320 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1321 return failure(); 1322 1323 // Match offset in result memref type and in static_offsets attribute. 1324 int64_t expectedOffset = 1325 extractFromI64ArrayAttr(op.static_offsets()).front(); 1326 if (resultOffset != expectedOffset) 1327 return op.emitError("expected result type with offset = ") 1328 << resultOffset << " instead of " << expectedOffset; 1329 1330 // Match strides in result memref type and in static_strides attribute. 1331 for (auto &en : llvm::enumerate(llvm::zip( 1332 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1333 int64_t resultStride = std::get<0>(en.value()); 1334 int64_t expectedStride = std::get<1>(en.value()); 1335 if (resultStride != expectedStride) 1336 return op.emitError("expected result type with stride = ") 1337 << expectedStride << " instead of " << resultStride 1338 << " in dim = " << en.index(); 1339 } 1340 } 1341 return success(); 1342 } 1343 1344 //===----------------------------------------------------------------------===// 1345 // ReshapeOp 1346 //===----------------------------------------------------------------------===// 1347 1348 static LogicalResult verify(ReshapeOp op) { 1349 Type operandType = op.source().getType(); 1350 Type resultType = op.result().getType(); 1351 1352 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1353 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1354 if (operandElementType != resultElementType) 1355 return op.emitOpError("element types of source and destination memref " 1356 "types should be the same"); 1357 1358 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1359 if (!operandMemRefType.getAffineMaps().empty()) 1360 return op.emitOpError( 1361 "source memref type should have identity affine map"); 1362 1363 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1364 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1365 if (resultMemRefType) { 1366 if (!resultMemRefType.getAffineMaps().empty()) 1367 return op.emitOpError( 1368 "result memref type should have identity affine map"); 1369 if (shapeSize == ShapedType::kDynamicSize) 1370 return op.emitOpError("cannot use shape operand with dynamic length to " 1371 "reshape to statically-ranked memref type"); 1372 if (shapeSize != resultMemRefType.getRank()) 1373 return op.emitOpError( 1374 "length of shape operand differs from the result's memref rank"); 1375 } 1376 return success(); 1377 } 1378 1379 //===----------------------------------------------------------------------===// 1380 // StoreOp 1381 //===----------------------------------------------------------------------===// 1382 1383 static LogicalResult verify(StoreOp op) { 1384 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1385 return op.emitOpError("store index operand count not equal to memref rank"); 1386 1387 return success(); 1388 } 1389 1390 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1391 SmallVectorImpl<OpFoldResult> &results) { 1392 /// store(memrefcast) -> store 1393 return foldMemRefCast(*this, getValueToStore()); 1394 } 1395 1396 //===----------------------------------------------------------------------===// 1397 // SubViewOp 1398 //===----------------------------------------------------------------------===// 1399 1400 namespace { 1401 /// Helpers to write more idiomatic operations. 1402 namespace saturated_arith { 1403 struct Wrapper { 1404 explicit Wrapper(int64_t v) : v(v) {} 1405 operator int64_t() { return v; } 1406 int64_t v; 1407 }; 1408 Wrapper operator+(Wrapper a, int64_t b) { 1409 if (ShapedType::isDynamicStrideOrOffset(a) || 1410 ShapedType::isDynamicStrideOrOffset(b)) 1411 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1412 return Wrapper(a.v + b); 1413 } 1414 Wrapper operator*(Wrapper a, int64_t b) { 1415 if (ShapedType::isDynamicStrideOrOffset(a) || 1416 ShapedType::isDynamicStrideOrOffset(b)) 1417 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1418 return Wrapper(a.v * b); 1419 } 1420 } // end namespace saturated_arith 1421 } // end namespace 1422 1423 /// A subview result type can be fully inferred from the source type and the 1424 /// static representation of offsets, sizes and strides. Special sentinels 1425 /// encode the dynamic case. 1426 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1427 ArrayRef<int64_t> leadingStaticOffsets, 1428 ArrayRef<int64_t> leadingStaticSizes, 1429 ArrayRef<int64_t> leadingStaticStrides) { 1430 // A subview may specify only a leading subset of offset/sizes/strides in 1431 // which case we complete with offset=0, sizes from memref type and strides=1. 1432 unsigned rank = sourceMemRefType.getRank(); 1433 assert(leadingStaticOffsets.size() <= rank && 1434 "unexpected leadingStaticOffsets overflow"); 1435 assert(leadingStaticSizes.size() <= rank && 1436 "unexpected leadingStaticSizes overflow"); 1437 assert(leadingStaticStrides.size() <= rank && 1438 "unexpected leadingStaticStrides overflow"); 1439 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1440 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1441 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1442 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1443 unsigned numTrailingSizes = rank - staticSizes.size(); 1444 unsigned numTrailingStrides = rank - staticStrides.size(); 1445 staticOffsets.append(numTrailingOffsets, 0); 1446 llvm::append_range(staticSizes, 1447 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1448 staticStrides.append(numTrailingStrides, 1); 1449 1450 // Extract source offset and strides. 1451 int64_t sourceOffset; 1452 SmallVector<int64_t, 4> sourceStrides; 1453 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1454 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1455 (void)res; 1456 1457 // Compute target offset whose value is: 1458 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1459 int64_t targetOffset = sourceOffset; 1460 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1461 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1462 using namespace saturated_arith; 1463 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1464 } 1465 1466 // Compute target stride whose value is: 1467 // `sourceStrides_i * staticStrides_i`. 1468 SmallVector<int64_t, 4> targetStrides; 1469 targetStrides.reserve(staticOffsets.size()); 1470 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1471 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1472 using namespace saturated_arith; 1473 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1474 } 1475 1476 // The type is now known. 1477 return MemRefType::get( 1478 staticSizes, sourceMemRefType.getElementType(), 1479 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1480 sourceMemRefType.getContext()), 1481 sourceMemRefType.getMemorySpace()); 1482 } 1483 1484 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1485 ArrayRef<OpFoldResult> leadingStaticOffsets, 1486 ArrayRef<OpFoldResult> leadingStaticSizes, 1487 ArrayRef<OpFoldResult> leadingStaticStrides) { 1488 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1489 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1490 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1491 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1492 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1493 ShapedType::kDynamicSize); 1494 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1495 staticStrides, ShapedType::kDynamicStrideOrOffset); 1496 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1497 staticSizes, staticStrides) 1498 .cast<MemRefType>(); 1499 } 1500 1501 Type SubViewOp::inferRankReducedResultType( 1502 unsigned resultRank, MemRefType sourceRankedTensorType, 1503 ArrayRef<int64_t> leadingStaticOffsets, 1504 ArrayRef<int64_t> leadingStaticSizes, 1505 ArrayRef<int64_t> leadingStaticStrides) { 1506 auto inferredType = 1507 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1508 leadingStaticSizes, leadingStaticStrides) 1509 .cast<MemRefType>(); 1510 assert(inferredType.getRank() >= resultRank && "expected "); 1511 int rankDiff = inferredType.getRank() - resultRank; 1512 if (rankDiff > 0) { 1513 auto shape = inferredType.getShape(); 1514 llvm::SmallDenseSet<unsigned> dimsToProject; 1515 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1516 SmallVector<int64_t> projectedShape; 1517 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1518 if (!dimsToProject.contains(pos)) 1519 projectedShape.push_back(shape[pos]); 1520 1521 AffineMap map; 1522 auto maps = inferredType.getAffineMaps(); 1523 if (!maps.empty() && maps.front()) 1524 map = getProjectedMap(maps.front(), dimsToProject); 1525 inferredType = 1526 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1527 inferredType.getMemorySpace()); 1528 } 1529 return inferredType; 1530 } 1531 1532 Type SubViewOp::inferRankReducedResultType( 1533 unsigned resultRank, MemRefType sourceRankedTensorType, 1534 ArrayRef<OpFoldResult> leadingStaticOffsets, 1535 ArrayRef<OpFoldResult> leadingStaticSizes, 1536 ArrayRef<OpFoldResult> leadingStaticStrides) { 1537 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1538 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1539 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1540 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1541 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1542 ShapedType::kDynamicSize); 1543 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1544 staticStrides, ShapedType::kDynamicStrideOrOffset); 1545 return SubViewOp::inferRankReducedResultType( 1546 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1547 staticStrides); 1548 } 1549 // Build a SubViewOp with mixed static and dynamic entries and custom result 1550 // type. If the type passed is nullptr, it is inferred. 1551 void SubViewOp::build(OpBuilder &b, OperationState &result, 1552 MemRefType resultType, Value source, 1553 ArrayRef<OpFoldResult> offsets, 1554 ArrayRef<OpFoldResult> sizes, 1555 ArrayRef<OpFoldResult> strides, 1556 ArrayRef<NamedAttribute> attrs) { 1557 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1558 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1559 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1560 ShapedType::kDynamicStrideOrOffset); 1561 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1562 ShapedType::kDynamicSize); 1563 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1564 ShapedType::kDynamicStrideOrOffset); 1565 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1566 // Structuring implementation this way avoids duplication between builders. 1567 if (!resultType) { 1568 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1569 staticSizes, staticStrides) 1570 .cast<MemRefType>(); 1571 } 1572 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1573 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1574 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1575 result.addAttributes(attrs); 1576 } 1577 1578 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1579 // type. 1580 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1581 ArrayRef<OpFoldResult> offsets, 1582 ArrayRef<OpFoldResult> sizes, 1583 ArrayRef<OpFoldResult> strides, 1584 ArrayRef<NamedAttribute> attrs) { 1585 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1586 } 1587 1588 // Build a SubViewOp with static entries and inferred result type. 1589 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1590 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1591 ArrayRef<int64_t> strides, 1592 ArrayRef<NamedAttribute> attrs) { 1593 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1594 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1595 return b.getI64IntegerAttr(v); 1596 })); 1597 SmallVector<OpFoldResult> sizeValues = 1598 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1599 return b.getI64IntegerAttr(v); 1600 })); 1601 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1602 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1603 return b.getI64IntegerAttr(v); 1604 })); 1605 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1606 } 1607 1608 // Build a SubViewOp with dynamic entries and custom result type. If the 1609 // type passed is nullptr, it is inferred. 1610 void SubViewOp::build(OpBuilder &b, OperationState &result, 1611 MemRefType resultType, Value source, 1612 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1613 ArrayRef<int64_t> strides, 1614 ArrayRef<NamedAttribute> attrs) { 1615 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1616 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1617 return b.getI64IntegerAttr(v); 1618 })); 1619 SmallVector<OpFoldResult> sizeValues = 1620 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1621 return b.getI64IntegerAttr(v); 1622 })); 1623 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1624 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1625 return b.getI64IntegerAttr(v); 1626 })); 1627 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1628 attrs); 1629 } 1630 1631 // Build a SubViewOp with dynamic entries and custom result type. If the type 1632 // passed is nullptr, it is inferred. 1633 void SubViewOp::build(OpBuilder &b, OperationState &result, 1634 MemRefType resultType, Value source, ValueRange offsets, 1635 ValueRange sizes, ValueRange strides, 1636 ArrayRef<NamedAttribute> attrs) { 1637 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1638 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1639 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1640 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1641 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1642 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1643 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1644 } 1645 1646 // Build a SubViewOp with dynamic entries and inferred result type. 1647 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1648 ValueRange offsets, ValueRange sizes, ValueRange strides, 1649 ArrayRef<NamedAttribute> attrs) { 1650 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1651 } 1652 1653 /// For ViewLikeOpInterface. 1654 Value SubViewOp::getViewSource() { return source(); } 1655 1656 enum SubViewVerificationResult { 1657 Success, 1658 RankTooLarge, 1659 SizeMismatch, 1660 ElemTypeMismatch, 1661 MemSpaceMismatch, 1662 AffineMapMismatch 1663 }; 1664 1665 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1666 /// This function is slight variant of `is subsequence` algorithm where 1667 /// not matching dimension must be 1. 1668 static SubViewVerificationResult 1669 isRankReducedType(Type originalType, Type candidateReducedType, 1670 std::string *errMsg = nullptr) { 1671 if (originalType == candidateReducedType) 1672 return SubViewVerificationResult::Success; 1673 if (!originalType.isa<MemRefType>()) 1674 return SubViewVerificationResult::Success; 1675 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1676 return SubViewVerificationResult::Success; 1677 1678 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1679 ShapedType candidateReducedShapedType = 1680 candidateReducedType.cast<ShapedType>(); 1681 1682 // Rank and size logic is valid for all ShapedTypes. 1683 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1684 ArrayRef<int64_t> candidateReducedShape = 1685 candidateReducedShapedType.getShape(); 1686 unsigned originalRank = originalShape.size(), 1687 candidateReducedRank = candidateReducedShape.size(); 1688 if (candidateReducedRank > originalRank) 1689 return SubViewVerificationResult::RankTooLarge; 1690 1691 auto optionalUnusedDimsMask = 1692 computeRankReductionMask(originalShape, candidateReducedShape); 1693 1694 // Sizes cannot be matched in case empty vector is returned. 1695 if (!optionalUnusedDimsMask.hasValue()) 1696 return SubViewVerificationResult::SizeMismatch; 1697 1698 if (originalShapedType.getElementType() != 1699 candidateReducedShapedType.getElementType()) 1700 return SubViewVerificationResult::ElemTypeMismatch; 1701 1702 // Strided layout logic is relevant for MemRefType only. 1703 MemRefType original = originalType.cast<MemRefType>(); 1704 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1705 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1706 return SubViewVerificationResult::MemSpaceMismatch; 1707 1708 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1709 auto inferredType = 1710 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1711 AffineMap candidateLayout; 1712 if (candidateReduced.getAffineMaps().empty()) 1713 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1714 else 1715 candidateLayout = candidateReduced.getAffineMaps().front(); 1716 assert(inferredType.getNumResults() == 1 && 1717 candidateLayout.getNumResults() == 1); 1718 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1719 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1720 if (errMsg) { 1721 llvm::raw_string_ostream os(*errMsg); 1722 os << "inferred type: " << inferredType; 1723 } 1724 return SubViewVerificationResult::AffineMapMismatch; 1725 } 1726 // Check that the difference of the affine maps simplifies to 0. 1727 AffineExpr diffExpr = 1728 inferredType.getResult(0) - candidateLayout.getResult(0); 1729 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1730 inferredType.getNumSymbols()); 1731 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1732 if (!(cst && cst.getValue() == 0)) { 1733 if (errMsg) { 1734 llvm::raw_string_ostream os(*errMsg); 1735 os << "inferred type: " << inferredType; 1736 } 1737 return SubViewVerificationResult::AffineMapMismatch; 1738 } 1739 return SubViewVerificationResult::Success; 1740 } 1741 1742 template <typename OpTy> 1743 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1744 OpTy op, Type expectedType, 1745 StringRef errMsg = "") { 1746 auto memrefType = expectedType.cast<ShapedType>(); 1747 switch (result) { 1748 case SubViewVerificationResult::Success: 1749 return success(); 1750 case SubViewVerificationResult::RankTooLarge: 1751 return op.emitError("expected result rank to be smaller or equal to ") 1752 << "the source rank. " << errMsg; 1753 case SubViewVerificationResult::SizeMismatch: 1754 return op.emitError("expected result type to be ") 1755 << expectedType 1756 << " or a rank-reduced version. (mismatch of result sizes) " 1757 << errMsg; 1758 case SubViewVerificationResult::ElemTypeMismatch: 1759 return op.emitError("expected result element type to be ") 1760 << memrefType.getElementType() << errMsg; 1761 case SubViewVerificationResult::MemSpaceMismatch: 1762 return op.emitError("expected result and source memory spaces to match.") 1763 << errMsg; 1764 case SubViewVerificationResult::AffineMapMismatch: 1765 return op.emitError("expected result type to be ") 1766 << expectedType 1767 << " or a rank-reduced version. (mismatch of result affine map) " 1768 << errMsg; 1769 } 1770 llvm_unreachable("unexpected subview verification result"); 1771 } 1772 1773 /// Verifier for SubViewOp. 1774 static LogicalResult verify(SubViewOp op) { 1775 MemRefType baseType = op.getSourceType(); 1776 MemRefType subViewType = op.getType(); 1777 1778 // The base memref and the view memref should be in the same memory space. 1779 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1780 return op.emitError("different memory spaces specified for base memref " 1781 "type ") 1782 << baseType << " and subview memref type " << subViewType; 1783 1784 // Verify that the base memref type has a strided layout map. 1785 if (!isStrided(baseType)) 1786 return op.emitError("base type ") << baseType << " is not strided"; 1787 1788 // Verify result type against inferred type. 1789 auto expectedType = SubViewOp::inferResultType( 1790 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1791 extractFromI64ArrayAttr(op.static_sizes()), 1792 extractFromI64ArrayAttr(op.static_strides())); 1793 1794 std::string errMsg; 1795 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1796 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1797 } 1798 1799 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1800 return os << "range " << range.offset << ":" << range.size << ":" 1801 << range.stride; 1802 } 1803 1804 /// Return the list of Range (i.e. offset, size, stride). Each Range 1805 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1806 /// with `b` at location `loc`. 1807 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1808 OpBuilder &b, Location loc) { 1809 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1810 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1811 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1812 SmallVector<Range, 8> res; 1813 unsigned rank = ranks[0]; 1814 res.reserve(rank); 1815 for (unsigned idx = 0; idx < rank; ++idx) { 1816 Value offset = 1817 op.isDynamicOffset(idx) 1818 ? op.getDynamicOffset(idx) 1819 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1820 Value size = op.isDynamicSize(idx) 1821 ? op.getDynamicSize(idx) 1822 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1823 Value stride = 1824 op.isDynamicStride(idx) 1825 ? op.getDynamicStride(idx) 1826 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1827 res.emplace_back(Range{offset, size, stride}); 1828 } 1829 return res; 1830 } 1831 1832 /// Infer the canonical type of the result of a subview operation. Returns a 1833 /// type with rank `resultRank` that is either the rank of the rank-reduced 1834 /// type, or the non-rank-reduced type. 1835 static MemRefType 1836 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1837 ArrayRef<OpFoldResult> mixedOffsets, 1838 ArrayRef<OpFoldResult> mixedSizes, 1839 ArrayRef<OpFoldResult> mixedStrides) { 1840 auto resultType = 1841 SubViewOp::inferRankReducedResultType( 1842 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1843 .cast<MemRefType>(); 1844 if (resultType.getRank() != resultRank) { 1845 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1846 mixedSizes, mixedStrides) 1847 .cast<MemRefType>(); 1848 } 1849 return resultType; 1850 } 1851 1852 namespace { 1853 /// Pattern to rewrite a subview op with MemRefCast arguments. 1854 /// This essentially pushes memref.cast past its consuming subview when 1855 /// `canFoldIntoConsumerOp` is true. 1856 /// 1857 /// Example: 1858 /// ``` 1859 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1860 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1861 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1862 /// ``` 1863 /// is rewritten into: 1864 /// ``` 1865 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1866 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1867 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1868 /// ``` 1869 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1870 public: 1871 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1872 1873 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1874 PatternRewriter &rewriter) const override { 1875 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1876 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1877 return matchPattern(operand, matchConstantIndex()); 1878 })) 1879 return failure(); 1880 1881 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1882 if (!castOp) 1883 return failure(); 1884 1885 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1886 return failure(); 1887 1888 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1889 /// the cast source operand type and the SubViewOp static information. This 1890 /// is the resulting type if the MemRefCastOp were folded. 1891 auto resultType = getCanonicalSubViewResultType( 1892 subViewOp.getType().getRank(), 1893 castOp.source().getType().cast<MemRefType>(), 1894 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1895 subViewOp.getMixedStrides()); 1896 Value newSubView = rewriter.create<SubViewOp>( 1897 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1898 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1899 subViewOp.static_sizes(), subViewOp.static_strides()); 1900 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1901 newSubView); 1902 return success(); 1903 } 1904 }; 1905 } // namespace 1906 1907 /// Return the canonical type of the result of a subview. 1908 struct SubViewReturnTypeCanonicalizer { 1909 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 1910 ArrayRef<OpFoldResult> mixedSizes, 1911 ArrayRef<OpFoldResult> mixedStrides) { 1912 return getCanonicalSubViewResultType(op.getType().getRank(), 1913 op.getSourceType(), mixedOffsets, 1914 mixedSizes, mixedStrides); 1915 } 1916 }; 1917 1918 /// A canonicalizer wrapper to replace SubViewOps. 1919 struct SubViewCanonicalizer { 1920 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1921 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1922 } 1923 }; 1924 1925 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1926 MLIRContext *context) { 1927 results 1928 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1929 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 1930 SubViewOpMemRefCastFolder>(context); 1931 } 1932 1933 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1934 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1935 auto sourceShapedType = source().getType().cast<ShapedType>(); 1936 1937 if (resultShapedType.hasStaticShape() && 1938 resultShapedType == sourceShapedType) { 1939 return getViewSource(); 1940 } 1941 1942 return {}; 1943 } 1944 1945 //===----------------------------------------------------------------------===// 1946 // TensorLoadOp 1947 //===----------------------------------------------------------------------===// 1948 1949 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1950 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1951 // Approximate alias analysis by conservatively folding only when no there 1952 // is no interleaved operation. 1953 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1954 bufferCast->getNextNode() == this->getOperation()) 1955 return bufferCast.tensor(); 1956 return {}; 1957 } 1958 1959 //===----------------------------------------------------------------------===// 1960 // TransposeOp 1961 //===----------------------------------------------------------------------===// 1962 1963 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1964 static MemRefType inferTransposeResultType(MemRefType memRefType, 1965 AffineMap permutationMap) { 1966 auto rank = memRefType.getRank(); 1967 auto originalSizes = memRefType.getShape(); 1968 // Compute permuted sizes. 1969 SmallVector<int64_t, 4> sizes(rank, 0); 1970 for (auto en : llvm::enumerate(permutationMap.getResults())) 1971 sizes[en.index()] = 1972 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 1973 1974 // Compute permuted strides. 1975 int64_t offset; 1976 SmallVector<int64_t, 4> strides; 1977 auto res = getStridesAndOffset(memRefType, strides, offset); 1978 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 1979 (void)res; 1980 auto map = 1981 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 1982 map = permutationMap ? map.compose(permutationMap) : map; 1983 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 1984 } 1985 1986 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 1987 AffineMapAttr permutation, 1988 ArrayRef<NamedAttribute> attrs) { 1989 auto permutationMap = permutation.getValue(); 1990 assert(permutationMap); 1991 1992 auto memRefType = in.getType().cast<MemRefType>(); 1993 // Compute result type. 1994 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 1995 1996 build(b, result, resultType, in, attrs); 1997 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 1998 } 1999 2000 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2001 static void print(OpAsmPrinter &p, TransposeOp op) { 2002 p << "memref.transpose " << op.in() << " " << op.permutation(); 2003 p.printOptionalAttrDict(op->getAttrs(), 2004 {TransposeOp::getPermutationAttrName()}); 2005 p << " : " << op.in().getType() << " to " << op.getType(); 2006 } 2007 2008 static ParseResult parseTransposeOp(OpAsmParser &parser, 2009 OperationState &result) { 2010 OpAsmParser::OperandType in; 2011 AffineMap permutation; 2012 MemRefType srcType, dstType; 2013 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2014 parser.parseOptionalAttrDict(result.attributes) || 2015 parser.parseColonType(srcType) || 2016 parser.resolveOperand(in, srcType, result.operands) || 2017 parser.parseKeywordType("to", dstType) || 2018 parser.addTypeToList(dstType, result.types)) 2019 return failure(); 2020 2021 result.addAttribute(TransposeOp::getPermutationAttrName(), 2022 AffineMapAttr::get(permutation)); 2023 return success(); 2024 } 2025 2026 static LogicalResult verify(TransposeOp op) { 2027 if (!op.permutation().isPermutation()) 2028 return op.emitOpError("expected a permutation map"); 2029 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2030 return op.emitOpError( 2031 "expected a permutation map of same rank as the input"); 2032 2033 auto srcType = op.in().getType().cast<MemRefType>(); 2034 auto dstType = op.getType().cast<MemRefType>(); 2035 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2036 if (dstType != transposedType) 2037 return op.emitOpError("output type ") 2038 << dstType << " does not match transposed input type " << srcType 2039 << ", " << transposedType; 2040 return success(); 2041 } 2042 2043 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2044 if (succeeded(foldMemRefCast(*this))) 2045 return getResult(); 2046 return {}; 2047 } 2048 2049 //===----------------------------------------------------------------------===// 2050 // ViewOp 2051 //===----------------------------------------------------------------------===// 2052 2053 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2054 OpAsmParser::OperandType srcInfo; 2055 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2056 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2057 auto indexType = parser.getBuilder().getIndexType(); 2058 Type srcType, dstType; 2059 llvm::SMLoc offsetLoc; 2060 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2061 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2062 return failure(); 2063 2064 if (offsetInfo.size() != 1) 2065 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2066 2067 return failure( 2068 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2069 parser.parseOptionalAttrDict(result.attributes) || 2070 parser.parseColonType(srcType) || 2071 parser.resolveOperand(srcInfo, srcType, result.operands) || 2072 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2073 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2074 parser.parseKeywordType("to", dstType) || 2075 parser.addTypeToList(dstType, result.types)); 2076 } 2077 2078 static void print(OpAsmPrinter &p, ViewOp op) { 2079 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2080 p.printOperand(op.byte_shift()); 2081 p << "][" << op.sizes() << ']'; 2082 p.printOptionalAttrDict(op->getAttrs()); 2083 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2084 } 2085 2086 static LogicalResult verify(ViewOp op) { 2087 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2088 auto viewType = op.getType(); 2089 2090 // The base memref should have identity layout map (or none). 2091 if (baseType.getAffineMaps().size() > 1 || 2092 (baseType.getAffineMaps().size() == 1 && 2093 !baseType.getAffineMaps()[0].isIdentity())) 2094 return op.emitError("unsupported map for base memref type ") << baseType; 2095 2096 // The result memref should have identity layout map (or none). 2097 if (viewType.getAffineMaps().size() > 1 || 2098 (viewType.getAffineMaps().size() == 1 && 2099 !viewType.getAffineMaps()[0].isIdentity())) 2100 return op.emitError("unsupported map for result memref type ") << viewType; 2101 2102 // The base memref and the view memref should be in the same memory space. 2103 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2104 return op.emitError("different memory spaces specified for base memref " 2105 "type ") 2106 << baseType << " and view memref type " << viewType; 2107 2108 // Verify that we have the correct number of sizes for the result type. 2109 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2110 if (op.sizes().size() != numDynamicDims) 2111 return op.emitError("incorrect number of size operands for type ") 2112 << viewType; 2113 2114 return success(); 2115 } 2116 2117 Value ViewOp::getViewSource() { return source(); } 2118 2119 namespace { 2120 2121 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2122 using OpRewritePattern<ViewOp>::OpRewritePattern; 2123 2124 LogicalResult matchAndRewrite(ViewOp viewOp, 2125 PatternRewriter &rewriter) const override { 2126 // Return if none of the operands are constants. 2127 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2128 return matchPattern(operand, matchConstantIndex()); 2129 })) 2130 return failure(); 2131 2132 // Get result memref type. 2133 auto memrefType = viewOp.getType(); 2134 2135 // Get offset from old memref view type 'memRefType'. 2136 int64_t oldOffset; 2137 SmallVector<int64_t, 4> oldStrides; 2138 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2139 return failure(); 2140 assert(oldOffset == 0 && "Expected 0 offset"); 2141 2142 SmallVector<Value, 4> newOperands; 2143 2144 // Offset cannot be folded into result type. 2145 2146 // Fold any dynamic dim operands which are produced by a constant. 2147 SmallVector<int64_t, 4> newShapeConstants; 2148 newShapeConstants.reserve(memrefType.getRank()); 2149 2150 unsigned dynamicDimPos = 0; 2151 unsigned rank = memrefType.getRank(); 2152 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2153 int64_t dimSize = memrefType.getDimSize(dim); 2154 // If this is already static dimension, keep it. 2155 if (!ShapedType::isDynamic(dimSize)) { 2156 newShapeConstants.push_back(dimSize); 2157 continue; 2158 } 2159 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2160 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2161 // Dynamic shape dimension will be folded. 2162 newShapeConstants.push_back(constantIndexOp.getValue()); 2163 } else { 2164 // Dynamic shape dimension not folded; copy operand from old memref. 2165 newShapeConstants.push_back(dimSize); 2166 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2167 } 2168 dynamicDimPos++; 2169 } 2170 2171 // Create new memref type with constant folded dims. 2172 MemRefType newMemRefType = 2173 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2174 // Nothing new, don't fold. 2175 if (newMemRefType == memrefType) 2176 return failure(); 2177 2178 // Create new ViewOp. 2179 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2180 viewOp.getOperand(0), 2181 viewOp.byte_shift(), newOperands); 2182 // Insert a cast so we have the same type as the old memref type. 2183 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2184 return success(); 2185 } 2186 }; 2187 2188 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2189 using OpRewritePattern<ViewOp>::OpRewritePattern; 2190 2191 LogicalResult matchAndRewrite(ViewOp viewOp, 2192 PatternRewriter &rewriter) const override { 2193 Value memrefOperand = viewOp.getOperand(0); 2194 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2195 if (!memrefCastOp) 2196 return failure(); 2197 Value allocOperand = memrefCastOp.getOperand(); 2198 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2199 if (!allocOp) 2200 return failure(); 2201 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2202 viewOp.byte_shift(), viewOp.sizes()); 2203 return success(); 2204 } 2205 }; 2206 2207 } // end anonymous namespace 2208 2209 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2210 MLIRContext *context) { 2211 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2212 } 2213 2214 //===----------------------------------------------------------------------===// 2215 // TableGen'd op method definitions 2216 //===----------------------------------------------------------------------===// 2217 2218 #define GET_OP_CLASSES 2219 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2220