1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 if (arith::ConstantOp::isBuildableWith(value, type)) 34 return builder.create<arith::ConstantOp>(loc, value, type); 35 if (ConstantOp::isBuildableWith(value, type)) 36 return builder.create<ConstantOp>(loc, value, type); 37 return nullptr; 38 } 39 40 //===----------------------------------------------------------------------===// 41 // Common canonicalization pattern support logic 42 //===----------------------------------------------------------------------===// 43 44 /// This is a common class used for patterns of the form 45 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 46 /// into the root operation directly. 47 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 48 bool folded = false; 49 for (OpOperand &operand : op->getOpOperands()) { 50 auto cast = operand.get().getDefiningOp<CastOp>(); 51 if (cast && operand.get() != inner && 52 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 53 operand.set(cast.getOperand()); 54 folded = true; 55 } 56 } 57 return success(folded); 58 } 59 60 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 61 /// type. 62 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 63 if (auto memref = type.dyn_cast<MemRefType>()) 64 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 65 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 66 return UnrankedTensorType::get(memref.getElementType()); 67 return NoneType::get(type.getContext()); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // AllocOp / AllocaOp 72 //===----------------------------------------------------------------------===// 73 74 template <typename AllocLikeOp> 75 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 76 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 77 "applies to only alloc or alloca"); 78 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 79 if (!memRefType) 80 return op.emitOpError("result must be a memref"); 81 82 if (static_cast<int64_t>(op.dynamicSizes().size()) != 83 memRefType.getNumDynamicDims()) 84 return op.emitOpError("dimension operand count does not equal memref " 85 "dynamic dimension count"); 86 87 unsigned numSymbols = 0; 88 if (!memRefType.getLayout().isIdentity()) 89 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 90 if (op.symbolOperands().size() != numSymbols) 91 return op.emitOpError("symbol operand count does not equal memref symbol " 92 "count: expected ") 93 << numSymbols << ", got " << op.symbolOperands().size(); 94 95 return success(); 96 } 97 98 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 99 100 static LogicalResult verify(AllocaOp op) { 101 // An alloca op needs to have an ancestor with an allocation scope trait. 102 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 103 return op.emitOpError( 104 "requires an ancestor op with AutomaticAllocationScope trait"); 105 106 return verifyAllocLikeOp(op); 107 } 108 109 namespace { 110 /// Fold constant dimensions into an alloc like operation. 111 template <typename AllocLikeOp> 112 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 113 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 114 115 LogicalResult matchAndRewrite(AllocLikeOp alloc, 116 PatternRewriter &rewriter) const override { 117 // Check to see if any dimensions operands are constants. If so, we can 118 // substitute and drop them. 119 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 120 return matchPattern(operand, matchConstantIndex()); 121 })) 122 return failure(); 123 124 auto memrefType = alloc.getType(); 125 126 // Ok, we have one or more constant operands. Collect the non-constant ones 127 // and keep track of the resultant memref type to build. 128 SmallVector<int64_t, 4> newShapeConstants; 129 newShapeConstants.reserve(memrefType.getRank()); 130 SmallVector<Value, 4> dynamicSizes; 131 132 unsigned dynamicDimPos = 0; 133 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 134 int64_t dimSize = memrefType.getDimSize(dim); 135 // If this is already static dimension, keep it. 136 if (dimSize != -1) { 137 newShapeConstants.push_back(dimSize); 138 continue; 139 } 140 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 141 auto *defOp = dynamicSize.getDefiningOp(); 142 if (auto constantIndexOp = 143 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 144 // Dynamic shape dimension will be folded. 145 newShapeConstants.push_back(constantIndexOp.value()); 146 } else { 147 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 148 newShapeConstants.push_back(-1); 149 dynamicSizes.push_back(dynamicSize); 150 } 151 dynamicDimPos++; 152 } 153 154 // Create new memref type (which will have fewer dynamic dimensions). 155 MemRefType newMemRefType = 156 MemRefType::Builder(memrefType).setShape(newShapeConstants); 157 assert(static_cast<int64_t>(dynamicSizes.size()) == 158 newMemRefType.getNumDynamicDims()); 159 160 // Create and insert the alloc op for the new memref. 161 auto newAlloc = rewriter.create<AllocLikeOp>( 162 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 163 alloc.alignmentAttr()); 164 // Insert a cast so we have the same type as the old alloc. 165 auto resultCast = 166 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 167 168 rewriter.replaceOp(alloc, {resultCast}); 169 return success(); 170 } 171 }; 172 173 /// Fold alloc operations with no users or only store and dealloc uses. 174 template <typename T> 175 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 176 using OpRewritePattern<T>::OpRewritePattern; 177 178 LogicalResult matchAndRewrite(T alloc, 179 PatternRewriter &rewriter) const override { 180 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 181 if (auto storeOp = dyn_cast<StoreOp>(op)) 182 return storeOp.value() == alloc; 183 return !isa<DeallocOp>(op); 184 })) 185 return failure(); 186 187 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 188 rewriter.eraseOp(user); 189 190 rewriter.eraseOp(alloc); 191 return success(); 192 } 193 }; 194 } // namespace 195 196 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 197 MLIRContext *context) { 198 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 199 } 200 201 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 202 MLIRContext *context) { 203 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 204 context); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // AllocaScopeOp 209 //===----------------------------------------------------------------------===// 210 211 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 212 bool printBlockTerminators = false; 213 214 p << " "; 215 if (!op.results().empty()) { 216 p << " -> (" << op.getResultTypes() << ")"; 217 printBlockTerminators = true; 218 } 219 p.printRegion(op.bodyRegion(), 220 /*printEntryBlockArgs=*/false, 221 /*printBlockTerminators=*/printBlockTerminators); 222 p.printOptionalAttrDict(op->getAttrs()); 223 } 224 225 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 226 OperationState &result) { 227 // Create a region for the body. 228 result.regions.reserve(1); 229 Region *bodyRegion = result.addRegion(); 230 231 // Parse optional results type list. 232 if (parser.parseOptionalArrowTypeList(result.types)) 233 return failure(); 234 235 // Parse the body region. 236 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 237 return failure(); 238 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 239 result.location); 240 241 // Parse the optional attribute list. 242 if (parser.parseOptionalAttrDict(result.attributes)) 243 return failure(); 244 245 return success(); 246 } 247 248 static LogicalResult verify(AllocaScopeOp op) { 249 if (failed(RegionBranchOpInterface::verifyTypes(op))) 250 return failure(); 251 252 return success(); 253 } 254 255 void AllocaScopeOp::getSuccessorRegions( 256 Optional<unsigned> index, ArrayRef<Attribute> operands, 257 SmallVectorImpl<RegionSuccessor> ®ions) { 258 if (index.hasValue()) { 259 regions.push_back(RegionSuccessor(getResults())); 260 return; 261 } 262 263 regions.push_back(RegionSuccessor(&bodyRegion())); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // AssumeAlignmentOp 268 //===----------------------------------------------------------------------===// 269 270 static LogicalResult verify(AssumeAlignmentOp op) { 271 unsigned alignment = op.alignment(); 272 if (!llvm::isPowerOf2_32(alignment)) 273 return op.emitOpError("alignment must be power of 2"); 274 return success(); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // CastOp 279 //===----------------------------------------------------------------------===// 280 281 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 282 /// source memref. This is useful to to fold a memref.cast into a consuming op 283 /// and implement canonicalization patterns for ops in different dialects that 284 /// may consume the results of memref.cast operations. Such foldable memref.cast 285 /// operations are typically inserted as `view` and `subview` ops are 286 /// canonicalized, to preserve the type compatibility of their uses. 287 /// 288 /// Returns true when all conditions are met: 289 /// 1. source and result are ranked memrefs with strided semantics and same 290 /// element type and rank. 291 /// 2. each of the source's size, offset or stride has more static information 292 /// than the corresponding result's size, offset or stride. 293 /// 294 /// Example 1: 295 /// ```mlir 296 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 297 /// %2 = consumer %1 ... : memref<?x?xf32> ... 298 /// ``` 299 /// 300 /// may fold into: 301 /// 302 /// ```mlir 303 /// %2 = consumer %0 ... : memref<8x16xf32> ... 304 /// ``` 305 /// 306 /// Example 2: 307 /// ``` 308 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 309 /// to memref<?x?xf32> 310 /// consumer %1 : memref<?x?xf32> ... 311 /// ``` 312 /// 313 /// may fold into: 314 /// 315 /// ``` 316 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 317 /// ``` 318 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 319 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 320 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 321 322 // Requires ranked MemRefType. 323 if (!sourceType || !resultType) 324 return false; 325 326 // Requires same elemental type. 327 if (sourceType.getElementType() != resultType.getElementType()) 328 return false; 329 330 // Requires same rank. 331 if (sourceType.getRank() != resultType.getRank()) 332 return false; 333 334 // Only fold casts between strided memref forms. 335 int64_t sourceOffset, resultOffset; 336 SmallVector<int64_t, 4> sourceStrides, resultStrides; 337 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 338 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 339 return false; 340 341 // If cast is towards more static sizes along any dimension, don't fold. 342 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 343 auto ss = std::get<0>(it), st = std::get<1>(it); 344 if (ss != st) 345 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 346 return false; 347 } 348 349 // If cast is towards more static offset along any dimension, don't fold. 350 if (sourceOffset != resultOffset) 351 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 352 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 353 return false; 354 355 // If cast is towards more static strides along any dimension, don't fold. 356 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 357 auto ss = std::get<0>(it), st = std::get<1>(it); 358 if (ss != st) 359 if (MemRefType::isDynamicStrideOrOffset(ss) && 360 !MemRefType::isDynamicStrideOrOffset(st)) 361 return false; 362 } 363 364 return true; 365 } 366 367 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 368 if (inputs.size() != 1 || outputs.size() != 1) 369 return false; 370 Type a = inputs.front(), b = outputs.front(); 371 auto aT = a.dyn_cast<MemRefType>(); 372 auto bT = b.dyn_cast<MemRefType>(); 373 374 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 375 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 376 377 if (aT && bT) { 378 if (aT.getElementType() != bT.getElementType()) 379 return false; 380 if (aT.getLayout() != bT.getLayout()) { 381 int64_t aOffset, bOffset; 382 SmallVector<int64_t, 4> aStrides, bStrides; 383 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 384 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 385 aStrides.size() != bStrides.size()) 386 return false; 387 388 // Strides along a dimension/offset are compatible if the value in the 389 // source memref is static and the value in the target memref is the 390 // same. They are also compatible if either one is dynamic (see 391 // description of MemRefCastOp for details). 392 auto checkCompatible = [](int64_t a, int64_t b) { 393 return (a == MemRefType::getDynamicStrideOrOffset() || 394 b == MemRefType::getDynamicStrideOrOffset() || a == b); 395 }; 396 if (!checkCompatible(aOffset, bOffset)) 397 return false; 398 for (auto aStride : enumerate(aStrides)) 399 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 400 return false; 401 } 402 if (aT.getMemorySpace() != bT.getMemorySpace()) 403 return false; 404 405 // They must have the same rank, and any specified dimensions must match. 406 if (aT.getRank() != bT.getRank()) 407 return false; 408 409 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 410 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 411 if (aDim != -1 && bDim != -1 && aDim != bDim) 412 return false; 413 } 414 return true; 415 } else { 416 if (!aT && !uaT) 417 return false; 418 if (!bT && !ubT) 419 return false; 420 // Unranked to unranked casting is unsupported 421 if (uaT && ubT) 422 return false; 423 424 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 425 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 426 if (aEltType != bEltType) 427 return false; 428 429 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 430 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 431 if (aMemSpace != bMemSpace) 432 return false; 433 434 return true; 435 } 436 437 return false; 438 } 439 440 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 441 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 442 } 443 444 //===----------------------------------------------------------------------===// 445 // DeallocOp 446 //===----------------------------------------------------------------------===// 447 448 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 449 SmallVectorImpl<OpFoldResult> &results) { 450 /// dealloc(memrefcast) -> dealloc 451 return foldMemRefCast(*this); 452 } 453 454 //===----------------------------------------------------------------------===// 455 // DimOp 456 //===----------------------------------------------------------------------===// 457 458 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 459 int64_t index) { 460 auto loc = result.location; 461 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 462 build(builder, result, source, indexValue); 463 } 464 465 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 466 Value index) { 467 auto indexTy = builder.getIndexType(); 468 build(builder, result, indexTy, source, index); 469 } 470 471 Optional<int64_t> DimOp::getConstantIndex() { 472 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 473 return constantOp.getValue().cast<IntegerAttr>().getInt(); 474 return {}; 475 } 476 477 static LogicalResult verify(DimOp op) { 478 // Assume unknown index to be in range. 479 Optional<int64_t> index = op.getConstantIndex(); 480 if (!index.hasValue()) 481 return success(); 482 483 // Check that constant index is not knowingly out of range. 484 auto type = op.source().getType(); 485 if (auto memrefType = type.dyn_cast<MemRefType>()) { 486 if (index.getValue() >= memrefType.getRank()) 487 return op.emitOpError("index is out of range"); 488 } else if (type.isa<UnrankedMemRefType>()) { 489 // Assume index to be in range. 490 } else { 491 llvm_unreachable("expected operand with memref type"); 492 } 493 return success(); 494 } 495 496 /// Return a map with key being elements in `vals` and data being number of 497 /// occurences of it. Use std::map, since the `vals` here are strides and the 498 /// dynamic stride value is the same as the tombstone value for 499 /// `DenseMap<int64_t>`. 500 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 501 std::map<int64_t, unsigned> numOccurences; 502 for (auto val : vals) 503 numOccurences[val]++; 504 return numOccurences; 505 } 506 507 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 508 /// to be a subset of `originalType` with some `1` entries erased, return the 509 /// set of indices that specifies which of the entries of `originalShape` are 510 /// dropped to obtain `reducedShape`. 511 /// This accounts for cases where there are multiple unit-dims, but only a 512 /// subset of those are dropped. For MemRefTypes these can be disambiguated 513 /// using the strides. If a dimension is dropped the stride must be dropped too. 514 static llvm::Optional<llvm::SmallDenseSet<unsigned>> 515 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 516 ArrayRef<OpFoldResult> sizes) { 517 llvm::SmallDenseSet<unsigned> unusedDims; 518 if (originalType.getRank() == reducedType.getRank()) 519 return unusedDims; 520 521 for (auto dim : llvm::enumerate(sizes)) 522 if (auto attr = dim.value().dyn_cast<Attribute>()) 523 if (attr.cast<IntegerAttr>().getInt() == 1) 524 unusedDims.insert(dim.index()); 525 526 SmallVector<int64_t> originalStrides, candidateStrides; 527 int64_t originalOffset, candidateOffset; 528 if (failed( 529 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 530 failed( 531 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 532 return llvm::None; 533 534 // For memrefs, a dimension is truly dropped if its corresponding stride is 535 // also dropped. This is particularly important when more than one of the dims 536 // is 1. Track the number of occurences of the strides in the original type 537 // and the candidate type. For each unused dim that stride should not be 538 // present in the candidate type. Note that there could be multiple dimensions 539 // that have the same size. We dont need to exactly figure out which dim 540 // corresponds to which stride, we just need to verify that the number of 541 // reptitions of a stride in the original + number of unused dims with that 542 // stride == number of repititions of a stride in the candidate. 543 std::map<int64_t, unsigned> currUnaccountedStrides = 544 getNumOccurences(originalStrides); 545 std::map<int64_t, unsigned> candidateStridesNumOccurences = 546 getNumOccurences(candidateStrides); 547 llvm::SmallDenseSet<unsigned> prunedUnusedDims; 548 for (unsigned dim : unusedDims) { 549 int64_t originalStride = originalStrides[dim]; 550 if (currUnaccountedStrides[originalStride] > 551 candidateStridesNumOccurences[originalStride]) { 552 // This dim can be treated as dropped. 553 currUnaccountedStrides[originalStride]--; 554 continue; 555 } 556 if (currUnaccountedStrides[originalStride] == 557 candidateStridesNumOccurences[originalStride]) { 558 // The stride for this is not dropped. Keep as is. 559 prunedUnusedDims.insert(dim); 560 continue; 561 } 562 if (currUnaccountedStrides[originalStride] < 563 candidateStridesNumOccurences[originalStride]) { 564 // This should never happen. Cant have a stride in the reduced rank type 565 // that wasnt in the original one. 566 return llvm::None; 567 } 568 } 569 570 for (auto prunedDim : prunedUnusedDims) 571 unusedDims.erase(prunedDim); 572 if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) 573 return llvm::None; 574 return unusedDims; 575 } 576 577 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() { 578 MemRefType sourceType = getSourceType(); 579 MemRefType resultType = getType(); 580 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 581 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 582 assert(unusedDims && "unable to find unused dims of subview"); 583 return *unusedDims; 584 } 585 586 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 587 // All forms of folding require a known index. 588 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 589 if (!index) 590 return {}; 591 592 // Folding for unranked types (UnrankedMemRefType) is not supported. 593 auto memrefType = source().getType().dyn_cast<MemRefType>(); 594 if (!memrefType) 595 return {}; 596 597 // Fold if the shape extent along the given index is known. 598 if (!memrefType.isDynamicDim(index.getInt())) { 599 Builder builder(getContext()); 600 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 601 } 602 603 // The size at the given index is now known to be a dynamic size. 604 unsigned unsignedIndex = index.getValue().getZExtValue(); 605 606 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 607 Operation *definingOp = source().getDefiningOp(); 608 609 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 610 return *(alloc.getDynamicSizes().begin() + 611 memrefType.getDynamicDimIndex(unsignedIndex)); 612 613 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 614 return *(alloca.getDynamicSizes().begin() + 615 memrefType.getDynamicDimIndex(unsignedIndex)); 616 617 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 618 return *(view.getDynamicSizes().begin() + 619 memrefType.getDynamicDimIndex(unsignedIndex)); 620 621 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 622 llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims(); 623 unsigned resultIndex = 0; 624 unsigned sourceRank = subview.getSourceType().getRank(); 625 unsigned sourceIndex = 0; 626 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 627 if (unusedDims.count(i)) 628 continue; 629 if (resultIndex == unsignedIndex) { 630 sourceIndex = i; 631 break; 632 } 633 resultIndex++; 634 } 635 assert(subview.isDynamicSize(sourceIndex) && 636 "expected dynamic subview size"); 637 return subview.getDynamicSize(sourceIndex); 638 } 639 640 if (auto sizeInterface = 641 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 642 assert(sizeInterface.isDynamicSize(unsignedIndex) && 643 "Expected dynamic subview size"); 644 return sizeInterface.getDynamicSize(unsignedIndex); 645 } 646 647 // dim(memrefcast) -> dim 648 if (succeeded(foldMemRefCast(*this))) 649 return getResult(); 650 651 return {}; 652 } 653 654 namespace { 655 /// Fold dim of a memref reshape operation to a load into the reshape's shape 656 /// operand. 657 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 658 using OpRewritePattern<DimOp>::OpRewritePattern; 659 660 LogicalResult matchAndRewrite(DimOp dim, 661 PatternRewriter &rewriter) const override { 662 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 663 664 if (!reshape) 665 return failure(); 666 667 // Place the load directly after the reshape to ensure that the shape memref 668 // was not mutated. 669 rewriter.setInsertionPointAfter(reshape); 670 Location loc = dim.getLoc(); 671 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 672 if (load.getType() != dim.getType()) 673 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 674 rewriter.replaceOp(dim, load); 675 return success(); 676 } 677 }; 678 679 } // namespace 680 681 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 682 MLIRContext *context) { 683 results.add<DimOfMemRefReshape>(context); 684 } 685 686 // --------------------------------------------------------------------------- 687 // DmaStartOp 688 // --------------------------------------------------------------------------- 689 690 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 691 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 692 ValueRange destIndices, Value numElements, 693 Value tagMemRef, ValueRange tagIndices, Value stride, 694 Value elementsPerStride) { 695 result.addOperands(srcMemRef); 696 result.addOperands(srcIndices); 697 result.addOperands(destMemRef); 698 result.addOperands(destIndices); 699 result.addOperands({numElements, tagMemRef}); 700 result.addOperands(tagIndices); 701 if (stride) 702 result.addOperands({stride, elementsPerStride}); 703 } 704 705 static void print(OpAsmPrinter &p, DmaStartOp op) { 706 p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " 707 << op.getDstMemRef() << '[' << op.getDstIndices() << "], " 708 << op.getNumElements() << ", " << op.getTagMemRef() << '[' 709 << op.getTagIndices() << ']'; 710 if (op.isStrided()) 711 p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); 712 713 p.printOptionalAttrDict(op->getAttrs()); 714 p << " : " << op.getSrcMemRef().getType() << ", " 715 << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); 716 } 717 718 // Parse DmaStartOp. 719 // Ex: 720 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 721 // %tag[%index], %stride, %num_elt_per_stride : 722 // : memref<3076 x f32, 0>, 723 // memref<1024 x f32, 2>, 724 // memref<1 x i32> 725 // 726 static ParseResult parseDmaStartOp(OpAsmParser &parser, 727 OperationState &result) { 728 OpAsmParser::OperandType srcMemRefInfo; 729 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 730 OpAsmParser::OperandType dstMemRefInfo; 731 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 732 OpAsmParser::OperandType numElementsInfo; 733 OpAsmParser::OperandType tagMemrefInfo; 734 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 735 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 736 737 SmallVector<Type, 3> types; 738 auto indexType = parser.getBuilder().getIndexType(); 739 740 // Parse and resolve the following list of operands: 741 // *) source memref followed by its indices (in square brackets). 742 // *) destination memref followed by its indices (in square brackets). 743 // *) dma size in KiB. 744 if (parser.parseOperand(srcMemRefInfo) || 745 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 746 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 747 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 748 parser.parseComma() || parser.parseOperand(numElementsInfo) || 749 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 750 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 751 return failure(); 752 753 // Parse optional stride and elements per stride. 754 if (parser.parseTrailingOperandList(strideInfo)) 755 return failure(); 756 757 bool isStrided = strideInfo.size() == 2; 758 if (!strideInfo.empty() && !isStrided) { 759 return parser.emitError(parser.getNameLoc(), 760 "expected two stride related operands"); 761 } 762 763 if (parser.parseColonTypeList(types)) 764 return failure(); 765 if (types.size() != 3) 766 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 767 768 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 769 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 770 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 771 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 772 // size should be an index. 773 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 774 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 775 // tag indices should be index. 776 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 777 return failure(); 778 779 if (isStrided) { 780 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 781 return failure(); 782 } 783 784 return success(); 785 } 786 787 static LogicalResult verify(DmaStartOp op) { 788 unsigned numOperands = op.getNumOperands(); 789 790 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 791 // the number of elements. 792 if (numOperands < 4) 793 return op.emitOpError("expected at least 4 operands"); 794 795 // Check types of operands. The order of these calls is important: the later 796 // calls rely on some type properties to compute the operand position. 797 // 1. Source memref. 798 if (!op.getSrcMemRef().getType().isa<MemRefType>()) 799 return op.emitOpError("expected source to be of memref type"); 800 if (numOperands < op.getSrcMemRefRank() + 4) 801 return op.emitOpError() 802 << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; 803 if (!op.getSrcIndices().empty() && 804 !llvm::all_of(op.getSrcIndices().getTypes(), 805 [](Type t) { return t.isIndex(); })) 806 return op.emitOpError("expected source indices to be of index type"); 807 808 // 2. Destination memref. 809 if (!op.getDstMemRef().getType().isa<MemRefType>()) 810 return op.emitOpError("expected destination to be of memref type"); 811 unsigned numExpectedOperands = 812 op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; 813 if (numOperands < numExpectedOperands) 814 return op.emitOpError() 815 << "expected at least " << numExpectedOperands << " operands"; 816 if (!op.getDstIndices().empty() && 817 !llvm::all_of(op.getDstIndices().getTypes(), 818 [](Type t) { return t.isIndex(); })) 819 return op.emitOpError("expected destination indices to be of index type"); 820 821 // 3. Number of elements. 822 if (!op.getNumElements().getType().isIndex()) 823 return op.emitOpError("expected num elements to be of index type"); 824 825 // 4. Tag memref. 826 if (!op.getTagMemRef().getType().isa<MemRefType>()) 827 return op.emitOpError("expected tag to be of memref type"); 828 numExpectedOperands += op.getTagMemRefRank(); 829 if (numOperands < numExpectedOperands) 830 return op.emitOpError() 831 << "expected at least " << numExpectedOperands << " operands"; 832 if (!op.getTagIndices().empty() && 833 !llvm::all_of(op.getTagIndices().getTypes(), 834 [](Type t) { return t.isIndex(); })) 835 return op.emitOpError("expected tag indices to be of index type"); 836 837 // Optional stride-related operands must be either both present or both 838 // absent. 839 if (numOperands != numExpectedOperands && 840 numOperands != numExpectedOperands + 2) 841 return op.emitOpError("incorrect number of operands"); 842 843 // 5. Strides. 844 if (op.isStrided()) { 845 if (!op.getStride().getType().isIndex() || 846 !op.getNumElementsPerStride().getType().isIndex()) 847 return op.emitOpError( 848 "expected stride and num elements per stride to be of type index"); 849 } 850 851 return success(); 852 } 853 854 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 855 SmallVectorImpl<OpFoldResult> &results) { 856 /// dma_start(memrefcast) -> dma_start 857 return foldMemRefCast(*this); 858 } 859 860 // --------------------------------------------------------------------------- 861 // DmaWaitOp 862 // --------------------------------------------------------------------------- 863 864 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 865 SmallVectorImpl<OpFoldResult> &results) { 866 /// dma_wait(memrefcast) -> dma_wait 867 return foldMemRefCast(*this); 868 } 869 870 static LogicalResult verify(DmaWaitOp op) { 871 // Check that the number of tag indices matches the tagMemRef rank. 872 unsigned numTagIndices = op.tagIndices().size(); 873 unsigned tagMemRefRank = op.getTagMemRefRank(); 874 if (numTagIndices != tagMemRefRank) 875 return op.emitOpError() << "expected tagIndices to have the same number of " 876 "elements as the tagMemRef rank, expected " 877 << tagMemRefRank << ", but got " << numTagIndices; 878 return success(); 879 } 880 881 //===----------------------------------------------------------------------===// 882 // GlobalOp 883 //===----------------------------------------------------------------------===// 884 885 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 886 TypeAttr type, 887 Attribute initialValue) { 888 p << type; 889 if (!op.isExternal()) { 890 p << " = "; 891 if (op.isUninitialized()) 892 p << "uninitialized"; 893 else 894 p.printAttributeWithoutType(initialValue); 895 } 896 } 897 898 static ParseResult 899 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 900 Attribute &initialValue) { 901 Type type; 902 if (parser.parseType(type)) 903 return failure(); 904 905 auto memrefType = type.dyn_cast<MemRefType>(); 906 if (!memrefType || !memrefType.hasStaticShape()) 907 return parser.emitError(parser.getNameLoc()) 908 << "type should be static shaped memref, but got " << type; 909 typeAttr = TypeAttr::get(type); 910 911 if (parser.parseOptionalEqual()) 912 return success(); 913 914 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 915 initialValue = UnitAttr::get(parser.getContext()); 916 return success(); 917 } 918 919 Type tensorType = getTensorTypeFromMemRefType(memrefType); 920 if (parser.parseAttribute(initialValue, tensorType)) 921 return failure(); 922 if (!initialValue.isa<ElementsAttr>()) 923 return parser.emitError(parser.getNameLoc()) 924 << "initial value should be a unit or elements attribute"; 925 return success(); 926 } 927 928 static LogicalResult verify(GlobalOp op) { 929 auto memrefType = op.type().dyn_cast<MemRefType>(); 930 if (!memrefType || !memrefType.hasStaticShape()) 931 return op.emitOpError("type should be static shaped memref, but got ") 932 << op.type(); 933 934 // Verify that the initial value, if present, is either a unit attribute or 935 // an elements attribute. 936 if (op.initial_value().hasValue()) { 937 Attribute initValue = op.initial_value().getValue(); 938 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 939 return op.emitOpError("initial value should be a unit or elements " 940 "attribute, but got ") 941 << initValue; 942 943 // Check that the type of the initial value is compatible with the type of 944 // the global variable. 945 if (initValue.isa<ElementsAttr>()) { 946 Type initType = initValue.getType(); 947 Type tensorType = getTensorTypeFromMemRefType(memrefType); 948 if (initType != tensorType) 949 return op.emitOpError("initial value expected to be of type ") 950 << tensorType << ", but was of type " << initType; 951 } 952 } 953 954 if (Optional<uint64_t> alignAttr = op.alignment()) { 955 uint64_t alignment = alignAttr.getValue(); 956 957 if (!llvm::isPowerOf2_64(alignment)) 958 return op->emitError() << "alignment attribute value " << alignment 959 << " is not a power of 2"; 960 } 961 962 // TODO: verify visibility for declarations. 963 return success(); 964 } 965 966 //===----------------------------------------------------------------------===// 967 // GetGlobalOp 968 //===----------------------------------------------------------------------===// 969 970 LogicalResult 971 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 972 // Verify that the result type is same as the type of the referenced 973 // memref.global op. 974 auto global = 975 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 976 if (!global) 977 return emitOpError("'") 978 << name() << "' does not reference a valid global memref"; 979 980 Type resultType = result().getType(); 981 if (global.type() != resultType) 982 return emitOpError("result type ") 983 << resultType << " does not match type " << global.type() 984 << " of the global memref @" << name(); 985 return success(); 986 } 987 988 //===----------------------------------------------------------------------===// 989 // LoadOp 990 //===----------------------------------------------------------------------===// 991 992 static LogicalResult verify(LoadOp op) { 993 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 994 return op.emitOpError("incorrect number of indices for load"); 995 return success(); 996 } 997 998 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 999 /// load(memrefcast) -> load 1000 if (succeeded(foldMemRefCast(*this))) 1001 return getResult(); 1002 return OpFoldResult(); 1003 } 1004 1005 //===----------------------------------------------------------------------===// 1006 // PrefetchOp 1007 //===----------------------------------------------------------------------===// 1008 1009 static void print(OpAsmPrinter &p, PrefetchOp op) { 1010 p << " " << op.memref() << '['; 1011 p.printOperands(op.indices()); 1012 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1013 p << ", locality<" << op.localityHint(); 1014 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1015 p.printOptionalAttrDict( 1016 op->getAttrs(), 1017 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1018 p << " : " << op.getMemRefType(); 1019 } 1020 1021 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1022 OperationState &result) { 1023 OpAsmParser::OperandType memrefInfo; 1024 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1025 IntegerAttr localityHint; 1026 MemRefType type; 1027 StringRef readOrWrite, cacheType; 1028 1029 auto indexTy = parser.getBuilder().getIndexType(); 1030 auto i32Type = parser.getBuilder().getIntegerType(32); 1031 if (parser.parseOperand(memrefInfo) || 1032 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1033 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1034 parser.parseComma() || parser.parseKeyword("locality") || 1035 parser.parseLess() || 1036 parser.parseAttribute(localityHint, i32Type, "localityHint", 1037 result.attributes) || 1038 parser.parseGreater() || parser.parseComma() || 1039 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1040 parser.resolveOperand(memrefInfo, type, result.operands) || 1041 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1042 return failure(); 1043 1044 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1045 return parser.emitError(parser.getNameLoc(), 1046 "rw specifier has to be 'read' or 'write'"); 1047 result.addAttribute( 1048 PrefetchOp::getIsWriteAttrName(), 1049 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1050 1051 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1052 return parser.emitError(parser.getNameLoc(), 1053 "cache type has to be 'data' or 'instr'"); 1054 1055 result.addAttribute( 1056 PrefetchOp::getIsDataCacheAttrName(), 1057 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1058 1059 return success(); 1060 } 1061 1062 static LogicalResult verify(PrefetchOp op) { 1063 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1064 return op.emitOpError("too few indices"); 1065 1066 return success(); 1067 } 1068 1069 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1070 SmallVectorImpl<OpFoldResult> &results) { 1071 // prefetch(memrefcast) -> prefetch 1072 return foldMemRefCast(*this); 1073 } 1074 1075 //===----------------------------------------------------------------------===// 1076 // RankOp 1077 //===----------------------------------------------------------------------===// 1078 1079 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1080 // Constant fold rank when the rank of the operand is known. 1081 auto type = getOperand().getType(); 1082 auto shapedType = type.dyn_cast<ShapedType>(); 1083 if (shapedType && shapedType.hasRank()) 1084 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1085 return IntegerAttr(); 1086 } 1087 1088 //===----------------------------------------------------------------------===// 1089 // ReinterpretCastOp 1090 //===----------------------------------------------------------------------===// 1091 1092 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1093 /// `staticSizes` and `staticStrides` are automatically filled with 1094 /// source-memref-rank sentinel values that encode dynamic entries. 1095 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1096 MemRefType resultType, Value source, 1097 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1098 ArrayRef<OpFoldResult> strides, 1099 ArrayRef<NamedAttribute> attrs) { 1100 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1101 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1102 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1103 ShapedType::kDynamicStrideOrOffset); 1104 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1105 ShapedType::kDynamicSize); 1106 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1107 ShapedType::kDynamicStrideOrOffset); 1108 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1109 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1110 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1111 result.addAttributes(attrs); 1112 } 1113 1114 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1115 MemRefType resultType, Value source, 1116 int64_t offset, ArrayRef<int64_t> sizes, 1117 ArrayRef<int64_t> strides, 1118 ArrayRef<NamedAttribute> attrs) { 1119 SmallVector<OpFoldResult> sizeValues = 1120 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1121 return b.getI64IntegerAttr(v); 1122 })); 1123 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1124 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1125 return b.getI64IntegerAttr(v); 1126 })); 1127 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1128 strideValues, attrs); 1129 } 1130 1131 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1132 MemRefType resultType, Value source, Value offset, 1133 ValueRange sizes, ValueRange strides, 1134 ArrayRef<NamedAttribute> attrs) { 1135 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1136 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1137 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1138 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1139 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1140 } 1141 1142 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1143 // completed automatically, like we have for subview and extract_slice. 1144 static LogicalResult verify(ReinterpretCastOp op) { 1145 // The source and result memrefs should be in the same memory space. 1146 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1147 auto resultType = op.getType().cast<MemRefType>(); 1148 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1149 return op.emitError("different memory spaces specified for source type ") 1150 << srcType << " and result memref type " << resultType; 1151 if (srcType.getElementType() != resultType.getElementType()) 1152 return op.emitError("different element types specified for source type ") 1153 << srcType << " and result memref type " << resultType; 1154 1155 // Match sizes in result memref type and in static_sizes attribute. 1156 for (auto &en : 1157 llvm::enumerate(llvm::zip(resultType.getShape(), 1158 extractFromI64ArrayAttr(op.static_sizes())))) { 1159 int64_t resultSize = std::get<0>(en.value()); 1160 int64_t expectedSize = std::get<1>(en.value()); 1161 if (resultSize != expectedSize) 1162 return op.emitError("expected result type with size = ") 1163 << expectedSize << " instead of " << resultSize 1164 << " in dim = " << en.index(); 1165 } 1166 1167 // Match offset and strides in static_offset and static_strides attributes if 1168 // result memref type has an affine map specified. 1169 if (!resultType.getLayout().isIdentity()) { 1170 int64_t resultOffset; 1171 SmallVector<int64_t, 4> resultStrides; 1172 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1173 return failure(); 1174 1175 // Match offset in result memref type and in static_offsets attribute. 1176 int64_t expectedOffset = 1177 extractFromI64ArrayAttr(op.static_offsets()).front(); 1178 if (resultOffset != expectedOffset) 1179 return op.emitError("expected result type with offset = ") 1180 << resultOffset << " instead of " << expectedOffset; 1181 1182 // Match strides in result memref type and in static_strides attribute. 1183 for (auto &en : llvm::enumerate(llvm::zip( 1184 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1185 int64_t resultStride = std::get<0>(en.value()); 1186 int64_t expectedStride = std::get<1>(en.value()); 1187 if (resultStride != expectedStride) 1188 return op.emitError("expected result type with stride = ") 1189 << expectedStride << " instead of " << resultStride 1190 << " in dim = " << en.index(); 1191 } 1192 } 1193 return success(); 1194 } 1195 1196 //===----------------------------------------------------------------------===// 1197 // Reassociative reshape ops 1198 //===----------------------------------------------------------------------===// 1199 1200 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1201 return getSymbolLessAffineMaps(getReassociationExprs()); 1202 } 1203 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1204 return convertReassociationIndicesToExprs(getContext(), 1205 getReassociationIndices()); 1206 } 1207 1208 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1209 return getSymbolLessAffineMaps(getReassociationExprs()); 1210 } 1211 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1212 return convertReassociationIndicesToExprs(getContext(), 1213 getReassociationIndices()); 1214 } 1215 1216 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1217 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1218 } 1219 1220 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1221 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1222 } 1223 1224 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1225 /// copies. 1226 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1227 ArrayRef<int64_t> sizes, 1228 ArrayRef<AffineExpr> strides) { 1229 // Bands of extent one can be reshaped, as they are not reshaped at all. 1230 if (extent == 1) 1231 return true; 1232 // Otherwise, the size of the first dimension needs to be known. 1233 if (ShapedType::isDynamic(sizes[dim])) 1234 return false; 1235 assert(sizes.size() == strides.size() && "mismatched ranks"); 1236 // off by 1 indexing to avoid out of bounds 1237 // V 1238 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1239 // Only bands of static shapes are reshapable. This is due to the fact that 1240 // there is no relation between dynamic sizes and dynamic strides: we do not 1241 // have enough information to know whether a "-1" size corresponds to the 1242 // proper symbol in the AffineExpr of a stride. 1243 if (ShapedType::isDynamic(sizes[idx + 1])) 1244 return false; 1245 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1246 // simplify on the fly and catch more reshapable cases. 1247 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1248 return false; 1249 } 1250 return true; 1251 } 1252 1253 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1254 /// expected to be valid) to `type`. 1255 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1256 /// MemRefType. 1257 static MemRefType 1258 computeReshapeCollapsedType(MemRefType type, 1259 ArrayRef<AffineMap> reassociation) { 1260 auto sizes = type.getShape(); 1261 AffineExpr offset; 1262 SmallVector<AffineExpr, 4> strides; 1263 auto status = getStridesAndOffset(type, strides, offset); 1264 (void)status; 1265 assert(succeeded(status) && "expected strided memref"); 1266 1267 SmallVector<int64_t, 4> newSizes; 1268 newSizes.reserve(reassociation.size()); 1269 SmallVector<AffineExpr, 4> newStrides; 1270 newStrides.reserve(reassociation.size()); 1271 1272 // Use the fact that reassociation is valid to simplify the logic: only use 1273 // each map's rank. 1274 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1275 unsigned currentDim = 0; 1276 for (AffineMap m : reassociation) { 1277 unsigned dim = m.getNumResults(); 1278 int64_t size = 1; 1279 AffineExpr stride = strides[currentDim + dim - 1]; 1280 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1281 size = ShapedType::kDynamicSize; 1282 stride = AffineExpr(); 1283 } else { 1284 for (unsigned d = 0; d < dim; ++d) 1285 size *= sizes[currentDim + d]; 1286 } 1287 newSizes.push_back(size); 1288 newStrides.push_back(stride); 1289 currentDim += dim; 1290 } 1291 1292 // Early-exit: if `type` is contiguous, the result must be contiguous. 1293 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1294 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1295 1296 // Convert back to int64_t because we don't have enough information to create 1297 // new strided layouts from AffineExpr only. This corresponds to a case where 1298 // copies may be necessary. 1299 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1300 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1301 intOffset = o.getValue(); 1302 SmallVector<int64_t, 4> intStrides; 1303 intStrides.reserve(strides.size()); 1304 for (auto stride : newStrides) { 1305 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1306 intStrides.push_back(cst.getValue()); 1307 else 1308 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1309 } 1310 auto layout = 1311 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1312 return canonicalizeStridedLayout( 1313 MemRefType::Builder(type).setShape(newSizes).setLayout( 1314 AffineMapAttr::get(layout))); 1315 } 1316 1317 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1318 ArrayRef<ReassociationIndices> reassociation, 1319 ArrayRef<NamedAttribute> attrs) { 1320 auto memRefType = src.getType().cast<MemRefType>(); 1321 auto resultType = computeReshapeCollapsedType( 1322 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1323 b.getContext(), reassociation))); 1324 build(b, result, resultType, src, attrs); 1325 result.addAttribute(getReassociationAttrName(), 1326 getReassociationIndicesAttribute(b, reassociation)); 1327 } 1328 1329 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1330 ArrayRef<ReassociationIndices> reassociation, 1331 ArrayRef<NamedAttribute> attrs) { 1332 auto memRefType = src.getType().cast<MemRefType>(); 1333 auto resultType = computeReshapeCollapsedType( 1334 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1335 b.getContext(), reassociation))); 1336 build(b, result, resultType, src, attrs); 1337 result.addAttribute(getReassociationAttrName(), 1338 getReassociationIndicesAttribute(b, reassociation)); 1339 } 1340 1341 template <typename ReshapeOp, 1342 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1343 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1344 MemRefType collapsedType) { 1345 if (failed( 1346 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1347 return failure(); 1348 auto maps = op.getReassociationMaps(); 1349 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1350 if (collapsedType != expectedType) 1351 return op.emitOpError("expected collapsed type to be ") 1352 << expectedType << ", but got " << collapsedType; 1353 return success(); 1354 } 1355 1356 static LogicalResult verify(ExpandShapeOp op) { 1357 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1358 } 1359 1360 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1361 MLIRContext *context) { 1362 results.add<CollapseReshapeOps<ExpandShapeOp>, 1363 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1364 } 1365 1366 static LogicalResult verify(CollapseShapeOp op) { 1367 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1368 } 1369 1370 struct CollapseShapeOpMemRefCastFolder 1371 : public OpRewritePattern<CollapseShapeOp> { 1372 public: 1373 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1374 1375 LogicalResult matchAndRewrite(CollapseShapeOp op, 1376 PatternRewriter &rewriter) const override { 1377 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1378 if (!cast) 1379 return failure(); 1380 1381 if (!CastOp::canFoldIntoConsumerOp(cast)) 1382 return failure(); 1383 1384 Type newResultType = computeReshapeCollapsedType( 1385 cast.getOperand().getType().cast<MemRefType>(), 1386 op.getReassociationMaps()); 1387 1388 if (newResultType == op.getResultType()) { 1389 rewriter.updateRootInPlace( 1390 op, [&]() { op.srcMutable().assign(cast.source()); }); 1391 } else { 1392 Value newOp = rewriter.create<CollapseShapeOp>( 1393 op->getLoc(), cast.source(), op.getReassociationIndices()); 1394 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1395 } 1396 return success(); 1397 } 1398 }; 1399 1400 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1401 MLIRContext *context) { 1402 results.add<CollapseReshapeOps<CollapseShapeOp>, 1403 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1404 CollapseShapeOpMemRefCastFolder>(context); 1405 } 1406 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1407 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1408 } 1409 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1410 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1411 } 1412 1413 //===----------------------------------------------------------------------===// 1414 // ReshapeOp 1415 //===----------------------------------------------------------------------===// 1416 1417 static LogicalResult verify(ReshapeOp op) { 1418 Type operandType = op.source().getType(); 1419 Type resultType = op.result().getType(); 1420 1421 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1422 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1423 if (operandElementType != resultElementType) 1424 return op.emitOpError("element types of source and destination memref " 1425 "types should be the same"); 1426 1427 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1428 if (!operandMemRefType.getLayout().isIdentity()) 1429 return op.emitOpError( 1430 "source memref type should have identity affine map"); 1431 1432 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1433 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1434 if (resultMemRefType) { 1435 if (!resultMemRefType.getLayout().isIdentity()) 1436 return op.emitOpError( 1437 "result memref type should have identity affine map"); 1438 if (shapeSize == ShapedType::kDynamicSize) 1439 return op.emitOpError("cannot use shape operand with dynamic length to " 1440 "reshape to statically-ranked memref type"); 1441 if (shapeSize != resultMemRefType.getRank()) 1442 return op.emitOpError( 1443 "length of shape operand differs from the result's memref rank"); 1444 } 1445 return success(); 1446 } 1447 1448 //===----------------------------------------------------------------------===// 1449 // StoreOp 1450 //===----------------------------------------------------------------------===// 1451 1452 static LogicalResult verify(StoreOp op) { 1453 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1454 return op.emitOpError("store index operand count not equal to memref rank"); 1455 1456 return success(); 1457 } 1458 1459 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1460 SmallVectorImpl<OpFoldResult> &results) { 1461 /// store(memrefcast) -> store 1462 return foldMemRefCast(*this, getValueToStore()); 1463 } 1464 1465 //===----------------------------------------------------------------------===// 1466 // SubViewOp 1467 //===----------------------------------------------------------------------===// 1468 1469 namespace { 1470 /// Helpers to write more idiomatic operations. 1471 namespace saturated_arith { 1472 struct Wrapper { 1473 explicit Wrapper(int64_t v) : v(v) {} 1474 operator int64_t() { return v; } 1475 int64_t v; 1476 }; 1477 Wrapper operator+(Wrapper a, int64_t b) { 1478 if (ShapedType::isDynamicStrideOrOffset(a) || 1479 ShapedType::isDynamicStrideOrOffset(b)) 1480 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1481 return Wrapper(a.v + b); 1482 } 1483 Wrapper operator*(Wrapper a, int64_t b) { 1484 if (ShapedType::isDynamicStrideOrOffset(a) || 1485 ShapedType::isDynamicStrideOrOffset(b)) 1486 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1487 return Wrapper(a.v * b); 1488 } 1489 } // namespace saturated_arith 1490 } // namespace 1491 1492 /// A subview result type can be fully inferred from the source type and the 1493 /// static representation of offsets, sizes and strides. Special sentinels 1494 /// encode the dynamic case. 1495 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1496 ArrayRef<int64_t> leadingStaticOffsets, 1497 ArrayRef<int64_t> leadingStaticSizes, 1498 ArrayRef<int64_t> leadingStaticStrides) { 1499 // A subview may specify only a leading subset of offset/sizes/strides in 1500 // which case we complete with offset=0, sizes from memref type and strides=1. 1501 unsigned rank = sourceMemRefType.getRank(); 1502 assert(leadingStaticOffsets.size() <= rank && 1503 "unexpected leadingStaticOffsets overflow"); 1504 assert(leadingStaticSizes.size() <= rank && 1505 "unexpected leadingStaticSizes overflow"); 1506 assert(leadingStaticStrides.size() <= rank && 1507 "unexpected leadingStaticStrides overflow"); 1508 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1509 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1510 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1511 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1512 unsigned numTrailingSizes = rank - staticSizes.size(); 1513 unsigned numTrailingStrides = rank - staticStrides.size(); 1514 staticOffsets.append(numTrailingOffsets, 0); 1515 llvm::append_range(staticSizes, 1516 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1517 staticStrides.append(numTrailingStrides, 1); 1518 1519 // Extract source offset and strides. 1520 int64_t sourceOffset; 1521 SmallVector<int64_t, 4> sourceStrides; 1522 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1523 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1524 (void)res; 1525 1526 // Compute target offset whose value is: 1527 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1528 int64_t targetOffset = sourceOffset; 1529 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1530 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1531 using namespace saturated_arith; 1532 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1533 } 1534 1535 // Compute target stride whose value is: 1536 // `sourceStrides_i * staticStrides_i`. 1537 SmallVector<int64_t, 4> targetStrides; 1538 targetStrides.reserve(staticOffsets.size()); 1539 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1540 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1541 using namespace saturated_arith; 1542 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1543 } 1544 1545 // The type is now known. 1546 return MemRefType::get( 1547 staticSizes, sourceMemRefType.getElementType(), 1548 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1549 sourceMemRefType.getContext()), 1550 sourceMemRefType.getMemorySpace()); 1551 } 1552 1553 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1554 ArrayRef<OpFoldResult> leadingStaticOffsets, 1555 ArrayRef<OpFoldResult> leadingStaticSizes, 1556 ArrayRef<OpFoldResult> leadingStaticStrides) { 1557 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1558 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1559 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1560 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1561 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1562 ShapedType::kDynamicSize); 1563 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1564 staticStrides, ShapedType::kDynamicStrideOrOffset); 1565 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1566 staticSizes, staticStrides); 1567 } 1568 1569 Type SubViewOp::inferRankReducedResultType( 1570 unsigned resultRank, MemRefType sourceRankedTensorType, 1571 ArrayRef<int64_t> leadingStaticOffsets, 1572 ArrayRef<int64_t> leadingStaticSizes, 1573 ArrayRef<int64_t> leadingStaticStrides) { 1574 auto inferredType = 1575 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1576 leadingStaticSizes, leadingStaticStrides) 1577 .cast<MemRefType>(); 1578 assert(inferredType.getRank() >= resultRank && "expected "); 1579 int rankDiff = inferredType.getRank() - resultRank; 1580 if (rankDiff > 0) { 1581 auto shape = inferredType.getShape(); 1582 llvm::SmallDenseSet<unsigned> dimsToProject; 1583 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1584 SmallVector<int64_t> projectedShape; 1585 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1586 if (!dimsToProject.contains(pos)) 1587 projectedShape.push_back(shape[pos]); 1588 1589 AffineMap map = inferredType.getLayout().getAffineMap(); 1590 if (!map.isIdentity()) 1591 map = getProjectedMap(map, dimsToProject); 1592 inferredType = 1593 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1594 inferredType.getMemorySpace()); 1595 } 1596 return inferredType; 1597 } 1598 1599 Type SubViewOp::inferRankReducedResultType( 1600 unsigned resultRank, MemRefType sourceRankedTensorType, 1601 ArrayRef<OpFoldResult> leadingStaticOffsets, 1602 ArrayRef<OpFoldResult> leadingStaticSizes, 1603 ArrayRef<OpFoldResult> leadingStaticStrides) { 1604 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1605 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1606 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1607 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1608 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1609 ShapedType::kDynamicSize); 1610 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1611 staticStrides, ShapedType::kDynamicStrideOrOffset); 1612 return SubViewOp::inferRankReducedResultType( 1613 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1614 staticStrides); 1615 } 1616 // Build a SubViewOp with mixed static and dynamic entries and custom result 1617 // type. If the type passed is nullptr, it is inferred. 1618 void SubViewOp::build(OpBuilder &b, OperationState &result, 1619 MemRefType resultType, Value source, 1620 ArrayRef<OpFoldResult> offsets, 1621 ArrayRef<OpFoldResult> sizes, 1622 ArrayRef<OpFoldResult> strides, 1623 ArrayRef<NamedAttribute> attrs) { 1624 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1625 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1626 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1627 ShapedType::kDynamicStrideOrOffset); 1628 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1629 ShapedType::kDynamicSize); 1630 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1631 ShapedType::kDynamicStrideOrOffset); 1632 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1633 // Structuring implementation this way avoids duplication between builders. 1634 if (!resultType) { 1635 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1636 staticSizes, staticStrides) 1637 .cast<MemRefType>(); 1638 } 1639 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1640 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1641 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1642 result.addAttributes(attrs); 1643 } 1644 1645 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1646 // type. 1647 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1648 ArrayRef<OpFoldResult> offsets, 1649 ArrayRef<OpFoldResult> sizes, 1650 ArrayRef<OpFoldResult> strides, 1651 ArrayRef<NamedAttribute> attrs) { 1652 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1653 } 1654 1655 // Build a SubViewOp with static entries and inferred result type. 1656 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1657 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1658 ArrayRef<int64_t> strides, 1659 ArrayRef<NamedAttribute> attrs) { 1660 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1661 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1662 return b.getI64IntegerAttr(v); 1663 })); 1664 SmallVector<OpFoldResult> sizeValues = 1665 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1666 return b.getI64IntegerAttr(v); 1667 })); 1668 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1669 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1670 return b.getI64IntegerAttr(v); 1671 })); 1672 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1673 } 1674 1675 // Build a SubViewOp with dynamic entries and custom result type. If the 1676 // type passed is nullptr, it is inferred. 1677 void SubViewOp::build(OpBuilder &b, OperationState &result, 1678 MemRefType resultType, Value source, 1679 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1680 ArrayRef<int64_t> strides, 1681 ArrayRef<NamedAttribute> attrs) { 1682 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1683 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1684 return b.getI64IntegerAttr(v); 1685 })); 1686 SmallVector<OpFoldResult> sizeValues = 1687 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1688 return b.getI64IntegerAttr(v); 1689 })); 1690 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1691 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1692 return b.getI64IntegerAttr(v); 1693 })); 1694 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1695 attrs); 1696 } 1697 1698 // Build a SubViewOp with dynamic entries and custom result type. If the type 1699 // passed is nullptr, it is inferred. 1700 void SubViewOp::build(OpBuilder &b, OperationState &result, 1701 MemRefType resultType, Value source, ValueRange offsets, 1702 ValueRange sizes, ValueRange strides, 1703 ArrayRef<NamedAttribute> attrs) { 1704 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1705 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1706 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1707 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1708 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1709 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1710 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1711 } 1712 1713 // Build a SubViewOp with dynamic entries and inferred result type. 1714 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1715 ValueRange offsets, ValueRange sizes, ValueRange strides, 1716 ArrayRef<NamedAttribute> attrs) { 1717 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1718 } 1719 1720 /// For ViewLikeOpInterface. 1721 Value SubViewOp::getViewSource() { return source(); } 1722 1723 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 1724 /// value). 1725 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 1726 AffineExpr t1Offset, t2Offset; 1727 SmallVector<AffineExpr> t1Strides, t2Strides; 1728 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 1729 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 1730 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 1731 } 1732 1733 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1734 /// This function is slight variant of `is subsequence` algorithm where 1735 /// not matching dimension must be 1. 1736 static SliceVerificationResult 1737 isRankReducedMemRefType(MemRefType originalType, 1738 MemRefType candidateRankReducedType, 1739 ArrayRef<OpFoldResult> sizes) { 1740 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 1741 if (partialRes != SliceVerificationResult::Success) 1742 return partialRes; 1743 1744 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 1745 originalType, candidateRankReducedType, sizes); 1746 1747 // Sizes cannot be matched in case empty vector is returned. 1748 if (!optionalUnusedDimsMask.hasValue()) 1749 return SliceVerificationResult::LayoutMismatch; 1750 1751 if (originalType.getMemorySpace() != 1752 candidateRankReducedType.getMemorySpace()) 1753 return SliceVerificationResult::MemSpaceMismatch; 1754 1755 // No amount of stride dropping can reconcile incompatible offsets. 1756 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 1757 return SliceVerificationResult::LayoutMismatch; 1758 1759 return SliceVerificationResult::Success; 1760 } 1761 1762 template <typename OpTy> 1763 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 1764 OpTy op, Type expectedType) { 1765 auto memrefType = expectedType.cast<ShapedType>(); 1766 switch (result) { 1767 case SliceVerificationResult::Success: 1768 return success(); 1769 case SliceVerificationResult::RankTooLarge: 1770 return op.emitError("expected result rank to be smaller or equal to ") 1771 << "the source rank. "; 1772 case SliceVerificationResult::SizeMismatch: 1773 return op.emitError("expected result type to be ") 1774 << expectedType 1775 << " or a rank-reduced version. (mismatch of result sizes) "; 1776 case SliceVerificationResult::ElemTypeMismatch: 1777 return op.emitError("expected result element type to be ") 1778 << memrefType.getElementType(); 1779 case SliceVerificationResult::MemSpaceMismatch: 1780 return op.emitError("expected result and source memory spaces to match."); 1781 case SliceVerificationResult::LayoutMismatch: 1782 return op.emitError("expected result type to be ") 1783 << expectedType 1784 << " or a rank-reduced version. (mismatch of result layout) "; 1785 } 1786 llvm_unreachable("unexpected subview verification result"); 1787 } 1788 1789 /// Verifier for SubViewOp. 1790 static LogicalResult verify(SubViewOp op) { 1791 MemRefType baseType = op.getSourceType(); 1792 MemRefType subViewType = op.getType(); 1793 1794 // The base memref and the view memref should be in the same memory space. 1795 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1796 return op.emitError("different memory spaces specified for base memref " 1797 "type ") 1798 << baseType << " and subview memref type " << subViewType; 1799 1800 // Verify that the base memref type has a strided layout map. 1801 if (!isStrided(baseType)) 1802 return op.emitError("base type ") << baseType << " is not strided"; 1803 1804 // Verify result type against inferred type. 1805 auto expectedType = SubViewOp::inferResultType( 1806 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1807 extractFromI64ArrayAttr(op.static_sizes()), 1808 extractFromI64ArrayAttr(op.static_strides())); 1809 1810 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 1811 subViewType, op.getMixedSizes()); 1812 return produceSubViewErrorMsg(result, op, expectedType); 1813 } 1814 1815 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 1816 return os << "range " << range.offset << ":" << range.size << ":" 1817 << range.stride; 1818 } 1819 1820 /// Return the list of Range (i.e. offset, size, stride). Each Range 1821 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1822 /// with `b` at location `loc`. 1823 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1824 OpBuilder &b, Location loc) { 1825 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1826 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1827 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1828 SmallVector<Range, 8> res; 1829 unsigned rank = ranks[0]; 1830 res.reserve(rank); 1831 for (unsigned idx = 0; idx < rank; ++idx) { 1832 Value offset = 1833 op.isDynamicOffset(idx) 1834 ? op.getDynamicOffset(idx) 1835 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1836 Value size = 1837 op.isDynamicSize(idx) 1838 ? op.getDynamicSize(idx) 1839 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 1840 Value stride = 1841 op.isDynamicStride(idx) 1842 ? op.getDynamicStride(idx) 1843 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 1844 res.emplace_back(Range{offset, size, stride}); 1845 } 1846 return res; 1847 } 1848 1849 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 1850 /// deduce the result type for the given `sourceType`. Additionally, reduce the 1851 /// rank of the inferred result type if `currentResultType` is lower rank than 1852 /// `currentSourceType`. Use this signature if `sourceType` is updated together 1853 /// with the result type. In this case, it is important to compute the dropped 1854 /// dimensions using `currentSourceType` whose strides align with 1855 /// `currentResultType`. 1856 static MemRefType getCanonicalSubViewResultType( 1857 MemRefType currentResultType, MemRefType currentSourceType, 1858 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 1859 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 1860 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1861 mixedSizes, mixedStrides) 1862 .cast<MemRefType>(); 1863 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 1864 computeMemRefRankReductionMask(currentSourceType, currentResultType, 1865 mixedSizes); 1866 // Return nullptr as failure mode. 1867 if (!unusedDims) 1868 return nullptr; 1869 SmallVector<int64_t> shape; 1870 for (auto sizes : llvm::enumerate(nonRankReducedType.getShape())) { 1871 if (unusedDims->count(sizes.index())) 1872 continue; 1873 shape.push_back(sizes.value()); 1874 } 1875 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 1876 if (!layoutMap.isIdentity()) 1877 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 1878 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 1879 nonRankReducedType.getMemorySpace()); 1880 } 1881 1882 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 1883 /// deduce the result type. Additionally, reduce the rank of the inferred result 1884 /// type if `currentResultType` is lower rank than `sourceType`. 1885 static MemRefType getCanonicalSubViewResultType( 1886 MemRefType currentResultType, MemRefType sourceType, 1887 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 1888 ArrayRef<OpFoldResult> mixedStrides) { 1889 return getCanonicalSubViewResultType(currentResultType, sourceType, 1890 sourceType, mixedOffsets, mixedSizes, 1891 mixedStrides); 1892 } 1893 1894 namespace { 1895 /// Pattern to rewrite a subview op with MemRefCast arguments. 1896 /// This essentially pushes memref.cast past its consuming subview when 1897 /// `canFoldIntoConsumerOp` is true. 1898 /// 1899 /// Example: 1900 /// ``` 1901 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1902 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1903 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1904 /// ``` 1905 /// is rewritten into: 1906 /// ``` 1907 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1908 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1909 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1910 /// ``` 1911 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1912 public: 1913 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1914 1915 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1916 PatternRewriter &rewriter) const override { 1917 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1918 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1919 return matchPattern(operand, matchConstantIndex()); 1920 })) 1921 return failure(); 1922 1923 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1924 if (!castOp) 1925 return failure(); 1926 1927 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1928 return failure(); 1929 1930 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 1931 // MemRefCastOp source operand type to infer the result type and the current 1932 // SubViewOp source operand type to compute the dropped dimensions if the 1933 // operation is rank-reducing. 1934 auto resultType = getCanonicalSubViewResultType( 1935 subViewOp.getType(), subViewOp.getSourceType(), 1936 castOp.source().getType().cast<MemRefType>(), 1937 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1938 subViewOp.getMixedStrides()); 1939 if (!resultType) 1940 return failure(); 1941 1942 Value newSubView = rewriter.create<SubViewOp>( 1943 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1944 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1945 subViewOp.static_sizes(), subViewOp.static_strides()); 1946 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1947 newSubView); 1948 return success(); 1949 } 1950 }; 1951 } // namespace 1952 1953 /// Return the canonical type of the result of a subview. 1954 struct SubViewReturnTypeCanonicalizer { 1955 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 1956 ArrayRef<OpFoldResult> mixedSizes, 1957 ArrayRef<OpFoldResult> mixedStrides) { 1958 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 1959 mixedOffsets, mixedSizes, 1960 mixedStrides); 1961 } 1962 }; 1963 1964 /// A canonicalizer wrapper to replace SubViewOps. 1965 struct SubViewCanonicalizer { 1966 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1967 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1968 } 1969 }; 1970 1971 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1972 MLIRContext *context) { 1973 results 1974 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1975 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 1976 SubViewOpMemRefCastFolder>(context); 1977 } 1978 1979 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1980 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1981 auto sourceShapedType = source().getType().cast<ShapedType>(); 1982 1983 if (resultShapedType.hasStaticShape() && 1984 resultShapedType == sourceShapedType) { 1985 return getViewSource(); 1986 } 1987 1988 return {}; 1989 } 1990 1991 //===----------------------------------------------------------------------===// 1992 // TransposeOp 1993 //===----------------------------------------------------------------------===// 1994 1995 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1996 static MemRefType inferTransposeResultType(MemRefType memRefType, 1997 AffineMap permutationMap) { 1998 auto rank = memRefType.getRank(); 1999 auto originalSizes = memRefType.getShape(); 2000 // Compute permuted sizes. 2001 SmallVector<int64_t, 4> sizes(rank, 0); 2002 for (auto en : llvm::enumerate(permutationMap.getResults())) 2003 sizes[en.index()] = 2004 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2005 2006 // Compute permuted strides. 2007 int64_t offset; 2008 SmallVector<int64_t, 4> strides; 2009 auto res = getStridesAndOffset(memRefType, strides, offset); 2010 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2011 (void)res; 2012 auto map = 2013 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2014 map = permutationMap ? map.compose(permutationMap) : map; 2015 return MemRefType::Builder(memRefType) 2016 .setShape(sizes) 2017 .setLayout(AffineMapAttr::get(map)); 2018 } 2019 2020 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2021 AffineMapAttr permutation, 2022 ArrayRef<NamedAttribute> attrs) { 2023 auto permutationMap = permutation.getValue(); 2024 assert(permutationMap); 2025 2026 auto memRefType = in.getType().cast<MemRefType>(); 2027 // Compute result type. 2028 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2029 2030 build(b, result, resultType, in, attrs); 2031 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2032 } 2033 2034 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2035 static void print(OpAsmPrinter &p, TransposeOp op) { 2036 p << " " << op.in() << " " << op.permutation(); 2037 p.printOptionalAttrDict(op->getAttrs(), 2038 {TransposeOp::getPermutationAttrName()}); 2039 p << " : " << op.in().getType() << " to " << op.getType(); 2040 } 2041 2042 static ParseResult parseTransposeOp(OpAsmParser &parser, 2043 OperationState &result) { 2044 OpAsmParser::OperandType in; 2045 AffineMap permutation; 2046 MemRefType srcType, dstType; 2047 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2048 parser.parseOptionalAttrDict(result.attributes) || 2049 parser.parseColonType(srcType) || 2050 parser.resolveOperand(in, srcType, result.operands) || 2051 parser.parseKeywordType("to", dstType) || 2052 parser.addTypeToList(dstType, result.types)) 2053 return failure(); 2054 2055 result.addAttribute(TransposeOp::getPermutationAttrName(), 2056 AffineMapAttr::get(permutation)); 2057 return success(); 2058 } 2059 2060 static LogicalResult verify(TransposeOp op) { 2061 if (!op.permutation().isPermutation()) 2062 return op.emitOpError("expected a permutation map"); 2063 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2064 return op.emitOpError( 2065 "expected a permutation map of same rank as the input"); 2066 2067 auto srcType = op.in().getType().cast<MemRefType>(); 2068 auto dstType = op.getType().cast<MemRefType>(); 2069 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2070 if (dstType != transposedType) 2071 return op.emitOpError("output type ") 2072 << dstType << " does not match transposed input type " << srcType 2073 << ", " << transposedType; 2074 return success(); 2075 } 2076 2077 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2078 if (succeeded(foldMemRefCast(*this))) 2079 return getResult(); 2080 return {}; 2081 } 2082 2083 //===----------------------------------------------------------------------===// 2084 // ViewOp 2085 //===----------------------------------------------------------------------===// 2086 2087 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2088 OpAsmParser::OperandType srcInfo; 2089 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2090 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2091 auto indexType = parser.getBuilder().getIndexType(); 2092 Type srcType, dstType; 2093 llvm::SMLoc offsetLoc; 2094 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2095 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2096 return failure(); 2097 2098 if (offsetInfo.size() != 1) 2099 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2100 2101 return failure( 2102 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2103 parser.parseOptionalAttrDict(result.attributes) || 2104 parser.parseColonType(srcType) || 2105 parser.resolveOperand(srcInfo, srcType, result.operands) || 2106 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2107 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2108 parser.parseKeywordType("to", dstType) || 2109 parser.addTypeToList(dstType, result.types)); 2110 } 2111 2112 static void print(OpAsmPrinter &p, ViewOp op) { 2113 p << ' ' << op.getOperand(0) << '['; 2114 p.printOperand(op.byte_shift()); 2115 p << "][" << op.sizes() << ']'; 2116 p.printOptionalAttrDict(op->getAttrs()); 2117 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2118 } 2119 2120 static LogicalResult verify(ViewOp op) { 2121 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2122 auto viewType = op.getType(); 2123 2124 // The base memref should have identity layout map (or none). 2125 if (!baseType.getLayout().isIdentity()) 2126 return op.emitError("unsupported map for base memref type ") << baseType; 2127 2128 // The result memref should have identity layout map (or none). 2129 if (!viewType.getLayout().isIdentity()) 2130 return op.emitError("unsupported map for result memref type ") << viewType; 2131 2132 // The base memref and the view memref should be in the same memory space. 2133 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2134 return op.emitError("different memory spaces specified for base memref " 2135 "type ") 2136 << baseType << " and view memref type " << viewType; 2137 2138 // Verify that we have the correct number of sizes for the result type. 2139 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2140 if (op.sizes().size() != numDynamicDims) 2141 return op.emitError("incorrect number of size operands for type ") 2142 << viewType; 2143 2144 return success(); 2145 } 2146 2147 Value ViewOp::getViewSource() { return source(); } 2148 2149 namespace { 2150 2151 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2152 using OpRewritePattern<ViewOp>::OpRewritePattern; 2153 2154 LogicalResult matchAndRewrite(ViewOp viewOp, 2155 PatternRewriter &rewriter) const override { 2156 // Return if none of the operands are constants. 2157 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2158 return matchPattern(operand, matchConstantIndex()); 2159 })) 2160 return failure(); 2161 2162 // Get result memref type. 2163 auto memrefType = viewOp.getType(); 2164 2165 // Get offset from old memref view type 'memRefType'. 2166 int64_t oldOffset; 2167 SmallVector<int64_t, 4> oldStrides; 2168 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2169 return failure(); 2170 assert(oldOffset == 0 && "Expected 0 offset"); 2171 2172 SmallVector<Value, 4> newOperands; 2173 2174 // Offset cannot be folded into result type. 2175 2176 // Fold any dynamic dim operands which are produced by a constant. 2177 SmallVector<int64_t, 4> newShapeConstants; 2178 newShapeConstants.reserve(memrefType.getRank()); 2179 2180 unsigned dynamicDimPos = 0; 2181 unsigned rank = memrefType.getRank(); 2182 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2183 int64_t dimSize = memrefType.getDimSize(dim); 2184 // If this is already static dimension, keep it. 2185 if (!ShapedType::isDynamic(dimSize)) { 2186 newShapeConstants.push_back(dimSize); 2187 continue; 2188 } 2189 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2190 if (auto constantIndexOp = 2191 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2192 // Dynamic shape dimension will be folded. 2193 newShapeConstants.push_back(constantIndexOp.value()); 2194 } else { 2195 // Dynamic shape dimension not folded; copy operand from old memref. 2196 newShapeConstants.push_back(dimSize); 2197 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2198 } 2199 dynamicDimPos++; 2200 } 2201 2202 // Create new memref type with constant folded dims. 2203 MemRefType newMemRefType = 2204 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2205 // Nothing new, don't fold. 2206 if (newMemRefType == memrefType) 2207 return failure(); 2208 2209 // Create new ViewOp. 2210 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2211 viewOp.getOperand(0), 2212 viewOp.byte_shift(), newOperands); 2213 // Insert a cast so we have the same type as the old memref type. 2214 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2215 return success(); 2216 } 2217 }; 2218 2219 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2220 using OpRewritePattern<ViewOp>::OpRewritePattern; 2221 2222 LogicalResult matchAndRewrite(ViewOp viewOp, 2223 PatternRewriter &rewriter) const override { 2224 Value memrefOperand = viewOp.getOperand(0); 2225 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2226 if (!memrefCastOp) 2227 return failure(); 2228 Value allocOperand = memrefCastOp.getOperand(); 2229 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2230 if (!allocOp) 2231 return failure(); 2232 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2233 viewOp.byte_shift(), viewOp.sizes()); 2234 return success(); 2235 } 2236 }; 2237 2238 } // namespace 2239 2240 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2241 MLIRContext *context) { 2242 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2243 } 2244 2245 //===----------------------------------------------------------------------===// 2246 // TableGen'd op method definitions 2247 //===----------------------------------------------------------------------===// 2248 2249 #define GET_OP_CLASSES 2250 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2251