1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 return builder.create<mlir::ConstantOp>(loc, type, value); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Common canonicalization pattern support logic 38 //===----------------------------------------------------------------------===// 39 40 /// This is a common class used for patterns of the form 41 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 42 /// into the root operation directly. 43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 44 bool folded = false; 45 for (OpOperand &operand : op->getOpOperands()) { 46 auto cast = operand.get().getDefiningOp<CastOp>(); 47 if (cast && operand.get() != inner && 48 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 49 operand.set(cast.getOperand()); 50 folded = true; 51 } 52 } 53 return success(folded); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Helpers for GlobalOp 58 //===----------------------------------------------------------------------===// 59 60 static Type getTensorTypeFromMemRefType(Type type) { 61 if (auto memref = type.dyn_cast<MemRefType>()) 62 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 63 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 64 return UnrankedTensorType::get(memref.getElementType()); 65 return NoneType::get(type.getContext()); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // AllocOp / AllocaOp 70 //===----------------------------------------------------------------------===// 71 72 template <typename AllocLikeOp> 73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 74 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 75 "applies to only alloc or alloca"); 76 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 77 if (!memRefType) 78 return op.emitOpError("result must be a memref"); 79 80 if (static_cast<int64_t>(op.dynamicSizes().size()) != 81 memRefType.getNumDynamicDims()) 82 return op.emitOpError("dimension operand count does not equal memref " 83 "dynamic dimension count"); 84 85 unsigned numSymbols = 0; 86 if (!memRefType.getAffineMaps().empty()) 87 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 88 if (op.symbolOperands().size() != numSymbols) 89 return op.emitOpError("symbol operand count does not equal memref symbol " 90 "count: expected ") 91 << numSymbols << ", got " << op.symbolOperands().size(); 92 93 return success(); 94 } 95 96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 97 98 static LogicalResult verify(AllocaOp op) { 99 // An alloca op needs to have an ancestor with an allocation scope trait. 100 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 101 return op.emitOpError( 102 "requires an ancestor op with AutomaticAllocationScope trait"); 103 104 return verifyAllocLikeOp(op); 105 } 106 107 namespace { 108 /// Fold constant dimensions into an alloc like operation. 109 template <typename AllocLikeOp> 110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 111 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(AllocLikeOp alloc, 114 PatternRewriter &rewriter) const override { 115 // Check to see if any dimensions operands are constants. If so, we can 116 // substitute and drop them. 117 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 118 return matchPattern(operand, matchConstantIndex()); 119 })) 120 return failure(); 121 122 auto memrefType = alloc.getType(); 123 124 // Ok, we have one or more constant operands. Collect the non-constant ones 125 // and keep track of the resultant memref type to build. 126 SmallVector<int64_t, 4> newShapeConstants; 127 newShapeConstants.reserve(memrefType.getRank()); 128 SmallVector<Value, 4> dynamicSizes; 129 130 unsigned dynamicDimPos = 0; 131 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 132 int64_t dimSize = memrefType.getDimSize(dim); 133 // If this is already static dimension, keep it. 134 if (dimSize != -1) { 135 newShapeConstants.push_back(dimSize); 136 continue; 137 } 138 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 139 auto *defOp = dynamicSize.getDefiningOp(); 140 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 141 // Dynamic shape dimension will be folded. 142 newShapeConstants.push_back(constantIndexOp.getValue()); 143 } else { 144 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 145 newShapeConstants.push_back(-1); 146 dynamicSizes.push_back(dynamicSize); 147 } 148 dynamicDimPos++; 149 } 150 151 // Create new memref type (which will have fewer dynamic dimensions). 152 MemRefType newMemRefType = 153 MemRefType::Builder(memrefType).setShape(newShapeConstants); 154 assert(static_cast<int64_t>(dynamicSizes.size()) == 155 newMemRefType.getNumDynamicDims()); 156 157 // Create and insert the alloc op for the new memref. 158 auto newAlloc = rewriter.create<AllocLikeOp>( 159 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 160 alloc.alignmentAttr()); 161 // Insert a cast so we have the same type as the old alloc. 162 auto resultCast = 163 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 164 165 rewriter.replaceOp(alloc, {resultCast}); 166 return success(); 167 } 168 }; 169 170 /// Fold alloc operations with no users or only store and dealloc uses. 171 template <typename T> 172 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 173 using OpRewritePattern<T>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(T alloc, 176 PatternRewriter &rewriter) const override { 177 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 178 if (auto storeOp = dyn_cast<StoreOp>(op)) 179 return storeOp.value() == alloc; 180 return !isa<DeallocOp>(op); 181 })) 182 return failure(); 183 184 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 185 rewriter.eraseOp(user); 186 187 rewriter.eraseOp(alloc); 188 return success(); 189 } 190 }; 191 } // end anonymous namespace. 192 193 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 194 MLIRContext *context) { 195 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 196 } 197 198 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 199 MLIRContext *context) { 200 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 201 context); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // AllocaScopeOp 206 //===----------------------------------------------------------------------===// 207 208 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 209 bool printBlockTerminators = false; 210 211 p << AllocaScopeOp::getOperationName() << " "; 212 if (!op.results().empty()) { 213 p << " -> (" << op.getResultTypes() << ")"; 214 printBlockTerminators = true; 215 } 216 p.printRegion(op.bodyRegion(), 217 /*printEntryBlockArgs=*/false, 218 /*printBlockTerminators=*/printBlockTerminators); 219 p.printOptionalAttrDict(op->getAttrs()); 220 } 221 222 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 223 OperationState &result) { 224 // Create a region for the body. 225 result.regions.reserve(1); 226 Region *bodyRegion = result.addRegion(); 227 228 // Parse optional results type list. 229 if (parser.parseOptionalArrowTypeList(result.types)) 230 return failure(); 231 232 // Parse the body region. 233 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 234 return failure(); 235 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 236 result.location); 237 238 // Parse the optional attribute list. 239 if (parser.parseOptionalAttrDict(result.attributes)) 240 return failure(); 241 242 return success(); 243 } 244 245 static LogicalResult verify(AllocaScopeOp op) { 246 if (failed(RegionBranchOpInterface::verifyTypes(op))) 247 return failure(); 248 249 return success(); 250 } 251 252 void AllocaScopeOp::getSuccessorRegions( 253 Optional<unsigned> index, ArrayRef<Attribute> operands, 254 SmallVectorImpl<RegionSuccessor> ®ions) { 255 if (index.hasValue()) { 256 regions.push_back(RegionSuccessor(getResults())); 257 return; 258 } 259 260 regions.push_back(RegionSuccessor(&bodyRegion())); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // AssumeAlignmentOp 265 //===----------------------------------------------------------------------===// 266 267 static LogicalResult verify(AssumeAlignmentOp op) { 268 unsigned alignment = op.alignment(); 269 if (!llvm::isPowerOf2_32(alignment)) 270 return op.emitOpError("alignment must be power of 2"); 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // BufferCastOp 276 //===----------------------------------------------------------------------===// 277 278 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 279 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 280 if (tensorLoad.memref().getType() == getType()) 281 return tensorLoad.memref(); 282 return {}; 283 } 284 285 namespace { 286 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 287 struct BufferCast : public OpRewritePattern<BufferCastOp> { 288 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 289 290 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 291 PatternRewriter &rewriter) const final { 292 auto tensorCastOperand = 293 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 294 if (!tensorCastOperand) 295 return failure(); 296 auto srcTensorType = 297 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 298 if (!srcTensorType) 299 return failure(); 300 auto memrefType = MemRefType::get(srcTensorType.getShape(), 301 srcTensorType.getElementType()); 302 Value memref = rewriter.create<BufferCastOp>( 303 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 304 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 305 memref); 306 return success(); 307 } 308 }; 309 310 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 311 /// type mismatches prevent `BufferCastOp::fold` to kick in. 312 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 313 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 316 PatternRewriter &rewriter) const final { 317 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 318 // Bail unless we have a tensor_load + memref.buffer_cast with different 319 // types. `BufferCastOp::fold` handles the same type case. 320 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 321 return failure(); 322 // If types are not cast-compatible, bail. 323 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 324 bufferCast.getType())) 325 return failure(); 326 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 327 tensorLoad.memref()); 328 return success(); 329 } 330 }; 331 332 } // namespace 333 334 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 335 MLIRContext *context) { 336 results.add<BufferCast, TensorLoadToMemRef>(context); 337 } 338 339 //===----------------------------------------------------------------------===// 340 // CastOp 341 //===----------------------------------------------------------------------===// 342 343 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 344 /// source memref. This is useful to to fold a memref.cast into a consuming op 345 /// and implement canonicalization patterns for ops in different dialects that 346 /// may consume the results of memref.cast operations. Such foldable memref.cast 347 /// operations are typically inserted as `view` and `subview` ops are 348 /// canonicalized, to preserve the type compatibility of their uses. 349 /// 350 /// Returns true when all conditions are met: 351 /// 1. source and result are ranked memrefs with strided semantics and same 352 /// element type and rank. 353 /// 2. each of the source's size, offset or stride has more static information 354 /// than the corresponding result's size, offset or stride. 355 /// 356 /// Example 1: 357 /// ```mlir 358 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 359 /// %2 = consumer %1 ... : memref<?x?xf32> ... 360 /// ``` 361 /// 362 /// may fold into: 363 /// 364 /// ```mlir 365 /// %2 = consumer %0 ... : memref<8x16xf32> ... 366 /// ``` 367 /// 368 /// Example 2: 369 /// ``` 370 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 371 /// to memref<?x?xf32> 372 /// consumer %1 : memref<?x?xf32> ... 373 /// ``` 374 /// 375 /// may fold into: 376 /// 377 /// ``` 378 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 379 /// ``` 380 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 381 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 382 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 383 384 // Requires ranked MemRefType. 385 if (!sourceType || !resultType) 386 return false; 387 388 // Requires same elemental type. 389 if (sourceType.getElementType() != resultType.getElementType()) 390 return false; 391 392 // Requires same rank. 393 if (sourceType.getRank() != resultType.getRank()) 394 return false; 395 396 // Only fold casts between strided memref forms. 397 int64_t sourceOffset, resultOffset; 398 SmallVector<int64_t, 4> sourceStrides, resultStrides; 399 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 400 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 401 return false; 402 403 // If cast is towards more static sizes along any dimension, don't fold. 404 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 405 auto ss = std::get<0>(it), st = std::get<1>(it); 406 if (ss != st) 407 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 408 return false; 409 } 410 411 // If cast is towards more static offset along any dimension, don't fold. 412 if (sourceOffset != resultOffset) 413 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 414 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 415 return false; 416 417 // If cast is towards more static strides along any dimension, don't fold. 418 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 419 auto ss = std::get<0>(it), st = std::get<1>(it); 420 if (ss != st) 421 if (MemRefType::isDynamicStrideOrOffset(ss) && 422 !MemRefType::isDynamicStrideOrOffset(st)) 423 return false; 424 } 425 426 return true; 427 } 428 429 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 430 if (inputs.size() != 1 || outputs.size() != 1) 431 return false; 432 Type a = inputs.front(), b = outputs.front(); 433 auto aT = a.dyn_cast<MemRefType>(); 434 auto bT = b.dyn_cast<MemRefType>(); 435 436 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 437 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 438 439 if (aT && bT) { 440 if (aT.getElementType() != bT.getElementType()) 441 return false; 442 if (aT.getAffineMaps() != bT.getAffineMaps()) { 443 int64_t aOffset, bOffset; 444 SmallVector<int64_t, 4> aStrides, bStrides; 445 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 446 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 447 aStrides.size() != bStrides.size()) 448 return false; 449 450 // Strides along a dimension/offset are compatible if the value in the 451 // source memref is static and the value in the target memref is the 452 // same. They are also compatible if either one is dynamic (see 453 // description of MemRefCastOp for details). 454 auto checkCompatible = [](int64_t a, int64_t b) { 455 return (a == MemRefType::getDynamicStrideOrOffset() || 456 b == MemRefType::getDynamicStrideOrOffset() || a == b); 457 }; 458 if (!checkCompatible(aOffset, bOffset)) 459 return false; 460 for (auto aStride : enumerate(aStrides)) 461 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 462 return false; 463 } 464 if (aT.getMemorySpace() != bT.getMemorySpace()) 465 return false; 466 467 // They must have the same rank, and any specified dimensions must match. 468 if (aT.getRank() != bT.getRank()) 469 return false; 470 471 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 472 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 473 if (aDim != -1 && bDim != -1 && aDim != bDim) 474 return false; 475 } 476 return true; 477 } else { 478 if (!aT && !uaT) 479 return false; 480 if (!bT && !ubT) 481 return false; 482 // Unranked to unranked casting is unsupported 483 if (uaT && ubT) 484 return false; 485 486 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 487 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 488 if (aEltType != bEltType) 489 return false; 490 491 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 492 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 493 if (aMemSpace != bMemSpace) 494 return false; 495 496 return true; 497 } 498 499 return false; 500 } 501 502 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 503 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // CloneOp 508 //===----------------------------------------------------------------------===// 509 510 void CloneOp::getEffects( 511 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 512 &effects) { 513 effects.emplace_back(MemoryEffects::Read::get(), input(), 514 SideEffects::DefaultResource::get()); 515 effects.emplace_back(MemoryEffects::Write::get(), output(), 516 SideEffects::DefaultResource::get()); 517 } 518 519 namespace { 520 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 521 /// by other Dealloc operations. 522 struct SimplifyClones : public OpRewritePattern<CloneOp> { 523 using OpRewritePattern<CloneOp>::OpRewritePattern; 524 525 LogicalResult matchAndRewrite(CloneOp cloneOp, 526 PatternRewriter &rewriter) const override { 527 if (cloneOp.use_empty()) { 528 rewriter.eraseOp(cloneOp); 529 return success(); 530 } 531 532 Value source = cloneOp.input(); 533 534 // This only finds dealloc operations for the immediate value. It should 535 // also consider aliases. That would also make the safety check below 536 // redundant. 537 Operation *cloneDeallocOp = findDealloc(cloneOp.output()); 538 Operation *sourceDeallocOp = findDealloc(source); 539 540 // If both are deallocated in the same block, their in-block lifetimes 541 // might not fully overlap, so we cannot decide which one to drop. 542 if (cloneDeallocOp && sourceDeallocOp && 543 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 544 return failure(); 545 546 Block *currentBlock = cloneOp->getBlock(); 547 Operation *redundantDealloc = nullptr; 548 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 549 redundantDealloc = cloneDeallocOp; 550 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 551 redundantDealloc = sourceDeallocOp; 552 } 553 554 if (!redundantDealloc) 555 return failure(); 556 557 // Safety check that there are no other deallocations inbetween 558 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 559 // of source before the uses of the clone. With alias information, we could 560 // restrict this to only fail of the dealloc's operand is an alias 561 // of the source. 562 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 563 pos = pos->getNextNode()) { 564 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 565 if (!effectInterface) 566 continue; 567 if (effectInterface.hasEffect<MemoryEffects::Free>()) 568 return failure(); 569 } 570 571 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 572 source); 573 rewriter.eraseOp(redundantDealloc); 574 return success(); 575 } 576 }; 577 578 } // end anonymous namespace. 579 580 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 581 MLIRContext *context) { 582 results.insert<SimplifyClones>(context); 583 } 584 585 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 586 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 587 } 588 589 //===----------------------------------------------------------------------===// 590 // DeallocOp 591 //===----------------------------------------------------------------------===// 592 593 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 594 SmallVectorImpl<OpFoldResult> &results) { 595 /// dealloc(memrefcast) -> dealloc 596 return foldMemRefCast(*this); 597 } 598 599 //===----------------------------------------------------------------------===// 600 // DimOp 601 //===----------------------------------------------------------------------===// 602 603 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 604 int64_t index) { 605 auto loc = result.location; 606 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 607 build(builder, result, memref, indexValue); 608 } 609 610 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 611 Value index) { 612 auto indexTy = builder.getIndexType(); 613 build(builder, result, indexTy, memref, index); 614 } 615 616 Optional<int64_t> DimOp::getConstantIndex() { 617 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 618 return constantOp.getValue().cast<IntegerAttr>().getInt(); 619 return {}; 620 } 621 622 static LogicalResult verify(DimOp op) { 623 // Assume unknown index to be in range. 624 Optional<int64_t> index = op.getConstantIndex(); 625 if (!index.hasValue()) 626 return success(); 627 628 // Check that constant index is not knowingly out of range. 629 auto type = op.memrefOrTensor().getType(); 630 if (auto memrefType = type.dyn_cast<MemRefType>()) { 631 if (index.getValue() >= memrefType.getRank()) 632 return op.emitOpError("index is out of range"); 633 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 634 if (index.getValue() >= tensorType.getRank()) 635 return op.emitOpError("index is out of range"); 636 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 637 // Assume index to be in range. 638 } else { 639 llvm_unreachable("expected operand with memref type"); 640 } 641 return success(); 642 } 643 644 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 645 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 646 647 // All forms of folding require a known index. 648 if (!index) 649 return {}; 650 651 auto argTy = memrefOrTensor().getType(); 652 // Fold if the shape extent along the given index is known. 653 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 654 // Folding for unranked types (UnrankedMemRefType) is not supported. 655 if (!shapedTy.hasRank()) 656 return {}; 657 if (!shapedTy.isDynamicDim(index.getInt())) { 658 Builder builder(getContext()); 659 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 660 } 661 } 662 663 Operation *definingOp = memrefOrTensor().getDefiningOp(); 664 665 // dim(memref.tensor_load(memref)) -> dim(memref) 666 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 667 setOperand(0, tensorLoadOp.memref()); 668 return getResult(); 669 } 670 671 // Fold dim to the operand of tensor.generate. 672 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 673 auto resultType = 674 fromElements.getResult().getType().cast<RankedTensorType>(); 675 // The case where the type encodes the size of the dimension is handled 676 // above. 677 assert(resultType.getShape()[index.getInt()] == 678 RankedTensorType::kDynamicSize); 679 680 // Find the operand of the fromElements that corresponds to this index. 681 auto dynExtents = fromElements.dynamicExtents().begin(); 682 for (auto dim : resultType.getShape().take_front(index.getInt())) 683 if (dim == RankedTensorType::kDynamicSize) 684 dynExtents++; 685 686 return Value{*dynExtents}; 687 } 688 689 // The size at the given index is now known to be a dynamic size. 690 unsigned unsignedIndex = index.getValue().getZExtValue(); 691 692 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) { 693 assert(sliceOp.isDynamicSize(unsignedIndex) && 694 "Expected dynamic slice size"); 695 return sliceOp.getDynamicSize(unsignedIndex); 696 } 697 698 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 699 auto memrefType = argTy.dyn_cast<MemRefType>(); 700 if (!memrefType) 701 return {}; 702 703 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 704 return *(alloc.getDynamicSizes().begin() + 705 memrefType.getDynamicDimIndex(unsignedIndex)); 706 707 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 708 return *(alloca.getDynamicSizes().begin() + 709 memrefType.getDynamicDimIndex(unsignedIndex)); 710 711 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 712 return *(view.getDynamicSizes().begin() + 713 memrefType.getDynamicDimIndex(unsignedIndex)); 714 715 if (auto sizeInterface = 716 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 717 assert(sizeInterface.isDynamicSize(unsignedIndex) && 718 "Expected dynamic subview size"); 719 return sizeInterface.getDynamicSize(unsignedIndex); 720 } 721 722 // dim(memrefcast) -> dim 723 if (succeeded(foldMemRefCast(*this))) 724 return getResult(); 725 726 return {}; 727 } 728 729 namespace { 730 /// Fold dim of a memref reshape operation to a load into the reshape's shape 731 /// operand. 732 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 733 using OpRewritePattern<DimOp>::OpRewritePattern; 734 735 LogicalResult matchAndRewrite(DimOp dim, 736 PatternRewriter &rewriter) const override { 737 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 738 739 if (!reshape) 740 return failure(); 741 742 // Place the load directly after the reshape to ensure that the shape memref 743 // was not mutated. 744 rewriter.setInsertionPointAfter(reshape); 745 Location loc = dim.getLoc(); 746 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 747 if (load.getType() != dim.getType()) 748 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load); 749 rewriter.replaceOp(dim, load); 750 return success(); 751 } 752 }; 753 754 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 755 template <typename CastOpTy> 756 struct DimOfCastOp : public OpRewritePattern<DimOp> { 757 using OpRewritePattern<DimOp>::OpRewritePattern; 758 759 LogicalResult matchAndRewrite(DimOp dimOp, 760 PatternRewriter &rewriter) const override { 761 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 762 if (!castOp) 763 return failure(); 764 Value newSource = castOp.getOperand(); 765 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 766 return success(); 767 } 768 }; 769 } // end anonymous namespace. 770 771 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 772 MLIRContext *context) { 773 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 774 DimOfCastOp<tensor::CastOp>>(context); 775 } 776 777 // --------------------------------------------------------------------------- 778 // DmaStartOp 779 // --------------------------------------------------------------------------- 780 781 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 782 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 783 ValueRange destIndices, Value numElements, 784 Value tagMemRef, ValueRange tagIndices, Value stride, 785 Value elementsPerStride) { 786 result.addOperands(srcMemRef); 787 result.addOperands(srcIndices); 788 result.addOperands(destMemRef); 789 result.addOperands(destIndices); 790 result.addOperands({numElements, tagMemRef}); 791 result.addOperands(tagIndices); 792 if (stride) 793 result.addOperands({stride, elementsPerStride}); 794 } 795 796 void DmaStartOp::print(OpAsmPrinter &p) { 797 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 798 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 799 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 800 << ']'; 801 if (isStrided()) 802 p << ", " << getStride() << ", " << getNumElementsPerStride(); 803 804 p.printOptionalAttrDict((*this)->getAttrs()); 805 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 806 << ", " << getTagMemRef().getType(); 807 } 808 809 // Parse DmaStartOp. 810 // Ex: 811 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 812 // %tag[%index], %stride, %num_elt_per_stride : 813 // : memref<3076 x f32, 0>, 814 // memref<1024 x f32, 2>, 815 // memref<1 x i32> 816 // 817 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 818 OpAsmParser::OperandType srcMemRefInfo; 819 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 820 OpAsmParser::OperandType dstMemRefInfo; 821 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 822 OpAsmParser::OperandType numElementsInfo; 823 OpAsmParser::OperandType tagMemrefInfo; 824 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 825 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 826 827 SmallVector<Type, 3> types; 828 auto indexType = parser.getBuilder().getIndexType(); 829 830 // Parse and resolve the following list of operands: 831 // *) source memref followed by its indices (in square brackets). 832 // *) destination memref followed by its indices (in square brackets). 833 // *) dma size in KiB. 834 if (parser.parseOperand(srcMemRefInfo) || 835 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 836 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 837 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 838 parser.parseComma() || parser.parseOperand(numElementsInfo) || 839 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 840 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 841 return failure(); 842 843 // Parse optional stride and elements per stride. 844 if (parser.parseTrailingOperandList(strideInfo)) 845 return failure(); 846 847 bool isStrided = strideInfo.size() == 2; 848 if (!strideInfo.empty() && !isStrided) { 849 return parser.emitError(parser.getNameLoc(), 850 "expected two stride related operands"); 851 } 852 853 if (parser.parseColonTypeList(types)) 854 return failure(); 855 if (types.size() != 3) 856 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 857 858 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 859 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 860 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 861 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 862 // size should be an index. 863 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 864 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 865 // tag indices should be index. 866 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 867 return failure(); 868 869 if (isStrided) { 870 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 871 return failure(); 872 } 873 874 return success(); 875 } 876 877 LogicalResult DmaStartOp::verify() { 878 unsigned numOperands = getNumOperands(); 879 880 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 881 // the number of elements. 882 if (numOperands < 4) 883 return emitOpError("expected at least 4 operands"); 884 885 // Check types of operands. The order of these calls is important: the later 886 // calls rely on some type properties to compute the operand position. 887 // 1. Source memref. 888 if (!getSrcMemRef().getType().isa<MemRefType>()) 889 return emitOpError("expected source to be of memref type"); 890 if (numOperands < getSrcMemRefRank() + 4) 891 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 892 << " operands"; 893 if (!getSrcIndices().empty() && 894 !llvm::all_of(getSrcIndices().getTypes(), 895 [](Type t) { return t.isIndex(); })) 896 return emitOpError("expected source indices to be of index type"); 897 898 // 2. Destination memref. 899 if (!getDstMemRef().getType().isa<MemRefType>()) 900 return emitOpError("expected destination to be of memref type"); 901 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 902 if (numOperands < numExpectedOperands) 903 return emitOpError() << "expected at least " << numExpectedOperands 904 << " operands"; 905 if (!getDstIndices().empty() && 906 !llvm::all_of(getDstIndices().getTypes(), 907 [](Type t) { return t.isIndex(); })) 908 return emitOpError("expected destination indices to be of index type"); 909 910 // 3. Number of elements. 911 if (!getNumElements().getType().isIndex()) 912 return emitOpError("expected num elements to be of index type"); 913 914 // 4. Tag memref. 915 if (!getTagMemRef().getType().isa<MemRefType>()) 916 return emitOpError("expected tag to be of memref type"); 917 numExpectedOperands += getTagMemRefRank(); 918 if (numOperands < numExpectedOperands) 919 return emitOpError() << "expected at least " << numExpectedOperands 920 << " operands"; 921 if (!getTagIndices().empty() && 922 !llvm::all_of(getTagIndices().getTypes(), 923 [](Type t) { return t.isIndex(); })) 924 return emitOpError("expected tag indices to be of index type"); 925 926 // Optional stride-related operands must be either both present or both 927 // absent. 928 if (numOperands != numExpectedOperands && 929 numOperands != numExpectedOperands + 2) 930 return emitOpError("incorrect number of operands"); 931 932 // 5. Strides. 933 if (isStrided()) { 934 if (!getStride().getType().isIndex() || 935 !getNumElementsPerStride().getType().isIndex()) 936 return emitOpError( 937 "expected stride and num elements per stride to be of type index"); 938 } 939 940 return success(); 941 } 942 943 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 944 SmallVectorImpl<OpFoldResult> &results) { 945 /// dma_start(memrefcast) -> dma_start 946 return foldMemRefCast(*this); 947 } 948 949 // --------------------------------------------------------------------------- 950 // DmaWaitOp 951 // --------------------------------------------------------------------------- 952 953 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 954 Value tagMemRef, ValueRange tagIndices, 955 Value numElements) { 956 result.addOperands(tagMemRef); 957 result.addOperands(tagIndices); 958 result.addOperands(numElements); 959 } 960 961 void DmaWaitOp::print(OpAsmPrinter &p) { 962 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 963 << "], " << getNumElements(); 964 p.printOptionalAttrDict((*this)->getAttrs()); 965 p << " : " << getTagMemRef().getType(); 966 } 967 968 // Parse DmaWaitOp. 969 // Eg: 970 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 971 // 972 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 973 OpAsmParser::OperandType tagMemrefInfo; 974 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 975 Type type; 976 auto indexType = parser.getBuilder().getIndexType(); 977 OpAsmParser::OperandType numElementsInfo; 978 979 // Parse tag memref, its indices, and dma size. 980 if (parser.parseOperand(tagMemrefInfo) || 981 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 982 parser.parseComma() || parser.parseOperand(numElementsInfo) || 983 parser.parseColonType(type) || 984 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 985 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 986 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 987 return failure(); 988 989 return success(); 990 } 991 992 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 993 SmallVectorImpl<OpFoldResult> &results) { 994 /// dma_wait(memrefcast) -> dma_wait 995 return foldMemRefCast(*this); 996 } 997 998 LogicalResult DmaWaitOp::verify() { 999 // Mandatory non-variadic operands are tag and the number of elements. 1000 if (getNumOperands() < 2) 1001 return emitOpError() << "expected at least 2 operands"; 1002 1003 // Check types of operands. The order of these calls is important: the later 1004 // calls rely on some type properties to compute the operand position. 1005 if (!getTagMemRef().getType().isa<MemRefType>()) 1006 return emitOpError() << "expected tag to be of memref type"; 1007 1008 if (getNumOperands() != 2 + getTagMemRefRank()) 1009 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1010 << " operands"; 1011 1012 if (!getTagIndices().empty() && 1013 !llvm::all_of(getTagIndices().getTypes(), 1014 [](Type t) { return t.isIndex(); })) 1015 return emitOpError() << "expected tag indices to be of index type"; 1016 1017 if (!getNumElements().getType().isIndex()) 1018 return emitOpError() 1019 << "expected the number of elements to be of index type"; 1020 1021 return success(); 1022 } 1023 1024 //===----------------------------------------------------------------------===// 1025 // GlobalOp 1026 //===----------------------------------------------------------------------===// 1027 1028 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1029 TypeAttr type, 1030 Attribute initialValue) { 1031 p << type; 1032 if (!op.isExternal()) { 1033 p << " = "; 1034 if (op.isUninitialized()) 1035 p << "uninitialized"; 1036 else 1037 p.printAttributeWithoutType(initialValue); 1038 } 1039 } 1040 1041 static ParseResult 1042 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1043 Attribute &initialValue) { 1044 Type type; 1045 if (parser.parseType(type)) 1046 return failure(); 1047 1048 auto memrefType = type.dyn_cast<MemRefType>(); 1049 if (!memrefType || !memrefType.hasStaticShape()) 1050 return parser.emitError(parser.getNameLoc()) 1051 << "type should be static shaped memref, but got " << type; 1052 typeAttr = TypeAttr::get(type); 1053 1054 if (parser.parseOptionalEqual()) 1055 return success(); 1056 1057 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1058 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1059 return success(); 1060 } 1061 1062 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1063 if (parser.parseAttribute(initialValue, tensorType)) 1064 return failure(); 1065 if (!initialValue.isa<ElementsAttr>()) 1066 return parser.emitError(parser.getNameLoc()) 1067 << "initial value should be a unit or elements attribute"; 1068 return success(); 1069 } 1070 1071 static LogicalResult verify(GlobalOp op) { 1072 auto memrefType = op.type().dyn_cast<MemRefType>(); 1073 if (!memrefType || !memrefType.hasStaticShape()) 1074 return op.emitOpError("type should be static shaped memref, but got ") 1075 << op.type(); 1076 1077 // Verify that the initial value, if present, is either a unit attribute or 1078 // an elements attribute. 1079 if (op.initial_value().hasValue()) { 1080 Attribute initValue = op.initial_value().getValue(); 1081 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1082 return op.emitOpError("initial value should be a unit or elements " 1083 "attribute, but got ") 1084 << initValue; 1085 1086 // Check that the type of the initial value is compatible with the type of 1087 // the global variable. 1088 if (initValue.isa<ElementsAttr>()) { 1089 Type initType = initValue.getType(); 1090 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1091 if (initType != tensorType) 1092 return op.emitOpError("initial value expected to be of type ") 1093 << tensorType << ", but was of type " << initType; 1094 } 1095 } 1096 1097 // TODO: verify visibility for declarations. 1098 return success(); 1099 } 1100 1101 //===----------------------------------------------------------------------===// 1102 // GetGlobalOp 1103 //===----------------------------------------------------------------------===// 1104 1105 LogicalResult 1106 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1107 // Verify that the result type is same as the type of the referenced 1108 // memref.global op. 1109 auto global = 1110 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1111 if (!global) 1112 return emitOpError("'") 1113 << name() << "' does not reference a valid global memref"; 1114 1115 Type resultType = result().getType(); 1116 if (global.type() != resultType) 1117 return emitOpError("result type ") 1118 << resultType << " does not match type " << global.type() 1119 << " of the global memref @" << name(); 1120 return success(); 1121 } 1122 1123 //===----------------------------------------------------------------------===// 1124 // LoadOp 1125 //===----------------------------------------------------------------------===// 1126 1127 static LogicalResult verify(LoadOp op) { 1128 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1129 return op.emitOpError("incorrect number of indices for load"); 1130 return success(); 1131 } 1132 1133 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1134 /// load(memrefcast) -> load 1135 if (succeeded(foldMemRefCast(*this))) 1136 return getResult(); 1137 return OpFoldResult(); 1138 } 1139 1140 namespace { 1141 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1142 /// corresponding tensor. 1143 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1144 using OpRewritePattern<LoadOp>::OpRewritePattern; 1145 1146 LogicalResult matchAndRewrite(LoadOp load, 1147 PatternRewriter &rewriter) const override { 1148 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1149 if (!buffercast) 1150 return failure(); 1151 1152 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1153 load.indices()); 1154 return success(); 1155 } 1156 }; 1157 } // end anonymous namespace. 1158 1159 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1160 MLIRContext *context) { 1161 results.add<LoadOfBufferCast>(context); 1162 } 1163 1164 //===----------------------------------------------------------------------===// 1165 // PrefetchOp 1166 //===----------------------------------------------------------------------===// 1167 1168 static void print(OpAsmPrinter &p, PrefetchOp op) { 1169 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1170 p.printOperands(op.indices()); 1171 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1172 p << ", locality<" << op.localityHint(); 1173 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1174 p.printOptionalAttrDict( 1175 op->getAttrs(), 1176 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1177 p << " : " << op.getMemRefType(); 1178 } 1179 1180 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1181 OperationState &result) { 1182 OpAsmParser::OperandType memrefInfo; 1183 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1184 IntegerAttr localityHint; 1185 MemRefType type; 1186 StringRef readOrWrite, cacheType; 1187 1188 auto indexTy = parser.getBuilder().getIndexType(); 1189 auto i32Type = parser.getBuilder().getIntegerType(32); 1190 if (parser.parseOperand(memrefInfo) || 1191 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1192 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1193 parser.parseComma() || parser.parseKeyword("locality") || 1194 parser.parseLess() || 1195 parser.parseAttribute(localityHint, i32Type, "localityHint", 1196 result.attributes) || 1197 parser.parseGreater() || parser.parseComma() || 1198 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1199 parser.resolveOperand(memrefInfo, type, result.operands) || 1200 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1201 return failure(); 1202 1203 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1204 return parser.emitError(parser.getNameLoc(), 1205 "rw specifier has to be 'read' or 'write'"); 1206 result.addAttribute( 1207 PrefetchOp::getIsWriteAttrName(), 1208 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1209 1210 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1211 return parser.emitError(parser.getNameLoc(), 1212 "cache type has to be 'data' or 'instr'"); 1213 1214 result.addAttribute( 1215 PrefetchOp::getIsDataCacheAttrName(), 1216 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1217 1218 return success(); 1219 } 1220 1221 static LogicalResult verify(PrefetchOp op) { 1222 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1223 return op.emitOpError("too few indices"); 1224 1225 return success(); 1226 } 1227 1228 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1229 SmallVectorImpl<OpFoldResult> &results) { 1230 // prefetch(memrefcast) -> prefetch 1231 return foldMemRefCast(*this); 1232 } 1233 1234 //===----------------------------------------------------------------------===// 1235 // ReinterpretCastOp 1236 //===----------------------------------------------------------------------===// 1237 1238 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1239 /// `staticSizes` and `staticStrides` are automatically filled with 1240 /// source-memref-rank sentinel values that encode dynamic entries. 1241 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1242 MemRefType resultType, Value source, 1243 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1244 ArrayRef<OpFoldResult> strides, 1245 ArrayRef<NamedAttribute> attrs) { 1246 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1247 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1248 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1249 ShapedType::kDynamicStrideOrOffset); 1250 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1251 ShapedType::kDynamicSize); 1252 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1253 ShapedType::kDynamicStrideOrOffset); 1254 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1255 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1256 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1257 result.addAttributes(attrs); 1258 } 1259 1260 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1261 MemRefType resultType, Value source, 1262 int64_t offset, ArrayRef<int64_t> sizes, 1263 ArrayRef<int64_t> strides, 1264 ArrayRef<NamedAttribute> attrs) { 1265 SmallVector<OpFoldResult> sizeValues = 1266 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1267 return b.getI64IntegerAttr(v); 1268 })); 1269 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1270 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1271 return b.getI64IntegerAttr(v); 1272 })); 1273 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1274 strideValues, attrs); 1275 } 1276 1277 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1278 MemRefType resultType, Value source, Value offset, 1279 ValueRange sizes, ValueRange strides, 1280 ArrayRef<NamedAttribute> attrs) { 1281 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1282 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1283 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1284 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1285 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1286 } 1287 1288 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1289 // completed automatically, like we have for subview and extract_slice. 1290 static LogicalResult verify(ReinterpretCastOp op) { 1291 // The source and result memrefs should be in the same memory space. 1292 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1293 auto resultType = op.getType().cast<MemRefType>(); 1294 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1295 return op.emitError("different memory spaces specified for source type ") 1296 << srcType << " and result memref type " << resultType; 1297 if (srcType.getElementType() != resultType.getElementType()) 1298 return op.emitError("different element types specified for source type ") 1299 << srcType << " and result memref type " << resultType; 1300 1301 // Match sizes in result memref type and in static_sizes attribute. 1302 for (auto &en : 1303 llvm::enumerate(llvm::zip(resultType.getShape(), 1304 extractFromI64ArrayAttr(op.static_sizes())))) { 1305 int64_t resultSize = std::get<0>(en.value()); 1306 int64_t expectedSize = std::get<1>(en.value()); 1307 if (resultSize != expectedSize) 1308 return op.emitError("expected result type with size = ") 1309 << expectedSize << " instead of " << resultSize 1310 << " in dim = " << en.index(); 1311 } 1312 1313 // Match offset and strides in static_offset and static_strides attributes if 1314 // result memref type has an affine map specified. 1315 if (!resultType.getAffineMaps().empty()) { 1316 int64_t resultOffset; 1317 SmallVector<int64_t, 4> resultStrides; 1318 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1319 return failure(); 1320 1321 // Match offset in result memref type and in static_offsets attribute. 1322 int64_t expectedOffset = 1323 extractFromI64ArrayAttr(op.static_offsets()).front(); 1324 if (resultOffset != expectedOffset) 1325 return op.emitError("expected result type with offset = ") 1326 << resultOffset << " instead of " << expectedOffset; 1327 1328 // Match strides in result memref type and in static_strides attribute. 1329 for (auto &en : llvm::enumerate(llvm::zip( 1330 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1331 int64_t resultStride = std::get<0>(en.value()); 1332 int64_t expectedStride = std::get<1>(en.value()); 1333 if (resultStride != expectedStride) 1334 return op.emitError("expected result type with stride = ") 1335 << expectedStride << " instead of " << resultStride 1336 << " in dim = " << en.index(); 1337 } 1338 } 1339 return success(); 1340 } 1341 1342 //===----------------------------------------------------------------------===// 1343 // ReshapeOp 1344 //===----------------------------------------------------------------------===// 1345 1346 static LogicalResult verify(ReshapeOp op) { 1347 Type operandType = op.source().getType(); 1348 Type resultType = op.result().getType(); 1349 1350 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1351 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1352 if (operandElementType != resultElementType) 1353 return op.emitOpError("element types of source and destination memref " 1354 "types should be the same"); 1355 1356 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1357 if (!operandMemRefType.getAffineMaps().empty()) 1358 return op.emitOpError( 1359 "source memref type should have identity affine map"); 1360 1361 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1362 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1363 if (resultMemRefType) { 1364 if (!resultMemRefType.getAffineMaps().empty()) 1365 return op.emitOpError( 1366 "result memref type should have identity affine map"); 1367 if (shapeSize == ShapedType::kDynamicSize) 1368 return op.emitOpError("cannot use shape operand with dynamic length to " 1369 "reshape to statically-ranked memref type"); 1370 if (shapeSize != resultMemRefType.getRank()) 1371 return op.emitOpError( 1372 "length of shape operand differs from the result's memref rank"); 1373 } 1374 return success(); 1375 } 1376 1377 //===----------------------------------------------------------------------===// 1378 // StoreOp 1379 //===----------------------------------------------------------------------===// 1380 1381 static LogicalResult verify(StoreOp op) { 1382 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1383 return op.emitOpError("store index operand count not equal to memref rank"); 1384 1385 return success(); 1386 } 1387 1388 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1389 SmallVectorImpl<OpFoldResult> &results) { 1390 /// store(memrefcast) -> store 1391 return foldMemRefCast(*this, getValueToStore()); 1392 } 1393 1394 //===----------------------------------------------------------------------===// 1395 // SubViewOp 1396 //===----------------------------------------------------------------------===// 1397 1398 namespace { 1399 /// Helpers to write more idiomatic operations. 1400 namespace saturated_arith { 1401 struct Wrapper { 1402 explicit Wrapper(int64_t v) : v(v) {} 1403 operator int64_t() { return v; } 1404 int64_t v; 1405 }; 1406 Wrapper operator+(Wrapper a, int64_t b) { 1407 if (ShapedType::isDynamicStrideOrOffset(a) || 1408 ShapedType::isDynamicStrideOrOffset(b)) 1409 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1410 return Wrapper(a.v + b); 1411 } 1412 Wrapper operator*(Wrapper a, int64_t b) { 1413 if (ShapedType::isDynamicStrideOrOffset(a) || 1414 ShapedType::isDynamicStrideOrOffset(b)) 1415 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1416 return Wrapper(a.v * b); 1417 } 1418 } // end namespace saturated_arith 1419 } // end namespace 1420 1421 /// A subview result type can be fully inferred from the source type and the 1422 /// static representation of offsets, sizes and strides. Special sentinels 1423 /// encode the dynamic case. 1424 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1425 ArrayRef<int64_t> leadingStaticOffsets, 1426 ArrayRef<int64_t> leadingStaticSizes, 1427 ArrayRef<int64_t> leadingStaticStrides) { 1428 // A subview may specify only a leading subset of offset/sizes/strides in 1429 // which case we complete with offset=0, sizes from memref type and strides=1. 1430 unsigned rank = sourceMemRefType.getRank(); 1431 assert(leadingStaticOffsets.size() <= rank && 1432 "unexpected leadingStaticOffsets overflow"); 1433 assert(leadingStaticSizes.size() <= rank && 1434 "unexpected leadingStaticSizes overflow"); 1435 assert(leadingStaticStrides.size() <= rank && 1436 "unexpected leadingStaticStrides overflow"); 1437 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1438 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1439 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1440 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1441 unsigned numTrailingSizes = rank - staticSizes.size(); 1442 unsigned numTrailingStrides = rank - staticStrides.size(); 1443 staticOffsets.append(numTrailingOffsets, 0); 1444 llvm::append_range(staticSizes, 1445 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1446 staticStrides.append(numTrailingStrides, 1); 1447 1448 // Extract source offset and strides. 1449 int64_t sourceOffset; 1450 SmallVector<int64_t, 4> sourceStrides; 1451 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1452 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1453 (void)res; 1454 1455 // Compute target offset whose value is: 1456 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1457 int64_t targetOffset = sourceOffset; 1458 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1459 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1460 using namespace saturated_arith; 1461 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1462 } 1463 1464 // Compute target stride whose value is: 1465 // `sourceStrides_i * staticStrides_i`. 1466 SmallVector<int64_t, 4> targetStrides; 1467 targetStrides.reserve(staticOffsets.size()); 1468 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1469 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1470 using namespace saturated_arith; 1471 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1472 } 1473 1474 // The type is now known. 1475 return MemRefType::get( 1476 staticSizes, sourceMemRefType.getElementType(), 1477 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1478 sourceMemRefType.getContext()), 1479 sourceMemRefType.getMemorySpace()); 1480 } 1481 1482 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1483 ArrayRef<OpFoldResult> leadingStaticOffsets, 1484 ArrayRef<OpFoldResult> leadingStaticSizes, 1485 ArrayRef<OpFoldResult> leadingStaticStrides) { 1486 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1487 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1488 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1489 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1490 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1491 ShapedType::kDynamicSize); 1492 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1493 staticStrides, ShapedType::kDynamicStrideOrOffset); 1494 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1495 staticSizes, staticStrides) 1496 .cast<MemRefType>(); 1497 } 1498 1499 Type SubViewOp::inferRankReducedResultType( 1500 unsigned resultRank, MemRefType sourceRankedTensorType, 1501 ArrayRef<int64_t> leadingStaticOffsets, 1502 ArrayRef<int64_t> leadingStaticSizes, 1503 ArrayRef<int64_t> leadingStaticStrides) { 1504 auto inferredType = 1505 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1506 leadingStaticSizes, leadingStaticStrides) 1507 .cast<MemRefType>(); 1508 assert(inferredType.getRank() >= resultRank && "expected "); 1509 int rankDiff = inferredType.getRank() - resultRank; 1510 if (rankDiff > 0) { 1511 auto shape = inferredType.getShape(); 1512 llvm::SmallDenseSet<unsigned> dimsToProject; 1513 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1514 SmallVector<int64_t> projectedShape; 1515 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1516 if (!dimsToProject.contains(pos)) 1517 projectedShape.push_back(shape[pos]); 1518 1519 AffineMap map; 1520 auto maps = inferredType.getAffineMaps(); 1521 if (!maps.empty() && maps.front()) 1522 map = getProjectedMap(maps.front(), dimsToProject); 1523 inferredType = 1524 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1525 inferredType.getMemorySpace()); 1526 } 1527 return inferredType; 1528 } 1529 1530 Type SubViewOp::inferRankReducedResultType( 1531 unsigned resultRank, MemRefType sourceRankedTensorType, 1532 ArrayRef<OpFoldResult> leadingStaticOffsets, 1533 ArrayRef<OpFoldResult> leadingStaticSizes, 1534 ArrayRef<OpFoldResult> leadingStaticStrides) { 1535 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1536 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1537 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1538 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1539 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1540 ShapedType::kDynamicSize); 1541 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1542 staticStrides, ShapedType::kDynamicStrideOrOffset); 1543 return SubViewOp::inferRankReducedResultType( 1544 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1545 staticStrides); 1546 } 1547 // Build a SubViewOp with mixed static and dynamic entries and custom result 1548 // type. If the type passed is nullptr, it is inferred. 1549 void SubViewOp::build(OpBuilder &b, OperationState &result, 1550 MemRefType resultType, Value source, 1551 ArrayRef<OpFoldResult> offsets, 1552 ArrayRef<OpFoldResult> sizes, 1553 ArrayRef<OpFoldResult> strides, 1554 ArrayRef<NamedAttribute> attrs) { 1555 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1556 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1557 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1558 ShapedType::kDynamicStrideOrOffset); 1559 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1560 ShapedType::kDynamicSize); 1561 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1562 ShapedType::kDynamicStrideOrOffset); 1563 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1564 // Structuring implementation this way avoids duplication between builders. 1565 if (!resultType) { 1566 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1567 staticSizes, staticStrides) 1568 .cast<MemRefType>(); 1569 } 1570 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1571 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1572 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1573 result.addAttributes(attrs); 1574 } 1575 1576 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1577 // type. 1578 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1579 ArrayRef<OpFoldResult> offsets, 1580 ArrayRef<OpFoldResult> sizes, 1581 ArrayRef<OpFoldResult> strides, 1582 ArrayRef<NamedAttribute> attrs) { 1583 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1584 } 1585 1586 // Build a SubViewOp with static entries and inferred result type. 1587 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1588 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1589 ArrayRef<int64_t> strides, 1590 ArrayRef<NamedAttribute> attrs) { 1591 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1592 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1593 return b.getI64IntegerAttr(v); 1594 })); 1595 SmallVector<OpFoldResult> sizeValues = 1596 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1597 return b.getI64IntegerAttr(v); 1598 })); 1599 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1600 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1601 return b.getI64IntegerAttr(v); 1602 })); 1603 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1604 } 1605 1606 // Build a SubViewOp with dynamic entries and custom result type. If the 1607 // type passed is nullptr, it is inferred. 1608 void SubViewOp::build(OpBuilder &b, OperationState &result, 1609 MemRefType resultType, Value source, 1610 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1611 ArrayRef<int64_t> strides, 1612 ArrayRef<NamedAttribute> attrs) { 1613 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1614 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1615 return b.getI64IntegerAttr(v); 1616 })); 1617 SmallVector<OpFoldResult> sizeValues = 1618 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1619 return b.getI64IntegerAttr(v); 1620 })); 1621 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1622 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1623 return b.getI64IntegerAttr(v); 1624 })); 1625 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1626 attrs); 1627 } 1628 1629 // Build a SubViewOp with dynamic entries and custom result type. If the type 1630 // passed is nullptr, it is inferred. 1631 void SubViewOp::build(OpBuilder &b, OperationState &result, 1632 MemRefType resultType, Value source, ValueRange offsets, 1633 ValueRange sizes, ValueRange strides, 1634 ArrayRef<NamedAttribute> attrs) { 1635 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1636 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1637 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1638 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1639 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1640 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1641 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1642 } 1643 1644 // Build a SubViewOp with dynamic entries and inferred result type. 1645 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1646 ValueRange offsets, ValueRange sizes, ValueRange strides, 1647 ArrayRef<NamedAttribute> attrs) { 1648 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1649 } 1650 1651 /// For ViewLikeOpInterface. 1652 Value SubViewOp::getViewSource() { return source(); } 1653 1654 enum SubViewVerificationResult { 1655 Success, 1656 RankTooLarge, 1657 SizeMismatch, 1658 ElemTypeMismatch, 1659 MemSpaceMismatch, 1660 AffineMapMismatch 1661 }; 1662 1663 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1664 /// This function is slight variant of `is subsequence` algorithm where 1665 /// not matching dimension must be 1. 1666 static SubViewVerificationResult 1667 isRankReducedType(Type originalType, Type candidateReducedType, 1668 std::string *errMsg = nullptr) { 1669 if (originalType == candidateReducedType) 1670 return SubViewVerificationResult::Success; 1671 if (!originalType.isa<MemRefType>()) 1672 return SubViewVerificationResult::Success; 1673 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1674 return SubViewVerificationResult::Success; 1675 1676 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1677 ShapedType candidateReducedShapedType = 1678 candidateReducedType.cast<ShapedType>(); 1679 1680 // Rank and size logic is valid for all ShapedTypes. 1681 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1682 ArrayRef<int64_t> candidateReducedShape = 1683 candidateReducedShapedType.getShape(); 1684 unsigned originalRank = originalShape.size(), 1685 candidateReducedRank = candidateReducedShape.size(); 1686 if (candidateReducedRank > originalRank) 1687 return SubViewVerificationResult::RankTooLarge; 1688 1689 auto optionalUnusedDimsMask = 1690 computeRankReductionMask(originalShape, candidateReducedShape); 1691 1692 // Sizes cannot be matched in case empty vector is returned. 1693 if (!optionalUnusedDimsMask.hasValue()) 1694 return SubViewVerificationResult::SizeMismatch; 1695 1696 if (originalShapedType.getElementType() != 1697 candidateReducedShapedType.getElementType()) 1698 return SubViewVerificationResult::ElemTypeMismatch; 1699 1700 // Strided layout logic is relevant for MemRefType only. 1701 MemRefType original = originalType.cast<MemRefType>(); 1702 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1703 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1704 return SubViewVerificationResult::MemSpaceMismatch; 1705 1706 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1707 auto inferredType = 1708 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1709 AffineMap candidateLayout; 1710 if (candidateReduced.getAffineMaps().empty()) 1711 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1712 else 1713 candidateLayout = candidateReduced.getAffineMaps().front(); 1714 assert(inferredType.getNumResults() == 1 && 1715 candidateLayout.getNumResults() == 1); 1716 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1717 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1718 if (errMsg) { 1719 llvm::raw_string_ostream os(*errMsg); 1720 os << "inferred type: " << inferredType; 1721 } 1722 return SubViewVerificationResult::AffineMapMismatch; 1723 } 1724 // Check that the difference of the affine maps simplifies to 0. 1725 AffineExpr diffExpr = 1726 inferredType.getResult(0) - candidateLayout.getResult(0); 1727 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1728 inferredType.getNumSymbols()); 1729 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1730 if (!(cst && cst.getValue() == 0)) { 1731 if (errMsg) { 1732 llvm::raw_string_ostream os(*errMsg); 1733 os << "inferred type: " << inferredType; 1734 } 1735 return SubViewVerificationResult::AffineMapMismatch; 1736 } 1737 return SubViewVerificationResult::Success; 1738 } 1739 1740 template <typename OpTy> 1741 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1742 OpTy op, Type expectedType, 1743 StringRef errMsg = "") { 1744 auto memrefType = expectedType.cast<ShapedType>(); 1745 switch (result) { 1746 case SubViewVerificationResult::Success: 1747 return success(); 1748 case SubViewVerificationResult::RankTooLarge: 1749 return op.emitError("expected result rank to be smaller or equal to ") 1750 << "the source rank. " << errMsg; 1751 case SubViewVerificationResult::SizeMismatch: 1752 return op.emitError("expected result type to be ") 1753 << expectedType 1754 << " or a rank-reduced version. (mismatch of result sizes) " 1755 << errMsg; 1756 case SubViewVerificationResult::ElemTypeMismatch: 1757 return op.emitError("expected result element type to be ") 1758 << memrefType.getElementType() << errMsg; 1759 case SubViewVerificationResult::MemSpaceMismatch: 1760 return op.emitError("expected result and source memory spaces to match.") 1761 << errMsg; 1762 case SubViewVerificationResult::AffineMapMismatch: 1763 return op.emitError("expected result type to be ") 1764 << expectedType 1765 << " or a rank-reduced version. (mismatch of result affine map) " 1766 << errMsg; 1767 } 1768 llvm_unreachable("unexpected subview verification result"); 1769 } 1770 1771 /// Verifier for SubViewOp. 1772 static LogicalResult verify(SubViewOp op) { 1773 MemRefType baseType = op.getSourceType(); 1774 MemRefType subViewType = op.getType(); 1775 1776 // The base memref and the view memref should be in the same memory space. 1777 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1778 return op.emitError("different memory spaces specified for base memref " 1779 "type ") 1780 << baseType << " and subview memref type " << subViewType; 1781 1782 // Verify that the base memref type has a strided layout map. 1783 if (!isStrided(baseType)) 1784 return op.emitError("base type ") << baseType << " is not strided"; 1785 1786 // Verify result type against inferred type. 1787 auto expectedType = SubViewOp::inferResultType( 1788 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1789 extractFromI64ArrayAttr(op.static_sizes()), 1790 extractFromI64ArrayAttr(op.static_strides())); 1791 1792 std::string errMsg; 1793 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1794 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1795 } 1796 1797 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1798 return os << "range " << range.offset << ":" << range.size << ":" 1799 << range.stride; 1800 } 1801 1802 /// Return the list of Range (i.e. offset, size, stride). Each Range 1803 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1804 /// with `b` at location `loc`. 1805 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1806 OpBuilder &b, Location loc) { 1807 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1808 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1809 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1810 SmallVector<Range, 8> res; 1811 unsigned rank = ranks[0]; 1812 res.reserve(rank); 1813 for (unsigned idx = 0; idx < rank; ++idx) { 1814 Value offset = 1815 op.isDynamicOffset(idx) 1816 ? op.getDynamicOffset(idx) 1817 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1818 Value size = op.isDynamicSize(idx) 1819 ? op.getDynamicSize(idx) 1820 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1821 Value stride = 1822 op.isDynamicStride(idx) 1823 ? op.getDynamicStride(idx) 1824 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1825 res.emplace_back(Range{offset, size, stride}); 1826 } 1827 return res; 1828 } 1829 1830 /// Infer the canonical type of the result of a subview operation. Returns a 1831 /// type with rank `resultRank` that is either the rank of the rank-reduced 1832 /// type, or the non-rank-reduced type. 1833 static MemRefType 1834 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1835 ArrayRef<OpFoldResult> mixedOffsets, 1836 ArrayRef<OpFoldResult> mixedSizes, 1837 ArrayRef<OpFoldResult> mixedStrides) { 1838 auto resultType = 1839 SubViewOp::inferRankReducedResultType( 1840 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1841 .cast<MemRefType>(); 1842 if (resultType.getRank() != resultRank) { 1843 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1844 mixedSizes, mixedStrides) 1845 .cast<MemRefType>(); 1846 } 1847 return resultType; 1848 } 1849 1850 namespace { 1851 /// Pattern to rewrite a subview op with MemRefCast arguments. 1852 /// This essentially pushes memref.cast past its consuming subview when 1853 /// `canFoldIntoConsumerOp` is true. 1854 /// 1855 /// Example: 1856 /// ``` 1857 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1858 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1859 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1860 /// ``` 1861 /// is rewritten into: 1862 /// ``` 1863 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1864 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1865 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1866 /// ``` 1867 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1868 public: 1869 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1870 1871 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1872 PatternRewriter &rewriter) const override { 1873 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1874 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1875 return matchPattern(operand, matchConstantIndex()); 1876 })) 1877 return failure(); 1878 1879 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1880 if (!castOp) 1881 return failure(); 1882 1883 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1884 return failure(); 1885 1886 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1887 /// the cast source operand type and the SubViewOp static information. This 1888 /// is the resulting type if the MemRefCastOp were folded. 1889 auto resultType = getCanonicalSubViewResultType( 1890 subViewOp.getType().getRank(), 1891 castOp.source().getType().cast<MemRefType>(), 1892 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1893 subViewOp.getMixedStrides()); 1894 Value newSubView = rewriter.create<SubViewOp>( 1895 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1896 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1897 subViewOp.static_sizes(), subViewOp.static_strides()); 1898 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1899 newSubView); 1900 return success(); 1901 } 1902 }; 1903 } // namespace 1904 1905 /// Return the canonical type of the result of a subview. 1906 struct SubViewReturnTypeCanonicalizer { 1907 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 1908 ArrayRef<OpFoldResult> mixedSizes, 1909 ArrayRef<OpFoldResult> mixedStrides) { 1910 return getCanonicalSubViewResultType(op.getType().getRank(), 1911 op.getSourceType(), mixedOffsets, 1912 mixedSizes, mixedStrides); 1913 } 1914 }; 1915 1916 /// A canonicalizer wrapper to replace SubViewOps. 1917 struct SubViewCanonicalizer { 1918 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1919 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1920 } 1921 }; 1922 1923 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1924 MLIRContext *context) { 1925 results 1926 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1927 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 1928 SubViewOpMemRefCastFolder>(context); 1929 } 1930 1931 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1932 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1933 auto sourceShapedType = source().getType().cast<ShapedType>(); 1934 1935 if (resultShapedType.hasStaticShape() && 1936 resultShapedType == sourceShapedType) { 1937 return getViewSource(); 1938 } 1939 1940 return {}; 1941 } 1942 1943 //===----------------------------------------------------------------------===// 1944 // TensorLoadOp 1945 //===----------------------------------------------------------------------===// 1946 1947 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1948 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1949 // Approximate alias analysis by conservatively folding only when no there 1950 // is no interleaved operation. 1951 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1952 bufferCast->getNextNode() == this->getOperation()) 1953 return bufferCast.tensor(); 1954 return {}; 1955 } 1956 1957 //===----------------------------------------------------------------------===// 1958 // TransposeOp 1959 //===----------------------------------------------------------------------===// 1960 1961 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1962 static MemRefType inferTransposeResultType(MemRefType memRefType, 1963 AffineMap permutationMap) { 1964 auto rank = memRefType.getRank(); 1965 auto originalSizes = memRefType.getShape(); 1966 // Compute permuted sizes. 1967 SmallVector<int64_t, 4> sizes(rank, 0); 1968 for (auto en : llvm::enumerate(permutationMap.getResults())) 1969 sizes[en.index()] = 1970 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 1971 1972 // Compute permuted strides. 1973 int64_t offset; 1974 SmallVector<int64_t, 4> strides; 1975 auto res = getStridesAndOffset(memRefType, strides, offset); 1976 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 1977 (void)res; 1978 auto map = 1979 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 1980 map = permutationMap ? map.compose(permutationMap) : map; 1981 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 1982 } 1983 1984 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 1985 AffineMapAttr permutation, 1986 ArrayRef<NamedAttribute> attrs) { 1987 auto permutationMap = permutation.getValue(); 1988 assert(permutationMap); 1989 1990 auto memRefType = in.getType().cast<MemRefType>(); 1991 // Compute result type. 1992 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 1993 1994 build(b, result, resultType, in, attrs); 1995 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 1996 } 1997 1998 // transpose $in $permutation attr-dict : type($in) `to` type(results) 1999 static void print(OpAsmPrinter &p, TransposeOp op) { 2000 p << "memref.transpose " << op.in() << " " << op.permutation(); 2001 p.printOptionalAttrDict(op->getAttrs(), 2002 {TransposeOp::getPermutationAttrName()}); 2003 p << " : " << op.in().getType() << " to " << op.getType(); 2004 } 2005 2006 static ParseResult parseTransposeOp(OpAsmParser &parser, 2007 OperationState &result) { 2008 OpAsmParser::OperandType in; 2009 AffineMap permutation; 2010 MemRefType srcType, dstType; 2011 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2012 parser.parseOptionalAttrDict(result.attributes) || 2013 parser.parseColonType(srcType) || 2014 parser.resolveOperand(in, srcType, result.operands) || 2015 parser.parseKeywordType("to", dstType) || 2016 parser.addTypeToList(dstType, result.types)) 2017 return failure(); 2018 2019 result.addAttribute(TransposeOp::getPermutationAttrName(), 2020 AffineMapAttr::get(permutation)); 2021 return success(); 2022 } 2023 2024 static LogicalResult verify(TransposeOp op) { 2025 if (!op.permutation().isPermutation()) 2026 return op.emitOpError("expected a permutation map"); 2027 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2028 return op.emitOpError( 2029 "expected a permutation map of same rank as the input"); 2030 2031 auto srcType = op.in().getType().cast<MemRefType>(); 2032 auto dstType = op.getType().cast<MemRefType>(); 2033 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2034 if (dstType != transposedType) 2035 return op.emitOpError("output type ") 2036 << dstType << " does not match transposed input type " << srcType 2037 << ", " << transposedType; 2038 return success(); 2039 } 2040 2041 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2042 if (succeeded(foldMemRefCast(*this))) 2043 return getResult(); 2044 return {}; 2045 } 2046 2047 //===----------------------------------------------------------------------===// 2048 // ViewOp 2049 //===----------------------------------------------------------------------===// 2050 2051 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2052 OpAsmParser::OperandType srcInfo; 2053 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2054 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2055 auto indexType = parser.getBuilder().getIndexType(); 2056 Type srcType, dstType; 2057 llvm::SMLoc offsetLoc; 2058 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2059 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2060 return failure(); 2061 2062 if (offsetInfo.size() != 1) 2063 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2064 2065 return failure( 2066 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2067 parser.parseOptionalAttrDict(result.attributes) || 2068 parser.parseColonType(srcType) || 2069 parser.resolveOperand(srcInfo, srcType, result.operands) || 2070 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2071 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2072 parser.parseKeywordType("to", dstType) || 2073 parser.addTypeToList(dstType, result.types)); 2074 } 2075 2076 static void print(OpAsmPrinter &p, ViewOp op) { 2077 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2078 p.printOperand(op.byte_shift()); 2079 p << "][" << op.sizes() << ']'; 2080 p.printOptionalAttrDict(op->getAttrs()); 2081 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2082 } 2083 2084 static LogicalResult verify(ViewOp op) { 2085 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2086 auto viewType = op.getType(); 2087 2088 // The base memref should have identity layout map (or none). 2089 if (baseType.getAffineMaps().size() > 1 || 2090 (baseType.getAffineMaps().size() == 1 && 2091 !baseType.getAffineMaps()[0].isIdentity())) 2092 return op.emitError("unsupported map for base memref type ") << baseType; 2093 2094 // The result memref should have identity layout map (or none). 2095 if (viewType.getAffineMaps().size() > 1 || 2096 (viewType.getAffineMaps().size() == 1 && 2097 !viewType.getAffineMaps()[0].isIdentity())) 2098 return op.emitError("unsupported map for result memref type ") << viewType; 2099 2100 // The base memref and the view memref should be in the same memory space. 2101 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2102 return op.emitError("different memory spaces specified for base memref " 2103 "type ") 2104 << baseType << " and view memref type " << viewType; 2105 2106 // Verify that we have the correct number of sizes for the result type. 2107 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2108 if (op.sizes().size() != numDynamicDims) 2109 return op.emitError("incorrect number of size operands for type ") 2110 << viewType; 2111 2112 return success(); 2113 } 2114 2115 Value ViewOp::getViewSource() { return source(); } 2116 2117 namespace { 2118 2119 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2120 using OpRewritePattern<ViewOp>::OpRewritePattern; 2121 2122 LogicalResult matchAndRewrite(ViewOp viewOp, 2123 PatternRewriter &rewriter) const override { 2124 // Return if none of the operands are constants. 2125 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2126 return matchPattern(operand, matchConstantIndex()); 2127 })) 2128 return failure(); 2129 2130 // Get result memref type. 2131 auto memrefType = viewOp.getType(); 2132 2133 // Get offset from old memref view type 'memRefType'. 2134 int64_t oldOffset; 2135 SmallVector<int64_t, 4> oldStrides; 2136 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2137 return failure(); 2138 assert(oldOffset == 0 && "Expected 0 offset"); 2139 2140 SmallVector<Value, 4> newOperands; 2141 2142 // Offset cannot be folded into result type. 2143 2144 // Fold any dynamic dim operands which are produced by a constant. 2145 SmallVector<int64_t, 4> newShapeConstants; 2146 newShapeConstants.reserve(memrefType.getRank()); 2147 2148 unsigned dynamicDimPos = 0; 2149 unsigned rank = memrefType.getRank(); 2150 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2151 int64_t dimSize = memrefType.getDimSize(dim); 2152 // If this is already static dimension, keep it. 2153 if (!ShapedType::isDynamic(dimSize)) { 2154 newShapeConstants.push_back(dimSize); 2155 continue; 2156 } 2157 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2158 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2159 // Dynamic shape dimension will be folded. 2160 newShapeConstants.push_back(constantIndexOp.getValue()); 2161 } else { 2162 // Dynamic shape dimension not folded; copy operand from old memref. 2163 newShapeConstants.push_back(dimSize); 2164 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2165 } 2166 dynamicDimPos++; 2167 } 2168 2169 // Create new memref type with constant folded dims. 2170 MemRefType newMemRefType = 2171 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2172 // Nothing new, don't fold. 2173 if (newMemRefType == memrefType) 2174 return failure(); 2175 2176 // Create new ViewOp. 2177 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2178 viewOp.getOperand(0), 2179 viewOp.byte_shift(), newOperands); 2180 // Insert a cast so we have the same type as the old memref type. 2181 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2182 return success(); 2183 } 2184 }; 2185 2186 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2187 using OpRewritePattern<ViewOp>::OpRewritePattern; 2188 2189 LogicalResult matchAndRewrite(ViewOp viewOp, 2190 PatternRewriter &rewriter) const override { 2191 Value memrefOperand = viewOp.getOperand(0); 2192 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2193 if (!memrefCastOp) 2194 return failure(); 2195 Value allocOperand = memrefCastOp.getOperand(); 2196 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2197 if (!allocOp) 2198 return failure(); 2199 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2200 viewOp.byte_shift(), viewOp.sizes()); 2201 return success(); 2202 } 2203 }; 2204 2205 } // end anonymous namespace 2206 2207 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2208 MLIRContext *context) { 2209 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2210 } 2211 2212 //===----------------------------------------------------------------------===// 2213 // TableGen'd op method definitions 2214 //===----------------------------------------------------------------------===// 2215 2216 #define GET_OP_CLASSES 2217 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2218