1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 if (arith::ConstantOp::isBuildableWith(value, type)) 34 return builder.create<arith::ConstantOp>(loc, value, type); 35 if (ConstantOp::isBuildableWith(value, type)) 36 return builder.create<ConstantOp>(loc, value, type); 37 return nullptr; 38 } 39 40 //===----------------------------------------------------------------------===// 41 // Common canonicalization pattern support logic 42 //===----------------------------------------------------------------------===// 43 44 /// This is a common class used for patterns of the form 45 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 46 /// into the root operation directly. 47 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 48 bool folded = false; 49 for (OpOperand &operand : op->getOpOperands()) { 50 auto cast = operand.get().getDefiningOp<CastOp>(); 51 if (cast && operand.get() != inner && 52 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 53 operand.set(cast.getOperand()); 54 folded = true; 55 } 56 } 57 return success(folded); 58 } 59 60 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 61 /// type. 62 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 63 if (auto memref = type.dyn_cast<MemRefType>()) 64 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 65 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 66 return UnrankedTensorType::get(memref.getElementType()); 67 return NoneType::get(type.getContext()); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // AllocOp / AllocaOp 72 //===----------------------------------------------------------------------===// 73 74 template <typename AllocLikeOp> 75 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 76 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 77 "applies to only alloc or alloca"); 78 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 79 if (!memRefType) 80 return op.emitOpError("result must be a memref"); 81 82 if (static_cast<int64_t>(op.dynamicSizes().size()) != 83 memRefType.getNumDynamicDims()) 84 return op.emitOpError("dimension operand count does not equal memref " 85 "dynamic dimension count"); 86 87 unsigned numSymbols = 0; 88 if (!memRefType.getLayout().isIdentity()) 89 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 90 if (op.symbolOperands().size() != numSymbols) 91 return op.emitOpError("symbol operand count does not equal memref symbol " 92 "count: expected ") 93 << numSymbols << ", got " << op.symbolOperands().size(); 94 95 return success(); 96 } 97 98 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 99 100 static LogicalResult verify(AllocaOp op) { 101 // An alloca op needs to have an ancestor with an allocation scope trait. 102 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 103 return op.emitOpError( 104 "requires an ancestor op with AutomaticAllocationScope trait"); 105 106 return verifyAllocLikeOp(op); 107 } 108 109 namespace { 110 /// Fold constant dimensions into an alloc like operation. 111 template <typename AllocLikeOp> 112 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 113 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 114 115 LogicalResult matchAndRewrite(AllocLikeOp alloc, 116 PatternRewriter &rewriter) const override { 117 // Check to see if any dimensions operands are constants. If so, we can 118 // substitute and drop them. 119 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 120 return matchPattern(operand, matchConstantIndex()); 121 })) 122 return failure(); 123 124 auto memrefType = alloc.getType(); 125 126 // Ok, we have one or more constant operands. Collect the non-constant ones 127 // and keep track of the resultant memref type to build. 128 SmallVector<int64_t, 4> newShapeConstants; 129 newShapeConstants.reserve(memrefType.getRank()); 130 SmallVector<Value, 4> dynamicSizes; 131 132 unsigned dynamicDimPos = 0; 133 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 134 int64_t dimSize = memrefType.getDimSize(dim); 135 // If this is already static dimension, keep it. 136 if (dimSize != -1) { 137 newShapeConstants.push_back(dimSize); 138 continue; 139 } 140 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 141 auto *defOp = dynamicSize.getDefiningOp(); 142 if (auto constantIndexOp = 143 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 144 // Dynamic shape dimension will be folded. 145 newShapeConstants.push_back(constantIndexOp.value()); 146 } else { 147 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 148 newShapeConstants.push_back(-1); 149 dynamicSizes.push_back(dynamicSize); 150 } 151 dynamicDimPos++; 152 } 153 154 // Create new memref type (which will have fewer dynamic dimensions). 155 MemRefType newMemRefType = 156 MemRefType::Builder(memrefType).setShape(newShapeConstants); 157 assert(static_cast<int64_t>(dynamicSizes.size()) == 158 newMemRefType.getNumDynamicDims()); 159 160 // Create and insert the alloc op for the new memref. 161 auto newAlloc = rewriter.create<AllocLikeOp>( 162 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 163 alloc.alignmentAttr()); 164 // Insert a cast so we have the same type as the old alloc. 165 auto resultCast = 166 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 167 168 rewriter.replaceOp(alloc, {resultCast}); 169 return success(); 170 } 171 }; 172 173 /// Fold alloc operations with no users or only store and dealloc uses. 174 template <typename T> 175 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 176 using OpRewritePattern<T>::OpRewritePattern; 177 178 LogicalResult matchAndRewrite(T alloc, 179 PatternRewriter &rewriter) const override { 180 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 181 if (auto storeOp = dyn_cast<StoreOp>(op)) 182 return storeOp.value() == alloc; 183 return !isa<DeallocOp>(op); 184 })) 185 return failure(); 186 187 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 188 rewriter.eraseOp(user); 189 190 rewriter.eraseOp(alloc); 191 return success(); 192 } 193 }; 194 } // namespace 195 196 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 197 MLIRContext *context) { 198 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 199 } 200 201 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 202 MLIRContext *context) { 203 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 204 context); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // AllocaScopeOp 209 //===----------------------------------------------------------------------===// 210 211 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 212 bool printBlockTerminators = false; 213 214 p << " "; 215 if (!op.results().empty()) { 216 p << " -> (" << op.getResultTypes() << ")"; 217 printBlockTerminators = true; 218 } 219 p.printRegion(op.bodyRegion(), 220 /*printEntryBlockArgs=*/false, 221 /*printBlockTerminators=*/printBlockTerminators); 222 p.printOptionalAttrDict(op->getAttrs()); 223 } 224 225 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 226 OperationState &result) { 227 // Create a region for the body. 228 result.regions.reserve(1); 229 Region *bodyRegion = result.addRegion(); 230 231 // Parse optional results type list. 232 if (parser.parseOptionalArrowTypeList(result.types)) 233 return failure(); 234 235 // Parse the body region. 236 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 237 return failure(); 238 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 239 result.location); 240 241 // Parse the optional attribute list. 242 if (parser.parseOptionalAttrDict(result.attributes)) 243 return failure(); 244 245 return success(); 246 } 247 248 static LogicalResult verify(AllocaScopeOp op) { 249 if (failed(RegionBranchOpInterface::verifyTypes(op))) 250 return failure(); 251 252 return success(); 253 } 254 255 void AllocaScopeOp::getSuccessorRegions( 256 Optional<unsigned> index, ArrayRef<Attribute> operands, 257 SmallVectorImpl<RegionSuccessor> ®ions) { 258 if (index.hasValue()) { 259 regions.push_back(RegionSuccessor(getResults())); 260 return; 261 } 262 263 regions.push_back(RegionSuccessor(&bodyRegion())); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // AssumeAlignmentOp 268 //===----------------------------------------------------------------------===// 269 270 static LogicalResult verify(AssumeAlignmentOp op) { 271 unsigned alignment = op.alignment(); 272 if (!llvm::isPowerOf2_32(alignment)) 273 return op.emitOpError("alignment must be power of 2"); 274 return success(); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // CastOp 279 //===----------------------------------------------------------------------===// 280 281 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 282 /// source memref. This is useful to to fold a memref.cast into a consuming op 283 /// and implement canonicalization patterns for ops in different dialects that 284 /// may consume the results of memref.cast operations. Such foldable memref.cast 285 /// operations are typically inserted as `view` and `subview` ops are 286 /// canonicalized, to preserve the type compatibility of their uses. 287 /// 288 /// Returns true when all conditions are met: 289 /// 1. source and result are ranked memrefs with strided semantics and same 290 /// element type and rank. 291 /// 2. each of the source's size, offset or stride has more static information 292 /// than the corresponding result's size, offset or stride. 293 /// 294 /// Example 1: 295 /// ```mlir 296 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 297 /// %2 = consumer %1 ... : memref<?x?xf32> ... 298 /// ``` 299 /// 300 /// may fold into: 301 /// 302 /// ```mlir 303 /// %2 = consumer %0 ... : memref<8x16xf32> ... 304 /// ``` 305 /// 306 /// Example 2: 307 /// ``` 308 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 309 /// to memref<?x?xf32> 310 /// consumer %1 : memref<?x?xf32> ... 311 /// ``` 312 /// 313 /// may fold into: 314 /// 315 /// ``` 316 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 317 /// ``` 318 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 319 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 320 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 321 322 // Requires ranked MemRefType. 323 if (!sourceType || !resultType) 324 return false; 325 326 // Requires same elemental type. 327 if (sourceType.getElementType() != resultType.getElementType()) 328 return false; 329 330 // Requires same rank. 331 if (sourceType.getRank() != resultType.getRank()) 332 return false; 333 334 // Only fold casts between strided memref forms. 335 int64_t sourceOffset, resultOffset; 336 SmallVector<int64_t, 4> sourceStrides, resultStrides; 337 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 338 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 339 return false; 340 341 // If cast is towards more static sizes along any dimension, don't fold. 342 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 343 auto ss = std::get<0>(it), st = std::get<1>(it); 344 if (ss != st) 345 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 346 return false; 347 } 348 349 // If cast is towards more static offset along any dimension, don't fold. 350 if (sourceOffset != resultOffset) 351 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 352 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 353 return false; 354 355 // If cast is towards more static strides along any dimension, don't fold. 356 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 357 auto ss = std::get<0>(it), st = std::get<1>(it); 358 if (ss != st) 359 if (ShapedType::isDynamicStrideOrOffset(ss) && 360 !ShapedType::isDynamicStrideOrOffset(st)) 361 return false; 362 } 363 364 return true; 365 } 366 367 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 368 if (inputs.size() != 1 || outputs.size() != 1) 369 return false; 370 Type a = inputs.front(), b = outputs.front(); 371 auto aT = a.dyn_cast<MemRefType>(); 372 auto bT = b.dyn_cast<MemRefType>(); 373 374 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 375 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 376 377 if (aT && bT) { 378 if (aT.getElementType() != bT.getElementType()) 379 return false; 380 if (aT.getLayout() != bT.getLayout()) { 381 int64_t aOffset, bOffset; 382 SmallVector<int64_t, 4> aStrides, bStrides; 383 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 384 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 385 aStrides.size() != bStrides.size()) 386 return false; 387 388 // Strides along a dimension/offset are compatible if the value in the 389 // source memref is static and the value in the target memref is the 390 // same. They are also compatible if either one is dynamic (see 391 // description of MemRefCastOp for details). 392 auto checkCompatible = [](int64_t a, int64_t b) { 393 return (a == MemRefType::getDynamicStrideOrOffset() || 394 b == MemRefType::getDynamicStrideOrOffset() || a == b); 395 }; 396 if (!checkCompatible(aOffset, bOffset)) 397 return false; 398 for (const auto &aStride : enumerate(aStrides)) 399 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 400 return false; 401 } 402 if (aT.getMemorySpace() != bT.getMemorySpace()) 403 return false; 404 405 // They must have the same rank, and any specified dimensions must match. 406 if (aT.getRank() != bT.getRank()) 407 return false; 408 409 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 410 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 411 if (aDim != -1 && bDim != -1 && aDim != bDim) 412 return false; 413 } 414 return true; 415 } else { 416 if (!aT && !uaT) 417 return false; 418 if (!bT && !ubT) 419 return false; 420 // Unranked to unranked casting is unsupported 421 if (uaT && ubT) 422 return false; 423 424 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 425 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 426 if (aEltType != bEltType) 427 return false; 428 429 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 430 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 431 return aMemSpace == bMemSpace; 432 } 433 434 return false; 435 } 436 437 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 438 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // CopyOp 443 //===----------------------------------------------------------------------===// 444 445 namespace { 446 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 447 /// and element type, the cast can be skipped. Such CastOps only cast the layout 448 /// of the type. 449 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 450 using OpRewritePattern<CopyOp>::OpRewritePattern; 451 452 LogicalResult matchAndRewrite(CopyOp copyOp, 453 PatternRewriter &rewriter) const override { 454 bool modified = false; 455 456 // Check source. 457 if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) { 458 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 459 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 460 461 if (fromType && toType) { 462 if (fromType.getShape() == toType.getShape() && 463 fromType.getElementType() == toType.getElementType()) { 464 rewriter.updateRootInPlace( 465 copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); 466 modified = true; 467 } 468 } 469 } 470 471 // Check target. 472 if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) { 473 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 474 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 475 476 if (fromType && toType) { 477 if (fromType.getShape() == toType.getShape() && 478 fromType.getElementType() == toType.getElementType()) { 479 rewriter.updateRootInPlace( 480 copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); 481 modified = true; 482 } 483 } 484 } 485 486 return success(modified); 487 } 488 }; 489 490 /// Fold memref.copy(%x, %x). 491 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 492 using OpRewritePattern<CopyOp>::OpRewritePattern; 493 494 LogicalResult matchAndRewrite(CopyOp copyOp, 495 PatternRewriter &rewriter) const override { 496 if (copyOp.source() != copyOp.target()) 497 return failure(); 498 499 rewriter.eraseOp(copyOp); 500 return success(); 501 } 502 }; 503 } // namespace 504 505 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 506 MLIRContext *context) { 507 results.add<FoldCopyOfCast, FoldSelfCopy>(context); 508 } 509 510 //===----------------------------------------------------------------------===// 511 // DeallocOp 512 //===----------------------------------------------------------------------===// 513 514 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 515 SmallVectorImpl<OpFoldResult> &results) { 516 /// dealloc(memrefcast) -> dealloc 517 return foldMemRefCast(*this); 518 } 519 520 //===----------------------------------------------------------------------===// 521 // DimOp 522 //===----------------------------------------------------------------------===// 523 524 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 525 int64_t index) { 526 auto loc = result.location; 527 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 528 build(builder, result, source, indexValue); 529 } 530 531 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 532 Value index) { 533 auto indexTy = builder.getIndexType(); 534 build(builder, result, indexTy, source, index); 535 } 536 537 Optional<int64_t> DimOp::getConstantIndex() { 538 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 539 return constantOp.getValue().cast<IntegerAttr>().getInt(); 540 return {}; 541 } 542 543 static LogicalResult verify(DimOp op) { 544 // Assume unknown index to be in range. 545 Optional<int64_t> index = op.getConstantIndex(); 546 if (!index.hasValue()) 547 return success(); 548 549 // Check that constant index is not knowingly out of range. 550 auto type = op.source().getType(); 551 if (auto memrefType = type.dyn_cast<MemRefType>()) { 552 if (index.getValue() >= memrefType.getRank()) 553 return op.emitOpError("index is out of range"); 554 } else if (type.isa<UnrankedMemRefType>()) { 555 // Assume index to be in range. 556 } else { 557 llvm_unreachable("expected operand with memref type"); 558 } 559 return success(); 560 } 561 562 /// Return a map with key being elements in `vals` and data being number of 563 /// occurences of it. Use std::map, since the `vals` here are strides and the 564 /// dynamic stride value is the same as the tombstone value for 565 /// `DenseMap<int64_t>`. 566 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 567 std::map<int64_t, unsigned> numOccurences; 568 for (auto val : vals) 569 numOccurences[val]++; 570 return numOccurences; 571 } 572 573 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 574 /// to be a subset of `originalType` with some `1` entries erased, return the 575 /// set of indices that specifies which of the entries of `originalShape` are 576 /// dropped to obtain `reducedShape`. 577 /// This accounts for cases where there are multiple unit-dims, but only a 578 /// subset of those are dropped. For MemRefTypes these can be disambiguated 579 /// using the strides. If a dimension is dropped the stride must be dropped too. 580 static llvm::Optional<llvm::SmallDenseSet<unsigned>> 581 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 582 ArrayRef<OpFoldResult> sizes) { 583 llvm::SmallDenseSet<unsigned> unusedDims; 584 if (originalType.getRank() == reducedType.getRank()) 585 return unusedDims; 586 587 for (const auto &dim : llvm::enumerate(sizes)) 588 if (auto attr = dim.value().dyn_cast<Attribute>()) 589 if (attr.cast<IntegerAttr>().getInt() == 1) 590 unusedDims.insert(dim.index()); 591 592 SmallVector<int64_t> originalStrides, candidateStrides; 593 int64_t originalOffset, candidateOffset; 594 if (failed( 595 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 596 failed( 597 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 598 return llvm::None; 599 600 // For memrefs, a dimension is truly dropped if its corresponding stride is 601 // also dropped. This is particularly important when more than one of the dims 602 // is 1. Track the number of occurences of the strides in the original type 603 // and the candidate type. For each unused dim that stride should not be 604 // present in the candidate type. Note that there could be multiple dimensions 605 // that have the same size. We dont need to exactly figure out which dim 606 // corresponds to which stride, we just need to verify that the number of 607 // reptitions of a stride in the original + number of unused dims with that 608 // stride == number of repititions of a stride in the candidate. 609 std::map<int64_t, unsigned> currUnaccountedStrides = 610 getNumOccurences(originalStrides); 611 std::map<int64_t, unsigned> candidateStridesNumOccurences = 612 getNumOccurences(candidateStrides); 613 llvm::SmallDenseSet<unsigned> prunedUnusedDims; 614 for (unsigned dim : unusedDims) { 615 int64_t originalStride = originalStrides[dim]; 616 if (currUnaccountedStrides[originalStride] > 617 candidateStridesNumOccurences[originalStride]) { 618 // This dim can be treated as dropped. 619 currUnaccountedStrides[originalStride]--; 620 continue; 621 } 622 if (currUnaccountedStrides[originalStride] == 623 candidateStridesNumOccurences[originalStride]) { 624 // The stride for this is not dropped. Keep as is. 625 prunedUnusedDims.insert(dim); 626 continue; 627 } 628 if (currUnaccountedStrides[originalStride] < 629 candidateStridesNumOccurences[originalStride]) { 630 // This should never happen. Cant have a stride in the reduced rank type 631 // that wasnt in the original one. 632 return llvm::None; 633 } 634 } 635 636 for (auto prunedDim : prunedUnusedDims) 637 unusedDims.erase(prunedDim); 638 if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) 639 return llvm::None; 640 return unusedDims; 641 } 642 643 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() { 644 MemRefType sourceType = getSourceType(); 645 MemRefType resultType = getType(); 646 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 647 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 648 assert(unusedDims && "unable to find unused dims of subview"); 649 return *unusedDims; 650 } 651 652 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 653 // All forms of folding require a known index. 654 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 655 if (!index) 656 return {}; 657 658 // Folding for unranked types (UnrankedMemRefType) is not supported. 659 auto memrefType = source().getType().dyn_cast<MemRefType>(); 660 if (!memrefType) 661 return {}; 662 663 // Fold if the shape extent along the given index is known. 664 if (!memrefType.isDynamicDim(index.getInt())) { 665 Builder builder(getContext()); 666 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 667 } 668 669 // The size at the given index is now known to be a dynamic size. 670 unsigned unsignedIndex = index.getValue().getZExtValue(); 671 672 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 673 Operation *definingOp = source().getDefiningOp(); 674 675 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 676 return *(alloc.getDynamicSizes().begin() + 677 memrefType.getDynamicDimIndex(unsignedIndex)); 678 679 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 680 return *(alloca.getDynamicSizes().begin() + 681 memrefType.getDynamicDimIndex(unsignedIndex)); 682 683 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 684 return *(view.getDynamicSizes().begin() + 685 memrefType.getDynamicDimIndex(unsignedIndex)); 686 687 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 688 llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims(); 689 unsigned resultIndex = 0; 690 unsigned sourceRank = subview.getSourceType().getRank(); 691 unsigned sourceIndex = 0; 692 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 693 if (unusedDims.count(i)) 694 continue; 695 if (resultIndex == unsignedIndex) { 696 sourceIndex = i; 697 break; 698 } 699 resultIndex++; 700 } 701 assert(subview.isDynamicSize(sourceIndex) && 702 "expected dynamic subview size"); 703 return subview.getDynamicSize(sourceIndex); 704 } 705 706 if (auto sizeInterface = 707 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 708 assert(sizeInterface.isDynamicSize(unsignedIndex) && 709 "Expected dynamic subview size"); 710 return sizeInterface.getDynamicSize(unsignedIndex); 711 } 712 713 // dim(memrefcast) -> dim 714 if (succeeded(foldMemRefCast(*this))) 715 return getResult(); 716 717 return {}; 718 } 719 720 namespace { 721 /// Fold dim of a memref reshape operation to a load into the reshape's shape 722 /// operand. 723 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 724 using OpRewritePattern<DimOp>::OpRewritePattern; 725 726 LogicalResult matchAndRewrite(DimOp dim, 727 PatternRewriter &rewriter) const override { 728 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 729 730 if (!reshape) 731 return failure(); 732 733 // Place the load directly after the reshape to ensure that the shape memref 734 // was not mutated. 735 rewriter.setInsertionPointAfter(reshape); 736 Location loc = dim.getLoc(); 737 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 738 if (load.getType() != dim.getType()) 739 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 740 rewriter.replaceOp(dim, load); 741 return success(); 742 } 743 }; 744 745 } // namespace 746 747 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 748 MLIRContext *context) { 749 results.add<DimOfMemRefReshape>(context); 750 } 751 752 // --------------------------------------------------------------------------- 753 // DmaStartOp 754 // --------------------------------------------------------------------------- 755 756 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 757 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 758 ValueRange destIndices, Value numElements, 759 Value tagMemRef, ValueRange tagIndices, Value stride, 760 Value elementsPerStride) { 761 result.addOperands(srcMemRef); 762 result.addOperands(srcIndices); 763 result.addOperands(destMemRef); 764 result.addOperands(destIndices); 765 result.addOperands({numElements, tagMemRef}); 766 result.addOperands(tagIndices); 767 if (stride) 768 result.addOperands({stride, elementsPerStride}); 769 } 770 771 static void print(OpAsmPrinter &p, DmaStartOp op) { 772 p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " 773 << op.getDstMemRef() << '[' << op.getDstIndices() << "], " 774 << op.getNumElements() << ", " << op.getTagMemRef() << '[' 775 << op.getTagIndices() << ']'; 776 if (op.isStrided()) 777 p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); 778 779 p.printOptionalAttrDict(op->getAttrs()); 780 p << " : " << op.getSrcMemRef().getType() << ", " 781 << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); 782 } 783 784 // Parse DmaStartOp. 785 // Ex: 786 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 787 // %tag[%index], %stride, %num_elt_per_stride : 788 // : memref<3076 x f32, 0>, 789 // memref<1024 x f32, 2>, 790 // memref<1 x i32> 791 // 792 static ParseResult parseDmaStartOp(OpAsmParser &parser, 793 OperationState &result) { 794 OpAsmParser::OperandType srcMemRefInfo; 795 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 796 OpAsmParser::OperandType dstMemRefInfo; 797 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 798 OpAsmParser::OperandType numElementsInfo; 799 OpAsmParser::OperandType tagMemrefInfo; 800 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 801 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 802 803 SmallVector<Type, 3> types; 804 auto indexType = parser.getBuilder().getIndexType(); 805 806 // Parse and resolve the following list of operands: 807 // *) source memref followed by its indices (in square brackets). 808 // *) destination memref followed by its indices (in square brackets). 809 // *) dma size in KiB. 810 if (parser.parseOperand(srcMemRefInfo) || 811 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 812 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 813 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 814 parser.parseComma() || parser.parseOperand(numElementsInfo) || 815 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 816 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 817 return failure(); 818 819 // Parse optional stride and elements per stride. 820 if (parser.parseTrailingOperandList(strideInfo)) 821 return failure(); 822 823 bool isStrided = strideInfo.size() == 2; 824 if (!strideInfo.empty() && !isStrided) { 825 return parser.emitError(parser.getNameLoc(), 826 "expected two stride related operands"); 827 } 828 829 if (parser.parseColonTypeList(types)) 830 return failure(); 831 if (types.size() != 3) 832 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 833 834 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 835 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 836 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 837 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 838 // size should be an index. 839 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 840 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 841 // tag indices should be index. 842 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 843 return failure(); 844 845 if (isStrided) { 846 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 847 return failure(); 848 } 849 850 return success(); 851 } 852 853 static LogicalResult verify(DmaStartOp op) { 854 unsigned numOperands = op.getNumOperands(); 855 856 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 857 // the number of elements. 858 if (numOperands < 4) 859 return op.emitOpError("expected at least 4 operands"); 860 861 // Check types of operands. The order of these calls is important: the later 862 // calls rely on some type properties to compute the operand position. 863 // 1. Source memref. 864 if (!op.getSrcMemRef().getType().isa<MemRefType>()) 865 return op.emitOpError("expected source to be of memref type"); 866 if (numOperands < op.getSrcMemRefRank() + 4) 867 return op.emitOpError() 868 << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; 869 if (!op.getSrcIndices().empty() && 870 !llvm::all_of(op.getSrcIndices().getTypes(), 871 [](Type t) { return t.isIndex(); })) 872 return op.emitOpError("expected source indices to be of index type"); 873 874 // 2. Destination memref. 875 if (!op.getDstMemRef().getType().isa<MemRefType>()) 876 return op.emitOpError("expected destination to be of memref type"); 877 unsigned numExpectedOperands = 878 op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; 879 if (numOperands < numExpectedOperands) 880 return op.emitOpError() 881 << "expected at least " << numExpectedOperands << " operands"; 882 if (!op.getDstIndices().empty() && 883 !llvm::all_of(op.getDstIndices().getTypes(), 884 [](Type t) { return t.isIndex(); })) 885 return op.emitOpError("expected destination indices to be of index type"); 886 887 // 3. Number of elements. 888 if (!op.getNumElements().getType().isIndex()) 889 return op.emitOpError("expected num elements to be of index type"); 890 891 // 4. Tag memref. 892 if (!op.getTagMemRef().getType().isa<MemRefType>()) 893 return op.emitOpError("expected tag to be of memref type"); 894 numExpectedOperands += op.getTagMemRefRank(); 895 if (numOperands < numExpectedOperands) 896 return op.emitOpError() 897 << "expected at least " << numExpectedOperands << " operands"; 898 if (!op.getTagIndices().empty() && 899 !llvm::all_of(op.getTagIndices().getTypes(), 900 [](Type t) { return t.isIndex(); })) 901 return op.emitOpError("expected tag indices to be of index type"); 902 903 // Optional stride-related operands must be either both present or both 904 // absent. 905 if (numOperands != numExpectedOperands && 906 numOperands != numExpectedOperands + 2) 907 return op.emitOpError("incorrect number of operands"); 908 909 // 5. Strides. 910 if (op.isStrided()) { 911 if (!op.getStride().getType().isIndex() || 912 !op.getNumElementsPerStride().getType().isIndex()) 913 return op.emitOpError( 914 "expected stride and num elements per stride to be of type index"); 915 } 916 917 return success(); 918 } 919 920 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 921 SmallVectorImpl<OpFoldResult> &results) { 922 /// dma_start(memrefcast) -> dma_start 923 return foldMemRefCast(*this); 924 } 925 926 // --------------------------------------------------------------------------- 927 // DmaWaitOp 928 // --------------------------------------------------------------------------- 929 930 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 931 SmallVectorImpl<OpFoldResult> &results) { 932 /// dma_wait(memrefcast) -> dma_wait 933 return foldMemRefCast(*this); 934 } 935 936 static LogicalResult verify(DmaWaitOp op) { 937 // Check that the number of tag indices matches the tagMemRef rank. 938 unsigned numTagIndices = op.tagIndices().size(); 939 unsigned tagMemRefRank = op.getTagMemRefRank(); 940 if (numTagIndices != tagMemRefRank) 941 return op.emitOpError() << "expected tagIndices to have the same number of " 942 "elements as the tagMemRef rank, expected " 943 << tagMemRefRank << ", but got " << numTagIndices; 944 return success(); 945 } 946 947 //===----------------------------------------------------------------------===// 948 // GlobalOp 949 //===----------------------------------------------------------------------===// 950 951 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 952 TypeAttr type, 953 Attribute initialValue) { 954 p << type; 955 if (!op.isExternal()) { 956 p << " = "; 957 if (op.isUninitialized()) 958 p << "uninitialized"; 959 else 960 p.printAttributeWithoutType(initialValue); 961 } 962 } 963 964 static ParseResult 965 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 966 Attribute &initialValue) { 967 Type type; 968 if (parser.parseType(type)) 969 return failure(); 970 971 auto memrefType = type.dyn_cast<MemRefType>(); 972 if (!memrefType || !memrefType.hasStaticShape()) 973 return parser.emitError(parser.getNameLoc()) 974 << "type should be static shaped memref, but got " << type; 975 typeAttr = TypeAttr::get(type); 976 977 if (parser.parseOptionalEqual()) 978 return success(); 979 980 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 981 initialValue = UnitAttr::get(parser.getContext()); 982 return success(); 983 } 984 985 Type tensorType = getTensorTypeFromMemRefType(memrefType); 986 if (parser.parseAttribute(initialValue, tensorType)) 987 return failure(); 988 if (!initialValue.isa<ElementsAttr>()) 989 return parser.emitError(parser.getNameLoc()) 990 << "initial value should be a unit or elements attribute"; 991 return success(); 992 } 993 994 static LogicalResult verify(GlobalOp op) { 995 auto memrefType = op.type().dyn_cast<MemRefType>(); 996 if (!memrefType || !memrefType.hasStaticShape()) 997 return op.emitOpError("type should be static shaped memref, but got ") 998 << op.type(); 999 1000 // Verify that the initial value, if present, is either a unit attribute or 1001 // an elements attribute. 1002 if (op.initial_value().hasValue()) { 1003 Attribute initValue = op.initial_value().getValue(); 1004 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1005 return op.emitOpError("initial value should be a unit or elements " 1006 "attribute, but got ") 1007 << initValue; 1008 1009 // Check that the type of the initial value is compatible with the type of 1010 // the global variable. 1011 if (initValue.isa<ElementsAttr>()) { 1012 Type initType = initValue.getType(); 1013 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1014 if (initType != tensorType) 1015 return op.emitOpError("initial value expected to be of type ") 1016 << tensorType << ", but was of type " << initType; 1017 } 1018 } 1019 1020 if (Optional<uint64_t> alignAttr = op.alignment()) { 1021 uint64_t alignment = alignAttr.getValue(); 1022 1023 if (!llvm::isPowerOf2_64(alignment)) 1024 return op->emitError() << "alignment attribute value " << alignment 1025 << " is not a power of 2"; 1026 } 1027 1028 // TODO: verify visibility for declarations. 1029 return success(); 1030 } 1031 1032 //===----------------------------------------------------------------------===// 1033 // GetGlobalOp 1034 //===----------------------------------------------------------------------===// 1035 1036 LogicalResult 1037 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1038 // Verify that the result type is same as the type of the referenced 1039 // memref.global op. 1040 auto global = 1041 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1042 if (!global) 1043 return emitOpError("'") 1044 << name() << "' does not reference a valid global memref"; 1045 1046 Type resultType = result().getType(); 1047 if (global.type() != resultType) 1048 return emitOpError("result type ") 1049 << resultType << " does not match type " << global.type() 1050 << " of the global memref @" << name(); 1051 return success(); 1052 } 1053 1054 //===----------------------------------------------------------------------===// 1055 // LoadOp 1056 //===----------------------------------------------------------------------===// 1057 1058 static LogicalResult verify(LoadOp op) { 1059 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1060 return op.emitOpError("incorrect number of indices for load"); 1061 return success(); 1062 } 1063 1064 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1065 /// load(memrefcast) -> load 1066 if (succeeded(foldMemRefCast(*this))) 1067 return getResult(); 1068 return OpFoldResult(); 1069 } 1070 1071 //===----------------------------------------------------------------------===// 1072 // PrefetchOp 1073 //===----------------------------------------------------------------------===// 1074 1075 static void print(OpAsmPrinter &p, PrefetchOp op) { 1076 p << " " << op.memref() << '['; 1077 p.printOperands(op.indices()); 1078 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1079 p << ", locality<" << op.localityHint(); 1080 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1081 p.printOptionalAttrDict( 1082 op->getAttrs(), 1083 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1084 p << " : " << op.getMemRefType(); 1085 } 1086 1087 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1088 OperationState &result) { 1089 OpAsmParser::OperandType memrefInfo; 1090 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1091 IntegerAttr localityHint; 1092 MemRefType type; 1093 StringRef readOrWrite, cacheType; 1094 1095 auto indexTy = parser.getBuilder().getIndexType(); 1096 auto i32Type = parser.getBuilder().getIntegerType(32); 1097 if (parser.parseOperand(memrefInfo) || 1098 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1099 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1100 parser.parseComma() || parser.parseKeyword("locality") || 1101 parser.parseLess() || 1102 parser.parseAttribute(localityHint, i32Type, "localityHint", 1103 result.attributes) || 1104 parser.parseGreater() || parser.parseComma() || 1105 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1106 parser.resolveOperand(memrefInfo, type, result.operands) || 1107 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1108 return failure(); 1109 1110 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1111 return parser.emitError(parser.getNameLoc(), 1112 "rw specifier has to be 'read' or 'write'"); 1113 result.addAttribute( 1114 PrefetchOp::getIsWriteAttrName(), 1115 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1116 1117 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1118 return parser.emitError(parser.getNameLoc(), 1119 "cache type has to be 'data' or 'instr'"); 1120 1121 result.addAttribute( 1122 PrefetchOp::getIsDataCacheAttrName(), 1123 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1124 1125 return success(); 1126 } 1127 1128 static LogicalResult verify(PrefetchOp op) { 1129 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1130 return op.emitOpError("too few indices"); 1131 1132 return success(); 1133 } 1134 1135 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1136 SmallVectorImpl<OpFoldResult> &results) { 1137 // prefetch(memrefcast) -> prefetch 1138 return foldMemRefCast(*this); 1139 } 1140 1141 //===----------------------------------------------------------------------===// 1142 // RankOp 1143 //===----------------------------------------------------------------------===// 1144 1145 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1146 // Constant fold rank when the rank of the operand is known. 1147 auto type = getOperand().getType(); 1148 auto shapedType = type.dyn_cast<ShapedType>(); 1149 if (shapedType && shapedType.hasRank()) 1150 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1151 return IntegerAttr(); 1152 } 1153 1154 //===----------------------------------------------------------------------===// 1155 // ReinterpretCastOp 1156 //===----------------------------------------------------------------------===// 1157 1158 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1159 /// `staticSizes` and `staticStrides` are automatically filled with 1160 /// source-memref-rank sentinel values that encode dynamic entries. 1161 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1162 MemRefType resultType, Value source, 1163 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1164 ArrayRef<OpFoldResult> strides, 1165 ArrayRef<NamedAttribute> attrs) { 1166 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1167 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1168 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1169 ShapedType::kDynamicStrideOrOffset); 1170 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1171 ShapedType::kDynamicSize); 1172 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1173 ShapedType::kDynamicStrideOrOffset); 1174 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1175 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1176 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1177 result.addAttributes(attrs); 1178 } 1179 1180 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1181 MemRefType resultType, Value source, 1182 int64_t offset, ArrayRef<int64_t> sizes, 1183 ArrayRef<int64_t> strides, 1184 ArrayRef<NamedAttribute> attrs) { 1185 SmallVector<OpFoldResult> sizeValues = 1186 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1187 return b.getI64IntegerAttr(v); 1188 })); 1189 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1190 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1191 return b.getI64IntegerAttr(v); 1192 })); 1193 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1194 strideValues, attrs); 1195 } 1196 1197 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1198 MemRefType resultType, Value source, Value offset, 1199 ValueRange sizes, ValueRange strides, 1200 ArrayRef<NamedAttribute> attrs) { 1201 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1202 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1203 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1204 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1205 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1206 } 1207 1208 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1209 // completed automatically, like we have for subview and extract_slice. 1210 static LogicalResult verify(ReinterpretCastOp op) { 1211 // The source and result memrefs should be in the same memory space. 1212 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1213 auto resultType = op.getType().cast<MemRefType>(); 1214 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1215 return op.emitError("different memory spaces specified for source type ") 1216 << srcType << " and result memref type " << resultType; 1217 if (srcType.getElementType() != resultType.getElementType()) 1218 return op.emitError("different element types specified for source type ") 1219 << srcType << " and result memref type " << resultType; 1220 1221 // Match sizes in result memref type and in static_sizes attribute. 1222 for (auto &en : 1223 llvm::enumerate(llvm::zip(resultType.getShape(), 1224 extractFromI64ArrayAttr(op.static_sizes())))) { 1225 int64_t resultSize = std::get<0>(en.value()); 1226 int64_t expectedSize = std::get<1>(en.value()); 1227 if (!ShapedType::isDynamic(resultSize) && 1228 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1229 return op.emitError("expected result type with size = ") 1230 << expectedSize << " instead of " << resultSize 1231 << " in dim = " << en.index(); 1232 } 1233 1234 // Match offset and strides in static_offset and static_strides attributes. If 1235 // result memref type has no affine map specified, this will assume an 1236 // identity layout. 1237 int64_t resultOffset; 1238 SmallVector<int64_t, 4> resultStrides; 1239 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1240 return op.emitError( 1241 "expected result type to have strided layout but found ") 1242 << resultType; 1243 1244 // Match offset in result memref type and in static_offsets attribute. 1245 int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); 1246 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1247 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1248 resultOffset != expectedOffset) 1249 return op.emitError("expected result type with offset = ") 1250 << resultOffset << " instead of " << expectedOffset; 1251 1252 // Match strides in result memref type and in static_strides attribute. 1253 for (auto &en : llvm::enumerate(llvm::zip( 1254 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1255 int64_t resultStride = std::get<0>(en.value()); 1256 int64_t expectedStride = std::get<1>(en.value()); 1257 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1258 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1259 resultStride != expectedStride) 1260 return op.emitError("expected result type with stride = ") 1261 << expectedStride << " instead of " << resultStride 1262 << " in dim = " << en.index(); 1263 } 1264 1265 return success(); 1266 } 1267 1268 //===----------------------------------------------------------------------===// 1269 // Reassociative reshape ops 1270 //===----------------------------------------------------------------------===// 1271 1272 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1273 return getSymbolLessAffineMaps(getReassociationExprs()); 1274 } 1275 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1276 return convertReassociationIndicesToExprs(getContext(), 1277 getReassociationIndices()); 1278 } 1279 1280 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1281 return getSymbolLessAffineMaps(getReassociationExprs()); 1282 } 1283 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1284 return convertReassociationIndicesToExprs(getContext(), 1285 getReassociationIndices()); 1286 } 1287 1288 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1289 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1290 } 1291 1292 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1293 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1294 } 1295 1296 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1297 /// copies. 1298 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1299 ArrayRef<int64_t> sizes, 1300 ArrayRef<AffineExpr> strides) { 1301 // Bands of extent one can be reshaped, as they are not reshaped at all. 1302 if (extent == 1) 1303 return true; 1304 // Otherwise, the size of the first dimension needs to be known. 1305 if (ShapedType::isDynamic(sizes[dim])) 1306 return false; 1307 assert(sizes.size() == strides.size() && "mismatched ranks"); 1308 // off by 1 indexing to avoid out of bounds 1309 // V 1310 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1311 // Only bands of static shapes are reshapable. This is due to the fact that 1312 // there is no relation between dynamic sizes and dynamic strides: we do not 1313 // have enough information to know whether a "-1" size corresponds to the 1314 // proper symbol in the AffineExpr of a stride. 1315 if (ShapedType::isDynamic(sizes[idx + 1])) 1316 return false; 1317 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1318 // simplify on the fly and catch more reshapable cases. 1319 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1320 return false; 1321 } 1322 return true; 1323 } 1324 1325 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1326 /// expected to be valid) to `type`. 1327 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1328 /// MemRefType. 1329 static MemRefType 1330 computeReshapeCollapsedType(MemRefType type, 1331 ArrayRef<AffineMap> reassociation) { 1332 auto sizes = type.getShape(); 1333 AffineExpr offset; 1334 SmallVector<AffineExpr, 4> strides; 1335 auto status = getStridesAndOffset(type, strides, offset); 1336 (void)status; 1337 assert(succeeded(status) && "expected strided memref"); 1338 1339 SmallVector<int64_t, 4> newSizes; 1340 newSizes.reserve(reassociation.size()); 1341 SmallVector<AffineExpr, 4> newStrides; 1342 newStrides.reserve(reassociation.size()); 1343 1344 // Use the fact that reassociation is valid to simplify the logic: only use 1345 // each map's rank. 1346 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1347 unsigned currentDim = 0; 1348 for (AffineMap m : reassociation) { 1349 unsigned dim = m.getNumResults(); 1350 int64_t size = 1; 1351 AffineExpr stride = strides[currentDim + dim - 1]; 1352 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1353 size = ShapedType::kDynamicSize; 1354 stride = AffineExpr(); 1355 } else { 1356 for (unsigned d = 0; d < dim; ++d) 1357 size *= sizes[currentDim + d]; 1358 } 1359 newSizes.push_back(size); 1360 newStrides.push_back(stride); 1361 currentDim += dim; 1362 } 1363 1364 // Early-exit: if `type` is contiguous, the result must be contiguous. 1365 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1366 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1367 1368 // Convert back to int64_t because we don't have enough information to create 1369 // new strided layouts from AffineExpr only. This corresponds to a case where 1370 // copies may be necessary. 1371 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1372 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1373 intOffset = o.getValue(); 1374 SmallVector<int64_t, 4> intStrides; 1375 intStrides.reserve(strides.size()); 1376 for (auto stride : newStrides) { 1377 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1378 intStrides.push_back(cst.getValue()); 1379 else 1380 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1381 } 1382 auto layout = 1383 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1384 return canonicalizeStridedLayout( 1385 MemRefType::Builder(type).setShape(newSizes).setLayout( 1386 AffineMapAttr::get(layout))); 1387 } 1388 1389 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1390 ArrayRef<ReassociationIndices> reassociation, 1391 ArrayRef<NamedAttribute> attrs) { 1392 auto memRefType = src.getType().cast<MemRefType>(); 1393 auto resultType = computeReshapeCollapsedType( 1394 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1395 b.getContext(), reassociation))); 1396 build(b, result, resultType, src, attrs); 1397 result.addAttribute(getReassociationAttrName(), 1398 getReassociationIndicesAttribute(b, reassociation)); 1399 } 1400 1401 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1402 ArrayRef<ReassociationIndices> reassociation, 1403 ArrayRef<NamedAttribute> attrs) { 1404 auto memRefType = src.getType().cast<MemRefType>(); 1405 auto resultType = computeReshapeCollapsedType( 1406 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1407 b.getContext(), reassociation))); 1408 build(b, result, resultType, src, attrs); 1409 result.addAttribute(getReassociationAttrName(), 1410 getReassociationIndicesAttribute(b, reassociation)); 1411 } 1412 1413 template <typename ReshapeOp, 1414 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1415 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1416 MemRefType collapsedType) { 1417 if (failed( 1418 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1419 return failure(); 1420 auto maps = op.getReassociationMaps(); 1421 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1422 if (collapsedType != expectedType) 1423 return op.emitOpError("expected collapsed type to be ") 1424 << expectedType << ", but got " << collapsedType; 1425 return success(); 1426 } 1427 1428 static LogicalResult verify(ExpandShapeOp op) { 1429 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1430 } 1431 1432 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1433 MLIRContext *context) { 1434 results.add<CollapseReshapeOps<ExpandShapeOp>, 1435 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1436 } 1437 1438 static LogicalResult verify(CollapseShapeOp op) { 1439 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1440 } 1441 1442 struct CollapseShapeOpMemRefCastFolder 1443 : public OpRewritePattern<CollapseShapeOp> { 1444 public: 1445 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1446 1447 LogicalResult matchAndRewrite(CollapseShapeOp op, 1448 PatternRewriter &rewriter) const override { 1449 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1450 if (!cast) 1451 return failure(); 1452 1453 if (!CastOp::canFoldIntoConsumerOp(cast)) 1454 return failure(); 1455 1456 Type newResultType = computeReshapeCollapsedType( 1457 cast.getOperand().getType().cast<MemRefType>(), 1458 op.getReassociationMaps()); 1459 1460 if (newResultType == op.getResultType()) { 1461 rewriter.updateRootInPlace( 1462 op, [&]() { op.srcMutable().assign(cast.source()); }); 1463 } else { 1464 Value newOp = rewriter.create<CollapseShapeOp>( 1465 op->getLoc(), cast.source(), op.getReassociationIndices()); 1466 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1467 } 1468 return success(); 1469 } 1470 }; 1471 1472 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1473 MLIRContext *context) { 1474 results.add<CollapseReshapeOps<CollapseShapeOp>, 1475 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1476 CollapseShapeOpMemRefCastFolder>(context); 1477 } 1478 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1479 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1480 } 1481 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1482 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1483 } 1484 1485 //===----------------------------------------------------------------------===// 1486 // ReshapeOp 1487 //===----------------------------------------------------------------------===// 1488 1489 static LogicalResult verify(ReshapeOp op) { 1490 Type operandType = op.source().getType(); 1491 Type resultType = op.result().getType(); 1492 1493 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1494 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1495 if (operandElementType != resultElementType) 1496 return op.emitOpError("element types of source and destination memref " 1497 "types should be the same"); 1498 1499 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1500 if (!operandMemRefType.getLayout().isIdentity()) 1501 return op.emitOpError( 1502 "source memref type should have identity affine map"); 1503 1504 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1505 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1506 if (resultMemRefType) { 1507 if (!resultMemRefType.getLayout().isIdentity()) 1508 return op.emitOpError( 1509 "result memref type should have identity affine map"); 1510 if (shapeSize == ShapedType::kDynamicSize) 1511 return op.emitOpError("cannot use shape operand with dynamic length to " 1512 "reshape to statically-ranked memref type"); 1513 if (shapeSize != resultMemRefType.getRank()) 1514 return op.emitOpError( 1515 "length of shape operand differs from the result's memref rank"); 1516 } 1517 return success(); 1518 } 1519 1520 //===----------------------------------------------------------------------===// 1521 // StoreOp 1522 //===----------------------------------------------------------------------===// 1523 1524 static LogicalResult verify(StoreOp op) { 1525 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1526 return op.emitOpError("store index operand count not equal to memref rank"); 1527 1528 return success(); 1529 } 1530 1531 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1532 SmallVectorImpl<OpFoldResult> &results) { 1533 /// store(memrefcast) -> store 1534 return foldMemRefCast(*this, getValueToStore()); 1535 } 1536 1537 //===----------------------------------------------------------------------===// 1538 // SubViewOp 1539 //===----------------------------------------------------------------------===// 1540 1541 namespace { 1542 /// Helpers to write more idiomatic operations. 1543 namespace saturated_arith { 1544 struct Wrapper { 1545 explicit Wrapper(int64_t v) : v(v) {} 1546 operator int64_t() { return v; } 1547 int64_t v; 1548 }; 1549 Wrapper operator+(Wrapper a, int64_t b) { 1550 if (ShapedType::isDynamicStrideOrOffset(a) || 1551 ShapedType::isDynamicStrideOrOffset(b)) 1552 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1553 return Wrapper(a.v + b); 1554 } 1555 Wrapper operator*(Wrapper a, int64_t b) { 1556 if (ShapedType::isDynamicStrideOrOffset(a) || 1557 ShapedType::isDynamicStrideOrOffset(b)) 1558 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1559 return Wrapper(a.v * b); 1560 } 1561 } // namespace saturated_arith 1562 } // namespace 1563 1564 /// A subview result type can be fully inferred from the source type and the 1565 /// static representation of offsets, sizes and strides. Special sentinels 1566 /// encode the dynamic case. 1567 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1568 ArrayRef<int64_t> staticOffsets, 1569 ArrayRef<int64_t> staticSizes, 1570 ArrayRef<int64_t> staticStrides) { 1571 unsigned rank = sourceMemRefType.getRank(); 1572 (void)rank; 1573 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 1574 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 1575 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 1576 1577 // Extract source offset and strides. 1578 int64_t sourceOffset; 1579 SmallVector<int64_t, 4> sourceStrides; 1580 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1581 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1582 (void)res; 1583 1584 // Compute target offset whose value is: 1585 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1586 int64_t targetOffset = sourceOffset; 1587 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1588 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1589 using namespace saturated_arith; 1590 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1591 } 1592 1593 // Compute target stride whose value is: 1594 // `sourceStrides_i * staticStrides_i`. 1595 SmallVector<int64_t, 4> targetStrides; 1596 targetStrides.reserve(staticOffsets.size()); 1597 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1598 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1599 using namespace saturated_arith; 1600 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1601 } 1602 1603 // The type is now known. 1604 return MemRefType::get( 1605 staticSizes, sourceMemRefType.getElementType(), 1606 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1607 sourceMemRefType.getContext()), 1608 sourceMemRefType.getMemorySpace()); 1609 } 1610 1611 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1612 ArrayRef<OpFoldResult> offsets, 1613 ArrayRef<OpFoldResult> sizes, 1614 ArrayRef<OpFoldResult> strides) { 1615 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1616 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1617 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1618 ShapedType::kDynamicStrideOrOffset); 1619 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1620 ShapedType::kDynamicSize); 1621 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1622 ShapedType::kDynamicStrideOrOffset); 1623 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1624 staticSizes, staticStrides); 1625 } 1626 1627 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1628 MemRefType sourceRankedTensorType, 1629 ArrayRef<int64_t> offsets, 1630 ArrayRef<int64_t> sizes, 1631 ArrayRef<int64_t> strides) { 1632 auto inferredType = 1633 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 1634 .cast<MemRefType>(); 1635 assert(inferredType.getRank() >= resultRank && "expected "); 1636 int rankDiff = inferredType.getRank() - resultRank; 1637 if (rankDiff > 0) { 1638 auto shape = inferredType.getShape(); 1639 llvm::SmallDenseSet<unsigned> dimsToProject; 1640 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1641 SmallVector<int64_t> projectedShape; 1642 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1643 if (!dimsToProject.contains(pos)) 1644 projectedShape.push_back(shape[pos]); 1645 1646 AffineMap map = inferredType.getLayout().getAffineMap(); 1647 if (!map.isIdentity()) 1648 map = getProjectedMap(map, dimsToProject); 1649 inferredType = 1650 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1651 inferredType.getMemorySpace()); 1652 } 1653 return inferredType; 1654 } 1655 1656 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1657 MemRefType sourceRankedTensorType, 1658 ArrayRef<OpFoldResult> offsets, 1659 ArrayRef<OpFoldResult> sizes, 1660 ArrayRef<OpFoldResult> strides) { 1661 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1662 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1663 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1664 ShapedType::kDynamicStrideOrOffset); 1665 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1666 ShapedType::kDynamicSize); 1667 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1668 ShapedType::kDynamicStrideOrOffset); 1669 return SubViewOp::inferRankReducedResultType( 1670 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1671 staticStrides); 1672 } 1673 // Build a SubViewOp with mixed static and dynamic entries and custom result 1674 // type. If the type passed is nullptr, it is inferred. 1675 void SubViewOp::build(OpBuilder &b, OperationState &result, 1676 MemRefType resultType, Value source, 1677 ArrayRef<OpFoldResult> offsets, 1678 ArrayRef<OpFoldResult> sizes, 1679 ArrayRef<OpFoldResult> strides, 1680 ArrayRef<NamedAttribute> attrs) { 1681 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1682 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1683 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1684 ShapedType::kDynamicStrideOrOffset); 1685 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1686 ShapedType::kDynamicSize); 1687 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1688 ShapedType::kDynamicStrideOrOffset); 1689 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1690 // Structuring implementation this way avoids duplication between builders. 1691 if (!resultType) { 1692 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1693 staticSizes, staticStrides) 1694 .cast<MemRefType>(); 1695 } 1696 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1697 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1698 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1699 result.addAttributes(attrs); 1700 } 1701 1702 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1703 // type. 1704 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1705 ArrayRef<OpFoldResult> offsets, 1706 ArrayRef<OpFoldResult> sizes, 1707 ArrayRef<OpFoldResult> strides, 1708 ArrayRef<NamedAttribute> attrs) { 1709 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1710 } 1711 1712 // Build a SubViewOp with static entries and inferred result type. 1713 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1714 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1715 ArrayRef<int64_t> strides, 1716 ArrayRef<NamedAttribute> attrs) { 1717 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1718 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1719 return b.getI64IntegerAttr(v); 1720 })); 1721 SmallVector<OpFoldResult> sizeValues = 1722 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1723 return b.getI64IntegerAttr(v); 1724 })); 1725 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1726 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1727 return b.getI64IntegerAttr(v); 1728 })); 1729 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1730 } 1731 1732 // Build a SubViewOp with dynamic entries and custom result type. If the 1733 // type passed is nullptr, it is inferred. 1734 void SubViewOp::build(OpBuilder &b, OperationState &result, 1735 MemRefType resultType, Value source, 1736 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1737 ArrayRef<int64_t> strides, 1738 ArrayRef<NamedAttribute> attrs) { 1739 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1740 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1741 return b.getI64IntegerAttr(v); 1742 })); 1743 SmallVector<OpFoldResult> sizeValues = 1744 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1745 return b.getI64IntegerAttr(v); 1746 })); 1747 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1748 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1749 return b.getI64IntegerAttr(v); 1750 })); 1751 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1752 attrs); 1753 } 1754 1755 // Build a SubViewOp with dynamic entries and custom result type. If the type 1756 // passed is nullptr, it is inferred. 1757 void SubViewOp::build(OpBuilder &b, OperationState &result, 1758 MemRefType resultType, Value source, ValueRange offsets, 1759 ValueRange sizes, ValueRange strides, 1760 ArrayRef<NamedAttribute> attrs) { 1761 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1762 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1763 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1764 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1765 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1766 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1767 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1768 } 1769 1770 // Build a SubViewOp with dynamic entries and inferred result type. 1771 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1772 ValueRange offsets, ValueRange sizes, ValueRange strides, 1773 ArrayRef<NamedAttribute> attrs) { 1774 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1775 } 1776 1777 /// For ViewLikeOpInterface. 1778 Value SubViewOp::getViewSource() { return source(); } 1779 1780 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 1781 /// value). 1782 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 1783 AffineExpr t1Offset, t2Offset; 1784 SmallVector<AffineExpr> t1Strides, t2Strides; 1785 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 1786 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 1787 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 1788 } 1789 1790 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1791 /// This function is slight variant of `is subsequence` algorithm where 1792 /// not matching dimension must be 1. 1793 static SliceVerificationResult 1794 isRankReducedMemRefType(MemRefType originalType, 1795 MemRefType candidateRankReducedType, 1796 ArrayRef<OpFoldResult> sizes) { 1797 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 1798 if (partialRes != SliceVerificationResult::Success) 1799 return partialRes; 1800 1801 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 1802 originalType, candidateRankReducedType, sizes); 1803 1804 // Sizes cannot be matched in case empty vector is returned. 1805 if (!optionalUnusedDimsMask.hasValue()) 1806 return SliceVerificationResult::LayoutMismatch; 1807 1808 if (originalType.getMemorySpace() != 1809 candidateRankReducedType.getMemorySpace()) 1810 return SliceVerificationResult::MemSpaceMismatch; 1811 1812 // No amount of stride dropping can reconcile incompatible offsets. 1813 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 1814 return SliceVerificationResult::LayoutMismatch; 1815 1816 return SliceVerificationResult::Success; 1817 } 1818 1819 template <typename OpTy> 1820 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 1821 OpTy op, Type expectedType) { 1822 auto memrefType = expectedType.cast<ShapedType>(); 1823 switch (result) { 1824 case SliceVerificationResult::Success: 1825 return success(); 1826 case SliceVerificationResult::RankTooLarge: 1827 return op.emitError("expected result rank to be smaller or equal to ") 1828 << "the source rank. "; 1829 case SliceVerificationResult::SizeMismatch: 1830 return op.emitError("expected result type to be ") 1831 << expectedType 1832 << " or a rank-reduced version. (mismatch of result sizes) "; 1833 case SliceVerificationResult::ElemTypeMismatch: 1834 return op.emitError("expected result element type to be ") 1835 << memrefType.getElementType(); 1836 case SliceVerificationResult::MemSpaceMismatch: 1837 return op.emitError("expected result and source memory spaces to match."); 1838 case SliceVerificationResult::LayoutMismatch: 1839 return op.emitError("expected result type to be ") 1840 << expectedType 1841 << " or a rank-reduced version. (mismatch of result layout) "; 1842 } 1843 llvm_unreachable("unexpected subview verification result"); 1844 } 1845 1846 /// Verifier for SubViewOp. 1847 static LogicalResult verify(SubViewOp op) { 1848 MemRefType baseType = op.getSourceType(); 1849 MemRefType subViewType = op.getType(); 1850 1851 // The base memref and the view memref should be in the same memory space. 1852 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1853 return op.emitError("different memory spaces specified for base memref " 1854 "type ") 1855 << baseType << " and subview memref type " << subViewType; 1856 1857 // Verify that the base memref type has a strided layout map. 1858 if (!isStrided(baseType)) 1859 return op.emitError("base type ") << baseType << " is not strided"; 1860 1861 // Verify result type against inferred type. 1862 auto expectedType = SubViewOp::inferResultType( 1863 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1864 extractFromI64ArrayAttr(op.static_sizes()), 1865 extractFromI64ArrayAttr(op.static_strides())); 1866 1867 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 1868 subViewType, op.getMixedSizes()); 1869 return produceSubViewErrorMsg(result, op, expectedType); 1870 } 1871 1872 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 1873 return os << "range " << range.offset << ":" << range.size << ":" 1874 << range.stride; 1875 } 1876 1877 /// Return the list of Range (i.e. offset, size, stride). Each Range 1878 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1879 /// with `b` at location `loc`. 1880 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1881 OpBuilder &b, Location loc) { 1882 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1883 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1884 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1885 SmallVector<Range, 8> res; 1886 unsigned rank = ranks[0]; 1887 res.reserve(rank); 1888 for (unsigned idx = 0; idx < rank; ++idx) { 1889 Value offset = 1890 op.isDynamicOffset(idx) 1891 ? op.getDynamicOffset(idx) 1892 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1893 Value size = 1894 op.isDynamicSize(idx) 1895 ? op.getDynamicSize(idx) 1896 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 1897 Value stride = 1898 op.isDynamicStride(idx) 1899 ? op.getDynamicStride(idx) 1900 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 1901 res.emplace_back(Range{offset, size, stride}); 1902 } 1903 return res; 1904 } 1905 1906 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 1907 /// deduce the result type for the given `sourceType`. Additionally, reduce the 1908 /// rank of the inferred result type if `currentResultType` is lower rank than 1909 /// `currentSourceType`. Use this signature if `sourceType` is updated together 1910 /// with the result type. In this case, it is important to compute the dropped 1911 /// dimensions using `currentSourceType` whose strides align with 1912 /// `currentResultType`. 1913 static MemRefType getCanonicalSubViewResultType( 1914 MemRefType currentResultType, MemRefType currentSourceType, 1915 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 1916 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 1917 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1918 mixedSizes, mixedStrides) 1919 .cast<MemRefType>(); 1920 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 1921 computeMemRefRankReductionMask(currentSourceType, currentResultType, 1922 mixedSizes); 1923 // Return nullptr as failure mode. 1924 if (!unusedDims) 1925 return nullptr; 1926 SmallVector<int64_t> shape; 1927 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 1928 if (unusedDims->count(sizes.index())) 1929 continue; 1930 shape.push_back(sizes.value()); 1931 } 1932 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 1933 if (!layoutMap.isIdentity()) 1934 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 1935 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 1936 nonRankReducedType.getMemorySpace()); 1937 } 1938 1939 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 1940 /// deduce the result type. Additionally, reduce the rank of the inferred result 1941 /// type if `currentResultType` is lower rank than `sourceType`. 1942 static MemRefType getCanonicalSubViewResultType( 1943 MemRefType currentResultType, MemRefType sourceType, 1944 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 1945 ArrayRef<OpFoldResult> mixedStrides) { 1946 return getCanonicalSubViewResultType(currentResultType, sourceType, 1947 sourceType, mixedOffsets, mixedSizes, 1948 mixedStrides); 1949 } 1950 1951 /// Helper method to check if a `subview` operation is trivially a no-op. This 1952 /// is the case if the all offsets are zero, all strides are 1, and the source 1953 /// shape is same as the size of the subview. In such cases, the subview can be 1954 /// folded into its source. 1955 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 1956 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 1957 return false; 1958 1959 auto mixedOffsets = subViewOp.getMixedOffsets(); 1960 auto mixedSizes = subViewOp.getMixedSizes(); 1961 auto mixedStrides = subViewOp.getMixedStrides(); 1962 1963 // Check offsets are zero. 1964 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 1965 Optional<int64_t> intValue = getConstantIntValue(ofr); 1966 return !intValue || intValue.getValue() != 0; 1967 })) 1968 return false; 1969 1970 // Check strides are one. 1971 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 1972 Optional<int64_t> intValue = getConstantIntValue(ofr); 1973 return !intValue || intValue.getValue() != 1; 1974 })) 1975 return false; 1976 1977 // Check all size values are static and matches the (static) source shape. 1978 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 1979 for (const auto &size : llvm::enumerate(mixedSizes)) { 1980 Optional<int64_t> intValue = getConstantIntValue(size.value()); 1981 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 1982 return false; 1983 } 1984 // All conditions met. The `SubViewOp` is foldable as a no-op. 1985 return true; 1986 } 1987 1988 namespace { 1989 /// Pattern to rewrite a subview op with MemRefCast arguments. 1990 /// This essentially pushes memref.cast past its consuming subview when 1991 /// `canFoldIntoConsumerOp` is true. 1992 /// 1993 /// Example: 1994 /// ``` 1995 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1996 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1997 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1998 /// ``` 1999 /// is rewritten into: 2000 /// ``` 2001 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2002 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2003 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2004 /// ``` 2005 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2006 public: 2007 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2008 2009 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2010 PatternRewriter &rewriter) const override { 2011 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2012 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2013 return matchPattern(operand, matchConstantIndex()); 2014 })) 2015 return failure(); 2016 2017 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2018 if (!castOp) 2019 return failure(); 2020 2021 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2022 return failure(); 2023 2024 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 2025 // MemRefCastOp source operand type to infer the result type and the current 2026 // SubViewOp source operand type to compute the dropped dimensions if the 2027 // operation is rank-reducing. 2028 auto resultType = getCanonicalSubViewResultType( 2029 subViewOp.getType(), subViewOp.getSourceType(), 2030 castOp.source().getType().cast<MemRefType>(), 2031 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2032 subViewOp.getMixedStrides()); 2033 if (!resultType) 2034 return failure(); 2035 2036 Value newSubView = rewriter.create<SubViewOp>( 2037 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2038 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2039 subViewOp.static_sizes(), subViewOp.static_strides()); 2040 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2041 newSubView); 2042 return success(); 2043 } 2044 }; 2045 2046 /// Canonicalize subview ops that are no-ops. When the source shape is not same 2047 /// as a result shape due to use of `affine_map`. 2048 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 2049 public: 2050 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2051 2052 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2053 PatternRewriter &rewriter) const override { 2054 if (!isTrivialSubViewOp(subViewOp)) 2055 return failure(); 2056 if (subViewOp.getSourceType() == subViewOp.getType()) { 2057 rewriter.replaceOp(subViewOp, subViewOp.source()); 2058 return success(); 2059 } 2060 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(), 2061 subViewOp.getType()); 2062 return success(); 2063 } 2064 }; 2065 } // namespace 2066 2067 /// Return the canonical type of the result of a subview. 2068 struct SubViewReturnTypeCanonicalizer { 2069 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2070 ArrayRef<OpFoldResult> mixedSizes, 2071 ArrayRef<OpFoldResult> mixedStrides) { 2072 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2073 mixedOffsets, mixedSizes, 2074 mixedStrides); 2075 } 2076 }; 2077 2078 /// A canonicalizer wrapper to replace SubViewOps. 2079 struct SubViewCanonicalizer { 2080 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2081 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2082 } 2083 }; 2084 2085 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2086 MLIRContext *context) { 2087 results 2088 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2089 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2090 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2091 } 2092 2093 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2094 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2095 auto sourceShapedType = source().getType().cast<ShapedType>(); 2096 2097 if (resultShapedType.hasStaticShape() && 2098 resultShapedType == sourceShapedType) { 2099 return getViewSource(); 2100 } 2101 2102 return {}; 2103 } 2104 2105 //===----------------------------------------------------------------------===// 2106 // TransposeOp 2107 //===----------------------------------------------------------------------===// 2108 2109 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2110 static MemRefType inferTransposeResultType(MemRefType memRefType, 2111 AffineMap permutationMap) { 2112 auto rank = memRefType.getRank(); 2113 auto originalSizes = memRefType.getShape(); 2114 // Compute permuted sizes. 2115 SmallVector<int64_t, 4> sizes(rank, 0); 2116 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2117 sizes[en.index()] = 2118 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2119 2120 // Compute permuted strides. 2121 int64_t offset; 2122 SmallVector<int64_t, 4> strides; 2123 auto res = getStridesAndOffset(memRefType, strides, offset); 2124 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2125 (void)res; 2126 auto map = 2127 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2128 map = permutationMap ? map.compose(permutationMap) : map; 2129 return MemRefType::Builder(memRefType) 2130 .setShape(sizes) 2131 .setLayout(AffineMapAttr::get(map)); 2132 } 2133 2134 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2135 AffineMapAttr permutation, 2136 ArrayRef<NamedAttribute> attrs) { 2137 auto permutationMap = permutation.getValue(); 2138 assert(permutationMap); 2139 2140 auto memRefType = in.getType().cast<MemRefType>(); 2141 // Compute result type. 2142 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2143 2144 build(b, result, resultType, in, attrs); 2145 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2146 } 2147 2148 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2149 static void print(OpAsmPrinter &p, TransposeOp op) { 2150 p << " " << op.in() << " " << op.permutation(); 2151 p.printOptionalAttrDict(op->getAttrs(), 2152 {TransposeOp::getPermutationAttrName()}); 2153 p << " : " << op.in().getType() << " to " << op.getType(); 2154 } 2155 2156 static ParseResult parseTransposeOp(OpAsmParser &parser, 2157 OperationState &result) { 2158 OpAsmParser::OperandType in; 2159 AffineMap permutation; 2160 MemRefType srcType, dstType; 2161 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2162 parser.parseOptionalAttrDict(result.attributes) || 2163 parser.parseColonType(srcType) || 2164 parser.resolveOperand(in, srcType, result.operands) || 2165 parser.parseKeywordType("to", dstType) || 2166 parser.addTypeToList(dstType, result.types)) 2167 return failure(); 2168 2169 result.addAttribute(TransposeOp::getPermutationAttrName(), 2170 AffineMapAttr::get(permutation)); 2171 return success(); 2172 } 2173 2174 static LogicalResult verify(TransposeOp op) { 2175 if (!op.permutation().isPermutation()) 2176 return op.emitOpError("expected a permutation map"); 2177 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2178 return op.emitOpError( 2179 "expected a permutation map of same rank as the input"); 2180 2181 auto srcType = op.in().getType().cast<MemRefType>(); 2182 auto dstType = op.getType().cast<MemRefType>(); 2183 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2184 if (dstType != transposedType) 2185 return op.emitOpError("output type ") 2186 << dstType << " does not match transposed input type " << srcType 2187 << ", " << transposedType; 2188 return success(); 2189 } 2190 2191 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2192 if (succeeded(foldMemRefCast(*this))) 2193 return getResult(); 2194 return {}; 2195 } 2196 2197 //===----------------------------------------------------------------------===// 2198 // ViewOp 2199 //===----------------------------------------------------------------------===// 2200 2201 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2202 OpAsmParser::OperandType srcInfo; 2203 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2204 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2205 auto indexType = parser.getBuilder().getIndexType(); 2206 Type srcType, dstType; 2207 llvm::SMLoc offsetLoc; 2208 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2209 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2210 return failure(); 2211 2212 if (offsetInfo.size() != 1) 2213 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2214 2215 return failure( 2216 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2217 parser.parseOptionalAttrDict(result.attributes) || 2218 parser.parseColonType(srcType) || 2219 parser.resolveOperand(srcInfo, srcType, result.operands) || 2220 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2221 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2222 parser.parseKeywordType("to", dstType) || 2223 parser.addTypeToList(dstType, result.types)); 2224 } 2225 2226 static void print(OpAsmPrinter &p, ViewOp op) { 2227 p << ' ' << op.getOperand(0) << '['; 2228 p.printOperand(op.byte_shift()); 2229 p << "][" << op.sizes() << ']'; 2230 p.printOptionalAttrDict(op->getAttrs()); 2231 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2232 } 2233 2234 static LogicalResult verify(ViewOp op) { 2235 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2236 auto viewType = op.getType(); 2237 2238 // The base memref should have identity layout map (or none). 2239 if (!baseType.getLayout().isIdentity()) 2240 return op.emitError("unsupported map for base memref type ") << baseType; 2241 2242 // The result memref should have identity layout map (or none). 2243 if (!viewType.getLayout().isIdentity()) 2244 return op.emitError("unsupported map for result memref type ") << viewType; 2245 2246 // The base memref and the view memref should be in the same memory space. 2247 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2248 return op.emitError("different memory spaces specified for base memref " 2249 "type ") 2250 << baseType << " and view memref type " << viewType; 2251 2252 // Verify that we have the correct number of sizes for the result type. 2253 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2254 if (op.sizes().size() != numDynamicDims) 2255 return op.emitError("incorrect number of size operands for type ") 2256 << viewType; 2257 2258 return success(); 2259 } 2260 2261 Value ViewOp::getViewSource() { return source(); } 2262 2263 namespace { 2264 2265 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2266 using OpRewritePattern<ViewOp>::OpRewritePattern; 2267 2268 LogicalResult matchAndRewrite(ViewOp viewOp, 2269 PatternRewriter &rewriter) const override { 2270 // Return if none of the operands are constants. 2271 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2272 return matchPattern(operand, matchConstantIndex()); 2273 })) 2274 return failure(); 2275 2276 // Get result memref type. 2277 auto memrefType = viewOp.getType(); 2278 2279 // Get offset from old memref view type 'memRefType'. 2280 int64_t oldOffset; 2281 SmallVector<int64_t, 4> oldStrides; 2282 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2283 return failure(); 2284 assert(oldOffset == 0 && "Expected 0 offset"); 2285 2286 SmallVector<Value, 4> newOperands; 2287 2288 // Offset cannot be folded into result type. 2289 2290 // Fold any dynamic dim operands which are produced by a constant. 2291 SmallVector<int64_t, 4> newShapeConstants; 2292 newShapeConstants.reserve(memrefType.getRank()); 2293 2294 unsigned dynamicDimPos = 0; 2295 unsigned rank = memrefType.getRank(); 2296 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2297 int64_t dimSize = memrefType.getDimSize(dim); 2298 // If this is already static dimension, keep it. 2299 if (!ShapedType::isDynamic(dimSize)) { 2300 newShapeConstants.push_back(dimSize); 2301 continue; 2302 } 2303 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2304 if (auto constantIndexOp = 2305 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2306 // Dynamic shape dimension will be folded. 2307 newShapeConstants.push_back(constantIndexOp.value()); 2308 } else { 2309 // Dynamic shape dimension not folded; copy operand from old memref. 2310 newShapeConstants.push_back(dimSize); 2311 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2312 } 2313 dynamicDimPos++; 2314 } 2315 2316 // Create new memref type with constant folded dims. 2317 MemRefType newMemRefType = 2318 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2319 // Nothing new, don't fold. 2320 if (newMemRefType == memrefType) 2321 return failure(); 2322 2323 // Create new ViewOp. 2324 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2325 viewOp.getOperand(0), 2326 viewOp.byte_shift(), newOperands); 2327 // Insert a cast so we have the same type as the old memref type. 2328 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2329 return success(); 2330 } 2331 }; 2332 2333 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2334 using OpRewritePattern<ViewOp>::OpRewritePattern; 2335 2336 LogicalResult matchAndRewrite(ViewOp viewOp, 2337 PatternRewriter &rewriter) const override { 2338 Value memrefOperand = viewOp.getOperand(0); 2339 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2340 if (!memrefCastOp) 2341 return failure(); 2342 Value allocOperand = memrefCastOp.getOperand(); 2343 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2344 if (!allocOp) 2345 return failure(); 2346 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2347 viewOp.byte_shift(), viewOp.sizes()); 2348 return success(); 2349 } 2350 }; 2351 2352 } // namespace 2353 2354 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2355 MLIRContext *context) { 2356 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2357 } 2358 2359 //===----------------------------------------------------------------------===// 2360 // AtomicRMWOp 2361 //===----------------------------------------------------------------------===// 2362 2363 static LogicalResult verify(AtomicRMWOp op) { 2364 if (op.getMemRefType().getRank() != op.getNumOperands() - 2) 2365 return op.emitOpError( 2366 "expects the number of subscripts to be equal to memref rank"); 2367 switch (op.kind()) { 2368 case arith::AtomicRMWKind::addf: 2369 case arith::AtomicRMWKind::maxf: 2370 case arith::AtomicRMWKind::minf: 2371 case arith::AtomicRMWKind::mulf: 2372 if (!op.value().getType().isa<FloatType>()) 2373 return op.emitOpError() 2374 << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) 2375 << "' expects a floating-point type"; 2376 break; 2377 case arith::AtomicRMWKind::addi: 2378 case arith::AtomicRMWKind::maxs: 2379 case arith::AtomicRMWKind::maxu: 2380 case arith::AtomicRMWKind::mins: 2381 case arith::AtomicRMWKind::minu: 2382 case arith::AtomicRMWKind::muli: 2383 case arith::AtomicRMWKind::ori: 2384 case arith::AtomicRMWKind::andi: 2385 if (!op.value().getType().isa<IntegerType>()) 2386 return op.emitOpError() 2387 << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) 2388 << "' expects an integer type"; 2389 break; 2390 default: 2391 break; 2392 } 2393 return success(); 2394 } 2395 2396 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2397 /// atomicrmw(memrefcast) -> atomicrmw 2398 if (succeeded(foldMemRefCast(*this, value()))) 2399 return getResult(); 2400 return OpFoldResult(); 2401 } 2402 2403 //===----------------------------------------------------------------------===// 2404 // TableGen'd op method definitions 2405 //===----------------------------------------------------------------------===// 2406 2407 #define GET_OP_CLASSES 2408 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2409