1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/SmallBitVector.h" 25 26 using namespace mlir; 27 using namespace mlir::memref; 28 29 /// Materialize a single constant operation from a given attribute value with 30 /// the desired resultant type. 31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 32 Attribute value, Type type, 33 Location loc) { 34 if (arith::ConstantOp::isBuildableWith(value, type)) 35 return builder.create<arith::ConstantOp>(loc, value, type); 36 return nullptr; 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Common canonicalization pattern support logic 41 //===----------------------------------------------------------------------===// 42 43 /// This is a common class used for patterns of the form 44 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 45 /// into the root operation directly. 46 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 47 bool folded = false; 48 for (OpOperand &operand : op->getOpOperands()) { 49 auto cast = operand.get().getDefiningOp<CastOp>(); 50 if (cast && operand.get() != inner && 51 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 52 operand.set(cast.getOperand()); 53 folded = true; 54 } 55 } 56 return success(folded); 57 } 58 59 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 60 /// type. 61 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 62 if (auto memref = type.dyn_cast<MemRefType>()) 63 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 64 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 65 return UnrankedTensorType::get(memref.getElementType()); 66 return NoneType::get(type.getContext()); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // AllocOp / AllocaOp 71 //===----------------------------------------------------------------------===// 72 73 template <typename AllocLikeOp> 74 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 75 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 76 "applies to only alloc or alloca"); 77 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 78 if (!memRefType) 79 return op.emitOpError("result must be a memref"); 80 81 if (static_cast<int64_t>(op.dynamicSizes().size()) != 82 memRefType.getNumDynamicDims()) 83 return op.emitOpError("dimension operand count does not equal memref " 84 "dynamic dimension count"); 85 86 unsigned numSymbols = 0; 87 if (!memRefType.getLayout().isIdentity()) 88 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 89 if (op.symbolOperands().size() != numSymbols) 90 return op.emitOpError("symbol operand count does not equal memref symbol " 91 "count: expected ") 92 << numSymbols << ", got " << op.symbolOperands().size(); 93 94 return success(); 95 } 96 97 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } 98 99 LogicalResult AllocaOp::verify() { 100 // An alloca op needs to have an ancestor with an allocation scope trait. 101 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 102 return emitOpError( 103 "requires an ancestor op with AutomaticAllocationScope trait"); 104 105 return verifyAllocLikeOp(*this); 106 } 107 108 namespace { 109 /// Fold constant dimensions into an alloc like operation. 110 template <typename AllocLikeOp> 111 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 112 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 113 114 LogicalResult matchAndRewrite(AllocLikeOp alloc, 115 PatternRewriter &rewriter) const override { 116 // Check to see if any dimensions operands are constants. If so, we can 117 // substitute and drop them. 118 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 119 return matchPattern(operand, matchConstantIndex()); 120 })) 121 return failure(); 122 123 auto memrefType = alloc.getType(); 124 125 // Ok, we have one or more constant operands. Collect the non-constant ones 126 // and keep track of the resultant memref type to build. 127 SmallVector<int64_t, 4> newShapeConstants; 128 newShapeConstants.reserve(memrefType.getRank()); 129 SmallVector<Value, 4> dynamicSizes; 130 131 unsigned dynamicDimPos = 0; 132 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 133 int64_t dimSize = memrefType.getDimSize(dim); 134 // If this is already static dimension, keep it. 135 if (dimSize != -1) { 136 newShapeConstants.push_back(dimSize); 137 continue; 138 } 139 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 140 auto *defOp = dynamicSize.getDefiningOp(); 141 if (auto constantIndexOp = 142 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 143 // Dynamic shape dimension will be folded. 144 newShapeConstants.push_back(constantIndexOp.value()); 145 } else { 146 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 147 newShapeConstants.push_back(-1); 148 dynamicSizes.push_back(dynamicSize); 149 } 150 dynamicDimPos++; 151 } 152 153 // Create new memref type (which will have fewer dynamic dimensions). 154 MemRefType newMemRefType = 155 MemRefType::Builder(memrefType).setShape(newShapeConstants); 156 assert(static_cast<int64_t>(dynamicSizes.size()) == 157 newMemRefType.getNumDynamicDims()); 158 159 // Create and insert the alloc op for the new memref. 160 auto newAlloc = rewriter.create<AllocLikeOp>( 161 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 162 alloc.alignmentAttr()); 163 // Insert a cast so we have the same type as the old alloc. 164 auto resultCast = 165 rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc); 166 167 rewriter.replaceOp(alloc, {resultCast}); 168 return success(); 169 } 170 }; 171 172 /// Fold alloc operations with no users or only store and dealloc uses. 173 template <typename T> 174 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 175 using OpRewritePattern<T>::OpRewritePattern; 176 177 LogicalResult matchAndRewrite(T alloc, 178 PatternRewriter &rewriter) const override { 179 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 180 if (auto storeOp = dyn_cast<StoreOp>(op)) 181 return storeOp.value() == alloc; 182 return !isa<DeallocOp>(op); 183 })) 184 return failure(); 185 186 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 187 rewriter.eraseOp(user); 188 189 rewriter.eraseOp(alloc); 190 return success(); 191 } 192 }; 193 } // namespace 194 195 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 196 MLIRContext *context) { 197 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 198 } 199 200 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 201 MLIRContext *context) { 202 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 203 context); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // AllocaScopeOp 208 //===----------------------------------------------------------------------===// 209 210 void AllocaScopeOp::print(OpAsmPrinter &p) { 211 bool printBlockTerminators = false; 212 213 p << ' '; 214 if (!results().empty()) { 215 p << " -> (" << getResultTypes() << ")"; 216 printBlockTerminators = true; 217 } 218 p << ' '; 219 p.printRegion(bodyRegion(), 220 /*printEntryBlockArgs=*/false, 221 /*printBlockTerminators=*/printBlockTerminators); 222 p.printOptionalAttrDict((*this)->getAttrs()); 223 } 224 225 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { 226 // Create a region for the body. 227 result.regions.reserve(1); 228 Region *bodyRegion = result.addRegion(); 229 230 // Parse optional results type list. 231 if (parser.parseOptionalArrowTypeList(result.types)) 232 return failure(); 233 234 // Parse the body region. 235 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 236 return failure(); 237 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 238 result.location); 239 240 // Parse the optional attribute list. 241 if (parser.parseOptionalAttrDict(result.attributes)) 242 return failure(); 243 244 return success(); 245 } 246 247 void AllocaScopeOp::getSuccessorRegions( 248 Optional<unsigned> index, ArrayRef<Attribute> operands, 249 SmallVectorImpl<RegionSuccessor> ®ions) { 250 if (index.hasValue()) { 251 regions.push_back(RegionSuccessor(getResults())); 252 return; 253 } 254 255 regions.push_back(RegionSuccessor(&bodyRegion())); 256 } 257 258 /// Given an operation, return whether this op is guaranteed to 259 /// allocate an AutomaticAllocationScopeResource 260 static bool isGuaranteedAutomaticAllocation(Operation *op) { 261 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 262 if (!interface) 263 return false; 264 for (auto res : op->getResults()) { 265 if (auto effect = 266 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 267 if (isa<SideEffects::AutomaticAllocationScopeResource>( 268 effect->getResource())) 269 return true; 270 } 271 } 272 return false; 273 } 274 275 /// Given an operation, return whether this op itself could 276 /// allocate an AutomaticAllocationScopeResource. Note that 277 /// this will not check whether an operation contained within 278 /// the op can allocate. 279 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) { 280 // This op itself doesn't create a stack allocation, 281 // the inner allocation should be handled separately. 282 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) 283 return false; 284 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 285 if (!interface) 286 return true; 287 for (auto res : op->getResults()) { 288 if (auto effect = 289 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 290 if (isa<SideEffects::AutomaticAllocationScopeResource>( 291 effect->getResource())) 292 return true; 293 } 294 } 295 return false; 296 } 297 298 /// Return whether this op is the last non terminating op 299 /// in a region. That is to say, it is in a one-block region 300 /// and is only followed by a terminator. This prevents 301 /// extending the lifetime of allocations. 302 static bool lastNonTerminatorInRegion(Operation *op) { 303 return op->getNextNode() == op->getBlock()->getTerminator() && 304 op->getParentRegion()->getBlocks().size() == 1; 305 } 306 307 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope 308 /// or it contains no allocation. 309 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { 310 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 311 312 LogicalResult matchAndRewrite(AllocaScopeOp op, 313 PatternRewriter &rewriter) const override { 314 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) { 315 bool hasPotentialAlloca = 316 op->walk([&](Operation *alloc) { 317 if (alloc == op) 318 return WalkResult::advance(); 319 if (isOpItselfPotentialAutomaticAllocation(alloc)) 320 return WalkResult::interrupt(); 321 return WalkResult::advance(); 322 }).wasInterrupted(); 323 if (hasPotentialAlloca) 324 return failure(); 325 } 326 327 // Only apply to if this is this last non-terminator 328 // op in the block (lest lifetime be extended) of a one 329 // block region 330 if (!lastNonTerminatorInRegion(op)) 331 return failure(); 332 333 Block *block = &op.getRegion().front(); 334 Operation *terminator = block->getTerminator(); 335 ValueRange results = terminator->getOperands(); 336 rewriter.mergeBlockBefore(block, op); 337 rewriter.replaceOp(op, results); 338 rewriter.eraseOp(terminator); 339 return success(); 340 } 341 }; 342 343 /// Move allocations into an allocation scope, if it is legal to 344 /// move them (e.g. their operands are available at the location 345 /// the op would be moved to). 346 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { 347 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 348 349 LogicalResult matchAndRewrite(AllocaScopeOp op, 350 PatternRewriter &rewriter) const override { 351 352 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 353 return failure(); 354 355 Operation *lastParentWithoutScope = op->getParentOp(); 356 357 if (!lastParentWithoutScope || 358 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>()) 359 return failure(); 360 361 // Only apply to if this is this last non-terminator 362 // op in the block (lest lifetime be extended) of a one 363 // block region 364 if (!lastNonTerminatorInRegion(op) || 365 !lastNonTerminatorInRegion(lastParentWithoutScope)) 366 return failure(); 367 368 while (!lastParentWithoutScope->getParentOp() 369 ->hasTrait<OpTrait::AutomaticAllocationScope>()) { 370 lastParentWithoutScope = lastParentWithoutScope->getParentOp(); 371 if (!lastParentWithoutScope || 372 !lastNonTerminatorInRegion(lastParentWithoutScope)) 373 return failure(); 374 } 375 assert(lastParentWithoutScope->getParentOp() 376 ->hasTrait<OpTrait::AutomaticAllocationScope>()); 377 378 Region *containingRegion = nullptr; 379 for (auto &r : lastParentWithoutScope->getRegions()) { 380 if (r.isAncestor(op->getParentRegion())) { 381 assert(containingRegion == nullptr && 382 "only one region can contain the op"); 383 containingRegion = &r; 384 } 385 } 386 assert(containingRegion && "op must be contained in a region"); 387 388 SmallVector<Operation *> toHoist; 389 op->walk([&](Operation *alloc) { 390 if (!isGuaranteedAutomaticAllocation(alloc)) 391 return WalkResult::skip(); 392 393 // If any operand is not defined before the location of 394 // lastParentWithoutScope (i.e. where we would hoist to), skip. 395 if (llvm::any_of(alloc->getOperands(), [&](Value v) { 396 return containingRegion->isAncestor(v.getParentRegion()); 397 })) 398 return WalkResult::skip(); 399 toHoist.push_back(alloc); 400 return WalkResult::advance(); 401 }); 402 403 if (toHoist.empty()) 404 return failure(); 405 rewriter.setInsertionPoint(lastParentWithoutScope); 406 for (auto *op : toHoist) { 407 auto *cloned = rewriter.clone(*op); 408 rewriter.replaceOp(op, cloned->getResults()); 409 } 410 return success(); 411 } 412 }; 413 414 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, 415 MLIRContext *context) { 416 results.add<AllocaScopeInliner, AllocaScopeHoister>(context); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // AssumeAlignmentOp 421 //===----------------------------------------------------------------------===// 422 423 LogicalResult AssumeAlignmentOp::verify() { 424 if (!llvm::isPowerOf2_32(alignment())) 425 return emitOpError("alignment must be power of 2"); 426 return success(); 427 } 428 429 //===----------------------------------------------------------------------===// 430 // CastOp 431 //===----------------------------------------------------------------------===// 432 433 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 434 /// source memref. This is useful to to fold a memref.cast into a consuming op 435 /// and implement canonicalization patterns for ops in different dialects that 436 /// may consume the results of memref.cast operations. Such foldable memref.cast 437 /// operations are typically inserted as `view` and `subview` ops are 438 /// canonicalized, to preserve the type compatibility of their uses. 439 /// 440 /// Returns true when all conditions are met: 441 /// 1. source and result are ranked memrefs with strided semantics and same 442 /// element type and rank. 443 /// 2. each of the source's size, offset or stride has more static information 444 /// than the corresponding result's size, offset or stride. 445 /// 446 /// Example 1: 447 /// ```mlir 448 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 449 /// %2 = consumer %1 ... : memref<?x?xf32> ... 450 /// ``` 451 /// 452 /// may fold into: 453 /// 454 /// ```mlir 455 /// %2 = consumer %0 ... : memref<8x16xf32> ... 456 /// ``` 457 /// 458 /// Example 2: 459 /// ``` 460 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 461 /// to memref<?x?xf32> 462 /// consumer %1 : memref<?x?xf32> ... 463 /// ``` 464 /// 465 /// may fold into: 466 /// 467 /// ``` 468 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 469 /// ``` 470 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 471 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 472 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 473 474 // Requires ranked MemRefType. 475 if (!sourceType || !resultType) 476 return false; 477 478 // Requires same elemental type. 479 if (sourceType.getElementType() != resultType.getElementType()) 480 return false; 481 482 // Requires same rank. 483 if (sourceType.getRank() != resultType.getRank()) 484 return false; 485 486 // Only fold casts between strided memref forms. 487 int64_t sourceOffset, resultOffset; 488 SmallVector<int64_t, 4> sourceStrides, resultStrides; 489 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 490 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 491 return false; 492 493 // If cast is towards more static sizes along any dimension, don't fold. 494 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 495 auto ss = std::get<0>(it), st = std::get<1>(it); 496 if (ss != st) 497 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 498 return false; 499 } 500 501 // If cast is towards more static offset along any dimension, don't fold. 502 if (sourceOffset != resultOffset) 503 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 504 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 505 return false; 506 507 // If cast is towards more static strides along any dimension, don't fold. 508 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 509 auto ss = std::get<0>(it), st = std::get<1>(it); 510 if (ss != st) 511 if (ShapedType::isDynamicStrideOrOffset(ss) && 512 !ShapedType::isDynamicStrideOrOffset(st)) 513 return false; 514 } 515 516 return true; 517 } 518 519 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 520 if (inputs.size() != 1 || outputs.size() != 1) 521 return false; 522 Type a = inputs.front(), b = outputs.front(); 523 auto aT = a.dyn_cast<MemRefType>(); 524 auto bT = b.dyn_cast<MemRefType>(); 525 526 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 527 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 528 529 if (aT && bT) { 530 if (aT.getElementType() != bT.getElementType()) 531 return false; 532 if (aT.getLayout() != bT.getLayout()) { 533 int64_t aOffset, bOffset; 534 SmallVector<int64_t, 4> aStrides, bStrides; 535 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 536 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 537 aStrides.size() != bStrides.size()) 538 return false; 539 540 // Strides along a dimension/offset are compatible if the value in the 541 // source memref is static and the value in the target memref is the 542 // same. They are also compatible if either one is dynamic (see 543 // description of MemRefCastOp for details). 544 auto checkCompatible = [](int64_t a, int64_t b) { 545 return (a == MemRefType::getDynamicStrideOrOffset() || 546 b == MemRefType::getDynamicStrideOrOffset() || a == b); 547 }; 548 if (!checkCompatible(aOffset, bOffset)) 549 return false; 550 for (const auto &aStride : enumerate(aStrides)) 551 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 552 return false; 553 } 554 if (aT.getMemorySpace() != bT.getMemorySpace()) 555 return false; 556 557 // They must have the same rank, and any specified dimensions must match. 558 if (aT.getRank() != bT.getRank()) 559 return false; 560 561 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 562 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 563 if (aDim != -1 && bDim != -1 && aDim != bDim) 564 return false; 565 } 566 return true; 567 } else { 568 if (!aT && !uaT) 569 return false; 570 if (!bT && !ubT) 571 return false; 572 // Unranked to unranked casting is unsupported 573 if (uaT && ubT) 574 return false; 575 576 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 577 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 578 if (aEltType != bEltType) 579 return false; 580 581 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 582 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 583 return aMemSpace == bMemSpace; 584 } 585 586 return false; 587 } 588 589 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 590 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 591 } 592 593 //===----------------------------------------------------------------------===// 594 // CopyOp 595 //===----------------------------------------------------------------------===// 596 597 namespace { 598 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 599 /// and element type, the cast can be skipped. Such CastOps only cast the layout 600 /// of the type. 601 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 602 using OpRewritePattern<CopyOp>::OpRewritePattern; 603 604 LogicalResult matchAndRewrite(CopyOp copyOp, 605 PatternRewriter &rewriter) const override { 606 bool modified = false; 607 608 // Check source. 609 if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) { 610 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 611 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 612 613 if (fromType && toType) { 614 if (fromType.getShape() == toType.getShape() && 615 fromType.getElementType() == toType.getElementType()) { 616 rewriter.updateRootInPlace( 617 copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); 618 modified = true; 619 } 620 } 621 } 622 623 // Check target. 624 if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) { 625 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 626 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 627 628 if (fromType && toType) { 629 if (fromType.getShape() == toType.getShape() && 630 fromType.getElementType() == toType.getElementType()) { 631 rewriter.updateRootInPlace( 632 copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); 633 modified = true; 634 } 635 } 636 } 637 638 return success(modified); 639 } 640 }; 641 642 /// Fold memref.copy(%x, %x). 643 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 644 using OpRewritePattern<CopyOp>::OpRewritePattern; 645 646 LogicalResult matchAndRewrite(CopyOp copyOp, 647 PatternRewriter &rewriter) const override { 648 if (copyOp.source() != copyOp.target()) 649 return failure(); 650 651 rewriter.eraseOp(copyOp); 652 return success(); 653 } 654 }; 655 } // namespace 656 657 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 658 MLIRContext *context) { 659 results.add<FoldCopyOfCast, FoldSelfCopy>(context); 660 } 661 662 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands, 663 SmallVectorImpl<OpFoldResult> &results) { 664 /// copy(memrefcast) -> copy 665 bool folded = false; 666 Operation *op = *this; 667 for (OpOperand &operand : op->getOpOperands()) { 668 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 669 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 670 operand.set(castOp.getOperand()); 671 folded = true; 672 } 673 } 674 return success(folded); 675 } 676 677 //===----------------------------------------------------------------------===// 678 // DeallocOp 679 //===----------------------------------------------------------------------===// 680 681 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 682 SmallVectorImpl<OpFoldResult> &results) { 683 /// dealloc(memrefcast) -> dealloc 684 return foldMemRefCast(*this); 685 } 686 687 //===----------------------------------------------------------------------===// 688 // DimOp 689 //===----------------------------------------------------------------------===// 690 691 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 692 int64_t index) { 693 auto loc = result.location; 694 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 695 build(builder, result, source, indexValue); 696 } 697 698 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 699 Value index) { 700 auto indexTy = builder.getIndexType(); 701 build(builder, result, indexTy, source, index); 702 } 703 704 Optional<int64_t> DimOp::getConstantIndex() { 705 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 706 return constantOp.getValue().cast<IntegerAttr>().getInt(); 707 return {}; 708 } 709 710 LogicalResult DimOp::verify() { 711 // Assume unknown index to be in range. 712 Optional<int64_t> index = getConstantIndex(); 713 if (!index.hasValue()) 714 return success(); 715 716 // Check that constant index is not knowingly out of range. 717 auto type = source().getType(); 718 if (auto memrefType = type.dyn_cast<MemRefType>()) { 719 if (index.getValue() >= memrefType.getRank()) 720 return emitOpError("index is out of range"); 721 } else if (type.isa<UnrankedMemRefType>()) { 722 // Assume index to be in range. 723 } else { 724 llvm_unreachable("expected operand with memref type"); 725 } 726 return success(); 727 } 728 729 /// Return a map with key being elements in `vals` and data being number of 730 /// occurences of it. Use std::map, since the `vals` here are strides and the 731 /// dynamic stride value is the same as the tombstone value for 732 /// `DenseMap<int64_t>`. 733 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 734 std::map<int64_t, unsigned> numOccurences; 735 for (auto val : vals) 736 numOccurences[val]++; 737 return numOccurences; 738 } 739 740 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 741 /// to be a subset of `originalType` with some `1` entries erased, return the 742 /// set of indices that specifies which of the entries of `originalShape` are 743 /// dropped to obtain `reducedShape`. 744 /// This accounts for cases where there are multiple unit-dims, but only a 745 /// subset of those are dropped. For MemRefTypes these can be disambiguated 746 /// using the strides. If a dimension is dropped the stride must be dropped too. 747 static llvm::Optional<llvm::SmallBitVector> 748 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 749 ArrayRef<OpFoldResult> sizes) { 750 llvm::SmallBitVector unusedDims(originalType.getRank()); 751 if (originalType.getRank() == reducedType.getRank()) 752 return unusedDims; 753 754 for (const auto &dim : llvm::enumerate(sizes)) 755 if (auto attr = dim.value().dyn_cast<Attribute>()) 756 if (attr.cast<IntegerAttr>().getInt() == 1) 757 unusedDims.set(dim.index()); 758 759 SmallVector<int64_t> originalStrides, candidateStrides; 760 int64_t originalOffset, candidateOffset; 761 if (failed( 762 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 763 failed( 764 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 765 return llvm::None; 766 767 // For memrefs, a dimension is truly dropped if its corresponding stride is 768 // also dropped. This is particularly important when more than one of the dims 769 // is 1. Track the number of occurences of the strides in the original type 770 // and the candidate type. For each unused dim that stride should not be 771 // present in the candidate type. Note that there could be multiple dimensions 772 // that have the same size. We dont need to exactly figure out which dim 773 // corresponds to which stride, we just need to verify that the number of 774 // reptitions of a stride in the original + number of unused dims with that 775 // stride == number of repititions of a stride in the candidate. 776 std::map<int64_t, unsigned> currUnaccountedStrides = 777 getNumOccurences(originalStrides); 778 std::map<int64_t, unsigned> candidateStridesNumOccurences = 779 getNumOccurences(candidateStrides); 780 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) { 781 if (!unusedDims.test(dim)) 782 continue; 783 int64_t originalStride = originalStrides[dim]; 784 if (currUnaccountedStrides[originalStride] > 785 candidateStridesNumOccurences[originalStride]) { 786 // This dim can be treated as dropped. 787 currUnaccountedStrides[originalStride]--; 788 continue; 789 } 790 if (currUnaccountedStrides[originalStride] == 791 candidateStridesNumOccurences[originalStride]) { 792 // The stride for this is not dropped. Keep as is. 793 unusedDims.reset(dim); 794 continue; 795 } 796 if (currUnaccountedStrides[originalStride] < 797 candidateStridesNumOccurences[originalStride]) { 798 // This should never happen. Cant have a stride in the reduced rank type 799 // that wasnt in the original one. 800 return llvm::None; 801 } 802 } 803 804 if ((int64_t)unusedDims.count() + reducedType.getRank() != 805 originalType.getRank()) 806 return llvm::None; 807 return unusedDims; 808 } 809 810 llvm::SmallBitVector SubViewOp::getDroppedDims() { 811 MemRefType sourceType = getSourceType(); 812 MemRefType resultType = getType(); 813 llvm::Optional<llvm::SmallBitVector> unusedDims = 814 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 815 assert(unusedDims && "unable to find unused dims of subview"); 816 return *unusedDims; 817 } 818 819 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 820 // All forms of folding require a known index. 821 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 822 if (!index) 823 return {}; 824 825 // Folding for unranked types (UnrankedMemRefType) is not supported. 826 auto memrefType = source().getType().dyn_cast<MemRefType>(); 827 if (!memrefType) 828 return {}; 829 830 // Fold if the shape extent along the given index is known. 831 if (!memrefType.isDynamicDim(index.getInt())) { 832 Builder builder(getContext()); 833 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 834 } 835 836 // The size at the given index is now known to be a dynamic size. 837 unsigned unsignedIndex = index.getValue().getZExtValue(); 838 839 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 840 Operation *definingOp = source().getDefiningOp(); 841 842 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 843 return *(alloc.getDynamicSizes().begin() + 844 memrefType.getDynamicDimIndex(unsignedIndex)); 845 846 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 847 return *(alloca.getDynamicSizes().begin() + 848 memrefType.getDynamicDimIndex(unsignedIndex)); 849 850 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 851 return *(view.getDynamicSizes().begin() + 852 memrefType.getDynamicDimIndex(unsignedIndex)); 853 854 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 855 llvm::SmallBitVector unusedDims = subview.getDroppedDims(); 856 unsigned resultIndex = 0; 857 unsigned sourceRank = subview.getSourceType().getRank(); 858 unsigned sourceIndex = 0; 859 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 860 if (unusedDims.test(i)) 861 continue; 862 if (resultIndex == unsignedIndex) { 863 sourceIndex = i; 864 break; 865 } 866 resultIndex++; 867 } 868 assert(subview.isDynamicSize(sourceIndex) && 869 "expected dynamic subview size"); 870 return subview.getDynamicSize(sourceIndex); 871 } 872 873 if (auto sizeInterface = 874 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 875 assert(sizeInterface.isDynamicSize(unsignedIndex) && 876 "Expected dynamic subview size"); 877 return sizeInterface.getDynamicSize(unsignedIndex); 878 } 879 880 // dim(memrefcast) -> dim 881 if (succeeded(foldMemRefCast(*this))) 882 return getResult(); 883 884 return {}; 885 } 886 887 namespace { 888 /// Fold dim of a memref reshape operation to a load into the reshape's shape 889 /// operand. 890 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 891 using OpRewritePattern<DimOp>::OpRewritePattern; 892 893 LogicalResult matchAndRewrite(DimOp dim, 894 PatternRewriter &rewriter) const override { 895 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 896 897 if (!reshape) 898 return failure(); 899 900 // Place the load directly after the reshape to ensure that the shape memref 901 // was not mutated. 902 rewriter.setInsertionPointAfter(reshape); 903 Location loc = dim.getLoc(); 904 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 905 if (load.getType() != dim.getType()) 906 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 907 rewriter.replaceOp(dim, load); 908 return success(); 909 } 910 }; 911 912 } // namespace 913 914 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 915 MLIRContext *context) { 916 results.add<DimOfMemRefReshape>(context); 917 } 918 919 // --------------------------------------------------------------------------- 920 // DmaStartOp 921 // --------------------------------------------------------------------------- 922 923 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 924 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 925 ValueRange destIndices, Value numElements, 926 Value tagMemRef, ValueRange tagIndices, Value stride, 927 Value elementsPerStride) { 928 result.addOperands(srcMemRef); 929 result.addOperands(srcIndices); 930 result.addOperands(destMemRef); 931 result.addOperands(destIndices); 932 result.addOperands({numElements, tagMemRef}); 933 result.addOperands(tagIndices); 934 if (stride) 935 result.addOperands({stride, elementsPerStride}); 936 } 937 938 void DmaStartOp::print(OpAsmPrinter &p) { 939 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " 940 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() 941 << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; 942 if (isStrided()) 943 p << ", " << getStride() << ", " << getNumElementsPerStride(); 944 945 p.printOptionalAttrDict((*this)->getAttrs()); 946 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 947 << ", " << getTagMemRef().getType(); 948 } 949 950 // Parse DmaStartOp. 951 // Ex: 952 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 953 // %tag[%index], %stride, %num_elt_per_stride : 954 // : memref<3076 x f32, 0>, 955 // memref<1024 x f32, 2>, 956 // memref<1 x i32> 957 // 958 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 959 OpAsmParser::OperandType srcMemRefInfo; 960 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 961 OpAsmParser::OperandType dstMemRefInfo; 962 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 963 OpAsmParser::OperandType numElementsInfo; 964 OpAsmParser::OperandType tagMemrefInfo; 965 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 966 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 967 968 SmallVector<Type, 3> types; 969 auto indexType = parser.getBuilder().getIndexType(); 970 971 // Parse and resolve the following list of operands: 972 // *) source memref followed by its indices (in square brackets). 973 // *) destination memref followed by its indices (in square brackets). 974 // *) dma size in KiB. 975 if (parser.parseOperand(srcMemRefInfo) || 976 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 977 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 978 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 979 parser.parseComma() || parser.parseOperand(numElementsInfo) || 980 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 981 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 982 return failure(); 983 984 // Parse optional stride and elements per stride. 985 if (parser.parseTrailingOperandList(strideInfo)) 986 return failure(); 987 988 bool isStrided = strideInfo.size() == 2; 989 if (!strideInfo.empty() && !isStrided) { 990 return parser.emitError(parser.getNameLoc(), 991 "expected two stride related operands"); 992 } 993 994 if (parser.parseColonTypeList(types)) 995 return failure(); 996 if (types.size() != 3) 997 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 998 999 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 1000 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 1001 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 1002 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 1003 // size should be an index. 1004 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 1005 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 1006 // tag indices should be index. 1007 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 1008 return failure(); 1009 1010 if (isStrided) { 1011 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1012 return failure(); 1013 } 1014 1015 return success(); 1016 } 1017 1018 LogicalResult DmaStartOp::verify() { 1019 unsigned numOperands = getNumOperands(); 1020 1021 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1022 // the number of elements. 1023 if (numOperands < 4) 1024 return emitOpError("expected at least 4 operands"); 1025 1026 // Check types of operands. The order of these calls is important: the later 1027 // calls rely on some type properties to compute the operand position. 1028 // 1. Source memref. 1029 if (!getSrcMemRef().getType().isa<MemRefType>()) 1030 return emitOpError("expected source to be of memref type"); 1031 if (numOperands < getSrcMemRefRank() + 4) 1032 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 1033 << " operands"; 1034 if (!getSrcIndices().empty() && 1035 !llvm::all_of(getSrcIndices().getTypes(), 1036 [](Type t) { return t.isIndex(); })) 1037 return emitOpError("expected source indices to be of index type"); 1038 1039 // 2. Destination memref. 1040 if (!getDstMemRef().getType().isa<MemRefType>()) 1041 return emitOpError("expected destination to be of memref type"); 1042 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 1043 if (numOperands < numExpectedOperands) 1044 return emitOpError() << "expected at least " << numExpectedOperands 1045 << " operands"; 1046 if (!getDstIndices().empty() && 1047 !llvm::all_of(getDstIndices().getTypes(), 1048 [](Type t) { return t.isIndex(); })) 1049 return emitOpError("expected destination indices to be of index type"); 1050 1051 // 3. Number of elements. 1052 if (!getNumElements().getType().isIndex()) 1053 return emitOpError("expected num elements to be of index type"); 1054 1055 // 4. Tag memref. 1056 if (!getTagMemRef().getType().isa<MemRefType>()) 1057 return emitOpError("expected tag to be of memref type"); 1058 numExpectedOperands += getTagMemRefRank(); 1059 if (numOperands < numExpectedOperands) 1060 return emitOpError() << "expected at least " << numExpectedOperands 1061 << " operands"; 1062 if (!getTagIndices().empty() && 1063 !llvm::all_of(getTagIndices().getTypes(), 1064 [](Type t) { return t.isIndex(); })) 1065 return emitOpError("expected tag indices to be of index type"); 1066 1067 // Optional stride-related operands must be either both present or both 1068 // absent. 1069 if (numOperands != numExpectedOperands && 1070 numOperands != numExpectedOperands + 2) 1071 return emitOpError("incorrect number of operands"); 1072 1073 // 5. Strides. 1074 if (isStrided()) { 1075 if (!getStride().getType().isIndex() || 1076 !getNumElementsPerStride().getType().isIndex()) 1077 return emitOpError( 1078 "expected stride and num elements per stride to be of type index"); 1079 } 1080 1081 return success(); 1082 } 1083 1084 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1085 SmallVectorImpl<OpFoldResult> &results) { 1086 /// dma_start(memrefcast) -> dma_start 1087 return foldMemRefCast(*this); 1088 } 1089 1090 // --------------------------------------------------------------------------- 1091 // DmaWaitOp 1092 // --------------------------------------------------------------------------- 1093 1094 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1095 SmallVectorImpl<OpFoldResult> &results) { 1096 /// dma_wait(memrefcast) -> dma_wait 1097 return foldMemRefCast(*this); 1098 } 1099 1100 LogicalResult DmaWaitOp::verify() { 1101 // Check that the number of tag indices matches the tagMemRef rank. 1102 unsigned numTagIndices = tagIndices().size(); 1103 unsigned tagMemRefRank = getTagMemRefRank(); 1104 if (numTagIndices != tagMemRefRank) 1105 return emitOpError() << "expected tagIndices to have the same number of " 1106 "elements as the tagMemRef rank, expected " 1107 << tagMemRefRank << ", but got " << numTagIndices; 1108 return success(); 1109 } 1110 1111 //===----------------------------------------------------------------------===// 1112 // GenericAtomicRMWOp 1113 //===----------------------------------------------------------------------===// 1114 1115 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, 1116 Value memref, ValueRange ivs) { 1117 result.addOperands(memref); 1118 result.addOperands(ivs); 1119 1120 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) { 1121 Type elementType = memrefType.getElementType(); 1122 result.addTypes(elementType); 1123 1124 Region *bodyRegion = result.addRegion(); 1125 bodyRegion->push_back(new Block()); 1126 bodyRegion->addArgument(elementType, memref.getLoc()); 1127 } 1128 } 1129 1130 LogicalResult GenericAtomicRMWOp::verify() { 1131 auto &body = getRegion(); 1132 if (body.getNumArguments() != 1) 1133 return emitOpError("expected single number of entry block arguments"); 1134 1135 if (getResult().getType() != body.getArgument(0).getType()) 1136 return emitOpError("expected block argument of the same type result type"); 1137 1138 bool hasSideEffects = 1139 body.walk([&](Operation *nestedOp) { 1140 if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) 1141 return WalkResult::advance(); 1142 nestedOp->emitError( 1143 "body of 'memref.generic_atomic_rmw' should contain " 1144 "only operations with no side effects"); 1145 return WalkResult::interrupt(); 1146 }) 1147 .wasInterrupted(); 1148 return hasSideEffects ? failure() : success(); 1149 } 1150 1151 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, 1152 OperationState &result) { 1153 OpAsmParser::OperandType memref; 1154 Type memrefType; 1155 SmallVector<OpAsmParser::OperandType, 4> ivs; 1156 1157 Type indexType = parser.getBuilder().getIndexType(); 1158 if (parser.parseOperand(memref) || 1159 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || 1160 parser.parseColonType(memrefType) || 1161 parser.resolveOperand(memref, memrefType, result.operands) || 1162 parser.resolveOperands(ivs, indexType, result.operands)) 1163 return failure(); 1164 1165 Region *body = result.addRegion(); 1166 if (parser.parseRegion(*body, llvm::None, llvm::None) || 1167 parser.parseOptionalAttrDict(result.attributes)) 1168 return failure(); 1169 result.types.push_back(memrefType.cast<MemRefType>().getElementType()); 1170 return success(); 1171 } 1172 1173 void GenericAtomicRMWOp::print(OpAsmPrinter &p) { 1174 p << ' ' << memref() << "[" << indices() << "] : " << memref().getType() 1175 << ' '; 1176 p.printRegion(getRegion()); 1177 p.printOptionalAttrDict((*this)->getAttrs()); 1178 } 1179 1180 //===----------------------------------------------------------------------===// 1181 // AtomicYieldOp 1182 //===----------------------------------------------------------------------===// 1183 1184 LogicalResult AtomicYieldOp::verify() { 1185 Type parentType = (*this)->getParentOp()->getResultTypes().front(); 1186 Type resultType = result().getType(); 1187 if (parentType != resultType) 1188 return emitOpError() << "types mismatch between yield op: " << resultType 1189 << " and its parent: " << parentType; 1190 return success(); 1191 } 1192 1193 //===----------------------------------------------------------------------===// 1194 // GlobalOp 1195 //===----------------------------------------------------------------------===// 1196 1197 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1198 TypeAttr type, 1199 Attribute initialValue) { 1200 p << type; 1201 if (!op.isExternal()) { 1202 p << " = "; 1203 if (op.isUninitialized()) 1204 p << "uninitialized"; 1205 else 1206 p.printAttributeWithoutType(initialValue); 1207 } 1208 } 1209 1210 static ParseResult 1211 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1212 Attribute &initialValue) { 1213 Type type; 1214 if (parser.parseType(type)) 1215 return failure(); 1216 1217 auto memrefType = type.dyn_cast<MemRefType>(); 1218 if (!memrefType || !memrefType.hasStaticShape()) 1219 return parser.emitError(parser.getNameLoc()) 1220 << "type should be static shaped memref, but got " << type; 1221 typeAttr = TypeAttr::get(type); 1222 1223 if (parser.parseOptionalEqual()) 1224 return success(); 1225 1226 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1227 initialValue = UnitAttr::get(parser.getContext()); 1228 return success(); 1229 } 1230 1231 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1232 if (parser.parseAttribute(initialValue, tensorType)) 1233 return failure(); 1234 if (!initialValue.isa<ElementsAttr>()) 1235 return parser.emitError(parser.getNameLoc()) 1236 << "initial value should be a unit or elements attribute"; 1237 return success(); 1238 } 1239 1240 LogicalResult GlobalOp::verify() { 1241 auto memrefType = type().dyn_cast<MemRefType>(); 1242 if (!memrefType || !memrefType.hasStaticShape()) 1243 return emitOpError("type should be static shaped memref, but got ") 1244 << type(); 1245 1246 // Verify that the initial value, if present, is either a unit attribute or 1247 // an elements attribute. 1248 if (initial_value().hasValue()) { 1249 Attribute initValue = initial_value().getValue(); 1250 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1251 return emitOpError("initial value should be a unit or elements " 1252 "attribute, but got ") 1253 << initValue; 1254 1255 // Check that the type of the initial value is compatible with the type of 1256 // the global variable. 1257 if (initValue.isa<ElementsAttr>()) { 1258 Type initType = initValue.getType(); 1259 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1260 if (initType != tensorType) 1261 return emitOpError("initial value expected to be of type ") 1262 << tensorType << ", but was of type " << initType; 1263 } 1264 } 1265 1266 if (Optional<uint64_t> alignAttr = alignment()) { 1267 uint64_t alignment = alignAttr.getValue(); 1268 1269 if (!llvm::isPowerOf2_64(alignment)) 1270 return emitError() << "alignment attribute value " << alignment 1271 << " is not a power of 2"; 1272 } 1273 1274 // TODO: verify visibility for declarations. 1275 return success(); 1276 } 1277 1278 ElementsAttr GlobalOp::getConstantInitValue() { 1279 auto initVal = initial_value(); 1280 if (constant() && initVal.hasValue()) 1281 return initVal.getValue().cast<ElementsAttr>(); 1282 return {}; 1283 } 1284 1285 //===----------------------------------------------------------------------===// 1286 // GetGlobalOp 1287 //===----------------------------------------------------------------------===// 1288 1289 LogicalResult 1290 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1291 // Verify that the result type is same as the type of the referenced 1292 // memref.global op. 1293 auto global = 1294 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1295 if (!global) 1296 return emitOpError("'") 1297 << name() << "' does not reference a valid global memref"; 1298 1299 Type resultType = result().getType(); 1300 if (global.type() != resultType) 1301 return emitOpError("result type ") 1302 << resultType << " does not match type " << global.type() 1303 << " of the global memref @" << name(); 1304 return success(); 1305 } 1306 1307 //===----------------------------------------------------------------------===// 1308 // LoadOp 1309 //===----------------------------------------------------------------------===// 1310 1311 LogicalResult LoadOp::verify() { 1312 if (getNumOperands() != 1 + getMemRefType().getRank()) 1313 return emitOpError("incorrect number of indices for load"); 1314 return success(); 1315 } 1316 1317 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1318 /// load(memrefcast) -> load 1319 if (succeeded(foldMemRefCast(*this))) 1320 return getResult(); 1321 return OpFoldResult(); 1322 } 1323 1324 //===----------------------------------------------------------------------===// 1325 // PrefetchOp 1326 //===----------------------------------------------------------------------===// 1327 1328 void PrefetchOp::print(OpAsmPrinter &p) { 1329 p << " " << memref() << '['; 1330 p.printOperands(indices()); 1331 p << ']' << ", " << (isWrite() ? "write" : "read"); 1332 p << ", locality<" << localityHint(); 1333 p << ">, " << (isDataCache() ? "data" : "instr"); 1334 p.printOptionalAttrDict( 1335 (*this)->getAttrs(), 1336 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1337 p << " : " << getMemRefType(); 1338 } 1339 1340 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { 1341 OpAsmParser::OperandType memrefInfo; 1342 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1343 IntegerAttr localityHint; 1344 MemRefType type; 1345 StringRef readOrWrite, cacheType; 1346 1347 auto indexTy = parser.getBuilder().getIndexType(); 1348 auto i32Type = parser.getBuilder().getIntegerType(32); 1349 if (parser.parseOperand(memrefInfo) || 1350 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1351 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1352 parser.parseComma() || parser.parseKeyword("locality") || 1353 parser.parseLess() || 1354 parser.parseAttribute(localityHint, i32Type, "localityHint", 1355 result.attributes) || 1356 parser.parseGreater() || parser.parseComma() || 1357 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1358 parser.resolveOperand(memrefInfo, type, result.operands) || 1359 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1360 return failure(); 1361 1362 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1363 return parser.emitError(parser.getNameLoc(), 1364 "rw specifier has to be 'read' or 'write'"); 1365 result.addAttribute( 1366 PrefetchOp::getIsWriteAttrName(), 1367 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1368 1369 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1370 return parser.emitError(parser.getNameLoc(), 1371 "cache type has to be 'data' or 'instr'"); 1372 1373 result.addAttribute( 1374 PrefetchOp::getIsDataCacheAttrName(), 1375 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1376 1377 return success(); 1378 } 1379 1380 LogicalResult PrefetchOp::verify() { 1381 if (getNumOperands() != 1 + getMemRefType().getRank()) 1382 return emitOpError("too few indices"); 1383 1384 return success(); 1385 } 1386 1387 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1388 SmallVectorImpl<OpFoldResult> &results) { 1389 // prefetch(memrefcast) -> prefetch 1390 return foldMemRefCast(*this); 1391 } 1392 1393 //===----------------------------------------------------------------------===// 1394 // RankOp 1395 //===----------------------------------------------------------------------===// 1396 1397 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1398 // Constant fold rank when the rank of the operand is known. 1399 auto type = getOperand().getType(); 1400 auto shapedType = type.dyn_cast<ShapedType>(); 1401 if (shapedType && shapedType.hasRank()) 1402 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1403 return IntegerAttr(); 1404 } 1405 1406 //===----------------------------------------------------------------------===// 1407 // ReinterpretCastOp 1408 //===----------------------------------------------------------------------===// 1409 1410 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1411 /// `staticSizes` and `staticStrides` are automatically filled with 1412 /// source-memref-rank sentinel values that encode dynamic entries. 1413 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1414 MemRefType resultType, Value source, 1415 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1416 ArrayRef<OpFoldResult> strides, 1417 ArrayRef<NamedAttribute> attrs) { 1418 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1419 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1420 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1421 ShapedType::kDynamicStrideOrOffset); 1422 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1423 ShapedType::kDynamicSize); 1424 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1425 ShapedType::kDynamicStrideOrOffset); 1426 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1427 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1428 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1429 result.addAttributes(attrs); 1430 } 1431 1432 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1433 MemRefType resultType, Value source, 1434 int64_t offset, ArrayRef<int64_t> sizes, 1435 ArrayRef<int64_t> strides, 1436 ArrayRef<NamedAttribute> attrs) { 1437 SmallVector<OpFoldResult> sizeValues = 1438 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1439 return b.getI64IntegerAttr(v); 1440 })); 1441 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1442 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1443 return b.getI64IntegerAttr(v); 1444 })); 1445 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1446 strideValues, attrs); 1447 } 1448 1449 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1450 MemRefType resultType, Value source, Value offset, 1451 ValueRange sizes, ValueRange strides, 1452 ArrayRef<NamedAttribute> attrs) { 1453 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1454 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1455 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1456 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1457 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1458 } 1459 1460 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1461 // completed automatically, like we have for subview and extract_slice. 1462 LogicalResult ReinterpretCastOp::verify() { 1463 // The source and result memrefs should be in the same memory space. 1464 auto srcType = source().getType().cast<BaseMemRefType>(); 1465 auto resultType = getType().cast<MemRefType>(); 1466 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1467 return emitError("different memory spaces specified for source type ") 1468 << srcType << " and result memref type " << resultType; 1469 if (srcType.getElementType() != resultType.getElementType()) 1470 return emitError("different element types specified for source type ") 1471 << srcType << " and result memref type " << resultType; 1472 1473 // Match sizes in result memref type and in static_sizes attribute. 1474 for (auto &en : llvm::enumerate(llvm::zip( 1475 resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { 1476 int64_t resultSize = std::get<0>(en.value()); 1477 int64_t expectedSize = std::get<1>(en.value()); 1478 if (!ShapedType::isDynamic(resultSize) && 1479 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1480 return emitError("expected result type with size = ") 1481 << expectedSize << " instead of " << resultSize 1482 << " in dim = " << en.index(); 1483 } 1484 1485 // Match offset and strides in static_offset and static_strides attributes. If 1486 // result memref type has no affine map specified, this will assume an 1487 // identity layout. 1488 int64_t resultOffset; 1489 SmallVector<int64_t, 4> resultStrides; 1490 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1491 return emitError("expected result type to have strided layout but found ") 1492 << resultType; 1493 1494 // Match offset in result memref type and in static_offsets attribute. 1495 int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); 1496 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1497 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1498 resultOffset != expectedOffset) 1499 return emitError("expected result type with offset = ") 1500 << resultOffset << " instead of " << expectedOffset; 1501 1502 // Match strides in result memref type and in static_strides attribute. 1503 for (auto &en : llvm::enumerate(llvm::zip( 1504 resultStrides, extractFromI64ArrayAttr(static_strides())))) { 1505 int64_t resultStride = std::get<0>(en.value()); 1506 int64_t expectedStride = std::get<1>(en.value()); 1507 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1508 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1509 resultStride != expectedStride) 1510 return emitError("expected result type with stride = ") 1511 << expectedStride << " instead of " << resultStride 1512 << " in dim = " << en.index(); 1513 } 1514 1515 return success(); 1516 } 1517 1518 OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) { 1519 Value src = source(); 1520 auto getPrevSrc = [&]() -> Value { 1521 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x). 1522 if (auto prev = src.getDefiningOp<ReinterpretCastOp>()) 1523 return prev.source(); 1524 1525 // reinterpret_cast(cast(x)) -> reinterpret_cast(x). 1526 if (auto prev = src.getDefiningOp<CastOp>()) 1527 return prev.source(); 1528 1529 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets 1530 // are 0. 1531 if (auto prev = src.getDefiningOp<SubViewOp>()) 1532 if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { 1533 return isConstantIntValue(val, 0); 1534 })) 1535 return prev.source(); 1536 1537 return nullptr; 1538 }; 1539 1540 if (auto prevSrc = getPrevSrc()) { 1541 sourceMutable().assign(prevSrc); 1542 return getResult(); 1543 } 1544 1545 return nullptr; 1546 } 1547 1548 //===----------------------------------------------------------------------===// 1549 // Reassociative reshape ops 1550 //===----------------------------------------------------------------------===// 1551 1552 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1553 return getSymbolLessAffineMaps(getReassociationExprs()); 1554 } 1555 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1556 return convertReassociationIndicesToExprs(getContext(), 1557 getReassociationIndices()); 1558 } 1559 1560 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1561 return getSymbolLessAffineMaps(getReassociationExprs()); 1562 } 1563 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1564 return convertReassociationIndicesToExprs(getContext(), 1565 getReassociationIndices()); 1566 } 1567 1568 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1569 /// copies. 1570 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1571 ArrayRef<int64_t> sizes, 1572 ArrayRef<AffineExpr> strides) { 1573 // Bands of extent one can be reshaped, as they are not reshaped at all. 1574 if (extent == 1) 1575 return true; 1576 // Otherwise, the size of the first dimension needs to be known. 1577 if (ShapedType::isDynamic(sizes[dim])) 1578 return false; 1579 assert(sizes.size() == strides.size() && "mismatched ranks"); 1580 // off by 1 indexing to avoid out of bounds 1581 // V 1582 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1583 // Only bands of static shapes are reshapable. This is due to the fact that 1584 // there is no relation between dynamic sizes and dynamic strides: we do not 1585 // have enough information to know whether a "-1" size corresponds to the 1586 // proper symbol in the AffineExpr of a stride. 1587 if (ShapedType::isDynamic(sizes[idx + 1])) 1588 return false; 1589 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1590 // simplify on the fly and catch more reshapable cases. 1591 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1592 return false; 1593 } 1594 return true; 1595 } 1596 1597 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1598 /// expected to be valid) to `type`. 1599 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1600 /// MemRefType. 1601 static MemRefType 1602 computeReshapeCollapsedType(MemRefType type, 1603 ArrayRef<AffineMap> reassociation) { 1604 auto sizes = type.getShape(); 1605 AffineExpr offset; 1606 SmallVector<AffineExpr, 4> strides; 1607 auto status = getStridesAndOffset(type, strides, offset); 1608 auto isIdentityLayout = type.getLayout().isIdentity(); 1609 (void)status; 1610 assert(succeeded(status) && "expected strided memref"); 1611 1612 SmallVector<int64_t, 4> newSizes; 1613 newSizes.reserve(reassociation.size()); 1614 SmallVector<AffineExpr, 4> newStrides; 1615 newStrides.reserve(reassociation.size()); 1616 1617 // Use the fact that reassociation is valid to simplify the logic: only use 1618 // each map's rank. 1619 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1620 unsigned currentDim = 0; 1621 for (AffineMap m : reassociation) { 1622 unsigned dim = m.getNumResults(); 1623 int64_t size = 1; 1624 AffineExpr stride = strides[currentDim + dim - 1]; 1625 if (isIdentityLayout || 1626 isReshapableDimBand(currentDim, dim, sizes, strides)) { 1627 for (unsigned d = 0; d < dim; ++d) { 1628 int64_t currentSize = sizes[currentDim + d]; 1629 if (ShapedType::isDynamic(currentSize)) { 1630 size = ShapedType::kDynamicSize; 1631 break; 1632 } 1633 size *= currentSize; 1634 } 1635 } else { 1636 size = ShapedType::kDynamicSize; 1637 stride = AffineExpr(); 1638 } 1639 newSizes.push_back(size); 1640 newStrides.push_back(stride); 1641 currentDim += dim; 1642 } 1643 1644 // Early-exit: if `type` is contiguous, the result must be contiguous. 1645 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1646 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1647 1648 // Convert back to int64_t because we don't have enough information to create 1649 // new strided layouts from AffineExpr only. This corresponds to a case where 1650 // copies may be necessary. 1651 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1652 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1653 intOffset = o.getValue(); 1654 SmallVector<int64_t, 4> intStrides; 1655 intStrides.reserve(strides.size()); 1656 for (auto stride : newStrides) { 1657 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1658 intStrides.push_back(cst.getValue()); 1659 else 1660 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1661 } 1662 auto layout = 1663 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1664 return canonicalizeStridedLayout( 1665 MemRefType::Builder(type).setShape(newSizes).setLayout( 1666 AffineMapAttr::get(layout))); 1667 } 1668 1669 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1670 ArrayRef<ReassociationIndices> reassociation, 1671 ArrayRef<NamedAttribute> attrs) { 1672 auto memRefType = src.getType().cast<MemRefType>(); 1673 auto resultType = computeReshapeCollapsedType( 1674 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1675 b.getContext(), reassociation))); 1676 build(b, result, resultType, src, attrs); 1677 result.addAttribute(getReassociationAttrName(), 1678 getReassociationIndicesAttribute(b, reassociation)); 1679 } 1680 1681 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1682 ArrayRef<ReassociationIndices> reassociation, 1683 ArrayRef<NamedAttribute> attrs) { 1684 auto memRefType = src.getType().cast<MemRefType>(); 1685 auto resultType = computeReshapeCollapsedType( 1686 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1687 b.getContext(), reassociation))); 1688 build(b, result, resultType, src, attrs); 1689 result.addAttribute(getReassociationAttrName(), 1690 getReassociationIndicesAttribute(b, reassociation)); 1691 } 1692 1693 template <typename ReshapeOp, 1694 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1695 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1696 MemRefType collapsedType) { 1697 if (failed( 1698 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1699 return failure(); 1700 auto maps = op.getReassociationMaps(); 1701 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1702 if (collapsedType != expectedType) 1703 return op.emitOpError("expected collapsed type to be ") 1704 << expectedType << ", but got " << collapsedType; 1705 return success(); 1706 } 1707 1708 LogicalResult ExpandShapeOp::verify() { 1709 return verifyReshapeOp(*this, getResultType(), getSrcType()); 1710 } 1711 1712 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1713 MLIRContext *context) { 1714 results.add<CollapseReshapeOps<ExpandShapeOp>, 1715 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1716 } 1717 1718 LogicalResult CollapseShapeOp::verify() { 1719 return verifyReshapeOp(*this, getSrcType(), getResultType()); 1720 } 1721 1722 struct CollapseShapeOpMemRefCastFolder 1723 : public OpRewritePattern<CollapseShapeOp> { 1724 public: 1725 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1726 1727 LogicalResult matchAndRewrite(CollapseShapeOp op, 1728 PatternRewriter &rewriter) const override { 1729 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1730 if (!cast) 1731 return failure(); 1732 1733 if (!CastOp::canFoldIntoConsumerOp(cast)) 1734 return failure(); 1735 1736 Type newResultType = computeReshapeCollapsedType( 1737 cast.getOperand().getType().cast<MemRefType>(), 1738 op.getReassociationMaps()); 1739 1740 if (newResultType == op.getResultType()) { 1741 rewriter.updateRootInPlace( 1742 op, [&]() { op.srcMutable().assign(cast.source()); }); 1743 } else { 1744 Value newOp = rewriter.create<CollapseShapeOp>( 1745 op->getLoc(), cast.source(), op.getReassociationIndices()); 1746 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1747 } 1748 return success(); 1749 } 1750 }; 1751 1752 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1753 MLIRContext *context) { 1754 results.add<CollapseReshapeOps<CollapseShapeOp>, 1755 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1756 CollapseShapeOpMemRefCastFolder>(context); 1757 } 1758 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1759 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1760 } 1761 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1762 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1763 } 1764 1765 //===----------------------------------------------------------------------===// 1766 // ReshapeOp 1767 //===----------------------------------------------------------------------===// 1768 1769 LogicalResult ReshapeOp::verify() { 1770 Type operandType = source().getType(); 1771 Type resultType = result().getType(); 1772 1773 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1774 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1775 if (operandElementType != resultElementType) 1776 return emitOpError("element types of source and destination memref " 1777 "types should be the same"); 1778 1779 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1780 if (!operandMemRefType.getLayout().isIdentity()) 1781 return emitOpError("source memref type should have identity affine map"); 1782 1783 int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0); 1784 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1785 if (resultMemRefType) { 1786 if (!resultMemRefType.getLayout().isIdentity()) 1787 return emitOpError("result memref type should have identity affine map"); 1788 if (shapeSize == ShapedType::kDynamicSize) 1789 return emitOpError("cannot use shape operand with dynamic length to " 1790 "reshape to statically-ranked memref type"); 1791 if (shapeSize != resultMemRefType.getRank()) 1792 return emitOpError( 1793 "length of shape operand differs from the result's memref rank"); 1794 } 1795 return success(); 1796 } 1797 1798 //===----------------------------------------------------------------------===// 1799 // StoreOp 1800 //===----------------------------------------------------------------------===// 1801 1802 LogicalResult StoreOp::verify() { 1803 if (getNumOperands() != 2 + getMemRefType().getRank()) 1804 return emitOpError("store index operand count not equal to memref rank"); 1805 1806 return success(); 1807 } 1808 1809 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1810 SmallVectorImpl<OpFoldResult> &results) { 1811 /// store(memrefcast) -> store 1812 return foldMemRefCast(*this, getValueToStore()); 1813 } 1814 1815 //===----------------------------------------------------------------------===// 1816 // SubViewOp 1817 //===----------------------------------------------------------------------===// 1818 1819 namespace { 1820 /// Helpers to write more idiomatic operations. 1821 namespace saturated_arith { 1822 struct Wrapper { 1823 explicit Wrapper(int64_t v) : v(v) {} 1824 operator int64_t() { return v; } 1825 int64_t v; 1826 }; 1827 Wrapper operator+(Wrapper a, int64_t b) { 1828 if (ShapedType::isDynamicStrideOrOffset(a) || 1829 ShapedType::isDynamicStrideOrOffset(b)) 1830 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1831 return Wrapper(a.v + b); 1832 } 1833 Wrapper operator*(Wrapper a, int64_t b) { 1834 if (ShapedType::isDynamicStrideOrOffset(a) || 1835 ShapedType::isDynamicStrideOrOffset(b)) 1836 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1837 return Wrapper(a.v * b); 1838 } 1839 } // namespace saturated_arith 1840 } // namespace 1841 1842 /// A subview result type can be fully inferred from the source type and the 1843 /// static representation of offsets, sizes and strides. Special sentinels 1844 /// encode the dynamic case. 1845 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1846 ArrayRef<int64_t> staticOffsets, 1847 ArrayRef<int64_t> staticSizes, 1848 ArrayRef<int64_t> staticStrides) { 1849 unsigned rank = sourceMemRefType.getRank(); 1850 (void)rank; 1851 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 1852 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 1853 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 1854 1855 // Extract source offset and strides. 1856 int64_t sourceOffset; 1857 SmallVector<int64_t, 4> sourceStrides; 1858 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1859 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1860 (void)res; 1861 1862 // Compute target offset whose value is: 1863 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1864 int64_t targetOffset = sourceOffset; 1865 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1866 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1867 using namespace saturated_arith; 1868 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1869 } 1870 1871 // Compute target stride whose value is: 1872 // `sourceStrides_i * staticStrides_i`. 1873 SmallVector<int64_t, 4> targetStrides; 1874 targetStrides.reserve(staticOffsets.size()); 1875 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1876 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1877 using namespace saturated_arith; 1878 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1879 } 1880 1881 // The type is now known. 1882 return MemRefType::get( 1883 staticSizes, sourceMemRefType.getElementType(), 1884 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1885 sourceMemRefType.getContext()), 1886 sourceMemRefType.getMemorySpace()); 1887 } 1888 1889 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1890 ArrayRef<OpFoldResult> offsets, 1891 ArrayRef<OpFoldResult> sizes, 1892 ArrayRef<OpFoldResult> strides) { 1893 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1894 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1895 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1896 ShapedType::kDynamicStrideOrOffset); 1897 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1898 ShapedType::kDynamicSize); 1899 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1900 ShapedType::kDynamicStrideOrOffset); 1901 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1902 staticSizes, staticStrides); 1903 } 1904 1905 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1906 MemRefType sourceRankedTensorType, 1907 ArrayRef<int64_t> offsets, 1908 ArrayRef<int64_t> sizes, 1909 ArrayRef<int64_t> strides) { 1910 auto inferredType = 1911 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 1912 .cast<MemRefType>(); 1913 assert(inferredType.getRank() >= resultRank && "expected "); 1914 int rankDiff = inferredType.getRank() - resultRank; 1915 if (rankDiff > 0) { 1916 auto shape = inferredType.getShape(); 1917 llvm::SmallBitVector dimsToProject = 1918 getPositionsOfShapeOne(rankDiff, shape); 1919 SmallVector<int64_t> projectedShape; 1920 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1921 if (!dimsToProject.test(pos)) 1922 projectedShape.push_back(shape[pos]); 1923 1924 AffineMap map = inferredType.getLayout().getAffineMap(); 1925 if (!map.isIdentity()) 1926 map = getProjectedMap(map, dimsToProject); 1927 inferredType = 1928 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1929 inferredType.getMemorySpace()); 1930 } 1931 return inferredType; 1932 } 1933 1934 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1935 MemRefType sourceRankedTensorType, 1936 ArrayRef<OpFoldResult> offsets, 1937 ArrayRef<OpFoldResult> sizes, 1938 ArrayRef<OpFoldResult> strides) { 1939 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1940 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1941 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1942 ShapedType::kDynamicStrideOrOffset); 1943 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1944 ShapedType::kDynamicSize); 1945 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1946 ShapedType::kDynamicStrideOrOffset); 1947 return SubViewOp::inferRankReducedResultType( 1948 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1949 staticStrides); 1950 } 1951 // Build a SubViewOp with mixed static and dynamic entries and custom result 1952 // type. If the type passed is nullptr, it is inferred. 1953 void SubViewOp::build(OpBuilder &b, OperationState &result, 1954 MemRefType resultType, Value source, 1955 ArrayRef<OpFoldResult> offsets, 1956 ArrayRef<OpFoldResult> sizes, 1957 ArrayRef<OpFoldResult> strides, 1958 ArrayRef<NamedAttribute> attrs) { 1959 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1960 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1961 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1962 ShapedType::kDynamicStrideOrOffset); 1963 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1964 ShapedType::kDynamicSize); 1965 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1966 ShapedType::kDynamicStrideOrOffset); 1967 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1968 // Structuring implementation this way avoids duplication between builders. 1969 if (!resultType) { 1970 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1971 staticSizes, staticStrides) 1972 .cast<MemRefType>(); 1973 } 1974 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1975 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1976 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1977 result.addAttributes(attrs); 1978 } 1979 1980 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1981 // type. 1982 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1983 ArrayRef<OpFoldResult> offsets, 1984 ArrayRef<OpFoldResult> sizes, 1985 ArrayRef<OpFoldResult> strides, 1986 ArrayRef<NamedAttribute> attrs) { 1987 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1988 } 1989 1990 // Build a SubViewOp with static entries and inferred result type. 1991 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1992 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1993 ArrayRef<int64_t> strides, 1994 ArrayRef<NamedAttribute> attrs) { 1995 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1996 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1997 return b.getI64IntegerAttr(v); 1998 })); 1999 SmallVector<OpFoldResult> sizeValues = 2000 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2001 return b.getI64IntegerAttr(v); 2002 })); 2003 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2004 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2005 return b.getI64IntegerAttr(v); 2006 })); 2007 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 2008 } 2009 2010 // Build a SubViewOp with dynamic entries and custom result type. If the 2011 // type passed is nullptr, it is inferred. 2012 void SubViewOp::build(OpBuilder &b, OperationState &result, 2013 MemRefType resultType, Value source, 2014 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2015 ArrayRef<int64_t> strides, 2016 ArrayRef<NamedAttribute> attrs) { 2017 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2018 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 2019 return b.getI64IntegerAttr(v); 2020 })); 2021 SmallVector<OpFoldResult> sizeValues = 2022 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 2023 return b.getI64IntegerAttr(v); 2024 })); 2025 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2026 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 2027 return b.getI64IntegerAttr(v); 2028 })); 2029 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 2030 attrs); 2031 } 2032 2033 // Build a SubViewOp with dynamic entries and custom result type. If the type 2034 // passed is nullptr, it is inferred. 2035 void SubViewOp::build(OpBuilder &b, OperationState &result, 2036 MemRefType resultType, Value source, ValueRange offsets, 2037 ValueRange sizes, ValueRange strides, 2038 ArrayRef<NamedAttribute> attrs) { 2039 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2040 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 2041 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 2042 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 2043 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2044 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 2045 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 2046 } 2047 2048 // Build a SubViewOp with dynamic entries and inferred result type. 2049 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2050 ValueRange offsets, ValueRange sizes, ValueRange strides, 2051 ArrayRef<NamedAttribute> attrs) { 2052 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2053 } 2054 2055 /// For ViewLikeOpInterface. 2056 Value SubViewOp::getViewSource() { return source(); } 2057 2058 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 2059 /// value). 2060 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 2061 AffineExpr t1Offset, t2Offset; 2062 SmallVector<AffineExpr> t1Strides, t2Strides; 2063 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 2064 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 2065 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 2066 } 2067 2068 /// Checks if `original` Type type can be rank reduced to `reduced` type. 2069 /// This function is slight variant of `is subsequence` algorithm where 2070 /// not matching dimension must be 1. 2071 static SliceVerificationResult 2072 isRankReducedMemRefType(MemRefType originalType, 2073 MemRefType candidateRankReducedType, 2074 ArrayRef<OpFoldResult> sizes) { 2075 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 2076 if (partialRes != SliceVerificationResult::Success) 2077 return partialRes; 2078 2079 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 2080 originalType, candidateRankReducedType, sizes); 2081 2082 // Sizes cannot be matched in case empty vector is returned. 2083 if (!optionalUnusedDimsMask.hasValue()) 2084 return SliceVerificationResult::LayoutMismatch; 2085 2086 if (originalType.getMemorySpace() != 2087 candidateRankReducedType.getMemorySpace()) 2088 return SliceVerificationResult::MemSpaceMismatch; 2089 2090 // No amount of stride dropping can reconcile incompatible offsets. 2091 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 2092 return SliceVerificationResult::LayoutMismatch; 2093 2094 return SliceVerificationResult::Success; 2095 } 2096 2097 template <typename OpTy> 2098 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 2099 OpTy op, Type expectedType) { 2100 auto memrefType = expectedType.cast<ShapedType>(); 2101 switch (result) { 2102 case SliceVerificationResult::Success: 2103 return success(); 2104 case SliceVerificationResult::RankTooLarge: 2105 return op.emitError("expected result rank to be smaller or equal to ") 2106 << "the source rank. "; 2107 case SliceVerificationResult::SizeMismatch: 2108 return op.emitError("expected result type to be ") 2109 << expectedType 2110 << " or a rank-reduced version. (mismatch of result sizes) "; 2111 case SliceVerificationResult::ElemTypeMismatch: 2112 return op.emitError("expected result element type to be ") 2113 << memrefType.getElementType(); 2114 case SliceVerificationResult::MemSpaceMismatch: 2115 return op.emitError("expected result and source memory spaces to match."); 2116 case SliceVerificationResult::LayoutMismatch: 2117 return op.emitError("expected result type to be ") 2118 << expectedType 2119 << " or a rank-reduced version. (mismatch of result layout) "; 2120 } 2121 llvm_unreachable("unexpected subview verification result"); 2122 } 2123 2124 /// Verifier for SubViewOp. 2125 LogicalResult SubViewOp::verify() { 2126 MemRefType baseType = getSourceType(); 2127 MemRefType subViewType = getType(); 2128 2129 // The base memref and the view memref should be in the same memory space. 2130 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2131 return emitError("different memory spaces specified for base memref " 2132 "type ") 2133 << baseType << " and subview memref type " << subViewType; 2134 2135 // Verify that the base memref type has a strided layout map. 2136 if (!isStrided(baseType)) 2137 return emitError("base type ") << baseType << " is not strided"; 2138 2139 // Verify result type against inferred type. 2140 auto expectedType = SubViewOp::inferResultType( 2141 baseType, extractFromI64ArrayAttr(static_offsets()), 2142 extractFromI64ArrayAttr(static_sizes()), 2143 extractFromI64ArrayAttr(static_strides())); 2144 2145 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 2146 subViewType, getMixedSizes()); 2147 return produceSubViewErrorMsg(result, *this, expectedType); 2148 } 2149 2150 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 2151 return os << "range " << range.offset << ":" << range.size << ":" 2152 << range.stride; 2153 } 2154 2155 /// Return the list of Range (i.e. offset, size, stride). Each Range 2156 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2157 /// with `b` at location `loc`. 2158 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2159 OpBuilder &b, Location loc) { 2160 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2161 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2162 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2163 SmallVector<Range, 8> res; 2164 unsigned rank = ranks[0]; 2165 res.reserve(rank); 2166 for (unsigned idx = 0; idx < rank; ++idx) { 2167 Value offset = 2168 op.isDynamicOffset(idx) 2169 ? op.getDynamicOffset(idx) 2170 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2171 Value size = 2172 op.isDynamicSize(idx) 2173 ? op.getDynamicSize(idx) 2174 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 2175 Value stride = 2176 op.isDynamicStride(idx) 2177 ? op.getDynamicStride(idx) 2178 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 2179 res.emplace_back(Range{offset, size, stride}); 2180 } 2181 return res; 2182 } 2183 2184 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2185 /// deduce the result type for the given `sourceType`. Additionally, reduce the 2186 /// rank of the inferred result type if `currentResultType` is lower rank than 2187 /// `currentSourceType`. Use this signature if `sourceType` is updated together 2188 /// with the result type. In this case, it is important to compute the dropped 2189 /// dimensions using `currentSourceType` whose strides align with 2190 /// `currentResultType`. 2191 static MemRefType getCanonicalSubViewResultType( 2192 MemRefType currentResultType, MemRefType currentSourceType, 2193 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 2194 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 2195 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2196 mixedSizes, mixedStrides) 2197 .cast<MemRefType>(); 2198 llvm::Optional<llvm::SmallBitVector> unusedDims = 2199 computeMemRefRankReductionMask(currentSourceType, currentResultType, 2200 mixedSizes); 2201 // Return nullptr as failure mode. 2202 if (!unusedDims) 2203 return nullptr; 2204 SmallVector<int64_t> shape; 2205 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 2206 if (unusedDims->test(sizes.index())) 2207 continue; 2208 shape.push_back(sizes.value()); 2209 } 2210 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 2211 if (!layoutMap.isIdentity()) 2212 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 2213 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 2214 nonRankReducedType.getMemorySpace()); 2215 } 2216 2217 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2218 /// deduce the result type. Additionally, reduce the rank of the inferred result 2219 /// type if `currentResultType` is lower rank than `sourceType`. 2220 static MemRefType getCanonicalSubViewResultType( 2221 MemRefType currentResultType, MemRefType sourceType, 2222 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 2223 ArrayRef<OpFoldResult> mixedStrides) { 2224 return getCanonicalSubViewResultType(currentResultType, sourceType, 2225 sourceType, mixedOffsets, mixedSizes, 2226 mixedStrides); 2227 } 2228 2229 /// Helper method to check if a `subview` operation is trivially a no-op. This 2230 /// is the case if the all offsets are zero, all strides are 1, and the source 2231 /// shape is same as the size of the subview. In such cases, the subview can be 2232 /// folded into its source. 2233 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 2234 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 2235 return false; 2236 2237 auto mixedOffsets = subViewOp.getMixedOffsets(); 2238 auto mixedSizes = subViewOp.getMixedSizes(); 2239 auto mixedStrides = subViewOp.getMixedStrides(); 2240 2241 // Check offsets are zero. 2242 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 2243 Optional<int64_t> intValue = getConstantIntValue(ofr); 2244 return !intValue || intValue.getValue() != 0; 2245 })) 2246 return false; 2247 2248 // Check strides are one. 2249 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 2250 Optional<int64_t> intValue = getConstantIntValue(ofr); 2251 return !intValue || intValue.getValue() != 1; 2252 })) 2253 return false; 2254 2255 // Check all size values are static and matches the (static) source shape. 2256 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 2257 for (const auto &size : llvm::enumerate(mixedSizes)) { 2258 Optional<int64_t> intValue = getConstantIntValue(size.value()); 2259 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 2260 return false; 2261 } 2262 // All conditions met. The `SubViewOp` is foldable as a no-op. 2263 return true; 2264 } 2265 2266 namespace { 2267 /// Pattern to rewrite a subview op with MemRefCast arguments. 2268 /// This essentially pushes memref.cast past its consuming subview when 2269 /// `canFoldIntoConsumerOp` is true. 2270 /// 2271 /// Example: 2272 /// ``` 2273 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2274 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2275 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2276 /// ``` 2277 /// is rewritten into: 2278 /// ``` 2279 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2280 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2281 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2282 /// ``` 2283 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2284 public: 2285 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2286 2287 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2288 PatternRewriter &rewriter) const override { 2289 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2290 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2291 return matchPattern(operand, matchConstantIndex()); 2292 })) 2293 return failure(); 2294 2295 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2296 if (!castOp) 2297 return failure(); 2298 2299 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2300 return failure(); 2301 2302 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 2303 // MemRefCastOp source operand type to infer the result type and the current 2304 // SubViewOp source operand type to compute the dropped dimensions if the 2305 // operation is rank-reducing. 2306 auto resultType = getCanonicalSubViewResultType( 2307 subViewOp.getType(), subViewOp.getSourceType(), 2308 castOp.source().getType().cast<MemRefType>(), 2309 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2310 subViewOp.getMixedStrides()); 2311 if (!resultType) 2312 return failure(); 2313 2314 Value newSubView = rewriter.create<SubViewOp>( 2315 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2316 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2317 subViewOp.static_sizes(), subViewOp.static_strides()); 2318 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2319 newSubView); 2320 return success(); 2321 } 2322 }; 2323 2324 /// Canonicalize subview ops that are no-ops. When the source shape is not same 2325 /// as a result shape due to use of `affine_map`. 2326 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 2327 public: 2328 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2329 2330 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2331 PatternRewriter &rewriter) const override { 2332 if (!isTrivialSubViewOp(subViewOp)) 2333 return failure(); 2334 if (subViewOp.getSourceType() == subViewOp.getType()) { 2335 rewriter.replaceOp(subViewOp, subViewOp.source()); 2336 return success(); 2337 } 2338 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2339 subViewOp.source()); 2340 return success(); 2341 } 2342 }; 2343 } // namespace 2344 2345 /// Return the canonical type of the result of a subview. 2346 struct SubViewReturnTypeCanonicalizer { 2347 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2348 ArrayRef<OpFoldResult> mixedSizes, 2349 ArrayRef<OpFoldResult> mixedStrides) { 2350 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2351 mixedOffsets, mixedSizes, 2352 mixedStrides); 2353 } 2354 }; 2355 2356 /// A canonicalizer wrapper to replace SubViewOps. 2357 struct SubViewCanonicalizer { 2358 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2359 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 2360 } 2361 }; 2362 2363 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2364 MLIRContext *context) { 2365 results 2366 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2367 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2368 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2369 } 2370 2371 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2372 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2373 auto sourceShapedType = source().getType().cast<ShapedType>(); 2374 2375 if (resultShapedType.hasStaticShape() && 2376 resultShapedType == sourceShapedType) { 2377 return getViewSource(); 2378 } 2379 2380 return {}; 2381 } 2382 2383 //===----------------------------------------------------------------------===// 2384 // TransposeOp 2385 //===----------------------------------------------------------------------===// 2386 2387 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2388 static MemRefType inferTransposeResultType(MemRefType memRefType, 2389 AffineMap permutationMap) { 2390 auto rank = memRefType.getRank(); 2391 auto originalSizes = memRefType.getShape(); 2392 // Compute permuted sizes. 2393 SmallVector<int64_t, 4> sizes(rank, 0); 2394 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2395 sizes[en.index()] = 2396 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2397 2398 // Compute permuted strides. 2399 int64_t offset; 2400 SmallVector<int64_t, 4> strides; 2401 auto res = getStridesAndOffset(memRefType, strides, offset); 2402 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2403 (void)res; 2404 auto map = 2405 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2406 map = permutationMap ? map.compose(permutationMap) : map; 2407 return MemRefType::Builder(memRefType) 2408 .setShape(sizes) 2409 .setLayout(AffineMapAttr::get(map)); 2410 } 2411 2412 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2413 AffineMapAttr permutation, 2414 ArrayRef<NamedAttribute> attrs) { 2415 auto permutationMap = permutation.getValue(); 2416 assert(permutationMap); 2417 2418 auto memRefType = in.getType().cast<MemRefType>(); 2419 // Compute result type. 2420 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2421 2422 build(b, result, resultType, in, attrs); 2423 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2424 } 2425 2426 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2427 void TransposeOp::print(OpAsmPrinter &p) { 2428 p << " " << in() << " " << permutation(); 2429 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); 2430 p << " : " << in().getType() << " to " << getType(); 2431 } 2432 2433 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { 2434 OpAsmParser::OperandType in; 2435 AffineMap permutation; 2436 MemRefType srcType, dstType; 2437 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2438 parser.parseOptionalAttrDict(result.attributes) || 2439 parser.parseColonType(srcType) || 2440 parser.resolveOperand(in, srcType, result.operands) || 2441 parser.parseKeywordType("to", dstType) || 2442 parser.addTypeToList(dstType, result.types)) 2443 return failure(); 2444 2445 result.addAttribute(TransposeOp::getPermutationAttrName(), 2446 AffineMapAttr::get(permutation)); 2447 return success(); 2448 } 2449 2450 LogicalResult TransposeOp::verify() { 2451 if (!permutation().isPermutation()) 2452 return emitOpError("expected a permutation map"); 2453 if (permutation().getNumDims() != getShapedType().getRank()) 2454 return emitOpError("expected a permutation map of same rank as the input"); 2455 2456 auto srcType = in().getType().cast<MemRefType>(); 2457 auto dstType = getType().cast<MemRefType>(); 2458 auto transposedType = inferTransposeResultType(srcType, permutation()); 2459 if (dstType != transposedType) 2460 return emitOpError("output type ") 2461 << dstType << " does not match transposed input type " << srcType 2462 << ", " << transposedType; 2463 return success(); 2464 } 2465 2466 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2467 if (succeeded(foldMemRefCast(*this))) 2468 return getResult(); 2469 return {}; 2470 } 2471 2472 //===----------------------------------------------------------------------===// 2473 // ViewOp 2474 //===----------------------------------------------------------------------===// 2475 2476 LogicalResult ViewOp::verify() { 2477 auto baseType = getOperand(0).getType().cast<MemRefType>(); 2478 auto viewType = getType(); 2479 2480 // The base memref should have identity layout map (or none). 2481 if (!baseType.getLayout().isIdentity()) 2482 return emitError("unsupported map for base memref type ") << baseType; 2483 2484 // The result memref should have identity layout map (or none). 2485 if (!viewType.getLayout().isIdentity()) 2486 return emitError("unsupported map for result memref type ") << viewType; 2487 2488 // The base memref and the view memref should be in the same memory space. 2489 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2490 return emitError("different memory spaces specified for base memref " 2491 "type ") 2492 << baseType << " and view memref type " << viewType; 2493 2494 // Verify that we have the correct number of sizes for the result type. 2495 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2496 if (sizes().size() != numDynamicDims) 2497 return emitError("incorrect number of size operands for type ") << viewType; 2498 2499 return success(); 2500 } 2501 2502 Value ViewOp::getViewSource() { return source(); } 2503 2504 namespace { 2505 2506 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2507 using OpRewritePattern<ViewOp>::OpRewritePattern; 2508 2509 LogicalResult matchAndRewrite(ViewOp viewOp, 2510 PatternRewriter &rewriter) const override { 2511 // Return if none of the operands are constants. 2512 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2513 return matchPattern(operand, matchConstantIndex()); 2514 })) 2515 return failure(); 2516 2517 // Get result memref type. 2518 auto memrefType = viewOp.getType(); 2519 2520 // Get offset from old memref view type 'memRefType'. 2521 int64_t oldOffset; 2522 SmallVector<int64_t, 4> oldStrides; 2523 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2524 return failure(); 2525 assert(oldOffset == 0 && "Expected 0 offset"); 2526 2527 SmallVector<Value, 4> newOperands; 2528 2529 // Offset cannot be folded into result type. 2530 2531 // Fold any dynamic dim operands which are produced by a constant. 2532 SmallVector<int64_t, 4> newShapeConstants; 2533 newShapeConstants.reserve(memrefType.getRank()); 2534 2535 unsigned dynamicDimPos = 0; 2536 unsigned rank = memrefType.getRank(); 2537 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2538 int64_t dimSize = memrefType.getDimSize(dim); 2539 // If this is already static dimension, keep it. 2540 if (!ShapedType::isDynamic(dimSize)) { 2541 newShapeConstants.push_back(dimSize); 2542 continue; 2543 } 2544 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2545 if (auto constantIndexOp = 2546 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2547 // Dynamic shape dimension will be folded. 2548 newShapeConstants.push_back(constantIndexOp.value()); 2549 } else { 2550 // Dynamic shape dimension not folded; copy operand from old memref. 2551 newShapeConstants.push_back(dimSize); 2552 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2553 } 2554 dynamicDimPos++; 2555 } 2556 2557 // Create new memref type with constant folded dims. 2558 MemRefType newMemRefType = 2559 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2560 // Nothing new, don't fold. 2561 if (newMemRefType == memrefType) 2562 return failure(); 2563 2564 // Create new ViewOp. 2565 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2566 viewOp.getOperand(0), 2567 viewOp.byte_shift(), newOperands); 2568 // Insert a cast so we have the same type as the old memref type. 2569 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp); 2570 return success(); 2571 } 2572 }; 2573 2574 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2575 using OpRewritePattern<ViewOp>::OpRewritePattern; 2576 2577 LogicalResult matchAndRewrite(ViewOp viewOp, 2578 PatternRewriter &rewriter) const override { 2579 Value memrefOperand = viewOp.getOperand(0); 2580 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2581 if (!memrefCastOp) 2582 return failure(); 2583 Value allocOperand = memrefCastOp.getOperand(); 2584 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2585 if (!allocOp) 2586 return failure(); 2587 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2588 viewOp.byte_shift(), viewOp.sizes()); 2589 return success(); 2590 } 2591 }; 2592 2593 } // namespace 2594 2595 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2596 MLIRContext *context) { 2597 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2598 } 2599 2600 //===----------------------------------------------------------------------===// 2601 // AtomicRMWOp 2602 //===----------------------------------------------------------------------===// 2603 2604 LogicalResult AtomicRMWOp::verify() { 2605 if (getMemRefType().getRank() != getNumOperands() - 2) 2606 return emitOpError( 2607 "expects the number of subscripts to be equal to memref rank"); 2608 switch (kind()) { 2609 case arith::AtomicRMWKind::addf: 2610 case arith::AtomicRMWKind::maxf: 2611 case arith::AtomicRMWKind::minf: 2612 case arith::AtomicRMWKind::mulf: 2613 if (!value().getType().isa<FloatType>()) 2614 return emitOpError() << "with kind '" 2615 << arith::stringifyAtomicRMWKind(kind()) 2616 << "' expects a floating-point type"; 2617 break; 2618 case arith::AtomicRMWKind::addi: 2619 case arith::AtomicRMWKind::maxs: 2620 case arith::AtomicRMWKind::maxu: 2621 case arith::AtomicRMWKind::mins: 2622 case arith::AtomicRMWKind::minu: 2623 case arith::AtomicRMWKind::muli: 2624 case arith::AtomicRMWKind::ori: 2625 case arith::AtomicRMWKind::andi: 2626 if (!value().getType().isa<IntegerType>()) 2627 return emitOpError() << "with kind '" 2628 << arith::stringifyAtomicRMWKind(kind()) 2629 << "' expects an integer type"; 2630 break; 2631 default: 2632 break; 2633 } 2634 return success(); 2635 } 2636 2637 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2638 /// atomicrmw(memrefcast) -> atomicrmw 2639 if (succeeded(foldMemRefCast(*this, value()))) 2640 return getResult(); 2641 return OpFoldResult(); 2642 } 2643 2644 //===----------------------------------------------------------------------===// 2645 // TableGen'd op method definitions 2646 //===----------------------------------------------------------------------===// 2647 2648 #define GET_OP_CLASSES 2649 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2650