1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 if (arith::ConstantOp::isBuildableWith(value, type)) 34 return builder.create<arith::ConstantOp>(loc, value, type); 35 if (ConstantOp::isBuildableWith(value, type)) 36 return builder.create<ConstantOp>(loc, value, type); 37 return nullptr; 38 } 39 40 //===----------------------------------------------------------------------===// 41 // Common canonicalization pattern support logic 42 //===----------------------------------------------------------------------===// 43 44 /// This is a common class used for patterns of the form 45 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 46 /// into the root operation directly. 47 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) { 48 bool folded = false; 49 for (OpOperand &operand : op->getOpOperands()) { 50 auto cast = operand.get().getDefiningOp<CastOp>(); 51 if (cast && operand.get() != inner && 52 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 53 operand.set(cast.getOperand()); 54 folded = true; 55 } 56 } 57 return success(folded); 58 } 59 60 /// Return an unranked/ranked tensor type for the given unranked/ranked memref 61 /// type. 62 Type mlir::memref::getTensorTypeFromMemRefType(Type type) { 63 if (auto memref = type.dyn_cast<MemRefType>()) 64 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 65 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 66 return UnrankedTensorType::get(memref.getElementType()); 67 return NoneType::get(type.getContext()); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // AllocOp / AllocaOp 72 //===----------------------------------------------------------------------===// 73 74 template <typename AllocLikeOp> 75 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 76 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 77 "applies to only alloc or alloca"); 78 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 79 if (!memRefType) 80 return op.emitOpError("result must be a memref"); 81 82 if (static_cast<int64_t>(op.dynamicSizes().size()) != 83 memRefType.getNumDynamicDims()) 84 return op.emitOpError("dimension operand count does not equal memref " 85 "dynamic dimension count"); 86 87 unsigned numSymbols = 0; 88 if (!memRefType.getLayout().isIdentity()) 89 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 90 if (op.symbolOperands().size() != numSymbols) 91 return op.emitOpError("symbol operand count does not equal memref symbol " 92 "count: expected ") 93 << numSymbols << ", got " << op.symbolOperands().size(); 94 95 return success(); 96 } 97 98 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 99 100 static LogicalResult verify(AllocaOp op) { 101 // An alloca op needs to have an ancestor with an allocation scope trait. 102 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 103 return op.emitOpError( 104 "requires an ancestor op with AutomaticAllocationScope trait"); 105 106 return verifyAllocLikeOp(op); 107 } 108 109 namespace { 110 /// Fold constant dimensions into an alloc like operation. 111 template <typename AllocLikeOp> 112 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 113 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 114 115 LogicalResult matchAndRewrite(AllocLikeOp alloc, 116 PatternRewriter &rewriter) const override { 117 // Check to see if any dimensions operands are constants. If so, we can 118 // substitute and drop them. 119 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 120 return matchPattern(operand, matchConstantIndex()); 121 })) 122 return failure(); 123 124 auto memrefType = alloc.getType(); 125 126 // Ok, we have one or more constant operands. Collect the non-constant ones 127 // and keep track of the resultant memref type to build. 128 SmallVector<int64_t, 4> newShapeConstants; 129 newShapeConstants.reserve(memrefType.getRank()); 130 SmallVector<Value, 4> dynamicSizes; 131 132 unsigned dynamicDimPos = 0; 133 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 134 int64_t dimSize = memrefType.getDimSize(dim); 135 // If this is already static dimension, keep it. 136 if (dimSize != -1) { 137 newShapeConstants.push_back(dimSize); 138 continue; 139 } 140 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 141 auto *defOp = dynamicSize.getDefiningOp(); 142 if (auto constantIndexOp = 143 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 144 // Dynamic shape dimension will be folded. 145 newShapeConstants.push_back(constantIndexOp.value()); 146 } else { 147 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 148 newShapeConstants.push_back(-1); 149 dynamicSizes.push_back(dynamicSize); 150 } 151 dynamicDimPos++; 152 } 153 154 // Create new memref type (which will have fewer dynamic dimensions). 155 MemRefType newMemRefType = 156 MemRefType::Builder(memrefType).setShape(newShapeConstants); 157 assert(static_cast<int64_t>(dynamicSizes.size()) == 158 newMemRefType.getNumDynamicDims()); 159 160 // Create and insert the alloc op for the new memref. 161 auto newAlloc = rewriter.create<AllocLikeOp>( 162 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 163 alloc.alignmentAttr()); 164 // Insert a cast so we have the same type as the old alloc. 165 auto resultCast = 166 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 167 168 rewriter.replaceOp(alloc, {resultCast}); 169 return success(); 170 } 171 }; 172 173 /// Fold alloc operations with no users or only store and dealloc uses. 174 template <typename T> 175 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 176 using OpRewritePattern<T>::OpRewritePattern; 177 178 LogicalResult matchAndRewrite(T alloc, 179 PatternRewriter &rewriter) const override { 180 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 181 if (auto storeOp = dyn_cast<StoreOp>(op)) 182 return storeOp.value() == alloc; 183 return !isa<DeallocOp>(op); 184 })) 185 return failure(); 186 187 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 188 rewriter.eraseOp(user); 189 190 rewriter.eraseOp(alloc); 191 return success(); 192 } 193 }; 194 } // namespace 195 196 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 197 MLIRContext *context) { 198 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 199 } 200 201 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 202 MLIRContext *context) { 203 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 204 context); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // AllocaScopeOp 209 //===----------------------------------------------------------------------===// 210 211 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 212 bool printBlockTerminators = false; 213 214 p << " "; 215 if (!op.results().empty()) { 216 p << " -> (" << op.getResultTypes() << ")"; 217 printBlockTerminators = true; 218 } 219 p.printRegion(op.bodyRegion(), 220 /*printEntryBlockArgs=*/false, 221 /*printBlockTerminators=*/printBlockTerminators); 222 p.printOptionalAttrDict(op->getAttrs()); 223 } 224 225 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 226 OperationState &result) { 227 // Create a region for the body. 228 result.regions.reserve(1); 229 Region *bodyRegion = result.addRegion(); 230 231 // Parse optional results type list. 232 if (parser.parseOptionalArrowTypeList(result.types)) 233 return failure(); 234 235 // Parse the body region. 236 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 237 return failure(); 238 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 239 result.location); 240 241 // Parse the optional attribute list. 242 if (parser.parseOptionalAttrDict(result.attributes)) 243 return failure(); 244 245 return success(); 246 } 247 248 static LogicalResult verify(AllocaScopeOp op) { 249 if (failed(RegionBranchOpInterface::verifyTypes(op))) 250 return failure(); 251 252 return success(); 253 } 254 255 void AllocaScopeOp::getSuccessorRegions( 256 Optional<unsigned> index, ArrayRef<Attribute> operands, 257 SmallVectorImpl<RegionSuccessor> ®ions) { 258 if (index.hasValue()) { 259 regions.push_back(RegionSuccessor(getResults())); 260 return; 261 } 262 263 regions.push_back(RegionSuccessor(&bodyRegion())); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // AssumeAlignmentOp 268 //===----------------------------------------------------------------------===// 269 270 static LogicalResult verify(AssumeAlignmentOp op) { 271 unsigned alignment = op.alignment(); 272 if (!llvm::isPowerOf2_32(alignment)) 273 return op.emitOpError("alignment must be power of 2"); 274 return success(); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // CastOp 279 //===----------------------------------------------------------------------===// 280 281 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 282 /// source memref. This is useful to to fold a memref.cast into a consuming op 283 /// and implement canonicalization patterns for ops in different dialects that 284 /// may consume the results of memref.cast operations. Such foldable memref.cast 285 /// operations are typically inserted as `view` and `subview` ops are 286 /// canonicalized, to preserve the type compatibility of their uses. 287 /// 288 /// Returns true when all conditions are met: 289 /// 1. source and result are ranked memrefs with strided semantics and same 290 /// element type and rank. 291 /// 2. each of the source's size, offset or stride has more static information 292 /// than the corresponding result's size, offset or stride. 293 /// 294 /// Example 1: 295 /// ```mlir 296 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 297 /// %2 = consumer %1 ... : memref<?x?xf32> ... 298 /// ``` 299 /// 300 /// may fold into: 301 /// 302 /// ```mlir 303 /// %2 = consumer %0 ... : memref<8x16xf32> ... 304 /// ``` 305 /// 306 /// Example 2: 307 /// ``` 308 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 309 /// to memref<?x?xf32> 310 /// consumer %1 : memref<?x?xf32> ... 311 /// ``` 312 /// 313 /// may fold into: 314 /// 315 /// ``` 316 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 317 /// ``` 318 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 319 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 320 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 321 322 // Requires ranked MemRefType. 323 if (!sourceType || !resultType) 324 return false; 325 326 // Requires same elemental type. 327 if (sourceType.getElementType() != resultType.getElementType()) 328 return false; 329 330 // Requires same rank. 331 if (sourceType.getRank() != resultType.getRank()) 332 return false; 333 334 // Only fold casts between strided memref forms. 335 int64_t sourceOffset, resultOffset; 336 SmallVector<int64_t, 4> sourceStrides, resultStrides; 337 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 338 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 339 return false; 340 341 // If cast is towards more static sizes along any dimension, don't fold. 342 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 343 auto ss = std::get<0>(it), st = std::get<1>(it); 344 if (ss != st) 345 if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) 346 return false; 347 } 348 349 // If cast is towards more static offset along any dimension, don't fold. 350 if (sourceOffset != resultOffset) 351 if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && 352 !ShapedType::isDynamicStrideOrOffset(resultOffset)) 353 return false; 354 355 // If cast is towards more static strides along any dimension, don't fold. 356 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 357 auto ss = std::get<0>(it), st = std::get<1>(it); 358 if (ss != st) 359 if (ShapedType::isDynamicStrideOrOffset(ss) && 360 !ShapedType::isDynamicStrideOrOffset(st)) 361 return false; 362 } 363 364 return true; 365 } 366 367 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 368 if (inputs.size() != 1 || outputs.size() != 1) 369 return false; 370 Type a = inputs.front(), b = outputs.front(); 371 auto aT = a.dyn_cast<MemRefType>(); 372 auto bT = b.dyn_cast<MemRefType>(); 373 374 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 375 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 376 377 if (aT && bT) { 378 if (aT.getElementType() != bT.getElementType()) 379 return false; 380 if (aT.getLayout() != bT.getLayout()) { 381 int64_t aOffset, bOffset; 382 SmallVector<int64_t, 4> aStrides, bStrides; 383 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 384 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 385 aStrides.size() != bStrides.size()) 386 return false; 387 388 // Strides along a dimension/offset are compatible if the value in the 389 // source memref is static and the value in the target memref is the 390 // same. They are also compatible if either one is dynamic (see 391 // description of MemRefCastOp for details). 392 auto checkCompatible = [](int64_t a, int64_t b) { 393 return (a == MemRefType::getDynamicStrideOrOffset() || 394 b == MemRefType::getDynamicStrideOrOffset() || a == b); 395 }; 396 if (!checkCompatible(aOffset, bOffset)) 397 return false; 398 for (const auto &aStride : enumerate(aStrides)) 399 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 400 return false; 401 } 402 if (aT.getMemorySpace() != bT.getMemorySpace()) 403 return false; 404 405 // They must have the same rank, and any specified dimensions must match. 406 if (aT.getRank() != bT.getRank()) 407 return false; 408 409 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 410 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 411 if (aDim != -1 && bDim != -1 && aDim != bDim) 412 return false; 413 } 414 return true; 415 } else { 416 if (!aT && !uaT) 417 return false; 418 if (!bT && !ubT) 419 return false; 420 // Unranked to unranked casting is unsupported 421 if (uaT && ubT) 422 return false; 423 424 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 425 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 426 if (aEltType != bEltType) 427 return false; 428 429 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 430 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 431 return aMemSpace == bMemSpace; 432 } 433 434 return false; 435 } 436 437 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 438 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // DeallocOp 443 //===----------------------------------------------------------------------===// 444 445 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 446 SmallVectorImpl<OpFoldResult> &results) { 447 /// dealloc(memrefcast) -> dealloc 448 return foldMemRefCast(*this); 449 } 450 451 //===----------------------------------------------------------------------===// 452 // DimOp 453 //===----------------------------------------------------------------------===// 454 455 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 456 int64_t index) { 457 auto loc = result.location; 458 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 459 build(builder, result, source, indexValue); 460 } 461 462 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 463 Value index) { 464 auto indexTy = builder.getIndexType(); 465 build(builder, result, indexTy, source, index); 466 } 467 468 Optional<int64_t> DimOp::getConstantIndex() { 469 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 470 return constantOp.getValue().cast<IntegerAttr>().getInt(); 471 return {}; 472 } 473 474 static LogicalResult verify(DimOp op) { 475 // Assume unknown index to be in range. 476 Optional<int64_t> index = op.getConstantIndex(); 477 if (!index.hasValue()) 478 return success(); 479 480 // Check that constant index is not knowingly out of range. 481 auto type = op.source().getType(); 482 if (auto memrefType = type.dyn_cast<MemRefType>()) { 483 if (index.getValue() >= memrefType.getRank()) 484 return op.emitOpError("index is out of range"); 485 } else if (type.isa<UnrankedMemRefType>()) { 486 // Assume index to be in range. 487 } else { 488 llvm_unreachable("expected operand with memref type"); 489 } 490 return success(); 491 } 492 493 /// Return a map with key being elements in `vals` and data being number of 494 /// occurences of it. Use std::map, since the `vals` here are strides and the 495 /// dynamic stride value is the same as the tombstone value for 496 /// `DenseMap<int64_t>`. 497 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 498 std::map<int64_t, unsigned> numOccurences; 499 for (auto val : vals) 500 numOccurences[val]++; 501 return numOccurences; 502 } 503 504 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed 505 /// to be a subset of `originalType` with some `1` entries erased, return the 506 /// set of indices that specifies which of the entries of `originalShape` are 507 /// dropped to obtain `reducedShape`. 508 /// This accounts for cases where there are multiple unit-dims, but only a 509 /// subset of those are dropped. For MemRefTypes these can be disambiguated 510 /// using the strides. If a dimension is dropped the stride must be dropped too. 511 static llvm::Optional<llvm::SmallDenseSet<unsigned>> 512 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 513 ArrayRef<OpFoldResult> sizes) { 514 llvm::SmallDenseSet<unsigned> unusedDims; 515 if (originalType.getRank() == reducedType.getRank()) 516 return unusedDims; 517 518 for (const auto &dim : llvm::enumerate(sizes)) 519 if (auto attr = dim.value().dyn_cast<Attribute>()) 520 if (attr.cast<IntegerAttr>().getInt() == 1) 521 unusedDims.insert(dim.index()); 522 523 SmallVector<int64_t> originalStrides, candidateStrides; 524 int64_t originalOffset, candidateOffset; 525 if (failed( 526 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 527 failed( 528 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 529 return llvm::None; 530 531 // For memrefs, a dimension is truly dropped if its corresponding stride is 532 // also dropped. This is particularly important when more than one of the dims 533 // is 1. Track the number of occurences of the strides in the original type 534 // and the candidate type. For each unused dim that stride should not be 535 // present in the candidate type. Note that there could be multiple dimensions 536 // that have the same size. We dont need to exactly figure out which dim 537 // corresponds to which stride, we just need to verify that the number of 538 // reptitions of a stride in the original + number of unused dims with that 539 // stride == number of repititions of a stride in the candidate. 540 std::map<int64_t, unsigned> currUnaccountedStrides = 541 getNumOccurences(originalStrides); 542 std::map<int64_t, unsigned> candidateStridesNumOccurences = 543 getNumOccurences(candidateStrides); 544 llvm::SmallDenseSet<unsigned> prunedUnusedDims; 545 for (unsigned dim : unusedDims) { 546 int64_t originalStride = originalStrides[dim]; 547 if (currUnaccountedStrides[originalStride] > 548 candidateStridesNumOccurences[originalStride]) { 549 // This dim can be treated as dropped. 550 currUnaccountedStrides[originalStride]--; 551 continue; 552 } 553 if (currUnaccountedStrides[originalStride] == 554 candidateStridesNumOccurences[originalStride]) { 555 // The stride for this is not dropped. Keep as is. 556 prunedUnusedDims.insert(dim); 557 continue; 558 } 559 if (currUnaccountedStrides[originalStride] < 560 candidateStridesNumOccurences[originalStride]) { 561 // This should never happen. Cant have a stride in the reduced rank type 562 // that wasnt in the original one. 563 return llvm::None; 564 } 565 } 566 567 for (auto prunedDim : prunedUnusedDims) 568 unusedDims.erase(prunedDim); 569 if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) 570 return llvm::None; 571 return unusedDims; 572 } 573 574 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() { 575 MemRefType sourceType = getSourceType(); 576 MemRefType resultType = getType(); 577 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 578 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes()); 579 assert(unusedDims && "unable to find unused dims of subview"); 580 return *unusedDims; 581 } 582 583 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 584 // All forms of folding require a known index. 585 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 586 if (!index) 587 return {}; 588 589 // Folding for unranked types (UnrankedMemRefType) is not supported. 590 auto memrefType = source().getType().dyn_cast<MemRefType>(); 591 if (!memrefType) 592 return {}; 593 594 // Fold if the shape extent along the given index is known. 595 if (!memrefType.isDynamicDim(index.getInt())) { 596 Builder builder(getContext()); 597 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 598 } 599 600 // The size at the given index is now known to be a dynamic size. 601 unsigned unsignedIndex = index.getValue().getZExtValue(); 602 603 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 604 Operation *definingOp = source().getDefiningOp(); 605 606 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 607 return *(alloc.getDynamicSizes().begin() + 608 memrefType.getDynamicDimIndex(unsignedIndex)); 609 610 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 611 return *(alloca.getDynamicSizes().begin() + 612 memrefType.getDynamicDimIndex(unsignedIndex)); 613 614 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 615 return *(view.getDynamicSizes().begin() + 616 memrefType.getDynamicDimIndex(unsignedIndex)); 617 618 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 619 llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims(); 620 unsigned resultIndex = 0; 621 unsigned sourceRank = subview.getSourceType().getRank(); 622 unsigned sourceIndex = 0; 623 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 624 if (unusedDims.count(i)) 625 continue; 626 if (resultIndex == unsignedIndex) { 627 sourceIndex = i; 628 break; 629 } 630 resultIndex++; 631 } 632 assert(subview.isDynamicSize(sourceIndex) && 633 "expected dynamic subview size"); 634 return subview.getDynamicSize(sourceIndex); 635 } 636 637 if (auto sizeInterface = 638 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 639 assert(sizeInterface.isDynamicSize(unsignedIndex) && 640 "Expected dynamic subview size"); 641 return sizeInterface.getDynamicSize(unsignedIndex); 642 } 643 644 // dim(memrefcast) -> dim 645 if (succeeded(foldMemRefCast(*this))) 646 return getResult(); 647 648 return {}; 649 } 650 651 namespace { 652 /// Fold dim of a memref reshape operation to a load into the reshape's shape 653 /// operand. 654 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 655 using OpRewritePattern<DimOp>::OpRewritePattern; 656 657 LogicalResult matchAndRewrite(DimOp dim, 658 PatternRewriter &rewriter) const override { 659 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 660 661 if (!reshape) 662 return failure(); 663 664 // Place the load directly after the reshape to ensure that the shape memref 665 // was not mutated. 666 rewriter.setInsertionPointAfter(reshape); 667 Location loc = dim.getLoc(); 668 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 669 if (load.getType() != dim.getType()) 670 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 671 rewriter.replaceOp(dim, load); 672 return success(); 673 } 674 }; 675 676 } // namespace 677 678 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 679 MLIRContext *context) { 680 results.add<DimOfMemRefReshape>(context); 681 } 682 683 // --------------------------------------------------------------------------- 684 // DmaStartOp 685 // --------------------------------------------------------------------------- 686 687 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 688 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 689 ValueRange destIndices, Value numElements, 690 Value tagMemRef, ValueRange tagIndices, Value stride, 691 Value elementsPerStride) { 692 result.addOperands(srcMemRef); 693 result.addOperands(srcIndices); 694 result.addOperands(destMemRef); 695 result.addOperands(destIndices); 696 result.addOperands({numElements, tagMemRef}); 697 result.addOperands(tagIndices); 698 if (stride) 699 result.addOperands({stride, elementsPerStride}); 700 } 701 702 static void print(OpAsmPrinter &p, DmaStartOp op) { 703 p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " 704 << op.getDstMemRef() << '[' << op.getDstIndices() << "], " 705 << op.getNumElements() << ", " << op.getTagMemRef() << '[' 706 << op.getTagIndices() << ']'; 707 if (op.isStrided()) 708 p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); 709 710 p.printOptionalAttrDict(op->getAttrs()); 711 p << " : " << op.getSrcMemRef().getType() << ", " 712 << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); 713 } 714 715 // Parse DmaStartOp. 716 // Ex: 717 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 718 // %tag[%index], %stride, %num_elt_per_stride : 719 // : memref<3076 x f32, 0>, 720 // memref<1024 x f32, 2>, 721 // memref<1 x i32> 722 // 723 static ParseResult parseDmaStartOp(OpAsmParser &parser, 724 OperationState &result) { 725 OpAsmParser::OperandType srcMemRefInfo; 726 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 727 OpAsmParser::OperandType dstMemRefInfo; 728 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 729 OpAsmParser::OperandType numElementsInfo; 730 OpAsmParser::OperandType tagMemrefInfo; 731 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 732 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 733 734 SmallVector<Type, 3> types; 735 auto indexType = parser.getBuilder().getIndexType(); 736 737 // Parse and resolve the following list of operands: 738 // *) source memref followed by its indices (in square brackets). 739 // *) destination memref followed by its indices (in square brackets). 740 // *) dma size in KiB. 741 if (parser.parseOperand(srcMemRefInfo) || 742 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 743 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 744 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 745 parser.parseComma() || parser.parseOperand(numElementsInfo) || 746 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 747 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 748 return failure(); 749 750 // Parse optional stride and elements per stride. 751 if (parser.parseTrailingOperandList(strideInfo)) 752 return failure(); 753 754 bool isStrided = strideInfo.size() == 2; 755 if (!strideInfo.empty() && !isStrided) { 756 return parser.emitError(parser.getNameLoc(), 757 "expected two stride related operands"); 758 } 759 760 if (parser.parseColonTypeList(types)) 761 return failure(); 762 if (types.size() != 3) 763 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 764 765 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 766 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 767 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 768 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 769 // size should be an index. 770 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 771 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 772 // tag indices should be index. 773 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 774 return failure(); 775 776 if (isStrided) { 777 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 778 return failure(); 779 } 780 781 return success(); 782 } 783 784 static LogicalResult verify(DmaStartOp op) { 785 unsigned numOperands = op.getNumOperands(); 786 787 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 788 // the number of elements. 789 if (numOperands < 4) 790 return op.emitOpError("expected at least 4 operands"); 791 792 // Check types of operands. The order of these calls is important: the later 793 // calls rely on some type properties to compute the operand position. 794 // 1. Source memref. 795 if (!op.getSrcMemRef().getType().isa<MemRefType>()) 796 return op.emitOpError("expected source to be of memref type"); 797 if (numOperands < op.getSrcMemRefRank() + 4) 798 return op.emitOpError() 799 << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; 800 if (!op.getSrcIndices().empty() && 801 !llvm::all_of(op.getSrcIndices().getTypes(), 802 [](Type t) { return t.isIndex(); })) 803 return op.emitOpError("expected source indices to be of index type"); 804 805 // 2. Destination memref. 806 if (!op.getDstMemRef().getType().isa<MemRefType>()) 807 return op.emitOpError("expected destination to be of memref type"); 808 unsigned numExpectedOperands = 809 op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; 810 if (numOperands < numExpectedOperands) 811 return op.emitOpError() 812 << "expected at least " << numExpectedOperands << " operands"; 813 if (!op.getDstIndices().empty() && 814 !llvm::all_of(op.getDstIndices().getTypes(), 815 [](Type t) { return t.isIndex(); })) 816 return op.emitOpError("expected destination indices to be of index type"); 817 818 // 3. Number of elements. 819 if (!op.getNumElements().getType().isIndex()) 820 return op.emitOpError("expected num elements to be of index type"); 821 822 // 4. Tag memref. 823 if (!op.getTagMemRef().getType().isa<MemRefType>()) 824 return op.emitOpError("expected tag to be of memref type"); 825 numExpectedOperands += op.getTagMemRefRank(); 826 if (numOperands < numExpectedOperands) 827 return op.emitOpError() 828 << "expected at least " << numExpectedOperands << " operands"; 829 if (!op.getTagIndices().empty() && 830 !llvm::all_of(op.getTagIndices().getTypes(), 831 [](Type t) { return t.isIndex(); })) 832 return op.emitOpError("expected tag indices to be of index type"); 833 834 // Optional stride-related operands must be either both present or both 835 // absent. 836 if (numOperands != numExpectedOperands && 837 numOperands != numExpectedOperands + 2) 838 return op.emitOpError("incorrect number of operands"); 839 840 // 5. Strides. 841 if (op.isStrided()) { 842 if (!op.getStride().getType().isIndex() || 843 !op.getNumElementsPerStride().getType().isIndex()) 844 return op.emitOpError( 845 "expected stride and num elements per stride to be of type index"); 846 } 847 848 return success(); 849 } 850 851 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 852 SmallVectorImpl<OpFoldResult> &results) { 853 /// dma_start(memrefcast) -> dma_start 854 return foldMemRefCast(*this); 855 } 856 857 // --------------------------------------------------------------------------- 858 // DmaWaitOp 859 // --------------------------------------------------------------------------- 860 861 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 862 SmallVectorImpl<OpFoldResult> &results) { 863 /// dma_wait(memrefcast) -> dma_wait 864 return foldMemRefCast(*this); 865 } 866 867 static LogicalResult verify(DmaWaitOp op) { 868 // Check that the number of tag indices matches the tagMemRef rank. 869 unsigned numTagIndices = op.tagIndices().size(); 870 unsigned tagMemRefRank = op.getTagMemRefRank(); 871 if (numTagIndices != tagMemRefRank) 872 return op.emitOpError() << "expected tagIndices to have the same number of " 873 "elements as the tagMemRef rank, expected " 874 << tagMemRefRank << ", but got " << numTagIndices; 875 return success(); 876 } 877 878 //===----------------------------------------------------------------------===// 879 // GlobalOp 880 //===----------------------------------------------------------------------===// 881 882 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 883 TypeAttr type, 884 Attribute initialValue) { 885 p << type; 886 if (!op.isExternal()) { 887 p << " = "; 888 if (op.isUninitialized()) 889 p << "uninitialized"; 890 else 891 p.printAttributeWithoutType(initialValue); 892 } 893 } 894 895 static ParseResult 896 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 897 Attribute &initialValue) { 898 Type type; 899 if (parser.parseType(type)) 900 return failure(); 901 902 auto memrefType = type.dyn_cast<MemRefType>(); 903 if (!memrefType || !memrefType.hasStaticShape()) 904 return parser.emitError(parser.getNameLoc()) 905 << "type should be static shaped memref, but got " << type; 906 typeAttr = TypeAttr::get(type); 907 908 if (parser.parseOptionalEqual()) 909 return success(); 910 911 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 912 initialValue = UnitAttr::get(parser.getContext()); 913 return success(); 914 } 915 916 Type tensorType = getTensorTypeFromMemRefType(memrefType); 917 if (parser.parseAttribute(initialValue, tensorType)) 918 return failure(); 919 if (!initialValue.isa<ElementsAttr>()) 920 return parser.emitError(parser.getNameLoc()) 921 << "initial value should be a unit or elements attribute"; 922 return success(); 923 } 924 925 static LogicalResult verify(GlobalOp op) { 926 auto memrefType = op.type().dyn_cast<MemRefType>(); 927 if (!memrefType || !memrefType.hasStaticShape()) 928 return op.emitOpError("type should be static shaped memref, but got ") 929 << op.type(); 930 931 // Verify that the initial value, if present, is either a unit attribute or 932 // an elements attribute. 933 if (op.initial_value().hasValue()) { 934 Attribute initValue = op.initial_value().getValue(); 935 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 936 return op.emitOpError("initial value should be a unit or elements " 937 "attribute, but got ") 938 << initValue; 939 940 // Check that the type of the initial value is compatible with the type of 941 // the global variable. 942 if (initValue.isa<ElementsAttr>()) { 943 Type initType = initValue.getType(); 944 Type tensorType = getTensorTypeFromMemRefType(memrefType); 945 if (initType != tensorType) 946 return op.emitOpError("initial value expected to be of type ") 947 << tensorType << ", but was of type " << initType; 948 } 949 } 950 951 if (Optional<uint64_t> alignAttr = op.alignment()) { 952 uint64_t alignment = alignAttr.getValue(); 953 954 if (!llvm::isPowerOf2_64(alignment)) 955 return op->emitError() << "alignment attribute value " << alignment 956 << " is not a power of 2"; 957 } 958 959 // TODO: verify visibility for declarations. 960 return success(); 961 } 962 963 //===----------------------------------------------------------------------===// 964 // GetGlobalOp 965 //===----------------------------------------------------------------------===// 966 967 LogicalResult 968 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 969 // Verify that the result type is same as the type of the referenced 970 // memref.global op. 971 auto global = 972 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 973 if (!global) 974 return emitOpError("'") 975 << name() << "' does not reference a valid global memref"; 976 977 Type resultType = result().getType(); 978 if (global.type() != resultType) 979 return emitOpError("result type ") 980 << resultType << " does not match type " << global.type() 981 << " of the global memref @" << name(); 982 return success(); 983 } 984 985 //===----------------------------------------------------------------------===// 986 // LoadOp 987 //===----------------------------------------------------------------------===// 988 989 static LogicalResult verify(LoadOp op) { 990 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 991 return op.emitOpError("incorrect number of indices for load"); 992 return success(); 993 } 994 995 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 996 /// load(memrefcast) -> load 997 if (succeeded(foldMemRefCast(*this))) 998 return getResult(); 999 return OpFoldResult(); 1000 } 1001 1002 //===----------------------------------------------------------------------===// 1003 // PrefetchOp 1004 //===----------------------------------------------------------------------===// 1005 1006 static void print(OpAsmPrinter &p, PrefetchOp op) { 1007 p << " " << op.memref() << '['; 1008 p.printOperands(op.indices()); 1009 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1010 p << ", locality<" << op.localityHint(); 1011 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1012 p.printOptionalAttrDict( 1013 op->getAttrs(), 1014 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1015 p << " : " << op.getMemRefType(); 1016 } 1017 1018 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1019 OperationState &result) { 1020 OpAsmParser::OperandType memrefInfo; 1021 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1022 IntegerAttr localityHint; 1023 MemRefType type; 1024 StringRef readOrWrite, cacheType; 1025 1026 auto indexTy = parser.getBuilder().getIndexType(); 1027 auto i32Type = parser.getBuilder().getIntegerType(32); 1028 if (parser.parseOperand(memrefInfo) || 1029 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1030 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1031 parser.parseComma() || parser.parseKeyword("locality") || 1032 parser.parseLess() || 1033 parser.parseAttribute(localityHint, i32Type, "localityHint", 1034 result.attributes) || 1035 parser.parseGreater() || parser.parseComma() || 1036 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1037 parser.resolveOperand(memrefInfo, type, result.operands) || 1038 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1039 return failure(); 1040 1041 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1042 return parser.emitError(parser.getNameLoc(), 1043 "rw specifier has to be 'read' or 'write'"); 1044 result.addAttribute( 1045 PrefetchOp::getIsWriteAttrName(), 1046 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1047 1048 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1049 return parser.emitError(parser.getNameLoc(), 1050 "cache type has to be 'data' or 'instr'"); 1051 1052 result.addAttribute( 1053 PrefetchOp::getIsDataCacheAttrName(), 1054 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1055 1056 return success(); 1057 } 1058 1059 static LogicalResult verify(PrefetchOp op) { 1060 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1061 return op.emitOpError("too few indices"); 1062 1063 return success(); 1064 } 1065 1066 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1067 SmallVectorImpl<OpFoldResult> &results) { 1068 // prefetch(memrefcast) -> prefetch 1069 return foldMemRefCast(*this); 1070 } 1071 1072 //===----------------------------------------------------------------------===// 1073 // RankOp 1074 //===----------------------------------------------------------------------===// 1075 1076 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1077 // Constant fold rank when the rank of the operand is known. 1078 auto type = getOperand().getType(); 1079 auto shapedType = type.dyn_cast<ShapedType>(); 1080 if (shapedType && shapedType.hasRank()) 1081 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1082 return IntegerAttr(); 1083 } 1084 1085 //===----------------------------------------------------------------------===// 1086 // ReinterpretCastOp 1087 //===----------------------------------------------------------------------===// 1088 1089 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1090 /// `staticSizes` and `staticStrides` are automatically filled with 1091 /// source-memref-rank sentinel values that encode dynamic entries. 1092 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1093 MemRefType resultType, Value source, 1094 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1095 ArrayRef<OpFoldResult> strides, 1096 ArrayRef<NamedAttribute> attrs) { 1097 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1098 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1099 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1100 ShapedType::kDynamicStrideOrOffset); 1101 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1102 ShapedType::kDynamicSize); 1103 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1104 ShapedType::kDynamicStrideOrOffset); 1105 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1106 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1107 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1108 result.addAttributes(attrs); 1109 } 1110 1111 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1112 MemRefType resultType, Value source, 1113 int64_t offset, ArrayRef<int64_t> sizes, 1114 ArrayRef<int64_t> strides, 1115 ArrayRef<NamedAttribute> attrs) { 1116 SmallVector<OpFoldResult> sizeValues = 1117 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1118 return b.getI64IntegerAttr(v); 1119 })); 1120 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1121 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1122 return b.getI64IntegerAttr(v); 1123 })); 1124 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1125 strideValues, attrs); 1126 } 1127 1128 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1129 MemRefType resultType, Value source, Value offset, 1130 ValueRange sizes, ValueRange strides, 1131 ArrayRef<NamedAttribute> attrs) { 1132 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1133 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1134 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1135 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1136 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1137 } 1138 1139 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1140 // completed automatically, like we have for subview and extract_slice. 1141 static LogicalResult verify(ReinterpretCastOp op) { 1142 // The source and result memrefs should be in the same memory space. 1143 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1144 auto resultType = op.getType().cast<MemRefType>(); 1145 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1146 return op.emitError("different memory spaces specified for source type ") 1147 << srcType << " and result memref type " << resultType; 1148 if (srcType.getElementType() != resultType.getElementType()) 1149 return op.emitError("different element types specified for source type ") 1150 << srcType << " and result memref type " << resultType; 1151 1152 // Match sizes in result memref type and in static_sizes attribute. 1153 for (auto &en : 1154 llvm::enumerate(llvm::zip(resultType.getShape(), 1155 extractFromI64ArrayAttr(op.static_sizes())))) { 1156 int64_t resultSize = std::get<0>(en.value()); 1157 int64_t expectedSize = std::get<1>(en.value()); 1158 if (!ShapedType::isDynamic(resultSize) && 1159 !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) 1160 return op.emitError("expected result type with size = ") 1161 << expectedSize << " instead of " << resultSize 1162 << " in dim = " << en.index(); 1163 } 1164 1165 // Match offset and strides in static_offset and static_strides attributes. If 1166 // result memref type has no affine map specified, this will assume an 1167 // identity layout. 1168 int64_t resultOffset; 1169 SmallVector<int64_t, 4> resultStrides; 1170 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1171 return op.emitError( 1172 "expected result type to have strided layout but found ") 1173 << resultType; 1174 1175 // Match offset in result memref type and in static_offsets attribute. 1176 int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); 1177 if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && 1178 !ShapedType::isDynamicStrideOrOffset(expectedOffset) && 1179 resultOffset != expectedOffset) 1180 return op.emitError("expected result type with offset = ") 1181 << resultOffset << " instead of " << expectedOffset; 1182 1183 // Match strides in result memref type and in static_strides attribute. 1184 for (auto &en : llvm::enumerate(llvm::zip( 1185 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1186 int64_t resultStride = std::get<0>(en.value()); 1187 int64_t expectedStride = std::get<1>(en.value()); 1188 if (!ShapedType::isDynamicStrideOrOffset(resultStride) && 1189 !ShapedType::isDynamicStrideOrOffset(expectedStride) && 1190 resultStride != expectedStride) 1191 return op.emitError("expected result type with stride = ") 1192 << expectedStride << " instead of " << resultStride 1193 << " in dim = " << en.index(); 1194 } 1195 1196 return success(); 1197 } 1198 1199 //===----------------------------------------------------------------------===// 1200 // Reassociative reshape ops 1201 //===----------------------------------------------------------------------===// 1202 1203 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1204 return getSymbolLessAffineMaps(getReassociationExprs()); 1205 } 1206 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1207 return convertReassociationIndicesToExprs(getContext(), 1208 getReassociationIndices()); 1209 } 1210 1211 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1212 return getSymbolLessAffineMaps(getReassociationExprs()); 1213 } 1214 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1215 return convertReassociationIndicesToExprs(getContext(), 1216 getReassociationIndices()); 1217 } 1218 1219 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1220 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1221 } 1222 1223 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1224 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1225 } 1226 1227 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1228 /// copies. 1229 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1230 ArrayRef<int64_t> sizes, 1231 ArrayRef<AffineExpr> strides) { 1232 // Bands of extent one can be reshaped, as they are not reshaped at all. 1233 if (extent == 1) 1234 return true; 1235 // Otherwise, the size of the first dimension needs to be known. 1236 if (ShapedType::isDynamic(sizes[dim])) 1237 return false; 1238 assert(sizes.size() == strides.size() && "mismatched ranks"); 1239 // off by 1 indexing to avoid out of bounds 1240 // V 1241 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1242 // Only bands of static shapes are reshapable. This is due to the fact that 1243 // there is no relation between dynamic sizes and dynamic strides: we do not 1244 // have enough information to know whether a "-1" size corresponds to the 1245 // proper symbol in the AffineExpr of a stride. 1246 if (ShapedType::isDynamic(sizes[idx + 1])) 1247 return false; 1248 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1249 // simplify on the fly and catch more reshapable cases. 1250 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1251 return false; 1252 } 1253 return true; 1254 } 1255 1256 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1257 /// expected to be valid) to `type`. 1258 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1259 /// MemRefType. 1260 static MemRefType 1261 computeReshapeCollapsedType(MemRefType type, 1262 ArrayRef<AffineMap> reassociation) { 1263 auto sizes = type.getShape(); 1264 AffineExpr offset; 1265 SmallVector<AffineExpr, 4> strides; 1266 auto status = getStridesAndOffset(type, strides, offset); 1267 (void)status; 1268 assert(succeeded(status) && "expected strided memref"); 1269 1270 SmallVector<int64_t, 4> newSizes; 1271 newSizes.reserve(reassociation.size()); 1272 SmallVector<AffineExpr, 4> newStrides; 1273 newStrides.reserve(reassociation.size()); 1274 1275 // Use the fact that reassociation is valid to simplify the logic: only use 1276 // each map's rank. 1277 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1278 unsigned currentDim = 0; 1279 for (AffineMap m : reassociation) { 1280 unsigned dim = m.getNumResults(); 1281 int64_t size = 1; 1282 AffineExpr stride = strides[currentDim + dim - 1]; 1283 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1284 size = ShapedType::kDynamicSize; 1285 stride = AffineExpr(); 1286 } else { 1287 for (unsigned d = 0; d < dim; ++d) 1288 size *= sizes[currentDim + d]; 1289 } 1290 newSizes.push_back(size); 1291 newStrides.push_back(stride); 1292 currentDim += dim; 1293 } 1294 1295 // Early-exit: if `type` is contiguous, the result must be contiguous. 1296 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1297 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1298 1299 // Convert back to int64_t because we don't have enough information to create 1300 // new strided layouts from AffineExpr only. This corresponds to a case where 1301 // copies may be necessary. 1302 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1303 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1304 intOffset = o.getValue(); 1305 SmallVector<int64_t, 4> intStrides; 1306 intStrides.reserve(strides.size()); 1307 for (auto stride : newStrides) { 1308 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1309 intStrides.push_back(cst.getValue()); 1310 else 1311 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1312 } 1313 auto layout = 1314 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1315 return canonicalizeStridedLayout( 1316 MemRefType::Builder(type).setShape(newSizes).setLayout( 1317 AffineMapAttr::get(layout))); 1318 } 1319 1320 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1321 ArrayRef<ReassociationIndices> reassociation, 1322 ArrayRef<NamedAttribute> attrs) { 1323 auto memRefType = src.getType().cast<MemRefType>(); 1324 auto resultType = computeReshapeCollapsedType( 1325 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1326 b.getContext(), reassociation))); 1327 build(b, result, resultType, src, attrs); 1328 result.addAttribute(getReassociationAttrName(), 1329 getReassociationIndicesAttribute(b, reassociation)); 1330 } 1331 1332 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1333 ArrayRef<ReassociationIndices> reassociation, 1334 ArrayRef<NamedAttribute> attrs) { 1335 auto memRefType = src.getType().cast<MemRefType>(); 1336 auto resultType = computeReshapeCollapsedType( 1337 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1338 b.getContext(), reassociation))); 1339 build(b, result, resultType, src, attrs); 1340 result.addAttribute(getReassociationAttrName(), 1341 getReassociationIndicesAttribute(b, reassociation)); 1342 } 1343 1344 template <typename ReshapeOp, 1345 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1346 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1347 MemRefType collapsedType) { 1348 if (failed( 1349 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1350 return failure(); 1351 auto maps = op.getReassociationMaps(); 1352 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1353 if (collapsedType != expectedType) 1354 return op.emitOpError("expected collapsed type to be ") 1355 << expectedType << ", but got " << collapsedType; 1356 return success(); 1357 } 1358 1359 static LogicalResult verify(ExpandShapeOp op) { 1360 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1361 } 1362 1363 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1364 MLIRContext *context) { 1365 results.add<CollapseReshapeOps<ExpandShapeOp>, 1366 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1367 } 1368 1369 static LogicalResult verify(CollapseShapeOp op) { 1370 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1371 } 1372 1373 struct CollapseShapeOpMemRefCastFolder 1374 : public OpRewritePattern<CollapseShapeOp> { 1375 public: 1376 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1377 1378 LogicalResult matchAndRewrite(CollapseShapeOp op, 1379 PatternRewriter &rewriter) const override { 1380 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1381 if (!cast) 1382 return failure(); 1383 1384 if (!CastOp::canFoldIntoConsumerOp(cast)) 1385 return failure(); 1386 1387 Type newResultType = computeReshapeCollapsedType( 1388 cast.getOperand().getType().cast<MemRefType>(), 1389 op.getReassociationMaps()); 1390 1391 if (newResultType == op.getResultType()) { 1392 rewriter.updateRootInPlace( 1393 op, [&]() { op.srcMutable().assign(cast.source()); }); 1394 } else { 1395 Value newOp = rewriter.create<CollapseShapeOp>( 1396 op->getLoc(), cast.source(), op.getReassociationIndices()); 1397 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1398 } 1399 return success(); 1400 } 1401 }; 1402 1403 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1404 MLIRContext *context) { 1405 results.add<CollapseReshapeOps<CollapseShapeOp>, 1406 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1407 CollapseShapeOpMemRefCastFolder>(context); 1408 } 1409 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1410 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1411 } 1412 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1413 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1414 } 1415 1416 //===----------------------------------------------------------------------===// 1417 // ReshapeOp 1418 //===----------------------------------------------------------------------===// 1419 1420 static LogicalResult verify(ReshapeOp op) { 1421 Type operandType = op.source().getType(); 1422 Type resultType = op.result().getType(); 1423 1424 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1425 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1426 if (operandElementType != resultElementType) 1427 return op.emitOpError("element types of source and destination memref " 1428 "types should be the same"); 1429 1430 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1431 if (!operandMemRefType.getLayout().isIdentity()) 1432 return op.emitOpError( 1433 "source memref type should have identity affine map"); 1434 1435 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1436 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1437 if (resultMemRefType) { 1438 if (!resultMemRefType.getLayout().isIdentity()) 1439 return op.emitOpError( 1440 "result memref type should have identity affine map"); 1441 if (shapeSize == ShapedType::kDynamicSize) 1442 return op.emitOpError("cannot use shape operand with dynamic length to " 1443 "reshape to statically-ranked memref type"); 1444 if (shapeSize != resultMemRefType.getRank()) 1445 return op.emitOpError( 1446 "length of shape operand differs from the result's memref rank"); 1447 } 1448 return success(); 1449 } 1450 1451 //===----------------------------------------------------------------------===// 1452 // StoreOp 1453 //===----------------------------------------------------------------------===// 1454 1455 static LogicalResult verify(StoreOp op) { 1456 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1457 return op.emitOpError("store index operand count not equal to memref rank"); 1458 1459 return success(); 1460 } 1461 1462 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1463 SmallVectorImpl<OpFoldResult> &results) { 1464 /// store(memrefcast) -> store 1465 return foldMemRefCast(*this, getValueToStore()); 1466 } 1467 1468 //===----------------------------------------------------------------------===// 1469 // SubViewOp 1470 //===----------------------------------------------------------------------===// 1471 1472 namespace { 1473 /// Helpers to write more idiomatic operations. 1474 namespace saturated_arith { 1475 struct Wrapper { 1476 explicit Wrapper(int64_t v) : v(v) {} 1477 operator int64_t() { return v; } 1478 int64_t v; 1479 }; 1480 Wrapper operator+(Wrapper a, int64_t b) { 1481 if (ShapedType::isDynamicStrideOrOffset(a) || 1482 ShapedType::isDynamicStrideOrOffset(b)) 1483 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1484 return Wrapper(a.v + b); 1485 } 1486 Wrapper operator*(Wrapper a, int64_t b) { 1487 if (ShapedType::isDynamicStrideOrOffset(a) || 1488 ShapedType::isDynamicStrideOrOffset(b)) 1489 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1490 return Wrapper(a.v * b); 1491 } 1492 } // namespace saturated_arith 1493 } // namespace 1494 1495 /// A subview result type can be fully inferred from the source type and the 1496 /// static representation of offsets, sizes and strides. Special sentinels 1497 /// encode the dynamic case. 1498 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1499 ArrayRef<int64_t> staticOffsets, 1500 ArrayRef<int64_t> staticSizes, 1501 ArrayRef<int64_t> staticStrides) { 1502 unsigned rank = sourceMemRefType.getRank(); 1503 (void)rank; 1504 assert(staticOffsets.size() == rank && "staticOffsets length mismatch"); 1505 assert(staticSizes.size() == rank && "staticSizes length mismatch"); 1506 assert(staticStrides.size() == rank && "staticStrides length mismatch"); 1507 1508 // Extract source offset and strides. 1509 int64_t sourceOffset; 1510 SmallVector<int64_t, 4> sourceStrides; 1511 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1512 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1513 (void)res; 1514 1515 // Compute target offset whose value is: 1516 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1517 int64_t targetOffset = sourceOffset; 1518 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1519 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1520 using namespace saturated_arith; 1521 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1522 } 1523 1524 // Compute target stride whose value is: 1525 // `sourceStrides_i * staticStrides_i`. 1526 SmallVector<int64_t, 4> targetStrides; 1527 targetStrides.reserve(staticOffsets.size()); 1528 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1529 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1530 using namespace saturated_arith; 1531 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1532 } 1533 1534 // The type is now known. 1535 return MemRefType::get( 1536 staticSizes, sourceMemRefType.getElementType(), 1537 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1538 sourceMemRefType.getContext()), 1539 sourceMemRefType.getMemorySpace()); 1540 } 1541 1542 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1543 ArrayRef<OpFoldResult> offsets, 1544 ArrayRef<OpFoldResult> sizes, 1545 ArrayRef<OpFoldResult> strides) { 1546 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1547 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1548 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1549 ShapedType::kDynamicStrideOrOffset); 1550 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1551 ShapedType::kDynamicSize); 1552 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1553 ShapedType::kDynamicStrideOrOffset); 1554 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1555 staticSizes, staticStrides); 1556 } 1557 1558 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1559 MemRefType sourceRankedTensorType, 1560 ArrayRef<int64_t> offsets, 1561 ArrayRef<int64_t> sizes, 1562 ArrayRef<int64_t> strides) { 1563 auto inferredType = 1564 inferResultType(sourceRankedTensorType, offsets, sizes, strides) 1565 .cast<MemRefType>(); 1566 assert(inferredType.getRank() >= resultRank && "expected "); 1567 int rankDiff = inferredType.getRank() - resultRank; 1568 if (rankDiff > 0) { 1569 auto shape = inferredType.getShape(); 1570 llvm::SmallDenseSet<unsigned> dimsToProject; 1571 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1572 SmallVector<int64_t> projectedShape; 1573 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1574 if (!dimsToProject.contains(pos)) 1575 projectedShape.push_back(shape[pos]); 1576 1577 AffineMap map = inferredType.getLayout().getAffineMap(); 1578 if (!map.isIdentity()) 1579 map = getProjectedMap(map, dimsToProject); 1580 inferredType = 1581 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1582 inferredType.getMemorySpace()); 1583 } 1584 return inferredType; 1585 } 1586 1587 Type SubViewOp::inferRankReducedResultType(unsigned resultRank, 1588 MemRefType sourceRankedTensorType, 1589 ArrayRef<OpFoldResult> offsets, 1590 ArrayRef<OpFoldResult> sizes, 1591 ArrayRef<OpFoldResult> strides) { 1592 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1593 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1594 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1595 ShapedType::kDynamicStrideOrOffset); 1596 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1597 ShapedType::kDynamicSize); 1598 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1599 ShapedType::kDynamicStrideOrOffset); 1600 return SubViewOp::inferRankReducedResultType( 1601 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1602 staticStrides); 1603 } 1604 // Build a SubViewOp with mixed static and dynamic entries and custom result 1605 // type. If the type passed is nullptr, it is inferred. 1606 void SubViewOp::build(OpBuilder &b, OperationState &result, 1607 MemRefType resultType, Value source, 1608 ArrayRef<OpFoldResult> offsets, 1609 ArrayRef<OpFoldResult> sizes, 1610 ArrayRef<OpFoldResult> strides, 1611 ArrayRef<NamedAttribute> attrs) { 1612 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1613 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1614 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1615 ShapedType::kDynamicStrideOrOffset); 1616 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1617 ShapedType::kDynamicSize); 1618 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1619 ShapedType::kDynamicStrideOrOffset); 1620 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1621 // Structuring implementation this way avoids duplication between builders. 1622 if (!resultType) { 1623 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1624 staticSizes, staticStrides) 1625 .cast<MemRefType>(); 1626 } 1627 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1628 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1629 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1630 result.addAttributes(attrs); 1631 } 1632 1633 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1634 // type. 1635 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1636 ArrayRef<OpFoldResult> offsets, 1637 ArrayRef<OpFoldResult> sizes, 1638 ArrayRef<OpFoldResult> strides, 1639 ArrayRef<NamedAttribute> attrs) { 1640 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1641 } 1642 1643 // Build a SubViewOp with static entries and inferred result type. 1644 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1645 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1646 ArrayRef<int64_t> strides, 1647 ArrayRef<NamedAttribute> attrs) { 1648 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1649 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1650 return b.getI64IntegerAttr(v); 1651 })); 1652 SmallVector<OpFoldResult> sizeValues = 1653 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1654 return b.getI64IntegerAttr(v); 1655 })); 1656 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1657 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1658 return b.getI64IntegerAttr(v); 1659 })); 1660 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1661 } 1662 1663 // Build a SubViewOp with dynamic entries and custom result type. If the 1664 // type passed is nullptr, it is inferred. 1665 void SubViewOp::build(OpBuilder &b, OperationState &result, 1666 MemRefType resultType, Value source, 1667 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1668 ArrayRef<int64_t> strides, 1669 ArrayRef<NamedAttribute> attrs) { 1670 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1671 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1672 return b.getI64IntegerAttr(v); 1673 })); 1674 SmallVector<OpFoldResult> sizeValues = 1675 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1676 return b.getI64IntegerAttr(v); 1677 })); 1678 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1679 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1680 return b.getI64IntegerAttr(v); 1681 })); 1682 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1683 attrs); 1684 } 1685 1686 // Build a SubViewOp with dynamic entries and custom result type. If the type 1687 // passed is nullptr, it is inferred. 1688 void SubViewOp::build(OpBuilder &b, OperationState &result, 1689 MemRefType resultType, Value source, ValueRange offsets, 1690 ValueRange sizes, ValueRange strides, 1691 ArrayRef<NamedAttribute> attrs) { 1692 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1693 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1694 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1695 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1696 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1697 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1698 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1699 } 1700 1701 // Build a SubViewOp with dynamic entries and inferred result type. 1702 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1703 ValueRange offsets, ValueRange sizes, ValueRange strides, 1704 ArrayRef<NamedAttribute> attrs) { 1705 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1706 } 1707 1708 /// For ViewLikeOpInterface. 1709 Value SubViewOp::getViewSource() { return source(); } 1710 1711 /// Return true if t1 and t2 have equal offsets (both dynamic or of same static 1712 /// value). 1713 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { 1714 AffineExpr t1Offset, t2Offset; 1715 SmallVector<AffineExpr> t1Strides, t2Strides; 1716 auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); 1717 auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); 1718 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; 1719 } 1720 1721 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1722 /// This function is slight variant of `is subsequence` algorithm where 1723 /// not matching dimension must be 1. 1724 static SliceVerificationResult 1725 isRankReducedMemRefType(MemRefType originalType, 1726 MemRefType candidateRankReducedType, 1727 ArrayRef<OpFoldResult> sizes) { 1728 auto partialRes = isRankReducedType(originalType, candidateRankReducedType); 1729 if (partialRes != SliceVerificationResult::Success) 1730 return partialRes; 1731 1732 auto optionalUnusedDimsMask = computeMemRefRankReductionMask( 1733 originalType, candidateRankReducedType, sizes); 1734 1735 // Sizes cannot be matched in case empty vector is returned. 1736 if (!optionalUnusedDimsMask.hasValue()) 1737 return SliceVerificationResult::LayoutMismatch; 1738 1739 if (originalType.getMemorySpace() != 1740 candidateRankReducedType.getMemorySpace()) 1741 return SliceVerificationResult::MemSpaceMismatch; 1742 1743 // No amount of stride dropping can reconcile incompatible offsets. 1744 if (!haveCompatibleOffsets(originalType, candidateRankReducedType)) 1745 return SliceVerificationResult::LayoutMismatch; 1746 1747 return SliceVerificationResult::Success; 1748 } 1749 1750 template <typename OpTy> 1751 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, 1752 OpTy op, Type expectedType) { 1753 auto memrefType = expectedType.cast<ShapedType>(); 1754 switch (result) { 1755 case SliceVerificationResult::Success: 1756 return success(); 1757 case SliceVerificationResult::RankTooLarge: 1758 return op.emitError("expected result rank to be smaller or equal to ") 1759 << "the source rank. "; 1760 case SliceVerificationResult::SizeMismatch: 1761 return op.emitError("expected result type to be ") 1762 << expectedType 1763 << " or a rank-reduced version. (mismatch of result sizes) "; 1764 case SliceVerificationResult::ElemTypeMismatch: 1765 return op.emitError("expected result element type to be ") 1766 << memrefType.getElementType(); 1767 case SliceVerificationResult::MemSpaceMismatch: 1768 return op.emitError("expected result and source memory spaces to match."); 1769 case SliceVerificationResult::LayoutMismatch: 1770 return op.emitError("expected result type to be ") 1771 << expectedType 1772 << " or a rank-reduced version. (mismatch of result layout) "; 1773 } 1774 llvm_unreachable("unexpected subview verification result"); 1775 } 1776 1777 /// Verifier for SubViewOp. 1778 static LogicalResult verify(SubViewOp op) { 1779 MemRefType baseType = op.getSourceType(); 1780 MemRefType subViewType = op.getType(); 1781 1782 // The base memref and the view memref should be in the same memory space. 1783 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1784 return op.emitError("different memory spaces specified for base memref " 1785 "type ") 1786 << baseType << " and subview memref type " << subViewType; 1787 1788 // Verify that the base memref type has a strided layout map. 1789 if (!isStrided(baseType)) 1790 return op.emitError("base type ") << baseType << " is not strided"; 1791 1792 // Verify result type against inferred type. 1793 auto expectedType = SubViewOp::inferResultType( 1794 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1795 extractFromI64ArrayAttr(op.static_sizes()), 1796 extractFromI64ArrayAttr(op.static_strides())); 1797 1798 auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), 1799 subViewType, op.getMixedSizes()); 1800 return produceSubViewErrorMsg(result, op, expectedType); 1801 } 1802 1803 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 1804 return os << "range " << range.offset << ":" << range.size << ":" 1805 << range.stride; 1806 } 1807 1808 /// Return the list of Range (i.e. offset, size, stride). Each Range 1809 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1810 /// with `b` at location `loc`. 1811 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1812 OpBuilder &b, Location loc) { 1813 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1814 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1815 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1816 SmallVector<Range, 8> res; 1817 unsigned rank = ranks[0]; 1818 res.reserve(rank); 1819 for (unsigned idx = 0; idx < rank; ++idx) { 1820 Value offset = 1821 op.isDynamicOffset(idx) 1822 ? op.getDynamicOffset(idx) 1823 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1824 Value size = 1825 op.isDynamicSize(idx) 1826 ? op.getDynamicSize(idx) 1827 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 1828 Value stride = 1829 op.isDynamicStride(idx) 1830 ? op.getDynamicStride(idx) 1831 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 1832 res.emplace_back(Range{offset, size, stride}); 1833 } 1834 return res; 1835 } 1836 1837 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 1838 /// deduce the result type for the given `sourceType`. Additionally, reduce the 1839 /// rank of the inferred result type if `currentResultType` is lower rank than 1840 /// `currentSourceType`. Use this signature if `sourceType` is updated together 1841 /// with the result type. In this case, it is important to compute the dropped 1842 /// dimensions using `currentSourceType` whose strides align with 1843 /// `currentResultType`. 1844 static MemRefType getCanonicalSubViewResultType( 1845 MemRefType currentResultType, MemRefType currentSourceType, 1846 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets, 1847 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { 1848 auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1849 mixedSizes, mixedStrides) 1850 .cast<MemRefType>(); 1851 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 1852 computeMemRefRankReductionMask(currentSourceType, currentResultType, 1853 mixedSizes); 1854 // Return nullptr as failure mode. 1855 if (!unusedDims) 1856 return nullptr; 1857 SmallVector<int64_t> shape; 1858 for (const auto &sizes : llvm::enumerate(nonRankReducedType.getShape())) { 1859 if (unusedDims->count(sizes.index())) 1860 continue; 1861 shape.push_back(sizes.value()); 1862 } 1863 AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap(); 1864 if (!layoutMap.isIdentity()) 1865 layoutMap = getProjectedMap(layoutMap, unusedDims.getValue()); 1866 return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap, 1867 nonRankReducedType.getMemorySpace()); 1868 } 1869 1870 /// Compute the canonical result type of a SubViewOp. Call `inferResultType` to 1871 /// deduce the result type. Additionally, reduce the rank of the inferred result 1872 /// type if `currentResultType` is lower rank than `sourceType`. 1873 static MemRefType getCanonicalSubViewResultType( 1874 MemRefType currentResultType, MemRefType sourceType, 1875 ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes, 1876 ArrayRef<OpFoldResult> mixedStrides) { 1877 return getCanonicalSubViewResultType(currentResultType, sourceType, 1878 sourceType, mixedOffsets, mixedSizes, 1879 mixedStrides); 1880 } 1881 1882 /// Helper method to check if a `subview` operation is trivially a no-op. This 1883 /// is the case if the all offsets are zero, all strides are 1, and the source 1884 /// shape is same as the size of the subview. In such cases, the subview can be 1885 /// folded into its source. 1886 static bool isTrivialSubViewOp(SubViewOp subViewOp) { 1887 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) 1888 return false; 1889 1890 auto mixedOffsets = subViewOp.getMixedOffsets(); 1891 auto mixedSizes = subViewOp.getMixedSizes(); 1892 auto mixedStrides = subViewOp.getMixedStrides(); 1893 1894 // Check offsets are zero. 1895 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) { 1896 Optional<int64_t> intValue = getConstantIntValue(ofr); 1897 return !intValue || intValue.getValue() != 0; 1898 })) 1899 return false; 1900 1901 // Check strides are one. 1902 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) { 1903 Optional<int64_t> intValue = getConstantIntValue(ofr); 1904 return !intValue || intValue.getValue() != 1; 1905 })) 1906 return false; 1907 1908 // Check all size values are static and matches the (static) source shape. 1909 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape(); 1910 for (const auto &size : llvm::enumerate(mixedSizes)) { 1911 Optional<int64_t> intValue = getConstantIntValue(size.value()); 1912 if (!intValue || intValue.getValue() != sourceShape[size.index()]) 1913 return false; 1914 } 1915 // All conditions met. The `SubViewOp` is foldable as a no-op. 1916 return true; 1917 } 1918 1919 namespace { 1920 /// Pattern to rewrite a subview op with MemRefCast arguments. 1921 /// This essentially pushes memref.cast past its consuming subview when 1922 /// `canFoldIntoConsumerOp` is true. 1923 /// 1924 /// Example: 1925 /// ``` 1926 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1927 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1928 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1929 /// ``` 1930 /// is rewritten into: 1931 /// ``` 1932 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1933 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1934 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1935 /// ``` 1936 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1937 public: 1938 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1939 1940 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1941 PatternRewriter &rewriter) const override { 1942 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1943 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1944 return matchPattern(operand, matchConstantIndex()); 1945 })) 1946 return failure(); 1947 1948 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1949 if (!castOp) 1950 return failure(); 1951 1952 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1953 return failure(); 1954 1955 // Compute the SubViewOp result type after folding the MemRefCastOp. Use the 1956 // MemRefCastOp source operand type to infer the result type and the current 1957 // SubViewOp source operand type to compute the dropped dimensions if the 1958 // operation is rank-reducing. 1959 auto resultType = getCanonicalSubViewResultType( 1960 subViewOp.getType(), subViewOp.getSourceType(), 1961 castOp.source().getType().cast<MemRefType>(), 1962 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1963 subViewOp.getMixedStrides()); 1964 if (!resultType) 1965 return failure(); 1966 1967 Value newSubView = rewriter.create<SubViewOp>( 1968 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1969 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1970 subViewOp.static_sizes(), subViewOp.static_strides()); 1971 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1972 newSubView); 1973 return success(); 1974 } 1975 }; 1976 1977 /// Canonicalize subview ops that are no-ops. When the source shape is not same 1978 /// as a result shape due to use of `affine_map`. 1979 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> { 1980 public: 1981 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1982 1983 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1984 PatternRewriter &rewriter) const override { 1985 if (!isTrivialSubViewOp(subViewOp)) 1986 return failure(); 1987 if (subViewOp.getSourceType() == subViewOp.getType()) { 1988 rewriter.replaceOp(subViewOp, subViewOp.source()); 1989 return success(); 1990 } 1991 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(), 1992 subViewOp.getType()); 1993 return success(); 1994 } 1995 }; 1996 } // namespace 1997 1998 /// Return the canonical type of the result of a subview. 1999 struct SubViewReturnTypeCanonicalizer { 2000 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2001 ArrayRef<OpFoldResult> mixedSizes, 2002 ArrayRef<OpFoldResult> mixedStrides) { 2003 return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), 2004 mixedOffsets, mixedSizes, 2005 mixedStrides); 2006 } 2007 }; 2008 2009 /// A canonicalizer wrapper to replace SubViewOps. 2010 struct SubViewCanonicalizer { 2011 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2012 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2013 } 2014 }; 2015 2016 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2017 MLIRContext *context) { 2018 results 2019 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2020 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2021 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); 2022 } 2023 2024 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2025 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2026 auto sourceShapedType = source().getType().cast<ShapedType>(); 2027 2028 if (resultShapedType.hasStaticShape() && 2029 resultShapedType == sourceShapedType) { 2030 return getViewSource(); 2031 } 2032 2033 return {}; 2034 } 2035 2036 //===----------------------------------------------------------------------===// 2037 // TransposeOp 2038 //===----------------------------------------------------------------------===// 2039 2040 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2041 static MemRefType inferTransposeResultType(MemRefType memRefType, 2042 AffineMap permutationMap) { 2043 auto rank = memRefType.getRank(); 2044 auto originalSizes = memRefType.getShape(); 2045 // Compute permuted sizes. 2046 SmallVector<int64_t, 4> sizes(rank, 0); 2047 for (const auto &en : llvm::enumerate(permutationMap.getResults())) 2048 sizes[en.index()] = 2049 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2050 2051 // Compute permuted strides. 2052 int64_t offset; 2053 SmallVector<int64_t, 4> strides; 2054 auto res = getStridesAndOffset(memRefType, strides, offset); 2055 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2056 (void)res; 2057 auto map = 2058 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2059 map = permutationMap ? map.compose(permutationMap) : map; 2060 return MemRefType::Builder(memRefType) 2061 .setShape(sizes) 2062 .setLayout(AffineMapAttr::get(map)); 2063 } 2064 2065 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2066 AffineMapAttr permutation, 2067 ArrayRef<NamedAttribute> attrs) { 2068 auto permutationMap = permutation.getValue(); 2069 assert(permutationMap); 2070 2071 auto memRefType = in.getType().cast<MemRefType>(); 2072 // Compute result type. 2073 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2074 2075 build(b, result, resultType, in, attrs); 2076 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2077 } 2078 2079 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2080 static void print(OpAsmPrinter &p, TransposeOp op) { 2081 p << " " << op.in() << " " << op.permutation(); 2082 p.printOptionalAttrDict(op->getAttrs(), 2083 {TransposeOp::getPermutationAttrName()}); 2084 p << " : " << op.in().getType() << " to " << op.getType(); 2085 } 2086 2087 static ParseResult parseTransposeOp(OpAsmParser &parser, 2088 OperationState &result) { 2089 OpAsmParser::OperandType in; 2090 AffineMap permutation; 2091 MemRefType srcType, dstType; 2092 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2093 parser.parseOptionalAttrDict(result.attributes) || 2094 parser.parseColonType(srcType) || 2095 parser.resolveOperand(in, srcType, result.operands) || 2096 parser.parseKeywordType("to", dstType) || 2097 parser.addTypeToList(dstType, result.types)) 2098 return failure(); 2099 2100 result.addAttribute(TransposeOp::getPermutationAttrName(), 2101 AffineMapAttr::get(permutation)); 2102 return success(); 2103 } 2104 2105 static LogicalResult verify(TransposeOp op) { 2106 if (!op.permutation().isPermutation()) 2107 return op.emitOpError("expected a permutation map"); 2108 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2109 return op.emitOpError( 2110 "expected a permutation map of same rank as the input"); 2111 2112 auto srcType = op.in().getType().cast<MemRefType>(); 2113 auto dstType = op.getType().cast<MemRefType>(); 2114 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2115 if (dstType != transposedType) 2116 return op.emitOpError("output type ") 2117 << dstType << " does not match transposed input type " << srcType 2118 << ", " << transposedType; 2119 return success(); 2120 } 2121 2122 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2123 if (succeeded(foldMemRefCast(*this))) 2124 return getResult(); 2125 return {}; 2126 } 2127 2128 //===----------------------------------------------------------------------===// 2129 // ViewOp 2130 //===----------------------------------------------------------------------===// 2131 2132 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2133 OpAsmParser::OperandType srcInfo; 2134 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2135 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2136 auto indexType = parser.getBuilder().getIndexType(); 2137 Type srcType, dstType; 2138 llvm::SMLoc offsetLoc; 2139 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2140 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2141 return failure(); 2142 2143 if (offsetInfo.size() != 1) 2144 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2145 2146 return failure( 2147 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2148 parser.parseOptionalAttrDict(result.attributes) || 2149 parser.parseColonType(srcType) || 2150 parser.resolveOperand(srcInfo, srcType, result.operands) || 2151 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2152 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2153 parser.parseKeywordType("to", dstType) || 2154 parser.addTypeToList(dstType, result.types)); 2155 } 2156 2157 static void print(OpAsmPrinter &p, ViewOp op) { 2158 p << ' ' << op.getOperand(0) << '['; 2159 p.printOperand(op.byte_shift()); 2160 p << "][" << op.sizes() << ']'; 2161 p.printOptionalAttrDict(op->getAttrs()); 2162 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2163 } 2164 2165 static LogicalResult verify(ViewOp op) { 2166 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2167 auto viewType = op.getType(); 2168 2169 // The base memref should have identity layout map (or none). 2170 if (!baseType.getLayout().isIdentity()) 2171 return op.emitError("unsupported map for base memref type ") << baseType; 2172 2173 // The result memref should have identity layout map (or none). 2174 if (!viewType.getLayout().isIdentity()) 2175 return op.emitError("unsupported map for result memref type ") << viewType; 2176 2177 // The base memref and the view memref should be in the same memory space. 2178 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2179 return op.emitError("different memory spaces specified for base memref " 2180 "type ") 2181 << baseType << " and view memref type " << viewType; 2182 2183 // Verify that we have the correct number of sizes for the result type. 2184 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2185 if (op.sizes().size() != numDynamicDims) 2186 return op.emitError("incorrect number of size operands for type ") 2187 << viewType; 2188 2189 return success(); 2190 } 2191 2192 Value ViewOp::getViewSource() { return source(); } 2193 2194 namespace { 2195 2196 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2197 using OpRewritePattern<ViewOp>::OpRewritePattern; 2198 2199 LogicalResult matchAndRewrite(ViewOp viewOp, 2200 PatternRewriter &rewriter) const override { 2201 // Return if none of the operands are constants. 2202 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2203 return matchPattern(operand, matchConstantIndex()); 2204 })) 2205 return failure(); 2206 2207 // Get result memref type. 2208 auto memrefType = viewOp.getType(); 2209 2210 // Get offset from old memref view type 'memRefType'. 2211 int64_t oldOffset; 2212 SmallVector<int64_t, 4> oldStrides; 2213 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2214 return failure(); 2215 assert(oldOffset == 0 && "Expected 0 offset"); 2216 2217 SmallVector<Value, 4> newOperands; 2218 2219 // Offset cannot be folded into result type. 2220 2221 // Fold any dynamic dim operands which are produced by a constant. 2222 SmallVector<int64_t, 4> newShapeConstants; 2223 newShapeConstants.reserve(memrefType.getRank()); 2224 2225 unsigned dynamicDimPos = 0; 2226 unsigned rank = memrefType.getRank(); 2227 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2228 int64_t dimSize = memrefType.getDimSize(dim); 2229 // If this is already static dimension, keep it. 2230 if (!ShapedType::isDynamic(dimSize)) { 2231 newShapeConstants.push_back(dimSize); 2232 continue; 2233 } 2234 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2235 if (auto constantIndexOp = 2236 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2237 // Dynamic shape dimension will be folded. 2238 newShapeConstants.push_back(constantIndexOp.value()); 2239 } else { 2240 // Dynamic shape dimension not folded; copy operand from old memref. 2241 newShapeConstants.push_back(dimSize); 2242 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2243 } 2244 dynamicDimPos++; 2245 } 2246 2247 // Create new memref type with constant folded dims. 2248 MemRefType newMemRefType = 2249 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2250 // Nothing new, don't fold. 2251 if (newMemRefType == memrefType) 2252 return failure(); 2253 2254 // Create new ViewOp. 2255 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2256 viewOp.getOperand(0), 2257 viewOp.byte_shift(), newOperands); 2258 // Insert a cast so we have the same type as the old memref type. 2259 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2260 return success(); 2261 } 2262 }; 2263 2264 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2265 using OpRewritePattern<ViewOp>::OpRewritePattern; 2266 2267 LogicalResult matchAndRewrite(ViewOp viewOp, 2268 PatternRewriter &rewriter) const override { 2269 Value memrefOperand = viewOp.getOperand(0); 2270 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2271 if (!memrefCastOp) 2272 return failure(); 2273 Value allocOperand = memrefCastOp.getOperand(); 2274 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2275 if (!allocOp) 2276 return failure(); 2277 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2278 viewOp.byte_shift(), viewOp.sizes()); 2279 return success(); 2280 } 2281 }; 2282 2283 } // namespace 2284 2285 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2286 MLIRContext *context) { 2287 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2288 } 2289 2290 //===----------------------------------------------------------------------===// 2291 // AtomicRMWOp 2292 //===----------------------------------------------------------------------===// 2293 2294 static LogicalResult verify(AtomicRMWOp op) { 2295 if (op.getMemRefType().getRank() != op.getNumOperands() - 2) 2296 return op.emitOpError( 2297 "expects the number of subscripts to be equal to memref rank"); 2298 switch (op.kind()) { 2299 case arith::AtomicRMWKind::addf: 2300 case arith::AtomicRMWKind::maxf: 2301 case arith::AtomicRMWKind::minf: 2302 case arith::AtomicRMWKind::mulf: 2303 if (!op.value().getType().isa<FloatType>()) 2304 return op.emitOpError() 2305 << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) 2306 << "' expects a floating-point type"; 2307 break; 2308 case arith::AtomicRMWKind::addi: 2309 case arith::AtomicRMWKind::maxs: 2310 case arith::AtomicRMWKind::maxu: 2311 case arith::AtomicRMWKind::mins: 2312 case arith::AtomicRMWKind::minu: 2313 case arith::AtomicRMWKind::muli: 2314 case arith::AtomicRMWKind::ori: 2315 case arith::AtomicRMWKind::andi: 2316 if (!op.value().getType().isa<IntegerType>()) 2317 return op.emitOpError() 2318 << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) 2319 << "' expects an integer type"; 2320 break; 2321 default: 2322 break; 2323 } 2324 return success(); 2325 } 2326 2327 OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) { 2328 /// atomicrmw(memrefcast) -> atomicrmw 2329 if (succeeded(foldMemRefCast(*this, value()))) 2330 return getResult(); 2331 return OpFoldResult(); 2332 } 2333 2334 //===----------------------------------------------------------------------===// 2335 // TableGen'd op method definitions 2336 //===----------------------------------------------------------------------===// 2337 2338 #define GET_OP_CLASSES 2339 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2340