1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/SmallBitVector.h" 25 26 using namespace mlir; 27 using namespace mlir::memref; 28 29 namespace { 30 /// Idiomatic saturated operations on offsets, sizes and strides. 31 namespace saturated_arith { 32 struct Wrapper { 33 static Wrapper stride(int64_t v) { 34 return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0} 35 : Wrapper{false, v}; 36 } 37 static Wrapper offset(int64_t v) { 38 return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0} 39 : Wrapper{false, v}; 40 } 41 static Wrapper size(int64_t v) { 42 return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; 43 } 44 int64_t asOffset() { 45 return saturated ? ShapedType::kDynamicStrideOrOffset : v; 46 } 47 int64_t asSize() { return saturated ? ShapedType::kDynamicSize : v; } 48 int64_t asStride() { 49 return saturated ? ShapedType::kDynamicStrideOrOffset : v; 50 } 51 bool operator==(Wrapper other) { 52 return (saturated && other.saturated) || 53 (!saturated && !other.saturated && v == other.v); 54 } 55 bool operator!=(Wrapper other) { return !(*this == other); } 56 Wrapper operator+(Wrapper other) { 57 if (saturated || other.saturated) 58 return Wrapper{true, 0}; 59 return Wrapper{false, other.v + v}; 60 } 61 Wrapper operator*(Wrapper other) { 62 if (saturated || other.saturated) 63 return Wrapper{true, 0}; 64 return Wrapper{false, other.v * v}; 65 } 66 bool saturated; 67 int64_t v; 68 }; 69 } // namespace saturated_arith 70 } // namespace 71 72 /// Materialize a single constant operation from a given attribute value with 73 /// the desired resultant type. 74 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 75 Attribute value, Type type, 76 Location loc) { 77 if (arith::ConstantOp::isBuildableWith(value, type)) 78 return builder.create<arith::ConstantOp>(loc, value, type); 79 return nullptr; 80 } 81 82 //===----------------------------------------------------------------------===// 83 // Common canonicalization pattern support logic 84 //===----------------------------------------------------------------------===// 85 86 /// This is a common class used for patterns of the form 87 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 88 /// into the root operation directly. 89 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 90 bool folded = false; 91 for (OpOperand &operand : op->getOpOperands()) { 92 auto cast = operand.get().getDefiningOp<CastOp>(); 93 if (cast && operand.get() != inner && 94 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 95 operand.set(cast.getOperand()); 96 folded = true; 97 } 98 } 99 return success(folded); 100 } 101 102 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 103 /// type. 104 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 105 if (auto memref = type.dyn_cast<MemRefType>()) 106 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 107 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 108 return UnrankedTensorType::get(memref.getElementType()); 109 return NoneType::get(type.getContext()); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // AllocOp / AllocaOp 114 //===----------------------------------------------------------------------===// 115 116 template <typename AllocLikeOp> 117 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 118 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 119 "applies to only alloc or alloca"); 120 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 121 if (!memRefType) 122 return op.emitOpError("result must be a memref"); 123 124 if (static_cast<int64_t>(op.dynamicSizes().size()) != 125 memRefType.getNumDynamicDims()) 126 return op.emitOpError("dimension operand count does not equal memref " 127 "dynamic dimension count"); 128 129 unsigned numSymbols = 0; 130 if (!memRefType.getLayout().isIdentity()) 131 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 132 if (op.symbolOperands().size() != numSymbols) 133 return op.emitOpError("symbol operand count does not equal memref symbol " 134 "count: expected ") 135 << numSymbols << ", got " << op.symbolOperands().size(); 136 137 return success(); 138 } 139 140 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } 141 142 LogicalResult AllocaOp::verify() { 143 // An alloca op needs to have an ancestor with an allocation scope trait. 144 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 145 return emitOpError( 146 "requires an ancestor op with AutomaticAllocationScope trait"); 147 148 return verifyAllocLikeOp(*this); 149 } 150 151 namespace { 152 /// Fold constant dimensions into an alloc like operation. 153 template <typename AllocLikeOp> 154 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 155 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 156 157 LogicalResult matchAndRewrite(AllocLikeOp alloc, 158 PatternRewriter &rewriter) const override { 159 // Check to see if any dimensions operands are constants. If so, we can 160 // substitute and drop them. 161 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 162 return matchPattern(operand, matchConstantIndex()); 163 })) 164 return failure(); 165 166 auto memrefType = alloc.getType(); 167 168 // Ok, we have one or more constant operands. Collect the non-constant ones 169 // and keep track of the resultant memref type to build. 170 SmallVector<int64_t, 4> newShapeConstants; 171 newShapeConstants.reserve(memrefType.getRank()); 172 SmallVector<Value, 4> dynamicSizes; 173 174 unsigned dynamicDimPos = 0; 175 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 176 int64_t dimSize = memrefType.getDimSize(dim); 177 // If this is already static dimension, keep it. 178 if (dimSize != -1) { 179 newShapeConstants.push_back(dimSize); 180 continue; 181 } 182 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 183 auto *defOp = dynamicSize.getDefiningOp(); 184 if (auto constantIndexOp = 185 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 186 // Dynamic shape dimension will be folded. 187 newShapeConstants.push_back(constantIndexOp.value()); 188 } else { 189 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 190 newShapeConstants.push_back(-1); 191 dynamicSizes.push_back(dynamicSize); 192 } 193 dynamicDimPos++; 194 } 195 196 // Create new memref type (which will have fewer dynamic dimensions). 197 MemRefType newMemRefType = 198 MemRefType::Builder(memrefType).setShape(newShapeConstants); 199 assert(static_cast<int64_t>(dynamicSizes.size()) == 200 newMemRefType.getNumDynamicDims()); 201 202 // Create and insert the alloc op for the new memref. 203 auto newAlloc = rewriter.create<AllocLikeOp>( 204 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 205 alloc.alignmentAttr()); 206 // Insert a cast so we have the same type as the old alloc. 207 auto resultCast = 208 rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc); 209 210 rewriter.replaceOp(alloc, {resultCast}); 211 return success(); 212 } 213 }; 214 215 /// Fold alloc operations with no users or only store and dealloc uses. 216 template <typename T> 217 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 218 using OpRewritePattern<T>::OpRewritePattern; 219 220 LogicalResult matchAndRewrite(T alloc, 221 PatternRewriter &rewriter) const override { 222 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 223 if (auto storeOp = dyn_cast<StoreOp>(op)) 224 return storeOp.value() == alloc; 225 return !isa<DeallocOp>(op); 226 })) 227 return failure(); 228 229 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 230 rewriter.eraseOp(user); 231 232 rewriter.eraseOp(alloc); 233 return success(); 234 } 235 }; 236 } // namespace 237 238 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 239 MLIRContext *context) { 240 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 241 } 242 243 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 244 MLIRContext *context) { 245 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 246 context); 247 } 248 249 //===----------------------------------------------------------------------===// 250 // AllocaScopeOp 251 //===----------------------------------------------------------------------===// 252 253 void AllocaScopeOp::print(OpAsmPrinter &p) { 254 bool printBlockTerminators = false; 255 256 p << ' '; 257 if (!results().empty()) { 258 p << " -> (" << getResultTypes() << ")"; 259 printBlockTerminators = true; 260 } 261 p << ' '; 262 p.printRegion(bodyRegion(), 263 /*printEntryBlockArgs=*/false, 264 /*printBlockTerminators=*/printBlockTerminators); 265 p.printOptionalAttrDict((*this)->getAttrs()); 266 } 267 268 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { 269 // Create a region for the body. 270 result.regions.reserve(1); 271 Region *bodyRegion = result.addRegion(); 272 273 // Parse optional results type list. 274 if (parser.parseOptionalArrowTypeList(result.types)) 275 return failure(); 276 277 // Parse the body region. 278 if (parser.parseRegion(*bodyRegion, /*arguments=*/{})) 279 return failure(); 280 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 281 result.location); 282 283 // Parse the optional attribute list. 284 if (parser.parseOptionalAttrDict(result.attributes)) 285 return failure(); 286 287 return success(); 288 } 289 290 void AllocaScopeOp::getSuccessorRegions( 291 Optional<unsigned> index, ArrayRef<Attribute> operands, 292 SmallVectorImpl<RegionSuccessor> ®ions) { 293 if (index.hasValue()) { 294 regions.push_back(RegionSuccessor(getResults())); 295 return; 296 } 297 298 regions.push_back(RegionSuccessor(&bodyRegion())); 299 } 300 301 /// Given an operation, return whether this op is guaranteed to 302 /// allocate an AutomaticAllocationScopeResource 303 static bool isGuaranteedAutomaticAllocation(Operation *op) { 304 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 305 if (!interface) 306 return false; 307 for (auto res : op->getResults()) { 308 if (auto effect = 309 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 310 if (isa<SideEffects::AutomaticAllocationScopeResource>( 311 effect->getResource())) 312 return true; 313 } 314 } 315 return false; 316 } 317 318 /// Given an operation, return whether this op itself could 319 /// allocate an AutomaticAllocationScopeResource. Note that 320 /// this will not check whether an operation contained within 321 /// the op can allocate. 322 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { 323 // This op itself doesn't create a stack allocation, 324 // the inner allocation should be handled separately. 325 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) 326 return false; 327 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 328 if (!interface) 329 return true; 330 for (auto res : op->getResults()) { 331 if (auto effect = 332 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 333 if (isa<SideEffects::AutomaticAllocationScopeResource>( 334 effect->getResource())) 335 return true; 336 } 337 } 338 return false; 339 } 340 341 /// Return whether this op is the last non terminating op 342 /// in a region. That is to say, it is in a one-block region 343 /// and is only followed by a terminator. This prevents 344 /// extending the lifetime of allocations. 345 static bool lastNonTerminatorInRegion(Operation *op) { 346 return op->getNextNode() == op->getBlock()->getTerminator() && 347 op->getParentRegion()->getBlocks().size() == 1; 348 } 349 350 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope 351 /// or it contains no allocation. 352 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { 353 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 354 355 LogicalResult matchAndRewrite(AllocaScopeOp op, 356 PatternRewriter &rewriter) const override { 357 bool hasPotentialAlloca = 358 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) { 359 if (alloc == op) 360 return WalkResult::advance(); 361 if (isOpItselfPotentialAutomaticAllocation(alloc)) 362 return WalkResult::interrupt(); 363 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>()) 364 return WalkResult::skip(); 365 return WalkResult::advance(); 366 }).wasInterrupted(); 367 368 // If this contains no potential allocation, it is always legal to 369 // inline. Otherwise, consider two conditions: 370 if (hasPotentialAlloca) { 371 // If the parent isn't an allocation scope, or we are not the last 372 // non-terminator op in the parent, we will extend the lifetime. 373 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) 374 return failure(); 375 if (!lastNonTerminatorInRegion(op)) 376 return failure(); 377 } 378 379 Block *block = &op.getRegion().front(); 380 Operation *terminator = block->getTerminator(); 381 ValueRange results = terminator->getOperands(); 382 rewriter.mergeBlockBefore(block, op); 383 rewriter.replaceOp(op, results); 384 rewriter.eraseOp(terminator); 385 return success(); 386 } 387 }; 388 389 /// Move allocations into an allocation scope, if it is legal to 390 /// move them (e.g. their operands are available at the location 391 /// the op would be moved to). 392 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { 393 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 394 395 LogicalResult matchAndRewrite(AllocaScopeOp op, 396 PatternRewriter &rewriter) const override { 397 398 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 399 return failure(); 400 401 Operation *lastParentWithoutScope = op->getParentOp(); 402 403 if (!lastParentWithoutScope || 404 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>()) 405 return failure(); 406 407 // Only apply to if this is this last non-terminator 408 // op in the block (lest lifetime be extended) of a one 409 // block region 410 if (!lastNonTerminatorInRegion(op) || 411 !lastNonTerminatorInRegion(lastParentWithoutScope)) 412 return failure(); 413 414 while (!lastParentWithoutScope->getParentOp() 415 ->hasTrait<OpTrait::AutomaticAllocationScope>()) { 416 lastParentWithoutScope = lastParentWithoutScope->getParentOp(); 417 if (!lastParentWithoutScope || 418 !lastNonTerminatorInRegion(lastParentWithoutScope)) 419 return failure(); 420 } 421 assert(lastParentWithoutScope->getParentOp() 422 ->hasTrait<OpTrait::AutomaticAllocationScope>()); 423 424 Region *containingRegion = nullptr; 425 for (auto &r : lastParentWithoutScope->getRegions()) { 426 if (r.isAncestor(op->getParentRegion())) { 427 assert(containingRegion == nullptr && 428 "only one region can contain the op"); 429 containingRegion = &r; 430 } 431 } 432 assert(containingRegion && "op must be contained in a region"); 433 434 SmallVector<Operation *> toHoist; 435 op->walk([&](Operation *alloc) { 436 if (!isGuaranteedAutomaticAllocation(alloc)) 437 return WalkResult::skip(); 438 439 // If any operand is not defined before the location of 440 // lastParentWithoutScope (i.e. where we would hoist to), skip. 441 if (llvm::any_of(alloc->getOperands(), [&](Value v) { 442 return containingRegion->isAncestor(v.getParentRegion()); 443 })) 444 return WalkResult::skip(); 445 toHoist.push_back(alloc); 446 return WalkResult::advance(); 447 }); 448 449 if (toHoist.empty()) 450 return failure(); 451 rewriter.setInsertionPoint(lastParentWithoutScope); 452 for (auto *op : toHoist) { 453 auto *cloned = rewriter.clone(*op); 454 rewriter.replaceOp(op, cloned->getResults()); 455 } 456 return success(); 457 } 458 }; 459 460 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, 461 MLIRContext *context) { 462 results.add<AllocaScopeInliner, AllocaScopeHoister>(context); 463 } 464 465 //===----------------------------------------------------------------------===// 466 // AssumeAlignmentOp 467 //===----------------------------------------------------------------------===// 468 469 LogicalResult AssumeAlignmentOp::verify() { 470 if (!llvm::isPowerOf2_32(alignment())) 471 return emitOpError("alignment must be power of 2"); 472 return success(); 473 } 474 475 //===----------------------------------------------------------------------===// 476 // CastOp 477 //===----------------------------------------------------------------------===// 478 479 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 480 /// source memref. This is useful to to fold a memref.cast into a consuming op 481 /// and implement canonicalization patterns for ops in different dialects that 482 /// may consume the results of memref.cast operations. Such foldable memref.cast 483 /// operations are typically inserted as `view` and `subview` ops are 484 /// canonicalized, to preserve the type compatibility of their uses. 485 /// 486 /// Returns true when all conditions are met: 487 /// 1. source and result are ranked memrefs with strided semantics and same 488 /// element type and rank. 489 /// 2. each of the source's size, offset or stride has more static information 490 /// than the corresponding result's size, offset or stride. 491 /// 492 /// Example 1: 493 /// ```mlir 494 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 495 /// %2 = consumer %1 ... : memref<?x?xf32> ... 496 /// ``` 497 /// 498 /// may fold into: 499 /// 500 /// ```mlir 501 /// %2 = consumer %0 ... : memref<8x16xf32> ... 502 /// ``` 503 /// 504 /// Example 2: 505 /// ``` 506 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 507 /// to memref<?x?xf32> 508 /// consumer %1 : memref<?x?xf32> ... 509 /// ``` 510 /// 511 /// may fold into: 512 /// 513 /// ``` 514 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 515 /// ``` 516 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 517 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 518 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 519 520 // Requires ranked MemRefType. 521 if (!sourceType || !resultType) 522 return false; 523 524 // Requires same elemental type. 525 if (sourceType.getElementType() != resultType.getElementType()) 526 return false; 527 528 // Requires same rank. 529 if (sourceType.getRank() != resultType.getRank()) 530 return false; 531 532 // Only fold casts between strided memref forms. 533 int64_t sourceOffset, resultOffset; 534 SmallVector<int64_t, 4> sourceStrides, resultStrides; 535 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 536 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 537 return false; 538 539 // If cast is towards more static sizes along any dimension, don't fold. 540 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 541 auto ss = std::get<0>(it), st = std::get<1>(it); 542 if (ss != st) 543 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 544 return false; 545 } 546 547 // If cast is towards more static offset along any dimension, don't fold. 548 if (sourceOffset != resultOffset) 549 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 550 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 551 return false; 552 553 // If cast is towards more static strides along any dimension, don't fold. 554 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 555 auto ss = std::get<0>(it), st = std::get<1>(it); 556 if (ss != st) 557 if (ShapedType::isDynamicStrideOrOffset(ss) && 558 !ShapedType::isDynamicStrideOrOffset(st)) 559 return false; 560 } 561 562 return true; 563 } 564 565 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 566 if (inputs.size() != 1 || outputs.size() != 1) 567 return false; 568 Type a = inputs.front(), b = outputs.front(); 569 auto aT = a.dyn_cast<MemRefType>(); 570 auto bT = b.dyn_cast<MemRefType>(); 571 572 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 573 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 574 575 if (aT && bT) { 576 if (aT.getElementType() != bT.getElementType()) 577 return false; 578 if (aT.getLayout() != bT.getLayout()) { 579 int64_t aOffset, bOffset; 580 SmallVector<int64_t, 4> aStrides, bStrides; 581 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 582 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 583 aStrides.size() != bStrides.size()) 584 return false; 585 586 // Strides along a dimension/offset are compatible if the value in the 587 // source memref is static and the value in the target memref is the 588 // same. They are also compatible if either one is dynamic (see 589 // description of MemRefCastOp for details). 590 auto checkCompatible = [](int64_t a, int64_t b) { 591 return (a == MemRefType::getDynamicStrideOrOffset() || 592 b == MemRefType::getDynamicStrideOrOffset() || a == b); 593 }; 594 if (!checkCompatible(aOffset, bOffset)) 595 return false; 596 for (const auto &aStride : enumerate(aStrides)) 597 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 598 return false; 599 } 600 if (aT.getMemorySpace() != bT.getMemorySpace()) 601 return false; 602 603 // They must have the same rank, and any specified dimensions must match. 604 if (aT.getRank() != bT.getRank()) 605 return false; 606 607 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 608 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 609 if (aDim != -1 && bDim != -1 && aDim != bDim) 610 return false; 611 } 612 return true; 613 } else { 614 if (!aT && !uaT) 615 return false; 616 if (!bT && !ubT) 617 return false; 618 // Unranked to unranked casting is unsupported 619 if (uaT && ubT) 620 return false; 621 622 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 623 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 624 if (aEltType != bEltType) 625 return false; 626 627 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 628 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 629 return aMemSpace == bMemSpace; 630 } 631 632 return false; 633 } 634 635 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 636 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 637 } 638 639 //===----------------------------------------------------------------------===// 640 // CopyOp 641 //===----------------------------------------------------------------------===// 642 643 namespace { 644 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 645 /// and element type, the cast can be skipped. Such CastOps only cast the layout 646 /// of the type. 647 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 648 using OpRewritePattern<CopyOp>::OpRewritePattern; 649 650 LogicalResult matchAndRewrite(CopyOp copyOp, 651 PatternRewriter &rewriter) const override { 652 bool modified = false; 653 654 // Check source. 655 if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) { 656 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 657 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 658 659 if (fromType && toType) { 660 if (fromType.getShape() == toType.getShape() && 661 fromType.getElementType() == toType.getElementType()) { 662 rewriter.updateRootInPlace( 663 copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); 664 modified = true; 665 } 666 } 667 } 668 669 // Check target. 670 if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) { 671 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 672 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 673 674 if (fromType && toType) { 675 if (fromType.getShape() == toType.getShape() && 676 fromType.getElementType() == toType.getElementType()) { 677 rewriter.updateRootInPlace( 678 copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); 679 modified = true; 680 } 681 } 682 } 683 684 return success(modified); 685 } 686 }; 687 688 /// Fold memref.copy(%x, %x). 689 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 690 using OpRewritePattern<CopyOp>::OpRewritePattern; 691 692 LogicalResult matchAndRewrite(CopyOp copyOp, 693 PatternRewriter &rewriter) const override { 694 if (copyOp.source() != copyOp.target()) 695 return failure(); 696 697 rewriter.eraseOp(copyOp); 698 return success(); 699 } 700 }; 701 } // namespace 702 703 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 704 MLIRContext *context) { 705 results.add<FoldCopyOfCast, FoldSelfCopy>(context); 706 } 707 708 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands, 709 SmallVectorImpl<OpFoldResult> &results) { 710 /// copy(memrefcast) -> copy 711 bool folded = false; 712 Operation *op = *this; 713 for (OpOperand &operand : op->getOpOperands()) { 714 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 715 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 716 operand.set(castOp.getOperand()); 717 folded = true; 718 } 719 } 720 return success(folded); 721 } 722 723 //===----------------------------------------------------------------------===// 724 // DeallocOp 725 //===----------------------------------------------------------------------===// 726 727 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 728 SmallVectorImpl<OpFoldResult> &results) { 729 /// dealloc(memrefcast) -> dealloc 730 return foldMemRefCast(*this); 731 } 732 733 //===----------------------------------------------------------------------===// 734 // DimOp 735 //===----------------------------------------------------------------------===// 736 737 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 738 int64_t index) { 739 auto loc = result.location; 740 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 741 build(builder, result, source, indexValue); 742 } 743 744 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 745 Value index) { 746 auto indexTy = builder.getIndexType(); 747 build(builder, result, indexTy, source, index); 748 } 749 750 Optional<int64_t> DimOp::getConstantIndex() { 751 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 752 return constantOp.getValue().cast<IntegerAttr>().getInt(); 753 return {}; 754 } 755 756 LogicalResult DimOp::verify() { 757 // Assume unknown index to be in range. 758 Optional<int64_t> index = getConstantIndex(); 759 if (!index.hasValue()) 760 return success(); 761 762 // Check that constant index is not knowingly out of range. 763 auto type = source().getType(); 764 if (auto memrefType = type.dyn_cast<MemRefType>()) { 765 if (index.getValue() >= memrefType.getRank()) 766 return emitOpError("index is out of range"); 767 } else if (type.isa<UnrankedMemRefType>()) { 768 // Assume index to be in range. 769 } else { 770 llvm_unreachable("expected operand with memref type"); 771 } 772 return success(); 773 } 774 775 /// Return a map with key being elements in `vals` and data being number of 776 /// occurences of it. Use std::map, since the `vals` here are strides and the 777 /// dynamic stride value is the same as the tombstone value for 778 /// `DenseMap<int64_t>`. 779 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 780 std::map<int64_t, unsigned> numOccurences; 781 for (auto val : vals) 782 numOccurences[val]++; 783 return numOccurences; 784 } 785 786 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 787 /// to be a subset of `originalType` with some `1` entries erased, return the 788 /// set of indices that specifies which of the entries of `originalShape` are 789 /// dropped to obtain `reducedShape`. 790 /// This accounts for cases where there are multiple unit-dims, but only a 791 /// subset of those are dropped. For MemRefTypes these can be disambiguated 792 /// using the strides. If a dimension is dropped the stride must be dropped too. 793 static llvm::Optional<llvm::SmallBitVector> 794 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 795 ArrayRef<OpFoldResult> sizes) { 796 llvm::SmallBitVector unusedDims(originalType.getRank()); 797 if (originalType.getRank() == reducedType.getRank()) 798 return unusedDims; 799 800 for (const auto &dim : llvm::enumerate(sizes)) 801 if (auto attr = dim.value().dyn_cast<Attribute>()) 802 if (attr.cast<IntegerAttr>().getInt() == 1) 803 unusedDims.set(dim.index()); 804 805 // Early exit for the case where the number of unused dims matches the number 806 // of ranks reduced. 807 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() == 808 originalType.getRank()) 809 return unusedDims; 810 811 SmallVector<int64_t> originalStrides, candidateStrides; 812 int64_t originalOffset, candidateOffset; 813 if (failed( 814 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 815 failed( 816 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 817 return llvm::None; 818 819 // For memrefs, a dimension is truly dropped if its corresponding stride is 820 // also dropped. This is particularly important when more than one of the dims 821 // is 1. Track the number of occurences of the strides in the original type 822 // and the candidate type. For each unused dim that stride should not be 823 // present in the candidate type. Note that there could be multiple dimensions 824 // that have the same size. We dont need to exactly figure out which dim 825 // corresponds to which stride, we just need to verify that the number of 826 // reptitions of a stride in the original + number of unused dims with that 827 // stride == number of repititions of a stride in the candidate. 828 std::map<int64_t, unsigned> currUnaccountedStrides = 829 getNumOccurences(originalStrides); 830 std::map<int64_t, unsigned> candidateStridesNumOccurences = 831 getNumOccurences(candidateStrides); 832 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) { 833 if (!unusedDims.test(dim)) 834 continue; 835 int64_t originalStride = originalStrides[dim]; 836 if (currUnaccountedStrides[originalStride] > 837 candidateStridesNumOccurences[originalStride]) { 838 // This dim can be treated as dropped. 839 currUnaccountedStrides[originalStride]--; 840 continue; 841 } 842 if (currUnaccountedStrides[originalStride] == 843 candidateStridesNumOccurences[originalStride]) { 844 // The stride for this is not dropped. Keep as is. 845 unusedDims.reset(dim); 846 continue; 847 } 848 if (currUnaccountedStrides[originalStride] < 849 candidateStridesNumOccurences[originalStride]) { 850 // This should never happen. Cant have a stride in the reduced rank type 851 // that wasnt in the original one. 852 return llvm::None; 853 } 854 } 855 856 if ((int64_t)unusedDims.count() + reducedType.getRank() != 857 originalType.getRank()) 858 return llvm::None; 859 return unusedDims; 860 } 861 862 llvm::SmallBitVector SubViewOp::getDroppedDims() { 863 MemRefType sourceType = getSourceType(); 864 MemRefType resultType = getType(); 865 llvm::Optional<llvm::SmallBitVector> unusedDims = 866 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 867 assert(unusedDims && "unable to find unused dims of subview"); 868 return *unusedDims; 869 } 870 871 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 872 // All forms of folding require a known index. 873 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 874 if (!index) 875 return {}; 876 877 // Folding for unranked types (UnrankedMemRefType) is not supported. 878 auto memrefType = source().getType().dyn_cast<MemRefType>(); 879 if (!memrefType) 880 return {}; 881 882 // Fold if the shape extent along the given index is known. 883 if (!memrefType.isDynamicDim(index.getInt())) { 884 Builder builder(getContext()); 885 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 886 } 887 888 // The size at the given index is now known to be a dynamic size. 889 unsigned unsignedIndex = index.getValue().getZExtValue(); 890 891 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 892 Operation *definingOp = source().getDefiningOp(); 893 894 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 895 return *(alloc.getDynamicSizes().begin() + 896 memrefType.getDynamicDimIndex(unsignedIndex)); 897 898 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 899 return *(alloca.getDynamicSizes().begin() + 900 memrefType.getDynamicDimIndex(unsignedIndex)); 901 902 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 903 return *(view.getDynamicSizes().begin() + 904 memrefType.getDynamicDimIndex(unsignedIndex)); 905 906 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 907 llvm::SmallBitVector unusedDims = subview.getDroppedDims(); 908 unsigned resultIndex = 0; 909 unsigned sourceRank = subview.getSourceType().getRank(); 910 unsigned sourceIndex = 0; 911 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 912 if (unusedDims.test(i)) 913 continue; 914 if (resultIndex == unsignedIndex) { 915 sourceIndex = i; 916 break; 917 } 918 resultIndex++; 919 } 920 assert(subview.isDynamicSize(sourceIndex) && 921 "expected dynamic subview size"); 922 return subview.getDynamicSize(sourceIndex); 923 } 924 925 if (auto sizeInterface = 926 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 927 assert(sizeInterface.isDynamicSize(unsignedIndex) && 928 "Expected dynamic subview size"); 929 return sizeInterface.getDynamicSize(unsignedIndex); 930 } 931 932 // dim(memrefcast) -> dim 933 if (succeeded(foldMemRefCast(*this))) 934 return getResult(); 935 936 return {}; 937 } 938 939 namespace { 940 /// Fold dim of a memref reshape operation to a load into the reshape's shape 941 /// operand. 942 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 943 using OpRewritePattern<DimOp>::OpRewritePattern; 944 945 LogicalResult matchAndRewrite(DimOp dim, 946 PatternRewriter &rewriter) const override { 947 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 948 949 if (!reshape) 950 return failure(); 951 952 // Place the load directly after the reshape to ensure that the shape memref 953 // was not mutated. 954 rewriter.setInsertionPointAfter(reshape); 955 Location loc = dim.getLoc(); 956 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 957 if (load.getType() != dim.getType()) 958 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 959 rewriter.replaceOp(dim, load); 960 return success(); 961 } 962 }; 963 964 } // namespace 965 966 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 967 MLIRContext *context) { 968 results.add<DimOfMemRefReshape>(context); 969 } 970 971 // --------------------------------------------------------------------------- 972 // DmaStartOp 973 // --------------------------------------------------------------------------- 974 975 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 976 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 977 ValueRange destIndices, Value numElements, 978 Value tagMemRef, ValueRange tagIndices, Value stride, 979 Value elementsPerStride) { 980 result.addOperands(srcMemRef); 981 result.addOperands(srcIndices); 982 result.addOperands(destMemRef); 983 result.addOperands(destIndices); 984 result.addOperands({numElements, tagMemRef}); 985 result.addOperands(tagIndices); 986 if (stride) 987 result.addOperands({stride, elementsPerStride}); 988 } 989 990 void DmaStartOp::print(OpAsmPrinter &p) { 991 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " 992 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() 993 << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; 994 if (isStrided()) 995 p << ", " << getStride() << ", " << getNumElementsPerStride(); 996 997 p.printOptionalAttrDict((*this)->getAttrs()); 998 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 999 << ", " << getTagMemRef().getType(); 1000 } 1001 1002 // Parse DmaStartOp. 1003 // Ex: 1004 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 1005 // %tag[%index], %stride, %num_elt_per_stride : 1006 // : memref<3076 x f32, 0>, 1007 // memref<1024 x f32, 2>, 1008 // memref<1 x i32> 1009 // 1010 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 1011 OpAsmParser::UnresolvedOperand srcMemRefInfo; 1012 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos; 1013 OpAsmParser::UnresolvedOperand dstMemRefInfo; 1014 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos; 1015 OpAsmParser::UnresolvedOperand numElementsInfo; 1016 OpAsmParser::UnresolvedOperand tagMemrefInfo; 1017 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos; 1018 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo; 1019 1020 SmallVector<Type, 3> types; 1021 auto indexType = parser.getBuilder().getIndexType(); 1022 1023 // Parse and resolve the following list of operands: 1024 // *) source memref followed by its indices (in square brackets). 1025 // *) destination memref followed by its indices (in square brackets). 1026 // *) dma size in KiB. 1027 if (parser.parseOperand(srcMemRefInfo) || 1028 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 1029 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 1030 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 1031 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1032 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 1033 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 1034 return failure(); 1035 1036 // Parse optional stride and elements per stride. 1037 if (parser.parseTrailingOperandList(strideInfo)) 1038 return failure(); 1039 1040 bool isStrided = strideInfo.size() == 2; 1041 if (!strideInfo.empty() && !isStrided) { 1042 return parser.emitError(parser.getNameLoc(), 1043 "expected two stride related operands"); 1044 } 1045 1046 if (parser.parseColonTypeList(types)) 1047 return failure(); 1048 if (types.size() != 3) 1049 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 1050 1051 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 1052 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 1053 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 1054 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 1055 // size should be an index. 1056 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 1057 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 1058 // tag indices should be index. 1059 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 1060 return failure(); 1061 1062 if (isStrided) { 1063 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1064 return failure(); 1065 } 1066 1067 return success(); 1068 } 1069 1070 LogicalResult DmaStartOp::verify() { 1071 unsigned numOperands = getNumOperands(); 1072 1073 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1074 // the number of elements. 1075 if (numOperands < 4) 1076 return emitOpError("expected at least 4 operands"); 1077 1078 // Check types of operands. The order of these calls is important: the later 1079 // calls rely on some type properties to compute the operand position. 1080 // 1. Source memref. 1081 if (!getSrcMemRef().getType().isa<MemRefType>()) 1082 return emitOpError("expected source to be of memref type"); 1083 if (numOperands < getSrcMemRefRank() + 4) 1084 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 1085 << " operands"; 1086 if (!getSrcIndices().empty() && 1087 !llvm::all_of(getSrcIndices().getTypes(), 1088 [](Type t) { return t.isIndex(); })) 1089 return emitOpError("expected source indices to be of index type"); 1090 1091 // 2. Destination memref. 1092 if (!getDstMemRef().getType().isa<MemRefType>()) 1093 return emitOpError("expected destination to be of memref type"); 1094 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 1095 if (numOperands < numExpectedOperands) 1096 return emitOpError() << "expected at least " << numExpectedOperands 1097 << " operands"; 1098 if (!getDstIndices().empty() && 1099 !llvm::all_of(getDstIndices().getTypes(), 1100 [](Type t) { return t.isIndex(); })) 1101 return emitOpError("expected destination indices to be of index type"); 1102 1103 // 3. Number of elements. 1104 if (!getNumElements().getType().isIndex()) 1105 return emitOpError("expected num elements to be of index type"); 1106 1107 // 4. Tag memref. 1108 if (!getTagMemRef().getType().isa<MemRefType>()) 1109 return emitOpError("expected tag to be of memref type"); 1110 numExpectedOperands += getTagMemRefRank(); 1111 if (numOperands < numExpectedOperands) 1112 return emitOpError() << "expected at least " << numExpectedOperands 1113 << " operands"; 1114 if (!getTagIndices().empty() && 1115 !llvm::all_of(getTagIndices().getTypes(), 1116 [](Type t) { return t.isIndex(); })) 1117 return emitOpError("expected tag indices to be of index type"); 1118 1119 // Optional stride-related operands must be either both present or both 1120 // absent. 1121 if (numOperands != numExpectedOperands && 1122 numOperands != numExpectedOperands + 2) 1123 return emitOpError("incorrect number of operands"); 1124 1125 // 5. Strides. 1126 if (isStrided()) { 1127 if (!getStride().getType().isIndex() || 1128 !getNumElementsPerStride().getType().isIndex()) 1129 return emitOpError( 1130 "expected stride and num elements per stride to be of type index"); 1131 } 1132 1133 return success(); 1134 } 1135 1136 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1137 SmallVectorImpl<OpFoldResult> &results) { 1138 /// dma_start(memrefcast) -> dma_start 1139 return foldMemRefCast(*this); 1140 } 1141 1142 // --------------------------------------------------------------------------- 1143 // DmaWaitOp 1144 // --------------------------------------------------------------------------- 1145 1146 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1147 SmallVectorImpl<OpFoldResult> &results) { 1148 /// dma_wait(memrefcast) -> dma_wait 1149 return foldMemRefCast(*this); 1150 } 1151 1152 LogicalResult DmaWaitOp::verify() { 1153 // Check that the number of tag indices matches the tagMemRef rank. 1154 unsigned numTagIndices = tagIndices().size(); 1155 unsigned tagMemRefRank = getTagMemRefRank(); 1156 if (numTagIndices != tagMemRefRank) 1157 return emitOpError() << "expected tagIndices to have the same number of " 1158 "elements as the tagMemRef rank, expected " 1159 << tagMemRefRank << ", but got " << numTagIndices; 1160 return success(); 1161 } 1162 1163 //===----------------------------------------------------------------------===// 1164 // GenericAtomicRMWOp 1165 //===----------------------------------------------------------------------===// 1166 1167 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, 1168 Value memref, ValueRange ivs) { 1169 result.addOperands(memref); 1170 result.addOperands(ivs); 1171 1172 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) { 1173 Type elementType = memrefType.getElementType(); 1174 result.addTypes(elementType); 1175 1176 Region *bodyRegion = result.addRegion(); 1177 bodyRegion->push_back(new Block()); 1178 bodyRegion->addArgument(elementType, memref.getLoc()); 1179 } 1180 } 1181 1182 LogicalResult GenericAtomicRMWOp::verify() { 1183 auto &body = getRegion(); 1184 if (body.getNumArguments() != 1) 1185 return emitOpError("expected single number of entry block arguments"); 1186 1187 if (getResult().getType() != body.getArgument(0).getType()) 1188 return emitOpError("expected block argument of the same type result type"); 1189 1190 bool hasSideEffects = 1191 body.walk([&](Operation *nestedOp) { 1192 if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) 1193 return WalkResult::advance(); 1194 nestedOp->emitError( 1195 "body of 'memref.generic_atomic_rmw' should contain " 1196 "only operations with no side effects"); 1197 return WalkResult::interrupt(); 1198 }) 1199 .wasInterrupted(); 1200 return hasSideEffects ? failure() : success(); 1201 } 1202 1203 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, 1204 OperationState &result) { 1205 OpAsmParser::UnresolvedOperand memref; 1206 Type memrefType; 1207 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs; 1208 1209 Type indexType = parser.getBuilder().getIndexType(); 1210 if (parser.parseOperand(memref) || 1211 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || 1212 parser.parseColonType(memrefType) || 1213 parser.resolveOperand(memref, memrefType, result.operands) || 1214 parser.resolveOperands(ivs, indexType, result.operands)) 1215 return failure(); 1216 1217 Region *body = result.addRegion(); 1218 if (parser.parseRegion(*body, {}) || 1219 parser.parseOptionalAttrDict(result.attributes)) 1220 return failure(); 1221 result.types.push_back(memrefType.cast<MemRefType>().getElementType()); 1222 return success(); 1223 } 1224 1225 void GenericAtomicRMWOp::print(OpAsmPrinter &p) { 1226 p << ' ' << memref() << "[" << indices() << "] : " << memref().getType() 1227 << ' '; 1228 p.printRegion(getRegion()); 1229 p.printOptionalAttrDict((*this)->getAttrs()); 1230 } 1231 1232 //===----------------------------------------------------------------------===// 1233 // AtomicYieldOp 1234 //===----------------------------------------------------------------------===// 1235 1236 LogicalResult AtomicYieldOp::verify() { 1237 Type parentType = (*this)->getParentOp()->getResultTypes().front(); 1238 Type resultType = result().getType(); 1239 if (parentType != resultType) 1240 return emitOpError() << "types mismatch between yield op: " << resultType 1241 << " and its parent: " << parentType; 1242 return success(); 1243 } 1244 1245 //===----------------------------------------------------------------------===// 1246 // GlobalOp 1247 //===----------------------------------------------------------------------===// 1248 1249 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1250 TypeAttr type, 1251 Attribute initialValue) { 1252 p << type; 1253 if (!op.isExternal()) { 1254 p << " = "; 1255 if (op.isUninitialized()) 1256 p << "uninitialized"; 1257 else 1258 p.printAttributeWithoutType(initialValue); 1259 } 1260 } 1261 1262 static ParseResult 1263 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1264 Attribute &initialValue) { 1265 Type type; 1266 if (parser.parseType(type)) 1267 return failure(); 1268 1269 auto memrefType = type.dyn_cast<MemRefType>(); 1270 if (!memrefType || !memrefType.hasStaticShape()) 1271 return parser.emitError(parser.getNameLoc()) 1272 << "type should be static shaped memref, but got " << type; 1273 typeAttr = TypeAttr::get(type); 1274 1275 if (parser.parseOptionalEqual()) 1276 return success(); 1277 1278 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1279 initialValue = UnitAttr::get(parser.getContext()); 1280 return success(); 1281 } 1282 1283 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1284 if (parser.parseAttribute(initialValue, tensorType)) 1285 return failure(); 1286 if (!initialValue.isa<ElementsAttr>()) 1287 return parser.emitError(parser.getNameLoc()) 1288 << "initial value should be a unit or elements attribute"; 1289 return success(); 1290 } 1291 1292 LogicalResult GlobalOp::verify() { 1293 auto memrefType = type().dyn_cast<MemRefType>(); 1294 if (!memrefType || !memrefType.hasStaticShape()) 1295 return emitOpError("type should be static shaped memref, but got ") 1296 << type(); 1297 1298 // Verify that the initial value, if present, is either a unit attribute or 1299 // an elements attribute. 1300 if (initial_value().hasValue()) { 1301 Attribute initValue = initial_value().getValue(); 1302 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1303 return emitOpError("initial value should be a unit or elements " 1304 "attribute, but got ") 1305 << initValue; 1306 1307 // Check that the type of the initial value is compatible with the type of 1308 // the global variable. 1309 if (initValue.isa<ElementsAttr>()) { 1310 Type initType = initValue.getType(); 1311 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1312 if (initType != tensorType) 1313 return emitOpError("initial value expected to be of type ") 1314 << tensorType << ", but was of type " << initType; 1315 } 1316 } 1317 1318 if (Optional<uint64_t> alignAttr = alignment()) { 1319 uint64_t alignment = alignAttr.getValue(); 1320 1321 if (!llvm::isPowerOf2_64(alignment)) 1322 return emitError() << "alignment attribute value " << alignment 1323 << " is not a power of 2"; 1324 } 1325 1326 // TODO: verify visibility for declarations. 1327 return success(); 1328 } 1329 1330 ElementsAttr GlobalOp::getConstantInitValue() { 1331 auto initVal = initial_value(); 1332 if (constant() && initVal.hasValue()) 1333 return initVal.getValue().cast<ElementsAttr>(); 1334 return {}; 1335 } 1336 1337 //===----------------------------------------------------------------------===// 1338 // GetGlobalOp 1339 //===----------------------------------------------------------------------===// 1340 1341 LogicalResult 1342 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1343 // Verify that the result type is same as the type of the referenced 1344 // memref.global op. 1345 auto global = 1346 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1347 if (!global) 1348 return emitOpError("'") 1349 << name() << "' does not reference a valid global memref"; 1350 1351 Type resultType = result().getType(); 1352 if (global.type() != resultType) 1353 return emitOpError("result type ") 1354 << resultType << " does not match type " << global.type() 1355 << " of the global memref @" << name(); 1356 return success(); 1357 } 1358 1359 //===----------------------------------------------------------------------===// 1360 // LoadOp 1361 //===----------------------------------------------------------------------===// 1362 1363 LogicalResult LoadOp::verify() { 1364 if (getNumOperands() != 1 + getMemRefType().getRank()) 1365 return emitOpError("incorrect number of indices for load"); 1366 return success(); 1367 } 1368 1369 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1370 /// load(memrefcast) -> load 1371 if (succeeded(foldMemRefCast(*this))) 1372 return getResult(); 1373 return OpFoldResult(); 1374 } 1375 1376 //===----------------------------------------------------------------------===// 1377 // PrefetchOp 1378 //===----------------------------------------------------------------------===// 1379 1380 void PrefetchOp::print(OpAsmPrinter &p) { 1381 p << " " << memref() << '['; 1382 p.printOperands(indices()); 1383 p << ']' << ", " << (isWrite() ? "write" : "read"); 1384 p << ", locality<" << localityHint(); 1385 p << ">, " << (isDataCache() ? "data" : "instr"); 1386 p.printOptionalAttrDict( 1387 (*this)->getAttrs(), 1388 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1389 p << " : " << getMemRefType(); 1390 } 1391 1392 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { 1393 OpAsmParser::UnresolvedOperand memrefInfo; 1394 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo; 1395 IntegerAttr localityHint; 1396 MemRefType type; 1397 StringRef readOrWrite, cacheType; 1398 1399 auto indexTy = parser.getBuilder().getIndexType(); 1400 auto i32Type = parser.getBuilder().getIntegerType(32); 1401 if (parser.parseOperand(memrefInfo) || 1402 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1403 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1404 parser.parseComma() || parser.parseKeyword("locality") || 1405 parser.parseLess() || 1406 parser.parseAttribute(localityHint, i32Type, "localityHint", 1407 result.attributes) || 1408 parser.parseGreater() || parser.parseComma() || 1409 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1410 parser.resolveOperand(memrefInfo, type, result.operands) || 1411 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1412 return failure(); 1413 1414 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1415 return parser.emitError(parser.getNameLoc(), 1416 "rw specifier has to be 'read' or 'write'"); 1417 result.addAttribute( 1418 PrefetchOp::getIsWriteAttrName(), 1419 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1420 1421 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1422 return parser.emitError(parser.getNameLoc(), 1423 "cache type has to be 'data' or 'instr'"); 1424 1425 result.addAttribute( 1426 PrefetchOp::getIsDataCacheAttrName(), 1427 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1428 1429 return success(); 1430 } 1431 1432 LogicalResult PrefetchOp::verify() { 1433 if (getNumOperands() != 1 + getMemRefType().getRank()) 1434 return emitOpError("too few indices"); 1435 1436 return success(); 1437 } 1438 1439 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1440 SmallVectorImpl<OpFoldResult> &results) { 1441 // prefetch(memrefcast) -> prefetch 1442 return foldMemRefCast(*this); 1443 } 1444 1445 //===----------------------------------------------------------------------===// 1446 // RankOp 1447 //===----------------------------------------------------------------------===// 1448 1449 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1450 // Constant fold rank when the rank of the operand is known. 1451 auto type = getOperand().getType(); 1452 auto shapedType = type.dyn_cast<ShapedType>(); 1453 if (shapedType && shapedType.hasRank()) 1454 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1455 return IntegerAttr(); 1456 } 1457 1458 //===----------------------------------------------------------------------===// 1459 // ReinterpretCastOp 1460 //===----------------------------------------------------------------------===// 1461 1462 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1463 /// `staticSizes` and `staticStrides` are automatically filled with 1464 /// source-memref-rank sentinel values that encode dynamic entries. 1465 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1466 MemRefType resultType, Value source, 1467 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1468 ArrayRef<OpFoldResult> strides, 1469 ArrayRef<NamedAttribute> attrs) { 1470 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1471 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1472 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1473 ShapedType::kDynamicStrideOrOffset); 1474 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1475 ShapedType::kDynamicSize); 1476 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1477 ShapedType::kDynamicStrideOrOffset); 1478 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1479 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1480 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1481 result.addAttributes(attrs); 1482 } 1483 1484 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1485 MemRefType resultType, Value source, 1486 int64_t offset, ArrayRef<int64_t> sizes, 1487 ArrayRef<int64_t> strides, 1488 ArrayRef<NamedAttribute> attrs) { 1489 SmallVector<OpFoldResult> sizeValues = 1490 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1491 return b.getI64IntegerAttr(v); 1492 })); 1493 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1494 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1495 return b.getI64IntegerAttr(v); 1496 })); 1497 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1498 strideValues, attrs); 1499 } 1500 1501 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1502 MemRefType resultType, Value source, Value offset, 1503 ValueRange sizes, ValueRange strides, 1504 ArrayRef<NamedAttribute> attrs) { 1505 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1506 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1507 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1508 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1509 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1510 } 1511 1512 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1513 // completed automatically, like we have for subview and extract_slice. 1514 LogicalResult ReinterpretCastOp::verify() { 1515 // The source and result memrefs should be in the same memory space. 1516 auto srcType = source().getType().cast<BaseMemRefType>(); 1517 auto resultType = getType().cast<MemRefType>(); 1518 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1519 return emitError("different memory spaces specified for source type ") 1520 << srcType << " and result memref type " << resultType; 1521 if (srcType.getElementType() != resultType.getElementType()) 1522 return emitError("different element types specified for source type ") 1523 << srcType << " and result memref type " << resultType; 1524 1525 // Match sizes in result memref type and in static_sizes attribute. 1526 for (auto &en : llvm::enumerate(llvm::zip( 1527 resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { 1528 int64_t resultSize = std::get<0>(en.value()); 1529 int64_t expectedSize = std::get<1>(en.value()); 1530 if (!ShapedType::isDynamic(resultSize) && 1531 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1532 return emitError("expected result type with size = ") 1533 << expectedSize << " instead of " << resultSize 1534 << " in dim = " << en.index(); 1535 } 1536 1537 // Match offset and strides in static_offset and static_strides attributes. If 1538 // result memref type has no affine map specified, this will assume an 1539 // identity layout. 1540 int64_t resultOffset; 1541 SmallVector<int64_t, 4> resultStrides; 1542 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1543 return emitError("expected result type to have strided layout but found ") 1544 << resultType; 1545 1546 // Match offset in result memref type and in static_offsets attribute. 1547 int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); 1548 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1549 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1550 resultOffset != expectedOffset) 1551 return emitError("expected result type with offset = ") 1552 << resultOffset << " instead of " << expectedOffset; 1553 1554 // Match strides in result memref type and in static_strides attribute. 1555 for (auto &en : llvm::enumerate(llvm::zip( 1556 resultStrides, extractFromI64ArrayAttr(static_strides())))) { 1557 int64_t resultStride = std::get<0>(en.value()); 1558 int64_t expectedStride = std::get<1>(en.value()); 1559 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1560 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1561 resultStride != expectedStride) 1562 return emitError("expected result type with stride = ") 1563 << expectedStride << " instead of " << resultStride 1564 << " in dim = " << en.index(); 1565 } 1566 1567 return success(); 1568 } 1569 1570 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) { 1571 Value src = source(); 1572 auto getPrevSrc = [&]() -> Value { 1573 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x). 1574 if (auto prev = src.getDefiningOp<ReinterpretCastOp>()) 1575 return prev.source(); 1576 1577 // reinterpret_cast(cast(x)) -> reinterpret_cast(x). 1578 if (auto prev = src.getDefiningOp<CastOp>()) 1579 return prev.source(); 1580 1581 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets 1582 // are 0. 1583 if (auto prev = src.getDefiningOp<SubViewOp>()) 1584 if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { 1585 return isConstantIntValue(val, 0); 1586 })) 1587 return prev.source(); 1588 1589 return nullptr; 1590 }; 1591 1592 if (auto prevSrc = getPrevSrc()) { 1593 sourceMutable().assign(prevSrc); 1594 return getResult(); 1595 } 1596 1597 return nullptr; 1598 } 1599 1600 //===----------------------------------------------------------------------===// 1601 // Reassociative reshape ops 1602 //===----------------------------------------------------------------------===// 1603 1604 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp 1605 /// result and operand. Layout maps are verified separately. 1606 /// 1607 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are 1608 /// allowed in a reassocation group. 1609 static LogicalResult 1610 verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape, 1611 ArrayRef<int64_t> expandedShape, 1612 ArrayRef<ReassociationIndices> reassociation, 1613 bool allowMultipleDynamicDimsPerGroup) { 1614 // There must be one reassociation group per collapsed dimension. 1615 if (collapsedShape.size() != reassociation.size()) 1616 return op->emitOpError("invalid number of reassociation groups: found ") 1617 << reassociation.size() << ", expected " << collapsedShape.size(); 1618 1619 // The next expected expanded dimension index (while iterating over 1620 // reassociation indices). 1621 int64_t nextDim = 0; 1622 for (const auto &it : llvm::enumerate(reassociation)) { 1623 ReassociationIndices group = it.value(); 1624 int64_t collapsedDim = it.index(); 1625 1626 bool foundDynamic = false; 1627 for (int64_t expandedDim : group) { 1628 if (expandedDim != nextDim++) 1629 return op->emitOpError("reassociation indices must be contiguous"); 1630 1631 if (expandedDim >= static_cast<int64_t>(expandedShape.size())) 1632 return op->emitOpError("reassociation index ") 1633 << expandedDim << " is out of bounds"; 1634 1635 // Check if there are multiple dynamic dims in a reassociation group. 1636 if (ShapedType::isDynamic(expandedShape[expandedDim])) { 1637 if (foundDynamic && !allowMultipleDynamicDimsPerGroup) 1638 return op->emitOpError( 1639 "at most one dimension in a reassociation group may be dynamic"); 1640 foundDynamic = true; 1641 } 1642 } 1643 1644 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity. 1645 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic) 1646 return op->emitOpError("collapsed dim (") 1647 << collapsedDim 1648 << ") must be dynamic if and only if reassociation group is " 1649 "dynamic"; 1650 1651 // If all dims in the reassociation group are static, the size of the 1652 // collapsed dim can be verified. 1653 if (!foundDynamic) { 1654 int64_t groupSize = 1; 1655 for (int64_t expandedDim : group) 1656 groupSize *= expandedShape[expandedDim]; 1657 if (groupSize != collapsedShape[collapsedDim]) 1658 return op->emitOpError("collapsed dim size (") 1659 << collapsedShape[collapsedDim] 1660 << ") must equal reassociation group size (" << groupSize << ")"; 1661 } 1662 } 1663 1664 if (collapsedShape.empty()) { 1665 // Rank 0: All expanded dimensions must be 1. 1666 for (int64_t d : expandedShape) 1667 if (d != 1) 1668 return op->emitOpError( 1669 "rank 0 memrefs can only be extended/collapsed with/from ones"); 1670 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) { 1671 // Rank >= 1: Number of dimensions among all reassociation groups must match 1672 // the result memref rank. 1673 return op->emitOpError("expanded rank (") 1674 << expandedShape.size() 1675 << ") inconsistent with number of reassociation indices (" << nextDim 1676 << ")"; 1677 } 1678 1679 return success(); 1680 } 1681 1682 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1683 return getSymbolLessAffineMaps(getReassociationExprs()); 1684 } 1685 1686 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1687 return convertReassociationIndicesToExprs(getContext(), 1688 getReassociationIndices()); 1689 } 1690 1691 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1692 return getSymbolLessAffineMaps(getReassociationExprs()); 1693 } 1694 1695 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1696 return convertReassociationIndicesToExprs(getContext(), 1697 getReassociationIndices()); 1698 } 1699 1700 /// Compute the layout map after expanding a given source MemRef type with the 1701 /// specified reassociation indices. 1702 static FailureOr<AffineMap> 1703 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape, 1704 ArrayRef<ReassociationIndices> reassociation) { 1705 int64_t srcOffset; 1706 SmallVector<int64_t> srcStrides; 1707 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 1708 return failure(); 1709 assert(srcStrides.size() == reassociation.size() && "invalid reassociation"); 1710 1711 // 1-1 mapping between srcStrides and reassociation packs. 1712 // Each srcStride starts with the given value and gets expanded according to 1713 // the proper entries in resultShape. 1714 // Example: 1715 // srcStrides = [10000, 1 , 100 ], 1716 // reassociations = [ [0], [1], [2, 3, 4]], 1717 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]] 1718 // -> For the purpose of stride calculation, the useful sizes are: 1719 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]]. 1720 // resultStrides = [10000, 1, 600, 200, 100] 1721 // Note that a stride does not get expanded along the first entry of each 1722 // shape pack. 1723 SmallVector<int64_t> reverseResultStrides; 1724 reverseResultStrides.reserve(resultShape.size()); 1725 unsigned shapeIndex = resultShape.size() - 1; 1726 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) { 1727 ReassociationIndices reassoc = std::get<0>(it); 1728 int64_t currentStrideToExpand = std::get<1>(it); 1729 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) { 1730 using saturated_arith::Wrapper; 1731 reverseResultStrides.push_back(currentStrideToExpand); 1732 currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) * 1733 Wrapper::size(resultShape[shapeIndex--])) 1734 .asStride(); 1735 } 1736 } 1737 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides)); 1738 resultStrides.resize(resultShape.size(), 1); 1739 return makeStridedLinearLayoutMap(resultStrides, srcOffset, 1740 srcType.getContext()); 1741 } 1742 1743 static FailureOr<MemRefType> 1744 computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape, 1745 ArrayRef<ReassociationIndices> reassociation) { 1746 if (srcType.getLayout().isIdentity()) { 1747 // If the source is contiguous (i.e., no layout map specified), so is the 1748 // result. 1749 MemRefLayoutAttrInterface layout; 1750 return MemRefType::get(resultShape, srcType.getElementType(), layout, 1751 srcType.getMemorySpace()); 1752 } 1753 1754 // Source may not be contiguous. Compute the layout map. 1755 FailureOr<AffineMap> computedLayout = 1756 computeExpandedLayoutMap(srcType, resultShape, reassociation); 1757 if (failed(computedLayout)) 1758 return failure(); 1759 auto computedType = 1760 MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, 1761 srcType.getMemorySpaceAsInt()); 1762 return canonicalizeStridedLayout(computedType); 1763 } 1764 1765 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, 1766 ArrayRef<int64_t> resultShape, Value src, 1767 ArrayRef<ReassociationIndices> reassociation) { 1768 // Only ranked memref source values are supported. 1769 auto srcType = src.getType().cast<MemRefType>(); 1770 FailureOr<MemRefType> resultType = 1771 computeExpandedType(srcType, resultShape, reassociation); 1772 // Failure of this assertion usually indicates a problem with the source 1773 // type, e.g., could not get strides/offset. 1774 assert(succeeded(resultType) && "could not compute layout"); 1775 build(builder, result, *resultType, src, reassociation); 1776 } 1777 1778 LogicalResult ExpandShapeOp::verify() { 1779 MemRefType srcType = getSrcType(); 1780 MemRefType resultType = getResultType(); 1781 1782 // Verify result shape. 1783 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(), 1784 resultType.getShape(), 1785 getReassociationIndices(), 1786 /*allowMultipleDynamicDimsPerGroup=*/false))) 1787 return failure(); 1788 1789 // Compute expected result type (including layout map). 1790 FailureOr<MemRefType> expectedResultType = computeExpandedType( 1791 srcType, resultType.getShape(), getReassociationIndices()); 1792 if (failed(expectedResultType)) 1793 return emitOpError("invalid source layout map"); 1794 1795 // Check actual result type. 1796 auto canonicalizedResultType = canonicalizeStridedLayout(resultType); 1797 if (*expectedResultType != canonicalizedResultType) 1798 return emitOpError("expected expanded type to be ") 1799 << *expectedResultType << " but found " << canonicalizedResultType; 1800 1801 return success(); 1802 } 1803 1804 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1805 MLIRContext *context) { 1806 results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>, 1807 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>( 1808 context); 1809 } 1810 1811 /// Compute the layout map after collapsing a given source MemRef type with the 1812 /// specified reassociation indices. 1813 /// 1814 /// Note: All collapsed dims in a reassociation group must be contiguous. It is 1815 /// not possible to check this by inspecting a MemRefType in the general case. 1816 /// If non-contiguity cannot be checked statically, the collapse is assumed to 1817 /// be valid (and thus accepted by this function) unless `strict = true`. 1818 static FailureOr<AffineMap> 1819 computeCollapsedLayoutMap(MemRefType srcType, 1820 ArrayRef<ReassociationIndices> reassociation, 1821 bool strict = false) { 1822 int64_t srcOffset; 1823 SmallVector<int64_t> srcStrides; 1824 auto srcShape = srcType.getShape(); 1825 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 1826 return failure(); 1827 1828 // The result stride of a reassociation group is the stride of the last entry 1829 // of the reassociation. (TODO: Should be the minimum stride in the 1830 // reassociation because strides are not necessarily sorted. E.g., when using 1831 // memref.transpose.) Dimensions of size 1 should be skipped, because their 1832 // strides are meaningless and could have any arbitrary value. 1833 SmallVector<int64_t> resultStrides; 1834 resultStrides.reserve(reassociation.size()); 1835 for (const ReassociationIndices &reassoc : reassociation) { 1836 ArrayRef<int64_t> ref = llvm::makeArrayRef(reassoc); 1837 while (srcShape[ref.back()] == 1 && ref.size() > 1) 1838 ref = ref.drop_back(); 1839 if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { 1840 resultStrides.push_back(srcStrides[ref.back()]); 1841 } else { 1842 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so 1843 // the corresponding stride may have to be skipped. (See above comment.) 1844 // Therefore, the result stride cannot be statically determined and must 1845 // be dynamic. 1846 resultStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1847 } 1848 } 1849 1850 // Validate that each reassociation group is contiguous. 1851 unsigned resultStrideIndex = resultStrides.size() - 1; 1852 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) { 1853 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front(); 1854 using saturated_arith::Wrapper; 1855 auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]); 1856 for (int64_t idx : llvm::reverse(trailingReassocs)) { 1857 stride = stride * Wrapper::size(srcShape[idx]); 1858 1859 // Both source and result stride must have the same static value. In that 1860 // case, we can be sure, that the dimensions are collapsible (because they 1861 // are contiguous). 1862 // 1863 // One special case is when the srcShape is `1`, in which case it can 1864 // never produce non-contiguity. 1865 if (srcShape[idx] == 1) 1866 continue; 1867 1868 // If `strict = false` (default during op verification), we accept cases 1869 // where one or both strides are dynamic. This is best effort: We reject 1870 // ops where obviously non-contiguous dims are collapsed, but accept ops 1871 // where we cannot be sure statically. Such ops may fail at runtime. See 1872 // the op documentation for details. 1873 auto srcStride = Wrapper::stride(srcStrides[idx - 1]); 1874 if (strict && (stride.saturated || srcStride.saturated)) 1875 return failure(); 1876 1877 if (!stride.saturated && !srcStride.saturated && stride != srcStride) 1878 return failure(); 1879 } 1880 } 1881 return makeStridedLinearLayoutMap(resultStrides, srcOffset, 1882 srcType.getContext()); 1883 } 1884 1885 bool CollapseShapeOp::isGuaranteedCollapsible( 1886 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) { 1887 // MemRefs with standard layout are always collapsible. 1888 if (srcType.getLayout().isIdentity()) 1889 return true; 1890 1891 return succeeded(computeCollapsedLayoutMap(srcType, reassociation, 1892 /*strict=*/true)); 1893 } 1894 1895 static MemRefType 1896 computeCollapsedType(MemRefType srcType, 1897 ArrayRef<ReassociationIndices> reassociation) { 1898 SmallVector<int64_t> resultShape; 1899 resultShape.reserve(reassociation.size()); 1900 for (const ReassociationIndices &group : reassociation) { 1901 using saturated_arith::Wrapper; 1902 auto groupSize = Wrapper::size(1); 1903 for (int64_t srcDim : group) 1904 groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim)); 1905 resultShape.push_back(groupSize.asSize()); 1906 } 1907 1908 if (srcType.getLayout().isIdentity()) { 1909 // If the source is contiguous (i.e., no layout map specified), so is the 1910 // result. 1911 MemRefLayoutAttrInterface layout; 1912 return MemRefType::get(resultShape, srcType.getElementType(), layout, 1913 srcType.getMemorySpace()); 1914 } 1915 1916 // Source may not be fully contiguous. Compute the layout map. 1917 // Note: Dimensions that are collapsed into a single dim are assumed to be 1918 // contiguous. 1919 FailureOr<AffineMap> computedLayout = 1920 computeCollapsedLayoutMap(srcType, reassociation); 1921 assert(succeeded(computedLayout) && 1922 "invalid source layout map or collapsing non-contiguous dims"); 1923 auto computedType = 1924 MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, 1925 srcType.getMemorySpaceAsInt()); 1926 return canonicalizeStridedLayout(computedType); 1927 } 1928 1929 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1930 ArrayRef<ReassociationIndices> reassociation, 1931 ArrayRef<NamedAttribute> attrs) { 1932 auto srcType = src.getType().cast<MemRefType>(); 1933 MemRefType resultType = computeCollapsedType(srcType, reassociation); 1934 build(b, result, resultType, src, attrs); 1935 result.addAttribute(getReassociationAttrName(), 1936 getReassociationIndicesAttribute(b, reassociation)); 1937 } 1938 1939 LogicalResult CollapseShapeOp::verify() { 1940 MemRefType srcType = getSrcType(); 1941 MemRefType resultType = getResultType(); 1942 1943 // Verify result shape. 1944 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(), 1945 srcType.getShape(), getReassociationIndices(), 1946 /*allowMultipleDynamicDimsPerGroup=*/true))) 1947 return failure(); 1948 1949 // Compute expected result type (including layout map). 1950 MemRefType expectedResultType; 1951 if (srcType.getLayout().isIdentity()) { 1952 // If the source is contiguous (i.e., no layout map specified), so is the 1953 // result. 1954 MemRefLayoutAttrInterface layout; 1955 expectedResultType = 1956 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout, 1957 srcType.getMemorySpace()); 1958 } else { 1959 // Source may not be fully contiguous. Compute the layout map. 1960 // Note: Dimensions that are collapsed into a single dim are assumed to be 1961 // contiguous. 1962 FailureOr<AffineMap> computedLayout = 1963 computeCollapsedLayoutMap(srcType, getReassociationIndices()); 1964 if (failed(computedLayout)) 1965 return emitOpError( 1966 "invalid source layout map or collapsing non-contiguous dims"); 1967 auto computedType = 1968 MemRefType::get(resultType.getShape(), srcType.getElementType(), 1969 *computedLayout, srcType.getMemorySpaceAsInt()); 1970 expectedResultType = canonicalizeStridedLayout(computedType); 1971 } 1972 1973 auto canonicalizedResultType = canonicalizeStridedLayout(resultType); 1974 if (expectedResultType != canonicalizedResultType) 1975 return emitOpError("expected collapsed type to be ") 1976 << expectedResultType << " but found " << canonicalizedResultType; 1977 1978 return success(); 1979 } 1980 1981 struct CollapseShapeOpMemRefCastFolder 1982 : public OpRewritePattern<CollapseShapeOp> { 1983 public: 1984 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1985 1986 LogicalResult matchAndRewrite(CollapseShapeOp op, 1987 PatternRewriter &rewriter) const override { 1988 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1989 if (!cast) 1990 return failure(); 1991 1992 if (!CastOp::canFoldIntoConsumerOp(cast)) 1993 return failure(); 1994 1995 Type newResultType = 1996 computeCollapsedType(cast.getOperand().getType().cast<MemRefType>(), 1997 op.getReassociationIndices()); 1998 1999 if (newResultType == op.getResultType()) { 2000 rewriter.updateRootInPlace( 2001 op, [&]() { op.srcMutable().assign(cast.source()); }); 2002 } else { 2003 Value newOp = rewriter.create<CollapseShapeOp>( 2004 op->getLoc(), cast.source(), op.getReassociationIndices()); 2005 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 2006 } 2007 return success(); 2008 } 2009 }; 2010 2011 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 2012 MLIRContext *context) { 2013 results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>, 2014 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>, 2015 CollapseShapeOpMemRefCastFolder>(context); 2016 } 2017 2018 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 2019 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 2020 } 2021 2022 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 2023 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 2024 } 2025 2026 //===----------------------------------------------------------------------===// 2027 // ReshapeOp 2028 //===----------------------------------------------------------------------===// 2029 2030 LogicalResult ReshapeOp::verify() { 2031 Type operandType = source().getType(); 2032 Type resultType = result().getType(); 2033 2034 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 2035 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 2036 if (operandElementType != resultElementType) 2037 return emitOpError("element types of source and destination memref " 2038 "types should be the same"); 2039 2040 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 2041 if (!operandMemRefType.getLayout().isIdentity()) 2042 return emitOpError("source memref type should have identity affine map"); 2043 2044 int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0); 2045 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 2046 if (resultMemRefType) { 2047 if (!resultMemRefType.getLayout().isIdentity()) 2048 return emitOpError("result memref type should have identity affine map"); 2049 if (shapeSize == ShapedType::kDynamicSize) 2050 return emitOpError("cannot use shape operand with dynamic length to " 2051 "reshape to statically-ranked memref type"); 2052 if (shapeSize != resultMemRefType.getRank()) 2053 return emitOpError( 2054 "length of shape operand differs from the result's memref rank"); 2055 } 2056 return success(); 2057 } 2058 2059 //===----------------------------------------------------------------------===// 2060 // StoreOp 2061 //===----------------------------------------------------------------------===// 2062 2063 LogicalResult StoreOp::verify() { 2064 if (getNumOperands() != 2 + getMemRefType().getRank()) 2065 return emitOpError("store index operand count not equal to memref rank"); 2066 2067 return success(); 2068 } 2069 2070 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 2071 SmallVectorImpl<OpFoldResult> &results) { 2072 /// store(memrefcast) -> store 2073 return foldMemRefCast(*this, getValueToStore()); 2074 } 2075 2076 //===----------------------------------------------------------------------===// 2077 // SubViewOp 2078 //===----------------------------------------------------------------------===// 2079 2080 /// A subview result type can be fully inferred from the source type and the 2081 /// static representation of offsets, sizes and strides. Special sentinels 2082 /// encode the dynamic case. 2083 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 2084 ArrayRef<int64_t> staticOffsets, 2085 ArrayRef<int64_t> staticSizes, 2086 ArrayRef<int64_t> staticStrides) { 2087 unsigned rank = sourceMemRefType.getRank(); 2088 (void)rank; 2089 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 2090 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 2091 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 2092 2093 // Extract source offset and strides. 2094 int64_t sourceOffset; 2095 SmallVector<int64_t, 4> sourceStrides; 2096 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 2097 assert(succeeded(res) && "SubViewOp expected strided memref type"); 2098 (void)res; 2099 2100 // Compute target offset whose value is: 2101 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 2102 int64_t targetOffset = sourceOffset; 2103 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 2104 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 2105 using saturated_arith::Wrapper; 2106 targetOffset = 2107 (Wrapper::offset(targetOffset) + 2108 Wrapper::offset(staticOffset) * Wrapper::stride(targetStride)) 2109 .asOffset(); 2110 } 2111 2112 // Compute target stride whose value is: 2113 // `sourceStrides_i * staticStrides_i`. 2114 SmallVector<int64_t, 4> targetStrides; 2115 targetStrides.reserve(staticOffsets.size()); 2116 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 2117 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 2118 using saturated_arith::Wrapper; 2119 targetStrides.push_back( 2120 (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride)) 2121 .asStride()); 2122 } 2123 2124 // The type is now known. 2125 return MemRefType::get( 2126 staticSizes, sourceMemRefType.getElementType(), 2127 makeStridedLinearLayoutMap(targetStrides, targetOffset, 2128 sourceMemRefType.getContext()), 2129 sourceMemRefType.getMemorySpace()); 2130 } 2131 2132 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 2133 ArrayRef<OpFoldResult> offsets, 2134 ArrayRef<OpFoldResult> sizes, 2135 ArrayRef<OpFoldResult> strides) { 2136 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2137 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2138 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 2139 ShapedType::kDynamicStrideOrOffset); 2140 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 2141 ShapedType::kDynamicSize); 2142 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 2143 ShapedType::kDynamicStrideOrOffset); 2144 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 2145 staticSizes, staticStrides); 2146 } 2147 2148 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 2149 MemRefType sourceRankedTensorType, 2150 ArrayRef<int64_t> offsets, 2151 ArrayRef<int64_t> sizes, 2152 ArrayRef<int64_t> strides) { 2153 auto inferredType = 2154 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 2155 .cast<MemRefType>(); 2156 assert(inferredType.getRank() >= resultRank && "expected "); 2157 int rankDiff = inferredType.getRank() - resultRank; 2158 if (rankDiff > 0) { 2159 auto shape = inferredType.getShape(); 2160 llvm::SmallBitVector dimsToProject = 2161 getPositionsOfShapeOne(rankDiff, shape); 2162 SmallVector<int64_t> projectedShape; 2163 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 2164 if (!dimsToProject.test(pos)) 2165 projectedShape.push_back(shape[pos]); 2166 2167 AffineMap map = inferredType.getLayout().getAffineMap(); 2168 if (!map.isIdentity()) 2169 map = getProjectedMap(map, dimsToProject); 2170 inferredType = 2171 MemRefType::get(projectedShape, inferredType.getElementType(), map, 2172 inferredType.getMemorySpace()); 2173 } 2174 return inferredType; 2175 } 2176 2177 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 2178 MemRefType sourceRankedTensorType, 2179 ArrayRef<OpFoldResult> offsets, 2180 ArrayRef<OpFoldResult> sizes, 2181 ArrayRef<OpFoldResult> strides) { 2182 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2183 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2184 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 2185 ShapedType::kDynamicStrideOrOffset); 2186 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 2187 ShapedType::kDynamicSize); 2188 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 2189 ShapedType::kDynamicStrideOrOffset); 2190 return SubViewOp::inferRankReducedResultType( 2191 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 2192 staticStrides); 2193 } 2194 // Build a SubViewOp with mixed static and dynamic entries and custom result 2195 // type. If the type passed is nullptr, it is inferred. 2196 void SubViewOp::build(OpBuilder &b, OperationState &result, 2197 MemRefType resultType, Value source, 2198 ArrayRef<OpFoldResult> offsets, 2199 ArrayRef<OpFoldResult> sizes, 2200 ArrayRef<OpFoldResult> strides, 2201 ArrayRef<NamedAttribute> attrs) { 2202 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2203 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2204 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 2205 ShapedType::kDynamicStrideOrOffset); 2206 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 2207 ShapedType::kDynamicSize); 2208 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 2209 ShapedType::kDynamicStrideOrOffset); 2210 auto sourceMemRefType = source.getType().cast<MemRefType>(); 2211 // Structuring implementation this way avoids duplication between builders. 2212 if (!resultType) { 2213 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 2214 staticSizes, staticStrides) 2215 .cast<MemRefType>(); 2216 } 2217 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 2218 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 2219 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 2220 result.addAttributes(attrs); 2221 } 2222 2223 // Build a SubViewOp with mixed static and dynamic entries and inferred result 2224 // type. 2225 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2226 ArrayRef<OpFoldResult> offsets, 2227 ArrayRef<OpFoldResult> sizes, 2228 ArrayRef<OpFoldResult> strides, 2229 ArrayRef<NamedAttribute> attrs) { 2230 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2231 } 2232 2233 // Build a SubViewOp with static entries and inferred result type. 2234 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2235 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2236 ArrayRef<int64_t> strides, 2237 ArrayRef<NamedAttribute> attrs) { 2238 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2239 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 2240 return b.getI64IntegerAttr(v); 2241 })); 2242 SmallVector<OpFoldResult> sizeValues = 2243 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2244 return b.getI64IntegerAttr(v); 2245 })); 2246 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2247 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2248 return b.getI64IntegerAttr(v); 2249 })); 2250 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 2251 } 2252 2253 // Build a SubViewOp with dynamic entries and custom result type. If the 2254 // type passed is nullptr, it is inferred. 2255 void SubViewOp::build(OpBuilder &b, OperationState &result, 2256 MemRefType resultType, Value source, 2257 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2258 ArrayRef<int64_t> strides, 2259 ArrayRef<NamedAttribute> attrs) { 2260 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2261 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 2262 return b.getI64IntegerAttr(v); 2263 })); 2264 SmallVector<OpFoldResult> sizeValues = 2265 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2266 return b.getI64IntegerAttr(v); 2267 })); 2268 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2269 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2270 return b.getI64IntegerAttr(v); 2271 })); 2272 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 2273 attrs); 2274 } 2275 2276 // Build a SubViewOp with dynamic entries and custom result type. If the type 2277 // passed is nullptr, it is inferred. 2278 void SubViewOp::build(OpBuilder &b, OperationState &result, 2279 MemRefType resultType, Value source, ValueRange offsets, 2280 ValueRange sizes, ValueRange strides, 2281 ArrayRef<NamedAttribute> attrs) { 2282 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2283 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 2284 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 2285 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 2286 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2287 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 2288 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 2289 } 2290 2291 // Build a SubViewOp with dynamic entries and inferred result type. 2292 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2293 ValueRange offsets, ValueRange sizes, ValueRange strides, 2294 ArrayRef<NamedAttribute> attrs) { 2295 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2296 } 2297 2298 /// For ViewLikeOpInterface. 2299 Value SubViewOp::getViewSource() { return source(); } 2300 2301 /// Return true if t1 and t2 have equal offsets (both dynamic or of same 2302 /// static value). 2303 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 2304 AffineExpr t1Offset, t2Offset; 2305 SmallVector<AffineExpr> t1Strides, t2Strides; 2306 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 2307 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 2308 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 2309 } 2310 2311 /// Checks if `original` Type type can be rank reduced to `reduced` type. 2312 /// This function is slight variant of `is subsequence` algorithm where 2313 /// not matching dimension must be 1. 2314 static SliceVerificationResult 2315 isRankReducedMemRefType(MemRefType originalType, 2316 MemRefType candidateRankReducedType, 2317 ArrayRef<OpFoldResult> sizes) { 2318 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 2319 if (partialRes != SliceVerificationResult::Success) 2320 return partialRes; 2321 2322 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 2323 originalType, candidateRankReducedType, sizes); 2324 2325 // Sizes cannot be matched in case empty vector is returned. 2326 if (!optionalUnusedDimsMask.hasValue()) 2327 return SliceVerificationResult::LayoutMismatch; 2328 2329 if (originalType.getMemorySpace() != 2330 candidateRankReducedType.getMemorySpace()) 2331 return SliceVerificationResult::MemSpaceMismatch; 2332 2333 // No amount of stride dropping can reconcile incompatible offsets. 2334 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 2335 return SliceVerificationResult::LayoutMismatch; 2336 2337 return SliceVerificationResult::Success; 2338 } 2339 2340 template <typename OpTy> 2341 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 2342 OpTy op, Type expectedType) { 2343 auto memrefType = expectedType.cast<ShapedType>(); 2344 switch (result) { 2345 case SliceVerificationResult::Success: 2346 return success(); 2347 case SliceVerificationResult::RankTooLarge: 2348 return op.emitError("expected result rank to be smaller or equal to ") 2349 << "the source rank. "; 2350 case SliceVerificationResult::SizeMismatch: 2351 return op.emitError("expected result type to be ") 2352 << expectedType 2353 << " or a rank-reduced version. (mismatch of result sizes) "; 2354 case SliceVerificationResult::ElemTypeMismatch: 2355 return op.emitError("expected result element type to be ") 2356 << memrefType.getElementType(); 2357 case SliceVerificationResult::MemSpaceMismatch: 2358 return op.emitError("expected result and source memory spaces to match."); 2359 case SliceVerificationResult::LayoutMismatch: 2360 return op.emitError("expected result type to be ") 2361 << expectedType 2362 << " or a rank-reduced version. (mismatch of result layout) "; 2363 } 2364 llvm_unreachable("unexpected subview verification result"); 2365 } 2366 2367 /// Verifier for SubViewOp. 2368 LogicalResult SubViewOp::verify() { 2369 MemRefType baseType = getSourceType(); 2370 MemRefType subViewType = getType(); 2371 2372 // The base memref and the view memref should be in the same memory space. 2373 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2374 return emitError("different memory spaces specified for base memref " 2375 "type ") 2376 << baseType << " and subview memref type " << subViewType; 2377 2378 // Verify that the base memref type has a strided layout map. 2379 if (!isStrided(baseType)) 2380 return emitError("base type ") << baseType << " is not strided"; 2381 2382 // Verify result type against inferred type. 2383 auto expectedType = SubViewOp::inferResultType( 2384 baseType, extractFromI64ArrayAttr(static_offsets()), 2385 extractFromI64ArrayAttr(static_sizes()), 2386 extractFromI64ArrayAttr(static_strides())); 2387 2388 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 2389 subViewType, getMixedSizes()); 2390 return produceSubViewErrorMsg(result, *this, expectedType); 2391 } 2392 2393 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 2394 return os << "range " << range.offset << ":" << range.size << ":" 2395 << range.stride; 2396 } 2397 2398 /// Return the list of Range (i.e. offset, size, stride). Each Range 2399 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2400 /// with `b` at location `loc`. 2401 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2402 OpBuilder &b, Location loc) { 2403 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2404 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2405 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2406 SmallVector<Range, 8> res; 2407 unsigned rank = ranks[0]; 2408 res.reserve(rank); 2409 for (unsigned idx = 0; idx < rank; ++idx) { 2410 Value offset = 2411 op.isDynamicOffset(idx) 2412 ? op.getDynamicOffset(idx) 2413 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2414 Value size = 2415 op.isDynamicSize(idx) 2416 ? op.getDynamicSize(idx) 2417 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 2418 Value stride = 2419 op.isDynamicStride(idx) 2420 ? op.getDynamicStride(idx) 2421 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 2422 res.emplace_back(Range{offset, size, stride}); 2423 } 2424 return res; 2425 } 2426 2427 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` 2428 /// to deduce the result type for the given `sourceType`. Additionally, reduce 2429 /// the rank of the inferred result type if `currentResultType` is lower rank 2430 /// than `currentSourceType`. Use this signature if `sourceType` is updated 2431 /// together with the result type. In this case, it is important to compute 2432 /// the dropped dimensions using `currentSourceType` whose strides align with 2433 /// `currentResultType`. 2434 static MemRefType getCanonicalSubViewResultType( 2435 MemRefType currentResultType, MemRefType currentSourceType, 2436 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 2437 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 2438 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2439 mixedSizes, mixedStrides) 2440 .cast<MemRefType>(); 2441 llvm::Optional<llvm::SmallBitVector> unusedDims = 2442 computeMemRefRankReductionMask(currentSourceType, currentResultType, 2443 mixedSizes); 2444 // Return nullptr as failure mode. 2445 if (!unusedDims) 2446 return nullptr; 2447 SmallVector<int64_t> shape; 2448 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 2449 if (unusedDims->test(sizes.index())) 2450 continue; 2451 shape.push_back(sizes.value()); 2452 } 2453 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 2454 if (!layoutMap.isIdentity()) 2455 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 2456 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 2457 nonRankReducedType.getMemorySpace()); 2458 } 2459 2460 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` 2461 /// to deduce the result type. Additionally, reduce the rank of the inferred 2462 /// result type if `currentResultType` is lower rank than `sourceType`. 2463 static MemRefType getCanonicalSubViewResultType( 2464 MemRefType currentResultType, MemRefType sourceType, 2465 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 2466 ArrayRef<OpFoldResult> mixedStrides) { 2467 return getCanonicalSubViewResultType(currentResultType, sourceType, 2468 sourceType, mixedOffsets, mixedSizes, 2469 mixedStrides); 2470 } 2471 2472 /// Helper method to check if a `subview` operation is trivially a no-op. This 2473 /// is the case if the all offsets are zero, all strides are 1, and the source 2474 /// shape is same as the size of the subview. In such cases, the subview can 2475 /// be folded into its source. 2476 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 2477 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 2478 return false; 2479 2480 auto mixedOffsets = subViewOp.getMixedOffsets(); 2481 auto mixedSizes = subViewOp.getMixedSizes(); 2482 auto mixedStrides = subViewOp.getMixedStrides(); 2483 2484 // Check offsets are zero. 2485 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 2486 Optional<int64_t> intValue = getConstantIntValue(ofr); 2487 return !intValue || intValue.getValue() != 0; 2488 })) 2489 return false; 2490 2491 // Check strides are one. 2492 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 2493 Optional<int64_t> intValue = getConstantIntValue(ofr); 2494 return !intValue || intValue.getValue() != 1; 2495 })) 2496 return false; 2497 2498 // Check all size values are static and matches the (static) source shape. 2499 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 2500 for (const auto &size : llvm::enumerate(mixedSizes)) { 2501 Optional<int64_t> intValue = getConstantIntValue(size.value()); 2502 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 2503 return false; 2504 } 2505 // All conditions met. The `SubViewOp` is foldable as a no-op. 2506 return true; 2507 } 2508 2509 namespace { 2510 /// Pattern to rewrite a subview op with MemRefCast arguments. 2511 /// This essentially pushes memref.cast past its consuming subview when 2512 /// `canFoldIntoConsumerOp` is true. 2513 /// 2514 /// Example: 2515 /// ``` 2516 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2517 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2518 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2519 /// ``` 2520 /// is rewritten into: 2521 /// ``` 2522 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2523 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2524 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2525 /// ``` 2526 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2527 public: 2528 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2529 2530 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2531 PatternRewriter &rewriter) const override { 2532 // Any constant operand, just return to let SubViewOpConstantFolder kick 2533 // in. 2534 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2535 return matchPattern(operand, matchConstantIndex()); 2536 })) 2537 return failure(); 2538 2539 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2540 if (!castOp) 2541 return failure(); 2542 2543 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2544 return failure(); 2545 2546 // Compute the SubViewOp result type after folding the MemRefCastOp. Use 2547 // the MemRefCastOp source operand type to infer the result type and the 2548 // current SubViewOp source operand type to compute the dropped dimensions 2549 // if the operation is rank-reducing. 2550 auto resultType = getCanonicalSubViewResultType( 2551 subViewOp.getType(), subViewOp.getSourceType(), 2552 castOp.source().getType().cast<MemRefType>(), 2553 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2554 subViewOp.getMixedStrides()); 2555 if (!resultType) 2556 return failure(); 2557 2558 Value newSubView = rewriter.create<SubViewOp>( 2559 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2560 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2561 subViewOp.static_sizes(), subViewOp.static_strides()); 2562 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2563 newSubView); 2564 return success(); 2565 } 2566 }; 2567 2568 /// Canonicalize subview ops that are no-ops. When the source shape is not 2569 /// same as a result shape due to use of `affine_map`. 2570 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 2571 public: 2572 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2573 2574 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2575 PatternRewriter &rewriter) const override { 2576 if (!isTrivialSubViewOp(subViewOp)) 2577 return failure(); 2578 if (subViewOp.getSourceType() == subViewOp.getType()) { 2579 rewriter.replaceOp(subViewOp, subViewOp.source()); 2580 return success(); 2581 } 2582 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2583 subViewOp.source()); 2584 return success(); 2585 } 2586 }; 2587 } // namespace 2588 2589 /// Return the canonical type of the result of a subview. 2590 struct SubViewReturnTypeCanonicalizer { 2591 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2592 ArrayRef<OpFoldResult> mixedSizes, 2593 ArrayRef<OpFoldResult> mixedStrides) { 2594 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2595 mixedOffsets, mixedSizes, 2596 mixedStrides); 2597 } 2598 }; 2599 2600 /// A canonicalizer wrapper to replace SubViewOps. 2601 struct SubViewCanonicalizer { 2602 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2603 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 2604 } 2605 }; 2606 2607 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2608 MLIRContext *context) { 2609 results 2610 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2611 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2612 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2613 } 2614 2615 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2616 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2617 auto sourceShapedType = source().getType().cast<ShapedType>(); 2618 2619 if (resultShapedType.hasStaticShape() && 2620 resultShapedType == sourceShapedType) { 2621 return getViewSource(); 2622 } 2623 2624 return {}; 2625 } 2626 2627 //===----------------------------------------------------------------------===// 2628 // TransposeOp 2629 //===----------------------------------------------------------------------===// 2630 2631 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2632 static MemRefType inferTransposeResultType(MemRefType memRefType, 2633 AffineMap permutationMap) { 2634 auto rank = memRefType.getRank(); 2635 auto originalSizes = memRefType.getShape(); 2636 // Compute permuted sizes. 2637 SmallVector<int64_t, 4> sizes(rank, 0); 2638 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2639 sizes[en.index()] = 2640 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2641 2642 // Compute permuted strides. 2643 int64_t offset; 2644 SmallVector<int64_t, 4> strides; 2645 auto res = getStridesAndOffset(memRefType, strides, offset); 2646 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2647 (void)res; 2648 auto map = 2649 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2650 map = permutationMap ? map.compose(permutationMap) : map; 2651 return MemRefType::Builder(memRefType) 2652 .setShape(sizes) 2653 .setLayout(AffineMapAttr::get(map)); 2654 } 2655 2656 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2657 AffineMapAttr permutation, 2658 ArrayRef<NamedAttribute> attrs) { 2659 auto permutationMap = permutation.getValue(); 2660 assert(permutationMap); 2661 2662 auto memRefType = in.getType().cast<MemRefType>(); 2663 // Compute result type. 2664 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2665 2666 build(b, result, resultType, in, attrs); 2667 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2668 } 2669 2670 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2671 void TransposeOp::print(OpAsmPrinter &p) { 2672 p << " " << in() << " " << permutation(); 2673 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); 2674 p << " : " << in().getType() << " to " << getType(); 2675 } 2676 2677 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { 2678 OpAsmParser::UnresolvedOperand in; 2679 AffineMap permutation; 2680 MemRefType srcType, dstType; 2681 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2682 parser.parseOptionalAttrDict(result.attributes) || 2683 parser.parseColonType(srcType) || 2684 parser.resolveOperand(in, srcType, result.operands) || 2685 parser.parseKeywordType("to", dstType) || 2686 parser.addTypeToList(dstType, result.types)) 2687 return failure(); 2688 2689 result.addAttribute(TransposeOp::getPermutationAttrName(), 2690 AffineMapAttr::get(permutation)); 2691 return success(); 2692 } 2693 2694 LogicalResult TransposeOp::verify() { 2695 if (!permutation().isPermutation()) 2696 return emitOpError("expected a permutation map"); 2697 if (permutation().getNumDims() != getShapedType().getRank()) 2698 return emitOpError("expected a permutation map of same rank as the input"); 2699 2700 auto srcType = in().getType().cast<MemRefType>(); 2701 auto dstType = getType().cast<MemRefType>(); 2702 auto transposedType = inferTransposeResultType(srcType, permutation()); 2703 if (dstType != transposedType) 2704 return emitOpError("output type ") 2705 << dstType << " does not match transposed input type " << srcType 2706 << ", " << transposedType; 2707 return success(); 2708 } 2709 2710 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2711 if (succeeded(foldMemRefCast(*this))) 2712 return getResult(); 2713 return {}; 2714 } 2715 2716 //===----------------------------------------------------------------------===// 2717 // ViewOp 2718 //===----------------------------------------------------------------------===// 2719 2720 LogicalResult ViewOp::verify() { 2721 auto baseType = getOperand(0).getType().cast<MemRefType>(); 2722 auto viewType = getType(); 2723 2724 // The base memref should have identity layout map (or none). 2725 if (!baseType.getLayout().isIdentity()) 2726 return emitError("unsupported map for base memref type ") << baseType; 2727 2728 // The result memref should have identity layout map (or none). 2729 if (!viewType.getLayout().isIdentity()) 2730 return emitError("unsupported map for result memref type ") << viewType; 2731 2732 // The base memref and the view memref should be in the same memory space. 2733 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2734 return emitError("different memory spaces specified for base memref " 2735 "type ") 2736 << baseType << " and view memref type " << viewType; 2737 2738 // Verify that we have the correct number of sizes for the result type. 2739 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2740 if (sizes().size() != numDynamicDims) 2741 return emitError("incorrect number of size operands for type ") << viewType; 2742 2743 return success(); 2744 } 2745 2746 Value ViewOp::getViewSource() { return source(); } 2747 2748 namespace { 2749 2750 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2751 using OpRewritePattern<ViewOp>::OpRewritePattern; 2752 2753 LogicalResult matchAndRewrite(ViewOp viewOp, 2754 PatternRewriter &rewriter) const override { 2755 // Return if none of the operands are constants. 2756 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2757 return matchPattern(operand, matchConstantIndex()); 2758 })) 2759 return failure(); 2760 2761 // Get result memref type. 2762 auto memrefType = viewOp.getType(); 2763 2764 // Get offset from old memref view type 'memRefType'. 2765 int64_t oldOffset; 2766 SmallVector<int64_t, 4> oldStrides; 2767 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2768 return failure(); 2769 assert(oldOffset == 0 && "Expected 0 offset"); 2770 2771 SmallVector<Value, 4> newOperands; 2772 2773 // Offset cannot be folded into result type. 2774 2775 // Fold any dynamic dim operands which are produced by a constant. 2776 SmallVector<int64_t, 4> newShapeConstants; 2777 newShapeConstants.reserve(memrefType.getRank()); 2778 2779 unsigned dynamicDimPos = 0; 2780 unsigned rank = memrefType.getRank(); 2781 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2782 int64_t dimSize = memrefType.getDimSize(dim); 2783 // If this is already static dimension, keep it. 2784 if (!ShapedType::isDynamic(dimSize)) { 2785 newShapeConstants.push_back(dimSize); 2786 continue; 2787 } 2788 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2789 if (auto constantIndexOp = 2790 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2791 // Dynamic shape dimension will be folded. 2792 newShapeConstants.push_back(constantIndexOp.value()); 2793 } else { 2794 // Dynamic shape dimension not folded; copy operand from old memref. 2795 newShapeConstants.push_back(dimSize); 2796 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2797 } 2798 dynamicDimPos++; 2799 } 2800 2801 // Create new memref type with constant folded dims. 2802 MemRefType newMemRefType = 2803 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2804 // Nothing new, don't fold. 2805 if (newMemRefType == memrefType) 2806 return failure(); 2807 2808 // Create new ViewOp. 2809 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2810 viewOp.getOperand(0), 2811 viewOp.byte_shift(), newOperands); 2812 // Insert a cast so we have the same type as the old memref type. 2813 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp); 2814 return success(); 2815 } 2816 }; 2817 2818 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2819 using OpRewritePattern<ViewOp>::OpRewritePattern; 2820 2821 LogicalResult matchAndRewrite(ViewOp viewOp, 2822 PatternRewriter &rewriter) const override { 2823 Value memrefOperand = viewOp.getOperand(0); 2824 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2825 if (!memrefCastOp) 2826 return failure(); 2827 Value allocOperand = memrefCastOp.getOperand(); 2828 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2829 if (!allocOp) 2830 return failure(); 2831 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2832 viewOp.byte_shift(), viewOp.sizes()); 2833 return success(); 2834 } 2835 }; 2836 2837 } // namespace 2838 2839 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2840 MLIRContext *context) { 2841 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2842 } 2843 2844 //===----------------------------------------------------------------------===// 2845 // AtomicRMWOp 2846 //===----------------------------------------------------------------------===// 2847 2848 LogicalResult AtomicRMWOp::verify() { 2849 if (getMemRefType().getRank() != getNumOperands() - 2) 2850 return emitOpError( 2851 "expects the number of subscripts to be equal to memref rank"); 2852 switch (kind()) { 2853 case arith::AtomicRMWKind::addf: 2854 case arith::AtomicRMWKind::maxf: 2855 case arith::AtomicRMWKind::minf: 2856 case arith::AtomicRMWKind::mulf: 2857 if (!value().getType().isa<FloatType>()) 2858 return emitOpError() << "with kind '" 2859 << arith::stringifyAtomicRMWKind(kind()) 2860 << "' expects a floating-point type"; 2861 break; 2862 case arith::AtomicRMWKind::addi: 2863 case arith::AtomicRMWKind::maxs: 2864 case arith::AtomicRMWKind::maxu: 2865 case arith::AtomicRMWKind::mins: 2866 case arith::AtomicRMWKind::minu: 2867 case arith::AtomicRMWKind::muli: 2868 case arith::AtomicRMWKind::ori: 2869 case arith::AtomicRMWKind::andi: 2870 if (!value().getType().isa<IntegerType>()) 2871 return emitOpError() << "with kind '" 2872 << arith::stringifyAtomicRMWKind(kind()) 2873 << "' expects an integer type"; 2874 break; 2875 default: 2876 break; 2877 } 2878 return success(); 2879 } 2880 2881 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2882 /// atomicrmw(memrefcast) -> atomicrmw 2883 if (succeeded(foldMemRefCast(*this, value()))) 2884 return getResult(); 2885 return OpFoldResult(); 2886 } 2887 2888 //===----------------------------------------------------------------------===// 2889 // TableGen'd op method definitions 2890 //===----------------------------------------------------------------------===// 2891 2892 #define GET_OP_CLASSES 2893 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2894