1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 13 #include "mlir/Dialect/Utils/StaticValueUtils.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/SideEffectInterfaces.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/SmallBitVector.h" 25 26 using namespace mlir; 27 using namespace mlir::memref; 28 29 /// Materialize a single constant operation from a given attribute value with 30 /// the desired resultant type. 31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 32 Attribute value, Type type, 33 Location loc) { 34 if (arith::ConstantOp::isBuildableWith(value, type)) 35 return builder.create<arith::ConstantOp>(loc, value, type); 36 return nullptr; 37 } 38 39 //===----------------------------------------------------------------------===// 40 // Common canonicalization pattern support logic 41 //===----------------------------------------------------------------------===// 42 43 /// This is a common class used for patterns of the form 44 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 45 /// into the root operation directly. 46 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 47 bool folded = false; 48 for (OpOperand &operand : op->getOpOperands()) { 49 auto cast = operand.get().getDefiningOp<CastOp>(); 50 if (cast && operand.get() != inner && 51 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 52 operand.set(cast.getOperand()); 53 folded = true; 54 } 55 } 56 return success(folded); 57 } 58 59 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 60 /// type. 61 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 62 if (auto memref = type.dyn_cast<MemRefType>()) 63 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 64 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 65 return UnrankedTensorType::get(memref.getElementType()); 66 return NoneType::get(type.getContext()); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // AllocOp / AllocaOp 71 //===----------------------------------------------------------------------===// 72 73 template <typename AllocLikeOp> 74 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 75 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 76 "applies to only alloc or alloca"); 77 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 78 if (!memRefType) 79 return op.emitOpError("result must be a memref"); 80 81 if (static_cast<int64_t>(op.dynamicSizes().size()) != 82 memRefType.getNumDynamicDims()) 83 return op.emitOpError("dimension operand count does not equal memref " 84 "dynamic dimension count"); 85 86 unsigned numSymbols = 0; 87 if (!memRefType.getLayout().isIdentity()) 88 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 89 if (op.symbolOperands().size() != numSymbols) 90 return op.emitOpError("symbol operand count does not equal memref symbol " 91 "count: expected ") 92 << numSymbols << ", got " << op.symbolOperands().size(); 93 94 return success(); 95 } 96 97 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } 98 99 LogicalResult AllocaOp::verify() { 100 // An alloca op needs to have an ancestor with an allocation scope trait. 101 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 102 return emitOpError( 103 "requires an ancestor op with AutomaticAllocationScope trait"); 104 105 return verifyAllocLikeOp(*this); 106 } 107 108 namespace { 109 /// Fold constant dimensions into an alloc like operation. 110 template <typename AllocLikeOp> 111 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 112 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 113 114 LogicalResult matchAndRewrite(AllocLikeOp alloc, 115 PatternRewriter &rewriter) const override { 116 // Check to see if any dimensions operands are constants. If so, we can 117 // substitute and drop them. 118 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 119 return matchPattern(operand, matchConstantIndex()); 120 })) 121 return failure(); 122 123 auto memrefType = alloc.getType(); 124 125 // Ok, we have one or more constant operands. Collect the non-constant ones 126 // and keep track of the resultant memref type to build. 127 SmallVector<int64_t, 4> newShapeConstants; 128 newShapeConstants.reserve(memrefType.getRank()); 129 SmallVector<Value, 4> dynamicSizes; 130 131 unsigned dynamicDimPos = 0; 132 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 133 int64_t dimSize = memrefType.getDimSize(dim); 134 // If this is already static dimension, keep it. 135 if (dimSize != -1) { 136 newShapeConstants.push_back(dimSize); 137 continue; 138 } 139 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 140 auto *defOp = dynamicSize.getDefiningOp(); 141 if (auto constantIndexOp = 142 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 143 // Dynamic shape dimension will be folded. 144 newShapeConstants.push_back(constantIndexOp.value()); 145 } else { 146 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 147 newShapeConstants.push_back(-1); 148 dynamicSizes.push_back(dynamicSize); 149 } 150 dynamicDimPos++; 151 } 152 153 // Create new memref type (which will have fewer dynamic dimensions). 154 MemRefType newMemRefType = 155 MemRefType::Builder(memrefType).setShape(newShapeConstants); 156 assert(static_cast<int64_t>(dynamicSizes.size()) == 157 newMemRefType.getNumDynamicDims()); 158 159 // Create and insert the alloc op for the new memref. 160 auto newAlloc = rewriter.create<AllocLikeOp>( 161 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 162 alloc.alignmentAttr()); 163 // Insert a cast so we have the same type as the old alloc. 164 auto resultCast = 165 rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc); 166 167 rewriter.replaceOp(alloc, {resultCast}); 168 return success(); 169 } 170 }; 171 172 /// Fold alloc operations with no users or only store and dealloc uses. 173 template <typename T> 174 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 175 using OpRewritePattern<T>::OpRewritePattern; 176 177 LogicalResult matchAndRewrite(T alloc, 178 PatternRewriter &rewriter) const override { 179 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 180 if (auto storeOp = dyn_cast<StoreOp>(op)) 181 return storeOp.value() == alloc; 182 return !isa<DeallocOp>(op); 183 })) 184 return failure(); 185 186 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 187 rewriter.eraseOp(user); 188 189 rewriter.eraseOp(alloc); 190 return success(); 191 } 192 }; 193 } // namespace 194 195 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 196 MLIRContext *context) { 197 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 198 } 199 200 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 201 MLIRContext *context) { 202 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 203 context); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // AllocaScopeOp 208 //===----------------------------------------------------------------------===// 209 210 void AllocaScopeOp::print(OpAsmPrinter &p) { 211 bool printBlockTerminators = false; 212 213 p << ' '; 214 if (!results().empty()) { 215 p << " -> (" << getResultTypes() << ")"; 216 printBlockTerminators = true; 217 } 218 p << ' '; 219 p.printRegion(bodyRegion(), 220 /*printEntryBlockArgs=*/false, 221 /*printBlockTerminators=*/printBlockTerminators); 222 p.printOptionalAttrDict((*this)->getAttrs()); 223 } 224 225 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { 226 // Create a region for the body. 227 result.regions.reserve(1); 228 Region *bodyRegion = result.addRegion(); 229 230 // Parse optional results type list. 231 if (parser.parseOptionalArrowTypeList(result.types)) 232 return failure(); 233 234 // Parse the body region. 235 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 236 return failure(); 237 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 238 result.location); 239 240 // Parse the optional attribute list. 241 if (parser.parseOptionalAttrDict(result.attributes)) 242 return failure(); 243 244 return success(); 245 } 246 247 LogicalResult AllocaScopeOp::verify() { 248 return RegionBranchOpInterface::verifyTypes(*this); 249 } 250 251 void AllocaScopeOp::getSuccessorRegions( 252 Optional<unsigned> index, ArrayRef<Attribute> operands, 253 SmallVectorImpl<RegionSuccessor> ®ions) { 254 if (index.hasValue()) { 255 regions.push_back(RegionSuccessor(getResults())); 256 return; 257 } 258 259 regions.push_back(RegionSuccessor(&bodyRegion())); 260 } 261 262 /// Given an operation, return whether this op is guaranteed to 263 /// allocate an AutomaticAllocationScopeResource 264 static bool isGuaranteedAutomaticAllocationScope(Operation *op) { 265 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 266 if (!interface) 267 return false; 268 for (auto res : op->getResults()) { 269 if (auto effect = 270 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 271 if (isa<SideEffects::AutomaticAllocationScopeResource>( 272 effect->getResource())) 273 return true; 274 } 275 } 276 return false; 277 } 278 279 /// Given an operation, return whether this op could to 280 /// allocate an AutomaticAllocationScopeResource 281 static bool isPotentialAutomaticAllocationScope(Operation *op) { 282 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); 283 if (!interface) 284 return true; 285 for (auto res : op->getResults()) { 286 if (auto effect = 287 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) { 288 if (isa<SideEffects::AutomaticAllocationScopeResource>( 289 effect->getResource())) 290 return true; 291 } 292 } 293 return false; 294 } 295 296 /// Return whether this op is the last non terminating op 297 /// in a region. That is to say, it is in a one-block region 298 /// and is only followed by a terminator. This prevents 299 /// extending the lifetime of allocations. 300 static bool lastNonTerminatorInRegion(Operation *op) { 301 return op->getNextNode() == op->getBlock()->getTerminator() && 302 op->getParentRegion()->getBlocks().size() == 1; 303 } 304 305 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope 306 /// or it contains no allocation. 307 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> { 308 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 309 310 LogicalResult matchAndRewrite(AllocaScopeOp op, 311 PatternRewriter &rewriter) const override { 312 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>()) { 313 bool hasPotentialAlloca = 314 op->walk([&](Operation *alloc) { 315 if (isPotentialAutomaticAllocationScope(alloc)) 316 return WalkResult::interrupt(); 317 return WalkResult::skip(); 318 }).wasInterrupted(); 319 if (hasPotentialAlloca) 320 return failure(); 321 } 322 323 // Only apply to if this is this last non-terminator 324 // op in the block (lest lifetime be extended) of a one 325 // block region 326 if (!lastNonTerminatorInRegion(op)) 327 return failure(); 328 329 Block *block = &op.getRegion().front(); 330 Operation *terminator = block->getTerminator(); 331 ValueRange results = terminator->getOperands(); 332 rewriter.mergeBlockBefore(block, op); 333 rewriter.replaceOp(op, results); 334 rewriter.eraseOp(terminator); 335 return success(); 336 } 337 }; 338 339 /// Move allocations into an allocation scope, if it is legal to 340 /// move them (e.g. their operands are available at the location 341 /// the op would be moved to). 342 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> { 343 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern; 344 345 LogicalResult matchAndRewrite(AllocaScopeOp op, 346 PatternRewriter &rewriter) const override { 347 348 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 349 return failure(); 350 351 Operation *lastParentWithoutScope = op->getParentOp(); 352 353 if (!lastParentWithoutScope || 354 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>()) 355 return failure(); 356 357 // Only apply to if this is this last non-terminator 358 // op in the block (lest lifetime be extended) of a one 359 // block region 360 if (!lastNonTerminatorInRegion(op) || 361 !lastNonTerminatorInRegion(lastParentWithoutScope)) 362 return failure(); 363 364 while (!lastParentWithoutScope->getParentOp() 365 ->hasTrait<OpTrait::AutomaticAllocationScope>()) { 366 lastParentWithoutScope = lastParentWithoutScope->getParentOp(); 367 if (!lastParentWithoutScope || 368 !lastNonTerminatorInRegion(lastParentWithoutScope)) 369 return failure(); 370 } 371 assert(lastParentWithoutScope->getParentOp() 372 ->hasTrait<OpTrait::AutomaticAllocationScope>()); 373 374 Region *containingRegion = nullptr; 375 for (auto &r : lastParentWithoutScope->getRegions()) { 376 if (r.isAncestor(op->getParentRegion())) { 377 assert(containingRegion == nullptr && 378 "only one region can contain the op"); 379 containingRegion = &r; 380 } 381 } 382 assert(containingRegion && "op must be contained in a region"); 383 384 SmallVector<Operation *> toHoist; 385 op->walk([&](Operation *alloc) { 386 if (!isGuaranteedAutomaticAllocationScope(alloc)) 387 return WalkResult::skip(); 388 389 // If any operand is not defined before the location of 390 // lastParentWithoutScope (i.e. where we would hoist to), skip. 391 if (llvm::any_of(alloc->getOperands(), [&](Value v) { 392 return containingRegion->isAncestor(v.getParentRegion()); 393 })) 394 return WalkResult::skip(); 395 toHoist.push_back(alloc); 396 return WalkResult::advance(); 397 }); 398 399 if (!toHoist.size()) 400 return failure(); 401 rewriter.setInsertionPoint(lastParentWithoutScope); 402 for (auto op : toHoist) { 403 auto cloned = rewriter.clone(*op); 404 rewriter.replaceOp(op, cloned->getResults()); 405 } 406 return success(); 407 } 408 }; 409 410 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results, 411 MLIRContext *context) { 412 results.add<AllocaScopeInliner, AllocaScopeHoister>(context); 413 } 414 415 //===----------------------------------------------------------------------===// 416 // AssumeAlignmentOp 417 //===----------------------------------------------------------------------===// 418 419 LogicalResult AssumeAlignmentOp::verify() { 420 if (!llvm::isPowerOf2_32(alignment())) 421 return emitOpError("alignment must be power of 2"); 422 return success(); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // CastOp 427 //===----------------------------------------------------------------------===// 428 429 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 430 /// source memref. This is useful to to fold a memref.cast into a consuming op 431 /// and implement canonicalization patterns for ops in different dialects that 432 /// may consume the results of memref.cast operations. Such foldable memref.cast 433 /// operations are typically inserted as `view` and `subview` ops are 434 /// canonicalized, to preserve the type compatibility of their uses. 435 /// 436 /// Returns true when all conditions are met: 437 /// 1. source and result are ranked memrefs with strided semantics and same 438 /// element type and rank. 439 /// 2. each of the source's size, offset or stride has more static information 440 /// than the corresponding result's size, offset or stride. 441 /// 442 /// Example 1: 443 /// ```mlir 444 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 445 /// %2 = consumer %1 ... : memref<?x?xf32> ... 446 /// ``` 447 /// 448 /// may fold into: 449 /// 450 /// ```mlir 451 /// %2 = consumer %0 ... : memref<8x16xf32> ... 452 /// ``` 453 /// 454 /// Example 2: 455 /// ``` 456 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 457 /// to memref<?x?xf32> 458 /// consumer %1 : memref<?x?xf32> ... 459 /// ``` 460 /// 461 /// may fold into: 462 /// 463 /// ``` 464 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 465 /// ``` 466 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 467 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 468 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 469 470 // Requires ranked MemRefType. 471 if (!sourceType || !resultType) 472 return false; 473 474 // Requires same elemental type. 475 if (sourceType.getElementType() != resultType.getElementType()) 476 return false; 477 478 // Requires same rank. 479 if (sourceType.getRank() != resultType.getRank()) 480 return false; 481 482 // Only fold casts between strided memref forms. 483 int64_t sourceOffset, resultOffset; 484 SmallVector<int64_t, 4> sourceStrides, resultStrides; 485 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 486 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 487 return false; 488 489 // If cast is towards more static sizes along any dimension, don't fold. 490 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 491 auto ss = std::get<0>(it), st = std::get<1>(it); 492 if (ss != st) 493 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 494 return false; 495 } 496 497 // If cast is towards more static offset along any dimension, don't fold. 498 if (sourceOffset != resultOffset) 499 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 500 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 501 return false; 502 503 // If cast is towards more static strides along any dimension, don't fold. 504 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 505 auto ss = std::get<0>(it), st = std::get<1>(it); 506 if (ss != st) 507 if (ShapedType::isDynamicStrideOrOffset(ss) && 508 !ShapedType::isDynamicStrideOrOffset(st)) 509 return false; 510 } 511 512 return true; 513 } 514 515 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 516 if (inputs.size() != 1 || outputs.size() != 1) 517 return false; 518 Type a = inputs.front(), b = outputs.front(); 519 auto aT = a.dyn_cast<MemRefType>(); 520 auto bT = b.dyn_cast<MemRefType>(); 521 522 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 523 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 524 525 if (aT && bT) { 526 if (aT.getElementType() != bT.getElementType()) 527 return false; 528 if (aT.getLayout() != bT.getLayout()) { 529 int64_t aOffset, bOffset; 530 SmallVector<int64_t, 4> aStrides, bStrides; 531 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 532 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 533 aStrides.size() != bStrides.size()) 534 return false; 535 536 // Strides along a dimension/offset are compatible if the value in the 537 // source memref is static and the value in the target memref is the 538 // same. They are also compatible if either one is dynamic (see 539 // description of MemRefCastOp for details). 540 auto checkCompatible = [](int64_t a, int64_t b) { 541 return (a == MemRefType::getDynamicStrideOrOffset() || 542 b == MemRefType::getDynamicStrideOrOffset() || a == b); 543 }; 544 if (!checkCompatible(aOffset, bOffset)) 545 return false; 546 for (const auto &aStride : enumerate(aStrides)) 547 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 548 return false; 549 } 550 if (aT.getMemorySpace() != bT.getMemorySpace()) 551 return false; 552 553 // They must have the same rank, and any specified dimensions must match. 554 if (aT.getRank() != bT.getRank()) 555 return false; 556 557 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 558 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 559 if (aDim != -1 && bDim != -1 && aDim != bDim) 560 return false; 561 } 562 return true; 563 } else { 564 if (!aT && !uaT) 565 return false; 566 if (!bT && !ubT) 567 return false; 568 // Unranked to unranked casting is unsupported 569 if (uaT && ubT) 570 return false; 571 572 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 573 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 574 if (aEltType != bEltType) 575 return false; 576 577 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 578 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 579 return aMemSpace == bMemSpace; 580 } 581 582 return false; 583 } 584 585 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 586 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 587 } 588 589 //===----------------------------------------------------------------------===// 590 // CopyOp 591 //===----------------------------------------------------------------------===// 592 593 namespace { 594 /// If the source/target of a CopyOp is a CastOp that does not modify the shape 595 /// and element type, the cast can be skipped. Such CastOps only cast the layout 596 /// of the type. 597 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> { 598 using OpRewritePattern<CopyOp>::OpRewritePattern; 599 600 LogicalResult matchAndRewrite(CopyOp copyOp, 601 PatternRewriter &rewriter) const override { 602 bool modified = false; 603 604 // Check source. 605 if (auto castOp = copyOp.source().getDefiningOp<CastOp>()) { 606 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 607 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 608 609 if (fromType && toType) { 610 if (fromType.getShape() == toType.getShape() && 611 fromType.getElementType() == toType.getElementType()) { 612 rewriter.updateRootInPlace( 613 copyOp, [&] { copyOp.sourceMutable().assign(castOp.source()); }); 614 modified = true; 615 } 616 } 617 } 618 619 // Check target. 620 if (auto castOp = copyOp.target().getDefiningOp<CastOp>()) { 621 auto fromType = castOp.source().getType().dyn_cast<MemRefType>(); 622 auto toType = castOp.source().getType().dyn_cast<MemRefType>(); 623 624 if (fromType && toType) { 625 if (fromType.getShape() == toType.getShape() && 626 fromType.getElementType() == toType.getElementType()) { 627 rewriter.updateRootInPlace( 628 copyOp, [&] { copyOp.targetMutable().assign(castOp.source()); }); 629 modified = true; 630 } 631 } 632 } 633 634 return success(modified); 635 } 636 }; 637 638 /// Fold memref.copy(%x, %x). 639 struct FoldSelfCopy : public OpRewritePattern<CopyOp> { 640 using OpRewritePattern<CopyOp>::OpRewritePattern; 641 642 LogicalResult matchAndRewrite(CopyOp copyOp, 643 PatternRewriter &rewriter) const override { 644 if (copyOp.source() != copyOp.target()) 645 return failure(); 646 647 rewriter.eraseOp(copyOp); 648 return success(); 649 } 650 }; 651 } // namespace 652 653 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, 654 MLIRContext *context) { 655 results.add<FoldCopyOfCast, FoldSelfCopy>(context); 656 } 657 658 LogicalResult CopyOp::fold(ArrayRef<Attribute> cstOperands, 659 SmallVectorImpl<OpFoldResult> &results) { 660 /// copy(memrefcast) -> copy 661 bool folded = false; 662 Operation *op = *this; 663 for (OpOperand &operand : op->getOpOperands()) { 664 auto castOp = operand.get().getDefiningOp<memref::CastOp>(); 665 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { 666 operand.set(castOp.getOperand()); 667 folded = true; 668 } 669 } 670 return success(folded); 671 } 672 673 //===----------------------------------------------------------------------===// 674 // DeallocOp 675 //===----------------------------------------------------------------------===// 676 677 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 678 SmallVectorImpl<OpFoldResult> &results) { 679 /// dealloc(memrefcast) -> dealloc 680 return foldMemRefCast(*this); 681 } 682 683 //===----------------------------------------------------------------------===// 684 // DimOp 685 //===----------------------------------------------------------------------===// 686 687 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 688 int64_t index) { 689 auto loc = result.location; 690 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 691 build(builder, result, source, indexValue); 692 } 693 694 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 695 Value index) { 696 auto indexTy = builder.getIndexType(); 697 build(builder, result, indexTy, source, index); 698 } 699 700 Optional<int64_t> DimOp::getConstantIndex() { 701 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 702 return constantOp.getValue().cast<IntegerAttr>().getInt(); 703 return {}; 704 } 705 706 LogicalResult DimOp::verify() { 707 // Assume unknown index to be in range. 708 Optional<int64_t> index = getConstantIndex(); 709 if (!index.hasValue()) 710 return success(); 711 712 // Check that constant index is not knowingly out of range. 713 auto type = source().getType(); 714 if (auto memrefType = type.dyn_cast<MemRefType>()) { 715 if (index.getValue() >= memrefType.getRank()) 716 return emitOpError("index is out of range"); 717 } else if (type.isa<UnrankedMemRefType>()) { 718 // Assume index to be in range. 719 } else { 720 llvm_unreachable("expected operand with memref type"); 721 } 722 return success(); 723 } 724 725 /// Return a map with key being elements in `vals` and data being number of 726 /// occurences of it. Use std::map, since the `vals` here are strides and the 727 /// dynamic stride value is the same as the tombstone value for 728 /// `DenseMap<int64_t>`. 729 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 730 std::map<int64_t, unsigned> numOccurences; 731 for (auto val : vals) 732 numOccurences[val]++; 733 return numOccurences; 734 } 735 736 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 737 /// to be a subset of `originalType` with some `1` entries erased, return the 738 /// set of indices that specifies which of the entries of `originalShape` are 739 /// dropped to obtain `reducedShape`. 740 /// This accounts for cases where there are multiple unit-dims, but only a 741 /// subset of those are dropped. For MemRefTypes these can be disambiguated 742 /// using the strides. If a dimension is dropped the stride must be dropped too. 743 static llvm::Optional<llvm::SmallBitVector> 744 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 745 ArrayRef<OpFoldResult> sizes) { 746 llvm::SmallBitVector unusedDims(originalType.getRank()); 747 if (originalType.getRank() == reducedType.getRank()) 748 return unusedDims; 749 750 for (const auto &dim : llvm::enumerate(sizes)) 751 if (auto attr = dim.value().dyn_cast<Attribute>()) 752 if (attr.cast<IntegerAttr>().getInt() == 1) 753 unusedDims.set(dim.index()); 754 755 SmallVector<int64_t> originalStrides, candidateStrides; 756 int64_t originalOffset, candidateOffset; 757 if (failed( 758 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 759 failed( 760 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 761 return llvm::None; 762 763 // For memrefs, a dimension is truly dropped if its corresponding stride is 764 // also dropped. This is particularly important when more than one of the dims 765 // is 1. Track the number of occurences of the strides in the original type 766 // and the candidate type. For each unused dim that stride should not be 767 // present in the candidate type. Note that there could be multiple dimensions 768 // that have the same size. We dont need to exactly figure out which dim 769 // corresponds to which stride, we just need to verify that the number of 770 // reptitions of a stride in the original + number of unused dims with that 771 // stride == number of repititions of a stride in the candidate. 772 std::map<int64_t, unsigned> currUnaccountedStrides = 773 getNumOccurences(originalStrides); 774 std::map<int64_t, unsigned> candidateStridesNumOccurences = 775 getNumOccurences(candidateStrides); 776 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) { 777 if (!unusedDims.test(dim)) 778 continue; 779 int64_t originalStride = originalStrides[dim]; 780 if (currUnaccountedStrides[originalStride] > 781 candidateStridesNumOccurences[originalStride]) { 782 // This dim can be treated as dropped. 783 currUnaccountedStrides[originalStride]--; 784 continue; 785 } 786 if (currUnaccountedStrides[originalStride] == 787 candidateStridesNumOccurences[originalStride]) { 788 // The stride for this is not dropped. Keep as is. 789 unusedDims.reset(dim); 790 continue; 791 } 792 if (currUnaccountedStrides[originalStride] < 793 candidateStridesNumOccurences[originalStride]) { 794 // This should never happen. Cant have a stride in the reduced rank type 795 // that wasnt in the original one. 796 return llvm::None; 797 } 798 } 799 800 if ((int64_t)unusedDims.count() + reducedType.getRank() != 801 originalType.getRank()) 802 return llvm::None; 803 return unusedDims; 804 } 805 806 llvm::SmallBitVector SubViewOp::getDroppedDims() { 807 MemRefType sourceType = getSourceType(); 808 MemRefType resultType = getType(); 809 llvm::Optional<llvm::SmallBitVector> unusedDims = 810 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 811 assert(unusedDims && "unable to find unused dims of subview"); 812 return *unusedDims; 813 } 814 815 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 816 // All forms of folding require a known index. 817 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 818 if (!index) 819 return {}; 820 821 // Folding for unranked types (UnrankedMemRefType) is not supported. 822 auto memrefType = source().getType().dyn_cast<MemRefType>(); 823 if (!memrefType) 824 return {}; 825 826 // Fold if the shape extent along the given index is known. 827 if (!memrefType.isDynamicDim(index.getInt())) { 828 Builder builder(getContext()); 829 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 830 } 831 832 // The size at the given index is now known to be a dynamic size. 833 unsigned unsignedIndex = index.getValue().getZExtValue(); 834 835 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 836 Operation *definingOp = source().getDefiningOp(); 837 838 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 839 return *(alloc.getDynamicSizes().begin() + 840 memrefType.getDynamicDimIndex(unsignedIndex)); 841 842 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 843 return *(alloca.getDynamicSizes().begin() + 844 memrefType.getDynamicDimIndex(unsignedIndex)); 845 846 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 847 return *(view.getDynamicSizes().begin() + 848 memrefType.getDynamicDimIndex(unsignedIndex)); 849 850 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 851 llvm::SmallBitVector unusedDims = subview.getDroppedDims(); 852 unsigned resultIndex = 0; 853 unsigned sourceRank = subview.getSourceType().getRank(); 854 unsigned sourceIndex = 0; 855 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 856 if (unusedDims.test(i)) 857 continue; 858 if (resultIndex == unsignedIndex) { 859 sourceIndex = i; 860 break; 861 } 862 resultIndex++; 863 } 864 assert(subview.isDynamicSize(sourceIndex) && 865 "expected dynamic subview size"); 866 return subview.getDynamicSize(sourceIndex); 867 } 868 869 if (auto sizeInterface = 870 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 871 assert(sizeInterface.isDynamicSize(unsignedIndex) && 872 "Expected dynamic subview size"); 873 return sizeInterface.getDynamicSize(unsignedIndex); 874 } 875 876 // dim(memrefcast) -> dim 877 if (succeeded(foldMemRefCast(*this))) 878 return getResult(); 879 880 return {}; 881 } 882 883 namespace { 884 /// Fold dim of a memref reshape operation to a load into the reshape's shape 885 /// operand. 886 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 887 using OpRewritePattern<DimOp>::OpRewritePattern; 888 889 LogicalResult matchAndRewrite(DimOp dim, 890 PatternRewriter &rewriter) const override { 891 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 892 893 if (!reshape) 894 return failure(); 895 896 // Place the load directly after the reshape to ensure that the shape memref 897 // was not mutated. 898 rewriter.setInsertionPointAfter(reshape); 899 Location loc = dim.getLoc(); 900 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 901 if (load.getType() != dim.getType()) 902 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 903 rewriter.replaceOp(dim, load); 904 return success(); 905 } 906 }; 907 908 } // namespace 909 910 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 911 MLIRContext *context) { 912 results.add<DimOfMemRefReshape>(context); 913 } 914 915 // --------------------------------------------------------------------------- 916 // DmaStartOp 917 // --------------------------------------------------------------------------- 918 919 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 920 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 921 ValueRange destIndices, Value numElements, 922 Value tagMemRef, ValueRange tagIndices, Value stride, 923 Value elementsPerStride) { 924 result.addOperands(srcMemRef); 925 result.addOperands(srcIndices); 926 result.addOperands(destMemRef); 927 result.addOperands(destIndices); 928 result.addOperands({numElements, tagMemRef}); 929 result.addOperands(tagIndices); 930 if (stride) 931 result.addOperands({stride, elementsPerStride}); 932 } 933 934 void DmaStartOp::print(OpAsmPrinter &p) { 935 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " 936 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() 937 << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; 938 if (isStrided()) 939 p << ", " << getStride() << ", " << getNumElementsPerStride(); 940 941 p.printOptionalAttrDict((*this)->getAttrs()); 942 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 943 << ", " << getTagMemRef().getType(); 944 } 945 946 // Parse DmaStartOp. 947 // Ex: 948 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 949 // %tag[%index], %stride, %num_elt_per_stride : 950 // : memref<3076 x f32, 0>, 951 // memref<1024 x f32, 2>, 952 // memref<1 x i32> 953 // 954 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 955 OpAsmParser::OperandType srcMemRefInfo; 956 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 957 OpAsmParser::OperandType dstMemRefInfo; 958 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 959 OpAsmParser::OperandType numElementsInfo; 960 OpAsmParser::OperandType tagMemrefInfo; 961 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 962 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 963 964 SmallVector<Type, 3> types; 965 auto indexType = parser.getBuilder().getIndexType(); 966 967 // Parse and resolve the following list of operands: 968 // *) source memref followed by its indices (in square brackets). 969 // *) destination memref followed by its indices (in square brackets). 970 // *) dma size in KiB. 971 if (parser.parseOperand(srcMemRefInfo) || 972 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 973 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 974 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 975 parser.parseComma() || parser.parseOperand(numElementsInfo) || 976 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 977 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 978 return failure(); 979 980 // Parse optional stride and elements per stride. 981 if (parser.parseTrailingOperandList(strideInfo)) 982 return failure(); 983 984 bool isStrided = strideInfo.size() == 2; 985 if (!strideInfo.empty() && !isStrided) { 986 return parser.emitError(parser.getNameLoc(), 987 "expected two stride related operands"); 988 } 989 990 if (parser.parseColonTypeList(types)) 991 return failure(); 992 if (types.size() != 3) 993 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 994 995 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 996 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 997 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 998 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 999 // size should be an index. 1000 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 1001 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 1002 // tag indices should be index. 1003 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 1004 return failure(); 1005 1006 if (isStrided) { 1007 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1008 return failure(); 1009 } 1010 1011 return success(); 1012 } 1013 1014 LogicalResult DmaStartOp::verify() { 1015 unsigned numOperands = getNumOperands(); 1016 1017 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1018 // the number of elements. 1019 if (numOperands < 4) 1020 return emitOpError("expected at least 4 operands"); 1021 1022 // Check types of operands. The order of these calls is important: the later 1023 // calls rely on some type properties to compute the operand position. 1024 // 1. Source memref. 1025 if (!getSrcMemRef().getType().isa<MemRefType>()) 1026 return emitOpError("expected source to be of memref type"); 1027 if (numOperands < getSrcMemRefRank() + 4) 1028 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 1029 << " operands"; 1030 if (!getSrcIndices().empty() && 1031 !llvm::all_of(getSrcIndices().getTypes(), 1032 [](Type t) { return t.isIndex(); })) 1033 return emitOpError("expected source indices to be of index type"); 1034 1035 // 2. Destination memref. 1036 if (!getDstMemRef().getType().isa<MemRefType>()) 1037 return emitOpError("expected destination to be of memref type"); 1038 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 1039 if (numOperands < numExpectedOperands) 1040 return emitOpError() << "expected at least " << numExpectedOperands 1041 << " operands"; 1042 if (!getDstIndices().empty() && 1043 !llvm::all_of(getDstIndices().getTypes(), 1044 [](Type t) { return t.isIndex(); })) 1045 return emitOpError("expected destination indices to be of index type"); 1046 1047 // 3. Number of elements. 1048 if (!getNumElements().getType().isIndex()) 1049 return emitOpError("expected num elements to be of index type"); 1050 1051 // 4. Tag memref. 1052 if (!getTagMemRef().getType().isa<MemRefType>()) 1053 return emitOpError("expected tag to be of memref type"); 1054 numExpectedOperands += getTagMemRefRank(); 1055 if (numOperands < numExpectedOperands) 1056 return emitOpError() << "expected at least " << numExpectedOperands 1057 << " operands"; 1058 if (!getTagIndices().empty() && 1059 !llvm::all_of(getTagIndices().getTypes(), 1060 [](Type t) { return t.isIndex(); })) 1061 return emitOpError("expected tag indices to be of index type"); 1062 1063 // Optional stride-related operands must be either both present or both 1064 // absent. 1065 if (numOperands != numExpectedOperands && 1066 numOperands != numExpectedOperands + 2) 1067 return emitOpError("incorrect number of operands"); 1068 1069 // 5. Strides. 1070 if (isStrided()) { 1071 if (!getStride().getType().isIndex() || 1072 !getNumElementsPerStride().getType().isIndex()) 1073 return emitOpError( 1074 "expected stride and num elements per stride to be of type index"); 1075 } 1076 1077 return success(); 1078 } 1079 1080 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1081 SmallVectorImpl<OpFoldResult> &results) { 1082 /// dma_start(memrefcast) -> dma_start 1083 return foldMemRefCast(*this); 1084 } 1085 1086 // --------------------------------------------------------------------------- 1087 // DmaWaitOp 1088 // --------------------------------------------------------------------------- 1089 1090 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1091 SmallVectorImpl<OpFoldResult> &results) { 1092 /// dma_wait(memrefcast) -> dma_wait 1093 return foldMemRefCast(*this); 1094 } 1095 1096 LogicalResult DmaWaitOp::verify() { 1097 // Check that the number of tag indices matches the tagMemRef rank. 1098 unsigned numTagIndices = tagIndices().size(); 1099 unsigned tagMemRefRank = getTagMemRefRank(); 1100 if (numTagIndices != tagMemRefRank) 1101 return emitOpError() << "expected tagIndices to have the same number of " 1102 "elements as the tagMemRef rank, expected " 1103 << tagMemRefRank << ", but got " << numTagIndices; 1104 return success(); 1105 } 1106 1107 //===----------------------------------------------------------------------===// 1108 // GenericAtomicRMWOp 1109 //===----------------------------------------------------------------------===// 1110 1111 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, 1112 Value memref, ValueRange ivs) { 1113 result.addOperands(memref); 1114 result.addOperands(ivs); 1115 1116 if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) { 1117 Type elementType = memrefType.getElementType(); 1118 result.addTypes(elementType); 1119 1120 Region *bodyRegion = result.addRegion(); 1121 bodyRegion->push_back(new Block()); 1122 bodyRegion->addArgument(elementType, memref.getLoc()); 1123 } 1124 } 1125 1126 LogicalResult GenericAtomicRMWOp::verify() { 1127 auto &body = getRegion(); 1128 if (body.getNumArguments() != 1) 1129 return emitOpError("expected single number of entry block arguments"); 1130 1131 if (getResult().getType() != body.getArgument(0).getType()) 1132 return emitOpError("expected block argument of the same type result type"); 1133 1134 bool hasSideEffects = 1135 body.walk([&](Operation *nestedOp) { 1136 if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) 1137 return WalkResult::advance(); 1138 nestedOp->emitError( 1139 "body of 'memref.generic_atomic_rmw' should contain " 1140 "only operations with no side effects"); 1141 return WalkResult::interrupt(); 1142 }) 1143 .wasInterrupted(); 1144 return hasSideEffects ? failure() : success(); 1145 } 1146 1147 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, 1148 OperationState &result) { 1149 OpAsmParser::OperandType memref; 1150 Type memrefType; 1151 SmallVector<OpAsmParser::OperandType, 4> ivs; 1152 1153 Type indexType = parser.getBuilder().getIndexType(); 1154 if (parser.parseOperand(memref) || 1155 parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || 1156 parser.parseColonType(memrefType) || 1157 parser.resolveOperand(memref, memrefType, result.operands) || 1158 parser.resolveOperands(ivs, indexType, result.operands)) 1159 return failure(); 1160 1161 Region *body = result.addRegion(); 1162 if (parser.parseRegion(*body, llvm::None, llvm::None) || 1163 parser.parseOptionalAttrDict(result.attributes)) 1164 return failure(); 1165 result.types.push_back(memrefType.cast<MemRefType>().getElementType()); 1166 return success(); 1167 } 1168 1169 void GenericAtomicRMWOp::print(OpAsmPrinter &p) { 1170 p << ' ' << memref() << "[" << indices() << "] : " << memref().getType() 1171 << ' '; 1172 p.printRegion(getRegion()); 1173 p.printOptionalAttrDict((*this)->getAttrs()); 1174 } 1175 1176 //===----------------------------------------------------------------------===// 1177 // AtomicYieldOp 1178 //===----------------------------------------------------------------------===// 1179 1180 LogicalResult AtomicYieldOp::verify() { 1181 Type parentType = (*this)->getParentOp()->getResultTypes().front(); 1182 Type resultType = result().getType(); 1183 if (parentType != resultType) 1184 return emitOpError() << "types mismatch between yield op: " << resultType 1185 << " and its parent: " << parentType; 1186 return success(); 1187 } 1188 1189 //===----------------------------------------------------------------------===// 1190 // GlobalOp 1191 //===----------------------------------------------------------------------===// 1192 1193 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1194 TypeAttr type, 1195 Attribute initialValue) { 1196 p << type; 1197 if (!op.isExternal()) { 1198 p << " = "; 1199 if (op.isUninitialized()) 1200 p << "uninitialized"; 1201 else 1202 p.printAttributeWithoutType(initialValue); 1203 } 1204 } 1205 1206 static ParseResult 1207 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1208 Attribute &initialValue) { 1209 Type type; 1210 if (parser.parseType(type)) 1211 return failure(); 1212 1213 auto memrefType = type.dyn_cast<MemRefType>(); 1214 if (!memrefType || !memrefType.hasStaticShape()) 1215 return parser.emitError(parser.getNameLoc()) 1216 << "type should be static shaped memref, but got " << type; 1217 typeAttr = TypeAttr::get(type); 1218 1219 if (parser.parseOptionalEqual()) 1220 return success(); 1221 1222 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1223 initialValue = UnitAttr::get(parser.getContext()); 1224 return success(); 1225 } 1226 1227 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1228 if (parser.parseAttribute(initialValue, tensorType)) 1229 return failure(); 1230 if (!initialValue.isa<ElementsAttr>()) 1231 return parser.emitError(parser.getNameLoc()) 1232 << "initial value should be a unit or elements attribute"; 1233 return success(); 1234 } 1235 1236 LogicalResult GlobalOp::verify() { 1237 auto memrefType = type().dyn_cast<MemRefType>(); 1238 if (!memrefType || !memrefType.hasStaticShape()) 1239 return emitOpError("type should be static shaped memref, but got ") 1240 << type(); 1241 1242 // Verify that the initial value, if present, is either a unit attribute or 1243 // an elements attribute. 1244 if (initial_value().hasValue()) { 1245 Attribute initValue = initial_value().getValue(); 1246 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1247 return emitOpError("initial value should be a unit or elements " 1248 "attribute, but got ") 1249 << initValue; 1250 1251 // Check that the type of the initial value is compatible with the type of 1252 // the global variable. 1253 if (initValue.isa<ElementsAttr>()) { 1254 Type initType = initValue.getType(); 1255 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1256 if (initType != tensorType) 1257 return emitOpError("initial value expected to be of type ") 1258 << tensorType << ", but was of type " << initType; 1259 } 1260 } 1261 1262 if (Optional<uint64_t> alignAttr = alignment()) { 1263 uint64_t alignment = alignAttr.getValue(); 1264 1265 if (!llvm::isPowerOf2_64(alignment)) 1266 return emitError() << "alignment attribute value " << alignment 1267 << " is not a power of 2"; 1268 } 1269 1270 // TODO: verify visibility for declarations. 1271 return success(); 1272 } 1273 1274 //===----------------------------------------------------------------------===// 1275 // GetGlobalOp 1276 //===----------------------------------------------------------------------===// 1277 1278 LogicalResult 1279 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1280 // Verify that the result type is same as the type of the referenced 1281 // memref.global op. 1282 auto global = 1283 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1284 if (!global) 1285 return emitOpError("'") 1286 << name() << "' does not reference a valid global memref"; 1287 1288 Type resultType = result().getType(); 1289 if (global.type() != resultType) 1290 return emitOpError("result type ") 1291 << resultType << " does not match type " << global.type() 1292 << " of the global memref @" << name(); 1293 return success(); 1294 } 1295 1296 //===----------------------------------------------------------------------===// 1297 // LoadOp 1298 //===----------------------------------------------------------------------===// 1299 1300 LogicalResult LoadOp::verify() { 1301 if (getNumOperands() != 1 + getMemRefType().getRank()) 1302 return emitOpError("incorrect number of indices for load"); 1303 return success(); 1304 } 1305 1306 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1307 /// load(memrefcast) -> load 1308 if (succeeded(foldMemRefCast(*this))) 1309 return getResult(); 1310 return OpFoldResult(); 1311 } 1312 1313 //===----------------------------------------------------------------------===// 1314 // PrefetchOp 1315 //===----------------------------------------------------------------------===// 1316 1317 void PrefetchOp::print(OpAsmPrinter &p) { 1318 p << " " << memref() << '['; 1319 p.printOperands(indices()); 1320 p << ']' << ", " << (isWrite() ? "write" : "read"); 1321 p << ", locality<" << localityHint(); 1322 p << ">, " << (isDataCache() ? "data" : "instr"); 1323 p.printOptionalAttrDict( 1324 (*this)->getAttrs(), 1325 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1326 p << " : " << getMemRefType(); 1327 } 1328 1329 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { 1330 OpAsmParser::OperandType memrefInfo; 1331 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1332 IntegerAttr localityHint; 1333 MemRefType type; 1334 StringRef readOrWrite, cacheType; 1335 1336 auto indexTy = parser.getBuilder().getIndexType(); 1337 auto i32Type = parser.getBuilder().getIntegerType(32); 1338 if (parser.parseOperand(memrefInfo) || 1339 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1340 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1341 parser.parseComma() || parser.parseKeyword("locality") || 1342 parser.parseLess() || 1343 parser.parseAttribute(localityHint, i32Type, "localityHint", 1344 result.attributes) || 1345 parser.parseGreater() || parser.parseComma() || 1346 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1347 parser.resolveOperand(memrefInfo, type, result.operands) || 1348 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1349 return failure(); 1350 1351 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1352 return parser.emitError(parser.getNameLoc(), 1353 "rw specifier has to be 'read' or 'write'"); 1354 result.addAttribute( 1355 PrefetchOp::getIsWriteAttrName(), 1356 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1357 1358 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1359 return parser.emitError(parser.getNameLoc(), 1360 "cache type has to be 'data' or 'instr'"); 1361 1362 result.addAttribute( 1363 PrefetchOp::getIsDataCacheAttrName(), 1364 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1365 1366 return success(); 1367 } 1368 1369 LogicalResult PrefetchOp::verify() { 1370 if (getNumOperands() != 1 + getMemRefType().getRank()) 1371 return emitOpError("too few indices"); 1372 1373 return success(); 1374 } 1375 1376 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1377 SmallVectorImpl<OpFoldResult> &results) { 1378 // prefetch(memrefcast) -> prefetch 1379 return foldMemRefCast(*this); 1380 } 1381 1382 //===----------------------------------------------------------------------===// 1383 // RankOp 1384 //===----------------------------------------------------------------------===// 1385 1386 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1387 // Constant fold rank when the rank of the operand is known. 1388 auto type = getOperand().getType(); 1389 auto shapedType = type.dyn_cast<ShapedType>(); 1390 if (shapedType && shapedType.hasRank()) 1391 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1392 return IntegerAttr(); 1393 } 1394 1395 //===----------------------------------------------------------------------===// 1396 // ReinterpretCastOp 1397 //===----------------------------------------------------------------------===// 1398 1399 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1400 /// `staticSizes` and `staticStrides` are automatically filled with 1401 /// source-memref-rank sentinel values that encode dynamic entries. 1402 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1403 MemRefType resultType, Value source, 1404 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1405 ArrayRef<OpFoldResult> strides, 1406 ArrayRef<NamedAttribute> attrs) { 1407 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1408 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1409 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1410 ShapedType::kDynamicStrideOrOffset); 1411 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1412 ShapedType::kDynamicSize); 1413 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1414 ShapedType::kDynamicStrideOrOffset); 1415 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1416 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1417 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1418 result.addAttributes(attrs); 1419 } 1420 1421 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1422 MemRefType resultType, Value source, 1423 int64_t offset, ArrayRef<int64_t> sizes, 1424 ArrayRef<int64_t> strides, 1425 ArrayRef<NamedAttribute> attrs) { 1426 SmallVector<OpFoldResult> sizeValues = 1427 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1428 return b.getI64IntegerAttr(v); 1429 })); 1430 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1431 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1432 return b.getI64IntegerAttr(v); 1433 })); 1434 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1435 strideValues, attrs); 1436 } 1437 1438 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1439 MemRefType resultType, Value source, Value offset, 1440 ValueRange sizes, ValueRange strides, 1441 ArrayRef<NamedAttribute> attrs) { 1442 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1443 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1444 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1445 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1446 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1447 } 1448 1449 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1450 // completed automatically, like we have for subview and extract_slice. 1451 LogicalResult ReinterpretCastOp::verify() { 1452 // The source and result memrefs should be in the same memory space. 1453 auto srcType = source().getType().cast<BaseMemRefType>(); 1454 auto resultType = getType().cast<MemRefType>(); 1455 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1456 return emitError("different memory spaces specified for source type ") 1457 << srcType << " and result memref type " << resultType; 1458 if (srcType.getElementType() != resultType.getElementType()) 1459 return emitError("different element types specified for source type ") 1460 << srcType << " and result memref type " << resultType; 1461 1462 // Match sizes in result memref type and in static_sizes attribute. 1463 for (auto &en : llvm::enumerate(llvm::zip( 1464 resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { 1465 int64_t resultSize = std::get<0>(en.value()); 1466 int64_t expectedSize = std::get<1>(en.value()); 1467 if (!ShapedType::isDynamic(resultSize) && 1468 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1469 return emitError("expected result type with size = ") 1470 << expectedSize << " instead of " << resultSize 1471 << " in dim = " << en.index(); 1472 } 1473 1474 // Match offset and strides in static_offset and static_strides attributes. If 1475 // result memref type has no affine map specified, this will assume an 1476 // identity layout. 1477 int64_t resultOffset; 1478 SmallVector<int64_t, 4> resultStrides; 1479 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1480 return emitError("expected result type to have strided layout but found ") 1481 << resultType; 1482 1483 // Match offset in result memref type and in static_offsets attribute. 1484 int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); 1485 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1486 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1487 resultOffset != expectedOffset) 1488 return emitError("expected result type with offset = ") 1489 << resultOffset << " instead of " << expectedOffset; 1490 1491 // Match strides in result memref type and in static_strides attribute. 1492 for (auto &en : llvm::enumerate(llvm::zip( 1493 resultStrides, extractFromI64ArrayAttr(static_strides())))) { 1494 int64_t resultStride = std::get<0>(en.value()); 1495 int64_t expectedStride = std::get<1>(en.value()); 1496 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1497 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1498 resultStride != expectedStride) 1499 return emitError("expected result type with stride = ") 1500 << expectedStride << " instead of " << resultStride 1501 << " in dim = " << en.index(); 1502 } 1503 1504 return success(); 1505 } 1506 1507 //===----------------------------------------------------------------------===// 1508 // Reassociative reshape ops 1509 //===----------------------------------------------------------------------===// 1510 1511 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1512 return getSymbolLessAffineMaps(getReassociationExprs()); 1513 } 1514 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1515 return convertReassociationIndicesToExprs(getContext(), 1516 getReassociationIndices()); 1517 } 1518 1519 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1520 return getSymbolLessAffineMaps(getReassociationExprs()); 1521 } 1522 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1523 return convertReassociationIndicesToExprs(getContext(), 1524 getReassociationIndices()); 1525 } 1526 1527 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1528 /// copies. 1529 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1530 ArrayRef<int64_t> sizes, 1531 ArrayRef<AffineExpr> strides) { 1532 // Bands of extent one can be reshaped, as they are not reshaped at all. 1533 if (extent == 1) 1534 return true; 1535 // Otherwise, the size of the first dimension needs to be known. 1536 if (ShapedType::isDynamic(sizes[dim])) 1537 return false; 1538 assert(sizes.size() == strides.size() && "mismatched ranks"); 1539 // off by 1 indexing to avoid out of bounds 1540 // V 1541 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1542 // Only bands of static shapes are reshapable. This is due to the fact that 1543 // there is no relation between dynamic sizes and dynamic strides: we do not 1544 // have enough information to know whether a "-1" size corresponds to the 1545 // proper symbol in the AffineExpr of a stride. 1546 if (ShapedType::isDynamic(sizes[idx + 1])) 1547 return false; 1548 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1549 // simplify on the fly and catch more reshapable cases. 1550 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1551 return false; 1552 } 1553 return true; 1554 } 1555 1556 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1557 /// expected to be valid) to `type`. 1558 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1559 /// MemRefType. 1560 static MemRefType 1561 computeReshapeCollapsedType(MemRefType type, 1562 ArrayRef<AffineMap> reassociation) { 1563 auto sizes = type.getShape(); 1564 AffineExpr offset; 1565 SmallVector<AffineExpr, 4> strides; 1566 auto status = getStridesAndOffset(type, strides, offset); 1567 auto isIdentityLayout = type.getLayout().isIdentity(); 1568 (void)status; 1569 assert(succeeded(status) && "expected strided memref"); 1570 1571 SmallVector<int64_t, 4> newSizes; 1572 newSizes.reserve(reassociation.size()); 1573 SmallVector<AffineExpr, 4> newStrides; 1574 newStrides.reserve(reassociation.size()); 1575 1576 // Use the fact that reassociation is valid to simplify the logic: only use 1577 // each map's rank. 1578 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1579 unsigned currentDim = 0; 1580 for (AffineMap m : reassociation) { 1581 unsigned dim = m.getNumResults(); 1582 int64_t size = 1; 1583 AffineExpr stride = strides[currentDim + dim - 1]; 1584 if (isIdentityLayout || 1585 isReshapableDimBand(currentDim, dim, sizes, strides)) { 1586 for (unsigned d = 0; d < dim; ++d) { 1587 int64_t currentSize = sizes[currentDim + d]; 1588 if (ShapedType::isDynamic(currentSize)) { 1589 size = ShapedType::kDynamicSize; 1590 break; 1591 } 1592 size *= currentSize; 1593 } 1594 } else { 1595 size = ShapedType::kDynamicSize; 1596 stride = AffineExpr(); 1597 } 1598 newSizes.push_back(size); 1599 newStrides.push_back(stride); 1600 currentDim += dim; 1601 } 1602 1603 // Early-exit: if `type` is contiguous, the result must be contiguous. 1604 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1605 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1606 1607 // Convert back to int64_t because we don't have enough information to create 1608 // new strided layouts from AffineExpr only. This corresponds to a case where 1609 // copies may be necessary. 1610 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1611 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1612 intOffset = o.getValue(); 1613 SmallVector<int64_t, 4> intStrides; 1614 intStrides.reserve(strides.size()); 1615 for (auto stride : newStrides) { 1616 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1617 intStrides.push_back(cst.getValue()); 1618 else 1619 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1620 } 1621 auto layout = 1622 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1623 return canonicalizeStridedLayout( 1624 MemRefType::Builder(type).setShape(newSizes).setLayout( 1625 AffineMapAttr::get(layout))); 1626 } 1627 1628 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1629 ArrayRef<ReassociationIndices> reassociation, 1630 ArrayRef<NamedAttribute> attrs) { 1631 auto memRefType = src.getType().cast<MemRefType>(); 1632 auto resultType = computeReshapeCollapsedType( 1633 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1634 b.getContext(), reassociation))); 1635 build(b, result, resultType, src, attrs); 1636 result.addAttribute(getReassociationAttrName(), 1637 getReassociationIndicesAttribute(b, reassociation)); 1638 } 1639 1640 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1641 ArrayRef<ReassociationIndices> reassociation, 1642 ArrayRef<NamedAttribute> attrs) { 1643 auto memRefType = src.getType().cast<MemRefType>(); 1644 auto resultType = computeReshapeCollapsedType( 1645 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1646 b.getContext(), reassociation))); 1647 build(b, result, resultType, src, attrs); 1648 result.addAttribute(getReassociationAttrName(), 1649 getReassociationIndicesAttribute(b, reassociation)); 1650 } 1651 1652 template <typename ReshapeOp, 1653 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1654 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1655 MemRefType collapsedType) { 1656 if (failed( 1657 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1658 return failure(); 1659 auto maps = op.getReassociationMaps(); 1660 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1661 if (collapsedType != expectedType) 1662 return op.emitOpError("expected collapsed type to be ") 1663 << expectedType << ", but got " << collapsedType; 1664 return success(); 1665 } 1666 1667 LogicalResult ExpandShapeOp::verify() { 1668 return verifyReshapeOp(*this, getResultType(), getSrcType()); 1669 } 1670 1671 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1672 MLIRContext *context) { 1673 results.add<CollapseReshapeOps<ExpandShapeOp>, 1674 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1675 } 1676 1677 LogicalResult CollapseShapeOp::verify() { 1678 return verifyReshapeOp(*this, getSrcType(), getResultType()); 1679 } 1680 1681 struct CollapseShapeOpMemRefCastFolder 1682 : public OpRewritePattern<CollapseShapeOp> { 1683 public: 1684 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1685 1686 LogicalResult matchAndRewrite(CollapseShapeOp op, 1687 PatternRewriter &rewriter) const override { 1688 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1689 if (!cast) 1690 return failure(); 1691 1692 if (!CastOp::canFoldIntoConsumerOp(cast)) 1693 return failure(); 1694 1695 Type newResultType = computeReshapeCollapsedType( 1696 cast.getOperand().getType().cast<MemRefType>(), 1697 op.getReassociationMaps()); 1698 1699 if (newResultType == op.getResultType()) { 1700 rewriter.updateRootInPlace( 1701 op, [&]() { op.srcMutable().assign(cast.source()); }); 1702 } else { 1703 Value newOp = rewriter.create<CollapseShapeOp>( 1704 op->getLoc(), cast.source(), op.getReassociationIndices()); 1705 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1706 } 1707 return success(); 1708 } 1709 }; 1710 1711 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1712 MLIRContext *context) { 1713 results.add<CollapseReshapeOps<CollapseShapeOp>, 1714 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1715 CollapseShapeOpMemRefCastFolder>(context); 1716 } 1717 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1718 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1719 } 1720 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1721 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1722 } 1723 1724 //===----------------------------------------------------------------------===// 1725 // ReshapeOp 1726 //===----------------------------------------------------------------------===// 1727 1728 LogicalResult ReshapeOp::verify() { 1729 Type operandType = source().getType(); 1730 Type resultType = result().getType(); 1731 1732 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1733 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1734 if (operandElementType != resultElementType) 1735 return emitOpError("element types of source and destination memref " 1736 "types should be the same"); 1737 1738 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1739 if (!operandMemRefType.getLayout().isIdentity()) 1740 return emitOpError("source memref type should have identity affine map"); 1741 1742 int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0); 1743 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1744 if (resultMemRefType) { 1745 if (!resultMemRefType.getLayout().isIdentity()) 1746 return emitOpError("result memref type should have identity affine map"); 1747 if (shapeSize == ShapedType::kDynamicSize) 1748 return emitOpError("cannot use shape operand with dynamic length to " 1749 "reshape to statically-ranked memref type"); 1750 if (shapeSize != resultMemRefType.getRank()) 1751 return emitOpError( 1752 "length of shape operand differs from the result's memref rank"); 1753 } 1754 return success(); 1755 } 1756 1757 //===----------------------------------------------------------------------===// 1758 // StoreOp 1759 //===----------------------------------------------------------------------===// 1760 1761 LogicalResult StoreOp::verify() { 1762 if (getNumOperands() != 2 + getMemRefType().getRank()) 1763 return emitOpError("store index operand count not equal to memref rank"); 1764 1765 return success(); 1766 } 1767 1768 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1769 SmallVectorImpl<OpFoldResult> &results) { 1770 /// store(memrefcast) -> store 1771 return foldMemRefCast(*this, getValueToStore()); 1772 } 1773 1774 //===----------------------------------------------------------------------===// 1775 // SubViewOp 1776 //===----------------------------------------------------------------------===// 1777 1778 namespace { 1779 /// Helpers to write more idiomatic operations. 1780 namespace saturated_arith { 1781 struct Wrapper { 1782 explicit Wrapper(int64_t v) : v(v) {} 1783 operator int64_t() { return v; } 1784 int64_t v; 1785 }; 1786 Wrapper operator+(Wrapper a, int64_t b) { 1787 if (ShapedType::isDynamicStrideOrOffset(a) || 1788 ShapedType::isDynamicStrideOrOffset(b)) 1789 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1790 return Wrapper(a.v + b); 1791 } 1792 Wrapper operator*(Wrapper a, int64_t b) { 1793 if (ShapedType::isDynamicStrideOrOffset(a) || 1794 ShapedType::isDynamicStrideOrOffset(b)) 1795 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1796 return Wrapper(a.v * b); 1797 } 1798 } // namespace saturated_arith 1799 } // namespace 1800 1801 /// A subview result type can be fully inferred from the source type and the 1802 /// static representation of offsets, sizes and strides. Special sentinels 1803 /// encode the dynamic case. 1804 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1805 ArrayRef<int64_t> staticOffsets, 1806 ArrayRef<int64_t> staticSizes, 1807 ArrayRef<int64_t> staticStrides) { 1808 unsigned rank = sourceMemRefType.getRank(); 1809 (void)rank; 1810 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 1811 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 1812 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 1813 1814 // Extract source offset and strides. 1815 int64_t sourceOffset; 1816 SmallVector<int64_t, 4> sourceStrides; 1817 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1818 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1819 (void)res; 1820 1821 // Compute target offset whose value is: 1822 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1823 int64_t targetOffset = sourceOffset; 1824 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1825 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1826 using namespace saturated_arith; 1827 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1828 } 1829 1830 // Compute target stride whose value is: 1831 // `sourceStrides_i * staticStrides_i`. 1832 SmallVector<int64_t, 4> targetStrides; 1833 targetStrides.reserve(staticOffsets.size()); 1834 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1835 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1836 using namespace saturated_arith; 1837 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1838 } 1839 1840 // The type is now known. 1841 return MemRefType::get( 1842 staticSizes, sourceMemRefType.getElementType(), 1843 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1844 sourceMemRefType.getContext()), 1845 sourceMemRefType.getMemorySpace()); 1846 } 1847 1848 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1849 ArrayRef<OpFoldResult> offsets, 1850 ArrayRef<OpFoldResult> sizes, 1851 ArrayRef<OpFoldResult> strides) { 1852 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1853 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1854 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1855 ShapedType::kDynamicStrideOrOffset); 1856 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1857 ShapedType::kDynamicSize); 1858 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1859 ShapedType::kDynamicStrideOrOffset); 1860 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1861 staticSizes, staticStrides); 1862 } 1863 1864 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1865 MemRefType sourceRankedTensorType, 1866 ArrayRef<int64_t> offsets, 1867 ArrayRef<int64_t> sizes, 1868 ArrayRef<int64_t> strides) { 1869 auto inferredType = 1870 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 1871 .cast<MemRefType>(); 1872 assert(inferredType.getRank() >= resultRank && "expected "); 1873 int rankDiff = inferredType.getRank() - resultRank; 1874 if (rankDiff > 0) { 1875 auto shape = inferredType.getShape(); 1876 llvm::SmallBitVector dimsToProject = 1877 getPositionsOfShapeOne(rankDiff, shape); 1878 SmallVector<int64_t> projectedShape; 1879 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1880 if (!dimsToProject.test(pos)) 1881 projectedShape.push_back(shape[pos]); 1882 1883 AffineMap map = inferredType.getLayout().getAffineMap(); 1884 if (!map.isIdentity()) 1885 map = getProjectedMap(map, dimsToProject); 1886 inferredType = 1887 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1888 inferredType.getMemorySpace()); 1889 } 1890 return inferredType; 1891 } 1892 1893 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1894 MemRefType sourceRankedTensorType, 1895 ArrayRef<OpFoldResult> offsets, 1896 ArrayRef<OpFoldResult> sizes, 1897 ArrayRef<OpFoldResult> strides) { 1898 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1899 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1900 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1901 ShapedType::kDynamicStrideOrOffset); 1902 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1903 ShapedType::kDynamicSize); 1904 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1905 ShapedType::kDynamicStrideOrOffset); 1906 return SubViewOp::inferRankReducedResultType( 1907 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1908 staticStrides); 1909 } 1910 // Build a SubViewOp with mixed static and dynamic entries and custom result 1911 // type. If the type passed is nullptr, it is inferred. 1912 void SubViewOp::build(OpBuilder &b, OperationState &result, 1913 MemRefType resultType, Value source, 1914 ArrayRef<OpFoldResult> offsets, 1915 ArrayRef<OpFoldResult> sizes, 1916 ArrayRef<OpFoldResult> strides, 1917 ArrayRef<NamedAttribute> attrs) { 1918 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1919 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1920 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1921 ShapedType::kDynamicStrideOrOffset); 1922 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1923 ShapedType::kDynamicSize); 1924 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1925 ShapedType::kDynamicStrideOrOffset); 1926 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1927 // Structuring implementation this way avoids duplication between builders. 1928 if (!resultType) { 1929 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1930 staticSizes, staticStrides) 1931 .cast<MemRefType>(); 1932 } 1933 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1934 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1935 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1936 result.addAttributes(attrs); 1937 } 1938 1939 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1940 // type. 1941 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1942 ArrayRef<OpFoldResult> offsets, 1943 ArrayRef<OpFoldResult> sizes, 1944 ArrayRef<OpFoldResult> strides, 1945 ArrayRef<NamedAttribute> attrs) { 1946 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1947 } 1948 1949 // Build a SubViewOp with static entries and inferred result type. 1950 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1951 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1952 ArrayRef<int64_t> strides, 1953 ArrayRef<NamedAttribute> attrs) { 1954 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1955 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1956 return b.getI64IntegerAttr(v); 1957 })); 1958 SmallVector<OpFoldResult> sizeValues = 1959 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1960 return b.getI64IntegerAttr(v); 1961 })); 1962 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1963 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1964 return b.getI64IntegerAttr(v); 1965 })); 1966 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1967 } 1968 1969 // Build a SubViewOp with dynamic entries and custom result type. If the 1970 // type passed is nullptr, it is inferred. 1971 void SubViewOp::build(OpBuilder &b, OperationState &result, 1972 MemRefType resultType, Value source, 1973 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1974 ArrayRef<int64_t> strides, 1975 ArrayRef<NamedAttribute> attrs) { 1976 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1977 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1978 return b.getI64IntegerAttr(v); 1979 })); 1980 SmallVector<OpFoldResult> sizeValues = 1981 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1982 return b.getI64IntegerAttr(v); 1983 })); 1984 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1985 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1986 return b.getI64IntegerAttr(v); 1987 })); 1988 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1989 attrs); 1990 } 1991 1992 // Build a SubViewOp with dynamic entries and custom result type. If the type 1993 // passed is nullptr, it is inferred. 1994 void SubViewOp::build(OpBuilder &b, OperationState &result, 1995 MemRefType resultType, Value source, ValueRange offsets, 1996 ValueRange sizes, ValueRange strides, 1997 ArrayRef<NamedAttribute> attrs) { 1998 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1999 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 2000 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 2001 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 2002 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2003 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 2004 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 2005 } 2006 2007 // Build a SubViewOp with dynamic entries and inferred result type. 2008 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 2009 ValueRange offsets, ValueRange sizes, ValueRange strides, 2010 ArrayRef<NamedAttribute> attrs) { 2011 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 2012 } 2013 2014 /// For ViewLikeOpInterface. 2015 Value SubViewOp::getViewSource() { return source(); } 2016 2017 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 2018 /// value). 2019 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 2020 AffineExpr t1Offset, t2Offset; 2021 SmallVector<AffineExpr> t1Strides, t2Strides; 2022 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 2023 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 2024 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 2025 } 2026 2027 /// Checks if `original` Type type can be rank reduced to `reduced` type. 2028 /// This function is slight variant of `is subsequence` algorithm where 2029 /// not matching dimension must be 1. 2030 static SliceVerificationResult 2031 isRankReducedMemRefType(MemRefType originalType, 2032 MemRefType candidateRankReducedType, 2033 ArrayRef<OpFoldResult> sizes) { 2034 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 2035 if (partialRes != SliceVerificationResult::Success) 2036 return partialRes; 2037 2038 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 2039 originalType, candidateRankReducedType, sizes); 2040 2041 // Sizes cannot be matched in case empty vector is returned. 2042 if (!optionalUnusedDimsMask.hasValue()) 2043 return SliceVerificationResult::LayoutMismatch; 2044 2045 if (originalType.getMemorySpace() != 2046 candidateRankReducedType.getMemorySpace()) 2047 return SliceVerificationResult::MemSpaceMismatch; 2048 2049 // No amount of stride dropping can reconcile incompatible offsets. 2050 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 2051 return SliceVerificationResult::LayoutMismatch; 2052 2053 return SliceVerificationResult::Success; 2054 } 2055 2056 template <typename OpTy> 2057 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 2058 OpTy op, Type expectedType) { 2059 auto memrefType = expectedType.cast<ShapedType>(); 2060 switch (result) { 2061 case SliceVerificationResult::Success: 2062 return success(); 2063 case SliceVerificationResult::RankTooLarge: 2064 return op.emitError("expected result rank to be smaller or equal to ") 2065 << "the source rank. "; 2066 case SliceVerificationResult::SizeMismatch: 2067 return op.emitError("expected result type to be ") 2068 << expectedType 2069 << " or a rank-reduced version. (mismatch of result sizes) "; 2070 case SliceVerificationResult::ElemTypeMismatch: 2071 return op.emitError("expected result element type to be ") 2072 << memrefType.getElementType(); 2073 case SliceVerificationResult::MemSpaceMismatch: 2074 return op.emitError("expected result and source memory spaces to match."); 2075 case SliceVerificationResult::LayoutMismatch: 2076 return op.emitError("expected result type to be ") 2077 << expectedType 2078 << " or a rank-reduced version. (mismatch of result layout) "; 2079 } 2080 llvm_unreachable("unexpected subview verification result"); 2081 } 2082 2083 /// Verifier for SubViewOp. 2084 LogicalResult SubViewOp::verify() { 2085 MemRefType baseType = getSourceType(); 2086 MemRefType subViewType = getType(); 2087 2088 // The base memref and the view memref should be in the same memory space. 2089 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2090 return emitError("different memory spaces specified for base memref " 2091 "type ") 2092 << baseType << " and subview memref type " << subViewType; 2093 2094 // Verify that the base memref type has a strided layout map. 2095 if (!isStrided(baseType)) 2096 return emitError("base type ") << baseType << " is not strided"; 2097 2098 // Verify result type against inferred type. 2099 auto expectedType = SubViewOp::inferResultType( 2100 baseType, extractFromI64ArrayAttr(static_offsets()), 2101 extractFromI64ArrayAttr(static_sizes()), 2102 extractFromI64ArrayAttr(static_strides())); 2103 2104 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 2105 subViewType, getMixedSizes()); 2106 return produceSubViewErrorMsg(result, *this, expectedType); 2107 } 2108 2109 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 2110 return os << "range " << range.offset << ":" << range.size << ":" 2111 << range.stride; 2112 } 2113 2114 /// Return the list of Range (i.e. offset, size, stride). Each Range 2115 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2116 /// with `b` at location `loc`. 2117 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2118 OpBuilder &b, Location loc) { 2119 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2120 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2121 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2122 SmallVector<Range, 8> res; 2123 unsigned rank = ranks[0]; 2124 res.reserve(rank); 2125 for (unsigned idx = 0; idx < rank; ++idx) { 2126 Value offset = 2127 op.isDynamicOffset(idx) 2128 ? op.getDynamicOffset(idx) 2129 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2130 Value size = 2131 op.isDynamicSize(idx) 2132 ? op.getDynamicSize(idx) 2133 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 2134 Value stride = 2135 op.isDynamicStride(idx) 2136 ? op.getDynamicStride(idx) 2137 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 2138 res.emplace_back(Range{offset, size, stride}); 2139 } 2140 return res; 2141 } 2142 2143 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2144 /// deduce the result type for the given `sourceType`. Additionally, reduce the 2145 /// rank of the inferred result type if `currentResultType` is lower rank than 2146 /// `currentSourceType`. Use this signature if `sourceType` is updated together 2147 /// with the result type. In this case, it is important to compute the dropped 2148 /// dimensions using `currentSourceType` whose strides align with 2149 /// `currentResultType`. 2150 static MemRefType getCanonicalSubViewResultType( 2151 MemRefType currentResultType, MemRefType currentSourceType, 2152 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 2153 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 2154 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2155 mixedSizes, mixedStrides) 2156 .cast<MemRefType>(); 2157 llvm::Optional<llvm::SmallBitVector> unusedDims = 2158 computeMemRefRankReductionMask(currentSourceType, currentResultType, 2159 mixedSizes); 2160 // Return nullptr as failure mode. 2161 if (!unusedDims) 2162 return nullptr; 2163 SmallVector<int64_t> shape; 2164 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 2165 if (unusedDims->test(sizes.index())) 2166 continue; 2167 shape.push_back(sizes.value()); 2168 } 2169 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 2170 if (!layoutMap.isIdentity()) 2171 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 2172 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 2173 nonRankReducedType.getMemorySpace()); 2174 } 2175 2176 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 2177 /// deduce the result type. Additionally, reduce the rank of the inferred result 2178 /// type if `currentResultType` is lower rank than `sourceType`. 2179 static MemRefType getCanonicalSubViewResultType( 2180 MemRefType currentResultType, MemRefType sourceType, 2181 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 2182 ArrayRef<OpFoldResult> mixedStrides) { 2183 return getCanonicalSubViewResultType(currentResultType, sourceType, 2184 sourceType, mixedOffsets, mixedSizes, 2185 mixedStrides); 2186 } 2187 2188 /// Helper method to check if a `subview` operation is trivially a no-op. This 2189 /// is the case if the all offsets are zero, all strides are 1, and the source 2190 /// shape is same as the size of the subview. In such cases, the subview can be 2191 /// folded into its source. 2192 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 2193 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 2194 return false; 2195 2196 auto mixedOffsets = subViewOp.getMixedOffsets(); 2197 auto mixedSizes = subViewOp.getMixedSizes(); 2198 auto mixedStrides = subViewOp.getMixedStrides(); 2199 2200 // Check offsets are zero. 2201 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 2202 Optional<int64_t> intValue = getConstantIntValue(ofr); 2203 return !intValue || intValue.getValue() != 0; 2204 })) 2205 return false; 2206 2207 // Check strides are one. 2208 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 2209 Optional<int64_t> intValue = getConstantIntValue(ofr); 2210 return !intValue || intValue.getValue() != 1; 2211 })) 2212 return false; 2213 2214 // Check all size values are static and matches the (static) source shape. 2215 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 2216 for (const auto &size : llvm::enumerate(mixedSizes)) { 2217 Optional<int64_t> intValue = getConstantIntValue(size.value()); 2218 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 2219 return false; 2220 } 2221 // All conditions met. The `SubViewOp` is foldable as a no-op. 2222 return true; 2223 } 2224 2225 namespace { 2226 /// Pattern to rewrite a subview op with MemRefCast arguments. 2227 /// This essentially pushes memref.cast past its consuming subview when 2228 /// `canFoldIntoConsumerOp` is true. 2229 /// 2230 /// Example: 2231 /// ``` 2232 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2233 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2234 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2235 /// ``` 2236 /// is rewritten into: 2237 /// ``` 2238 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2239 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2240 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2241 /// ``` 2242 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2243 public: 2244 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2245 2246 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2247 PatternRewriter &rewriter) const override { 2248 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2249 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2250 return matchPattern(operand, matchConstantIndex()); 2251 })) 2252 return failure(); 2253 2254 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2255 if (!castOp) 2256 return failure(); 2257 2258 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2259 return failure(); 2260 2261 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 2262 // MemRefCastOp source operand type to infer the result type and the current 2263 // SubViewOp source operand type to compute the dropped dimensions if the 2264 // operation is rank-reducing. 2265 auto resultType = getCanonicalSubViewResultType( 2266 subViewOp.getType(), subViewOp.getSourceType(), 2267 castOp.source().getType().cast<MemRefType>(), 2268 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2269 subViewOp.getMixedStrides()); 2270 if (!resultType) 2271 return failure(); 2272 2273 Value newSubView = rewriter.create<SubViewOp>( 2274 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2275 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2276 subViewOp.static_sizes(), subViewOp.static_strides()); 2277 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2278 newSubView); 2279 return success(); 2280 } 2281 }; 2282 2283 /// Canonicalize subview ops that are no-ops. When the source shape is not same 2284 /// as a result shape due to use of `affine_map`. 2285 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 2286 public: 2287 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2288 2289 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2290 PatternRewriter &rewriter) const override { 2291 if (!isTrivialSubViewOp(subViewOp)) 2292 return failure(); 2293 if (subViewOp.getSourceType() == subViewOp.getType()) { 2294 rewriter.replaceOp(subViewOp, subViewOp.source()); 2295 return success(); 2296 } 2297 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2298 subViewOp.source()); 2299 return success(); 2300 } 2301 }; 2302 } // namespace 2303 2304 /// Return the canonical type of the result of a subview. 2305 struct SubViewReturnTypeCanonicalizer { 2306 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2307 ArrayRef<OpFoldResult> mixedSizes, 2308 ArrayRef<OpFoldResult> mixedStrides) { 2309 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2310 mixedOffsets, mixedSizes, 2311 mixedStrides); 2312 } 2313 }; 2314 2315 /// A canonicalizer wrapper to replace SubViewOps. 2316 struct SubViewCanonicalizer { 2317 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2318 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 2319 } 2320 }; 2321 2322 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2323 MLIRContext *context) { 2324 results 2325 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2326 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2327 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2328 } 2329 2330 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2331 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2332 auto sourceShapedType = source().getType().cast<ShapedType>(); 2333 2334 if (resultShapedType.hasStaticShape() && 2335 resultShapedType == sourceShapedType) { 2336 return getViewSource(); 2337 } 2338 2339 return {}; 2340 } 2341 2342 //===----------------------------------------------------------------------===// 2343 // TransposeOp 2344 //===----------------------------------------------------------------------===// 2345 2346 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2347 static MemRefType inferTransposeResultType(MemRefType memRefType, 2348 AffineMap permutationMap) { 2349 auto rank = memRefType.getRank(); 2350 auto originalSizes = memRefType.getShape(); 2351 // Compute permuted sizes. 2352 SmallVector<int64_t, 4> sizes(rank, 0); 2353 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2354 sizes[en.index()] = 2355 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2356 2357 // Compute permuted strides. 2358 int64_t offset; 2359 SmallVector<int64_t, 4> strides; 2360 auto res = getStridesAndOffset(memRefType, strides, offset); 2361 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2362 (void)res; 2363 auto map = 2364 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2365 map = permutationMap ? map.compose(permutationMap) : map; 2366 return MemRefType::Builder(memRefType) 2367 .setShape(sizes) 2368 .setLayout(AffineMapAttr::get(map)); 2369 } 2370 2371 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2372 AffineMapAttr permutation, 2373 ArrayRef<NamedAttribute> attrs) { 2374 auto permutationMap = permutation.getValue(); 2375 assert(permutationMap); 2376 2377 auto memRefType = in.getType().cast<MemRefType>(); 2378 // Compute result type. 2379 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2380 2381 build(b, result, resultType, in, attrs); 2382 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2383 } 2384 2385 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2386 void TransposeOp::print(OpAsmPrinter &p) { 2387 p << " " << in() << " " << permutation(); 2388 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); 2389 p << " : " << in().getType() << " to " << getType(); 2390 } 2391 2392 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { 2393 OpAsmParser::OperandType in; 2394 AffineMap permutation; 2395 MemRefType srcType, dstType; 2396 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2397 parser.parseOptionalAttrDict(result.attributes) || 2398 parser.parseColonType(srcType) || 2399 parser.resolveOperand(in, srcType, result.operands) || 2400 parser.parseKeywordType("to", dstType) || 2401 parser.addTypeToList(dstType, result.types)) 2402 return failure(); 2403 2404 result.addAttribute(TransposeOp::getPermutationAttrName(), 2405 AffineMapAttr::get(permutation)); 2406 return success(); 2407 } 2408 2409 LogicalResult TransposeOp::verify() { 2410 if (!permutation().isPermutation()) 2411 return emitOpError("expected a permutation map"); 2412 if (permutation().getNumDims() != getShapedType().getRank()) 2413 return emitOpError("expected a permutation map of same rank as the input"); 2414 2415 auto srcType = in().getType().cast<MemRefType>(); 2416 auto dstType = getType().cast<MemRefType>(); 2417 auto transposedType = inferTransposeResultType(srcType, permutation()); 2418 if (dstType != transposedType) 2419 return emitOpError("output type ") 2420 << dstType << " does not match transposed input type " << srcType 2421 << ", " << transposedType; 2422 return success(); 2423 } 2424 2425 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2426 if (succeeded(foldMemRefCast(*this))) 2427 return getResult(); 2428 return {}; 2429 } 2430 2431 //===----------------------------------------------------------------------===// 2432 // ViewOp 2433 //===----------------------------------------------------------------------===// 2434 2435 LogicalResult ViewOp::verify() { 2436 auto baseType = getOperand(0).getType().cast<MemRefType>(); 2437 auto viewType = getType(); 2438 2439 // The base memref should have identity layout map (or none). 2440 if (!baseType.getLayout().isIdentity()) 2441 return emitError("unsupported map for base memref type ") << baseType; 2442 2443 // The result memref should have identity layout map (or none). 2444 if (!viewType.getLayout().isIdentity()) 2445 return emitError("unsupported map for result memref type ") << viewType; 2446 2447 // The base memref and the view memref should be in the same memory space. 2448 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2449 return emitError("different memory spaces specified for base memref " 2450 "type ") 2451 << baseType << " and view memref type " << viewType; 2452 2453 // Verify that we have the correct number of sizes for the result type. 2454 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2455 if (sizes().size() != numDynamicDims) 2456 return emitError("incorrect number of size operands for type ") << viewType; 2457 2458 return success(); 2459 } 2460 2461 Value ViewOp::getViewSource() { return source(); } 2462 2463 namespace { 2464 2465 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2466 using OpRewritePattern<ViewOp>::OpRewritePattern; 2467 2468 LogicalResult matchAndRewrite(ViewOp viewOp, 2469 PatternRewriter &rewriter) const override { 2470 // Return if none of the operands are constants. 2471 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2472 return matchPattern(operand, matchConstantIndex()); 2473 })) 2474 return failure(); 2475 2476 // Get result memref type. 2477 auto memrefType = viewOp.getType(); 2478 2479 // Get offset from old memref view type 'memRefType'. 2480 int64_t oldOffset; 2481 SmallVector<int64_t, 4> oldStrides; 2482 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2483 return failure(); 2484 assert(oldOffset == 0 && "Expected 0 offset"); 2485 2486 SmallVector<Value, 4> newOperands; 2487 2488 // Offset cannot be folded into result type. 2489 2490 // Fold any dynamic dim operands which are produced by a constant. 2491 SmallVector<int64_t, 4> newShapeConstants; 2492 newShapeConstants.reserve(memrefType.getRank()); 2493 2494 unsigned dynamicDimPos = 0; 2495 unsigned rank = memrefType.getRank(); 2496 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2497 int64_t dimSize = memrefType.getDimSize(dim); 2498 // If this is already static dimension, keep it. 2499 if (!ShapedType::isDynamic(dimSize)) { 2500 newShapeConstants.push_back(dimSize); 2501 continue; 2502 } 2503 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2504 if (auto constantIndexOp = 2505 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2506 // Dynamic shape dimension will be folded. 2507 newShapeConstants.push_back(constantIndexOp.value()); 2508 } else { 2509 // Dynamic shape dimension not folded; copy operand from old memref. 2510 newShapeConstants.push_back(dimSize); 2511 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2512 } 2513 dynamicDimPos++; 2514 } 2515 2516 // Create new memref type with constant folded dims. 2517 MemRefType newMemRefType = 2518 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2519 // Nothing new, don't fold. 2520 if (newMemRefType == memrefType) 2521 return failure(); 2522 2523 // Create new ViewOp. 2524 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2525 viewOp.getOperand(0), 2526 viewOp.byte_shift(), newOperands); 2527 // Insert a cast so we have the same type as the old memref type. 2528 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp); 2529 return success(); 2530 } 2531 }; 2532 2533 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2534 using OpRewritePattern<ViewOp>::OpRewritePattern; 2535 2536 LogicalResult matchAndRewrite(ViewOp viewOp, 2537 PatternRewriter &rewriter) const override { 2538 Value memrefOperand = viewOp.getOperand(0); 2539 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2540 if (!memrefCastOp) 2541 return failure(); 2542 Value allocOperand = memrefCastOp.getOperand(); 2543 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2544 if (!allocOp) 2545 return failure(); 2546 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2547 viewOp.byte_shift(), viewOp.sizes()); 2548 return success(); 2549 } 2550 }; 2551 2552 } // namespace 2553 2554 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2555 MLIRContext *context) { 2556 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2557 } 2558 2559 //===----------------------------------------------------------------------===// 2560 // AtomicRMWOp 2561 //===----------------------------------------------------------------------===// 2562 2563 LogicalResult AtomicRMWOp::verify() { 2564 if (getMemRefType().getRank() != getNumOperands() - 2) 2565 return emitOpError( 2566 "expects the number of subscripts to be equal to memref rank"); 2567 switch (kind()) { 2568 case arith::AtomicRMWKind::addf: 2569 case arith::AtomicRMWKind::maxf: 2570 case arith::AtomicRMWKind::minf: 2571 case arith::AtomicRMWKind::mulf: 2572 if (!value().getType().isa<FloatType>()) 2573 return emitOpError() << "with kind '" 2574 << arith::stringifyAtomicRMWKind(kind()) 2575 << "' expects a floating-point type"; 2576 break; 2577 case arith::AtomicRMWKind::addi: 2578 case arith::AtomicRMWKind::maxs: 2579 case arith::AtomicRMWKind::maxu: 2580 case arith::AtomicRMWKind::mins: 2581 case arith::AtomicRMWKind::minu: 2582 case arith::AtomicRMWKind::muli: 2583 case arith::AtomicRMWKind::ori: 2584 case arith::AtomicRMWKind::andi: 2585 if (!value().getType().isa<IntegerType>()) 2586 return emitOpError() << "with kind '" 2587 << arith::stringifyAtomicRMWKind(kind()) 2588 << "' expects an integer type"; 2589 break; 2590 default: 2591 break; 2592 } 2593 return success(); 2594 } 2595 2596 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2597 /// atomicrmw(memrefcast) -> atomicrmw 2598 if (succeeded(foldMemRefCast(*this, value()))) 2599 return getResult(); 2600 return OpFoldResult(); 2601 } 2602 2603 //===----------------------------------------------------------------------===// 2604 // TableGen'd op method definitions 2605 //===----------------------------------------------------------------------===// 2606 2607 #define GET_OP_CLASSES 2608 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2609