1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/SmallBitVector.h" 25 26 using namespace mlir; 27 using namespace mlir::memref; 28 29 /// Materialize a single constant operation from a given attribute value with 30 /// the desired resultant type. 31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 32 Attribute value, Type type, 33 Location loc) { 34 if (arith::ConstantOp::isBuildableWith(value, type)) 35 return builder.create<arith::ConstantOp>(loc, value, type); 36 return nullptr; 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Common canonicalization pattern support logic 41 //===----------------------------------------------------------------------===// 42 43 /// This is a common class used for patterns of the form 44 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 45 /// into the root operation directly. 46 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 47 bool folded = false; 48 for (OpOperand &operand : op->getOpOperands()) { 49 auto cast = operand.get().getDefiningOp<CastOp>(); 50 if (cast && operand.get() != inner && 51 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 52 operand.set(cast.getOperand()); 53 folded = true; 54 } 55 } 56 return success(folded); 57 } 58 59 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 60 /// type. 61 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 62 if (auto memref = type.dyn_cast<MemRefType>()) 63 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 64 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 65 return UnrankedTensorType::get(memref.getElementType()); 66 return NoneType::get(type.getContext()); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // AllocOp / AllocaOp 71 //===----------------------------------------------------------------------===// 72 73 template <typename AllocLikeOp> 74 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 75 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 76 "applies to only alloc or alloca"); 77 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 78 if (!memRefType) 79 return op.emitOpError("result must be a memref"); 80 81 if (static_cast<int64_t>(op.dynamicSizes().size()) != 82 memRefType.getNumDynamicDims()) 83 return op.emitOpError("dimension operand count does not equal memref " 84 "dynamic dimension count"); 85 86 unsigned numSymbols = 0; 87 if (!memRefType.getLayout().isIdentity()) 88 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 89 if (op.symbolOperands().size() != numSymbols) 90 return op.emitOpError("symbol operand count does not equal memref symbol " 91 "count: expected ") 92 << numSymbols << ", got " << op.symbolOperands().size(); 93 94 return success(); 95 } 96 97 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } 98 99 LogicalResult AllocaOp::verify() { 100 // An alloca op needs to have an ancestor with an allocation scope trait. 101 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 102 return emitOpError( 103 "requires an ancestor op with AutomaticAllocationScope trait"); 104 105 return verifyAllocLikeOp(*this); 106 } 107 108 namespace { 109 /// Fold constant dimensions into an alloc like operation. 110 template <typename AllocLikeOp> 111 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 112 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 113 114 LogicalResult matchAndRewrite(AllocLikeOp alloc, 115 PatternRewriter &rewriter) const override { 116 // Check to see if any dimensions operands are constants. If so, we can 117 // substitute and drop them. 118 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 119 return matchPattern(operand, matchConstantIndex()); 120 })) 121 return failure(); 122 123 auto memrefType = alloc.getType(); 124 125 // Ok, we have one or more constant operands. Collect the non-constant ones 126 // and keep track of the resultant memref type to build. 127 SmallVector<int64_t, 4> newShapeConstants; 128 newShapeConstants.reserve(memrefType.getRank()); 129 SmallVector<Value, 4> dynamicSizes; 130 131 unsigned dynamicDimPos = 0; 132 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 133 int64_t dimSize = memrefType.getDimSize(dim); 134 // If this is already static dimension, keep it. 135 if (dimSize != -1) { 136 newShapeConstants.push_back(dimSize); 137 continue; 138 } 139 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 140 auto *defOp = dynamicSize.getDefiningOp(); 141 if (auto constantIndexOp = 142 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 143 // Dynamic shape dimension will be folded. 144 newShapeConstants.push_back(constantIndexOp.value()); 145 } else { 146 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 147 newShapeConstants.push_back(-1); 148 dynamicSizes.push_back(dynamicSize); 149 } 150 dynamicDimPos++; 151 } 152 153 // Create new memref type (which will have fewer dynamic dimensions). 154 MemRefType newMemRefType = 155 MemRefType::Builder(memrefType).setShape(newShapeConstants); 156 assert(static_cast<int64_t>(dynamicSizes.size()) == 157 newMemRefType.getNumDynamicDims()); 158 159 // Create and insert the alloc op for the new memref. 160 auto newAlloc = rewriter.create<AllocLikeOp>( 161 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 162 alloc.alignmentAttr()); 163 // Insert a cast so we have the same type as the old alloc. 164 auto resultCast = 165 rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc); 166 167 rewriter.replaceOp(alloc, {resultCast}); 168 return success(); 169 } 170 }; 171 172 /// Fold alloc operations with no users or only store and dealloc uses. 173 template <typename T> 174 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 175 using OpRewritePattern<T>::OpRewritePattern; 176 177 LogicalResult matchAndRewrite(T alloc, 178 PatternRewriter &rewriter) const override { 179 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 180 if (auto storeOp = dyn_cast<StoreOp>(op)) 181 return storeOp.value() == alloc; 182 return !isa<DeallocOp>(op); 183 })) 184 return failure(); 185 186 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 187 rewriter.eraseOp(user); 188 189 rewriter.eraseOp(alloc); 190 return success(); 191 } 192 }; 193 } // namespace 194 195 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 196 MLIRContext *context) { 197 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 198 } 199 200 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 201 MLIRContext *context) { 202 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 203 context); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // AllocaScopeOp 208 //===----------------------------------------------------------------------===// 209 210 void AllocaScopeOp::print(OpAsmPrinter &p) { 211 bool printBlockTerminators = false; 212 213 p << ' '; 214 if (!results().empty()) { 215 p << " -> (" << getResultTypes() << ")"; 216 printBlockTerminators = true; 217 } 218 p << ' '; 219 p.printRegion(bodyRegion(), 220 /*printEntryBlockArgs=*/false, 221 /*printBlockTerminators=*/printBlockTerminators); 222 p.printOptionalAttrDict((*this)->getAttrs()); 223 } 224 225 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { 226 // Create a region for the body. 227 result.regions.reserve(1); 228 Region *bodyRegion = result.addRegion(); 229 230 // Parse optional results type list. 231 if (parser.parseOptionalArrowTypeList(result.types)) 232 return failure(); 233 234 // Parse the body region. 235 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 236 return failure(); 237 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 238 result.location); 239 240 // Parse the optional attribute list. 241 if (parser.parseOptionalAttrDict(result.attributes)) 242 return failure(); 243 244 return success(); 245 } 246 247 void AllocaScopeOp::getSuccessorRegions( 248 Optional<unsigned> index, ArrayRef<Attribute> operands, 249 SmallVectorImpl<RegionSuccessor> ®ions) { 250 if (index.hasValue()) { 251 regions.push_back(RegionSuccessor(getResults())); 252 return; 253 } 254 255 regions.push_back(RegionSuccessor(&bodyRegion())); 256 } 257 258 /// Given an operation, return whether this op is guaranteed to 259 /// allocate an AutomaticAllocationScopeResource 260 static bool isGuaranteedAutomaticAllocation(Operation *op) { 261 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 262 if (!interface) 263 return false; 264 for (auto res : op->getResults()) { 265 if (auto effect = 266 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 267 if (isa<SideEffects::AutomaticAllocationScopeResource>( 268 effect->getResource())) 269 return true; 270 } 271 } 272 return false; 273 } 274 275 /// Given an operation, return whether this op itself could 276 /// allocate an AutomaticAllocationScopeResource. Note that 277 /// this will not check whether an operation contained within 278 /// the op can allocate. 279 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { 280 // This op itself doesn't create a stack allocation, 281 // the inner allocation should be handled separately. 282 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) 283 return false; 284 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 285 if (!interface) 286 return true; 287 for (auto res : op->getResults()) { 288 if (auto effect = 289 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 290 if (isa<SideEffects::AutomaticAllocationScopeResource>( 291 effect->getResource())) 292 return true; 293 } 294 } 295 return false; 296 } 297 298 /// Return whether this op is the last non terminating op 299 /// in a region. That is to say, it is in a one-block region 300 /// and is only followed by a terminator. This prevents 301 /// extending the lifetime of allocations. 302 static bool lastNonTerminatorInRegion(Operation *op) { 303 return op->getNextNode() == op->getBlock()->getTerminator() && 304 op->getParentRegion()->getBlocks().size() == 1; 305 } 306 307 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope 308 /// or it contains no allocation. 309 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { 310 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 311 312 LogicalResult matchAndRewrite(AllocaScopeOp op, 313 PatternRewriter &rewriter) const override { 314 bool hasPotentialAlloca = 315 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) { 316 if (alloc == op) 317 return WalkResult::advance(); 318 if (isOpItselfPotentialAutomaticAllocation(alloc)) 319 return WalkResult::interrupt(); 320 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>()) 321 return WalkResult::skip(); 322 return WalkResult::advance(); 323 }).wasInterrupted(); 324 325 // If this contains no potential allocation, it is always legal to 326 // inline. Otherwise, consider two conditions: 327 if (hasPotentialAlloca) { 328 // If the parent isn't an allocation scope, or we are not the last 329 // non-terminator op in the parent, we will extend the lifetime. 330 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) 331 return failure(); 332 if (!lastNonTerminatorInRegion(op)) 333 return failure(); 334 } 335 336 Block *block = &op.getRegion().front(); 337 Operation *terminator = block->getTerminator(); 338 ValueRange results = terminator->getOperands(); 339 rewriter.mergeBlockBefore(block, op); 340 rewriter.replaceOp(op, results); 341 rewriter.eraseOp(terminator); 342 return success(); 343 } 344 }; 345 346 /// Move allocations into an allocation scope, if it is legal to 347 /// move them (e.g. their operands are available at the location 348 /// the op would be moved to). 349 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { 350 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 351 352 LogicalResult matchAndRewrite(AllocaScopeOp op, 353 PatternRewriter &rewriter) const override { 354 355 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 356 return failure(); 357 358 Operation *lastParentWithoutScope = op->getParentOp(); 359 360 if (!lastParentWithoutScope || 361 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>()) 362 return failure(); 363 364 // Only apply to if this is this last non-terminator 365 // op in the block (lest lifetime be extended) of a one 366 // block region 367 if (!lastNonTerminatorInRegion(op) || 368 !lastNonTerminatorInRegion(lastParentWithoutScope)) 369 return failure(); 370 371 while (!lastParentWithoutScope->getParentOp() 372 ->hasTrait<OpTrait::AutomaticAllocationScope>()) { 373 lastParentWithoutScope = lastParentWithoutScope->getParentOp(); 374 if (!lastParentWithoutScope || 375 !lastNonTerminatorInRegion(lastParentWithoutScope)) 376 return failure(); 377 } 378 assert(lastParentWithoutScope->getParentOp() 379 ->hasTrait<OpTrait::AutomaticAllocationScope>()); 380 381 Region *containingRegion = nullptr; 382 for (auto &r : lastParentWithoutScope->getRegions()) { 383 if (r.isAncestor(op->getParentRegion())) { 384 assert(containingRegion == nullptr && 385 "only one region can contain the op"); 386 containingRegion = &r; 387 } 388 } 389 assert(containingRegion && "op must be contained in a region"); 390 391 SmallVector<Operation *> toHoist; 392 op->walk([&](Operation *alloc) { 393 if (!isGuaranteedAutomaticAllocation(alloc)) 394 return WalkResult::skip(); 395 396 // If any operand is not defined before the location of 397 // lastParentWithoutScope (i.e. where we would hoist to), skip. 398 if (llvm::any_of(alloc->getOperands(), [&](Value v) { 399 return containingRegion->isAncestor(v.getParentRegion()); 400 })) 401 return WalkResult::skip(); 402 toHoist.push_back(alloc); 403 return WalkResult::advance(); 404 }); 405 406 if (toHoist.empty()) 407 return failure(); 408 rewriter.setInsertionPoint(lastParentWithoutScope); 409 for (auto *op : toHoist) { 410 auto *cloned = rewriter.clone(*op); 411 rewriter.replaceOp(op, cloned->getResults()); 412 } 413 return success(); 414 } 415 }; 416 417 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, 418 MLIRContext *context) { 419 results.add<AllocaScopeInliner, AllocaScopeHoister>(context); 420 } 421 422 //===----------------------------------------------------------------------===// 423 // AssumeAlignmentOp 424 //===----------------------------------------------------------------------===// 425 426 LogicalResult AssumeAlignmentOp::verify() { 427 if (!llvm::isPowerOf2_32(alignment())) 428 return emitOpError("alignment must be power of 2"); 429 return success(); 430 } 431 432 //===----------------------------------------------------------------------===// 433 // CastOp 434 //===----------------------------------------------------------------------===// 435 436 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 437 /// source memref. This is useful to to fold a memref.cast into a consuming op 438 /// and implement canonicalization patterns for ops in different dialects that 439 /// may consume the results of memref.cast operations. Such foldable memref.cast 440 /// operations are typically inserted as `view` and `subview` ops are 441 /// canonicalized, to preserve the type compatibility of their uses. 442 /// 443 /// Returns true when all conditions are met: 444 /// 1. source and result are ranked memrefs with strided semantics and same 445 /// element type and rank. 446 /// 2. each of the source's size, offset or stride has more static information 447 /// than the corresponding result's size, offset or stride. 448 /// 449 /// Example 1: 450 /// ```mlir 451 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 452 /// %2 = consumer %1 ... : memref<?x?xf32> ... 453 /// ``` 454 /// 455 /// may fold into: 456 /// 457 /// ```mlir 458 /// %2 = consumer %0 ... : memref<8x16xf32> ... 459 /// ``` 460 /// 461 /// Example 2: 462 /// ``` 463 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 464 /// to memref<?x?xf32> 465 /// consumer %1 : memref<?x?xf32> ... 466 /// ``` 467 /// 468 /// may fold into: 469 /// 470 /// ``` 471 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 472 /// ``` 473 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 474 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 475 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 476 477 // Requires ranked MemRefType. 478 if (!sourceType || !resultType) 479 return false; 480 481 // Requires same elemental type. 482 if (sourceType.getElementType() != resultType.getElementType()) 483 return false; 484 485 // Requires same rank. 486 if (sourceType.getRank() != resultType.getRank()) 487 return false; 488 489 // Only fold casts between strided memref forms. 490 int64_t sourceOffset, resultOffset; 491 SmallVector<int64_t, 4> sourceStrides, resultStrides; 492 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 493 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 494 return false; 495 496 // If cast is towards more static sizes along any dimension, don't fold. 497 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 498 auto ss = std::get<0>(it), st = std::get<1>(it); 499 if (ss != st) 500 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 501 return false; 502 } 503 504 // If cast is towards more static offset along any dimension, don't fold. 505 if (sourceOffset != resultOffset) 506 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 507 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 508 return false; 509 510 // If cast is towards more static strides along any dimension, don't fold. 511 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 512 auto ss = std::get<0>(it), st = std::get<1>(it); 513 if (ss != st) 514 if (ShapedType::isDynamicStrideOrOffset(ss) && 515 !ShapedType::isDynamicStrideOrOffset(st)) 516 return false; 517 } 518 519 return true; 520 } 521 522 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 523 if (inputs.size() != 1 || outputs.size() != 1) 524 return false; 525 Type a = inputs.front(), b = outputs.front(); 526 auto aT = a.dyn_cast<MemRefType>(); 527 auto bT = b.dyn_cast<MemRefType>(); 528 529 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 530 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 531 532 if (aT && bT) { 533 if (aT.getElementType() != bT.getElementType()) 534 return false; 535 if (aT.getLayout() != bT.getLayout()) { 536 int64_t aOffset, bOffset; 537 SmallVector<int64_t, 4> aStrides, bStrides; 538 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 539 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 540 aStrides.size() != bStrides.size()) 541 return false; 542 543 // Strides along a dimension/offset are compatible if the value in the 544 // source memref is static and the value in the target memref is the 545 // same. They are also compatible if either one is dynamic (see 546 // description of MemRefCastOp for details). 547 auto checkCompatible = [](int64_t a, int64_t b) { 548 return (a == MemRefType::getDynamicStrideOrOffset() || 549 b == MemRefType::getDynamicStrideOrOffset() || a == b); 550 }; 551 if (!checkCompatible(aOffset, bOffset)) 552 return false; 553 for (const auto &aStride : enumerate(aStrides)) 554 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 555 return false; 556 } 557 if (aT.getMemorySpace() != bT.getMemorySpace()) 558 return false; 559 560 // They must have the same rank, and any specified dimensions must match. 561 if (aT.getRank() != bT.getRank()) 562 return false; 563 564 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 565 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 566 if (aDim != -1 && bDim != -1 && aDim != bDim) 567 return false; 568 } 569 return true; 570 } else { 571 if (!aT && !uaT) 572 return false; 573 if (!bT && !ubT) 574 return false; 575 // Unranked to unranked casting is unsupported 576 if (uaT && ubT) 577 return false; 578 579 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 580 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 581 if (aEltType != bEltType) 582 return false; 583 584 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 585 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 586 return aMemSpace == bMemSpace; 587 } 588 589 return false; 590 } 591 592 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 593 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // CopyOp 598 //===----------------------------------------------------------------------===// 599 600 namespace { 601 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 602 /// and element type, the cast can be skipped. Such CastOps only cast the layout 603 /// of the type. 604 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 605 using OpRewritePattern<CopyOp>::OpRewritePattern; 606 607 LogicalResult matchAndRewrite(CopyOp copyOp, 608 PatternRewriter &rewriter) const override { 609 bool modified = false; 610 611 // Check source. 612 if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) { 613 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 614 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 615 616 if (fromType && toType) { 617 if (fromType.getShape() == toType.getShape() && 618 fromType.getElementType() == toType.getElementType()) { 619 rewriter.updateRootInPlace( 620 copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); 621 modified = true; 622 } 623 } 624 } 625 626 // Check target. 627 if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) { 628 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 629 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 630 631 if (fromType && toType) { 632 if (fromType.getShape() == toType.getShape() && 633 fromType.getElementType() == toType.getElementType()) { 634 rewriter.updateRootInPlace( 635 copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); 636 modified = true; 637 } 638 } 639 } 640 641 return success(modified); 642 } 643 }; 644 645 /// Fold memref.copy(%x, %x). 646 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 647 using OpRewritePattern<CopyOp>::OpRewritePattern; 648 649 LogicalResult matchAndRewrite(CopyOp copyOp, 650 PatternRewriter &rewriter) const override { 651 if (copyOp.source() != copyOp.target()) 652 return failure(); 653 654 rewriter.eraseOp(copyOp); 655 return success(); 656 } 657 }; 658 } // namespace 659 660 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 661 MLIRContext *context) { 662 results.add<FoldCopyOfCast, FoldSelfCopy>(context); 663 } 664 665 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands, 666 SmallVectorImpl<OpFoldResult> &results) { 667 /// copy(memrefcast) -> copy 668 bool folded = false; 669 Operation *op = *this; 670 for (OpOperand &operand : op->getOpOperands()) { 671 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 672 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 673 operand.set(castOp.getOperand()); 674 folded = true; 675 } 676 } 677 return success(folded); 678 } 679 680 //===----------------------------------------------------------------------===// 681 // DeallocOp 682 //===----------------------------------------------------------------------===// 683 684 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 685 SmallVectorImpl<OpFoldResult> &results) { 686 /// dealloc(memrefcast) -> dealloc 687 return foldMemRefCast(*this); 688 } 689 690 //===----------------------------------------------------------------------===// 691 // DimOp 692 //===----------------------------------------------------------------------===// 693 694 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 695 int64_t index) { 696 auto loc = result.location; 697 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 698 build(builder, result, source, indexValue); 699 } 700 701 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 702 Value index) { 703 auto indexTy = builder.getIndexType(); 704 build(builder, result, indexTy, source, index); 705 } 706 707 Optional<int64_t> DimOp::getConstantIndex() { 708 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 709 return constantOp.getValue().cast<IntegerAttr>().getInt(); 710 return {}; 711 } 712 713 LogicalResult DimOp::verify() { 714 // Assume unknown index to be in range. 715 Optional<int64_t> index = getConstantIndex(); 716 if (!index.hasValue()) 717 return success(); 718 719 // Check that constant index is not knowingly out of range. 720 auto type = source().getType(); 721 if (auto memrefType = type.dyn_cast<MemRefType>()) { 722 if (index.getValue() >= memrefType.getRank()) 723 return emitOpError("index is out of range"); 724 } else if (type.isa<UnrankedMemRefType>()) { 725 // Assume index to be in range. 726 } else { 727 llvm_unreachable("expected operand with memref type"); 728 } 729 return success(); 730 } 731 732 /// Return a map with key being elements in `vals` and data being number of 733 /// occurences of it. Use std::map, since the `vals` here are strides and the 734 /// dynamic stride value is the same as the tombstone value for 735 /// `DenseMap<int64_t>`. 736 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 737 std::map<int64_t, unsigned> numOccurences; 738 for (auto val : vals) 739 numOccurences[val]++; 740 return numOccurences; 741 } 742 743 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 744 /// to be a subset of `originalType` with some `1` entries erased, return the 745 /// set of indices that specifies which of the entries of `originalShape` are 746 /// dropped to obtain `reducedShape`. 747 /// This accounts for cases where there are multiple unit-dims, but only a 748 /// subset of those are dropped. For MemRefTypes these can be disambiguated 749 /// using the strides. If a dimension is dropped the stride must be dropped too. 750 static llvm::Optional<llvm::SmallBitVector> 751 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 752 ArrayRef<OpFoldResult> sizes) { 753 llvm::SmallBitVector unusedDims(originalType.getRank()); 754 if (originalType.getRank() == reducedType.getRank()) 755 return unusedDims; 756 757 for (const auto &dim : llvm::enumerate(sizes)) 758 if (auto attr = dim.value().dyn_cast<Attribute>()) 759 if (attr.cast<IntegerAttr>().getInt() == 1) 760 unusedDims.set(dim.index()); 761 762 // Early exit for the case where the number of unused dims matches the number 763 // of ranks reduced. 764 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() == 765 originalType.getRank()) 766 return unusedDims; 767 768 SmallVector<int64_t> originalStrides, candidateStrides; 769 int64_t originalOffset, candidateOffset; 770 if (failed( 771 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 772 failed( 773 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 774 return llvm::None; 775 776 // For memrefs, a dimension is truly dropped if its corresponding stride is 777 // also dropped. This is particularly important when more than one of the dims 778 // is 1. Track the number of occurences of the strides in the original type 779 // and the candidate type. For each unused dim that stride should not be 780 // present in the candidate type. Note that there could be multiple dimensions 781 // that have the same size. We dont need to exactly figure out which dim 782 // corresponds to which stride, we just need to verify that the number of 783 // reptitions of a stride in the original + number of unused dims with that 784 // stride == number of repititions of a stride in the candidate. 785 std::map<int64_t, unsigned> currUnaccountedStrides = 786 getNumOccurences(originalStrides); 787 std::map<int64_t, unsigned> candidateStridesNumOccurences = 788 getNumOccurences(candidateStrides); 789 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) { 790 if (!unusedDims.test(dim)) 791 continue; 792 int64_t originalStride = originalStrides[dim]; 793 if (currUnaccountedStrides[originalStride] > 794 candidateStridesNumOccurences[originalStride]) { 795 // This dim can be treated as dropped. 796 currUnaccountedStrides[originalStride]--; 797 continue; 798 } 799 if (currUnaccountedStrides[originalStride] == 800 candidateStridesNumOccurences[originalStride]) { 801 // The stride for this is not dropped. Keep as is. 802 unusedDims.reset(dim); 803 continue; 804 } 805 if (currUnaccountedStrides[originalStride] < 806 candidateStridesNumOccurences[originalStride]) { 807 // This should never happen. Cant have a stride in the reduced rank type 808 // that wasnt in the original one. 809 return llvm::None; 810 } 811 } 812 813 if ((int64_t)unusedDims.count() + reducedType.getRank() != 814 originalType.getRank()) 815 return llvm::None; 816 return unusedDims; 817 } 818 819 llvm::SmallBitVector SubViewOp::getDroppedDims() { 820 MemRefType sourceType = getSourceType(); 821 MemRefType resultType = getType(); 822 llvm::Optional<llvm::SmallBitVector> unusedDims = 823 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 824 assert(unusedDims && "unable to find unused dims of subview"); 825 return *unusedDims; 826 } 827 828 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 829 // All forms of folding require a known index. 830 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 831 if (!index) 832 return {}; 833 834 // Folding for unranked types (UnrankedMemRefType) is not supported. 835 auto memrefType = source().getType().dyn_cast<MemRefType>(); 836 if (!memrefType) 837 return {}; 838 839 // Fold if the shape extent along the given index is known. 840 if (!memrefType.isDynamicDim(index.getInt())) { 841 Builder builder(getContext()); 842 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 843 } 844 845 // The size at the given index is now known to be a dynamic size. 846 unsigned unsignedIndex = index.getValue().getZExtValue(); 847 848 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 849 Operation *definingOp = source().getDefiningOp(); 850 851 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 852 return *(alloc.getDynamicSizes().begin() + 853 memrefType.getDynamicDimIndex(unsignedIndex)); 854 855 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 856 return *(alloca.getDynamicSizes().begin() + 857 memrefType.getDynamicDimIndex(unsignedIndex)); 858 859 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 860 return *(view.getDynamicSizes().begin() + 861 memrefType.getDynamicDimIndex(unsignedIndex)); 862 863 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 864 llvm::SmallBitVector unusedDims = subview.getDroppedDims(); 865 unsigned resultIndex = 0; 866 unsigned sourceRank = subview.getSourceType().getRank(); 867 unsigned sourceIndex = 0; 868 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 869 if (unusedDims.test(i)) 870 continue; 871 if (resultIndex == unsignedIndex) { 872 sourceIndex = i; 873 break; 874 } 875 resultIndex++; 876 } 877 assert(subview.isDynamicSize(sourceIndex) && 878 "expected dynamic subview size"); 879 return subview.getDynamicSize(sourceIndex); 880 } 881 882 if (auto sizeInterface = 883 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 884 assert(sizeInterface.isDynamicSize(unsignedIndex) && 885 "Expected dynamic subview size"); 886 return sizeInterface.getDynamicSize(unsignedIndex); 887 } 888 889 // dim(memrefcast) -> dim 890 if (succeeded(foldMemRefCast(*this))) 891 return getResult(); 892 893 return {}; 894 } 895 896 namespace { 897 /// Fold dim of a memref reshape operation to a load into the reshape's shape 898 /// operand. 899 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 900 using OpRewritePattern<DimOp>::OpRewritePattern; 901 902 LogicalResult matchAndRewrite(DimOp dim, 903 PatternRewriter &rewriter) const override { 904 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 905 906 if (!reshape) 907 return failure(); 908 909 // Place the load directly after the reshape to ensure that the shape memref 910 // was not mutated. 911 rewriter.setInsertionPointAfter(reshape); 912 Location loc = dim.getLoc(); 913 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 914 if (load.getType() != dim.getType()) 915 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 916 rewriter.replaceOp(dim, load); 917 return success(); 918 } 919 }; 920 921 } // namespace 922 923 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 924 MLIRContext *context) { 925 results.add<DimOfMemRefReshape>(context); 926 } 927 928 // --------------------------------------------------------------------------- 929 // DmaStartOp 930 // --------------------------------------------------------------------------- 931 932 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 933 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 934 ValueRange destIndices, Value numElements, 935 Value tagMemRef, ValueRange tagIndices, Value stride, 936 Value elementsPerStride) { 937 result.addOperands(srcMemRef); 938 result.addOperands(srcIndices); 939 result.addOperands(destMemRef); 940 result.addOperands(destIndices); 941 result.addOperands({numElements, tagMemRef}); 942 result.addOperands(tagIndices); 943 if (stride) 944 result.addOperands({stride, elementsPerStride}); 945 } 946 947 void DmaStartOp::print(OpAsmPrinter &p) { 948 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " 949 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() 950 << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; 951 if (isStrided()) 952 p << ", " << getStride() << ", " << getNumElementsPerStride(); 953 954 p.printOptionalAttrDict((*this)->getAttrs()); 955 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 956 << ", " << getTagMemRef().getType(); 957 } 958 959 // Parse DmaStartOp. 960 // Ex: 961 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 962 // %tag[%index], %stride, %num_elt_per_stride : 963 // : memref<3076 x f32, 0>, 964 // memref<1024 x f32, 2>, 965 // memref<1 x i32> 966 // 967 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 968 OpAsmParser::UnresolvedOperand srcMemRefInfo; 969 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos; 970 OpAsmParser::UnresolvedOperand dstMemRefInfo; 971 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos; 972 OpAsmParser::UnresolvedOperand numElementsInfo; 973 OpAsmParser::UnresolvedOperand tagMemrefInfo; 974 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos; 975 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo; 976 977 SmallVector<Type, 3> types; 978 auto indexType = parser.getBuilder().getIndexType(); 979 980 // Parse and resolve the following list of operands: 981 // *) source memref followed by its indices (in square brackets). 982 // *) destination memref followed by its indices (in square brackets). 983 // *) dma size in KiB. 984 if (parser.parseOperand(srcMemRefInfo) || 985 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 986 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 987 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 988 parser.parseComma() || parser.parseOperand(numElementsInfo) || 989 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 990 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 991 return failure(); 992 993 // Parse optional stride and elements per stride. 994 if (parser.parseTrailingOperandList(strideInfo)) 995 return failure(); 996 997 bool isStrided = strideInfo.size() == 2; 998 if (!strideInfo.empty() && !isStrided) { 999 return parser.emitError(parser.getNameLoc(), 1000 "expected two stride related operands"); 1001 } 1002 1003 if (parser.parseColonTypeList(types)) 1004 return failure(); 1005 if (types.size() != 3) 1006 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 1007 1008 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 1009 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 1010 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 1011 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 1012 // size should be an index. 1013 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 1014 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 1015 // tag indices should be index. 1016 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 1017 return failure(); 1018 1019 if (isStrided) { 1020 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1021 return failure(); 1022 } 1023 1024 return success(); 1025 } 1026 1027 LogicalResult DmaStartOp::verify() { 1028 unsigned numOperands = getNumOperands(); 1029 1030 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1031 // the number of elements. 1032 if (numOperands < 4) 1033 return emitOpError("expected at least 4 operands"); 1034 1035 // Check types of operands. The order of these calls is important: the later 1036 // calls rely on some type properties to compute the operand position. 1037 // 1. Source memref. 1038 if (!getSrcMemRef().getType().isa<MemRefType>()) 1039 return emitOpError("expected source to be of memref type"); 1040 if (numOperands < getSrcMemRefRank() + 4) 1041 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 1042 << " operands"; 1043 if (!getSrcIndices().empty() && 1044 !llvm::all_of(getSrcIndices().getTypes(), 1045 [](Type t) { return t.isIndex(); })) 1046 return emitOpError("expected source indices to be of index type"); 1047 1048 // 2. Destination memref. 1049 if (!getDstMemRef().getType().isa<MemRefType>()) 1050 return emitOpError("expected destination to be of memref type"); 1051 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 1052 if (numOperands < numExpectedOperands) 1053 return emitOpError() << "expected at least " << numExpectedOperands 1054 << " operands"; 1055 if (!getDstIndices().empty() && 1056 !llvm::all_of(getDstIndices().getTypes(), 1057 [](Type t) { return t.isIndex(); })) 1058 return emitOpError("expected destination indices to be of index type"); 1059 1060 // 3. Number of elements. 1061 if (!getNumElements().getType().isIndex()) 1062 return emitOpError("expected num elements to be of index type"); 1063 1064 // 4. Tag memref. 1065 if (!getTagMemRef().getType().isa<MemRefType>()) 1066 return emitOpError("expected tag to be of memref type"); 1067 numExpectedOperands += getTagMemRefRank(); 1068 if (numOperands < numExpectedOperands) 1069 return emitOpError() << "expected at least " << numExpectedOperands 1070 << " operands"; 1071 if (!getTagIndices().empty() && 1072 !llvm::all_of(getTagIndices().getTypes(), 1073 [](Type t) { return t.isIndex(); })) 1074 return emitOpError("expected tag indices to be of index type"); 1075 1076 // Optional stride-related operands must be either both present or both 1077 // absent. 1078 if (numOperands != numExpectedOperands && 1079 numOperands != numExpectedOperands + 2) 1080 return emitOpError("incorrect number of operands"); 1081 1082 // 5. Strides. 1083 if (isStrided()) { 1084 if (!getStride().getType().isIndex() || 1085 !getNumElementsPerStride().getType().isIndex()) 1086 return emitOpError( 1087 "expected stride and num elements per stride to be of type index"); 1088 } 1089 1090 return success(); 1091 } 1092 1093 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1094 SmallVectorImpl<OpFoldResult> &results) { 1095 /// dma_start(memrefcast) -> dma_start 1096 return foldMemRefCast(*this); 1097 } 1098 1099 // --------------------------------------------------------------------------- 1100 // DmaWaitOp 1101 // --------------------------------------------------------------------------- 1102 1103 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1104 SmallVectorImpl<OpFoldResult> &results) { 1105 /// dma_wait(memrefcast) -> dma_wait 1106 return foldMemRefCast(*this); 1107 } 1108 1109 LogicalResult DmaWaitOp::verify() { 1110 // Check that the number of tag indices matches the tagMemRef rank. 1111 unsigned numTagIndices = tagIndices().size(); 1112 unsigned tagMemRefRank = getTagMemRefRank(); 1113 if (numTagIndices != tagMemRefRank) 1114 return emitOpError() << "expected tagIndices to have the same number of " 1115 "elements as the tagMemRef rank, expected " 1116 << tagMemRefRank << ", but got " << numTagIndices; 1117 return success(); 1118 } 1119 1120 //===----------------------------------------------------------------------===// 1121 // GenericAtomicRMWOp 1122 //===----------------------------------------------------------------------===// 1123 1124 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, 1125 Value memref, ValueRange ivs) { 1126 result.addOperands(memref); 1127 result.addOperands(ivs); 1128 1129 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) { 1130 Type elementType = memrefType.getElementType(); 1131 result.addTypes(elementType); 1132 1133 Region *bodyRegion = result.addRegion(); 1134 bodyRegion->push_back(new Block()); 1135 bodyRegion->addArgument(elementType, memref.getLoc()); 1136 } 1137 } 1138 1139 LogicalResult GenericAtomicRMWOp::verify() { 1140 auto &body = getRegion(); 1141 if (body.getNumArguments() != 1) 1142 return emitOpError("expected single number of entry block arguments"); 1143 1144 if (getResult().getType() != body.getArgument(0).getType()) 1145 return emitOpError("expected block argument of the same type result type"); 1146 1147 bool hasSideEffects = 1148 body.walk([&](Operation *nestedOp) { 1149 if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) 1150 return WalkResult::advance(); 1151 nestedOp->emitError( 1152 "body of 'memref.generic_atomic_rmw' should contain " 1153 "only operations with no side effects"); 1154 return WalkResult::interrupt(); 1155 }) 1156 .wasInterrupted(); 1157 return hasSideEffects ? failure() : success(); 1158 } 1159 1160 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, 1161 OperationState &result) { 1162 OpAsmParser::UnresolvedOperand memref; 1163 Type memrefType; 1164 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs; 1165 1166 Type indexType = parser.getBuilder().getIndexType(); 1167 if (parser.parseOperand(memref) || 1168 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || 1169 parser.parseColonType(memrefType) || 1170 parser.resolveOperand(memref, memrefType, result.operands) || 1171 parser.resolveOperands(ivs, indexType, result.operands)) 1172 return failure(); 1173 1174 Region *body = result.addRegion(); 1175 if (parser.parseRegion(*body, llvm::None, llvm::None) || 1176 parser.parseOptionalAttrDict(result.attributes)) 1177 return failure(); 1178 result.types.push_back(memrefType.cast<MemRefType>().getElementType()); 1179 return success(); 1180 } 1181 1182 void GenericAtomicRMWOp::print(OpAsmPrinter &p) { 1183 p << ' ' << memref() << "[" << indices() << "] : " << memref().getType() 1184 << ' '; 1185 p.printRegion(getRegion()); 1186 p.printOptionalAttrDict((*this)->getAttrs()); 1187 } 1188 1189 //===----------------------------------------------------------------------===// 1190 // AtomicYieldOp 1191 //===----------------------------------------------------------------------===// 1192 1193 LogicalResult AtomicYieldOp::verify() { 1194 Type parentType = (*this)->getParentOp()->getResultTypes().front(); 1195 Type resultType = result().getType(); 1196 if (parentType != resultType) 1197 return emitOpError() << "types mismatch between yield op: " << resultType 1198 << " and its parent: " << parentType; 1199 return success(); 1200 } 1201 1202 //===----------------------------------------------------------------------===// 1203 // GlobalOp 1204 //===----------------------------------------------------------------------===// 1205 1206 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1207 TypeAttr type, 1208 Attribute initialValue) { 1209 p << type; 1210 if (!op.isExternal()) { 1211 p << " = "; 1212 if (op.isUninitialized()) 1213 p << "uninitialized"; 1214 else 1215 p.printAttributeWithoutType(initialValue); 1216 } 1217 } 1218 1219 static ParseResult 1220 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1221 Attribute &initialValue) { 1222 Type type; 1223 if (parser.parseType(type)) 1224 return failure(); 1225 1226 auto memrefType = type.dyn_cast<MemRefType>(); 1227 if (!memrefType || !memrefType.hasStaticShape()) 1228 return parser.emitError(parser.getNameLoc()) 1229 << "type should be static shaped memref, but got " << type; 1230 typeAttr = TypeAttr::get(type); 1231 1232 if (parser.parseOptionalEqual()) 1233 return success(); 1234 1235 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1236 initialValue = UnitAttr::get(parser.getContext()); 1237 return success(); 1238 } 1239 1240 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1241 if (parser.parseAttribute(initialValue, tensorType)) 1242 return failure(); 1243 if (!initialValue.isa<ElementsAttr>()) 1244 return parser.emitError(parser.getNameLoc()) 1245 << "initial value should be a unit or elements attribute"; 1246 return success(); 1247 } 1248 1249 LogicalResult GlobalOp::verify() { 1250 auto memrefType = type().dyn_cast<MemRefType>(); 1251 if (!memrefType || !memrefType.hasStaticShape()) 1252 return emitOpError("type should be static shaped memref, but got ") 1253 << type(); 1254 1255 // Verify that the initial value, if present, is either a unit attribute or 1256 // an elements attribute. 1257 if (initial_value().hasValue()) { 1258 Attribute initValue = initial_value().getValue(); 1259 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1260 return emitOpError("initial value should be a unit or elements " 1261 "attribute, but got ") 1262 << initValue; 1263 1264 // Check that the type of the initial value is compatible with the type of 1265 // the global variable. 1266 if (initValue.isa<ElementsAttr>()) { 1267 Type initType = initValue.getType(); 1268 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1269 if (initType != tensorType) 1270 return emitOpError("initial value expected to be of type ") 1271 << tensorType << ", but was of type " << initType; 1272 } 1273 } 1274 1275 if (Optional<uint64_t> alignAttr = alignment()) { 1276 uint64_t alignment = alignAttr.getValue(); 1277 1278 if (!llvm::isPowerOf2_64(alignment)) 1279 return emitError() << "alignment attribute value " << alignment 1280 << " is not a power of 2"; 1281 } 1282 1283 // TODO: verify visibility for declarations. 1284 return success(); 1285 } 1286 1287 ElementsAttr GlobalOp::getConstantInitValue() { 1288 auto initVal = initial_value(); 1289 if (constant() && initVal.hasValue()) 1290 return initVal.getValue().cast<ElementsAttr>(); 1291 return {}; 1292 } 1293 1294 //===----------------------------------------------------------------------===// 1295 // GetGlobalOp 1296 //===----------------------------------------------------------------------===// 1297 1298 LogicalResult 1299 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1300 // Verify that the result type is same as the type of the referenced 1301 // memref.global op. 1302 auto global = 1303 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1304 if (!global) 1305 return emitOpError("'") 1306 << name() << "' does not reference a valid global memref"; 1307 1308 Type resultType = result().getType(); 1309 if (global.type() != resultType) 1310 return emitOpError("result type ") 1311 << resultType << " does not match type " << global.type() 1312 << " of the global memref @" << name(); 1313 return success(); 1314 } 1315 1316 //===----------------------------------------------------------------------===// 1317 // LoadOp 1318 //===----------------------------------------------------------------------===// 1319 1320 LogicalResult LoadOp::verify() { 1321 if (getNumOperands() != 1 + getMemRefType().getRank()) 1322 return emitOpError("incorrect number of indices for load"); 1323 return success(); 1324 } 1325 1326 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1327 /// load(memrefcast) -> load 1328 if (succeeded(foldMemRefCast(*this))) 1329 return getResult(); 1330 return OpFoldResult(); 1331 } 1332 1333 //===----------------------------------------------------------------------===// 1334 // PrefetchOp 1335 //===----------------------------------------------------------------------===// 1336 1337 void PrefetchOp::print(OpAsmPrinter &p) { 1338 p << " " << memref() << '['; 1339 p.printOperands(indices()); 1340 p << ']' << ", " << (isWrite() ? "write" : "read"); 1341 p << ", locality<" << localityHint(); 1342 p << ">, " << (isDataCache() ? "data" : "instr"); 1343 p.printOptionalAttrDict( 1344 (*this)->getAttrs(), 1345 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1346 p << " : " << getMemRefType(); 1347 } 1348 1349 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { 1350 OpAsmParser::UnresolvedOperand memrefInfo; 1351 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo; 1352 IntegerAttr localityHint; 1353 MemRefType type; 1354 StringRef readOrWrite, cacheType; 1355 1356 auto indexTy = parser.getBuilder().getIndexType(); 1357 auto i32Type = parser.getBuilder().getIntegerType(32); 1358 if (parser.parseOperand(memrefInfo) || 1359 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1360 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1361 parser.parseComma() || parser.parseKeyword("locality") || 1362 parser.parseLess() || 1363 parser.parseAttribute(localityHint, i32Type, "localityHint", 1364 result.attributes) || 1365 parser.parseGreater() || parser.parseComma() || 1366 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1367 parser.resolveOperand(memrefInfo, type, result.operands) || 1368 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1369 return failure(); 1370 1371 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1372 return parser.emitError(parser.getNameLoc(), 1373 "rw specifier has to be 'read' or 'write'"); 1374 result.addAttribute( 1375 PrefetchOp::getIsWriteAttrName(), 1376 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1377 1378 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1379 return parser.emitError(parser.getNameLoc(), 1380 "cache type has to be 'data' or 'instr'"); 1381 1382 result.addAttribute( 1383 PrefetchOp::getIsDataCacheAttrName(), 1384 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1385 1386 return success(); 1387 } 1388 1389 LogicalResult PrefetchOp::verify() { 1390 if (getNumOperands() != 1 + getMemRefType().getRank()) 1391 return emitOpError("too few indices"); 1392 1393 return success(); 1394 } 1395 1396 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1397 SmallVectorImpl<OpFoldResult> &results) { 1398 // prefetch(memrefcast) -> prefetch 1399 return foldMemRefCast(*this); 1400 } 1401 1402 //===----------------------------------------------------------------------===// 1403 // RankOp 1404 //===----------------------------------------------------------------------===// 1405 1406 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1407 // Constant fold rank when the rank of the operand is known. 1408 auto type = getOperand().getType(); 1409 auto shapedType = type.dyn_cast<ShapedType>(); 1410 if (shapedType && shapedType.hasRank()) 1411 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1412 return IntegerAttr(); 1413 } 1414 1415 //===----------------------------------------------------------------------===// 1416 // ReinterpretCastOp 1417 //===----------------------------------------------------------------------===// 1418 1419 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1420 /// `staticSizes` and `staticStrides` are automatically filled with 1421 /// source-memref-rank sentinel values that encode dynamic entries. 1422 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1423 MemRefType resultType, Value source, 1424 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1425 ArrayRef<OpFoldResult> strides, 1426 ArrayRef<NamedAttribute> attrs) { 1427 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1428 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1429 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1430 ShapedType::kDynamicStrideOrOffset); 1431 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1432 ShapedType::kDynamicSize); 1433 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1434 ShapedType::kDynamicStrideOrOffset); 1435 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1436 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1437 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1438 result.addAttributes(attrs); 1439 } 1440 1441 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1442 MemRefType resultType, Value source, 1443 int64_t offset, ArrayRef<int64_t> sizes, 1444 ArrayRef<int64_t> strides, 1445 ArrayRef<NamedAttribute> attrs) { 1446 SmallVector<OpFoldResult> sizeValues = 1447 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1448 return b.getI64IntegerAttr(v); 1449 })); 1450 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1451 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1452 return b.getI64IntegerAttr(v); 1453 })); 1454 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1455 strideValues, attrs); 1456 } 1457 1458 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1459 MemRefType resultType, Value source, Value offset, 1460 ValueRange sizes, ValueRange strides, 1461 ArrayRef<NamedAttribute> attrs) { 1462 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1463 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1464 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1465 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1466 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1467 } 1468 1469 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1470 // completed automatically, like we have for subview and extract_slice. 1471 LogicalResult ReinterpretCastOp::verify() { 1472 // The source and result memrefs should be in the same memory space. 1473 auto srcType = source().getType().cast<BaseMemRefType>(); 1474 auto resultType = getType().cast<MemRefType>(); 1475 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1476 return emitError("different memory spaces specified for source type ") 1477 << srcType << " and result memref type " << resultType; 1478 if (srcType.getElementType() != resultType.getElementType()) 1479 return emitError("different element types specified for source type ") 1480 << srcType << " and result memref type " << resultType; 1481 1482 // Match sizes in result memref type and in static_sizes attribute. 1483 for (auto &en : llvm::enumerate(llvm::zip( 1484 resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { 1485 int64_t resultSize = std::get<0>(en.value()); 1486 int64_t expectedSize = std::get<1>(en.value()); 1487 if (!ShapedType::isDynamic(resultSize) && 1488 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1489 return emitError("expected result type with size = ") 1490 << expectedSize << " instead of " << resultSize 1491 << " in dim = " << en.index(); 1492 } 1493 1494 // Match offset and strides in static_offset and static_strides attributes. If 1495 // result memref type has no affine map specified, this will assume an 1496 // identity layout. 1497 int64_t resultOffset; 1498 SmallVector<int64_t, 4> resultStrides; 1499 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1500 return emitError("expected result type to have strided layout but found ") 1501 << resultType; 1502 1503 // Match offset in result memref type and in static_offsets attribute. 1504 int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); 1505 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1506 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1507 resultOffset != expectedOffset) 1508 return emitError("expected result type with offset = ") 1509 << resultOffset << " instead of " << expectedOffset; 1510 1511 // Match strides in result memref type and in static_strides attribute. 1512 for (auto &en : llvm::enumerate(llvm::zip( 1513 resultStrides, extractFromI64ArrayAttr(static_strides())))) { 1514 int64_t resultStride = std::get<0>(en.value()); 1515 int64_t expectedStride = std::get<1>(en.value()); 1516 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1517 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1518 resultStride != expectedStride) 1519 return emitError("expected result type with stride = ") 1520 << expectedStride << " instead of " << resultStride 1521 << " in dim = " << en.index(); 1522 } 1523 1524 return success(); 1525 } 1526 1527 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) { 1528 Value src = source(); 1529 auto getPrevSrc = [&]() -> Value { 1530 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x). 1531 if (auto prev = src.getDefiningOp<ReinterpretCastOp>()) 1532 return prev.source(); 1533 1534 // reinterpret_cast(cast(x)) -> reinterpret_cast(x). 1535 if (auto prev = src.getDefiningOp<CastOp>()) 1536 return prev.source(); 1537 1538 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets 1539 // are 0. 1540 if (auto prev = src.getDefiningOp<SubViewOp>()) 1541 if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { 1542 return isConstantIntValue(val, 0); 1543 })) 1544 return prev.source(); 1545 1546 return nullptr; 1547 }; 1548 1549 if (auto prevSrc = getPrevSrc()) { 1550 sourceMutable().assign(prevSrc); 1551 return getResult(); 1552 } 1553 1554 return nullptr; 1555 } 1556 1557 //===----------------------------------------------------------------------===// 1558 // Reassociative reshape ops 1559 //===----------------------------------------------------------------------===// 1560 1561 /// Helper function that computes a stride based on the size/stride of the 1562 /// previous dimension. 1563 /// 1564 /// E.g., memref<20x10x5xf32, offset: 0, strides: [50, 5, 1]> 1565 /// ^^ 1566 /// compute this one 1567 /// prevStride = 5, prevDimSize = 10 1568 /// nextStride = 5 * 10 = 50 1569 static int64_t computeNextStride(int64_t prevStride, int64_t prevDimSize) { 1570 if (ShapedType::isDynamicStrideOrOffset(prevStride)) 1571 return ShapedType::kDynamicStrideOrOffset; 1572 1573 if (ShapedType::isDynamic(prevDimSize)) 1574 return ShapedType::kDynamicStrideOrOffset; 1575 1576 return prevStride * prevDimSize; 1577 } 1578 1579 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp 1580 /// result and operand. Layout maps are verified separately. 1581 /// 1582 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are 1583 /// allowed in a reassocation group. 1584 static LogicalResult 1585 verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape, 1586 ArrayRef<int64_t> expandedShape, 1587 ArrayRef<ReassociationIndices> reassociation, 1588 bool allowMultipleDynamicDimsPerGroup) { 1589 // There must be one reassociation group per collapsed dimension. 1590 if (collapsedShape.size() != reassociation.size()) 1591 return op->emitOpError("invalid number of reassociation groups: found ") 1592 << reassociation.size() << ", expected " << collapsedShape.size(); 1593 1594 // The next expected expanded dimension index (while iterating over 1595 // reassociation indices). 1596 int64_t nextDim = 0; 1597 for (const auto &it : llvm::enumerate(reassociation)) { 1598 ReassociationIndices group = it.value(); 1599 int64_t collapsedDim = it.index(); 1600 1601 bool foundDynamic = false; 1602 for (int64_t expandedDim : group) { 1603 if (expandedDim != nextDim++) 1604 return op->emitOpError("reassociation indices must be contiguous"); 1605 1606 if (expandedDim >= static_cast<int64_t>(expandedShape.size())) 1607 return op->emitOpError("reassociation index ") 1608 << expandedDim << " is out of bounds"; 1609 1610 // Check if there are multiple dynamic dims in a reassociation group. 1611 if (ShapedType::isDynamic(expandedShape[expandedDim])) { 1612 if (foundDynamic && !allowMultipleDynamicDimsPerGroup) 1613 return op->emitOpError( 1614 "at most one dimension in a reassociation group may be dynamic"); 1615 foundDynamic = true; 1616 } 1617 } 1618 1619 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity. 1620 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic) 1621 return op->emitOpError("collapsed dim (") 1622 << collapsedDim 1623 << ") must be dynamic if and only if reassociation group is " 1624 "dynamic"; 1625 1626 // If all dims in the reassociation group are static, the size of the 1627 // collapsed dim can be verified. 1628 if (!foundDynamic) { 1629 int64_t groupSize = 1; 1630 for (int64_t expandedDim : group) 1631 groupSize *= expandedShape[expandedDim]; 1632 if (groupSize != collapsedShape[collapsedDim]) 1633 return op->emitOpError("collapsed dim size (") 1634 << collapsedShape[collapsedDim] 1635 << ") must equal reassociation group size (" << groupSize << ")"; 1636 } 1637 } 1638 1639 if (collapsedShape.empty()) { 1640 // Rank 0: All expanded dimensions must be 1. 1641 for (int64_t d : expandedShape) 1642 if (d != 1) 1643 return op->emitOpError( 1644 "rank 0 memrefs can only be extended/collapsed with/from ones"); 1645 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) { 1646 // Rank >= 1: Number of dimensions among all reassociation groups must match 1647 // the result memref rank. 1648 return op->emitOpError("expanded rank (") 1649 << expandedShape.size() 1650 << ") inconsistent with number of reassociation indices (" << nextDim 1651 << ")"; 1652 } 1653 1654 return success(); 1655 } 1656 1657 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1658 return getSymbolLessAffineMaps(getReassociationExprs()); 1659 } 1660 1661 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1662 return convertReassociationIndicesToExprs(getContext(), 1663 getReassociationIndices()); 1664 } 1665 1666 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1667 return getSymbolLessAffineMaps(getReassociationExprs()); 1668 } 1669 1670 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1671 return convertReassociationIndicesToExprs(getContext(), 1672 getReassociationIndices()); 1673 } 1674 1675 /// Compute the layout map after expanding a given source MemRef type with the 1676 /// specified reassociation indices. 1677 static FailureOr<AffineMap> 1678 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape, 1679 ArrayRef<ReassociationIndices> reassociation) { 1680 SmallVector<int64_t> srcStrides, resultStrides(resultShape.size(), 0); 1681 int64_t srcOffset; 1682 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 1683 return failure(); 1684 assert(srcStrides.size() == reassociation.size() && "invalid reassociation"); 1685 1686 // Ensure that inner strides are the fastest-varying ones. Other source layout 1687 // maps are currently not supported. 1688 int64_t lastStride = 0; 1689 for (int64_t s : llvm::reverse(srcStrides)) { 1690 if (!ShapedType::isDynamicStrideOrOffset(s)) { 1691 if (s < lastStride) 1692 return failure(); 1693 lastStride = s; 1694 } 1695 } 1696 1697 // Iterate over all reassociation groups from the back. Example: 1698 // strides = [1000, ?, 2] 1699 // source shape = [20, 10, 5] 1700 // result shape = [ 2, 10, 2, 5, 5] 1701 // reassociation = [[0, 1], [2, 3], [4]] 1702 for (const auto &it : llvm::reverse(llvm::zip(reassociation, srcStrides))) { 1703 ReassociationIndices indices = std::get<0>(it); 1704 int64_t srcGroupStride = std::get<1>(it); 1705 1706 // The first result dimension (least significant one) in each reassociation 1707 // group has the same stride as the corresponding source dimension. E.g.: 1708 // reassociation = [[0, 1], [2, 3], [4]] 1709 // | | | 1710 // v v v 1711 // 1000 ? 2 1712 resultStrides[indices.pop_back_val()] = srcGroupStride; 1713 1714 // Compute the strides for the remaining dims in the reassociation group. 1715 for (int64_t resultDim : llvm::reverse(indices)) { 1716 // E.g.: 1717 // reassociation = [[0, 1], [2, 3], [4]] 1718 // | 1719 // v 1720 // 1000 * 10 = 10000 1721 // 1722 // If the previous stride or the previous dimension was dynamic, then this 1723 // stride will also be dynamic. 1724 resultStrides[resultDim] = computeNextStride(resultStrides[resultDim + 1], 1725 resultShape[resultDim + 1]); 1726 } 1727 } 1728 1729 return makeStridedLinearLayoutMap(resultStrides, srcOffset, 1730 srcType.getContext()); 1731 } 1732 1733 static FailureOr<MemRefType> 1734 computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape, 1735 ArrayRef<ReassociationIndices> reassociation) { 1736 if (srcType.getLayout().isIdentity()) { 1737 // If the source is contiguous (i.e., no layout map specified), so is the 1738 // result. 1739 MemRefLayoutAttrInterface layout; 1740 return MemRefType::get(resultShape, srcType.getElementType(), layout, 1741 srcType.getMemorySpace()); 1742 } 1743 1744 // Source may not be contiguous. Compute the layout map. 1745 FailureOr<AffineMap> computedLayout = 1746 computeExpandedLayoutMap(srcType, resultShape, reassociation); 1747 if (failed(computedLayout)) 1748 return failure(); 1749 auto computedType = 1750 MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, 1751 srcType.getMemorySpaceAsInt()); 1752 return canonicalizeStridedLayout(computedType); 1753 } 1754 1755 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, 1756 ArrayRef<int64_t> resultShape, Value src, 1757 ArrayRef<ReassociationIndices> reassociation) { 1758 // Only ranked memref source values are supported. 1759 auto srcType = src.getType().cast<MemRefType>(); 1760 FailureOr<MemRefType> resultType = 1761 computeExpandedType(srcType, resultShape, reassociation); 1762 // Failure of this assertion usually indicates a problem with the source 1763 // type, e.g., could not get strides/offset. 1764 assert(succeeded(resultType) && "could not compute layout"); 1765 build(builder, result, *resultType, src, reassociation); 1766 } 1767 1768 LogicalResult ExpandShapeOp::verify() { 1769 MemRefType srcType = getSrcType(); 1770 MemRefType resultType = getResultType(); 1771 1772 // Verify result shape. 1773 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(), 1774 resultType.getShape(), 1775 getReassociationIndices(), 1776 /*allowMultipleDynamicDimsPerGroup=*/false))) 1777 return failure(); 1778 1779 // Compute expected result type (including layout map). 1780 FailureOr<MemRefType> expectedResultType = computeExpandedType( 1781 srcType, resultType.getShape(), getReassociationIndices()); 1782 if (failed(expectedResultType)) 1783 return emitOpError("invalid source layout map"); 1784 1785 // Check actual result type. 1786 auto canonicalizedResultType = canonicalizeStridedLayout(resultType); 1787 if (*expectedResultType != canonicalizedResultType) 1788 return emitOpError("expected expanded type to be ") 1789 << *expectedResultType << " but found " << canonicalizedResultType; 1790 1791 return success(); 1792 } 1793 1794 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1795 MLIRContext *context) { 1796 results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>, 1797 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>( 1798 context); 1799 } 1800 1801 /// Compute the layout map after collapsing a given source MemRef type with the 1802 /// specified reassociation indices. 1803 /// 1804 /// Note: All collapsed dims in a reassociation group must be contiguous. It is 1805 /// not possible to check this by inspecting a MemRefType in the general case. 1806 /// But it is assumed. If this is not the case, the behavior is undefined. 1807 static FailureOr<AffineMap> 1808 computeCollapsedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape, 1809 ArrayRef<ReassociationIndices> reassociation) { 1810 SmallVector<int64_t> srcStrides, resultStrides; 1811 int64_t srcOffset; 1812 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 1813 return failure(); 1814 assert(resultShape.size() == reassociation.size() && "invalid reassociation"); 1815 1816 // Iterate over all reassociation groups from the back. Example: 1817 // source shape = [20, ?, 5, 10, 2] 1818 // source strides = [ ?, ?, 800, 80, 4] 1819 // reassociation = [[0, 1], [2, 3], [4]] 1820 // result shape = [ ?, 50, 2] 1821 // 1822 // Note: The result shape is not needed in this computation. It is just used 1823 // check that the size of the reassociation is correct. 1824 for (ReassociationIndices group : llvm::reverse(reassociation)) { 1825 // A result dim has the same stride as the first dimension (least 1826 // significant one) in the corresponding reassociation group. E.g.: 1827 // reassociation = [[0, 1], [2, 3], [4]] 1828 // | | | 1829 // v v v 1830 // ? 80 4 1831 int64_t resultStride = srcStrides[group.pop_back_val()]; 1832 1833 // The following is just a best-effort check for non-contiguous source 1834 // strides within a reassociation group. E.g.: 1835 // reassociation = [[0, 1], [2, 3], [4]] 1836 // ^^^^^^ 1837 // Iteratively compute the next stride within the reassociation group 1838 // one-by-one. Start with the stride computed above. E.g.: 1839 // reassociation = [[0, 1], [2, 3], [4]] 1840 // | 1841 // v 1842 // nextStride = 80 1843 int64_t nextStride = resultStride; 1844 for (int64_t nextDim : llvm::reverse(group)) { 1845 // Next expected stride is previous stride multiplied by dim size, e.g.: 1846 // reassociation = [[0, 1], [2, 3], [4]] 1847 // | 1848 // v 1849 // nextStride = 80 * 10 = 800 1850 nextStride = 1851 computeNextStride(nextStride, srcType.getDimSize(nextDim + 1)); 1852 1853 // Ensure that the source actually has this stride value. E.g.: 1854 // source strides = [ ?, ?, 800, 80, 4] 1855 // | 1856 // v 1857 // same stride, OK 1858 // If strides are dynamic, we cannot verify anything statically. 1859 if (!ShapedType::isDynamicStrideOrOffset(srcStrides[nextDim]) && 1860 !ShapedType::isDynamicStrideOrOffset(nextStride) && 1861 srcStrides[nextDim] != nextStride) { 1862 // Attempting to collapse non-contiguous dimensions. This is forbidden. 1863 // Note: This check does not handle cases where strides and dimension 1864 // sizes are dynamic. Such dims could still turn out to be non- 1865 // contiguous at runtime. This check is only a best effort to catch 1866 // illegal collapses at verification time. 1867 return failure(); 1868 } 1869 } 1870 1871 resultStrides.push_back(resultStride); 1872 } 1873 1874 return makeStridedLinearLayoutMap( 1875 llvm::to_vector<8>(llvm::reverse(resultStrides)), srcOffset, 1876 srcType.getContext()); 1877 } 1878 1879 static MemRefType 1880 computeCollapsedType(MemRefType srcType, 1881 ArrayRef<ReassociationIndices> reassociation) { 1882 SmallVector<int64_t> resultShape; 1883 for (const ReassociationIndices &group : reassociation) { 1884 int64_t groupSize = 1; 1885 for (int64_t srcDim : group) { 1886 if (srcType.isDynamicDim(srcDim)) { 1887 // Source dim is dynamic, so the collapsed dim is also dynamic. 1888 groupSize = ShapedType::kDynamicSize; 1889 break; 1890 } 1891 1892 groupSize *= srcType.getDimSize(srcDim); 1893 } 1894 1895 resultShape.push_back(groupSize); 1896 } 1897 1898 if (srcType.getLayout().isIdentity()) { 1899 // If the source is contiguous (i.e., no layout map specified), so is the 1900 // result. 1901 MemRefLayoutAttrInterface layout; 1902 return MemRefType::get(resultShape, srcType.getElementType(), layout, 1903 srcType.getMemorySpace()); 1904 } 1905 1906 // Source may not be fully contiguous. Compute the layout map. 1907 // Note: Dimensions that are collapsed into a single dim are assumed to be 1908 // contiguous. 1909 FailureOr<AffineMap> computedLayout = 1910 computeCollapsedLayoutMap(srcType, resultShape, reassociation); 1911 assert(succeeded(computedLayout) && 1912 "invalid source layout map or collapsing non-contiguous dims"); 1913 auto computedType = 1914 MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, 1915 srcType.getMemorySpaceAsInt()); 1916 return canonicalizeStridedLayout(computedType); 1917 } 1918 1919 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1920 ArrayRef<ReassociationIndices> reassociation, 1921 ArrayRef<NamedAttribute> attrs) { 1922 auto srcType = src.getType().cast<MemRefType>(); 1923 MemRefType resultType = computeCollapsedType(srcType, reassociation); 1924 build(b, result, resultType, src, attrs); 1925 result.addAttribute(getReassociationAttrName(), 1926 getReassociationIndicesAttribute(b, reassociation)); 1927 } 1928 1929 LogicalResult CollapseShapeOp::verify() { 1930 MemRefType srcType = getSrcType(); 1931 MemRefType resultType = getResultType(); 1932 1933 // Verify result shape. 1934 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(), 1935 srcType.getShape(), getReassociationIndices(), 1936 /*allowMultipleDynamicDimsPerGroup=*/true))) 1937 return failure(); 1938 1939 // Compute expected result type (including layout map). 1940 MemRefType expectedResultType; 1941 if (srcType.getLayout().isIdentity()) { 1942 // If the source is contiguous (i.e., no layout map specified), so is the 1943 // result. 1944 MemRefLayoutAttrInterface layout; 1945 expectedResultType = 1946 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout, 1947 srcType.getMemorySpace()); 1948 } else { 1949 // Source may not be fully contiguous. Compute the layout map. 1950 // Note: Dimensions that are collapsed into a single dim are assumed to be 1951 // contiguous. 1952 FailureOr<AffineMap> computedLayout = computeCollapsedLayoutMap( 1953 srcType, resultType.getShape(), getReassociationIndices()); 1954 if (failed(computedLayout)) 1955 return emitOpError( 1956 "invalid source layout map or collapsing non-contiguous dims"); 1957 auto computedType = 1958 MemRefType::get(resultType.getShape(), srcType.getElementType(), 1959 *computedLayout, srcType.getMemorySpaceAsInt()); 1960 expectedResultType = canonicalizeStridedLayout(computedType); 1961 } 1962 1963 auto canonicalizedResultType = canonicalizeStridedLayout(resultType); 1964 if (expectedResultType != canonicalizedResultType) 1965 return emitOpError("expected collapsed type to be ") 1966 << expectedResultType << " but found " << canonicalizedResultType; 1967 1968 return success(); 1969 } 1970 1971 struct CollapseShapeOpMemRefCastFolder 1972 : public OpRewritePattern<CollapseShapeOp> { 1973 public: 1974 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1975 1976 LogicalResult matchAndRewrite(CollapseShapeOp op, 1977 PatternRewriter &rewriter) const override { 1978 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1979 if (!cast) 1980 return failure(); 1981 1982 if (!CastOp::canFoldIntoConsumerOp(cast)) 1983 return failure(); 1984 1985 Type newResultType = 1986 computeCollapsedType(cast.getOperand().getType().cast<MemRefType>(), 1987 op.getReassociationIndices()); 1988 1989 if (newResultType == op.getResultType()) { 1990 rewriter.updateRootInPlace( 1991 op, [&]() { op.srcMutable().assign(cast.source()); }); 1992 } else { 1993 Value newOp = rewriter.create<CollapseShapeOp>( 1994 op->getLoc(), cast.source(), op.getReassociationIndices()); 1995 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1996 } 1997 return success(); 1998 } 1999 }; 2000 2001 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 2002 MLIRContext *context) { 2003 results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>, 2004 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>, 2005 CollapseShapeOpMemRefCastFolder>(context); 2006 } 2007 2008 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 2009 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 2010 } 2011 2012 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 2013 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 2014 } 2015 2016 //===----------------------------------------------------------------------===// 2017 // ReshapeOp 2018 //===----------------------------------------------------------------------===// 2019 2020 LogicalResult ReshapeOp::verify() { 2021 Type operandType = source().getType(); 2022 Type resultType = result().getType(); 2023 2024 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 2025 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 2026 if (operandElementType != resultElementType) 2027 return emitOpError("element types of source and destination memref " 2028 "types should be the same"); 2029 2030 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 2031 if (!operandMemRefType.getLayout().isIdentity()) 2032 return emitOpError("source memref type should have identity affine map"); 2033 2034 int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0); 2035 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 2036 if (resultMemRefType) { 2037 if (!resultMemRefType.getLayout().isIdentity()) 2038 return emitOpError("result memref type should have identity affine map"); 2039 if (shapeSize == ShapedType::kDynamicSize) 2040 return emitOpError("cannot use shape operand with dynamic length to " 2041 "reshape to statically-ranked memref type"); 2042 if (shapeSize != resultMemRefType.getRank()) 2043 return emitOpError( 2044 "length of shape operand differs from the result's memref rank"); 2045 } 2046 return success(); 2047 } 2048 2049 //===----------------------------------------------------------------------===// 2050 // StoreOp 2051 //===----------------------------------------------------------------------===// 2052 2053 LogicalResult StoreOp::verify() { 2054 if (getNumOperands() != 2 + getMemRefType().getRank()) 2055 return emitOpError("store index operand count not equal to memref rank"); 2056 2057 return success(); 2058 } 2059 2060 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 2061 SmallVectorImpl<OpFoldResult> &results) { 2062 /// store(memrefcast) -> store 2063 return foldMemRefCast(*this, getValueToStore()); 2064 } 2065 2066 //===----------------------------------------------------------------------===// 2067 // SubViewOp 2068 //===----------------------------------------------------------------------===// 2069 2070 namespace { 2071 /// Helpers to write more idiomatic operations. 2072 namespace saturated_arith { 2073 struct Wrapper { 2074 explicit Wrapper(int64_t v) : v(v) {} 2075 operator int64_t() { return v; } 2076 int64_t v; 2077 }; 2078 Wrapper operator+(Wrapper a, int64_t b) { 2079 if (ShapedType::isDynamicStrideOrOffset(a) || 2080 ShapedType::isDynamicStrideOrOffset(b)) 2081 return Wrapper(ShapedType::kDynamicStrideOrOffset); 2082 return Wrapper(a.v + b); 2083 } 2084 Wrapper operator*(Wrapper a, int64_t b) { 2085 if (ShapedType::isDynamicStrideOrOffset(a) || 2086 ShapedType::isDynamicStrideOrOffset(b)) 2087 return Wrapper(ShapedType::kDynamicStrideOrOffset); 2088 return Wrapper(a.v * b); 2089 } 2090 } // namespace saturated_arith 2091 } // namespace 2092 2093 /// A subview result type can be fully inferred from the source type and the 2094 /// static representation of offsets, sizes and strides. Special sentinels 2095 /// encode the dynamic case. 2096 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 2097 ArrayRef<int64_t> staticOffsets, 2098 ArrayRef<int64_t> staticSizes, 2099 ArrayRef<int64_t> staticStrides) { 2100 unsigned rank = sourceMemRefType.getRank(); 2101 (void)rank; 2102 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 2103 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 2104 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 2105 2106 // Extract source offset and strides. 2107 int64_t sourceOffset; 2108 SmallVector<int64_t, 4> sourceStrides; 2109 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 2110 assert(succeeded(res) && "SubViewOp expected strided memref type"); 2111 (void)res; 2112 2113 // Compute target offset whose value is: 2114 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 2115 int64_t targetOffset = sourceOffset; 2116 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 2117 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 2118 using namespace saturated_arith; 2119 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 2120 } 2121 2122 // Compute target stride whose value is: 2123 // `sourceStrides_i * staticStrides_i`. 2124 SmallVector<int64_t, 4> targetStrides; 2125 targetStrides.reserve(staticOffsets.size()); 2126 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 2127 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 2128 using namespace saturated_arith; 2129 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 2130 } 2131 2132 // The type is now known. 2133 return MemRefType::get( 2134 staticSizes, sourceMemRefType.getElementType(), 2135 makeStridedLinearLayoutMap(targetStrides, targetOffset, 2136 sourceMemRefType.getContext()), 2137 sourceMemRefType.getMemorySpace()); 2138 } 2139 2140 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 2141 ArrayRef<OpFoldResult> offsets, 2142 ArrayRef<OpFoldResult> sizes, 2143 ArrayRef<OpFoldResult> strides) { 2144 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2145 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2146 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 2147 ShapedType::kDynamicStrideOrOffset); 2148 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 2149 ShapedType::kDynamicSize); 2150 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 2151 ShapedType::kDynamicStrideOrOffset); 2152 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 2153 staticSizes, staticStrides); 2154 } 2155 2156 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 2157 MemRefType sourceRankedTensorType, 2158 ArrayRef<int64_t> offsets, 2159 ArrayRef<int64_t> sizes, 2160 ArrayRef<int64_t> strides) { 2161 auto inferredType = 2162 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 2163 .cast<MemRefType>(); 2164 assert(inferredType.getRank() >= resultRank && "expected "); 2165 int rankDiff = inferredType.getRank() - resultRank; 2166 if (rankDiff > 0) { 2167 auto shape = inferredType.getShape(); 2168 llvm::SmallBitVector dimsToProject = 2169 getPositionsOfShapeOne(rankDiff, shape); 2170 SmallVector<int64_t> projectedShape; 2171 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 2172 if (!dimsToProject.test(pos)) 2173 projectedShape.push_back(shape[pos]); 2174 2175 AffineMap map = inferredType.getLayout().getAffineMap(); 2176 if (!map.isIdentity()) 2177 map = getProjectedMap(map, dimsToProject); 2178 inferredType = 2179 MemRefType::get(projectedShape, inferredType.getElementType(), map, 2180 inferredType.getMemorySpace()); 2181 } 2182 return inferredType; 2183 } 2184 2185 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 2186 MemRefType sourceRankedTensorType, 2187 ArrayRef<OpFoldResult> offsets, 2188 ArrayRef<OpFoldResult> sizes, 2189 ArrayRef<OpFoldResult> strides) { 2190 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2191 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2192 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 2193 ShapedType::kDynamicStrideOrOffset); 2194 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 2195 ShapedType::kDynamicSize); 2196 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 2197 ShapedType::kDynamicStrideOrOffset); 2198 return SubViewOp::inferRankReducedResultType( 2199 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 2200 staticStrides); 2201 } 2202 // Build a SubViewOp with mixed static and dynamic entries and custom result 2203 // type. If the type passed is nullptr, it is inferred. 2204 void SubViewOp::build(OpBuilder &b, OperationState &result, 2205 MemRefType resultType, Value source, 2206 ArrayRef<OpFoldResult> offsets, 2207 ArrayRef<OpFoldResult> sizes, 2208 ArrayRef<OpFoldResult> strides, 2209 ArrayRef<NamedAttribute> attrs) { 2210 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2211 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2212 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 2213 ShapedType::kDynamicStrideOrOffset); 2214 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 2215 ShapedType::kDynamicSize); 2216 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 2217 ShapedType::kDynamicStrideOrOffset); 2218 auto sourceMemRefType = source.getType().cast<MemRefType>(); 2219 // Structuring implementation this way avoids duplication between builders. 2220 if (!resultType) { 2221 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 2222 staticSizes, staticStrides) 2223 .cast<MemRefType>(); 2224 } 2225 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 2226 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 2227 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 2228 result.addAttributes(attrs); 2229 } 2230 2231 // Build a SubViewOp with mixed static and dynamic entries and inferred result 2232 // type. 2233 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2234 ArrayRef<OpFoldResult> offsets, 2235 ArrayRef<OpFoldResult> sizes, 2236 ArrayRef<OpFoldResult> strides, 2237 ArrayRef<NamedAttribute> attrs) { 2238 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2239 } 2240 2241 // Build a SubViewOp with static entries and inferred result type. 2242 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2243 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2244 ArrayRef<int64_t> strides, 2245 ArrayRef<NamedAttribute> attrs) { 2246 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2247 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 2248 return b.getI64IntegerAttr(v); 2249 })); 2250 SmallVector<OpFoldResult> sizeValues = 2251 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2252 return b.getI64IntegerAttr(v); 2253 })); 2254 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2255 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2256 return b.getI64IntegerAttr(v); 2257 })); 2258 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 2259 } 2260 2261 // Build a SubViewOp with dynamic entries and custom result type. If the 2262 // type passed is nullptr, it is inferred. 2263 void SubViewOp::build(OpBuilder &b, OperationState &result, 2264 MemRefType resultType, Value source, 2265 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2266 ArrayRef<int64_t> strides, 2267 ArrayRef<NamedAttribute> attrs) { 2268 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2269 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 2270 return b.getI64IntegerAttr(v); 2271 })); 2272 SmallVector<OpFoldResult> sizeValues = 2273 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2274 return b.getI64IntegerAttr(v); 2275 })); 2276 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2277 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2278 return b.getI64IntegerAttr(v); 2279 })); 2280 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 2281 attrs); 2282 } 2283 2284 // Build a SubViewOp with dynamic entries and custom result type. If the type 2285 // passed is nullptr, it is inferred. 2286 void SubViewOp::build(OpBuilder &b, OperationState &result, 2287 MemRefType resultType, Value source, ValueRange offsets, 2288 ValueRange sizes, ValueRange strides, 2289 ArrayRef<NamedAttribute> attrs) { 2290 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2291 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 2292 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 2293 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 2294 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2295 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 2296 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 2297 } 2298 2299 // Build a SubViewOp with dynamic entries and inferred result type. 2300 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2301 ValueRange offsets, ValueRange sizes, ValueRange strides, 2302 ArrayRef<NamedAttribute> attrs) { 2303 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2304 } 2305 2306 /// For ViewLikeOpInterface. 2307 Value SubViewOp::getViewSource() { return source(); } 2308 2309 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 2310 /// value). 2311 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 2312 AffineExpr t1Offset, t2Offset; 2313 SmallVector<AffineExpr> t1Strides, t2Strides; 2314 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 2315 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 2316 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 2317 } 2318 2319 /// Checks if `original` Type type can be rank reduced to `reduced` type. 2320 /// This function is slight variant of `is subsequence` algorithm where 2321 /// not matching dimension must be 1. 2322 static SliceVerificationResult 2323 isRankReducedMemRefType(MemRefType originalType, 2324 MemRefType candidateRankReducedType, 2325 ArrayRef<OpFoldResult> sizes) { 2326 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 2327 if (partialRes != SliceVerificationResult::Success) 2328 return partialRes; 2329 2330 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 2331 originalType, candidateRankReducedType, sizes); 2332 2333 // Sizes cannot be matched in case empty vector is returned. 2334 if (!optionalUnusedDimsMask.hasValue()) 2335 return SliceVerificationResult::LayoutMismatch; 2336 2337 if (originalType.getMemorySpace() != 2338 candidateRankReducedType.getMemorySpace()) 2339 return SliceVerificationResult::MemSpaceMismatch; 2340 2341 // No amount of stride dropping can reconcile incompatible offsets. 2342 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 2343 return SliceVerificationResult::LayoutMismatch; 2344 2345 return SliceVerificationResult::Success; 2346 } 2347 2348 template <typename OpTy> 2349 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 2350 OpTy op, Type expectedType) { 2351 auto memrefType = expectedType.cast<ShapedType>(); 2352 switch (result) { 2353 case SliceVerificationResult::Success: 2354 return success(); 2355 case SliceVerificationResult::RankTooLarge: 2356 return op.emitError("expected result rank to be smaller or equal to ") 2357 << "the source rank. "; 2358 case SliceVerificationResult::SizeMismatch: 2359 return op.emitError("expected result type to be ") 2360 << expectedType 2361 << " or a rank-reduced version. (mismatch of result sizes) "; 2362 case SliceVerificationResult::ElemTypeMismatch: 2363 return op.emitError("expected result element type to be ") 2364 << memrefType.getElementType(); 2365 case SliceVerificationResult::MemSpaceMismatch: 2366 return op.emitError("expected result and source memory spaces to match."); 2367 case SliceVerificationResult::LayoutMismatch: 2368 return op.emitError("expected result type to be ") 2369 << expectedType 2370 << " or a rank-reduced version. (mismatch of result layout) "; 2371 } 2372 llvm_unreachable("unexpected subview verification result"); 2373 } 2374 2375 /// Verifier for SubViewOp. 2376 LogicalResult SubViewOp::verify() { 2377 MemRefType baseType = getSourceType(); 2378 MemRefType subViewType = getType(); 2379 2380 // The base memref and the view memref should be in the same memory space. 2381 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2382 return emitError("different memory spaces specified for base memref " 2383 "type ") 2384 << baseType << " and subview memref type " << subViewType; 2385 2386 // Verify that the base memref type has a strided layout map. 2387 if (!isStrided(baseType)) 2388 return emitError("base type ") << baseType << " is not strided"; 2389 2390 // Verify result type against inferred type. 2391 auto expectedType = SubViewOp::inferResultType( 2392 baseType, extractFromI64ArrayAttr(static_offsets()), 2393 extractFromI64ArrayAttr(static_sizes()), 2394 extractFromI64ArrayAttr(static_strides())); 2395 2396 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 2397 subViewType, getMixedSizes()); 2398 return produceSubViewErrorMsg(result, *this, expectedType); 2399 } 2400 2401 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 2402 return os << "range " << range.offset << ":" << range.size << ":" 2403 << range.stride; 2404 } 2405 2406 /// Return the list of Range (i.e. offset, size, stride). Each Range 2407 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2408 /// with `b` at location `loc`. 2409 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2410 OpBuilder &b, Location loc) { 2411 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2412 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2413 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2414 SmallVector<Range, 8> res; 2415 unsigned rank = ranks[0]; 2416 res.reserve(rank); 2417 for (unsigned idx = 0; idx < rank; ++idx) { 2418 Value offset = 2419 op.isDynamicOffset(idx) 2420 ? op.getDynamicOffset(idx) 2421 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2422 Value size = 2423 op.isDynamicSize(idx) 2424 ? op.getDynamicSize(idx) 2425 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 2426 Value stride = 2427 op.isDynamicStride(idx) 2428 ? op.getDynamicStride(idx) 2429 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 2430 res.emplace_back(Range{offset, size, stride}); 2431 } 2432 return res; 2433 } 2434 2435 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2436 /// deduce the result type for the given `sourceType`. Additionally, reduce the 2437 /// rank of the inferred result type if `currentResultType` is lower rank than 2438 /// `currentSourceType`. Use this signature if `sourceType` is updated together 2439 /// with the result type. In this case, it is important to compute the dropped 2440 /// dimensions using `currentSourceType` whose strides align with 2441 /// `currentResultType`. 2442 static MemRefType getCanonicalSubViewResultType( 2443 MemRefType currentResultType, MemRefType currentSourceType, 2444 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 2445 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 2446 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2447 mixedSizes, mixedStrides) 2448 .cast<MemRefType>(); 2449 llvm::Optional<llvm::SmallBitVector> unusedDims = 2450 computeMemRefRankReductionMask(currentSourceType, currentResultType, 2451 mixedSizes); 2452 // Return nullptr as failure mode. 2453 if (!unusedDims) 2454 return nullptr; 2455 SmallVector<int64_t> shape; 2456 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 2457 if (unusedDims->test(sizes.index())) 2458 continue; 2459 shape.push_back(sizes.value()); 2460 } 2461 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 2462 if (!layoutMap.isIdentity()) 2463 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 2464 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 2465 nonRankReducedType.getMemorySpace()); 2466 } 2467 2468 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2469 /// deduce the result type. Additionally, reduce the rank of the inferred result 2470 /// type if `currentResultType` is lower rank than `sourceType`. 2471 static MemRefType getCanonicalSubViewResultType( 2472 MemRefType currentResultType, MemRefType sourceType, 2473 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 2474 ArrayRef<OpFoldResult> mixedStrides) { 2475 return getCanonicalSubViewResultType(currentResultType, sourceType, 2476 sourceType, mixedOffsets, mixedSizes, 2477 mixedStrides); 2478 } 2479 2480 /// Helper method to check if a `subview` operation is trivially a no-op. This 2481 /// is the case if the all offsets are zero, all strides are 1, and the source 2482 /// shape is same as the size of the subview. In such cases, the subview can be 2483 /// folded into its source. 2484 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 2485 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 2486 return false; 2487 2488 auto mixedOffsets = subViewOp.getMixedOffsets(); 2489 auto mixedSizes = subViewOp.getMixedSizes(); 2490 auto mixedStrides = subViewOp.getMixedStrides(); 2491 2492 // Check offsets are zero. 2493 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 2494 Optional<int64_t> intValue = getConstantIntValue(ofr); 2495 return !intValue || intValue.getValue() != 0; 2496 })) 2497 return false; 2498 2499 // Check strides are one. 2500 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 2501 Optional<int64_t> intValue = getConstantIntValue(ofr); 2502 return !intValue || intValue.getValue() != 1; 2503 })) 2504 return false; 2505 2506 // Check all size values are static and matches the (static) source shape. 2507 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 2508 for (const auto &size : llvm::enumerate(mixedSizes)) { 2509 Optional<int64_t> intValue = getConstantIntValue(size.value()); 2510 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 2511 return false; 2512 } 2513 // All conditions met. The `SubViewOp` is foldable as a no-op. 2514 return true; 2515 } 2516 2517 namespace { 2518 /// Pattern to rewrite a subview op with MemRefCast arguments. 2519 /// This essentially pushes memref.cast past its consuming subview when 2520 /// `canFoldIntoConsumerOp` is true. 2521 /// 2522 /// Example: 2523 /// ``` 2524 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2525 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2526 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2527 /// ``` 2528 /// is rewritten into: 2529 /// ``` 2530 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2531 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2532 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2533 /// ``` 2534 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2535 public: 2536 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2537 2538 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2539 PatternRewriter &rewriter) const override { 2540 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2541 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2542 return matchPattern(operand, matchConstantIndex()); 2543 })) 2544 return failure(); 2545 2546 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2547 if (!castOp) 2548 return failure(); 2549 2550 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2551 return failure(); 2552 2553 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 2554 // MemRefCastOp source operand type to infer the result type and the current 2555 // SubViewOp source operand type to compute the dropped dimensions if the 2556 // operation is rank-reducing. 2557 auto resultType = getCanonicalSubViewResultType( 2558 subViewOp.getType(), subViewOp.getSourceType(), 2559 castOp.source().getType().cast<MemRefType>(), 2560 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2561 subViewOp.getMixedStrides()); 2562 if (!resultType) 2563 return failure(); 2564 2565 Value newSubView = rewriter.create<SubViewOp>( 2566 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2567 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2568 subViewOp.static_sizes(), subViewOp.static_strides()); 2569 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2570 newSubView); 2571 return success(); 2572 } 2573 }; 2574 2575 /// Canonicalize subview ops that are no-ops. When the source shape is not same 2576 /// as a result shape due to use of `affine_map`. 2577 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 2578 public: 2579 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2580 2581 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2582 PatternRewriter &rewriter) const override { 2583 if (!isTrivialSubViewOp(subViewOp)) 2584 return failure(); 2585 if (subViewOp.getSourceType() == subViewOp.getType()) { 2586 rewriter.replaceOp(subViewOp, subViewOp.source()); 2587 return success(); 2588 } 2589 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2590 subViewOp.source()); 2591 return success(); 2592 } 2593 }; 2594 } // namespace 2595 2596 /// Return the canonical type of the result of a subview. 2597 struct SubViewReturnTypeCanonicalizer { 2598 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2599 ArrayRef<OpFoldResult> mixedSizes, 2600 ArrayRef<OpFoldResult> mixedStrides) { 2601 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2602 mixedOffsets, mixedSizes, 2603 mixedStrides); 2604 } 2605 }; 2606 2607 /// A canonicalizer wrapper to replace SubViewOps. 2608 struct SubViewCanonicalizer { 2609 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2610 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 2611 } 2612 }; 2613 2614 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2615 MLIRContext *context) { 2616 results 2617 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2618 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2619 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2620 } 2621 2622 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2623 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2624 auto sourceShapedType = source().getType().cast<ShapedType>(); 2625 2626 if (resultShapedType.hasStaticShape() && 2627 resultShapedType == sourceShapedType) { 2628 return getViewSource(); 2629 } 2630 2631 return {}; 2632 } 2633 2634 //===----------------------------------------------------------------------===// 2635 // TransposeOp 2636 //===----------------------------------------------------------------------===// 2637 2638 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2639 static MemRefType inferTransposeResultType(MemRefType memRefType, 2640 AffineMap permutationMap) { 2641 auto rank = memRefType.getRank(); 2642 auto originalSizes = memRefType.getShape(); 2643 // Compute permuted sizes. 2644 SmallVector<int64_t, 4> sizes(rank, 0); 2645 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2646 sizes[en.index()] = 2647 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2648 2649 // Compute permuted strides. 2650 int64_t offset; 2651 SmallVector<int64_t, 4> strides; 2652 auto res = getStridesAndOffset(memRefType, strides, offset); 2653 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2654 (void)res; 2655 auto map = 2656 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2657 map = permutationMap ? map.compose(permutationMap) : map; 2658 return MemRefType::Builder(memRefType) 2659 .setShape(sizes) 2660 .setLayout(AffineMapAttr::get(map)); 2661 } 2662 2663 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2664 AffineMapAttr permutation, 2665 ArrayRef<NamedAttribute> attrs) { 2666 auto permutationMap = permutation.getValue(); 2667 assert(permutationMap); 2668 2669 auto memRefType = in.getType().cast<MemRefType>(); 2670 // Compute result type. 2671 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2672 2673 build(b, result, resultType, in, attrs); 2674 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2675 } 2676 2677 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2678 void TransposeOp::print(OpAsmPrinter &p) { 2679 p << " " << in() << " " << permutation(); 2680 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); 2681 p << " : " << in().getType() << " to " << getType(); 2682 } 2683 2684 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { 2685 OpAsmParser::UnresolvedOperand in; 2686 AffineMap permutation; 2687 MemRefType srcType, dstType; 2688 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2689 parser.parseOptionalAttrDict(result.attributes) || 2690 parser.parseColonType(srcType) || 2691 parser.resolveOperand(in, srcType, result.operands) || 2692 parser.parseKeywordType("to", dstType) || 2693 parser.addTypeToList(dstType, result.types)) 2694 return failure(); 2695 2696 result.addAttribute(TransposeOp::getPermutationAttrName(), 2697 AffineMapAttr::get(permutation)); 2698 return success(); 2699 } 2700 2701 LogicalResult TransposeOp::verify() { 2702 if (!permutation().isPermutation()) 2703 return emitOpError("expected a permutation map"); 2704 if (permutation().getNumDims() != getShapedType().getRank()) 2705 return emitOpError("expected a permutation map of same rank as the input"); 2706 2707 auto srcType = in().getType().cast<MemRefType>(); 2708 auto dstType = getType().cast<MemRefType>(); 2709 auto transposedType = inferTransposeResultType(srcType, permutation()); 2710 if (dstType != transposedType) 2711 return emitOpError("output type ") 2712 << dstType << " does not match transposed input type " << srcType 2713 << ", " << transposedType; 2714 return success(); 2715 } 2716 2717 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2718 if (succeeded(foldMemRefCast(*this))) 2719 return getResult(); 2720 return {}; 2721 } 2722 2723 //===----------------------------------------------------------------------===// 2724 // ViewOp 2725 //===----------------------------------------------------------------------===// 2726 2727 LogicalResult ViewOp::verify() { 2728 auto baseType = getOperand(0).getType().cast<MemRefType>(); 2729 auto viewType = getType(); 2730 2731 // The base memref should have identity layout map (or none). 2732 if (!baseType.getLayout().isIdentity()) 2733 return emitError("unsupported map for base memref type ") << baseType; 2734 2735 // The result memref should have identity layout map (or none). 2736 if (!viewType.getLayout().isIdentity()) 2737 return emitError("unsupported map for result memref type ") << viewType; 2738 2739 // The base memref and the view memref should be in the same memory space. 2740 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2741 return emitError("different memory spaces specified for base memref " 2742 "type ") 2743 << baseType << " and view memref type " << viewType; 2744 2745 // Verify that we have the correct number of sizes for the result type. 2746 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2747 if (sizes().size() != numDynamicDims) 2748 return emitError("incorrect number of size operands for type ") << viewType; 2749 2750 return success(); 2751 } 2752 2753 Value ViewOp::getViewSource() { return source(); } 2754 2755 namespace { 2756 2757 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2758 using OpRewritePattern<ViewOp>::OpRewritePattern; 2759 2760 LogicalResult matchAndRewrite(ViewOp viewOp, 2761 PatternRewriter &rewriter) const override { 2762 // Return if none of the operands are constants. 2763 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2764 return matchPattern(operand, matchConstantIndex()); 2765 })) 2766 return failure(); 2767 2768 // Get result memref type. 2769 auto memrefType = viewOp.getType(); 2770 2771 // Get offset from old memref view type 'memRefType'. 2772 int64_t oldOffset; 2773 SmallVector<int64_t, 4> oldStrides; 2774 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2775 return failure(); 2776 assert(oldOffset == 0 && "Expected 0 offset"); 2777 2778 SmallVector<Value, 4> newOperands; 2779 2780 // Offset cannot be folded into result type. 2781 2782 // Fold any dynamic dim operands which are produced by a constant. 2783 SmallVector<int64_t, 4> newShapeConstants; 2784 newShapeConstants.reserve(memrefType.getRank()); 2785 2786 unsigned dynamicDimPos = 0; 2787 unsigned rank = memrefType.getRank(); 2788 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2789 int64_t dimSize = memrefType.getDimSize(dim); 2790 // If this is already static dimension, keep it. 2791 if (!ShapedType::isDynamic(dimSize)) { 2792 newShapeConstants.push_back(dimSize); 2793 continue; 2794 } 2795 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2796 if (auto constantIndexOp = 2797 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2798 // Dynamic shape dimension will be folded. 2799 newShapeConstants.push_back(constantIndexOp.value()); 2800 } else { 2801 // Dynamic shape dimension not folded; copy operand from old memref. 2802 newShapeConstants.push_back(dimSize); 2803 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2804 } 2805 dynamicDimPos++; 2806 } 2807 2808 // Create new memref type with constant folded dims. 2809 MemRefType newMemRefType = 2810 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2811 // Nothing new, don't fold. 2812 if (newMemRefType == memrefType) 2813 return failure(); 2814 2815 // Create new ViewOp. 2816 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2817 viewOp.getOperand(0), 2818 viewOp.byte_shift(), newOperands); 2819 // Insert a cast so we have the same type as the old memref type. 2820 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp); 2821 return success(); 2822 } 2823 }; 2824 2825 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2826 using OpRewritePattern<ViewOp>::OpRewritePattern; 2827 2828 LogicalResult matchAndRewrite(ViewOp viewOp, 2829 PatternRewriter &rewriter) const override { 2830 Value memrefOperand = viewOp.getOperand(0); 2831 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2832 if (!memrefCastOp) 2833 return failure(); 2834 Value allocOperand = memrefCastOp.getOperand(); 2835 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2836 if (!allocOp) 2837 return failure(); 2838 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2839 viewOp.byte_shift(), viewOp.sizes()); 2840 return success(); 2841 } 2842 }; 2843 2844 } // namespace 2845 2846 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2847 MLIRContext *context) { 2848 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2849 } 2850 2851 //===----------------------------------------------------------------------===// 2852 // AtomicRMWOp 2853 //===----------------------------------------------------------------------===// 2854 2855 LogicalResult AtomicRMWOp::verify() { 2856 if (getMemRefType().getRank() != getNumOperands() - 2) 2857 return emitOpError( 2858 "expects the number of subscripts to be equal to memref rank"); 2859 switch (kind()) { 2860 case arith::AtomicRMWKind::addf: 2861 case arith::AtomicRMWKind::maxf: 2862 case arith::AtomicRMWKind::minf: 2863 case arith::AtomicRMWKind::mulf: 2864 if (!value().getType().isa<FloatType>()) 2865 return emitOpError() << "with kind '" 2866 << arith::stringifyAtomicRMWKind(kind()) 2867 << "' expects a floating-point type"; 2868 break; 2869 case arith::AtomicRMWKind::addi: 2870 case arith::AtomicRMWKind::maxs: 2871 case arith::AtomicRMWKind::maxu: 2872 case arith::AtomicRMWKind::mins: 2873 case arith::AtomicRMWKind::minu: 2874 case arith::AtomicRMWKind::muli: 2875 case arith::AtomicRMWKind::ori: 2876 case arith::AtomicRMWKind::andi: 2877 if (!value().getType().isa<IntegerType>()) 2878 return emitOpError() << "with kind '" 2879 << arith::stringifyAtomicRMWKind(kind()) 2880 << "' expects an integer type"; 2881 break; 2882 default: 2883 break; 2884 } 2885 return success(); 2886 } 2887 2888 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2889 /// atomicrmw(memrefcast) -> atomicrmw 2890 if (succeeded(foldMemRefCast(*this, value()))) 2891 return getResult(); 2892 return OpFoldResult(); 2893 } 2894 2895 //===----------------------------------------------------------------------===// 2896 // TableGen'd op method definitions 2897 //===----------------------------------------------------------------------===// 2898 2899 #define GET_OP_CLASSES 2900 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2901