1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/ViewLikeInterface.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/ADT/SmallBitVector.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 if (arith::ConstantOp::isBuildableWith(value, type)) 34 return builder.create<arith::ConstantOp>(loc, value, type); 35 return nullptr; 36 } 37 38 //===----------------------------------------------------------------------===// 39 // Common canonicalization pattern support logic 40 //===----------------------------------------------------------------------===// 41 42 /// This is a common class used for patterns of the form 43 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 44 /// into the root operation directly. 45 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 46 bool folded = false; 47 for (OpOperand &operand : op->getOpOperands()) { 48 auto cast = operand.get().getDefiningOp<CastOp>(); 49 if (cast && operand.get() != inner && 50 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 51 operand.set(cast.getOperand()); 52 folded = true; 53 } 54 } 55 return success(folded); 56 } 57 58 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 59 /// type. 60 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 61 if (auto memref = type.dyn_cast<MemRefType>()) 62 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 63 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 64 return UnrankedTensorType::get(memref.getElementType()); 65 return NoneType::get(type.getContext()); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // AllocOp / AllocaOp 70 //===----------------------------------------------------------------------===// 71 72 template <typename AllocLikeOp> 73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 74 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 75 "applies to only alloc or alloca"); 76 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 77 if (!memRefType) 78 return op.emitOpError("result must be a memref"); 79 80 if (static_cast<int64_t>(op.dynamicSizes().size()) != 81 memRefType.getNumDynamicDims()) 82 return op.emitOpError("dimension operand count does not equal memref " 83 "dynamic dimension count"); 84 85 unsigned numSymbols = 0; 86 if (!memRefType.getLayout().isIdentity()) 87 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 88 if (op.symbolOperands().size() != numSymbols) 89 return op.emitOpError("symbol operand count does not equal memref symbol " 90 "count: expected ") 91 << numSymbols << ", got " << op.symbolOperands().size(); 92 93 return success(); 94 } 95 96 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } 97 98 LogicalResult AllocaOp::verify() { 99 // An alloca op needs to have an ancestor with an allocation scope trait. 100 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 101 return emitOpError( 102 "requires an ancestor op with AutomaticAllocationScope trait"); 103 104 return verifyAllocLikeOp(*this); 105 } 106 107 namespace { 108 /// Fold constant dimensions into an alloc like operation. 109 template <typename AllocLikeOp> 110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 111 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(AllocLikeOp alloc, 114 PatternRewriter &rewriter) const override { 115 // Check to see if any dimensions operands are constants. If so, we can 116 // substitute and drop them. 117 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 118 return matchPattern(operand, matchConstantIndex()); 119 })) 120 return failure(); 121 122 auto memrefType = alloc.getType(); 123 124 // Ok, we have one or more constant operands. Collect the non-constant ones 125 // and keep track of the resultant memref type to build. 126 SmallVector<int64_t, 4> newShapeConstants; 127 newShapeConstants.reserve(memrefType.getRank()); 128 SmallVector<Value, 4> dynamicSizes; 129 130 unsigned dynamicDimPos = 0; 131 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 132 int64_t dimSize = memrefType.getDimSize(dim); 133 // If this is already static dimension, keep it. 134 if (dimSize != -1) { 135 newShapeConstants.push_back(dimSize); 136 continue; 137 } 138 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 139 auto *defOp = dynamicSize.getDefiningOp(); 140 if (auto constantIndexOp = 141 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 142 // Dynamic shape dimension will be folded. 143 newShapeConstants.push_back(constantIndexOp.value()); 144 } else { 145 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 146 newShapeConstants.push_back(-1); 147 dynamicSizes.push_back(dynamicSize); 148 } 149 dynamicDimPos++; 150 } 151 152 // Create new memref type (which will have fewer dynamic dimensions). 153 MemRefType newMemRefType = 154 MemRefType::Builder(memrefType).setShape(newShapeConstants); 155 assert(static_cast<int64_t>(dynamicSizes.size()) == 156 newMemRefType.getNumDynamicDims()); 157 158 // Create and insert the alloc op for the new memref. 159 auto newAlloc = rewriter.create<AllocLikeOp>( 160 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 161 alloc.alignmentAttr()); 162 // Insert a cast so we have the same type as the old alloc. 163 auto resultCast = 164 rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc); 165 166 rewriter.replaceOp(alloc, {resultCast}); 167 return success(); 168 } 169 }; 170 171 /// Fold alloc operations with no users or only store and dealloc uses. 172 template <typename T> 173 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 174 using OpRewritePattern<T>::OpRewritePattern; 175 176 LogicalResult matchAndRewrite(T alloc, 177 PatternRewriter &rewriter) const override { 178 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 179 if (auto storeOp = dyn_cast<StoreOp>(op)) 180 return storeOp.value() == alloc; 181 return !isa<DeallocOp>(op); 182 })) 183 return failure(); 184 185 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 186 rewriter.eraseOp(user); 187 188 rewriter.eraseOp(alloc); 189 return success(); 190 } 191 }; 192 } // namespace 193 194 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 195 MLIRContext *context) { 196 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 197 } 198 199 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 200 MLIRContext *context) { 201 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 202 context); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // AllocaScopeOp 207 //===----------------------------------------------------------------------===// 208 209 void AllocaScopeOp::print(OpAsmPrinter &p) { 210 bool printBlockTerminators = false; 211 212 p << ' '; 213 if (!results().empty()) { 214 p << " -> (" << getResultTypes() << ")"; 215 printBlockTerminators = true; 216 } 217 p << ' '; 218 p.printRegion(bodyRegion(), 219 /*printEntryBlockArgs=*/false, 220 /*printBlockTerminators=*/printBlockTerminators); 221 p.printOptionalAttrDict((*this)->getAttrs()); 222 } 223 224 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { 225 // Create a region for the body. 226 result.regions.reserve(1); 227 Region *bodyRegion = result.addRegion(); 228 229 // Parse optional results type list. 230 if (parser.parseOptionalArrowTypeList(result.types)) 231 return failure(); 232 233 // Parse the body region. 234 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 235 return failure(); 236 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 237 result.location); 238 239 // Parse the optional attribute list. 240 if (parser.parseOptionalAttrDict(result.attributes)) 241 return failure(); 242 243 return success(); 244 } 245 246 LogicalResult AllocaScopeOp::verify() { 247 return RegionBranchOpInterface::verifyTypes(*this); 248 } 249 250 void AllocaScopeOp::getSuccessorRegions( 251 Optional<unsigned> index, ArrayRef<Attribute> operands, 252 SmallVectorImpl<RegionSuccessor> ®ions) { 253 if (index.hasValue()) { 254 regions.push_back(RegionSuccessor(getResults())); 255 return; 256 } 257 258 regions.push_back(RegionSuccessor(&bodyRegion())); 259 } 260 261 //===----------------------------------------------------------------------===// 262 // AssumeAlignmentOp 263 //===----------------------------------------------------------------------===// 264 265 LogicalResult AssumeAlignmentOp::verify() { 266 if (!llvm::isPowerOf2_32(alignment())) 267 return emitOpError("alignment must be power of 2"); 268 return success(); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // CastOp 273 //===----------------------------------------------------------------------===// 274 275 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 276 /// source memref. This is useful to to fold a memref.cast into a consuming op 277 /// and implement canonicalization patterns for ops in different dialects that 278 /// may consume the results of memref.cast operations. Such foldable memref.cast 279 /// operations are typically inserted as `view` and `subview` ops are 280 /// canonicalized, to preserve the type compatibility of their uses. 281 /// 282 /// Returns true when all conditions are met: 283 /// 1. source and result are ranked memrefs with strided semantics and same 284 /// element type and rank. 285 /// 2. each of the source's size, offset or stride has more static information 286 /// than the corresponding result's size, offset or stride. 287 /// 288 /// Example 1: 289 /// ```mlir 290 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 291 /// %2 = consumer %1 ... : memref<?x?xf32> ... 292 /// ``` 293 /// 294 /// may fold into: 295 /// 296 /// ```mlir 297 /// %2 = consumer %0 ... : memref<8x16xf32> ... 298 /// ``` 299 /// 300 /// Example 2: 301 /// ``` 302 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 303 /// to memref<?x?xf32> 304 /// consumer %1 : memref<?x?xf32> ... 305 /// ``` 306 /// 307 /// may fold into: 308 /// 309 /// ``` 310 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 311 /// ``` 312 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 313 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 314 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 315 316 // Requires ranked MemRefType. 317 if (!sourceType || !resultType) 318 return false; 319 320 // Requires same elemental type. 321 if (sourceType.getElementType() != resultType.getElementType()) 322 return false; 323 324 // Requires same rank. 325 if (sourceType.getRank() != resultType.getRank()) 326 return false; 327 328 // Only fold casts between strided memref forms. 329 int64_t sourceOffset, resultOffset; 330 SmallVector<int64_t, 4> sourceStrides, resultStrides; 331 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 332 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 333 return false; 334 335 // If cast is towards more static sizes along any dimension, don't fold. 336 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 337 auto ss = std::get<0>(it), st = std::get<1>(it); 338 if (ss != st) 339 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 340 return false; 341 } 342 343 // If cast is towards more static offset along any dimension, don't fold. 344 if (sourceOffset != resultOffset) 345 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 346 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 347 return false; 348 349 // If cast is towards more static strides along any dimension, don't fold. 350 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 351 auto ss = std::get<0>(it), st = std::get<1>(it); 352 if (ss != st) 353 if (ShapedType::isDynamicStrideOrOffset(ss) && 354 !ShapedType::isDynamicStrideOrOffset(st)) 355 return false; 356 } 357 358 return true; 359 } 360 361 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 362 if (inputs.size() != 1 || outputs.size() != 1) 363 return false; 364 Type a = inputs.front(), b = outputs.front(); 365 auto aT = a.dyn_cast<MemRefType>(); 366 auto bT = b.dyn_cast<MemRefType>(); 367 368 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 369 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 370 371 if (aT && bT) { 372 if (aT.getElementType() != bT.getElementType()) 373 return false; 374 if (aT.getLayout() != bT.getLayout()) { 375 int64_t aOffset, bOffset; 376 SmallVector<int64_t, 4> aStrides, bStrides; 377 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 378 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 379 aStrides.size() != bStrides.size()) 380 return false; 381 382 // Strides along a dimension/offset are compatible if the value in the 383 // source memref is static and the value in the target memref is the 384 // same. They are also compatible if either one is dynamic (see 385 // description of MemRefCastOp for details). 386 auto checkCompatible = [](int64_t a, int64_t b) { 387 return (a == MemRefType::getDynamicStrideOrOffset() || 388 b == MemRefType::getDynamicStrideOrOffset() || a == b); 389 }; 390 if (!checkCompatible(aOffset, bOffset)) 391 return false; 392 for (const auto &aStride : enumerate(aStrides)) 393 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 394 return false; 395 } 396 if (aT.getMemorySpace() != bT.getMemorySpace()) 397 return false; 398 399 // They must have the same rank, and any specified dimensions must match. 400 if (aT.getRank() != bT.getRank()) 401 return false; 402 403 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 404 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 405 if (aDim != -1 && bDim != -1 && aDim != bDim) 406 return false; 407 } 408 return true; 409 } else { 410 if (!aT && !uaT) 411 return false; 412 if (!bT && !ubT) 413 return false; 414 // Unranked to unranked casting is unsupported 415 if (uaT && ubT) 416 return false; 417 418 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 419 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 420 if (aEltType != bEltType) 421 return false; 422 423 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 424 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 425 return aMemSpace == bMemSpace; 426 } 427 428 return false; 429 } 430 431 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 432 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 433 } 434 435 //===----------------------------------------------------------------------===// 436 // CopyOp 437 //===----------------------------------------------------------------------===// 438 439 namespace { 440 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 441 /// and element type, the cast can be skipped. Such CastOps only cast the layout 442 /// of the type. 443 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 444 using OpRewritePattern<CopyOp>::OpRewritePattern; 445 446 LogicalResult matchAndRewrite(CopyOp copyOp, 447 PatternRewriter &rewriter) const override { 448 bool modified = false; 449 450 // Check source. 451 if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) { 452 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 453 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 454 455 if (fromType && toType) { 456 if (fromType.getShape() == toType.getShape() && 457 fromType.getElementType() == toType.getElementType()) { 458 rewriter.updateRootInPlace( 459 copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); 460 modified = true; 461 } 462 } 463 } 464 465 // Check target. 466 if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) { 467 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 468 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 469 470 if (fromType && toType) { 471 if (fromType.getShape() == toType.getShape() && 472 fromType.getElementType() == toType.getElementType()) { 473 rewriter.updateRootInPlace( 474 copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); 475 modified = true; 476 } 477 } 478 } 479 480 return success(modified); 481 } 482 }; 483 484 /// Fold memref.copy(%x, %x). 485 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 486 using OpRewritePattern<CopyOp>::OpRewritePattern; 487 488 LogicalResult matchAndRewrite(CopyOp copyOp, 489 PatternRewriter &rewriter) const override { 490 if (copyOp.source() != copyOp.target()) 491 return failure(); 492 493 rewriter.eraseOp(copyOp); 494 return success(); 495 } 496 }; 497 } // namespace 498 499 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 500 MLIRContext *context) { 501 results.add<FoldCopyOfCast, FoldSelfCopy>(context); 502 } 503 504 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands, 505 SmallVectorImpl<OpFoldResult> &results) { 506 /// copy(memrefcast) -> copy 507 bool folded = false; 508 Operation *op = *this; 509 for (OpOperand &operand : op->getOpOperands()) { 510 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 511 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 512 operand.set(castOp.getOperand()); 513 folded = true; 514 } 515 } 516 return success(folded); 517 } 518 519 //===----------------------------------------------------------------------===// 520 // DeallocOp 521 //===----------------------------------------------------------------------===// 522 523 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 524 SmallVectorImpl<OpFoldResult> &results) { 525 /// dealloc(memrefcast) -> dealloc 526 return foldMemRefCast(*this); 527 } 528 529 //===----------------------------------------------------------------------===// 530 // DimOp 531 //===----------------------------------------------------------------------===// 532 533 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 534 int64_t index) { 535 auto loc = result.location; 536 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 537 build(builder, result, source, indexValue); 538 } 539 540 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 541 Value index) { 542 auto indexTy = builder.getIndexType(); 543 build(builder, result, indexTy, source, index); 544 } 545 546 Optional<int64_t> DimOp::getConstantIndex() { 547 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 548 return constantOp.getValue().cast<IntegerAttr>().getInt(); 549 return {}; 550 } 551 552 LogicalResult DimOp::verify() { 553 // Assume unknown index to be in range. 554 Optional<int64_t> index = getConstantIndex(); 555 if (!index.hasValue()) 556 return success(); 557 558 // Check that constant index is not knowingly out of range. 559 auto type = source().getType(); 560 if (auto memrefType = type.dyn_cast<MemRefType>()) { 561 if (index.getValue() >= memrefType.getRank()) 562 return emitOpError("index is out of range"); 563 } else if (type.isa<UnrankedMemRefType>()) { 564 // Assume index to be in range. 565 } else { 566 llvm_unreachable("expected operand with memref type"); 567 } 568 return success(); 569 } 570 571 /// Return a map with key being elements in `vals` and data being number of 572 /// occurences of it. Use std::map, since the `vals` here are strides and the 573 /// dynamic stride value is the same as the tombstone value for 574 /// `DenseMap<int64_t>`. 575 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 576 std::map<int64_t, unsigned> numOccurences; 577 for (auto val : vals) 578 numOccurences[val]++; 579 return numOccurences; 580 } 581 582 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 583 /// to be a subset of `originalType` with some `1` entries erased, return the 584 /// set of indices that specifies which of the entries of `originalShape` are 585 /// dropped to obtain `reducedShape`. 586 /// This accounts for cases where there are multiple unit-dims, but only a 587 /// subset of those are dropped. For MemRefTypes these can be disambiguated 588 /// using the strides. If a dimension is dropped the stride must be dropped too. 589 static llvm::Optional<llvm::SmallBitVector> 590 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 591 ArrayRef<OpFoldResult> sizes) { 592 llvm::SmallBitVector unusedDims(originalType.getRank()); 593 if (originalType.getRank() == reducedType.getRank()) 594 return unusedDims; 595 596 for (const auto &dim : llvm::enumerate(sizes)) 597 if (auto attr = dim.value().dyn_cast<Attribute>()) 598 if (attr.cast<IntegerAttr>().getInt() == 1) 599 unusedDims.set(dim.index()); 600 601 SmallVector<int64_t> originalStrides, candidateStrides; 602 int64_t originalOffset, candidateOffset; 603 if (failed( 604 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 605 failed( 606 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 607 return llvm::None; 608 609 // For memrefs, a dimension is truly dropped if its corresponding stride is 610 // also dropped. This is particularly important when more than one of the dims 611 // is 1. Track the number of occurences of the strides in the original type 612 // and the candidate type. For each unused dim that stride should not be 613 // present in the candidate type. Note that there could be multiple dimensions 614 // that have the same size. We dont need to exactly figure out which dim 615 // corresponds to which stride, we just need to verify that the number of 616 // reptitions of a stride in the original + number of unused dims with that 617 // stride == number of repititions of a stride in the candidate. 618 std::map<int64_t, unsigned> currUnaccountedStrides = 619 getNumOccurences(originalStrides); 620 std::map<int64_t, unsigned> candidateStridesNumOccurences = 621 getNumOccurences(candidateStrides); 622 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) { 623 if (!unusedDims.test(dim)) 624 continue; 625 int64_t originalStride = originalStrides[dim]; 626 if (currUnaccountedStrides[originalStride] > 627 candidateStridesNumOccurences[originalStride]) { 628 // This dim can be treated as dropped. 629 currUnaccountedStrides[originalStride]--; 630 continue; 631 } 632 if (currUnaccountedStrides[originalStride] == 633 candidateStridesNumOccurences[originalStride]) { 634 // The stride for this is not dropped. Keep as is. 635 unusedDims.reset(dim); 636 continue; 637 } 638 if (currUnaccountedStrides[originalStride] < 639 candidateStridesNumOccurences[originalStride]) { 640 // This should never happen. Cant have a stride in the reduced rank type 641 // that wasnt in the original one. 642 return llvm::None; 643 } 644 } 645 646 if ((int64_t)unusedDims.count() + reducedType.getRank() != 647 originalType.getRank()) 648 return llvm::None; 649 return unusedDims; 650 } 651 652 llvm::SmallBitVector SubViewOp::getDroppedDims() { 653 MemRefType sourceType = getSourceType(); 654 MemRefType resultType = getType(); 655 llvm::Optional<llvm::SmallBitVector> unusedDims = 656 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 657 assert(unusedDims && "unable to find unused dims of subview"); 658 return *unusedDims; 659 } 660 661 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 662 // All forms of folding require a known index. 663 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 664 if (!index) 665 return {}; 666 667 // Folding for unranked types (UnrankedMemRefType) is not supported. 668 auto memrefType = source().getType().dyn_cast<MemRefType>(); 669 if (!memrefType) 670 return {}; 671 672 // Fold if the shape extent along the given index is known. 673 if (!memrefType.isDynamicDim(index.getInt())) { 674 Builder builder(getContext()); 675 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 676 } 677 678 // The size at the given index is now known to be a dynamic size. 679 unsigned unsignedIndex = index.getValue().getZExtValue(); 680 681 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 682 Operation *definingOp = source().getDefiningOp(); 683 684 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 685 return *(alloc.getDynamicSizes().begin() + 686 memrefType.getDynamicDimIndex(unsignedIndex)); 687 688 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 689 return *(alloca.getDynamicSizes().begin() + 690 memrefType.getDynamicDimIndex(unsignedIndex)); 691 692 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 693 return *(view.getDynamicSizes().begin() + 694 memrefType.getDynamicDimIndex(unsignedIndex)); 695 696 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 697 llvm::SmallBitVector unusedDims = subview.getDroppedDims(); 698 unsigned resultIndex = 0; 699 unsigned sourceRank = subview.getSourceType().getRank(); 700 unsigned sourceIndex = 0; 701 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 702 if (unusedDims.test(i)) 703 continue; 704 if (resultIndex == unsignedIndex) { 705 sourceIndex = i; 706 break; 707 } 708 resultIndex++; 709 } 710 assert(subview.isDynamicSize(sourceIndex) && 711 "expected dynamic subview size"); 712 return subview.getDynamicSize(sourceIndex); 713 } 714 715 if (auto sizeInterface = 716 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 717 assert(sizeInterface.isDynamicSize(unsignedIndex) && 718 "Expected dynamic subview size"); 719 return sizeInterface.getDynamicSize(unsignedIndex); 720 } 721 722 // dim(memrefcast) -> dim 723 if (succeeded(foldMemRefCast(*this))) 724 return getResult(); 725 726 return {}; 727 } 728 729 namespace { 730 /// Fold dim of a memref reshape operation to a load into the reshape's shape 731 /// operand. 732 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 733 using OpRewritePattern<DimOp>::OpRewritePattern; 734 735 LogicalResult matchAndRewrite(DimOp dim, 736 PatternRewriter &rewriter) const override { 737 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 738 739 if (!reshape) 740 return failure(); 741 742 // Place the load directly after the reshape to ensure that the shape memref 743 // was not mutated. 744 rewriter.setInsertionPointAfter(reshape); 745 Location loc = dim.getLoc(); 746 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 747 if (load.getType() != dim.getType()) 748 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 749 rewriter.replaceOp(dim, load); 750 return success(); 751 } 752 }; 753 754 } // namespace 755 756 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 757 MLIRContext *context) { 758 results.add<DimOfMemRefReshape>(context); 759 } 760 761 // --------------------------------------------------------------------------- 762 // DmaStartOp 763 // --------------------------------------------------------------------------- 764 765 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 766 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 767 ValueRange destIndices, Value numElements, 768 Value tagMemRef, ValueRange tagIndices, Value stride, 769 Value elementsPerStride) { 770 result.addOperands(srcMemRef); 771 result.addOperands(srcIndices); 772 result.addOperands(destMemRef); 773 result.addOperands(destIndices); 774 result.addOperands({numElements, tagMemRef}); 775 result.addOperands(tagIndices); 776 if (stride) 777 result.addOperands({stride, elementsPerStride}); 778 } 779 780 void DmaStartOp::print(OpAsmPrinter &p) { 781 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " 782 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() 783 << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; 784 if (isStrided()) 785 p << ", " << getStride() << ", " << getNumElementsPerStride(); 786 787 p.printOptionalAttrDict((*this)->getAttrs()); 788 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 789 << ", " << getTagMemRef().getType(); 790 } 791 792 // Parse DmaStartOp. 793 // Ex: 794 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 795 // %tag[%index], %stride, %num_elt_per_stride : 796 // : memref<3076 x f32, 0>, 797 // memref<1024 x f32, 2>, 798 // memref<1 x i32> 799 // 800 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 801 OpAsmParser::OperandType srcMemRefInfo; 802 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 803 OpAsmParser::OperandType dstMemRefInfo; 804 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 805 OpAsmParser::OperandType numElementsInfo; 806 OpAsmParser::OperandType tagMemrefInfo; 807 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 808 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 809 810 SmallVector<Type, 3> types; 811 auto indexType = parser.getBuilder().getIndexType(); 812 813 // Parse and resolve the following list of operands: 814 // *) source memref followed by its indices (in square brackets). 815 // *) destination memref followed by its indices (in square brackets). 816 // *) dma size in KiB. 817 if (parser.parseOperand(srcMemRefInfo) || 818 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 819 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 820 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 821 parser.parseComma() || parser.parseOperand(numElementsInfo) || 822 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 823 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 824 return failure(); 825 826 // Parse optional stride and elements per stride. 827 if (parser.parseTrailingOperandList(strideInfo)) 828 return failure(); 829 830 bool isStrided = strideInfo.size() == 2; 831 if (!strideInfo.empty() && !isStrided) { 832 return parser.emitError(parser.getNameLoc(), 833 "expected two stride related operands"); 834 } 835 836 if (parser.parseColonTypeList(types)) 837 return failure(); 838 if (types.size() != 3) 839 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 840 841 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 842 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 843 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 844 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 845 // size should be an index. 846 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 847 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 848 // tag indices should be index. 849 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 850 return failure(); 851 852 if (isStrided) { 853 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 854 return failure(); 855 } 856 857 return success(); 858 } 859 860 LogicalResult DmaStartOp::verify() { 861 unsigned numOperands = getNumOperands(); 862 863 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 864 // the number of elements. 865 if (numOperands < 4) 866 return emitOpError("expected at least 4 operands"); 867 868 // Check types of operands. The order of these calls is important: the later 869 // calls rely on some type properties to compute the operand position. 870 // 1. Source memref. 871 if (!getSrcMemRef().getType().isa<MemRefType>()) 872 return emitOpError("expected source to be of memref type"); 873 if (numOperands < getSrcMemRefRank() + 4) 874 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 875 << " operands"; 876 if (!getSrcIndices().empty() && 877 !llvm::all_of(getSrcIndices().getTypes(), 878 [](Type t) { return t.isIndex(); })) 879 return emitOpError("expected source indices to be of index type"); 880 881 // 2. Destination memref. 882 if (!getDstMemRef().getType().isa<MemRefType>()) 883 return emitOpError("expected destination to be of memref type"); 884 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 885 if (numOperands < numExpectedOperands) 886 return emitOpError() << "expected at least " << numExpectedOperands 887 << " operands"; 888 if (!getDstIndices().empty() && 889 !llvm::all_of(getDstIndices().getTypes(), 890 [](Type t) { return t.isIndex(); })) 891 return emitOpError("expected destination indices to be of index type"); 892 893 // 3. Number of elements. 894 if (!getNumElements().getType().isIndex()) 895 return emitOpError("expected num elements to be of index type"); 896 897 // 4. Tag memref. 898 if (!getTagMemRef().getType().isa<MemRefType>()) 899 return emitOpError("expected tag to be of memref type"); 900 numExpectedOperands += getTagMemRefRank(); 901 if (numOperands < numExpectedOperands) 902 return emitOpError() << "expected at least " << numExpectedOperands 903 << " operands"; 904 if (!getTagIndices().empty() && 905 !llvm::all_of(getTagIndices().getTypes(), 906 [](Type t) { return t.isIndex(); })) 907 return emitOpError("expected tag indices to be of index type"); 908 909 // Optional stride-related operands must be either both present or both 910 // absent. 911 if (numOperands != numExpectedOperands && 912 numOperands != numExpectedOperands + 2) 913 return emitOpError("incorrect number of operands"); 914 915 // 5. Strides. 916 if (isStrided()) { 917 if (!getStride().getType().isIndex() || 918 !getNumElementsPerStride().getType().isIndex()) 919 return emitOpError( 920 "expected stride and num elements per stride to be of type index"); 921 } 922 923 return success(); 924 } 925 926 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 927 SmallVectorImpl<OpFoldResult> &results) { 928 /// dma_start(memrefcast) -> dma_start 929 return foldMemRefCast(*this); 930 } 931 932 // --------------------------------------------------------------------------- 933 // DmaWaitOp 934 // --------------------------------------------------------------------------- 935 936 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 937 SmallVectorImpl<OpFoldResult> &results) { 938 /// dma_wait(memrefcast) -> dma_wait 939 return foldMemRefCast(*this); 940 } 941 942 LogicalResult DmaWaitOp::verify() { 943 // Check that the number of tag indices matches the tagMemRef rank. 944 unsigned numTagIndices = tagIndices().size(); 945 unsigned tagMemRefRank = getTagMemRefRank(); 946 if (numTagIndices != tagMemRefRank) 947 return emitOpError() << "expected tagIndices to have the same number of " 948 "elements as the tagMemRef rank, expected " 949 << tagMemRefRank << ", but got " << numTagIndices; 950 return success(); 951 } 952 953 //===----------------------------------------------------------------------===// 954 // GenericAtomicRMWOp 955 //===----------------------------------------------------------------------===// 956 957 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, 958 Value memref, ValueRange ivs) { 959 result.addOperands(memref); 960 result.addOperands(ivs); 961 962 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) { 963 Type elementType = memrefType.getElementType(); 964 result.addTypes(elementType); 965 966 Region *bodyRegion = result.addRegion(); 967 bodyRegion->push_back(new Block()); 968 bodyRegion->addArgument(elementType, memref.getLoc()); 969 } 970 } 971 972 LogicalResult GenericAtomicRMWOp::verify() { 973 auto &body = getRegion(); 974 if (body.getNumArguments() != 1) 975 return emitOpError("expected single number of entry block arguments"); 976 977 if (getResult().getType() != body.getArgument(0).getType()) 978 return emitOpError("expected block argument of the same type result type"); 979 980 bool hasSideEffects = 981 body.walk([&](Operation *nestedOp) { 982 if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) 983 return WalkResult::advance(); 984 nestedOp->emitError( 985 "body of 'memref.generic_atomic_rmw' should contain " 986 "only operations with no side effects"); 987 return WalkResult::interrupt(); 988 }) 989 .wasInterrupted(); 990 return hasSideEffects ? failure() : success(); 991 } 992 993 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, 994 OperationState &result) { 995 OpAsmParser::OperandType memref; 996 Type memrefType; 997 SmallVector<OpAsmParser::OperandType, 4> ivs; 998 999 Type indexType = parser.getBuilder().getIndexType(); 1000 if (parser.parseOperand(memref) || 1001 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || 1002 parser.parseColonType(memrefType) || 1003 parser.resolveOperand(memref, memrefType, result.operands) || 1004 parser.resolveOperands(ivs, indexType, result.operands)) 1005 return failure(); 1006 1007 Region *body = result.addRegion(); 1008 if (parser.parseRegion(*body, llvm::None, llvm::None) || 1009 parser.parseOptionalAttrDict(result.attributes)) 1010 return failure(); 1011 result.types.push_back(memrefType.cast<MemRefType>().getElementType()); 1012 return success(); 1013 } 1014 1015 void GenericAtomicRMWOp::print(OpAsmPrinter &p) { 1016 p << ' ' << memref() << "[" << indices() << "] : " << memref().getType() 1017 << ' '; 1018 p.printRegion(getRegion()); 1019 p.printOptionalAttrDict((*this)->getAttrs()); 1020 } 1021 1022 //===----------------------------------------------------------------------===// 1023 // AtomicYieldOp 1024 //===----------------------------------------------------------------------===// 1025 1026 LogicalResult AtomicYieldOp::verify() { 1027 Type parentType = (*this)->getParentOp()->getResultTypes().front(); 1028 Type resultType = result().getType(); 1029 if (parentType != resultType) 1030 return emitOpError() << "types mismatch between yield op: " << resultType 1031 << " and its parent: " << parentType; 1032 return success(); 1033 } 1034 1035 //===----------------------------------------------------------------------===// 1036 // GlobalOp 1037 //===----------------------------------------------------------------------===// 1038 1039 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1040 TypeAttr type, 1041 Attribute initialValue) { 1042 p << type; 1043 if (!op.isExternal()) { 1044 p << " = "; 1045 if (op.isUninitialized()) 1046 p << "uninitialized"; 1047 else 1048 p.printAttributeWithoutType(initialValue); 1049 } 1050 } 1051 1052 static ParseResult 1053 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1054 Attribute &initialValue) { 1055 Type type; 1056 if (parser.parseType(type)) 1057 return failure(); 1058 1059 auto memrefType = type.dyn_cast<MemRefType>(); 1060 if (!memrefType || !memrefType.hasStaticShape()) 1061 return parser.emitError(parser.getNameLoc()) 1062 << "type should be static shaped memref, but got " << type; 1063 typeAttr = TypeAttr::get(type); 1064 1065 if (parser.parseOptionalEqual()) 1066 return success(); 1067 1068 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1069 initialValue = UnitAttr::get(parser.getContext()); 1070 return success(); 1071 } 1072 1073 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1074 if (parser.parseAttribute(initialValue, tensorType)) 1075 return failure(); 1076 if (!initialValue.isa<ElementsAttr>()) 1077 return parser.emitError(parser.getNameLoc()) 1078 << "initial value should be a unit or elements attribute"; 1079 return success(); 1080 } 1081 1082 LogicalResult GlobalOp::verify() { 1083 auto memrefType = type().dyn_cast<MemRefType>(); 1084 if (!memrefType || !memrefType.hasStaticShape()) 1085 return emitOpError("type should be static shaped memref, but got ") 1086 << type(); 1087 1088 // Verify that the initial value, if present, is either a unit attribute or 1089 // an elements attribute. 1090 if (initial_value().hasValue()) { 1091 Attribute initValue = initial_value().getValue(); 1092 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1093 return emitOpError("initial value should be a unit or elements " 1094 "attribute, but got ") 1095 << initValue; 1096 1097 // Check that the type of the initial value is compatible with the type of 1098 // the global variable. 1099 if (initValue.isa<ElementsAttr>()) { 1100 Type initType = initValue.getType(); 1101 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1102 if (initType != tensorType) 1103 return emitOpError("initial value expected to be of type ") 1104 << tensorType << ", but was of type " << initType; 1105 } 1106 } 1107 1108 if (Optional<uint64_t> alignAttr = alignment()) { 1109 uint64_t alignment = alignAttr.getValue(); 1110 1111 if (!llvm::isPowerOf2_64(alignment)) 1112 return emitError() << "alignment attribute value " << alignment 1113 << " is not a power of 2"; 1114 } 1115 1116 // TODO: verify visibility for declarations. 1117 return success(); 1118 } 1119 1120 //===----------------------------------------------------------------------===// 1121 // GetGlobalOp 1122 //===----------------------------------------------------------------------===// 1123 1124 LogicalResult 1125 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1126 // Verify that the result type is same as the type of the referenced 1127 // memref.global op. 1128 auto global = 1129 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1130 if (!global) 1131 return emitOpError("'") 1132 << name() << "' does not reference a valid global memref"; 1133 1134 Type resultType = result().getType(); 1135 if (global.type() != resultType) 1136 return emitOpError("result type ") 1137 << resultType << " does not match type " << global.type() 1138 << " of the global memref @" << name(); 1139 return success(); 1140 } 1141 1142 //===----------------------------------------------------------------------===// 1143 // LoadOp 1144 //===----------------------------------------------------------------------===// 1145 1146 LogicalResult LoadOp::verify() { 1147 if (getNumOperands() != 1 + getMemRefType().getRank()) 1148 return emitOpError("incorrect number of indices for load"); 1149 return success(); 1150 } 1151 1152 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1153 /// load(memrefcast) -> load 1154 if (succeeded(foldMemRefCast(*this))) 1155 return getResult(); 1156 return OpFoldResult(); 1157 } 1158 1159 //===----------------------------------------------------------------------===// 1160 // PrefetchOp 1161 //===----------------------------------------------------------------------===// 1162 1163 void PrefetchOp::print(OpAsmPrinter &p) { 1164 p << " " << memref() << '['; 1165 p.printOperands(indices()); 1166 p << ']' << ", " << (isWrite() ? "write" : "read"); 1167 p << ", locality<" << localityHint(); 1168 p << ">, " << (isDataCache() ? "data" : "instr"); 1169 p.printOptionalAttrDict( 1170 (*this)->getAttrs(), 1171 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1172 p << " : " << getMemRefType(); 1173 } 1174 1175 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { 1176 OpAsmParser::OperandType memrefInfo; 1177 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1178 IntegerAttr localityHint; 1179 MemRefType type; 1180 StringRef readOrWrite, cacheType; 1181 1182 auto indexTy = parser.getBuilder().getIndexType(); 1183 auto i32Type = parser.getBuilder().getIntegerType(32); 1184 if (parser.parseOperand(memrefInfo) || 1185 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1186 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1187 parser.parseComma() || parser.parseKeyword("locality") || 1188 parser.parseLess() || 1189 parser.parseAttribute(localityHint, i32Type, "localityHint", 1190 result.attributes) || 1191 parser.parseGreater() || parser.parseComma() || 1192 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1193 parser.resolveOperand(memrefInfo, type, result.operands) || 1194 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1195 return failure(); 1196 1197 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1198 return parser.emitError(parser.getNameLoc(), 1199 "rw specifier has to be 'read' or 'write'"); 1200 result.addAttribute( 1201 PrefetchOp::getIsWriteAttrName(), 1202 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1203 1204 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1205 return parser.emitError(parser.getNameLoc(), 1206 "cache type has to be 'data' or 'instr'"); 1207 1208 result.addAttribute( 1209 PrefetchOp::getIsDataCacheAttrName(), 1210 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1211 1212 return success(); 1213 } 1214 1215 LogicalResult PrefetchOp::verify() { 1216 if (getNumOperands() != 1 + getMemRefType().getRank()) 1217 return emitOpError("too few indices"); 1218 1219 return success(); 1220 } 1221 1222 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1223 SmallVectorImpl<OpFoldResult> &results) { 1224 // prefetch(memrefcast) -> prefetch 1225 return foldMemRefCast(*this); 1226 } 1227 1228 //===----------------------------------------------------------------------===// 1229 // RankOp 1230 //===----------------------------------------------------------------------===// 1231 1232 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1233 // Constant fold rank when the rank of the operand is known. 1234 auto type = getOperand().getType(); 1235 auto shapedType = type.dyn_cast<ShapedType>(); 1236 if (shapedType && shapedType.hasRank()) 1237 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1238 return IntegerAttr(); 1239 } 1240 1241 //===----------------------------------------------------------------------===// 1242 // ReinterpretCastOp 1243 //===----------------------------------------------------------------------===// 1244 1245 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1246 /// `staticSizes` and `staticStrides` are automatically filled with 1247 /// source-memref-rank sentinel values that encode dynamic entries. 1248 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1249 MemRefType resultType, Value source, 1250 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1251 ArrayRef<OpFoldResult> strides, 1252 ArrayRef<NamedAttribute> attrs) { 1253 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1254 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1255 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1256 ShapedType::kDynamicStrideOrOffset); 1257 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1258 ShapedType::kDynamicSize); 1259 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1260 ShapedType::kDynamicStrideOrOffset); 1261 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1262 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1263 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1264 result.addAttributes(attrs); 1265 } 1266 1267 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1268 MemRefType resultType, Value source, 1269 int64_t offset, ArrayRef<int64_t> sizes, 1270 ArrayRef<int64_t> strides, 1271 ArrayRef<NamedAttribute> attrs) { 1272 SmallVector<OpFoldResult> sizeValues = 1273 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1274 return b.getI64IntegerAttr(v); 1275 })); 1276 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1277 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1278 return b.getI64IntegerAttr(v); 1279 })); 1280 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1281 strideValues, attrs); 1282 } 1283 1284 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1285 MemRefType resultType, Value source, Value offset, 1286 ValueRange sizes, ValueRange strides, 1287 ArrayRef<NamedAttribute> attrs) { 1288 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1289 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1290 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1291 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1292 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1293 } 1294 1295 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1296 // completed automatically, like we have for subview and extract_slice. 1297 LogicalResult ReinterpretCastOp::verify() { 1298 // The source and result memrefs should be in the same memory space. 1299 auto srcType = source().getType().cast<BaseMemRefType>(); 1300 auto resultType = getType().cast<MemRefType>(); 1301 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1302 return emitError("different memory spaces specified for source type ") 1303 << srcType << " and result memref type " << resultType; 1304 if (srcType.getElementType() != resultType.getElementType()) 1305 return emitError("different element types specified for source type ") 1306 << srcType << " and result memref type " << resultType; 1307 1308 // Match sizes in result memref type and in static_sizes attribute. 1309 for (auto &en : llvm::enumerate(llvm::zip( 1310 resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { 1311 int64_t resultSize = std::get<0>(en.value()); 1312 int64_t expectedSize = std::get<1>(en.value()); 1313 if (!ShapedType::isDynamic(resultSize) && 1314 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1315 return emitError("expected result type with size = ") 1316 << expectedSize << " instead of " << resultSize 1317 << " in dim = " << en.index(); 1318 } 1319 1320 // Match offset and strides in static_offset and static_strides attributes. If 1321 // result memref type has no affine map specified, this will assume an 1322 // identity layout. 1323 int64_t resultOffset; 1324 SmallVector<int64_t, 4> resultStrides; 1325 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1326 return emitError("expected result type to have strided layout but found ") 1327 << resultType; 1328 1329 // Match offset in result memref type and in static_offsets attribute. 1330 int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); 1331 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1332 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1333 resultOffset != expectedOffset) 1334 return emitError("expected result type with offset = ") 1335 << resultOffset << " instead of " << expectedOffset; 1336 1337 // Match strides in result memref type and in static_strides attribute. 1338 for (auto &en : llvm::enumerate(llvm::zip( 1339 resultStrides, extractFromI64ArrayAttr(static_strides())))) { 1340 int64_t resultStride = std::get<0>(en.value()); 1341 int64_t expectedStride = std::get<1>(en.value()); 1342 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1343 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1344 resultStride != expectedStride) 1345 return emitError("expected result type with stride = ") 1346 << expectedStride << " instead of " << resultStride 1347 << " in dim = " << en.index(); 1348 } 1349 1350 return success(); 1351 } 1352 1353 //===----------------------------------------------------------------------===// 1354 // Reassociative reshape ops 1355 //===----------------------------------------------------------------------===// 1356 1357 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1358 return getSymbolLessAffineMaps(getReassociationExprs()); 1359 } 1360 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1361 return convertReassociationIndicesToExprs(getContext(), 1362 getReassociationIndices()); 1363 } 1364 1365 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1366 return getSymbolLessAffineMaps(getReassociationExprs()); 1367 } 1368 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1369 return convertReassociationIndicesToExprs(getContext(), 1370 getReassociationIndices()); 1371 } 1372 1373 ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) { 1374 return parseReshapeLikeOp(parser, result); 1375 } 1376 void ExpandShapeOp::print(OpAsmPrinter &p) { 1377 ::mlir::printReshapeOp<ExpandShapeOp>(p, *this); 1378 } 1379 1380 ParseResult CollapseShapeOp::parse(OpAsmParser &parser, 1381 OperationState &result) { 1382 return parseReshapeLikeOp(parser, result); 1383 } 1384 void CollapseShapeOp::print(OpAsmPrinter &p) { 1385 ::mlir::printReshapeOp<CollapseShapeOp>(p, *this); 1386 } 1387 1388 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1389 /// copies. 1390 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1391 ArrayRef<int64_t> sizes, 1392 ArrayRef<AffineExpr> strides) { 1393 // Bands of extent one can be reshaped, as they are not reshaped at all. 1394 if (extent == 1) 1395 return true; 1396 // Otherwise, the size of the first dimension needs to be known. 1397 if (ShapedType::isDynamic(sizes[dim])) 1398 return false; 1399 assert(sizes.size() == strides.size() && "mismatched ranks"); 1400 // off by 1 indexing to avoid out of bounds 1401 // V 1402 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1403 // Only bands of static shapes are reshapable. This is due to the fact that 1404 // there is no relation between dynamic sizes and dynamic strides: we do not 1405 // have enough information to know whether a "-1" size corresponds to the 1406 // proper symbol in the AffineExpr of a stride. 1407 if (ShapedType::isDynamic(sizes[idx + 1])) 1408 return false; 1409 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1410 // simplify on the fly and catch more reshapable cases. 1411 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1412 return false; 1413 } 1414 return true; 1415 } 1416 1417 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1418 /// expected to be valid) to `type`. 1419 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1420 /// MemRefType. 1421 static MemRefType 1422 computeReshapeCollapsedType(MemRefType type, 1423 ArrayRef<AffineMap> reassociation) { 1424 auto sizes = type.getShape(); 1425 AffineExpr offset; 1426 SmallVector<AffineExpr, 4> strides; 1427 auto status = getStridesAndOffset(type, strides, offset); 1428 auto isIdentityLayout = type.getLayout().isIdentity(); 1429 (void)status; 1430 assert(succeeded(status) && "expected strided memref"); 1431 1432 SmallVector<int64_t, 4> newSizes; 1433 newSizes.reserve(reassociation.size()); 1434 SmallVector<AffineExpr, 4> newStrides; 1435 newStrides.reserve(reassociation.size()); 1436 1437 // Use the fact that reassociation is valid to simplify the logic: only use 1438 // each map's rank. 1439 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1440 unsigned currentDim = 0; 1441 for (AffineMap m : reassociation) { 1442 unsigned dim = m.getNumResults(); 1443 int64_t size = 1; 1444 AffineExpr stride = strides[currentDim + dim - 1]; 1445 if (isIdentityLayout || 1446 isReshapableDimBand(currentDim, dim, sizes, strides)) { 1447 for (unsigned d = 0; d < dim; ++d) { 1448 int64_t currentSize = sizes[currentDim + d]; 1449 if (ShapedType::isDynamic(currentSize)) { 1450 size = ShapedType::kDynamicSize; 1451 break; 1452 } 1453 size *= currentSize; 1454 } 1455 } else { 1456 size = ShapedType::kDynamicSize; 1457 stride = AffineExpr(); 1458 } 1459 newSizes.push_back(size); 1460 newStrides.push_back(stride); 1461 currentDim += dim; 1462 } 1463 1464 // Early-exit: if `type` is contiguous, the result must be contiguous. 1465 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1466 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1467 1468 // Convert back to int64_t because we don't have enough information to create 1469 // new strided layouts from AffineExpr only. This corresponds to a case where 1470 // copies may be necessary. 1471 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1472 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1473 intOffset = o.getValue(); 1474 SmallVector<int64_t, 4> intStrides; 1475 intStrides.reserve(strides.size()); 1476 for (auto stride : newStrides) { 1477 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1478 intStrides.push_back(cst.getValue()); 1479 else 1480 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1481 } 1482 auto layout = 1483 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1484 return canonicalizeStridedLayout( 1485 MemRefType::Builder(type).setShape(newSizes).setLayout( 1486 AffineMapAttr::get(layout))); 1487 } 1488 1489 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1490 ArrayRef<ReassociationIndices> reassociation, 1491 ArrayRef<NamedAttribute> attrs) { 1492 auto memRefType = src.getType().cast<MemRefType>(); 1493 auto resultType = computeReshapeCollapsedType( 1494 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1495 b.getContext(), reassociation))); 1496 build(b, result, resultType, src, attrs); 1497 result.addAttribute(getReassociationAttrName(), 1498 getReassociationIndicesAttribute(b, reassociation)); 1499 } 1500 1501 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1502 ArrayRef<ReassociationIndices> reassociation, 1503 ArrayRef<NamedAttribute> attrs) { 1504 auto memRefType = src.getType().cast<MemRefType>(); 1505 auto resultType = computeReshapeCollapsedType( 1506 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1507 b.getContext(), reassociation))); 1508 build(b, result, resultType, src, attrs); 1509 result.addAttribute(getReassociationAttrName(), 1510 getReassociationIndicesAttribute(b, reassociation)); 1511 } 1512 1513 template <typename ReshapeOp, 1514 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1515 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1516 MemRefType collapsedType) { 1517 if (failed( 1518 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1519 return failure(); 1520 auto maps = op.getReassociationMaps(); 1521 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1522 if (collapsedType != expectedType) 1523 return op.emitOpError("expected collapsed type to be ") 1524 << expectedType << ", but got " << collapsedType; 1525 return success(); 1526 } 1527 1528 LogicalResult ExpandShapeOp::verify() { 1529 return verifyReshapeOp(*this, getResultType(), getSrcType()); 1530 } 1531 1532 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1533 MLIRContext *context) { 1534 results.add<CollapseReshapeOps<ExpandShapeOp>, 1535 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1536 } 1537 1538 LogicalResult CollapseShapeOp::verify() { 1539 return verifyReshapeOp(*this, getSrcType(), getResultType()); 1540 } 1541 1542 struct CollapseShapeOpMemRefCastFolder 1543 : public OpRewritePattern<CollapseShapeOp> { 1544 public: 1545 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1546 1547 LogicalResult matchAndRewrite(CollapseShapeOp op, 1548 PatternRewriter &rewriter) const override { 1549 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1550 if (!cast) 1551 return failure(); 1552 1553 if (!CastOp::canFoldIntoConsumerOp(cast)) 1554 return failure(); 1555 1556 Type newResultType = computeReshapeCollapsedType( 1557 cast.getOperand().getType().cast<MemRefType>(), 1558 op.getReassociationMaps()); 1559 1560 if (newResultType == op.getResultType()) { 1561 rewriter.updateRootInPlace( 1562 op, [&]() { op.srcMutable().assign(cast.source()); }); 1563 } else { 1564 Value newOp = rewriter.create<CollapseShapeOp>( 1565 op->getLoc(), cast.source(), op.getReassociationIndices()); 1566 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1567 } 1568 return success(); 1569 } 1570 }; 1571 1572 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1573 MLIRContext *context) { 1574 results.add<CollapseReshapeOps<CollapseShapeOp>, 1575 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1576 CollapseShapeOpMemRefCastFolder>(context); 1577 } 1578 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1579 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1580 } 1581 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1582 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1583 } 1584 1585 //===----------------------------------------------------------------------===// 1586 // ReshapeOp 1587 //===----------------------------------------------------------------------===// 1588 1589 LogicalResult ReshapeOp::verify() { 1590 Type operandType = source().getType(); 1591 Type resultType = result().getType(); 1592 1593 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1594 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1595 if (operandElementType != resultElementType) 1596 return emitOpError("element types of source and destination memref " 1597 "types should be the same"); 1598 1599 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1600 if (!operandMemRefType.getLayout().isIdentity()) 1601 return emitOpError("source memref type should have identity affine map"); 1602 1603 int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0); 1604 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1605 if (resultMemRefType) { 1606 if (!resultMemRefType.getLayout().isIdentity()) 1607 return emitOpError("result memref type should have identity affine map"); 1608 if (shapeSize == ShapedType::kDynamicSize) 1609 return emitOpError("cannot use shape operand with dynamic length to " 1610 "reshape to statically-ranked memref type"); 1611 if (shapeSize != resultMemRefType.getRank()) 1612 return emitOpError( 1613 "length of shape operand differs from the result's memref rank"); 1614 } 1615 return success(); 1616 } 1617 1618 //===----------------------------------------------------------------------===// 1619 // StoreOp 1620 //===----------------------------------------------------------------------===// 1621 1622 LogicalResult StoreOp::verify() { 1623 if (getNumOperands() != 2 + getMemRefType().getRank()) 1624 return emitOpError("store index operand count not equal to memref rank"); 1625 1626 return success(); 1627 } 1628 1629 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1630 SmallVectorImpl<OpFoldResult> &results) { 1631 /// store(memrefcast) -> store 1632 return foldMemRefCast(*this, getValueToStore()); 1633 } 1634 1635 //===----------------------------------------------------------------------===// 1636 // SubViewOp 1637 //===----------------------------------------------------------------------===// 1638 1639 namespace { 1640 /// Helpers to write more idiomatic operations. 1641 namespace saturated_arith { 1642 struct Wrapper { 1643 explicit Wrapper(int64_t v) : v(v) {} 1644 operator int64_t() { return v; } 1645 int64_t v; 1646 }; 1647 Wrapper operator+(Wrapper a, int64_t b) { 1648 if (ShapedType::isDynamicStrideOrOffset(a) || 1649 ShapedType::isDynamicStrideOrOffset(b)) 1650 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1651 return Wrapper(a.v + b); 1652 } 1653 Wrapper operator*(Wrapper a, int64_t b) { 1654 if (ShapedType::isDynamicStrideOrOffset(a) || 1655 ShapedType::isDynamicStrideOrOffset(b)) 1656 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1657 return Wrapper(a.v * b); 1658 } 1659 } // namespace saturated_arith 1660 } // namespace 1661 1662 /// A subview result type can be fully inferred from the source type and the 1663 /// static representation of offsets, sizes and strides. Special sentinels 1664 /// encode the dynamic case. 1665 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1666 ArrayRef<int64_t> staticOffsets, 1667 ArrayRef<int64_t> staticSizes, 1668 ArrayRef<int64_t> staticStrides) { 1669 unsigned rank = sourceMemRefType.getRank(); 1670 (void)rank; 1671 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 1672 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 1673 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 1674 1675 // Extract source offset and strides. 1676 int64_t sourceOffset; 1677 SmallVector<int64_t, 4> sourceStrides; 1678 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1679 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1680 (void)res; 1681 1682 // Compute target offset whose value is: 1683 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1684 int64_t targetOffset = sourceOffset; 1685 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1686 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1687 using namespace saturated_arith; 1688 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1689 } 1690 1691 // Compute target stride whose value is: 1692 // `sourceStrides_i * staticStrides_i`. 1693 SmallVector<int64_t, 4> targetStrides; 1694 targetStrides.reserve(staticOffsets.size()); 1695 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1696 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1697 using namespace saturated_arith; 1698 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1699 } 1700 1701 // The type is now known. 1702 return MemRefType::get( 1703 staticSizes, sourceMemRefType.getElementType(), 1704 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1705 sourceMemRefType.getContext()), 1706 sourceMemRefType.getMemorySpace()); 1707 } 1708 1709 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1710 ArrayRef<OpFoldResult> offsets, 1711 ArrayRef<OpFoldResult> sizes, 1712 ArrayRef<OpFoldResult> strides) { 1713 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1714 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1715 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1716 ShapedType::kDynamicStrideOrOffset); 1717 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1718 ShapedType::kDynamicSize); 1719 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1720 ShapedType::kDynamicStrideOrOffset); 1721 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1722 staticSizes, staticStrides); 1723 } 1724 1725 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1726 MemRefType sourceRankedTensorType, 1727 ArrayRef<int64_t> offsets, 1728 ArrayRef<int64_t> sizes, 1729 ArrayRef<int64_t> strides) { 1730 auto inferredType = 1731 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 1732 .cast<MemRefType>(); 1733 assert(inferredType.getRank() >= resultRank && "expected "); 1734 int rankDiff = inferredType.getRank() - resultRank; 1735 if (rankDiff > 0) { 1736 auto shape = inferredType.getShape(); 1737 llvm::SmallBitVector dimsToProject = 1738 getPositionsOfShapeOne(rankDiff, shape); 1739 SmallVector<int64_t> projectedShape; 1740 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1741 if (!dimsToProject.test(pos)) 1742 projectedShape.push_back(shape[pos]); 1743 1744 AffineMap map = inferredType.getLayout().getAffineMap(); 1745 if (!map.isIdentity()) 1746 map = getProjectedMap(map, dimsToProject); 1747 inferredType = 1748 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1749 inferredType.getMemorySpace()); 1750 } 1751 return inferredType; 1752 } 1753 1754 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1755 MemRefType sourceRankedTensorType, 1756 ArrayRef<OpFoldResult> offsets, 1757 ArrayRef<OpFoldResult> sizes, 1758 ArrayRef<OpFoldResult> strides) { 1759 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1760 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1761 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1762 ShapedType::kDynamicStrideOrOffset); 1763 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1764 ShapedType::kDynamicSize); 1765 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1766 ShapedType::kDynamicStrideOrOffset); 1767 return SubViewOp::inferRankReducedResultType( 1768 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1769 staticStrides); 1770 } 1771 // Build a SubViewOp with mixed static and dynamic entries and custom result 1772 // type. If the type passed is nullptr, it is inferred. 1773 void SubViewOp::build(OpBuilder &b, OperationState &result, 1774 MemRefType resultType, Value source, 1775 ArrayRef<OpFoldResult> offsets, 1776 ArrayRef<OpFoldResult> sizes, 1777 ArrayRef<OpFoldResult> strides, 1778 ArrayRef<NamedAttribute> attrs) { 1779 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1780 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1781 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1782 ShapedType::kDynamicStrideOrOffset); 1783 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1784 ShapedType::kDynamicSize); 1785 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1786 ShapedType::kDynamicStrideOrOffset); 1787 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1788 // Structuring implementation this way avoids duplication between builders. 1789 if (!resultType) { 1790 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1791 staticSizes, staticStrides) 1792 .cast<MemRefType>(); 1793 } 1794 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1795 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1796 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1797 result.addAttributes(attrs); 1798 } 1799 1800 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1801 // type. 1802 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1803 ArrayRef<OpFoldResult> offsets, 1804 ArrayRef<OpFoldResult> sizes, 1805 ArrayRef<OpFoldResult> strides, 1806 ArrayRef<NamedAttribute> attrs) { 1807 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1808 } 1809 1810 // Build a SubViewOp with static entries and inferred result type. 1811 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1812 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1813 ArrayRef<int64_t> strides, 1814 ArrayRef<NamedAttribute> attrs) { 1815 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1816 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1817 return b.getI64IntegerAttr(v); 1818 })); 1819 SmallVector<OpFoldResult> sizeValues = 1820 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1821 return b.getI64IntegerAttr(v); 1822 })); 1823 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1824 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1825 return b.getI64IntegerAttr(v); 1826 })); 1827 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1828 } 1829 1830 // Build a SubViewOp with dynamic entries and custom result type. If the 1831 // type passed is nullptr, it is inferred. 1832 void SubViewOp::build(OpBuilder &b, OperationState &result, 1833 MemRefType resultType, Value source, 1834 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1835 ArrayRef<int64_t> strides, 1836 ArrayRef<NamedAttribute> attrs) { 1837 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1838 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1839 return b.getI64IntegerAttr(v); 1840 })); 1841 SmallVector<OpFoldResult> sizeValues = 1842 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1843 return b.getI64IntegerAttr(v); 1844 })); 1845 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1846 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1847 return b.getI64IntegerAttr(v); 1848 })); 1849 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1850 attrs); 1851 } 1852 1853 // Build a SubViewOp with dynamic entries and custom result type. If the type 1854 // passed is nullptr, it is inferred. 1855 void SubViewOp::build(OpBuilder &b, OperationState &result, 1856 MemRefType resultType, Value source, ValueRange offsets, 1857 ValueRange sizes, ValueRange strides, 1858 ArrayRef<NamedAttribute> attrs) { 1859 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1860 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1861 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1862 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1863 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1864 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1865 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1866 } 1867 1868 // Build a SubViewOp with dynamic entries and inferred result type. 1869 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1870 ValueRange offsets, ValueRange sizes, ValueRange strides, 1871 ArrayRef<NamedAttribute> attrs) { 1872 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1873 } 1874 1875 /// For ViewLikeOpInterface. 1876 Value SubViewOp::getViewSource() { return source(); } 1877 1878 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 1879 /// value). 1880 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 1881 AffineExpr t1Offset, t2Offset; 1882 SmallVector<AffineExpr> t1Strides, t2Strides; 1883 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 1884 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 1885 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 1886 } 1887 1888 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1889 /// This function is slight variant of `is subsequence` algorithm where 1890 /// not matching dimension must be 1. 1891 static SliceVerificationResult 1892 isRankReducedMemRefType(MemRefType originalType, 1893 MemRefType candidateRankReducedType, 1894 ArrayRef<OpFoldResult> sizes) { 1895 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 1896 if (partialRes != SliceVerificationResult::Success) 1897 return partialRes; 1898 1899 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 1900 originalType, candidateRankReducedType, sizes); 1901 1902 // Sizes cannot be matched in case empty vector is returned. 1903 if (!optionalUnusedDimsMask.hasValue()) 1904 return SliceVerificationResult::LayoutMismatch; 1905 1906 if (originalType.getMemorySpace() != 1907 candidateRankReducedType.getMemorySpace()) 1908 return SliceVerificationResult::MemSpaceMismatch; 1909 1910 // No amount of stride dropping can reconcile incompatible offsets. 1911 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 1912 return SliceVerificationResult::LayoutMismatch; 1913 1914 return SliceVerificationResult::Success; 1915 } 1916 1917 template <typename OpTy> 1918 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 1919 OpTy op, Type expectedType) { 1920 auto memrefType = expectedType.cast<ShapedType>(); 1921 switch (result) { 1922 case SliceVerificationResult::Success: 1923 return success(); 1924 case SliceVerificationResult::RankTooLarge: 1925 return op.emitError("expected result rank to be smaller or equal to ") 1926 << "the source rank. "; 1927 case SliceVerificationResult::SizeMismatch: 1928 return op.emitError("expected result type to be ") 1929 << expectedType 1930 << " or a rank-reduced version. (mismatch of result sizes) "; 1931 case SliceVerificationResult::ElemTypeMismatch: 1932 return op.emitError("expected result element type to be ") 1933 << memrefType.getElementType(); 1934 case SliceVerificationResult::MemSpaceMismatch: 1935 return op.emitError("expected result and source memory spaces to match."); 1936 case SliceVerificationResult::LayoutMismatch: 1937 return op.emitError("expected result type to be ") 1938 << expectedType 1939 << " or a rank-reduced version. (mismatch of result layout) "; 1940 } 1941 llvm_unreachable("unexpected subview verification result"); 1942 } 1943 1944 /// Verifier for SubViewOp. 1945 LogicalResult SubViewOp::verify() { 1946 MemRefType baseType = getSourceType(); 1947 MemRefType subViewType = getType(); 1948 1949 // The base memref and the view memref should be in the same memory space. 1950 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1951 return emitError("different memory spaces specified for base memref " 1952 "type ") 1953 << baseType << " and subview memref type " << subViewType; 1954 1955 // Verify that the base memref type has a strided layout map. 1956 if (!isStrided(baseType)) 1957 return emitError("base type ") << baseType << " is not strided"; 1958 1959 // Verify result type against inferred type. 1960 auto expectedType = SubViewOp::inferResultType( 1961 baseType, extractFromI64ArrayAttr(static_offsets()), 1962 extractFromI64ArrayAttr(static_sizes()), 1963 extractFromI64ArrayAttr(static_strides())); 1964 1965 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 1966 subViewType, getMixedSizes()); 1967 return produceSubViewErrorMsg(result, *this, expectedType); 1968 } 1969 1970 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 1971 return os << "range " << range.offset << ":" << range.size << ":" 1972 << range.stride; 1973 } 1974 1975 /// Return the list of Range (i.e. offset, size, stride). Each Range 1976 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1977 /// with `b` at location `loc`. 1978 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1979 OpBuilder &b, Location loc) { 1980 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1981 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1982 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1983 SmallVector<Range, 8> res; 1984 unsigned rank = ranks[0]; 1985 res.reserve(rank); 1986 for (unsigned idx = 0; idx < rank; ++idx) { 1987 Value offset = 1988 op.isDynamicOffset(idx) 1989 ? op.getDynamicOffset(idx) 1990 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1991 Value size = 1992 op.isDynamicSize(idx) 1993 ? op.getDynamicSize(idx) 1994 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 1995 Value stride = 1996 op.isDynamicStride(idx) 1997 ? op.getDynamicStride(idx) 1998 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 1999 res.emplace_back(Range{offset, size, stride}); 2000 } 2001 return res; 2002 } 2003 2004 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2005 /// deduce the result type for the given `sourceType`. Additionally, reduce the 2006 /// rank of the inferred result type if `currentResultType` is lower rank than 2007 /// `currentSourceType`. Use this signature if `sourceType` is updated together 2008 /// with the result type. In this case, it is important to compute the dropped 2009 /// dimensions using `currentSourceType` whose strides align with 2010 /// `currentResultType`. 2011 static MemRefType getCanonicalSubViewResultType( 2012 MemRefType currentResultType, MemRefType currentSourceType, 2013 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 2014 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 2015 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2016 mixedSizes, mixedStrides) 2017 .cast<MemRefType>(); 2018 llvm::Optional<llvm::SmallBitVector> unusedDims = 2019 computeMemRefRankReductionMask(currentSourceType, currentResultType, 2020 mixedSizes); 2021 // Return nullptr as failure mode. 2022 if (!unusedDims) 2023 return nullptr; 2024 SmallVector<int64_t> shape; 2025 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 2026 if (unusedDims->test(sizes.index())) 2027 continue; 2028 shape.push_back(sizes.value()); 2029 } 2030 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 2031 if (!layoutMap.isIdentity()) 2032 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 2033 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 2034 nonRankReducedType.getMemorySpace()); 2035 } 2036 2037 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2038 /// deduce the result type. Additionally, reduce the rank of the inferred result 2039 /// type if `currentResultType` is lower rank than `sourceType`. 2040 static MemRefType getCanonicalSubViewResultType( 2041 MemRefType currentResultType, MemRefType sourceType, 2042 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 2043 ArrayRef<OpFoldResult> mixedStrides) { 2044 return getCanonicalSubViewResultType(currentResultType, sourceType, 2045 sourceType, mixedOffsets, mixedSizes, 2046 mixedStrides); 2047 } 2048 2049 /// Helper method to check if a `subview` operation is trivially a no-op. This 2050 /// is the case if the all offsets are zero, all strides are 1, and the source 2051 /// shape is same as the size of the subview. In such cases, the subview can be 2052 /// folded into its source. 2053 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 2054 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 2055 return false; 2056 2057 auto mixedOffsets = subViewOp.getMixedOffsets(); 2058 auto mixedSizes = subViewOp.getMixedSizes(); 2059 auto mixedStrides = subViewOp.getMixedStrides(); 2060 2061 // Check offsets are zero. 2062 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 2063 Optional<int64_t> intValue = getConstantIntValue(ofr); 2064 return !intValue || intValue.getValue() != 0; 2065 })) 2066 return false; 2067 2068 // Check strides are one. 2069 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 2070 Optional<int64_t> intValue = getConstantIntValue(ofr); 2071 return !intValue || intValue.getValue() != 1; 2072 })) 2073 return false; 2074 2075 // Check all size values are static and matches the (static) source shape. 2076 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 2077 for (const auto &size : llvm::enumerate(mixedSizes)) { 2078 Optional<int64_t> intValue = getConstantIntValue(size.value()); 2079 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 2080 return false; 2081 } 2082 // All conditions met. The `SubViewOp` is foldable as a no-op. 2083 return true; 2084 } 2085 2086 namespace { 2087 /// Pattern to rewrite a subview op with MemRefCast arguments. 2088 /// This essentially pushes memref.cast past its consuming subview when 2089 /// `canFoldIntoConsumerOp` is true. 2090 /// 2091 /// Example: 2092 /// ``` 2093 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2094 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2095 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2096 /// ``` 2097 /// is rewritten into: 2098 /// ``` 2099 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2100 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2101 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2102 /// ``` 2103 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2104 public: 2105 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2106 2107 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2108 PatternRewriter &rewriter) const override { 2109 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2110 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2111 return matchPattern(operand, matchConstantIndex()); 2112 })) 2113 return failure(); 2114 2115 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2116 if (!castOp) 2117 return failure(); 2118 2119 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2120 return failure(); 2121 2122 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 2123 // MemRefCastOp source operand type to infer the result type and the current 2124 // SubViewOp source operand type to compute the dropped dimensions if the 2125 // operation is rank-reducing. 2126 auto resultType = getCanonicalSubViewResultType( 2127 subViewOp.getType(), subViewOp.getSourceType(), 2128 castOp.source().getType().cast<MemRefType>(), 2129 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2130 subViewOp.getMixedStrides()); 2131 if (!resultType) 2132 return failure(); 2133 2134 Value newSubView = rewriter.create<SubViewOp>( 2135 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2136 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2137 subViewOp.static_sizes(), subViewOp.static_strides()); 2138 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2139 newSubView); 2140 return success(); 2141 } 2142 }; 2143 2144 /// Canonicalize subview ops that are no-ops. When the source shape is not same 2145 /// as a result shape due to use of `affine_map`. 2146 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 2147 public: 2148 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2149 2150 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2151 PatternRewriter &rewriter) const override { 2152 if (!isTrivialSubViewOp(subViewOp)) 2153 return failure(); 2154 if (subViewOp.getSourceType() == subViewOp.getType()) { 2155 rewriter.replaceOp(subViewOp, subViewOp.source()); 2156 return success(); 2157 } 2158 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2159 subViewOp.source()); 2160 return success(); 2161 } 2162 }; 2163 } // namespace 2164 2165 /// Return the canonical type of the result of a subview. 2166 struct SubViewReturnTypeCanonicalizer { 2167 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2168 ArrayRef<OpFoldResult> mixedSizes, 2169 ArrayRef<OpFoldResult> mixedStrides) { 2170 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2171 mixedOffsets, mixedSizes, 2172 mixedStrides); 2173 } 2174 }; 2175 2176 /// A canonicalizer wrapper to replace SubViewOps. 2177 struct SubViewCanonicalizer { 2178 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2179 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 2180 } 2181 }; 2182 2183 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2184 MLIRContext *context) { 2185 results 2186 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2187 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2188 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2189 } 2190 2191 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2192 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2193 auto sourceShapedType = source().getType().cast<ShapedType>(); 2194 2195 if (resultShapedType.hasStaticShape() && 2196 resultShapedType == sourceShapedType) { 2197 return getViewSource(); 2198 } 2199 2200 return {}; 2201 } 2202 2203 //===----------------------------------------------------------------------===// 2204 // TransposeOp 2205 //===----------------------------------------------------------------------===// 2206 2207 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2208 static MemRefType inferTransposeResultType(MemRefType memRefType, 2209 AffineMap permutationMap) { 2210 auto rank = memRefType.getRank(); 2211 auto originalSizes = memRefType.getShape(); 2212 // Compute permuted sizes. 2213 SmallVector<int64_t, 4> sizes(rank, 0); 2214 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2215 sizes[en.index()] = 2216 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2217 2218 // Compute permuted strides. 2219 int64_t offset; 2220 SmallVector<int64_t, 4> strides; 2221 auto res = getStridesAndOffset(memRefType, strides, offset); 2222 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2223 (void)res; 2224 auto map = 2225 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2226 map = permutationMap ? map.compose(permutationMap) : map; 2227 return MemRefType::Builder(memRefType) 2228 .setShape(sizes) 2229 .setLayout(AffineMapAttr::get(map)); 2230 } 2231 2232 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2233 AffineMapAttr permutation, 2234 ArrayRef<NamedAttribute> attrs) { 2235 auto permutationMap = permutation.getValue(); 2236 assert(permutationMap); 2237 2238 auto memRefType = in.getType().cast<MemRefType>(); 2239 // Compute result type. 2240 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2241 2242 build(b, result, resultType, in, attrs); 2243 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2244 } 2245 2246 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2247 void TransposeOp::print(OpAsmPrinter &p) { 2248 p << " " << in() << " " << permutation(); 2249 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); 2250 p << " : " << in().getType() << " to " << getType(); 2251 } 2252 2253 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { 2254 OpAsmParser::OperandType in; 2255 AffineMap permutation; 2256 MemRefType srcType, dstType; 2257 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2258 parser.parseOptionalAttrDict(result.attributes) || 2259 parser.parseColonType(srcType) || 2260 parser.resolveOperand(in, srcType, result.operands) || 2261 parser.parseKeywordType("to", dstType) || 2262 parser.addTypeToList(dstType, result.types)) 2263 return failure(); 2264 2265 result.addAttribute(TransposeOp::getPermutationAttrName(), 2266 AffineMapAttr::get(permutation)); 2267 return success(); 2268 } 2269 2270 LogicalResult TransposeOp::verify() { 2271 if (!permutation().isPermutation()) 2272 return emitOpError("expected a permutation map"); 2273 if (permutation().getNumDims() != getShapedType().getRank()) 2274 return emitOpError("expected a permutation map of same rank as the input"); 2275 2276 auto srcType = in().getType().cast<MemRefType>(); 2277 auto dstType = getType().cast<MemRefType>(); 2278 auto transposedType = inferTransposeResultType(srcType, permutation()); 2279 if (dstType != transposedType) 2280 return emitOpError("output type ") 2281 << dstType << " does not match transposed input type " << srcType 2282 << ", " << transposedType; 2283 return success(); 2284 } 2285 2286 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2287 if (succeeded(foldMemRefCast(*this))) 2288 return getResult(); 2289 return {}; 2290 } 2291 2292 //===----------------------------------------------------------------------===// 2293 // ViewOp 2294 //===----------------------------------------------------------------------===// 2295 2296 ParseResult ViewOp::parse(OpAsmParser &parser, OperationState &result) { 2297 OpAsmParser::OperandType srcInfo; 2298 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2299 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2300 auto indexType = parser.getBuilder().getIndexType(); 2301 Type srcType, dstType; 2302 SMLoc offsetLoc; 2303 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2304 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2305 return failure(); 2306 2307 if (offsetInfo.size() != 1) 2308 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2309 2310 return failure( 2311 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2312 parser.parseOptionalAttrDict(result.attributes) || 2313 parser.parseColonType(srcType) || 2314 parser.resolveOperand(srcInfo, srcType, result.operands) || 2315 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2316 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2317 parser.parseKeywordType("to", dstType) || 2318 parser.addTypeToList(dstType, result.types)); 2319 } 2320 2321 void ViewOp::print(OpAsmPrinter &p) { 2322 p << ' ' << getOperand(0) << '['; 2323 p.printOperand(byte_shift()); 2324 p << "][" << sizes() << ']'; 2325 p.printOptionalAttrDict((*this)->getAttrs()); 2326 p << " : " << getOperand(0).getType() << " to " << getType(); 2327 } 2328 2329 LogicalResult ViewOp::verify() { 2330 auto baseType = getOperand(0).getType().cast<MemRefType>(); 2331 auto viewType = getType(); 2332 2333 // The base memref should have identity layout map (or none). 2334 if (!baseType.getLayout().isIdentity()) 2335 return emitError("unsupported map for base memref type ") << baseType; 2336 2337 // The result memref should have identity layout map (or none). 2338 if (!viewType.getLayout().isIdentity()) 2339 return emitError("unsupported map for result memref type ") << viewType; 2340 2341 // The base memref and the view memref should be in the same memory space. 2342 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2343 return emitError("different memory spaces specified for base memref " 2344 "type ") 2345 << baseType << " and view memref type " << viewType; 2346 2347 // Verify that we have the correct number of sizes for the result type. 2348 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2349 if (sizes().size() != numDynamicDims) 2350 return emitError("incorrect number of size operands for type ") << viewType; 2351 2352 return success(); 2353 } 2354 2355 Value ViewOp::getViewSource() { return source(); } 2356 2357 namespace { 2358 2359 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2360 using OpRewritePattern<ViewOp>::OpRewritePattern; 2361 2362 LogicalResult matchAndRewrite(ViewOp viewOp, 2363 PatternRewriter &rewriter) const override { 2364 // Return if none of the operands are constants. 2365 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2366 return matchPattern(operand, matchConstantIndex()); 2367 })) 2368 return failure(); 2369 2370 // Get result memref type. 2371 auto memrefType = viewOp.getType(); 2372 2373 // Get offset from old memref view type 'memRefType'. 2374 int64_t oldOffset; 2375 SmallVector<int64_t, 4> oldStrides; 2376 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2377 return failure(); 2378 assert(oldOffset == 0 && "Expected 0 offset"); 2379 2380 SmallVector<Value, 4> newOperands; 2381 2382 // Offset cannot be folded into result type. 2383 2384 // Fold any dynamic dim operands which are produced by a constant. 2385 SmallVector<int64_t, 4> newShapeConstants; 2386 newShapeConstants.reserve(memrefType.getRank()); 2387 2388 unsigned dynamicDimPos = 0; 2389 unsigned rank = memrefType.getRank(); 2390 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2391 int64_t dimSize = memrefType.getDimSize(dim); 2392 // If this is already static dimension, keep it. 2393 if (!ShapedType::isDynamic(dimSize)) { 2394 newShapeConstants.push_back(dimSize); 2395 continue; 2396 } 2397 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2398 if (auto constantIndexOp = 2399 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2400 // Dynamic shape dimension will be folded. 2401 newShapeConstants.push_back(constantIndexOp.value()); 2402 } else { 2403 // Dynamic shape dimension not folded; copy operand from old memref. 2404 newShapeConstants.push_back(dimSize); 2405 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2406 } 2407 dynamicDimPos++; 2408 } 2409 2410 // Create new memref type with constant folded dims. 2411 MemRefType newMemRefType = 2412 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2413 // Nothing new, don't fold. 2414 if (newMemRefType == memrefType) 2415 return failure(); 2416 2417 // Create new ViewOp. 2418 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2419 viewOp.getOperand(0), 2420 viewOp.byte_shift(), newOperands); 2421 // Insert a cast so we have the same type as the old memref type. 2422 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp); 2423 return success(); 2424 } 2425 }; 2426 2427 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2428 using OpRewritePattern<ViewOp>::OpRewritePattern; 2429 2430 LogicalResult matchAndRewrite(ViewOp viewOp, 2431 PatternRewriter &rewriter) const override { 2432 Value memrefOperand = viewOp.getOperand(0); 2433 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2434 if (!memrefCastOp) 2435 return failure(); 2436 Value allocOperand = memrefCastOp.getOperand(); 2437 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2438 if (!allocOp) 2439 return failure(); 2440 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2441 viewOp.byte_shift(), viewOp.sizes()); 2442 return success(); 2443 } 2444 }; 2445 2446 } // namespace 2447 2448 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2449 MLIRContext *context) { 2450 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2451 } 2452 2453 //===----------------------------------------------------------------------===// 2454 // AtomicRMWOp 2455 //===----------------------------------------------------------------------===// 2456 2457 LogicalResult AtomicRMWOp::verify() { 2458 if (getMemRefType().getRank() != getNumOperands() - 2) 2459 return emitOpError( 2460 "expects the number of subscripts to be equal to memref rank"); 2461 switch (kind()) { 2462 case arith::AtomicRMWKind::addf: 2463 case arith::AtomicRMWKind::maxf: 2464 case arith::AtomicRMWKind::minf: 2465 case arith::AtomicRMWKind::mulf: 2466 if (!value().getType().isa<FloatType>()) 2467 return emitOpError() << "with kind '" 2468 << arith::stringifyAtomicRMWKind(kind()) 2469 << "' expects a floating-point type"; 2470 break; 2471 case arith::AtomicRMWKind::addi: 2472 case arith::AtomicRMWKind::maxs: 2473 case arith::AtomicRMWKind::maxu: 2474 case arith::AtomicRMWKind::mins: 2475 case arith::AtomicRMWKind::minu: 2476 case arith::AtomicRMWKind::muli: 2477 case arith::AtomicRMWKind::ori: 2478 case arith::AtomicRMWKind::andi: 2479 if (!value().getType().isa<IntegerType>()) 2480 return emitOpError() << "with kind '" 2481 << arith::stringifyAtomicRMWKind(kind()) 2482 << "' expects an integer type"; 2483 break; 2484 default: 2485 break; 2486 } 2487 return success(); 2488 } 2489 2490 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2491 /// atomicrmw(memrefcast) -> atomicrmw 2492 if (succeeded(foldMemRefCast(*this, value()))) 2493 return getResult(); 2494 return OpFoldResult(); 2495 } 2496 2497 //===----------------------------------------------------------------------===// 2498 // TableGen'd op method definitions 2499 //===----------------------------------------------------------------------===// 2500 2501 #define GET_OP_CLASSES 2502 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2503