1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 if (arith::ConstantOp::isBuildableWith(value, type)) 34 return builder.create<arith::ConstantOp>(loc, value, type); 35 if (ConstantOp::isBuildableWith(value, type)) 36 return builder.create<ConstantOp>(loc, value, type); 37 return nullptr; 38 } 39 40 //===----------------------------------------------------------------------===// 41 // Common canonicalization pattern support logic 42 //===----------------------------------------------------------------------===// 43 44 /// This is a common class used for patterns of the form 45 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 46 /// into the root operation directly. 47 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 48 bool folded = false; 49 for (OpOperand &operand : op->getOpOperands()) { 50 auto cast = operand.get().getDefiningOp<CastOp>(); 51 if (cast && operand.get() != inner && 52 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 53 operand.set(cast.getOperand()); 54 folded = true; 55 } 56 } 57 return success(folded); 58 } 59 60 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 61 /// type. 62 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 63 if (auto memref = type.dyn_cast<MemRefType>()) 64 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 65 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 66 return UnrankedTensorType::get(memref.getElementType()); 67 return NoneType::get(type.getContext()); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // AllocOp / AllocaOp 72 //===----------------------------------------------------------------------===// 73 74 template <typename AllocLikeOp> 75 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 76 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 77 "applies to only alloc or alloca"); 78 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 79 if (!memRefType) 80 return op.emitOpError("result must be a memref"); 81 82 if (static_cast<int64_t>(op.dynamicSizes().size()) != 83 memRefType.getNumDynamicDims()) 84 return op.emitOpError("dimension operand count does not equal memref " 85 "dynamic dimension count"); 86 87 unsigned numSymbols = 0; 88 if (!memRefType.getLayout().isIdentity()) 89 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 90 if (op.symbolOperands().size() != numSymbols) 91 return op.emitOpError("symbol operand count does not equal memref symbol " 92 "count: expected ") 93 << numSymbols << ", got " << op.symbolOperands().size(); 94 95 return success(); 96 } 97 98 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 99 100 static LogicalResult verify(AllocaOp op) { 101 // An alloca op needs to have an ancestor with an allocation scope trait. 102 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 103 return op.emitOpError( 104 "requires an ancestor op with AutomaticAllocationScope trait"); 105 106 return verifyAllocLikeOp(op); 107 } 108 109 namespace { 110 /// Fold constant dimensions into an alloc like operation. 111 template <typename AllocLikeOp> 112 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 113 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 114 115 LogicalResult matchAndRewrite(AllocLikeOp alloc, 116 PatternRewriter &rewriter) const override { 117 // Check to see if any dimensions operands are constants. If so, we can 118 // substitute and drop them. 119 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 120 return matchPattern(operand, matchConstantIndex()); 121 })) 122 return failure(); 123 124 auto memrefType = alloc.getType(); 125 126 // Ok, we have one or more constant operands. Collect the non-constant ones 127 // and keep track of the resultant memref type to build. 128 SmallVector<int64_t, 4> newShapeConstants; 129 newShapeConstants.reserve(memrefType.getRank()); 130 SmallVector<Value, 4> dynamicSizes; 131 132 unsigned dynamicDimPos = 0; 133 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 134 int64_t dimSize = memrefType.getDimSize(dim); 135 // If this is already static dimension, keep it. 136 if (dimSize != -1) { 137 newShapeConstants.push_back(dimSize); 138 continue; 139 } 140 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 141 auto *defOp = dynamicSize.getDefiningOp(); 142 if (auto constantIndexOp = 143 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 144 // Dynamic shape dimension will be folded. 145 newShapeConstants.push_back(constantIndexOp.value()); 146 } else { 147 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 148 newShapeConstants.push_back(-1); 149 dynamicSizes.push_back(dynamicSize); 150 } 151 dynamicDimPos++; 152 } 153 154 // Create new memref type (which will have fewer dynamic dimensions). 155 MemRefType newMemRefType = 156 MemRefType::Builder(memrefType).setShape(newShapeConstants); 157 assert(static_cast<int64_t>(dynamicSizes.size()) == 158 newMemRefType.getNumDynamicDims()); 159 160 // Create and insert the alloc op for the new memref. 161 auto newAlloc = rewriter.create<AllocLikeOp>( 162 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 163 alloc.alignmentAttr()); 164 // Insert a cast so we have the same type as the old alloc. 165 auto resultCast = 166 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 167 168 rewriter.replaceOp(alloc, {resultCast}); 169 return success(); 170 } 171 }; 172 173 /// Fold alloc operations with no users or only store and dealloc uses. 174 template <typename T> 175 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 176 using OpRewritePattern<T>::OpRewritePattern; 177 178 LogicalResult matchAndRewrite(T alloc, 179 PatternRewriter &rewriter) const override { 180 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 181 if (auto storeOp = dyn_cast<StoreOp>(op)) 182 return storeOp.value() == alloc; 183 return !isa<DeallocOp>(op); 184 })) 185 return failure(); 186 187 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 188 rewriter.eraseOp(user); 189 190 rewriter.eraseOp(alloc); 191 return success(); 192 } 193 }; 194 } // namespace 195 196 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 197 MLIRContext *context) { 198 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 199 } 200 201 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 202 MLIRContext *context) { 203 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 204 context); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // AllocaScopeOp 209 //===----------------------------------------------------------------------===// 210 211 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 212 bool printBlockTerminators = false; 213 214 p << ' '; 215 if (!op.results().empty()) { 216 p << " -> (" << op.getResultTypes() << ")"; 217 printBlockTerminators = true; 218 } 219 p << ' '; 220 p.printRegion(op.bodyRegion(), 221 /*printEntryBlockArgs=*/false, 222 /*printBlockTerminators=*/printBlockTerminators); 223 p.printOptionalAttrDict(op->getAttrs()); 224 } 225 226 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 227 OperationState &result) { 228 // Create a region for the body. 229 result.regions.reserve(1); 230 Region *bodyRegion = result.addRegion(); 231 232 // Parse optional results type list. 233 if (parser.parseOptionalArrowTypeList(result.types)) 234 return failure(); 235 236 // Parse the body region. 237 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 238 return failure(); 239 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 240 result.location); 241 242 // Parse the optional attribute list. 243 if (parser.parseOptionalAttrDict(result.attributes)) 244 return failure(); 245 246 return success(); 247 } 248 249 static LogicalResult verify(AllocaScopeOp op) { 250 if (failed(RegionBranchOpInterface::verifyTypes(op))) 251 return failure(); 252 253 return success(); 254 } 255 256 void AllocaScopeOp::getSuccessorRegions( 257 Optional<unsigned> index, ArrayRef<Attribute> operands, 258 SmallVectorImpl<RegionSuccessor> ®ions) { 259 if (index.hasValue()) { 260 regions.push_back(RegionSuccessor(getResults())); 261 return; 262 } 263 264 regions.push_back(RegionSuccessor(&bodyRegion())); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // AssumeAlignmentOp 269 //===----------------------------------------------------------------------===// 270 271 static LogicalResult verify(AssumeAlignmentOp op) { 272 unsigned alignment = op.alignment(); 273 if (!llvm::isPowerOf2_32(alignment)) 274 return op.emitOpError("alignment must be power of 2"); 275 return success(); 276 } 277 278 //===----------------------------------------------------------------------===// 279 // CastOp 280 //===----------------------------------------------------------------------===// 281 282 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 283 /// source memref. This is useful to to fold a memref.cast into a consuming op 284 /// and implement canonicalization patterns for ops in different dialects that 285 /// may consume the results of memref.cast operations. Such foldable memref.cast 286 /// operations are typically inserted as `view` and `subview` ops are 287 /// canonicalized, to preserve the type compatibility of their uses. 288 /// 289 /// Returns true when all conditions are met: 290 /// 1. source and result are ranked memrefs with strided semantics and same 291 /// element type and rank. 292 /// 2. each of the source's size, offset or stride has more static information 293 /// than the corresponding result's size, offset or stride. 294 /// 295 /// Example 1: 296 /// ```mlir 297 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 298 /// %2 = consumer %1 ... : memref<?x?xf32> ... 299 /// ``` 300 /// 301 /// may fold into: 302 /// 303 /// ```mlir 304 /// %2 = consumer %0 ... : memref<8x16xf32> ... 305 /// ``` 306 /// 307 /// Example 2: 308 /// ``` 309 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 310 /// to memref<?x?xf32> 311 /// consumer %1 : memref<?x?xf32> ... 312 /// ``` 313 /// 314 /// may fold into: 315 /// 316 /// ``` 317 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 318 /// ``` 319 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 320 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 321 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 322 323 // Requires ranked MemRefType. 324 if (!sourceType || !resultType) 325 return false; 326 327 // Requires same elemental type. 328 if (sourceType.getElementType() != resultType.getElementType()) 329 return false; 330 331 // Requires same rank. 332 if (sourceType.getRank() != resultType.getRank()) 333 return false; 334 335 // Only fold casts between strided memref forms. 336 int64_t sourceOffset, resultOffset; 337 SmallVector<int64_t, 4> sourceStrides, resultStrides; 338 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 339 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 340 return false; 341 342 // If cast is towards more static sizes along any dimension, don't fold. 343 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 344 auto ss = std::get<0>(it), st = std::get<1>(it); 345 if (ss != st) 346 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 347 return false; 348 } 349 350 // If cast is towards more static offset along any dimension, don't fold. 351 if (sourceOffset != resultOffset) 352 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 353 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 354 return false; 355 356 // If cast is towards more static strides along any dimension, don't fold. 357 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 358 auto ss = std::get<0>(it), st = std::get<1>(it); 359 if (ss != st) 360 if (ShapedType::isDynamicStrideOrOffset(ss) && 361 !ShapedType::isDynamicStrideOrOffset(st)) 362 return false; 363 } 364 365 return true; 366 } 367 368 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 369 if (inputs.size() != 1 || outputs.size() != 1) 370 return false; 371 Type a = inputs.front(), b = outputs.front(); 372 auto aT = a.dyn_cast<MemRefType>(); 373 auto bT = b.dyn_cast<MemRefType>(); 374 375 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 376 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 377 378 if (aT && bT) { 379 if (aT.getElementType() != bT.getElementType()) 380 return false; 381 if (aT.getLayout() != bT.getLayout()) { 382 int64_t aOffset, bOffset; 383 SmallVector<int64_t, 4> aStrides, bStrides; 384 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 385 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 386 aStrides.size() != bStrides.size()) 387 return false; 388 389 // Strides along a dimension/offset are compatible if the value in the 390 // source memref is static and the value in the target memref is the 391 // same. They are also compatible if either one is dynamic (see 392 // description of MemRefCastOp for details). 393 auto checkCompatible = [](int64_t a, int64_t b) { 394 return (a == MemRefType::getDynamicStrideOrOffset() || 395 b == MemRefType::getDynamicStrideOrOffset() || a == b); 396 }; 397 if (!checkCompatible(aOffset, bOffset)) 398 return false; 399 for (const auto &aStride : enumerate(aStrides)) 400 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 401 return false; 402 } 403 if (aT.getMemorySpace() != bT.getMemorySpace()) 404 return false; 405 406 // They must have the same rank, and any specified dimensions must match. 407 if (aT.getRank() != bT.getRank()) 408 return false; 409 410 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 411 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 412 if (aDim != -1 && bDim != -1 && aDim != bDim) 413 return false; 414 } 415 return true; 416 } else { 417 if (!aT && !uaT) 418 return false; 419 if (!bT && !ubT) 420 return false; 421 // Unranked to unranked casting is unsupported 422 if (uaT && ubT) 423 return false; 424 425 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 426 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 427 if (aEltType != bEltType) 428 return false; 429 430 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 431 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 432 return aMemSpace == bMemSpace; 433 } 434 435 return false; 436 } 437 438 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 439 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 440 } 441 442 //===----------------------------------------------------------------------===// 443 // CopyOp 444 //===----------------------------------------------------------------------===// 445 446 namespace { 447 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 448 /// and element type, the cast can be skipped. Such CastOps only cast the layout 449 /// of the type. 450 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 451 using OpRewritePattern<CopyOp>::OpRewritePattern; 452 453 LogicalResult matchAndRewrite(CopyOp copyOp, 454 PatternRewriter &rewriter) const override { 455 bool modified = false; 456 457 // Check source. 458 if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) { 459 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 460 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 461 462 if (fromType && toType) { 463 if (fromType.getShape() == toType.getShape() && 464 fromType.getElementType() == toType.getElementType()) { 465 rewriter.updateRootInPlace( 466 copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); 467 modified = true; 468 } 469 } 470 } 471 472 // Check target. 473 if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) { 474 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 475 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 476 477 if (fromType && toType) { 478 if (fromType.getShape() == toType.getShape() && 479 fromType.getElementType() == toType.getElementType()) { 480 rewriter.updateRootInPlace( 481 copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); 482 modified = true; 483 } 484 } 485 } 486 487 return success(modified); 488 } 489 }; 490 491 /// Fold memref.copy(%x, %x). 492 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 493 using OpRewritePattern<CopyOp>::OpRewritePattern; 494 495 LogicalResult matchAndRewrite(CopyOp copyOp, 496 PatternRewriter &rewriter) const override { 497 if (copyOp.source() != copyOp.target()) 498 return failure(); 499 500 rewriter.eraseOp(copyOp); 501 return success(); 502 } 503 }; 504 } // namespace 505 506 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 507 MLIRContext *context) { 508 results.add<FoldCopyOfCast, FoldSelfCopy>(context); 509 } 510 511 //===----------------------------------------------------------------------===// 512 // DeallocOp 513 //===----------------------------------------------------------------------===// 514 515 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 516 SmallVectorImpl<OpFoldResult> &results) { 517 /// dealloc(memrefcast) -> dealloc 518 return foldMemRefCast(*this); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // DimOp 523 //===----------------------------------------------------------------------===// 524 525 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 526 int64_t index) { 527 auto loc = result.location; 528 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 529 build(builder, result, source, indexValue); 530 } 531 532 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 533 Value index) { 534 auto indexTy = builder.getIndexType(); 535 build(builder, result, indexTy, source, index); 536 } 537 538 Optional<int64_t> DimOp::getConstantIndex() { 539 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 540 return constantOp.getValue().cast<IntegerAttr>().getInt(); 541 return {}; 542 } 543 544 static LogicalResult verify(DimOp op) { 545 // Assume unknown index to be in range. 546 Optional<int64_t> index = op.getConstantIndex(); 547 if (!index.hasValue()) 548 return success(); 549 550 // Check that constant index is not knowingly out of range. 551 auto type = op.source().getType(); 552 if (auto memrefType = type.dyn_cast<MemRefType>()) { 553 if (index.getValue() >= memrefType.getRank()) 554 return op.emitOpError("index is out of range"); 555 } else if (type.isa<UnrankedMemRefType>()) { 556 // Assume index to be in range. 557 } else { 558 llvm_unreachable("expected operand with memref type"); 559 } 560 return success(); 561 } 562 563 /// Return a map with key being elements in `vals` and data being number of 564 /// occurences of it. Use std::map, since the `vals` here are strides and the 565 /// dynamic stride value is the same as the tombstone value for 566 /// `DenseMap<int64_t>`. 567 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 568 std::map<int64_t, unsigned> numOccurences; 569 for (auto val : vals) 570 numOccurences[val]++; 571 return numOccurences; 572 } 573 574 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 575 /// to be a subset of `originalType` with some `1` entries erased, return the 576 /// set of indices that specifies which of the entries of `originalShape` are 577 /// dropped to obtain `reducedShape`. 578 /// This accounts for cases where there are multiple unit-dims, but only a 579 /// subset of those are dropped. For MemRefTypes these can be disambiguated 580 /// using the strides. If a dimension is dropped the stride must be dropped too. 581 static llvm::Optional<llvm::SmallDenseSet<unsigned>> 582 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 583 ArrayRef<OpFoldResult> sizes) { 584 llvm::SmallDenseSet<unsigned> unusedDims; 585 if (originalType.getRank() == reducedType.getRank()) 586 return unusedDims; 587 588 for (const auto &dim : llvm::enumerate(sizes)) 589 if (auto attr = dim.value().dyn_cast<Attribute>()) 590 if (attr.cast<IntegerAttr>().getInt() == 1) 591 unusedDims.insert(dim.index()); 592 593 SmallVector<int64_t> originalStrides, candidateStrides; 594 int64_t originalOffset, candidateOffset; 595 if (failed( 596 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 597 failed( 598 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 599 return llvm::None; 600 601 // For memrefs, a dimension is truly dropped if its corresponding stride is 602 // also dropped. This is particularly important when more than one of the dims 603 // is 1. Track the number of occurences of the strides in the original type 604 // and the candidate type. For each unused dim that stride should not be 605 // present in the candidate type. Note that there could be multiple dimensions 606 // that have the same size. We dont need to exactly figure out which dim 607 // corresponds to which stride, we just need to verify that the number of 608 // reptitions of a stride in the original + number of unused dims with that 609 // stride == number of repititions of a stride in the candidate. 610 std::map<int64_t, unsigned> currUnaccountedStrides = 611 getNumOccurences(originalStrides); 612 std::map<int64_t, unsigned> candidateStridesNumOccurences = 613 getNumOccurences(candidateStrides); 614 llvm::SmallDenseSet<unsigned> prunedUnusedDims; 615 for (unsigned dim : unusedDims) { 616 int64_t originalStride = originalStrides[dim]; 617 if (currUnaccountedStrides[originalStride] > 618 candidateStridesNumOccurences[originalStride]) { 619 // This dim can be treated as dropped. 620 currUnaccountedStrides[originalStride]--; 621 continue; 622 } 623 if (currUnaccountedStrides[originalStride] == 624 candidateStridesNumOccurences[originalStride]) { 625 // The stride for this is not dropped. Keep as is. 626 prunedUnusedDims.insert(dim); 627 continue; 628 } 629 if (currUnaccountedStrides[originalStride] < 630 candidateStridesNumOccurences[originalStride]) { 631 // This should never happen. Cant have a stride in the reduced rank type 632 // that wasnt in the original one. 633 return llvm::None; 634 } 635 } 636 637 for (auto prunedDim : prunedUnusedDims) 638 unusedDims.erase(prunedDim); 639 if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) 640 return llvm::None; 641 return unusedDims; 642 } 643 644 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() { 645 MemRefType sourceType = getSourceType(); 646 MemRefType resultType = getType(); 647 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 648 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 649 assert(unusedDims && "unable to find unused dims of subview"); 650 return *unusedDims; 651 } 652 653 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 654 // All forms of folding require a known index. 655 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 656 if (!index) 657 return {}; 658 659 // Folding for unranked types (UnrankedMemRefType) is not supported. 660 auto memrefType = source().getType().dyn_cast<MemRefType>(); 661 if (!memrefType) 662 return {}; 663 664 // Fold if the shape extent along the given index is known. 665 if (!memrefType.isDynamicDim(index.getInt())) { 666 Builder builder(getContext()); 667 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 668 } 669 670 // The size at the given index is now known to be a dynamic size. 671 unsigned unsignedIndex = index.getValue().getZExtValue(); 672 673 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 674 Operation *definingOp = source().getDefiningOp(); 675 676 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 677 return *(alloc.getDynamicSizes().begin() + 678 memrefType.getDynamicDimIndex(unsignedIndex)); 679 680 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 681 return *(alloca.getDynamicSizes().begin() + 682 memrefType.getDynamicDimIndex(unsignedIndex)); 683 684 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 685 return *(view.getDynamicSizes().begin() + 686 memrefType.getDynamicDimIndex(unsignedIndex)); 687 688 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 689 llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims(); 690 unsigned resultIndex = 0; 691 unsigned sourceRank = subview.getSourceType().getRank(); 692 unsigned sourceIndex = 0; 693 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 694 if (unusedDims.count(i)) 695 continue; 696 if (resultIndex == unsignedIndex) { 697 sourceIndex = i; 698 break; 699 } 700 resultIndex++; 701 } 702 assert(subview.isDynamicSize(sourceIndex) && 703 "expected dynamic subview size"); 704 return subview.getDynamicSize(sourceIndex); 705 } 706 707 if (auto sizeInterface = 708 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 709 assert(sizeInterface.isDynamicSize(unsignedIndex) && 710 "Expected dynamic subview size"); 711 return sizeInterface.getDynamicSize(unsignedIndex); 712 } 713 714 // dim(memrefcast) -> dim 715 if (succeeded(foldMemRefCast(*this))) 716 return getResult(); 717 718 return {}; 719 } 720 721 namespace { 722 /// Fold dim of a memref reshape operation to a load into the reshape's shape 723 /// operand. 724 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 725 using OpRewritePattern<DimOp>::OpRewritePattern; 726 727 LogicalResult matchAndRewrite(DimOp dim, 728 PatternRewriter &rewriter) const override { 729 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 730 731 if (!reshape) 732 return failure(); 733 734 // Place the load directly after the reshape to ensure that the shape memref 735 // was not mutated. 736 rewriter.setInsertionPointAfter(reshape); 737 Location loc = dim.getLoc(); 738 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 739 if (load.getType() != dim.getType()) 740 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 741 rewriter.replaceOp(dim, load); 742 return success(); 743 } 744 }; 745 746 } // namespace 747 748 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 749 MLIRContext *context) { 750 results.add<DimOfMemRefReshape>(context); 751 } 752 753 // --------------------------------------------------------------------------- 754 // DmaStartOp 755 // --------------------------------------------------------------------------- 756 757 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 758 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 759 ValueRange destIndices, Value numElements, 760 Value tagMemRef, ValueRange tagIndices, Value stride, 761 Value elementsPerStride) { 762 result.addOperands(srcMemRef); 763 result.addOperands(srcIndices); 764 result.addOperands(destMemRef); 765 result.addOperands(destIndices); 766 result.addOperands({numElements, tagMemRef}); 767 result.addOperands(tagIndices); 768 if (stride) 769 result.addOperands({stride, elementsPerStride}); 770 } 771 772 static void print(OpAsmPrinter &p, DmaStartOp op) { 773 p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " 774 << op.getDstMemRef() << '[' << op.getDstIndices() << "], " 775 << op.getNumElements() << ", " << op.getTagMemRef() << '[' 776 << op.getTagIndices() << ']'; 777 if (op.isStrided()) 778 p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); 779 780 p.printOptionalAttrDict(op->getAttrs()); 781 p << " : " << op.getSrcMemRef().getType() << ", " 782 << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); 783 } 784 785 // Parse DmaStartOp. 786 // Ex: 787 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 788 // %tag[%index], %stride, %num_elt_per_stride : 789 // : memref<3076 x f32, 0>, 790 // memref<1024 x f32, 2>, 791 // memref<1 x i32> 792 // 793 static ParseResult parseDmaStartOp(OpAsmParser &parser, 794 OperationState &result) { 795 OpAsmParser::OperandType srcMemRefInfo; 796 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 797 OpAsmParser::OperandType dstMemRefInfo; 798 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 799 OpAsmParser::OperandType numElementsInfo; 800 OpAsmParser::OperandType tagMemrefInfo; 801 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 802 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 803 804 SmallVector<Type, 3> types; 805 auto indexType = parser.getBuilder().getIndexType(); 806 807 // Parse and resolve the following list of operands: 808 // *) source memref followed by its indices (in square brackets). 809 // *) destination memref followed by its indices (in square brackets). 810 // *) dma size in KiB. 811 if (parser.parseOperand(srcMemRefInfo) || 812 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 813 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 814 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 815 parser.parseComma() || parser.parseOperand(numElementsInfo) || 816 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 817 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 818 return failure(); 819 820 // Parse optional stride and elements per stride. 821 if (parser.parseTrailingOperandList(strideInfo)) 822 return failure(); 823 824 bool isStrided = strideInfo.size() == 2; 825 if (!strideInfo.empty() && !isStrided) { 826 return parser.emitError(parser.getNameLoc(), 827 "expected two stride related operands"); 828 } 829 830 if (parser.parseColonTypeList(types)) 831 return failure(); 832 if (types.size() != 3) 833 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 834 835 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 836 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 837 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 838 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 839 // size should be an index. 840 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 841 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 842 // tag indices should be index. 843 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 844 return failure(); 845 846 if (isStrided) { 847 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 848 return failure(); 849 } 850 851 return success(); 852 } 853 854 static LogicalResult verify(DmaStartOp op) { 855 unsigned numOperands = op.getNumOperands(); 856 857 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 858 // the number of elements. 859 if (numOperands < 4) 860 return op.emitOpError("expected at least 4 operands"); 861 862 // Check types of operands. The order of these calls is important: the later 863 // calls rely on some type properties to compute the operand position. 864 // 1. Source memref. 865 if (!op.getSrcMemRef().getType().isa<MemRefType>()) 866 return op.emitOpError("expected source to be of memref type"); 867 if (numOperands < op.getSrcMemRefRank() + 4) 868 return op.emitOpError() 869 << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; 870 if (!op.getSrcIndices().empty() && 871 !llvm::all_of(op.getSrcIndices().getTypes(), 872 [](Type t) { return t.isIndex(); })) 873 return op.emitOpError("expected source indices to be of index type"); 874 875 // 2. Destination memref. 876 if (!op.getDstMemRef().getType().isa<MemRefType>()) 877 return op.emitOpError("expected destination to be of memref type"); 878 unsigned numExpectedOperands = 879 op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; 880 if (numOperands < numExpectedOperands) 881 return op.emitOpError() 882 << "expected at least " << numExpectedOperands << " operands"; 883 if (!op.getDstIndices().empty() && 884 !llvm::all_of(op.getDstIndices().getTypes(), 885 [](Type t) { return t.isIndex(); })) 886 return op.emitOpError("expected destination indices to be of index type"); 887 888 // 3. Number of elements. 889 if (!op.getNumElements().getType().isIndex()) 890 return op.emitOpError("expected num elements to be of index type"); 891 892 // 4. Tag memref. 893 if (!op.getTagMemRef().getType().isa<MemRefType>()) 894 return op.emitOpError("expected tag to be of memref type"); 895 numExpectedOperands += op.getTagMemRefRank(); 896 if (numOperands < numExpectedOperands) 897 return op.emitOpError() 898 << "expected at least " << numExpectedOperands << " operands"; 899 if (!op.getTagIndices().empty() && 900 !llvm::all_of(op.getTagIndices().getTypes(), 901 [](Type t) { return t.isIndex(); })) 902 return op.emitOpError("expected tag indices to be of index type"); 903 904 // Optional stride-related operands must be either both present or both 905 // absent. 906 if (numOperands != numExpectedOperands && 907 numOperands != numExpectedOperands + 2) 908 return op.emitOpError("incorrect number of operands"); 909 910 // 5. Strides. 911 if (op.isStrided()) { 912 if (!op.getStride().getType().isIndex() || 913 !op.getNumElementsPerStride().getType().isIndex()) 914 return op.emitOpError( 915 "expected stride and num elements per stride to be of type index"); 916 } 917 918 return success(); 919 } 920 921 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 922 SmallVectorImpl<OpFoldResult> &results) { 923 /// dma_start(memrefcast) -> dma_start 924 return foldMemRefCast(*this); 925 } 926 927 // --------------------------------------------------------------------------- 928 // DmaWaitOp 929 // --------------------------------------------------------------------------- 930 931 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 932 SmallVectorImpl<OpFoldResult> &results) { 933 /// dma_wait(memrefcast) -> dma_wait 934 return foldMemRefCast(*this); 935 } 936 937 static LogicalResult verify(DmaWaitOp op) { 938 // Check that the number of tag indices matches the tagMemRef rank. 939 unsigned numTagIndices = op.tagIndices().size(); 940 unsigned tagMemRefRank = op.getTagMemRefRank(); 941 if (numTagIndices != tagMemRefRank) 942 return op.emitOpError() << "expected tagIndices to have the same number of " 943 "elements as the tagMemRef rank, expected " 944 << tagMemRefRank << ", but got " << numTagIndices; 945 return success(); 946 } 947 948 //===----------------------------------------------------------------------===// 949 // GenericAtomicRMWOp 950 //===----------------------------------------------------------------------===// 951 952 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, 953 Value memref, ValueRange ivs) { 954 result.addOperands(memref); 955 result.addOperands(ivs); 956 957 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) { 958 Type elementType = memrefType.getElementType(); 959 result.addTypes(elementType); 960 961 Region *bodyRegion = result.addRegion(); 962 bodyRegion->push_back(new Block()); 963 bodyRegion->addArgument(elementType, memref.getLoc()); 964 } 965 } 966 967 static LogicalResult verify(GenericAtomicRMWOp op) { 968 auto &body = op.getRegion(); 969 if (body.getNumArguments() != 1) 970 return op.emitOpError("expected single number of entry block arguments"); 971 972 if (op.getResult().getType() != body.getArgument(0).getType()) 973 return op.emitOpError( 974 "expected block argument of the same type result type"); 975 976 bool hasSideEffects = 977 body.walk([&](Operation *nestedOp) { 978 if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) 979 return WalkResult::advance(); 980 nestedOp->emitError( 981 "body of 'memref.generic_atomic_rmw' should contain " 982 "only operations with no side effects"); 983 return WalkResult::interrupt(); 984 }) 985 .wasInterrupted(); 986 return hasSideEffects ? failure() : success(); 987 } 988 989 static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, 990 OperationState &result) { 991 OpAsmParser::OperandType memref; 992 Type memrefType; 993 SmallVector<OpAsmParser::OperandType, 4> ivs; 994 995 Type indexType = parser.getBuilder().getIndexType(); 996 if (parser.parseOperand(memref) || 997 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || 998 parser.parseColonType(memrefType) || 999 parser.resolveOperand(memref, memrefType, result.operands) || 1000 parser.resolveOperands(ivs, indexType, result.operands)) 1001 return failure(); 1002 1003 Region *body = result.addRegion(); 1004 if (parser.parseRegion(*body, llvm::None, llvm::None) || 1005 parser.parseOptionalAttrDict(result.attributes)) 1006 return failure(); 1007 result.types.push_back(memrefType.cast<MemRefType>().getElementType()); 1008 return success(); 1009 } 1010 1011 static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { 1012 p << ' ' << op.memref() << "[" << op.indices() 1013 << "] : " << op.memref().getType() << ' '; 1014 p.printRegion(op.getRegion()); 1015 p.printOptionalAttrDict(op->getAttrs()); 1016 } 1017 1018 //===----------------------------------------------------------------------===// 1019 // AtomicYieldOp 1020 //===----------------------------------------------------------------------===// 1021 1022 static LogicalResult verify(AtomicYieldOp op) { 1023 Type parentType = op->getParentOp()->getResultTypes().front(); 1024 Type resultType = op.result().getType(); 1025 if (parentType != resultType) 1026 return op.emitOpError() << "types mismatch between yield op: " << resultType 1027 << " and its parent: " << parentType; 1028 return success(); 1029 } 1030 1031 //===----------------------------------------------------------------------===// 1032 // GlobalOp 1033 //===----------------------------------------------------------------------===// 1034 1035 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1036 TypeAttr type, 1037 Attribute initialValue) { 1038 p << type; 1039 if (!op.isExternal()) { 1040 p << " = "; 1041 if (op.isUninitialized()) 1042 p << "uninitialized"; 1043 else 1044 p.printAttributeWithoutType(initialValue); 1045 } 1046 } 1047 1048 static ParseResult 1049 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1050 Attribute &initialValue) { 1051 Type type; 1052 if (parser.parseType(type)) 1053 return failure(); 1054 1055 auto memrefType = type.dyn_cast<MemRefType>(); 1056 if (!memrefType || !memrefType.hasStaticShape()) 1057 return parser.emitError(parser.getNameLoc()) 1058 << "type should be static shaped memref, but got " << type; 1059 typeAttr = TypeAttr::get(type); 1060 1061 if (parser.parseOptionalEqual()) 1062 return success(); 1063 1064 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1065 initialValue = UnitAttr::get(parser.getContext()); 1066 return success(); 1067 } 1068 1069 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1070 if (parser.parseAttribute(initialValue, tensorType)) 1071 return failure(); 1072 if (!initialValue.isa<ElementsAttr>()) 1073 return parser.emitError(parser.getNameLoc()) 1074 << "initial value should be a unit or elements attribute"; 1075 return success(); 1076 } 1077 1078 static LogicalResult verify(GlobalOp op) { 1079 auto memrefType = op.type().dyn_cast<MemRefType>(); 1080 if (!memrefType || !memrefType.hasStaticShape()) 1081 return op.emitOpError("type should be static shaped memref, but got ") 1082 << op.type(); 1083 1084 // Verify that the initial value, if present, is either a unit attribute or 1085 // an elements attribute. 1086 if (op.initial_value().hasValue()) { 1087 Attribute initValue = op.initial_value().getValue(); 1088 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1089 return op.emitOpError("initial value should be a unit or elements " 1090 "attribute, but got ") 1091 << initValue; 1092 1093 // Check that the type of the initial value is compatible with the type of 1094 // the global variable. 1095 if (initValue.isa<ElementsAttr>()) { 1096 Type initType = initValue.getType(); 1097 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1098 if (initType != tensorType) 1099 return op.emitOpError("initial value expected to be of type ") 1100 << tensorType << ", but was of type " << initType; 1101 } 1102 } 1103 1104 if (Optional<uint64_t> alignAttr = op.alignment()) { 1105 uint64_t alignment = alignAttr.getValue(); 1106 1107 if (!llvm::isPowerOf2_64(alignment)) 1108 return op->emitError() << "alignment attribute value " << alignment 1109 << " is not a power of 2"; 1110 } 1111 1112 // TODO: verify visibility for declarations. 1113 return success(); 1114 } 1115 1116 //===----------------------------------------------------------------------===// 1117 // GetGlobalOp 1118 //===----------------------------------------------------------------------===// 1119 1120 LogicalResult 1121 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1122 // Verify that the result type is same as the type of the referenced 1123 // memref.global op. 1124 auto global = 1125 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1126 if (!global) 1127 return emitOpError("'") 1128 << name() << "' does not reference a valid global memref"; 1129 1130 Type resultType = result().getType(); 1131 if (global.type() != resultType) 1132 return emitOpError("result type ") 1133 << resultType << " does not match type " << global.type() 1134 << " of the global memref @" << name(); 1135 return success(); 1136 } 1137 1138 //===----------------------------------------------------------------------===// 1139 // LoadOp 1140 //===----------------------------------------------------------------------===// 1141 1142 static LogicalResult verify(LoadOp op) { 1143 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1144 return op.emitOpError("incorrect number of indices for load"); 1145 return success(); 1146 } 1147 1148 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1149 /// load(memrefcast) -> load 1150 if (succeeded(foldMemRefCast(*this))) 1151 return getResult(); 1152 return OpFoldResult(); 1153 } 1154 1155 //===----------------------------------------------------------------------===// 1156 // PrefetchOp 1157 //===----------------------------------------------------------------------===// 1158 1159 static void print(OpAsmPrinter &p, PrefetchOp op) { 1160 p << " " << op.memref() << '['; 1161 p.printOperands(op.indices()); 1162 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1163 p << ", locality<" << op.localityHint(); 1164 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1165 p.printOptionalAttrDict( 1166 op->getAttrs(), 1167 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1168 p << " : " << op.getMemRefType(); 1169 } 1170 1171 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1172 OperationState &result) { 1173 OpAsmParser::OperandType memrefInfo; 1174 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1175 IntegerAttr localityHint; 1176 MemRefType type; 1177 StringRef readOrWrite, cacheType; 1178 1179 auto indexTy = parser.getBuilder().getIndexType(); 1180 auto i32Type = parser.getBuilder().getIntegerType(32); 1181 if (parser.parseOperand(memrefInfo) || 1182 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1183 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1184 parser.parseComma() || parser.parseKeyword("locality") || 1185 parser.parseLess() || 1186 parser.parseAttribute(localityHint, i32Type, "localityHint", 1187 result.attributes) || 1188 parser.parseGreater() || parser.parseComma() || 1189 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1190 parser.resolveOperand(memrefInfo, type, result.operands) || 1191 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1192 return failure(); 1193 1194 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1195 return parser.emitError(parser.getNameLoc(), 1196 "rw specifier has to be 'read' or 'write'"); 1197 result.addAttribute( 1198 PrefetchOp::getIsWriteAttrName(), 1199 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1200 1201 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1202 return parser.emitError(parser.getNameLoc(), 1203 "cache type has to be 'data' or 'instr'"); 1204 1205 result.addAttribute( 1206 PrefetchOp::getIsDataCacheAttrName(), 1207 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1208 1209 return success(); 1210 } 1211 1212 static LogicalResult verify(PrefetchOp op) { 1213 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1214 return op.emitOpError("too few indices"); 1215 1216 return success(); 1217 } 1218 1219 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1220 SmallVectorImpl<OpFoldResult> &results) { 1221 // prefetch(memrefcast) -> prefetch 1222 return foldMemRefCast(*this); 1223 } 1224 1225 //===----------------------------------------------------------------------===// 1226 // RankOp 1227 //===----------------------------------------------------------------------===// 1228 1229 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1230 // Constant fold rank when the rank of the operand is known. 1231 auto type = getOperand().getType(); 1232 auto shapedType = type.dyn_cast<ShapedType>(); 1233 if (shapedType && shapedType.hasRank()) 1234 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1235 return IntegerAttr(); 1236 } 1237 1238 //===----------------------------------------------------------------------===// 1239 // ReinterpretCastOp 1240 //===----------------------------------------------------------------------===// 1241 1242 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1243 /// `staticSizes` and `staticStrides` are automatically filled with 1244 /// source-memref-rank sentinel values that encode dynamic entries. 1245 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1246 MemRefType resultType, Value source, 1247 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1248 ArrayRef<OpFoldResult> strides, 1249 ArrayRef<NamedAttribute> attrs) { 1250 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1251 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1252 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1253 ShapedType::kDynamicStrideOrOffset); 1254 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1255 ShapedType::kDynamicSize); 1256 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1257 ShapedType::kDynamicStrideOrOffset); 1258 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1259 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1260 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1261 result.addAttributes(attrs); 1262 } 1263 1264 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1265 MemRefType resultType, Value source, 1266 int64_t offset, ArrayRef<int64_t> sizes, 1267 ArrayRef<int64_t> strides, 1268 ArrayRef<NamedAttribute> attrs) { 1269 SmallVector<OpFoldResult> sizeValues = 1270 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1271 return b.getI64IntegerAttr(v); 1272 })); 1273 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1274 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1275 return b.getI64IntegerAttr(v); 1276 })); 1277 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1278 strideValues, attrs); 1279 } 1280 1281 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1282 MemRefType resultType, Value source, Value offset, 1283 ValueRange sizes, ValueRange strides, 1284 ArrayRef<NamedAttribute> attrs) { 1285 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1286 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1287 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1288 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1289 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1290 } 1291 1292 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1293 // completed automatically, like we have for subview and extract_slice. 1294 static LogicalResult verify(ReinterpretCastOp op) { 1295 // The source and result memrefs should be in the same memory space. 1296 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1297 auto resultType = op.getType().cast<MemRefType>(); 1298 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1299 return op.emitError("different memory spaces specified for source type ") 1300 << srcType << " and result memref type " << resultType; 1301 if (srcType.getElementType() != resultType.getElementType()) 1302 return op.emitError("different element types specified for source type ") 1303 << srcType << " and result memref type " << resultType; 1304 1305 // Match sizes in result memref type and in static_sizes attribute. 1306 for (auto &en : 1307 llvm::enumerate(llvm::zip(resultType.getShape(), 1308 extractFromI64ArrayAttr(op.static_sizes())))) { 1309 int64_t resultSize = std::get<0>(en.value()); 1310 int64_t expectedSize = std::get<1>(en.value()); 1311 if (!ShapedType::isDynamic(resultSize) && 1312 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1313 return op.emitError("expected result type with size = ") 1314 << expectedSize << " instead of " << resultSize 1315 << " in dim = " << en.index(); 1316 } 1317 1318 // Match offset and strides in static_offset and static_strides attributes. If 1319 // result memref type has no affine map specified, this will assume an 1320 // identity layout. 1321 int64_t resultOffset; 1322 SmallVector<int64_t, 4> resultStrides; 1323 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1324 return op.emitError( 1325 "expected result type to have strided layout but found ") 1326 << resultType; 1327 1328 // Match offset in result memref type and in static_offsets attribute. 1329 int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); 1330 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1331 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1332 resultOffset != expectedOffset) 1333 return op.emitError("expected result type with offset = ") 1334 << resultOffset << " instead of " << expectedOffset; 1335 1336 // Match strides in result memref type and in static_strides attribute. 1337 for (auto &en : llvm::enumerate(llvm::zip( 1338 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1339 int64_t resultStride = std::get<0>(en.value()); 1340 int64_t expectedStride = std::get<1>(en.value()); 1341 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1342 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1343 resultStride != expectedStride) 1344 return op.emitError("expected result type with stride = ") 1345 << expectedStride << " instead of " << resultStride 1346 << " in dim = " << en.index(); 1347 } 1348 1349 return success(); 1350 } 1351 1352 //===----------------------------------------------------------------------===// 1353 // Reassociative reshape ops 1354 //===----------------------------------------------------------------------===// 1355 1356 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1357 return getSymbolLessAffineMaps(getReassociationExprs()); 1358 } 1359 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1360 return convertReassociationIndicesToExprs(getContext(), 1361 getReassociationIndices()); 1362 } 1363 1364 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1365 return getSymbolLessAffineMaps(getReassociationExprs()); 1366 } 1367 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1368 return convertReassociationIndicesToExprs(getContext(), 1369 getReassociationIndices()); 1370 } 1371 1372 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1373 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1374 } 1375 1376 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1377 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1378 } 1379 1380 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1381 /// copies. 1382 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1383 ArrayRef<int64_t> sizes, 1384 ArrayRef<AffineExpr> strides) { 1385 // Bands of extent one can be reshaped, as they are not reshaped at all. 1386 if (extent == 1) 1387 return true; 1388 // Otherwise, the size of the first dimension needs to be known. 1389 if (ShapedType::isDynamic(sizes[dim])) 1390 return false; 1391 assert(sizes.size() == strides.size() && "mismatched ranks"); 1392 // off by 1 indexing to avoid out of bounds 1393 // V 1394 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1395 // Only bands of static shapes are reshapable. This is due to the fact that 1396 // there is no relation between dynamic sizes and dynamic strides: we do not 1397 // have enough information to know whether a "-1" size corresponds to the 1398 // proper symbol in the AffineExpr of a stride. 1399 if (ShapedType::isDynamic(sizes[idx + 1])) 1400 return false; 1401 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1402 // simplify on the fly and catch more reshapable cases. 1403 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1404 return false; 1405 } 1406 return true; 1407 } 1408 1409 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1410 /// expected to be valid) to `type`. 1411 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1412 /// MemRefType. 1413 static MemRefType 1414 computeReshapeCollapsedType(MemRefType type, 1415 ArrayRef<AffineMap> reassociation) { 1416 auto sizes = type.getShape(); 1417 AffineExpr offset; 1418 SmallVector<AffineExpr, 4> strides; 1419 auto status = getStridesAndOffset(type, strides, offset); 1420 auto isIdentityLayout = type.getLayout().isIdentity(); 1421 (void)status; 1422 assert(succeeded(status) && "expected strided memref"); 1423 1424 SmallVector<int64_t, 4> newSizes; 1425 newSizes.reserve(reassociation.size()); 1426 SmallVector<AffineExpr, 4> newStrides; 1427 newStrides.reserve(reassociation.size()); 1428 1429 // Use the fact that reassociation is valid to simplify the logic: only use 1430 // each map's rank. 1431 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1432 unsigned currentDim = 0; 1433 for (AffineMap m : reassociation) { 1434 unsigned dim = m.getNumResults(); 1435 int64_t size = 1; 1436 AffineExpr stride = strides[currentDim + dim - 1]; 1437 if (isIdentityLayout || 1438 isReshapableDimBand(currentDim, dim, sizes, strides)) { 1439 for (unsigned d = 0; d < dim; ++d) { 1440 int64_t currentSize = sizes[currentDim + d]; 1441 if (ShapedType::isDynamic(currentSize)) { 1442 size = ShapedType::kDynamicSize; 1443 break; 1444 } 1445 size *= currentSize; 1446 } 1447 } else { 1448 size = ShapedType::kDynamicSize; 1449 stride = AffineExpr(); 1450 } 1451 newSizes.push_back(size); 1452 newStrides.push_back(stride); 1453 currentDim += dim; 1454 } 1455 1456 // Early-exit: if `type` is contiguous, the result must be contiguous. 1457 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1458 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1459 1460 // Convert back to int64_t because we don't have enough information to create 1461 // new strided layouts from AffineExpr only. This corresponds to a case where 1462 // copies may be necessary. 1463 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1464 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1465 intOffset = o.getValue(); 1466 SmallVector<int64_t, 4> intStrides; 1467 intStrides.reserve(strides.size()); 1468 for (auto stride : newStrides) { 1469 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1470 intStrides.push_back(cst.getValue()); 1471 else 1472 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1473 } 1474 auto layout = 1475 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1476 return canonicalizeStridedLayout( 1477 MemRefType::Builder(type).setShape(newSizes).setLayout( 1478 AffineMapAttr::get(layout))); 1479 } 1480 1481 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1482 ArrayRef<ReassociationIndices> reassociation, 1483 ArrayRef<NamedAttribute> attrs) { 1484 auto memRefType = src.getType().cast<MemRefType>(); 1485 auto resultType = computeReshapeCollapsedType( 1486 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1487 b.getContext(), reassociation))); 1488 build(b, result, resultType, src, attrs); 1489 result.addAttribute(getReassociationAttrName(), 1490 getReassociationIndicesAttribute(b, reassociation)); 1491 } 1492 1493 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1494 ArrayRef<ReassociationIndices> reassociation, 1495 ArrayRef<NamedAttribute> attrs) { 1496 auto memRefType = src.getType().cast<MemRefType>(); 1497 auto resultType = computeReshapeCollapsedType( 1498 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1499 b.getContext(), reassociation))); 1500 build(b, result, resultType, src, attrs); 1501 result.addAttribute(getReassociationAttrName(), 1502 getReassociationIndicesAttribute(b, reassociation)); 1503 } 1504 1505 template <typename ReshapeOp, 1506 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1507 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1508 MemRefType collapsedType) { 1509 if (failed( 1510 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1511 return failure(); 1512 auto maps = op.getReassociationMaps(); 1513 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1514 if (collapsedType != expectedType) 1515 return op.emitOpError("expected collapsed type to be ") 1516 << expectedType << ", but got " << collapsedType; 1517 return success(); 1518 } 1519 1520 static LogicalResult verify(ExpandShapeOp op) { 1521 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1522 } 1523 1524 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1525 MLIRContext *context) { 1526 results.add<CollapseReshapeOps<ExpandShapeOp>, 1527 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1528 } 1529 1530 static LogicalResult verify(CollapseShapeOp op) { 1531 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1532 } 1533 1534 struct CollapseShapeOpMemRefCastFolder 1535 : public OpRewritePattern<CollapseShapeOp> { 1536 public: 1537 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1538 1539 LogicalResult matchAndRewrite(CollapseShapeOp op, 1540 PatternRewriter &rewriter) const override { 1541 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1542 if (!cast) 1543 return failure(); 1544 1545 if (!CastOp::canFoldIntoConsumerOp(cast)) 1546 return failure(); 1547 1548 Type newResultType = computeReshapeCollapsedType( 1549 cast.getOperand().getType().cast<MemRefType>(), 1550 op.getReassociationMaps()); 1551 1552 if (newResultType == op.getResultType()) { 1553 rewriter.updateRootInPlace( 1554 op, [&]() { op.srcMutable().assign(cast.source()); }); 1555 } else { 1556 Value newOp = rewriter.create<CollapseShapeOp>( 1557 op->getLoc(), cast.source(), op.getReassociationIndices()); 1558 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1559 } 1560 return success(); 1561 } 1562 }; 1563 1564 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1565 MLIRContext *context) { 1566 results.add<CollapseReshapeOps<CollapseShapeOp>, 1567 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1568 CollapseShapeOpMemRefCastFolder>(context); 1569 } 1570 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1571 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1572 } 1573 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1574 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1575 } 1576 1577 //===----------------------------------------------------------------------===// 1578 // ReshapeOp 1579 //===----------------------------------------------------------------------===// 1580 1581 static LogicalResult verify(ReshapeOp op) { 1582 Type operandType = op.source().getType(); 1583 Type resultType = op.result().getType(); 1584 1585 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1586 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1587 if (operandElementType != resultElementType) 1588 return op.emitOpError("element types of source and destination memref " 1589 "types should be the same"); 1590 1591 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1592 if (!operandMemRefType.getLayout().isIdentity()) 1593 return op.emitOpError( 1594 "source memref type should have identity affine map"); 1595 1596 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1597 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1598 if (resultMemRefType) { 1599 if (!resultMemRefType.getLayout().isIdentity()) 1600 return op.emitOpError( 1601 "result memref type should have identity affine map"); 1602 if (shapeSize == ShapedType::kDynamicSize) 1603 return op.emitOpError("cannot use shape operand with dynamic length to " 1604 "reshape to statically-ranked memref type"); 1605 if (shapeSize != resultMemRefType.getRank()) 1606 return op.emitOpError( 1607 "length of shape operand differs from the result's memref rank"); 1608 } 1609 return success(); 1610 } 1611 1612 //===----------------------------------------------------------------------===// 1613 // StoreOp 1614 //===----------------------------------------------------------------------===// 1615 1616 static LogicalResult verify(StoreOp op) { 1617 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1618 return op.emitOpError("store index operand count not equal to memref rank"); 1619 1620 return success(); 1621 } 1622 1623 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1624 SmallVectorImpl<OpFoldResult> &results) { 1625 /// store(memrefcast) -> store 1626 return foldMemRefCast(*this, getValueToStore()); 1627 } 1628 1629 //===----------------------------------------------------------------------===// 1630 // SubViewOp 1631 //===----------------------------------------------------------------------===// 1632 1633 namespace { 1634 /// Helpers to write more idiomatic operations. 1635 namespace saturated_arith { 1636 struct Wrapper { 1637 explicit Wrapper(int64_t v) : v(v) {} 1638 operator int64_t() { return v; } 1639 int64_t v; 1640 }; 1641 Wrapper operator+(Wrapper a, int64_t b) { 1642 if (ShapedType::isDynamicStrideOrOffset(a) || 1643 ShapedType::isDynamicStrideOrOffset(b)) 1644 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1645 return Wrapper(a.v + b); 1646 } 1647 Wrapper operator*(Wrapper a, int64_t b) { 1648 if (ShapedType::isDynamicStrideOrOffset(a) || 1649 ShapedType::isDynamicStrideOrOffset(b)) 1650 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1651 return Wrapper(a.v * b); 1652 } 1653 } // namespace saturated_arith 1654 } // namespace 1655 1656 /// A subview result type can be fully inferred from the source type and the 1657 /// static representation of offsets, sizes and strides. Special sentinels 1658 /// encode the dynamic case. 1659 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1660 ArrayRef<int64_t> staticOffsets, 1661 ArrayRef<int64_t> staticSizes, 1662 ArrayRef<int64_t> staticStrides) { 1663 unsigned rank = sourceMemRefType.getRank(); 1664 (void)rank; 1665 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 1666 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 1667 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 1668 1669 // Extract source offset and strides. 1670 int64_t sourceOffset; 1671 SmallVector<int64_t, 4> sourceStrides; 1672 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1673 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1674 (void)res; 1675 1676 // Compute target offset whose value is: 1677 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1678 int64_t targetOffset = sourceOffset; 1679 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1680 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1681 using namespace saturated_arith; 1682 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1683 } 1684 1685 // Compute target stride whose value is: 1686 // `sourceStrides_i * staticStrides_i`. 1687 SmallVector<int64_t, 4> targetStrides; 1688 targetStrides.reserve(staticOffsets.size()); 1689 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1690 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1691 using namespace saturated_arith; 1692 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1693 } 1694 1695 // The type is now known. 1696 return MemRefType::get( 1697 staticSizes, sourceMemRefType.getElementType(), 1698 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1699 sourceMemRefType.getContext()), 1700 sourceMemRefType.getMemorySpace()); 1701 } 1702 1703 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1704 ArrayRef<OpFoldResult> offsets, 1705 ArrayRef<OpFoldResult> sizes, 1706 ArrayRef<OpFoldResult> strides) { 1707 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1708 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1709 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1710 ShapedType::kDynamicStrideOrOffset); 1711 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1712 ShapedType::kDynamicSize); 1713 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1714 ShapedType::kDynamicStrideOrOffset); 1715 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1716 staticSizes, staticStrides); 1717 } 1718 1719 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1720 MemRefType sourceRankedTensorType, 1721 ArrayRef<int64_t> offsets, 1722 ArrayRef<int64_t> sizes, 1723 ArrayRef<int64_t> strides) { 1724 auto inferredType = 1725 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 1726 .cast<MemRefType>(); 1727 assert(inferredType.getRank() >= resultRank && "expected "); 1728 int rankDiff = inferredType.getRank() - resultRank; 1729 if (rankDiff > 0) { 1730 auto shape = inferredType.getShape(); 1731 llvm::SmallDenseSet<unsigned> dimsToProject; 1732 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1733 SmallVector<int64_t> projectedShape; 1734 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1735 if (!dimsToProject.contains(pos)) 1736 projectedShape.push_back(shape[pos]); 1737 1738 AffineMap map = inferredType.getLayout().getAffineMap(); 1739 if (!map.isIdentity()) 1740 map = getProjectedMap(map, dimsToProject); 1741 inferredType = 1742 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1743 inferredType.getMemorySpace()); 1744 } 1745 return inferredType; 1746 } 1747 1748 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1749 MemRefType sourceRankedTensorType, 1750 ArrayRef<OpFoldResult> offsets, 1751 ArrayRef<OpFoldResult> sizes, 1752 ArrayRef<OpFoldResult> strides) { 1753 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1754 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1755 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1756 ShapedType::kDynamicStrideOrOffset); 1757 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1758 ShapedType::kDynamicSize); 1759 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1760 ShapedType::kDynamicStrideOrOffset); 1761 return SubViewOp::inferRankReducedResultType( 1762 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1763 staticStrides); 1764 } 1765 // Build a SubViewOp with mixed static and dynamic entries and custom result 1766 // type. If the type passed is nullptr, it is inferred. 1767 void SubViewOp::build(OpBuilder &b, OperationState &result, 1768 MemRefType resultType, Value source, 1769 ArrayRef<OpFoldResult> offsets, 1770 ArrayRef<OpFoldResult> sizes, 1771 ArrayRef<OpFoldResult> strides, 1772 ArrayRef<NamedAttribute> attrs) { 1773 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1774 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1775 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1776 ShapedType::kDynamicStrideOrOffset); 1777 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1778 ShapedType::kDynamicSize); 1779 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1780 ShapedType::kDynamicStrideOrOffset); 1781 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1782 // Structuring implementation this way avoids duplication between builders. 1783 if (!resultType) { 1784 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1785 staticSizes, staticStrides) 1786 .cast<MemRefType>(); 1787 } 1788 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1789 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1790 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1791 result.addAttributes(attrs); 1792 } 1793 1794 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1795 // type. 1796 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1797 ArrayRef<OpFoldResult> offsets, 1798 ArrayRef<OpFoldResult> sizes, 1799 ArrayRef<OpFoldResult> strides, 1800 ArrayRef<NamedAttribute> attrs) { 1801 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1802 } 1803 1804 // Build a SubViewOp with static entries and inferred result type. 1805 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1806 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1807 ArrayRef<int64_t> strides, 1808 ArrayRef<NamedAttribute> attrs) { 1809 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1810 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1811 return b.getI64IntegerAttr(v); 1812 })); 1813 SmallVector<OpFoldResult> sizeValues = 1814 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1815 return b.getI64IntegerAttr(v); 1816 })); 1817 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1818 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1819 return b.getI64IntegerAttr(v); 1820 })); 1821 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1822 } 1823 1824 // Build a SubViewOp with dynamic entries and custom result type. If the 1825 // type passed is nullptr, it is inferred. 1826 void SubViewOp::build(OpBuilder &b, OperationState &result, 1827 MemRefType resultType, Value source, 1828 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1829 ArrayRef<int64_t> strides, 1830 ArrayRef<NamedAttribute> attrs) { 1831 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1832 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1833 return b.getI64IntegerAttr(v); 1834 })); 1835 SmallVector<OpFoldResult> sizeValues = 1836 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1837 return b.getI64IntegerAttr(v); 1838 })); 1839 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1840 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1841 return b.getI64IntegerAttr(v); 1842 })); 1843 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1844 attrs); 1845 } 1846 1847 // Build a SubViewOp with dynamic entries and custom result type. If the type 1848 // passed is nullptr, it is inferred. 1849 void SubViewOp::build(OpBuilder &b, OperationState &result, 1850 MemRefType resultType, Value source, ValueRange offsets, 1851 ValueRange sizes, ValueRange strides, 1852 ArrayRef<NamedAttribute> attrs) { 1853 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1854 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1855 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1856 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1857 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1858 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1859 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1860 } 1861 1862 // Build a SubViewOp with dynamic entries and inferred result type. 1863 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1864 ValueRange offsets, ValueRange sizes, ValueRange strides, 1865 ArrayRef<NamedAttribute> attrs) { 1866 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1867 } 1868 1869 /// For ViewLikeOpInterface. 1870 Value SubViewOp::getViewSource() { return source(); } 1871 1872 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 1873 /// value). 1874 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 1875 AffineExpr t1Offset, t2Offset; 1876 SmallVector<AffineExpr> t1Strides, t2Strides; 1877 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 1878 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 1879 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 1880 } 1881 1882 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1883 /// This function is slight variant of `is subsequence` algorithm where 1884 /// not matching dimension must be 1. 1885 static SliceVerificationResult 1886 isRankReducedMemRefType(MemRefType originalType, 1887 MemRefType candidateRankReducedType, 1888 ArrayRef<OpFoldResult> sizes) { 1889 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 1890 if (partialRes != SliceVerificationResult::Success) 1891 return partialRes; 1892 1893 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 1894 originalType, candidateRankReducedType, sizes); 1895 1896 // Sizes cannot be matched in case empty vector is returned. 1897 if (!optionalUnusedDimsMask.hasValue()) 1898 return SliceVerificationResult::LayoutMismatch; 1899 1900 if (originalType.getMemorySpace() != 1901 candidateRankReducedType.getMemorySpace()) 1902 return SliceVerificationResult::MemSpaceMismatch; 1903 1904 // No amount of stride dropping can reconcile incompatible offsets. 1905 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 1906 return SliceVerificationResult::LayoutMismatch; 1907 1908 return SliceVerificationResult::Success; 1909 } 1910 1911 template <typename OpTy> 1912 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 1913 OpTy op, Type expectedType) { 1914 auto memrefType = expectedType.cast<ShapedType>(); 1915 switch (result) { 1916 case SliceVerificationResult::Success: 1917 return success(); 1918 case SliceVerificationResult::RankTooLarge: 1919 return op.emitError("expected result rank to be smaller or equal to ") 1920 << "the source rank. "; 1921 case SliceVerificationResult::SizeMismatch: 1922 return op.emitError("expected result type to be ") 1923 << expectedType 1924 << " or a rank-reduced version. (mismatch of result sizes) "; 1925 case SliceVerificationResult::ElemTypeMismatch: 1926 return op.emitError("expected result element type to be ") 1927 << memrefType.getElementType(); 1928 case SliceVerificationResult::MemSpaceMismatch: 1929 return op.emitError("expected result and source memory spaces to match."); 1930 case SliceVerificationResult::LayoutMismatch: 1931 return op.emitError("expected result type to be ") 1932 << expectedType 1933 << " or a rank-reduced version. (mismatch of result layout) "; 1934 } 1935 llvm_unreachable("unexpected subview verification result"); 1936 } 1937 1938 /// Verifier for SubViewOp. 1939 static LogicalResult verify(SubViewOp op) { 1940 MemRefType baseType = op.getSourceType(); 1941 MemRefType subViewType = op.getType(); 1942 1943 // The base memref and the view memref should be in the same memory space. 1944 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1945 return op.emitError("different memory spaces specified for base memref " 1946 "type ") 1947 << baseType << " and subview memref type " << subViewType; 1948 1949 // Verify that the base memref type has a strided layout map. 1950 if (!isStrided(baseType)) 1951 return op.emitError("base type ") << baseType << " is not strided"; 1952 1953 // Verify result type against inferred type. 1954 auto expectedType = SubViewOp::inferResultType( 1955 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1956 extractFromI64ArrayAttr(op.static_sizes()), 1957 extractFromI64ArrayAttr(op.static_strides())); 1958 1959 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 1960 subViewType, op.getMixedSizes()); 1961 return produceSubViewErrorMsg(result, op, expectedType); 1962 } 1963 1964 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 1965 return os << "range " << range.offset << ":" << range.size << ":" 1966 << range.stride; 1967 } 1968 1969 /// Return the list of Range (i.e. offset, size, stride). Each Range 1970 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1971 /// with `b` at location `loc`. 1972 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1973 OpBuilder &b, Location loc) { 1974 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1975 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1976 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1977 SmallVector<Range, 8> res; 1978 unsigned rank = ranks[0]; 1979 res.reserve(rank); 1980 for (unsigned idx = 0; idx < rank; ++idx) { 1981 Value offset = 1982 op.isDynamicOffset(idx) 1983 ? op.getDynamicOffset(idx) 1984 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1985 Value size = 1986 op.isDynamicSize(idx) 1987 ? op.getDynamicSize(idx) 1988 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 1989 Value stride = 1990 op.isDynamicStride(idx) 1991 ? op.getDynamicStride(idx) 1992 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 1993 res.emplace_back(Range{offset, size, stride}); 1994 } 1995 return res; 1996 } 1997 1998 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 1999 /// deduce the result type for the given `sourceType`. Additionally, reduce the 2000 /// rank of the inferred result type if `currentResultType` is lower rank than 2001 /// `currentSourceType`. Use this signature if `sourceType` is updated together 2002 /// with the result type. In this case, it is important to compute the dropped 2003 /// dimensions using `currentSourceType` whose strides align with 2004 /// `currentResultType`. 2005 static MemRefType getCanonicalSubViewResultType( 2006 MemRefType currentResultType, MemRefType currentSourceType, 2007 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 2008 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 2009 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2010 mixedSizes, mixedStrides) 2011 .cast<MemRefType>(); 2012 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 2013 computeMemRefRankReductionMask(currentSourceType, currentResultType, 2014 mixedSizes); 2015 // Return nullptr as failure mode. 2016 if (!unusedDims) 2017 return nullptr; 2018 SmallVector<int64_t> shape; 2019 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 2020 if (unusedDims->count(sizes.index())) 2021 continue; 2022 shape.push_back(sizes.value()); 2023 } 2024 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 2025 if (!layoutMap.isIdentity()) 2026 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 2027 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 2028 nonRankReducedType.getMemorySpace()); 2029 } 2030 2031 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2032 /// deduce the result type. Additionally, reduce the rank of the inferred result 2033 /// type if `currentResultType` is lower rank than `sourceType`. 2034 static MemRefType getCanonicalSubViewResultType( 2035 MemRefType currentResultType, MemRefType sourceType, 2036 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 2037 ArrayRef<OpFoldResult> mixedStrides) { 2038 return getCanonicalSubViewResultType(currentResultType, sourceType, 2039 sourceType, mixedOffsets, mixedSizes, 2040 mixedStrides); 2041 } 2042 2043 /// Helper method to check if a `subview` operation is trivially a no-op. This 2044 /// is the case if the all offsets are zero, all strides are 1, and the source 2045 /// shape is same as the size of the subview. In such cases, the subview can be 2046 /// folded into its source. 2047 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 2048 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 2049 return false; 2050 2051 auto mixedOffsets = subViewOp.getMixedOffsets(); 2052 auto mixedSizes = subViewOp.getMixedSizes(); 2053 auto mixedStrides = subViewOp.getMixedStrides(); 2054 2055 // Check offsets are zero. 2056 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 2057 Optional<int64_t> intValue = getConstantIntValue(ofr); 2058 return !intValue || intValue.getValue() != 0; 2059 })) 2060 return false; 2061 2062 // Check strides are one. 2063 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 2064 Optional<int64_t> intValue = getConstantIntValue(ofr); 2065 return !intValue || intValue.getValue() != 1; 2066 })) 2067 return false; 2068 2069 // Check all size values are static and matches the (static) source shape. 2070 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 2071 for (const auto &size : llvm::enumerate(mixedSizes)) { 2072 Optional<int64_t> intValue = getConstantIntValue(size.value()); 2073 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 2074 return false; 2075 } 2076 // All conditions met. The `SubViewOp` is foldable as a no-op. 2077 return true; 2078 } 2079 2080 namespace { 2081 /// Pattern to rewrite a subview op with MemRefCast arguments. 2082 /// This essentially pushes memref.cast past its consuming subview when 2083 /// `canFoldIntoConsumerOp` is true. 2084 /// 2085 /// Example: 2086 /// ``` 2087 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2088 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2089 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2090 /// ``` 2091 /// is rewritten into: 2092 /// ``` 2093 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2094 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2095 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2096 /// ``` 2097 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2098 public: 2099 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2100 2101 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2102 PatternRewriter &rewriter) const override { 2103 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2104 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2105 return matchPattern(operand, matchConstantIndex()); 2106 })) 2107 return failure(); 2108 2109 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2110 if (!castOp) 2111 return failure(); 2112 2113 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2114 return failure(); 2115 2116 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 2117 // MemRefCastOp source operand type to infer the result type and the current 2118 // SubViewOp source operand type to compute the dropped dimensions if the 2119 // operation is rank-reducing. 2120 auto resultType = getCanonicalSubViewResultType( 2121 subViewOp.getType(), subViewOp.getSourceType(), 2122 castOp.source().getType().cast<MemRefType>(), 2123 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2124 subViewOp.getMixedStrides()); 2125 if (!resultType) 2126 return failure(); 2127 2128 Value newSubView = rewriter.create<SubViewOp>( 2129 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2130 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2131 subViewOp.static_sizes(), subViewOp.static_strides()); 2132 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2133 newSubView); 2134 return success(); 2135 } 2136 }; 2137 2138 /// Canonicalize subview ops that are no-ops. When the source shape is not same 2139 /// as a result shape due to use of `affine_map`. 2140 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 2141 public: 2142 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2143 2144 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2145 PatternRewriter &rewriter) const override { 2146 if (!isTrivialSubViewOp(subViewOp)) 2147 return failure(); 2148 if (subViewOp.getSourceType() == subViewOp.getType()) { 2149 rewriter.replaceOp(subViewOp, subViewOp.source()); 2150 return success(); 2151 } 2152 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(), 2153 subViewOp.getType()); 2154 return success(); 2155 } 2156 }; 2157 } // namespace 2158 2159 /// Return the canonical type of the result of a subview. 2160 struct SubViewReturnTypeCanonicalizer { 2161 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2162 ArrayRef<OpFoldResult> mixedSizes, 2163 ArrayRef<OpFoldResult> mixedStrides) { 2164 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2165 mixedOffsets, mixedSizes, 2166 mixedStrides); 2167 } 2168 }; 2169 2170 /// A canonicalizer wrapper to replace SubViewOps. 2171 struct SubViewCanonicalizer { 2172 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2173 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2174 } 2175 }; 2176 2177 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2178 MLIRContext *context) { 2179 results 2180 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2181 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2182 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2183 } 2184 2185 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2186 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2187 auto sourceShapedType = source().getType().cast<ShapedType>(); 2188 2189 if (resultShapedType.hasStaticShape() && 2190 resultShapedType == sourceShapedType) { 2191 return getViewSource(); 2192 } 2193 2194 return {}; 2195 } 2196 2197 //===----------------------------------------------------------------------===// 2198 // TransposeOp 2199 //===----------------------------------------------------------------------===// 2200 2201 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2202 static MemRefType inferTransposeResultType(MemRefType memRefType, 2203 AffineMap permutationMap) { 2204 auto rank = memRefType.getRank(); 2205 auto originalSizes = memRefType.getShape(); 2206 // Compute permuted sizes. 2207 SmallVector<int64_t, 4> sizes(rank, 0); 2208 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2209 sizes[en.index()] = 2210 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2211 2212 // Compute permuted strides. 2213 int64_t offset; 2214 SmallVector<int64_t, 4> strides; 2215 auto res = getStridesAndOffset(memRefType, strides, offset); 2216 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2217 (void)res; 2218 auto map = 2219 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2220 map = permutationMap ? map.compose(permutationMap) : map; 2221 return MemRefType::Builder(memRefType) 2222 .setShape(sizes) 2223 .setLayout(AffineMapAttr::get(map)); 2224 } 2225 2226 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2227 AffineMapAttr permutation, 2228 ArrayRef<NamedAttribute> attrs) { 2229 auto permutationMap = permutation.getValue(); 2230 assert(permutationMap); 2231 2232 auto memRefType = in.getType().cast<MemRefType>(); 2233 // Compute result type. 2234 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2235 2236 build(b, result, resultType, in, attrs); 2237 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2238 } 2239 2240 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2241 static void print(OpAsmPrinter &p, TransposeOp op) { 2242 p << " " << op.in() << " " << op.permutation(); 2243 p.printOptionalAttrDict(op->getAttrs(), 2244 {TransposeOp::getPermutationAttrName()}); 2245 p << " : " << op.in().getType() << " to " << op.getType(); 2246 } 2247 2248 static ParseResult parseTransposeOp(OpAsmParser &parser, 2249 OperationState &result) { 2250 OpAsmParser::OperandType in; 2251 AffineMap permutation; 2252 MemRefType srcType, dstType; 2253 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2254 parser.parseOptionalAttrDict(result.attributes) || 2255 parser.parseColonType(srcType) || 2256 parser.resolveOperand(in, srcType, result.operands) || 2257 parser.parseKeywordType("to", dstType) || 2258 parser.addTypeToList(dstType, result.types)) 2259 return failure(); 2260 2261 result.addAttribute(TransposeOp::getPermutationAttrName(), 2262 AffineMapAttr::get(permutation)); 2263 return success(); 2264 } 2265 2266 static LogicalResult verify(TransposeOp op) { 2267 if (!op.permutation().isPermutation()) 2268 return op.emitOpError("expected a permutation map"); 2269 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2270 return op.emitOpError( 2271 "expected a permutation map of same rank as the input"); 2272 2273 auto srcType = op.in().getType().cast<MemRefType>(); 2274 auto dstType = op.getType().cast<MemRefType>(); 2275 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2276 if (dstType != transposedType) 2277 return op.emitOpError("output type ") 2278 << dstType << " does not match transposed input type " << srcType 2279 << ", " << transposedType; 2280 return success(); 2281 } 2282 2283 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2284 if (succeeded(foldMemRefCast(*this))) 2285 return getResult(); 2286 return {}; 2287 } 2288 2289 //===----------------------------------------------------------------------===// 2290 // ViewOp 2291 //===----------------------------------------------------------------------===// 2292 2293 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2294 OpAsmParser::OperandType srcInfo; 2295 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2296 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2297 auto indexType = parser.getBuilder().getIndexType(); 2298 Type srcType, dstType; 2299 SMLoc offsetLoc; 2300 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2301 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2302 return failure(); 2303 2304 if (offsetInfo.size() != 1) 2305 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2306 2307 return failure( 2308 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2309 parser.parseOptionalAttrDict(result.attributes) || 2310 parser.parseColonType(srcType) || 2311 parser.resolveOperand(srcInfo, srcType, result.operands) || 2312 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2313 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2314 parser.parseKeywordType("to", dstType) || 2315 parser.addTypeToList(dstType, result.types)); 2316 } 2317 2318 static void print(OpAsmPrinter &p, ViewOp op) { 2319 p << ' ' << op.getOperand(0) << '['; 2320 p.printOperand(op.byte_shift()); 2321 p << "][" << op.sizes() << ']'; 2322 p.printOptionalAttrDict(op->getAttrs()); 2323 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2324 } 2325 2326 static LogicalResult verify(ViewOp op) { 2327 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2328 auto viewType = op.getType(); 2329 2330 // The base memref should have identity layout map (or none). 2331 if (!baseType.getLayout().isIdentity()) 2332 return op.emitError("unsupported map for base memref type ") << baseType; 2333 2334 // The result memref should have identity layout map (or none). 2335 if (!viewType.getLayout().isIdentity()) 2336 return op.emitError("unsupported map for result memref type ") << viewType; 2337 2338 // The base memref and the view memref should be in the same memory space. 2339 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2340 return op.emitError("different memory spaces specified for base memref " 2341 "type ") 2342 << baseType << " and view memref type " << viewType; 2343 2344 // Verify that we have the correct number of sizes for the result type. 2345 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2346 if (op.sizes().size() != numDynamicDims) 2347 return op.emitError("incorrect number of size operands for type ") 2348 << viewType; 2349 2350 return success(); 2351 } 2352 2353 Value ViewOp::getViewSource() { return source(); } 2354 2355 namespace { 2356 2357 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2358 using OpRewritePattern<ViewOp>::OpRewritePattern; 2359 2360 LogicalResult matchAndRewrite(ViewOp viewOp, 2361 PatternRewriter &rewriter) const override { 2362 // Return if none of the operands are constants. 2363 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2364 return matchPattern(operand, matchConstantIndex()); 2365 })) 2366 return failure(); 2367 2368 // Get result memref type. 2369 auto memrefType = viewOp.getType(); 2370 2371 // Get offset from old memref view type 'memRefType'. 2372 int64_t oldOffset; 2373 SmallVector<int64_t, 4> oldStrides; 2374 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2375 return failure(); 2376 assert(oldOffset == 0 && "Expected 0 offset"); 2377 2378 SmallVector<Value, 4> newOperands; 2379 2380 // Offset cannot be folded into result type. 2381 2382 // Fold any dynamic dim operands which are produced by a constant. 2383 SmallVector<int64_t, 4> newShapeConstants; 2384 newShapeConstants.reserve(memrefType.getRank()); 2385 2386 unsigned dynamicDimPos = 0; 2387 unsigned rank = memrefType.getRank(); 2388 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2389 int64_t dimSize = memrefType.getDimSize(dim); 2390 // If this is already static dimension, keep it. 2391 if (!ShapedType::isDynamic(dimSize)) { 2392 newShapeConstants.push_back(dimSize); 2393 continue; 2394 } 2395 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2396 if (auto constantIndexOp = 2397 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2398 // Dynamic shape dimension will be folded. 2399 newShapeConstants.push_back(constantIndexOp.value()); 2400 } else { 2401 // Dynamic shape dimension not folded; copy operand from old memref. 2402 newShapeConstants.push_back(dimSize); 2403 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2404 } 2405 dynamicDimPos++; 2406 } 2407 2408 // Create new memref type with constant folded dims. 2409 MemRefType newMemRefType = 2410 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2411 // Nothing new, don't fold. 2412 if (newMemRefType == memrefType) 2413 return failure(); 2414 2415 // Create new ViewOp. 2416 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2417 viewOp.getOperand(0), 2418 viewOp.byte_shift(), newOperands); 2419 // Insert a cast so we have the same type as the old memref type. 2420 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2421 return success(); 2422 } 2423 }; 2424 2425 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2426 using OpRewritePattern<ViewOp>::OpRewritePattern; 2427 2428 LogicalResult matchAndRewrite(ViewOp viewOp, 2429 PatternRewriter &rewriter) const override { 2430 Value memrefOperand = viewOp.getOperand(0); 2431 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2432 if (!memrefCastOp) 2433 return failure(); 2434 Value allocOperand = memrefCastOp.getOperand(); 2435 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2436 if (!allocOp) 2437 return failure(); 2438 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2439 viewOp.byte_shift(), viewOp.sizes()); 2440 return success(); 2441 } 2442 }; 2443 2444 } // namespace 2445 2446 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2447 MLIRContext *context) { 2448 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2449 } 2450 2451 //===----------------------------------------------------------------------===// 2452 // AtomicRMWOp 2453 //===----------------------------------------------------------------------===// 2454 2455 static LogicalResult verify(AtomicRMWOp op) { 2456 if (op.getMemRefType().getRank() != op.getNumOperands() - 2) 2457 return op.emitOpError( 2458 "expects the number of subscripts to be equal to memref rank"); 2459 switch (op.kind()) { 2460 case arith::AtomicRMWKind::addf: 2461 case arith::AtomicRMWKind::maxf: 2462 case arith::AtomicRMWKind::minf: 2463 case arith::AtomicRMWKind::mulf: 2464 if (!op.value().getType().isa<FloatType>()) 2465 return op.emitOpError() 2466 << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) 2467 << "' expects a floating-point type"; 2468 break; 2469 case arith::AtomicRMWKind::addi: 2470 case arith::AtomicRMWKind::maxs: 2471 case arith::AtomicRMWKind::maxu: 2472 case arith::AtomicRMWKind::mins: 2473 case arith::AtomicRMWKind::minu: 2474 case arith::AtomicRMWKind::muli: 2475 case arith::AtomicRMWKind::ori: 2476 case arith::AtomicRMWKind::andi: 2477 if (!op.value().getType().isa<IntegerType>()) 2478 return op.emitOpError() 2479 << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) 2480 << "' expects an integer type"; 2481 break; 2482 default: 2483 break; 2484 } 2485 return success(); 2486 } 2487 2488 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2489 /// atomicrmw(memrefcast) -> atomicrmw 2490 if (succeeded(foldMemRefCast(*this, value()))) 2491 return getResult(); 2492 return OpFoldResult(); 2493 } 2494 2495 //===----------------------------------------------------------------------===// 2496 // TableGen'd op method definitions 2497 //===----------------------------------------------------------------------===// 2498 2499 #define GET_OP_CLASSES 2500 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2501