1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/SmallBitVector.h" 25 26 using namespace mlir; 27 using namespace mlir::memref; 28 29 /// Materialize a single constant operation from a given attribute value with 30 /// the desired resultant type. 31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 32 Attribute value, Type type, 33 Location loc) { 34 if (arith::ConstantOp::isBuildableWith(value, type)) 35 return builder.create<arith::ConstantOp>(loc, value, type); 36 return nullptr; 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Common canonicalization pattern support logic 41 //===----------------------------------------------------------------------===// 42 43 /// This is a common class used for patterns of the form 44 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 45 /// into the root operation directly. 46 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 47 bool folded = false; 48 for (OpOperand &operand : op->getOpOperands()) { 49 auto cast = operand.get().getDefiningOp<CastOp>(); 50 if (cast && operand.get() != inner && 51 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 52 operand.set(cast.getOperand()); 53 folded = true; 54 } 55 } 56 return success(folded); 57 } 58 59 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 60 /// type. 61 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 62 if (auto memref = type.dyn_cast<MemRefType>()) 63 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 64 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 65 return UnrankedTensorType::get(memref.getElementType()); 66 return NoneType::get(type.getContext()); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // AllocOp / AllocaOp 71 //===----------------------------------------------------------------------===// 72 73 template <typename AllocLikeOp> 74 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 75 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 76 "applies to only alloc or alloca"); 77 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 78 if (!memRefType) 79 return op.emitOpError("result must be a memref"); 80 81 if (static_cast<int64_t>(op.dynamicSizes().size()) != 82 memRefType.getNumDynamicDims()) 83 return op.emitOpError("dimension operand count does not equal memref " 84 "dynamic dimension count"); 85 86 unsigned numSymbols = 0; 87 if (!memRefType.getLayout().isIdentity()) 88 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 89 if (op.symbolOperands().size() != numSymbols) 90 return op.emitOpError("symbol operand count does not equal memref symbol " 91 "count: expected ") 92 << numSymbols << ", got " << op.symbolOperands().size(); 93 94 return success(); 95 } 96 97 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } 98 99 LogicalResult AllocaOp::verify() { 100 // An alloca op needs to have an ancestor with an allocation scope trait. 101 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 102 return emitOpError( 103 "requires an ancestor op with AutomaticAllocationScope trait"); 104 105 return verifyAllocLikeOp(*this); 106 } 107 108 namespace { 109 /// Fold constant dimensions into an alloc like operation. 110 template <typename AllocLikeOp> 111 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 112 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 113 114 LogicalResult matchAndRewrite(AllocLikeOp alloc, 115 PatternRewriter &rewriter) const override { 116 // Check to see if any dimensions operands are constants. If so, we can 117 // substitute and drop them. 118 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 119 return matchPattern(operand, matchConstantIndex()); 120 })) 121 return failure(); 122 123 auto memrefType = alloc.getType(); 124 125 // Ok, we have one or more constant operands. Collect the non-constant ones 126 // and keep track of the resultant memref type to build. 127 SmallVector<int64_t, 4> newShapeConstants; 128 newShapeConstants.reserve(memrefType.getRank()); 129 SmallVector<Value, 4> dynamicSizes; 130 131 unsigned dynamicDimPos = 0; 132 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 133 int64_t dimSize = memrefType.getDimSize(dim); 134 // If this is already static dimension, keep it. 135 if (dimSize != -1) { 136 newShapeConstants.push_back(dimSize); 137 continue; 138 } 139 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 140 auto *defOp = dynamicSize.getDefiningOp(); 141 if (auto constantIndexOp = 142 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 143 // Dynamic shape dimension will be folded. 144 newShapeConstants.push_back(constantIndexOp.value()); 145 } else { 146 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 147 newShapeConstants.push_back(-1); 148 dynamicSizes.push_back(dynamicSize); 149 } 150 dynamicDimPos++; 151 } 152 153 // Create new memref type (which will have fewer dynamic dimensions). 154 MemRefType newMemRefType = 155 MemRefType::Builder(memrefType).setShape(newShapeConstants); 156 assert(static_cast<int64_t>(dynamicSizes.size()) == 157 newMemRefType.getNumDynamicDims()); 158 159 // Create and insert the alloc op for the new memref. 160 auto newAlloc = rewriter.create<AllocLikeOp>( 161 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 162 alloc.alignmentAttr()); 163 // Insert a cast so we have the same type as the old alloc. 164 auto resultCast = 165 rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc); 166 167 rewriter.replaceOp(alloc, {resultCast}); 168 return success(); 169 } 170 }; 171 172 /// Fold alloc operations with no users or only store and dealloc uses. 173 template <typename T> 174 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 175 using OpRewritePattern<T>::OpRewritePattern; 176 177 LogicalResult matchAndRewrite(T alloc, 178 PatternRewriter &rewriter) const override { 179 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 180 if (auto storeOp = dyn_cast<StoreOp>(op)) 181 return storeOp.value() == alloc; 182 return !isa<DeallocOp>(op); 183 })) 184 return failure(); 185 186 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 187 rewriter.eraseOp(user); 188 189 rewriter.eraseOp(alloc); 190 return success(); 191 } 192 }; 193 } // namespace 194 195 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 196 MLIRContext *context) { 197 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 198 } 199 200 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 201 MLIRContext *context) { 202 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 203 context); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // AllocaScopeOp 208 //===----------------------------------------------------------------------===// 209 210 void AllocaScopeOp::print(OpAsmPrinter &p) { 211 bool printBlockTerminators = false; 212 213 p << ' '; 214 if (!results().empty()) { 215 p << " -> (" << getResultTypes() << ")"; 216 printBlockTerminators = true; 217 } 218 p << ' '; 219 p.printRegion(bodyRegion(), 220 /*printEntryBlockArgs=*/false, 221 /*printBlockTerminators=*/printBlockTerminators); 222 p.printOptionalAttrDict((*this)->getAttrs()); 223 } 224 225 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { 226 // Create a region for the body. 227 result.regions.reserve(1); 228 Region *bodyRegion = result.addRegion(); 229 230 // Parse optional results type list. 231 if (parser.parseOptionalArrowTypeList(result.types)) 232 return failure(); 233 234 // Parse the body region. 235 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 236 return failure(); 237 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 238 result.location); 239 240 // Parse the optional attribute list. 241 if (parser.parseOptionalAttrDict(result.attributes)) 242 return failure(); 243 244 return success(); 245 } 246 247 void AllocaScopeOp::getSuccessorRegions( 248 Optional<unsigned> index, ArrayRef<Attribute> operands, 249 SmallVectorImpl<RegionSuccessor> ®ions) { 250 if (index.hasValue()) { 251 regions.push_back(RegionSuccessor(getResults())); 252 return; 253 } 254 255 regions.push_back(RegionSuccessor(&bodyRegion())); 256 } 257 258 /// Given an operation, return whether this op is guaranteed to 259 /// allocate an AutomaticAllocationScopeResource 260 static bool isGuaranteedAutomaticAllocation(Operation *op) { 261 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 262 if (!interface) 263 return false; 264 for (auto res : op->getResults()) { 265 if (auto effect = 266 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 267 if (isa<SideEffects::AutomaticAllocationScopeResource>( 268 effect->getResource())) 269 return true; 270 } 271 } 272 return false; 273 } 274 275 /// Given an operation, return whether this op itself could 276 /// allocate an AutomaticAllocationScopeResource. Note that 277 /// this will not check whether an operation contained within 278 /// the op can allocate. 279 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { 280 // This op itself doesn't create a stack allocation, 281 // the inner allocation should be handled separately. 282 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) 283 return false; 284 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 285 if (!interface) 286 return true; 287 for (auto res : op->getResults()) { 288 if (auto effect = 289 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 290 if (isa<SideEffects::AutomaticAllocationScopeResource>( 291 effect->getResource())) 292 return true; 293 } 294 } 295 return false; 296 } 297 298 /// Return whether this op is the last non terminating op 299 /// in a region. That is to say, it is in a one-block region 300 /// and is only followed by a terminator. This prevents 301 /// extending the lifetime of allocations. 302 static bool lastNonTerminatorInRegion(Operation *op) { 303 return op->getNextNode() == op->getBlock()->getTerminator() && 304 op->getParentRegion()->getBlocks().size() == 1; 305 } 306 307 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope 308 /// or it contains no allocation. 309 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { 310 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 311 312 LogicalResult matchAndRewrite(AllocaScopeOp op, 313 PatternRewriter &rewriter) const override { 314 bool hasPotentialAlloca = 315 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) { 316 if (alloc == op) 317 return WalkResult::advance(); 318 if (isOpItselfPotentialAutomaticAllocation(alloc)) 319 return WalkResult::interrupt(); 320 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>()) 321 return WalkResult::skip(); 322 return WalkResult::advance(); 323 }).wasInterrupted(); 324 325 // If this contains no potential allocation, it is always legal to 326 // inline. Otherwise, consider two conditions: 327 if (hasPotentialAlloca) { 328 // If the parent isn't an allocation scope, or we are not the last 329 // non-terminator op in the parent, we will extend the lifetime. 330 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) 331 return failure(); 332 if (!lastNonTerminatorInRegion(op)) 333 return failure(); 334 } 335 336 Block *block = &op.getRegion().front(); 337 Operation *terminator = block->getTerminator(); 338 ValueRange results = terminator->getOperands(); 339 rewriter.mergeBlockBefore(block, op); 340 rewriter.replaceOp(op, results); 341 rewriter.eraseOp(terminator); 342 return success(); 343 } 344 }; 345 346 /// Move allocations into an allocation scope, if it is legal to 347 /// move them (e.g. their operands are available at the location 348 /// the op would be moved to). 349 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { 350 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 351 352 LogicalResult matchAndRewrite(AllocaScopeOp op, 353 PatternRewriter &rewriter) const override { 354 355 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 356 return failure(); 357 358 Operation *lastParentWithoutScope = op->getParentOp(); 359 360 if (!lastParentWithoutScope || 361 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>()) 362 return failure(); 363 364 // Only apply to if this is this last non-terminator 365 // op in the block (lest lifetime be extended) of a one 366 // block region 367 if (!lastNonTerminatorInRegion(op) || 368 !lastNonTerminatorInRegion(lastParentWithoutScope)) 369 return failure(); 370 371 while (!lastParentWithoutScope->getParentOp() 372 ->hasTrait<OpTrait::AutomaticAllocationScope>()) { 373 lastParentWithoutScope = lastParentWithoutScope->getParentOp(); 374 if (!lastParentWithoutScope || 375 !lastNonTerminatorInRegion(lastParentWithoutScope)) 376 return failure(); 377 } 378 assert(lastParentWithoutScope->getParentOp() 379 ->hasTrait<OpTrait::AutomaticAllocationScope>()); 380 381 Region *containingRegion = nullptr; 382 for (auto &r : lastParentWithoutScope->getRegions()) { 383 if (r.isAncestor(op->getParentRegion())) { 384 assert(containingRegion == nullptr && 385 "only one region can contain the op"); 386 containingRegion = &r; 387 } 388 } 389 assert(containingRegion && "op must be contained in a region"); 390 391 SmallVector<Operation *> toHoist; 392 op->walk([&](Operation *alloc) { 393 if (!isGuaranteedAutomaticAllocation(alloc)) 394 return WalkResult::skip(); 395 396 // If any operand is not defined before the location of 397 // lastParentWithoutScope (i.e. where we would hoist to), skip. 398 if (llvm::any_of(alloc->getOperands(), [&](Value v) { 399 return containingRegion->isAncestor(v.getParentRegion()); 400 })) 401 return WalkResult::skip(); 402 toHoist.push_back(alloc); 403 return WalkResult::advance(); 404 }); 405 406 if (toHoist.empty()) 407 return failure(); 408 rewriter.setInsertionPoint(lastParentWithoutScope); 409 for (auto *op : toHoist) { 410 auto *cloned = rewriter.clone(*op); 411 rewriter.replaceOp(op, cloned->getResults()); 412 } 413 return success(); 414 } 415 }; 416 417 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, 418 MLIRContext *context) { 419 results.add<AllocaScopeInliner, AllocaScopeHoister>(context); 420 } 421 422 //===----------------------------------------------------------------------===// 423 // AssumeAlignmentOp 424 //===----------------------------------------------------------------------===// 425 426 LogicalResult AssumeAlignmentOp::verify() { 427 if (!llvm::isPowerOf2_32(alignment())) 428 return emitOpError("alignment must be power of 2"); 429 return success(); 430 } 431 432 //===----------------------------------------------------------------------===// 433 // CastOp 434 //===----------------------------------------------------------------------===// 435 436 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 437 /// source memref. This is useful to to fold a memref.cast into a consuming op 438 /// and implement canonicalization patterns for ops in different dialects that 439 /// may consume the results of memref.cast operations. Such foldable memref.cast 440 /// operations are typically inserted as `view` and `subview` ops are 441 /// canonicalized, to preserve the type compatibility of their uses. 442 /// 443 /// Returns true when all conditions are met: 444 /// 1. source and result are ranked memrefs with strided semantics and same 445 /// element type and rank. 446 /// 2. each of the source's size, offset or stride has more static information 447 /// than the corresponding result's size, offset or stride. 448 /// 449 /// Example 1: 450 /// ```mlir 451 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 452 /// %2 = consumer %1 ... : memref<?x?xf32> ... 453 /// ``` 454 /// 455 /// may fold into: 456 /// 457 /// ```mlir 458 /// %2 = consumer %0 ... : memref<8x16xf32> ... 459 /// ``` 460 /// 461 /// Example 2: 462 /// ``` 463 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 464 /// to memref<?x?xf32> 465 /// consumer %1 : memref<?x?xf32> ... 466 /// ``` 467 /// 468 /// may fold into: 469 /// 470 /// ``` 471 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 472 /// ``` 473 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 474 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 475 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 476 477 // Requires ranked MemRefType. 478 if (!sourceType || !resultType) 479 return false; 480 481 // Requires same elemental type. 482 if (sourceType.getElementType() != resultType.getElementType()) 483 return false; 484 485 // Requires same rank. 486 if (sourceType.getRank() != resultType.getRank()) 487 return false; 488 489 // Only fold casts between strided memref forms. 490 int64_t sourceOffset, resultOffset; 491 SmallVector<int64_t, 4> sourceStrides, resultStrides; 492 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 493 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 494 return false; 495 496 // If cast is towards more static sizes along any dimension, don't fold. 497 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 498 auto ss = std::get<0>(it), st = std::get<1>(it); 499 if (ss != st) 500 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 501 return false; 502 } 503 504 // If cast is towards more static offset along any dimension, don't fold. 505 if (sourceOffset != resultOffset) 506 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 507 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 508 return false; 509 510 // If cast is towards more static strides along any dimension, don't fold. 511 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 512 auto ss = std::get<0>(it), st = std::get<1>(it); 513 if (ss != st) 514 if (ShapedType::isDynamicStrideOrOffset(ss) && 515 !ShapedType::isDynamicStrideOrOffset(st)) 516 return false; 517 } 518 519 return true; 520 } 521 522 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 523 if (inputs.size() != 1 || outputs.size() != 1) 524 return false; 525 Type a = inputs.front(), b = outputs.front(); 526 auto aT = a.dyn_cast<MemRefType>(); 527 auto bT = b.dyn_cast<MemRefType>(); 528 529 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 530 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 531 532 if (aT && bT) { 533 if (aT.getElementType() != bT.getElementType()) 534 return false; 535 if (aT.getLayout() != bT.getLayout()) { 536 int64_t aOffset, bOffset; 537 SmallVector<int64_t, 4> aStrides, bStrides; 538 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 539 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 540 aStrides.size() != bStrides.size()) 541 return false; 542 543 // Strides along a dimension/offset are compatible if the value in the 544 // source memref is static and the value in the target memref is the 545 // same. They are also compatible if either one is dynamic (see 546 // description of MemRefCastOp for details). 547 auto checkCompatible = [](int64_t a, int64_t b) { 548 return (a == MemRefType::getDynamicStrideOrOffset() || 549 b == MemRefType::getDynamicStrideOrOffset() || a == b); 550 }; 551 if (!checkCompatible(aOffset, bOffset)) 552 return false; 553 for (const auto &aStride : enumerate(aStrides)) 554 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 555 return false; 556 } 557 if (aT.getMemorySpace() != bT.getMemorySpace()) 558 return false; 559 560 // They must have the same rank, and any specified dimensions must match. 561 if (aT.getRank() != bT.getRank()) 562 return false; 563 564 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 565 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 566 if (aDim != -1 && bDim != -1 && aDim != bDim) 567 return false; 568 } 569 return true; 570 } else { 571 if (!aT && !uaT) 572 return false; 573 if (!bT && !ubT) 574 return false; 575 // Unranked to unranked casting is unsupported 576 if (uaT && ubT) 577 return false; 578 579 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 580 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 581 if (aEltType != bEltType) 582 return false; 583 584 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 585 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 586 return aMemSpace == bMemSpace; 587 } 588 589 return false; 590 } 591 592 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 593 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // CopyOp 598 //===----------------------------------------------------------------------===// 599 600 namespace { 601 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 602 /// and element type, the cast can be skipped. Such CastOps only cast the layout 603 /// of the type. 604 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 605 using OpRewritePattern<CopyOp>::OpRewritePattern; 606 607 LogicalResult matchAndRewrite(CopyOp copyOp, 608 PatternRewriter &rewriter) const override { 609 bool modified = false; 610 611 // Check source. 612 if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) { 613 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 614 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 615 616 if (fromType && toType) { 617 if (fromType.getShape() == toType.getShape() && 618 fromType.getElementType() == toType.getElementType()) { 619 rewriter.updateRootInPlace( 620 copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); 621 modified = true; 622 } 623 } 624 } 625 626 // Check target. 627 if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) { 628 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 629 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 630 631 if (fromType && toType) { 632 if (fromType.getShape() == toType.getShape() && 633 fromType.getElementType() == toType.getElementType()) { 634 rewriter.updateRootInPlace( 635 copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); 636 modified = true; 637 } 638 } 639 } 640 641 return success(modified); 642 } 643 }; 644 645 /// Fold memref.copy(%x, %x). 646 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 647 using OpRewritePattern<CopyOp>::OpRewritePattern; 648 649 LogicalResult matchAndRewrite(CopyOp copyOp, 650 PatternRewriter &rewriter) const override { 651 if (copyOp.source() != copyOp.target()) 652 return failure(); 653 654 rewriter.eraseOp(copyOp); 655 return success(); 656 } 657 }; 658 } // namespace 659 660 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 661 MLIRContext *context) { 662 results.add<FoldCopyOfCast, FoldSelfCopy>(context); 663 } 664 665 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands, 666 SmallVectorImpl<OpFoldResult> &results) { 667 /// copy(memrefcast) -> copy 668 bool folded = false; 669 Operation *op = *this; 670 for (OpOperand &operand : op->getOpOperands()) { 671 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 672 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 673 operand.set(castOp.getOperand()); 674 folded = true; 675 } 676 } 677 return success(folded); 678 } 679 680 //===----------------------------------------------------------------------===// 681 // DeallocOp 682 //===----------------------------------------------------------------------===// 683 684 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 685 SmallVectorImpl<OpFoldResult> &results) { 686 /// dealloc(memrefcast) -> dealloc 687 return foldMemRefCast(*this); 688 } 689 690 //===----------------------------------------------------------------------===// 691 // DimOp 692 //===----------------------------------------------------------------------===// 693 694 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 695 int64_t index) { 696 auto loc = result.location; 697 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 698 build(builder, result, source, indexValue); 699 } 700 701 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 702 Value index) { 703 auto indexTy = builder.getIndexType(); 704 build(builder, result, indexTy, source, index); 705 } 706 707 Optional<int64_t> DimOp::getConstantIndex() { 708 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 709 return constantOp.getValue().cast<IntegerAttr>().getInt(); 710 return {}; 711 } 712 713 LogicalResult DimOp::verify() { 714 // Assume unknown index to be in range. 715 Optional<int64_t> index = getConstantIndex(); 716 if (!index.hasValue()) 717 return success(); 718 719 // Check that constant index is not knowingly out of range. 720 auto type = source().getType(); 721 if (auto memrefType = type.dyn_cast<MemRefType>()) { 722 if (index.getValue() >= memrefType.getRank()) 723 return emitOpError("index is out of range"); 724 } else if (type.isa<UnrankedMemRefType>()) { 725 // Assume index to be in range. 726 } else { 727 llvm_unreachable("expected operand with memref type"); 728 } 729 return success(); 730 } 731 732 /// Return a map with key being elements in `vals` and data being number of 733 /// occurences of it. Use std::map, since the `vals` here are strides and the 734 /// dynamic stride value is the same as the tombstone value for 735 /// `DenseMap<int64_t>`. 736 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 737 std::map<int64_t, unsigned> numOccurences; 738 for (auto val : vals) 739 numOccurences[val]++; 740 return numOccurences; 741 } 742 743 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 744 /// to be a subset of `originalType` with some `1` entries erased, return the 745 /// set of indices that specifies which of the entries of `originalShape` are 746 /// dropped to obtain `reducedShape`. 747 /// This accounts for cases where there are multiple unit-dims, but only a 748 /// subset of those are dropped. For MemRefTypes these can be disambiguated 749 /// using the strides. If a dimension is dropped the stride must be dropped too. 750 static llvm::Optional<llvm::SmallBitVector> 751 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 752 ArrayRef<OpFoldResult> sizes) { 753 llvm::SmallBitVector unusedDims(originalType.getRank()); 754 if (originalType.getRank() == reducedType.getRank()) 755 return unusedDims; 756 757 for (const auto &dim : llvm::enumerate(sizes)) 758 if (auto attr = dim.value().dyn_cast<Attribute>()) 759 if (attr.cast<IntegerAttr>().getInt() == 1) 760 unusedDims.set(dim.index()); 761 762 SmallVector<int64_t> originalStrides, candidateStrides; 763 int64_t originalOffset, candidateOffset; 764 if (failed( 765 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 766 failed( 767 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 768 return llvm::None; 769 770 // For memrefs, a dimension is truly dropped if its corresponding stride is 771 // also dropped. This is particularly important when more than one of the dims 772 // is 1. Track the number of occurences of the strides in the original type 773 // and the candidate type. For each unused dim that stride should not be 774 // present in the candidate type. Note that there could be multiple dimensions 775 // that have the same size. We dont need to exactly figure out which dim 776 // corresponds to which stride, we just need to verify that the number of 777 // reptitions of a stride in the original + number of unused dims with that 778 // stride == number of repititions of a stride in the candidate. 779 std::map<int64_t, unsigned> currUnaccountedStrides = 780 getNumOccurences(originalStrides); 781 std::map<int64_t, unsigned> candidateStridesNumOccurences = 782 getNumOccurences(candidateStrides); 783 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) { 784 if (!unusedDims.test(dim)) 785 continue; 786 int64_t originalStride = originalStrides[dim]; 787 if (currUnaccountedStrides[originalStride] > 788 candidateStridesNumOccurences[originalStride]) { 789 // This dim can be treated as dropped. 790 currUnaccountedStrides[originalStride]--; 791 continue; 792 } 793 if (currUnaccountedStrides[originalStride] == 794 candidateStridesNumOccurences[originalStride]) { 795 // The stride for this is not dropped. Keep as is. 796 unusedDims.reset(dim); 797 continue; 798 } 799 if (currUnaccountedStrides[originalStride] < 800 candidateStridesNumOccurences[originalStride]) { 801 // This should never happen. Cant have a stride in the reduced rank type 802 // that wasnt in the original one. 803 return llvm::None; 804 } 805 } 806 807 if ((int64_t)unusedDims.count() + reducedType.getRank() != 808 originalType.getRank()) 809 return llvm::None; 810 return unusedDims; 811 } 812 813 llvm::SmallBitVector SubViewOp::getDroppedDims() { 814 MemRefType sourceType = getSourceType(); 815 MemRefType resultType = getType(); 816 llvm::Optional<llvm::SmallBitVector> unusedDims = 817 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 818 assert(unusedDims && "unable to find unused dims of subview"); 819 return *unusedDims; 820 } 821 822 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 823 // All forms of folding require a known index. 824 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 825 if (!index) 826 return {}; 827 828 // Folding for unranked types (UnrankedMemRefType) is not supported. 829 auto memrefType = source().getType().dyn_cast<MemRefType>(); 830 if (!memrefType) 831 return {}; 832 833 // Fold if the shape extent along the given index is known. 834 if (!memrefType.isDynamicDim(index.getInt())) { 835 Builder builder(getContext()); 836 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 837 } 838 839 // The size at the given index is now known to be a dynamic size. 840 unsigned unsignedIndex = index.getValue().getZExtValue(); 841 842 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 843 Operation *definingOp = source().getDefiningOp(); 844 845 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 846 return *(alloc.getDynamicSizes().begin() + 847 memrefType.getDynamicDimIndex(unsignedIndex)); 848 849 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 850 return *(alloca.getDynamicSizes().begin() + 851 memrefType.getDynamicDimIndex(unsignedIndex)); 852 853 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 854 return *(view.getDynamicSizes().begin() + 855 memrefType.getDynamicDimIndex(unsignedIndex)); 856 857 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 858 llvm::SmallBitVector unusedDims = subview.getDroppedDims(); 859 unsigned resultIndex = 0; 860 unsigned sourceRank = subview.getSourceType().getRank(); 861 unsigned sourceIndex = 0; 862 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 863 if (unusedDims.test(i)) 864 continue; 865 if (resultIndex == unsignedIndex) { 866 sourceIndex = i; 867 break; 868 } 869 resultIndex++; 870 } 871 assert(subview.isDynamicSize(sourceIndex) && 872 "expected dynamic subview size"); 873 return subview.getDynamicSize(sourceIndex); 874 } 875 876 if (auto sizeInterface = 877 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 878 assert(sizeInterface.isDynamicSize(unsignedIndex) && 879 "Expected dynamic subview size"); 880 return sizeInterface.getDynamicSize(unsignedIndex); 881 } 882 883 // dim(memrefcast) -> dim 884 if (succeeded(foldMemRefCast(*this))) 885 return getResult(); 886 887 return {}; 888 } 889 890 namespace { 891 /// Fold dim of a memref reshape operation to a load into the reshape's shape 892 /// operand. 893 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 894 using OpRewritePattern<DimOp>::OpRewritePattern; 895 896 LogicalResult matchAndRewrite(DimOp dim, 897 PatternRewriter &rewriter) const override { 898 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 899 900 if (!reshape) 901 return failure(); 902 903 // Place the load directly after the reshape to ensure that the shape memref 904 // was not mutated. 905 rewriter.setInsertionPointAfter(reshape); 906 Location loc = dim.getLoc(); 907 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 908 if (load.getType() != dim.getType()) 909 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 910 rewriter.replaceOp(dim, load); 911 return success(); 912 } 913 }; 914 915 } // namespace 916 917 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 918 MLIRContext *context) { 919 results.add<DimOfMemRefReshape>(context); 920 } 921 922 // --------------------------------------------------------------------------- 923 // DmaStartOp 924 // --------------------------------------------------------------------------- 925 926 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 927 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 928 ValueRange destIndices, Value numElements, 929 Value tagMemRef, ValueRange tagIndices, Value stride, 930 Value elementsPerStride) { 931 result.addOperands(srcMemRef); 932 result.addOperands(srcIndices); 933 result.addOperands(destMemRef); 934 result.addOperands(destIndices); 935 result.addOperands({numElements, tagMemRef}); 936 result.addOperands(tagIndices); 937 if (stride) 938 result.addOperands({stride, elementsPerStride}); 939 } 940 941 void DmaStartOp::print(OpAsmPrinter &p) { 942 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " 943 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() 944 << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; 945 if (isStrided()) 946 p << ", " << getStride() << ", " << getNumElementsPerStride(); 947 948 p.printOptionalAttrDict((*this)->getAttrs()); 949 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 950 << ", " << getTagMemRef().getType(); 951 } 952 953 // Parse DmaStartOp. 954 // Ex: 955 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 956 // %tag[%index], %stride, %num_elt_per_stride : 957 // : memref<3076 x f32, 0>, 958 // memref<1024 x f32, 2>, 959 // memref<1 x i32> 960 // 961 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 962 OpAsmParser::OperandType srcMemRefInfo; 963 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 964 OpAsmParser::OperandType dstMemRefInfo; 965 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 966 OpAsmParser::OperandType numElementsInfo; 967 OpAsmParser::OperandType tagMemrefInfo; 968 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 969 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 970 971 SmallVector<Type, 3> types; 972 auto indexType = parser.getBuilder().getIndexType(); 973 974 // Parse and resolve the following list of operands: 975 // *) source memref followed by its indices (in square brackets). 976 // *) destination memref followed by its indices (in square brackets). 977 // *) dma size in KiB. 978 if (parser.parseOperand(srcMemRefInfo) || 979 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 980 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 981 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 982 parser.parseComma() || parser.parseOperand(numElementsInfo) || 983 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 984 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 985 return failure(); 986 987 // Parse optional stride and elements per stride. 988 if (parser.parseTrailingOperandList(strideInfo)) 989 return failure(); 990 991 bool isStrided = strideInfo.size() == 2; 992 if (!strideInfo.empty() && !isStrided) { 993 return parser.emitError(parser.getNameLoc(), 994 "expected two stride related operands"); 995 } 996 997 if (parser.parseColonTypeList(types)) 998 return failure(); 999 if (types.size() != 3) 1000 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 1001 1002 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 1003 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 1004 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 1005 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 1006 // size should be an index. 1007 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 1008 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 1009 // tag indices should be index. 1010 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 1011 return failure(); 1012 1013 if (isStrided) { 1014 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1015 return failure(); 1016 } 1017 1018 return success(); 1019 } 1020 1021 LogicalResult DmaStartOp::verify() { 1022 unsigned numOperands = getNumOperands(); 1023 1024 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1025 // the number of elements. 1026 if (numOperands < 4) 1027 return emitOpError("expected at least 4 operands"); 1028 1029 // Check types of operands. The order of these calls is important: the later 1030 // calls rely on some type properties to compute the operand position. 1031 // 1. Source memref. 1032 if (!getSrcMemRef().getType().isa<MemRefType>()) 1033 return emitOpError("expected source to be of memref type"); 1034 if (numOperands < getSrcMemRefRank() + 4) 1035 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 1036 << " operands"; 1037 if (!getSrcIndices().empty() && 1038 !llvm::all_of(getSrcIndices().getTypes(), 1039 [](Type t) { return t.isIndex(); })) 1040 return emitOpError("expected source indices to be of index type"); 1041 1042 // 2. Destination memref. 1043 if (!getDstMemRef().getType().isa<MemRefType>()) 1044 return emitOpError("expected destination to be of memref type"); 1045 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 1046 if (numOperands < numExpectedOperands) 1047 return emitOpError() << "expected at least " << numExpectedOperands 1048 << " operands"; 1049 if (!getDstIndices().empty() && 1050 !llvm::all_of(getDstIndices().getTypes(), 1051 [](Type t) { return t.isIndex(); })) 1052 return emitOpError("expected destination indices to be of index type"); 1053 1054 // 3. Number of elements. 1055 if (!getNumElements().getType().isIndex()) 1056 return emitOpError("expected num elements to be of index type"); 1057 1058 // 4. Tag memref. 1059 if (!getTagMemRef().getType().isa<MemRefType>()) 1060 return emitOpError("expected tag to be of memref type"); 1061 numExpectedOperands += getTagMemRefRank(); 1062 if (numOperands < numExpectedOperands) 1063 return emitOpError() << "expected at least " << numExpectedOperands 1064 << " operands"; 1065 if (!getTagIndices().empty() && 1066 !llvm::all_of(getTagIndices().getTypes(), 1067 [](Type t) { return t.isIndex(); })) 1068 return emitOpError("expected tag indices to be of index type"); 1069 1070 // Optional stride-related operands must be either both present or both 1071 // absent. 1072 if (numOperands != numExpectedOperands && 1073 numOperands != numExpectedOperands + 2) 1074 return emitOpError("incorrect number of operands"); 1075 1076 // 5. Strides. 1077 if (isStrided()) { 1078 if (!getStride().getType().isIndex() || 1079 !getNumElementsPerStride().getType().isIndex()) 1080 return emitOpError( 1081 "expected stride and num elements per stride to be of type index"); 1082 } 1083 1084 return success(); 1085 } 1086 1087 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1088 SmallVectorImpl<OpFoldResult> &results) { 1089 /// dma_start(memrefcast) -> dma_start 1090 return foldMemRefCast(*this); 1091 } 1092 1093 // --------------------------------------------------------------------------- 1094 // DmaWaitOp 1095 // --------------------------------------------------------------------------- 1096 1097 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1098 SmallVectorImpl<OpFoldResult> &results) { 1099 /// dma_wait(memrefcast) -> dma_wait 1100 return foldMemRefCast(*this); 1101 } 1102 1103 LogicalResult DmaWaitOp::verify() { 1104 // Check that the number of tag indices matches the tagMemRef rank. 1105 unsigned numTagIndices = tagIndices().size(); 1106 unsigned tagMemRefRank = getTagMemRefRank(); 1107 if (numTagIndices != tagMemRefRank) 1108 return emitOpError() << "expected tagIndices to have the same number of " 1109 "elements as the tagMemRef rank, expected " 1110 << tagMemRefRank << ", but got " << numTagIndices; 1111 return success(); 1112 } 1113 1114 //===----------------------------------------------------------------------===// 1115 // GenericAtomicRMWOp 1116 //===----------------------------------------------------------------------===// 1117 1118 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, 1119 Value memref, ValueRange ivs) { 1120 result.addOperands(memref); 1121 result.addOperands(ivs); 1122 1123 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) { 1124 Type elementType = memrefType.getElementType(); 1125 result.addTypes(elementType); 1126 1127 Region *bodyRegion = result.addRegion(); 1128 bodyRegion->push_back(new Block()); 1129 bodyRegion->addArgument(elementType, memref.getLoc()); 1130 } 1131 } 1132 1133 LogicalResult GenericAtomicRMWOp::verify() { 1134 auto &body = getRegion(); 1135 if (body.getNumArguments() != 1) 1136 return emitOpError("expected single number of entry block arguments"); 1137 1138 if (getResult().getType() != body.getArgument(0).getType()) 1139 return emitOpError("expected block argument of the same type result type"); 1140 1141 bool hasSideEffects = 1142 body.walk([&](Operation *nestedOp) { 1143 if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) 1144 return WalkResult::advance(); 1145 nestedOp->emitError( 1146 "body of 'memref.generic_atomic_rmw' should contain " 1147 "only operations with no side effects"); 1148 return WalkResult::interrupt(); 1149 }) 1150 .wasInterrupted(); 1151 return hasSideEffects ? failure() : success(); 1152 } 1153 1154 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, 1155 OperationState &result) { 1156 OpAsmParser::OperandType memref; 1157 Type memrefType; 1158 SmallVector<OpAsmParser::OperandType, 4> ivs; 1159 1160 Type indexType = parser.getBuilder().getIndexType(); 1161 if (parser.parseOperand(memref) || 1162 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || 1163 parser.parseColonType(memrefType) || 1164 parser.resolveOperand(memref, memrefType, result.operands) || 1165 parser.resolveOperands(ivs, indexType, result.operands)) 1166 return failure(); 1167 1168 Region *body = result.addRegion(); 1169 if (parser.parseRegion(*body, llvm::None, llvm::None) || 1170 parser.parseOptionalAttrDict(result.attributes)) 1171 return failure(); 1172 result.types.push_back(memrefType.cast<MemRefType>().getElementType()); 1173 return success(); 1174 } 1175 1176 void GenericAtomicRMWOp::print(OpAsmPrinter &p) { 1177 p << ' ' << memref() << "[" << indices() << "] : " << memref().getType() 1178 << ' '; 1179 p.printRegion(getRegion()); 1180 p.printOptionalAttrDict((*this)->getAttrs()); 1181 } 1182 1183 //===----------------------------------------------------------------------===// 1184 // AtomicYieldOp 1185 //===----------------------------------------------------------------------===// 1186 1187 LogicalResult AtomicYieldOp::verify() { 1188 Type parentType = (*this)->getParentOp()->getResultTypes().front(); 1189 Type resultType = result().getType(); 1190 if (parentType != resultType) 1191 return emitOpError() << "types mismatch between yield op: " << resultType 1192 << " and its parent: " << parentType; 1193 return success(); 1194 } 1195 1196 //===----------------------------------------------------------------------===// 1197 // GlobalOp 1198 //===----------------------------------------------------------------------===// 1199 1200 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1201 TypeAttr type, 1202 Attribute initialValue) { 1203 p << type; 1204 if (!op.isExternal()) { 1205 p << " = "; 1206 if (op.isUninitialized()) 1207 p << "uninitialized"; 1208 else 1209 p.printAttributeWithoutType(initialValue); 1210 } 1211 } 1212 1213 static ParseResult 1214 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1215 Attribute &initialValue) { 1216 Type type; 1217 if (parser.parseType(type)) 1218 return failure(); 1219 1220 auto memrefType = type.dyn_cast<MemRefType>(); 1221 if (!memrefType || !memrefType.hasStaticShape()) 1222 return parser.emitError(parser.getNameLoc()) 1223 << "type should be static shaped memref, but got " << type; 1224 typeAttr = TypeAttr::get(type); 1225 1226 if (parser.parseOptionalEqual()) 1227 return success(); 1228 1229 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1230 initialValue = UnitAttr::get(parser.getContext()); 1231 return success(); 1232 } 1233 1234 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1235 if (parser.parseAttribute(initialValue, tensorType)) 1236 return failure(); 1237 if (!initialValue.isa<ElementsAttr>()) 1238 return parser.emitError(parser.getNameLoc()) 1239 << "initial value should be a unit or elements attribute"; 1240 return success(); 1241 } 1242 1243 LogicalResult GlobalOp::verify() { 1244 auto memrefType = type().dyn_cast<MemRefType>(); 1245 if (!memrefType || !memrefType.hasStaticShape()) 1246 return emitOpError("type should be static shaped memref, but got ") 1247 << type(); 1248 1249 // Verify that the initial value, if present, is either a unit attribute or 1250 // an elements attribute. 1251 if (initial_value().hasValue()) { 1252 Attribute initValue = initial_value().getValue(); 1253 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1254 return emitOpError("initial value should be a unit or elements " 1255 "attribute, but got ") 1256 << initValue; 1257 1258 // Check that the type of the initial value is compatible with the type of 1259 // the global variable. 1260 if (initValue.isa<ElementsAttr>()) { 1261 Type initType = initValue.getType(); 1262 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1263 if (initType != tensorType) 1264 return emitOpError("initial value expected to be of type ") 1265 << tensorType << ", but was of type " << initType; 1266 } 1267 } 1268 1269 if (Optional<uint64_t> alignAttr = alignment()) { 1270 uint64_t alignment = alignAttr.getValue(); 1271 1272 if (!llvm::isPowerOf2_64(alignment)) 1273 return emitError() << "alignment attribute value " << alignment 1274 << " is not a power of 2"; 1275 } 1276 1277 // TODO: verify visibility for declarations. 1278 return success(); 1279 } 1280 1281 ElementsAttr GlobalOp::getConstantInitValue() { 1282 auto initVal = initial_value(); 1283 if (constant() && initVal.hasValue()) 1284 return initVal.getValue().cast<ElementsAttr>(); 1285 return {}; 1286 } 1287 1288 //===----------------------------------------------------------------------===// 1289 // GetGlobalOp 1290 //===----------------------------------------------------------------------===// 1291 1292 LogicalResult 1293 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1294 // Verify that the result type is same as the type of the referenced 1295 // memref.global op. 1296 auto global = 1297 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1298 if (!global) 1299 return emitOpError("'") 1300 << name() << "' does not reference a valid global memref"; 1301 1302 Type resultType = result().getType(); 1303 if (global.type() != resultType) 1304 return emitOpError("result type ") 1305 << resultType << " does not match type " << global.type() 1306 << " of the global memref @" << name(); 1307 return success(); 1308 } 1309 1310 //===----------------------------------------------------------------------===// 1311 // LoadOp 1312 //===----------------------------------------------------------------------===// 1313 1314 LogicalResult LoadOp::verify() { 1315 if (getNumOperands() != 1 + getMemRefType().getRank()) 1316 return emitOpError("incorrect number of indices for load"); 1317 return success(); 1318 } 1319 1320 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1321 /// load(memrefcast) -> load 1322 if (succeeded(foldMemRefCast(*this))) 1323 return getResult(); 1324 return OpFoldResult(); 1325 } 1326 1327 //===----------------------------------------------------------------------===// 1328 // PrefetchOp 1329 //===----------------------------------------------------------------------===// 1330 1331 void PrefetchOp::print(OpAsmPrinter &p) { 1332 p << " " << memref() << '['; 1333 p.printOperands(indices()); 1334 p << ']' << ", " << (isWrite() ? "write" : "read"); 1335 p << ", locality<" << localityHint(); 1336 p << ">, " << (isDataCache() ? "data" : "instr"); 1337 p.printOptionalAttrDict( 1338 (*this)->getAttrs(), 1339 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1340 p << " : " << getMemRefType(); 1341 } 1342 1343 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { 1344 OpAsmParser::OperandType memrefInfo; 1345 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1346 IntegerAttr localityHint; 1347 MemRefType type; 1348 StringRef readOrWrite, cacheType; 1349 1350 auto indexTy = parser.getBuilder().getIndexType(); 1351 auto i32Type = parser.getBuilder().getIntegerType(32); 1352 if (parser.parseOperand(memrefInfo) || 1353 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1354 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1355 parser.parseComma() || parser.parseKeyword("locality") || 1356 parser.parseLess() || 1357 parser.parseAttribute(localityHint, i32Type, "localityHint", 1358 result.attributes) || 1359 parser.parseGreater() || parser.parseComma() || 1360 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1361 parser.resolveOperand(memrefInfo, type, result.operands) || 1362 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1363 return failure(); 1364 1365 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1366 return parser.emitError(parser.getNameLoc(), 1367 "rw specifier has to be 'read' or 'write'"); 1368 result.addAttribute( 1369 PrefetchOp::getIsWriteAttrName(), 1370 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1371 1372 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1373 return parser.emitError(parser.getNameLoc(), 1374 "cache type has to be 'data' or 'instr'"); 1375 1376 result.addAttribute( 1377 PrefetchOp::getIsDataCacheAttrName(), 1378 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1379 1380 return success(); 1381 } 1382 1383 LogicalResult PrefetchOp::verify() { 1384 if (getNumOperands() != 1 + getMemRefType().getRank()) 1385 return emitOpError("too few indices"); 1386 1387 return success(); 1388 } 1389 1390 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1391 SmallVectorImpl<OpFoldResult> &results) { 1392 // prefetch(memrefcast) -> prefetch 1393 return foldMemRefCast(*this); 1394 } 1395 1396 //===----------------------------------------------------------------------===// 1397 // RankOp 1398 //===----------------------------------------------------------------------===// 1399 1400 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1401 // Constant fold rank when the rank of the operand is known. 1402 auto type = getOperand().getType(); 1403 auto shapedType = type.dyn_cast<ShapedType>(); 1404 if (shapedType && shapedType.hasRank()) 1405 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1406 return IntegerAttr(); 1407 } 1408 1409 //===----------------------------------------------------------------------===// 1410 // ReinterpretCastOp 1411 //===----------------------------------------------------------------------===// 1412 1413 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1414 /// `staticSizes` and `staticStrides` are automatically filled with 1415 /// source-memref-rank sentinel values that encode dynamic entries. 1416 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1417 MemRefType resultType, Value source, 1418 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1419 ArrayRef<OpFoldResult> strides, 1420 ArrayRef<NamedAttribute> attrs) { 1421 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1422 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1423 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1424 ShapedType::kDynamicStrideOrOffset); 1425 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1426 ShapedType::kDynamicSize); 1427 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1428 ShapedType::kDynamicStrideOrOffset); 1429 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1430 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1431 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1432 result.addAttributes(attrs); 1433 } 1434 1435 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1436 MemRefType resultType, Value source, 1437 int64_t offset, ArrayRef<int64_t> sizes, 1438 ArrayRef<int64_t> strides, 1439 ArrayRef<NamedAttribute> attrs) { 1440 SmallVector<OpFoldResult> sizeValues = 1441 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1442 return b.getI64IntegerAttr(v); 1443 })); 1444 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1445 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1446 return b.getI64IntegerAttr(v); 1447 })); 1448 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1449 strideValues, attrs); 1450 } 1451 1452 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1453 MemRefType resultType, Value source, Value offset, 1454 ValueRange sizes, ValueRange strides, 1455 ArrayRef<NamedAttribute> attrs) { 1456 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1457 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1458 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1459 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1460 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1461 } 1462 1463 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1464 // completed automatically, like we have for subview and extract_slice. 1465 LogicalResult ReinterpretCastOp::verify() { 1466 // The source and result memrefs should be in the same memory space. 1467 auto srcType = source().getType().cast<BaseMemRefType>(); 1468 auto resultType = getType().cast<MemRefType>(); 1469 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1470 return emitError("different memory spaces specified for source type ") 1471 << srcType << " and result memref type " << resultType; 1472 if (srcType.getElementType() != resultType.getElementType()) 1473 return emitError("different element types specified for source type ") 1474 << srcType << " and result memref type " << resultType; 1475 1476 // Match sizes in result memref type and in static_sizes attribute. 1477 for (auto &en : llvm::enumerate(llvm::zip( 1478 resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { 1479 int64_t resultSize = std::get<0>(en.value()); 1480 int64_t expectedSize = std::get<1>(en.value()); 1481 if (!ShapedType::isDynamic(resultSize) && 1482 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1483 return emitError("expected result type with size = ") 1484 << expectedSize << " instead of " << resultSize 1485 << " in dim = " << en.index(); 1486 } 1487 1488 // Match offset and strides in static_offset and static_strides attributes. If 1489 // result memref type has no affine map specified, this will assume an 1490 // identity layout. 1491 int64_t resultOffset; 1492 SmallVector<int64_t, 4> resultStrides; 1493 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1494 return emitError("expected result type to have strided layout but found ") 1495 << resultType; 1496 1497 // Match offset in result memref type and in static_offsets attribute. 1498 int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); 1499 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1500 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1501 resultOffset != expectedOffset) 1502 return emitError("expected result type with offset = ") 1503 << resultOffset << " instead of " << expectedOffset; 1504 1505 // Match strides in result memref type and in static_strides attribute. 1506 for (auto &en : llvm::enumerate(llvm::zip( 1507 resultStrides, extractFromI64ArrayAttr(static_strides())))) { 1508 int64_t resultStride = std::get<0>(en.value()); 1509 int64_t expectedStride = std::get<1>(en.value()); 1510 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1511 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1512 resultStride != expectedStride) 1513 return emitError("expected result type with stride = ") 1514 << expectedStride << " instead of " << resultStride 1515 << " in dim = " << en.index(); 1516 } 1517 1518 return success(); 1519 } 1520 1521 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) { 1522 Value src = source(); 1523 auto getPrevSrc = [&]() -> Value { 1524 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x). 1525 if (auto prev = src.getDefiningOp<ReinterpretCastOp>()) 1526 return prev.source(); 1527 1528 // reinterpret_cast(cast(x)) -> reinterpret_cast(x). 1529 if (auto prev = src.getDefiningOp<CastOp>()) 1530 return prev.source(); 1531 1532 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets 1533 // are 0. 1534 if (auto prev = src.getDefiningOp<SubViewOp>()) 1535 if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { 1536 return isConstantIntValue(val, 0); 1537 })) 1538 return prev.source(); 1539 1540 return nullptr; 1541 }; 1542 1543 if (auto prevSrc = getPrevSrc()) { 1544 sourceMutable().assign(prevSrc); 1545 return getResult(); 1546 } 1547 1548 return nullptr; 1549 } 1550 1551 //===----------------------------------------------------------------------===// 1552 // Reassociative reshape ops 1553 //===----------------------------------------------------------------------===// 1554 1555 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1556 return getSymbolLessAffineMaps(getReassociationExprs()); 1557 } 1558 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1559 return convertReassociationIndicesToExprs(getContext(), 1560 getReassociationIndices()); 1561 } 1562 1563 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1564 return getSymbolLessAffineMaps(getReassociationExprs()); 1565 } 1566 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1567 return convertReassociationIndicesToExprs(getContext(), 1568 getReassociationIndices()); 1569 } 1570 1571 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1572 /// copies. 1573 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1574 ArrayRef<int64_t> sizes, 1575 ArrayRef<AffineExpr> strides) { 1576 // Bands of extent one can be reshaped, as they are not reshaped at all. 1577 if (extent == 1) 1578 return true; 1579 // Otherwise, the size of the first dimension needs to be known. 1580 if (ShapedType::isDynamic(sizes[dim])) 1581 return false; 1582 assert(sizes.size() == strides.size() && "mismatched ranks"); 1583 // off by 1 indexing to avoid out of bounds 1584 // V 1585 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1586 // Only bands of static shapes are reshapable. This is due to the fact that 1587 // there is no relation between dynamic sizes and dynamic strides: we do not 1588 // have enough information to know whether a "-1" size corresponds to the 1589 // proper symbol in the AffineExpr of a stride. 1590 if (ShapedType::isDynamic(sizes[idx + 1])) 1591 return false; 1592 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1593 // simplify on the fly and catch more reshapable cases. 1594 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1595 return false; 1596 } 1597 return true; 1598 } 1599 1600 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1601 /// expected to be valid) to `type`. 1602 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1603 /// MemRefType. 1604 static MemRefType 1605 computeReshapeCollapsedType(MemRefType type, 1606 ArrayRef<AffineMap> reassociation) { 1607 auto sizes = type.getShape(); 1608 AffineExpr offset; 1609 SmallVector<AffineExpr, 4> strides; 1610 auto status = getStridesAndOffset(type, strides, offset); 1611 auto isIdentityLayout = type.getLayout().isIdentity(); 1612 (void)status; 1613 assert(succeeded(status) && "expected strided memref"); 1614 1615 SmallVector<int64_t, 4> newSizes; 1616 newSizes.reserve(reassociation.size()); 1617 SmallVector<AffineExpr, 4> newStrides; 1618 newStrides.reserve(reassociation.size()); 1619 1620 // Use the fact that reassociation is valid to simplify the logic: only use 1621 // each map's rank. 1622 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1623 unsigned currentDim = 0; 1624 for (AffineMap m : reassociation) { 1625 unsigned dim = m.getNumResults(); 1626 int64_t size = 1; 1627 AffineExpr stride = strides[currentDim + dim - 1]; 1628 if (isIdentityLayout || 1629 isReshapableDimBand(currentDim, dim, sizes, strides)) { 1630 for (unsigned d = 0; d < dim; ++d) { 1631 int64_t currentSize = sizes[currentDim + d]; 1632 if (ShapedType::isDynamic(currentSize)) { 1633 size = ShapedType::kDynamicSize; 1634 break; 1635 } 1636 size *= currentSize; 1637 } 1638 } else { 1639 size = ShapedType::kDynamicSize; 1640 stride = AffineExpr(); 1641 } 1642 newSizes.push_back(size); 1643 newStrides.push_back(stride); 1644 currentDim += dim; 1645 } 1646 1647 // Early-exit: if `type` is contiguous, the result must be contiguous. 1648 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1649 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1650 1651 // Convert back to int64_t because we don't have enough information to create 1652 // new strided layouts from AffineExpr only. This corresponds to a case where 1653 // copies may be necessary. 1654 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1655 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1656 intOffset = o.getValue(); 1657 SmallVector<int64_t, 4> intStrides; 1658 intStrides.reserve(strides.size()); 1659 for (auto stride : newStrides) { 1660 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1661 intStrides.push_back(cst.getValue()); 1662 else 1663 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1664 } 1665 auto layout = 1666 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1667 return canonicalizeStridedLayout( 1668 MemRefType::Builder(type).setShape(newSizes).setLayout( 1669 AffineMapAttr::get(layout))); 1670 } 1671 1672 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1673 ArrayRef<ReassociationIndices> reassociation, 1674 ArrayRef<NamedAttribute> attrs) { 1675 auto memRefType = src.getType().cast<MemRefType>(); 1676 auto resultType = computeReshapeCollapsedType( 1677 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1678 b.getContext(), reassociation))); 1679 build(b, result, resultType, src, attrs); 1680 result.addAttribute(getReassociationAttrName(), 1681 getReassociationIndicesAttribute(b, reassociation)); 1682 } 1683 1684 template <typename ReshapeOp, 1685 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1686 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1687 MemRefType collapsedType) { 1688 if (failed( 1689 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1690 return failure(); 1691 auto maps = op.getReassociationMaps(); 1692 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1693 if (collapsedType != expectedType) 1694 return op.emitOpError("expected collapsed type to be ") 1695 << expectedType << ", but got " << collapsedType; 1696 return success(); 1697 } 1698 1699 LogicalResult ExpandShapeOp::verify() { 1700 return verifyReshapeOp(*this, getResultType(), getSrcType()); 1701 } 1702 1703 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1704 MLIRContext *context) { 1705 results.add<CollapseReshapeOps<ExpandShapeOp>, 1706 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1707 } 1708 1709 LogicalResult CollapseShapeOp::verify() { 1710 return verifyReshapeOp(*this, getSrcType(), getResultType()); 1711 } 1712 1713 struct CollapseShapeOpMemRefCastFolder 1714 : public OpRewritePattern<CollapseShapeOp> { 1715 public: 1716 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1717 1718 LogicalResult matchAndRewrite(CollapseShapeOp op, 1719 PatternRewriter &rewriter) const override { 1720 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1721 if (!cast) 1722 return failure(); 1723 1724 if (!CastOp::canFoldIntoConsumerOp(cast)) 1725 return failure(); 1726 1727 Type newResultType = computeReshapeCollapsedType( 1728 cast.getOperand().getType().cast<MemRefType>(), 1729 op.getReassociationMaps()); 1730 1731 if (newResultType == op.getResultType()) { 1732 rewriter.updateRootInPlace( 1733 op, [&]() { op.srcMutable().assign(cast.source()); }); 1734 } else { 1735 Value newOp = rewriter.create<CollapseShapeOp>( 1736 op->getLoc(), cast.source(), op.getReassociationIndices()); 1737 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1738 } 1739 return success(); 1740 } 1741 }; 1742 1743 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1744 MLIRContext *context) { 1745 results.add<CollapseReshapeOps<CollapseShapeOp>, 1746 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1747 CollapseShapeOpMemRefCastFolder>(context); 1748 } 1749 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1750 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1751 } 1752 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1753 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1754 } 1755 1756 //===----------------------------------------------------------------------===// 1757 // ReshapeOp 1758 //===----------------------------------------------------------------------===// 1759 1760 LogicalResult ReshapeOp::verify() { 1761 Type operandType = source().getType(); 1762 Type resultType = result().getType(); 1763 1764 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1765 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1766 if (operandElementType != resultElementType) 1767 return emitOpError("element types of source and destination memref " 1768 "types should be the same"); 1769 1770 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1771 if (!operandMemRefType.getLayout().isIdentity()) 1772 return emitOpError("source memref type should have identity affine map"); 1773 1774 int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0); 1775 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1776 if (resultMemRefType) { 1777 if (!resultMemRefType.getLayout().isIdentity()) 1778 return emitOpError("result memref type should have identity affine map"); 1779 if (shapeSize == ShapedType::kDynamicSize) 1780 return emitOpError("cannot use shape operand with dynamic length to " 1781 "reshape to statically-ranked memref type"); 1782 if (shapeSize != resultMemRefType.getRank()) 1783 return emitOpError( 1784 "length of shape operand differs from the result's memref rank"); 1785 } 1786 return success(); 1787 } 1788 1789 //===----------------------------------------------------------------------===// 1790 // StoreOp 1791 //===----------------------------------------------------------------------===// 1792 1793 LogicalResult StoreOp::verify() { 1794 if (getNumOperands() != 2 + getMemRefType().getRank()) 1795 return emitOpError("store index operand count not equal to memref rank"); 1796 1797 return success(); 1798 } 1799 1800 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1801 SmallVectorImpl<OpFoldResult> &results) { 1802 /// store(memrefcast) -> store 1803 return foldMemRefCast(*this, getValueToStore()); 1804 } 1805 1806 //===----------------------------------------------------------------------===// 1807 // SubViewOp 1808 //===----------------------------------------------------------------------===// 1809 1810 namespace { 1811 /// Helpers to write more idiomatic operations. 1812 namespace saturated_arith { 1813 struct Wrapper { 1814 explicit Wrapper(int64_t v) : v(v) {} 1815 operator int64_t() { return v; } 1816 int64_t v; 1817 }; 1818 Wrapper operator+(Wrapper a, int64_t b) { 1819 if (ShapedType::isDynamicStrideOrOffset(a) || 1820 ShapedType::isDynamicStrideOrOffset(b)) 1821 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1822 return Wrapper(a.v + b); 1823 } 1824 Wrapper operator*(Wrapper a, int64_t b) { 1825 if (ShapedType::isDynamicStrideOrOffset(a) || 1826 ShapedType::isDynamicStrideOrOffset(b)) 1827 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1828 return Wrapper(a.v * b); 1829 } 1830 } // namespace saturated_arith 1831 } // namespace 1832 1833 /// A subview result type can be fully inferred from the source type and the 1834 /// static representation of offsets, sizes and strides. Special sentinels 1835 /// encode the dynamic case. 1836 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1837 ArrayRef<int64_t> staticOffsets, 1838 ArrayRef<int64_t> staticSizes, 1839 ArrayRef<int64_t> staticStrides) { 1840 unsigned rank = sourceMemRefType.getRank(); 1841 (void)rank; 1842 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 1843 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 1844 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 1845 1846 // Extract source offset and strides. 1847 int64_t sourceOffset; 1848 SmallVector<int64_t, 4> sourceStrides; 1849 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1850 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1851 (void)res; 1852 1853 // Compute target offset whose value is: 1854 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1855 int64_t targetOffset = sourceOffset; 1856 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1857 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1858 using namespace saturated_arith; 1859 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1860 } 1861 1862 // Compute target stride whose value is: 1863 // `sourceStrides_i * staticStrides_i`. 1864 SmallVector<int64_t, 4> targetStrides; 1865 targetStrides.reserve(staticOffsets.size()); 1866 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1867 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1868 using namespace saturated_arith; 1869 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1870 } 1871 1872 // The type is now known. 1873 return MemRefType::get( 1874 staticSizes, sourceMemRefType.getElementType(), 1875 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1876 sourceMemRefType.getContext()), 1877 sourceMemRefType.getMemorySpace()); 1878 } 1879 1880 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1881 ArrayRef<OpFoldResult> offsets, 1882 ArrayRef<OpFoldResult> sizes, 1883 ArrayRef<OpFoldResult> strides) { 1884 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1885 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1886 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1887 ShapedType::kDynamicStrideOrOffset); 1888 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1889 ShapedType::kDynamicSize); 1890 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1891 ShapedType::kDynamicStrideOrOffset); 1892 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1893 staticSizes, staticStrides); 1894 } 1895 1896 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1897 MemRefType sourceRankedTensorType, 1898 ArrayRef<int64_t> offsets, 1899 ArrayRef<int64_t> sizes, 1900 ArrayRef<int64_t> strides) { 1901 auto inferredType = 1902 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 1903 .cast<MemRefType>(); 1904 assert(inferredType.getRank() >= resultRank && "expected "); 1905 int rankDiff = inferredType.getRank() - resultRank; 1906 if (rankDiff > 0) { 1907 auto shape = inferredType.getShape(); 1908 llvm::SmallBitVector dimsToProject = 1909 getPositionsOfShapeOne(rankDiff, shape); 1910 SmallVector<int64_t> projectedShape; 1911 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1912 if (!dimsToProject.test(pos)) 1913 projectedShape.push_back(shape[pos]); 1914 1915 AffineMap map = inferredType.getLayout().getAffineMap(); 1916 if (!map.isIdentity()) 1917 map = getProjectedMap(map, dimsToProject); 1918 inferredType = 1919 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1920 inferredType.getMemorySpace()); 1921 } 1922 return inferredType; 1923 } 1924 1925 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1926 MemRefType sourceRankedTensorType, 1927 ArrayRef<OpFoldResult> offsets, 1928 ArrayRef<OpFoldResult> sizes, 1929 ArrayRef<OpFoldResult> strides) { 1930 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1931 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1932 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1933 ShapedType::kDynamicStrideOrOffset); 1934 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1935 ShapedType::kDynamicSize); 1936 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1937 ShapedType::kDynamicStrideOrOffset); 1938 return SubViewOp::inferRankReducedResultType( 1939 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1940 staticStrides); 1941 } 1942 // Build a SubViewOp with mixed static and dynamic entries and custom result 1943 // type. If the type passed is nullptr, it is inferred. 1944 void SubViewOp::build(OpBuilder &b, OperationState &result, 1945 MemRefType resultType, Value source, 1946 ArrayRef<OpFoldResult> offsets, 1947 ArrayRef<OpFoldResult> sizes, 1948 ArrayRef<OpFoldResult> strides, 1949 ArrayRef<NamedAttribute> attrs) { 1950 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1951 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1952 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1953 ShapedType::kDynamicStrideOrOffset); 1954 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1955 ShapedType::kDynamicSize); 1956 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1957 ShapedType::kDynamicStrideOrOffset); 1958 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1959 // Structuring implementation this way avoids duplication between builders. 1960 if (!resultType) { 1961 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1962 staticSizes, staticStrides) 1963 .cast<MemRefType>(); 1964 } 1965 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1966 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1967 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1968 result.addAttributes(attrs); 1969 } 1970 1971 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1972 // type. 1973 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1974 ArrayRef<OpFoldResult> offsets, 1975 ArrayRef<OpFoldResult> sizes, 1976 ArrayRef<OpFoldResult> strides, 1977 ArrayRef<NamedAttribute> attrs) { 1978 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1979 } 1980 1981 // Build a SubViewOp with static entries and inferred result type. 1982 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1983 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1984 ArrayRef<int64_t> strides, 1985 ArrayRef<NamedAttribute> attrs) { 1986 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1987 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1988 return b.getI64IntegerAttr(v); 1989 })); 1990 SmallVector<OpFoldResult> sizeValues = 1991 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1992 return b.getI64IntegerAttr(v); 1993 })); 1994 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1995 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1996 return b.getI64IntegerAttr(v); 1997 })); 1998 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1999 } 2000 2001 // Build a SubViewOp with dynamic entries and custom result type. If the 2002 // type passed is nullptr, it is inferred. 2003 void SubViewOp::build(OpBuilder &b, OperationState &result, 2004 MemRefType resultType, Value source, 2005 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2006 ArrayRef<int64_t> strides, 2007 ArrayRef<NamedAttribute> attrs) { 2008 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2009 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 2010 return b.getI64IntegerAttr(v); 2011 })); 2012 SmallVector<OpFoldResult> sizeValues = 2013 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2014 return b.getI64IntegerAttr(v); 2015 })); 2016 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2017 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2018 return b.getI64IntegerAttr(v); 2019 })); 2020 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 2021 attrs); 2022 } 2023 2024 // Build a SubViewOp with dynamic entries and custom result type. If the type 2025 // passed is nullptr, it is inferred. 2026 void SubViewOp::build(OpBuilder &b, OperationState &result, 2027 MemRefType resultType, Value source, ValueRange offsets, 2028 ValueRange sizes, ValueRange strides, 2029 ArrayRef<NamedAttribute> attrs) { 2030 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2031 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 2032 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 2033 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 2034 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2035 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 2036 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 2037 } 2038 2039 // Build a SubViewOp with dynamic entries and inferred result type. 2040 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2041 ValueRange offsets, ValueRange sizes, ValueRange strides, 2042 ArrayRef<NamedAttribute> attrs) { 2043 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2044 } 2045 2046 /// For ViewLikeOpInterface. 2047 Value SubViewOp::getViewSource() { return source(); } 2048 2049 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 2050 /// value). 2051 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 2052 AffineExpr t1Offset, t2Offset; 2053 SmallVector<AffineExpr> t1Strides, t2Strides; 2054 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 2055 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 2056 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 2057 } 2058 2059 /// Checks if `original` Type type can be rank reduced to `reduced` type. 2060 /// This function is slight variant of `is subsequence` algorithm where 2061 /// not matching dimension must be 1. 2062 static SliceVerificationResult 2063 isRankReducedMemRefType(MemRefType originalType, 2064 MemRefType candidateRankReducedType, 2065 ArrayRef<OpFoldResult> sizes) { 2066 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 2067 if (partialRes != SliceVerificationResult::Success) 2068 return partialRes; 2069 2070 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 2071 originalType, candidateRankReducedType, sizes); 2072 2073 // Sizes cannot be matched in case empty vector is returned. 2074 if (!optionalUnusedDimsMask.hasValue()) 2075 return SliceVerificationResult::LayoutMismatch; 2076 2077 if (originalType.getMemorySpace() != 2078 candidateRankReducedType.getMemorySpace()) 2079 return SliceVerificationResult::MemSpaceMismatch; 2080 2081 // No amount of stride dropping can reconcile incompatible offsets. 2082 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 2083 return SliceVerificationResult::LayoutMismatch; 2084 2085 return SliceVerificationResult::Success; 2086 } 2087 2088 template <typename OpTy> 2089 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 2090 OpTy op, Type expectedType) { 2091 auto memrefType = expectedType.cast<ShapedType>(); 2092 switch (result) { 2093 case SliceVerificationResult::Success: 2094 return success(); 2095 case SliceVerificationResult::RankTooLarge: 2096 return op.emitError("expected result rank to be smaller or equal to ") 2097 << "the source rank. "; 2098 case SliceVerificationResult::SizeMismatch: 2099 return op.emitError("expected result type to be ") 2100 << expectedType 2101 << " or a rank-reduced version. (mismatch of result sizes) "; 2102 case SliceVerificationResult::ElemTypeMismatch: 2103 return op.emitError("expected result element type to be ") 2104 << memrefType.getElementType(); 2105 case SliceVerificationResult::MemSpaceMismatch: 2106 return op.emitError("expected result and source memory spaces to match."); 2107 case SliceVerificationResult::LayoutMismatch: 2108 return op.emitError("expected result type to be ") 2109 << expectedType 2110 << " or a rank-reduced version. (mismatch of result layout) "; 2111 } 2112 llvm_unreachable("unexpected subview verification result"); 2113 } 2114 2115 /// Verifier for SubViewOp. 2116 LogicalResult SubViewOp::verify() { 2117 MemRefType baseType = getSourceType(); 2118 MemRefType subViewType = getType(); 2119 2120 // The base memref and the view memref should be in the same memory space. 2121 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2122 return emitError("different memory spaces specified for base memref " 2123 "type ") 2124 << baseType << " and subview memref type " << subViewType; 2125 2126 // Verify that the base memref type has a strided layout map. 2127 if (!isStrided(baseType)) 2128 return emitError("base type ") << baseType << " is not strided"; 2129 2130 // Verify result type against inferred type. 2131 auto expectedType = SubViewOp::inferResultType( 2132 baseType, extractFromI64ArrayAttr(static_offsets()), 2133 extractFromI64ArrayAttr(static_sizes()), 2134 extractFromI64ArrayAttr(static_strides())); 2135 2136 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 2137 subViewType, getMixedSizes()); 2138 return produceSubViewErrorMsg(result, *this, expectedType); 2139 } 2140 2141 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 2142 return os << "range " << range.offset << ":" << range.size << ":" 2143 << range.stride; 2144 } 2145 2146 /// Return the list of Range (i.e. offset, size, stride). Each Range 2147 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2148 /// with `b` at location `loc`. 2149 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2150 OpBuilder &b, Location loc) { 2151 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2152 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2153 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2154 SmallVector<Range, 8> res; 2155 unsigned rank = ranks[0]; 2156 res.reserve(rank); 2157 for (unsigned idx = 0; idx < rank; ++idx) { 2158 Value offset = 2159 op.isDynamicOffset(idx) 2160 ? op.getDynamicOffset(idx) 2161 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2162 Value size = 2163 op.isDynamicSize(idx) 2164 ? op.getDynamicSize(idx) 2165 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 2166 Value stride = 2167 op.isDynamicStride(idx) 2168 ? op.getDynamicStride(idx) 2169 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 2170 res.emplace_back(Range{offset, size, stride}); 2171 } 2172 return res; 2173 } 2174 2175 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2176 /// deduce the result type for the given `sourceType`. Additionally, reduce the 2177 /// rank of the inferred result type if `currentResultType` is lower rank than 2178 /// `currentSourceType`. Use this signature if `sourceType` is updated together 2179 /// with the result type. In this case, it is important to compute the dropped 2180 /// dimensions using `currentSourceType` whose strides align with 2181 /// `currentResultType`. 2182 static MemRefType getCanonicalSubViewResultType( 2183 MemRefType currentResultType, MemRefType currentSourceType, 2184 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 2185 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 2186 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2187 mixedSizes, mixedStrides) 2188 .cast<MemRefType>(); 2189 llvm::Optional<llvm::SmallBitVector> unusedDims = 2190 computeMemRefRankReductionMask(currentSourceType, currentResultType, 2191 mixedSizes); 2192 // Return nullptr as failure mode. 2193 if (!unusedDims) 2194 return nullptr; 2195 SmallVector<int64_t> shape; 2196 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 2197 if (unusedDims->test(sizes.index())) 2198 continue; 2199 shape.push_back(sizes.value()); 2200 } 2201 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 2202 if (!layoutMap.isIdentity()) 2203 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 2204 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 2205 nonRankReducedType.getMemorySpace()); 2206 } 2207 2208 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2209 /// deduce the result type. Additionally, reduce the rank of the inferred result 2210 /// type if `currentResultType` is lower rank than `sourceType`. 2211 static MemRefType getCanonicalSubViewResultType( 2212 MemRefType currentResultType, MemRefType sourceType, 2213 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 2214 ArrayRef<OpFoldResult> mixedStrides) { 2215 return getCanonicalSubViewResultType(currentResultType, sourceType, 2216 sourceType, mixedOffsets, mixedSizes, 2217 mixedStrides); 2218 } 2219 2220 /// Helper method to check if a `subview` operation is trivially a no-op. This 2221 /// is the case if the all offsets are zero, all strides are 1, and the source 2222 /// shape is same as the size of the subview. In such cases, the subview can be 2223 /// folded into its source. 2224 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 2225 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 2226 return false; 2227 2228 auto mixedOffsets = subViewOp.getMixedOffsets(); 2229 auto mixedSizes = subViewOp.getMixedSizes(); 2230 auto mixedStrides = subViewOp.getMixedStrides(); 2231 2232 // Check offsets are zero. 2233 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 2234 Optional<int64_t> intValue = getConstantIntValue(ofr); 2235 return !intValue || intValue.getValue() != 0; 2236 })) 2237 return false; 2238 2239 // Check strides are one. 2240 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 2241 Optional<int64_t> intValue = getConstantIntValue(ofr); 2242 return !intValue || intValue.getValue() != 1; 2243 })) 2244 return false; 2245 2246 // Check all size values are static and matches the (static) source shape. 2247 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 2248 for (const auto &size : llvm::enumerate(mixedSizes)) { 2249 Optional<int64_t> intValue = getConstantIntValue(size.value()); 2250 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 2251 return false; 2252 } 2253 // All conditions met. The `SubViewOp` is foldable as a no-op. 2254 return true; 2255 } 2256 2257 namespace { 2258 /// Pattern to rewrite a subview op with MemRefCast arguments. 2259 /// This essentially pushes memref.cast past its consuming subview when 2260 /// `canFoldIntoConsumerOp` is true. 2261 /// 2262 /// Example: 2263 /// ``` 2264 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2265 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2266 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2267 /// ``` 2268 /// is rewritten into: 2269 /// ``` 2270 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2271 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2272 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2273 /// ``` 2274 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2275 public: 2276 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2277 2278 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2279 PatternRewriter &rewriter) const override { 2280 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2281 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2282 return matchPattern(operand, matchConstantIndex()); 2283 })) 2284 return failure(); 2285 2286 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2287 if (!castOp) 2288 return failure(); 2289 2290 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2291 return failure(); 2292 2293 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 2294 // MemRefCastOp source operand type to infer the result type and the current 2295 // SubViewOp source operand type to compute the dropped dimensions if the 2296 // operation is rank-reducing. 2297 auto resultType = getCanonicalSubViewResultType( 2298 subViewOp.getType(), subViewOp.getSourceType(), 2299 castOp.source().getType().cast<MemRefType>(), 2300 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2301 subViewOp.getMixedStrides()); 2302 if (!resultType) 2303 return failure(); 2304 2305 Value newSubView = rewriter.create<SubViewOp>( 2306 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2307 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2308 subViewOp.static_sizes(), subViewOp.static_strides()); 2309 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2310 newSubView); 2311 return success(); 2312 } 2313 }; 2314 2315 /// Canonicalize subview ops that are no-ops. When the source shape is not same 2316 /// as a result shape due to use of `affine_map`. 2317 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 2318 public: 2319 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2320 2321 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2322 PatternRewriter &rewriter) const override { 2323 if (!isTrivialSubViewOp(subViewOp)) 2324 return failure(); 2325 if (subViewOp.getSourceType() == subViewOp.getType()) { 2326 rewriter.replaceOp(subViewOp, subViewOp.source()); 2327 return success(); 2328 } 2329 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2330 subViewOp.source()); 2331 return success(); 2332 } 2333 }; 2334 } // namespace 2335 2336 /// Return the canonical type of the result of a subview. 2337 struct SubViewReturnTypeCanonicalizer { 2338 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2339 ArrayRef<OpFoldResult> mixedSizes, 2340 ArrayRef<OpFoldResult> mixedStrides) { 2341 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2342 mixedOffsets, mixedSizes, 2343 mixedStrides); 2344 } 2345 }; 2346 2347 /// A canonicalizer wrapper to replace SubViewOps. 2348 struct SubViewCanonicalizer { 2349 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2350 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 2351 } 2352 }; 2353 2354 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2355 MLIRContext *context) { 2356 results 2357 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2358 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2359 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2360 } 2361 2362 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2363 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2364 auto sourceShapedType = source().getType().cast<ShapedType>(); 2365 2366 if (resultShapedType.hasStaticShape() && 2367 resultShapedType == sourceShapedType) { 2368 return getViewSource(); 2369 } 2370 2371 return {}; 2372 } 2373 2374 //===----------------------------------------------------------------------===// 2375 // TransposeOp 2376 //===----------------------------------------------------------------------===// 2377 2378 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2379 static MemRefType inferTransposeResultType(MemRefType memRefType, 2380 AffineMap permutationMap) { 2381 auto rank = memRefType.getRank(); 2382 auto originalSizes = memRefType.getShape(); 2383 // Compute permuted sizes. 2384 SmallVector<int64_t, 4> sizes(rank, 0); 2385 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2386 sizes[en.index()] = 2387 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2388 2389 // Compute permuted strides. 2390 int64_t offset; 2391 SmallVector<int64_t, 4> strides; 2392 auto res = getStridesAndOffset(memRefType, strides, offset); 2393 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2394 (void)res; 2395 auto map = 2396 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2397 map = permutationMap ? map.compose(permutationMap) : map; 2398 return MemRefType::Builder(memRefType) 2399 .setShape(sizes) 2400 .setLayout(AffineMapAttr::get(map)); 2401 } 2402 2403 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2404 AffineMapAttr permutation, 2405 ArrayRef<NamedAttribute> attrs) { 2406 auto permutationMap = permutation.getValue(); 2407 assert(permutationMap); 2408 2409 auto memRefType = in.getType().cast<MemRefType>(); 2410 // Compute result type. 2411 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2412 2413 build(b, result, resultType, in, attrs); 2414 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2415 } 2416 2417 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2418 void TransposeOp::print(OpAsmPrinter &p) { 2419 p << " " << in() << " " << permutation(); 2420 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); 2421 p << " : " << in().getType() << " to " << getType(); 2422 } 2423 2424 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { 2425 OpAsmParser::OperandType in; 2426 AffineMap permutation; 2427 MemRefType srcType, dstType; 2428 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2429 parser.parseOptionalAttrDict(result.attributes) || 2430 parser.parseColonType(srcType) || 2431 parser.resolveOperand(in, srcType, result.operands) || 2432 parser.parseKeywordType("to", dstType) || 2433 parser.addTypeToList(dstType, result.types)) 2434 return failure(); 2435 2436 result.addAttribute(TransposeOp::getPermutationAttrName(), 2437 AffineMapAttr::get(permutation)); 2438 return success(); 2439 } 2440 2441 LogicalResult TransposeOp::verify() { 2442 if (!permutation().isPermutation()) 2443 return emitOpError("expected a permutation map"); 2444 if (permutation().getNumDims() != getShapedType().getRank()) 2445 return emitOpError("expected a permutation map of same rank as the input"); 2446 2447 auto srcType = in().getType().cast<MemRefType>(); 2448 auto dstType = getType().cast<MemRefType>(); 2449 auto transposedType = inferTransposeResultType(srcType, permutation()); 2450 if (dstType != transposedType) 2451 return emitOpError("output type ") 2452 << dstType << " does not match transposed input type " << srcType 2453 << ", " << transposedType; 2454 return success(); 2455 } 2456 2457 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2458 if (succeeded(foldMemRefCast(*this))) 2459 return getResult(); 2460 return {}; 2461 } 2462 2463 //===----------------------------------------------------------------------===// 2464 // ViewOp 2465 //===----------------------------------------------------------------------===// 2466 2467 LogicalResult ViewOp::verify() { 2468 auto baseType = getOperand(0).getType().cast<MemRefType>(); 2469 auto viewType = getType(); 2470 2471 // The base memref should have identity layout map (or none). 2472 if (!baseType.getLayout().isIdentity()) 2473 return emitError("unsupported map for base memref type ") << baseType; 2474 2475 // The result memref should have identity layout map (or none). 2476 if (!viewType.getLayout().isIdentity()) 2477 return emitError("unsupported map for result memref type ") << viewType; 2478 2479 // The base memref and the view memref should be in the same memory space. 2480 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2481 return emitError("different memory spaces specified for base memref " 2482 "type ") 2483 << baseType << " and view memref type " << viewType; 2484 2485 // Verify that we have the correct number of sizes for the result type. 2486 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2487 if (sizes().size() != numDynamicDims) 2488 return emitError("incorrect number of size operands for type ") << viewType; 2489 2490 return success(); 2491 } 2492 2493 Value ViewOp::getViewSource() { return source(); } 2494 2495 namespace { 2496 2497 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2498 using OpRewritePattern<ViewOp>::OpRewritePattern; 2499 2500 LogicalResult matchAndRewrite(ViewOp viewOp, 2501 PatternRewriter &rewriter) const override { 2502 // Return if none of the operands are constants. 2503 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2504 return matchPattern(operand, matchConstantIndex()); 2505 })) 2506 return failure(); 2507 2508 // Get result memref type. 2509 auto memrefType = viewOp.getType(); 2510 2511 // Get offset from old memref view type 'memRefType'. 2512 int64_t oldOffset; 2513 SmallVector<int64_t, 4> oldStrides; 2514 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2515 return failure(); 2516 assert(oldOffset == 0 && "Expected 0 offset"); 2517 2518 SmallVector<Value, 4> newOperands; 2519 2520 // Offset cannot be folded into result type. 2521 2522 // Fold any dynamic dim operands which are produced by a constant. 2523 SmallVector<int64_t, 4> newShapeConstants; 2524 newShapeConstants.reserve(memrefType.getRank()); 2525 2526 unsigned dynamicDimPos = 0; 2527 unsigned rank = memrefType.getRank(); 2528 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2529 int64_t dimSize = memrefType.getDimSize(dim); 2530 // If this is already static dimension, keep it. 2531 if (!ShapedType::isDynamic(dimSize)) { 2532 newShapeConstants.push_back(dimSize); 2533 continue; 2534 } 2535 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2536 if (auto constantIndexOp = 2537 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2538 // Dynamic shape dimension will be folded. 2539 newShapeConstants.push_back(constantIndexOp.value()); 2540 } else { 2541 // Dynamic shape dimension not folded; copy operand from old memref. 2542 newShapeConstants.push_back(dimSize); 2543 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2544 } 2545 dynamicDimPos++; 2546 } 2547 2548 // Create new memref type with constant folded dims. 2549 MemRefType newMemRefType = 2550 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2551 // Nothing new, don't fold. 2552 if (newMemRefType == memrefType) 2553 return failure(); 2554 2555 // Create new ViewOp. 2556 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2557 viewOp.getOperand(0), 2558 viewOp.byte_shift(), newOperands); 2559 // Insert a cast so we have the same type as the old memref type. 2560 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp); 2561 return success(); 2562 } 2563 }; 2564 2565 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2566 using OpRewritePattern<ViewOp>::OpRewritePattern; 2567 2568 LogicalResult matchAndRewrite(ViewOp viewOp, 2569 PatternRewriter &rewriter) const override { 2570 Value memrefOperand = viewOp.getOperand(0); 2571 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2572 if (!memrefCastOp) 2573 return failure(); 2574 Value allocOperand = memrefCastOp.getOperand(); 2575 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2576 if (!allocOp) 2577 return failure(); 2578 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2579 viewOp.byte_shift(), viewOp.sizes()); 2580 return success(); 2581 } 2582 }; 2583 2584 } // namespace 2585 2586 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2587 MLIRContext *context) { 2588 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2589 } 2590 2591 //===----------------------------------------------------------------------===// 2592 // AtomicRMWOp 2593 //===----------------------------------------------------------------------===// 2594 2595 LogicalResult AtomicRMWOp::verify() { 2596 if (getMemRefType().getRank() != getNumOperands() - 2) 2597 return emitOpError( 2598 "expects the number of subscripts to be equal to memref rank"); 2599 switch (kind()) { 2600 case arith::AtomicRMWKind::addf: 2601 case arith::AtomicRMWKind::maxf: 2602 case arith::AtomicRMWKind::minf: 2603 case arith::AtomicRMWKind::mulf: 2604 if (!value().getType().isa<FloatType>()) 2605 return emitOpError() << "with kind '" 2606 << arith::stringifyAtomicRMWKind(kind()) 2607 << "' expects a floating-point type"; 2608 break; 2609 case arith::AtomicRMWKind::addi: 2610 case arith::AtomicRMWKind::maxs: 2611 case arith::AtomicRMWKind::maxu: 2612 case arith::AtomicRMWKind::mins: 2613 case arith::AtomicRMWKind::minu: 2614 case arith::AtomicRMWKind::muli: 2615 case arith::AtomicRMWKind::ori: 2616 case arith::AtomicRMWKind::andi: 2617 if (!value().getType().isa<IntegerType>()) 2618 return emitOpError() << "with kind '" 2619 << arith::stringifyAtomicRMWKind(kind()) 2620 << "' expects an integer type"; 2621 break; 2622 default: 2623 break; 2624 } 2625 return success(); 2626 } 2627 2628 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2629 /// atomicrmw(memrefcast) -> atomicrmw 2630 if (succeeded(foldMemRefCast(*this, value()))) 2631 return getResult(); 2632 return OpFoldResult(); 2633 } 2634 2635 //===----------------------------------------------------------------------===// 2636 // TableGen'd op method definitions 2637 //===----------------------------------------------------------------------===// 2638 2639 #define GET_OP_CLASSES 2640 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2641