1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Utils/StaticValueUtils.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Interfaces/InferTypeOpInterface.h" 23 #include "mlir/Interfaces/ViewLikeInterface.h" 24 #include "llvm/ADT/STLExtras.h" 25 26 using namespace mlir; 27 using namespace mlir::memref; 28 29 /// Materialize a single constant operation from a given attribute value with 30 /// the desired resultant type. 31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 32 Attribute value, Type type, 33 Location loc) { 34 if (arith::ConstantOp::isBuildableWith(value, type)) 35 return builder.create<arith::ConstantOp>(loc, value, type); 36 if (ConstantOp::isBuildableWith(value, type)) 37 return builder.create<ConstantOp>(loc, value, type); 38 return nullptr; 39 } 40 41 //===----------------------------------------------------------------------===// 42 // Common canonicalization pattern support logic 43 //===----------------------------------------------------------------------===// 44 45 /// This is a common class used for patterns of the form 46 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 47 /// into the root operation directly. 48 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 49 bool folded = false; 50 for (OpOperand &operand : op->getOpOperands()) { 51 auto cast = operand.get().getDefiningOp<CastOp>(); 52 if (cast && operand.get() != inner && 53 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 54 operand.set(cast.getOperand()); 55 folded = true; 56 } 57 } 58 return success(folded); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // Helpers for GlobalOp 63 //===----------------------------------------------------------------------===// 64 65 static Type getTensorTypeFromMemRefType(Type type) { 66 if (auto memref = type.dyn_cast<MemRefType>()) 67 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 68 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 69 return UnrankedTensorType::get(memref.getElementType()); 70 return NoneType::get(type.getContext()); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // AllocOp / AllocaOp 75 //===----------------------------------------------------------------------===// 76 77 template <typename AllocLikeOp> 78 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 79 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 80 "applies to only alloc or alloca"); 81 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 82 if (!memRefType) 83 return op.emitOpError("result must be a memref"); 84 85 if (static_cast<int64_t>(op.dynamicSizes().size()) != 86 memRefType.getNumDynamicDims()) 87 return op.emitOpError("dimension operand count does not equal memref " 88 "dynamic dimension count"); 89 90 unsigned numSymbols = 0; 91 if (!memRefType.getLayout().isIdentity()) 92 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 93 if (op.symbolOperands().size() != numSymbols) 94 return op.emitOpError("symbol operand count does not equal memref symbol " 95 "count: expected ") 96 << numSymbols << ", got " << op.symbolOperands().size(); 97 98 return success(); 99 } 100 101 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 102 103 static LogicalResult verify(AllocaOp op) { 104 // An alloca op needs to have an ancestor with an allocation scope trait. 105 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 106 return op.emitOpError( 107 "requires an ancestor op with AutomaticAllocationScope trait"); 108 109 return verifyAllocLikeOp(op); 110 } 111 112 namespace { 113 /// Fold constant dimensions into an alloc like operation. 114 template <typename AllocLikeOp> 115 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 116 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 117 118 LogicalResult matchAndRewrite(AllocLikeOp alloc, 119 PatternRewriter &rewriter) const override { 120 // Check to see if any dimensions operands are constants. If so, we can 121 // substitute and drop them. 122 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 123 return matchPattern(operand, matchConstantIndex()); 124 })) 125 return failure(); 126 127 auto memrefType = alloc.getType(); 128 129 // Ok, we have one or more constant operands. Collect the non-constant ones 130 // and keep track of the resultant memref type to build. 131 SmallVector<int64_t, 4> newShapeConstants; 132 newShapeConstants.reserve(memrefType.getRank()); 133 SmallVector<Value, 4> dynamicSizes; 134 135 unsigned dynamicDimPos = 0; 136 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 137 int64_t dimSize = memrefType.getDimSize(dim); 138 // If this is already static dimension, keep it. 139 if (dimSize != -1) { 140 newShapeConstants.push_back(dimSize); 141 continue; 142 } 143 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 144 auto *defOp = dynamicSize.getDefiningOp(); 145 if (auto constantIndexOp = 146 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 147 // Dynamic shape dimension will be folded. 148 newShapeConstants.push_back(constantIndexOp.value()); 149 } else { 150 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 151 newShapeConstants.push_back(-1); 152 dynamicSizes.push_back(dynamicSize); 153 } 154 dynamicDimPos++; 155 } 156 157 // Create new memref type (which will have fewer dynamic dimensions). 158 MemRefType newMemRefType = 159 MemRefType::Builder(memrefType).setShape(newShapeConstants); 160 assert(static_cast<int64_t>(dynamicSizes.size()) == 161 newMemRefType.getNumDynamicDims()); 162 163 // Create and insert the alloc op for the new memref. 164 auto newAlloc = rewriter.create<AllocLikeOp>( 165 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 166 alloc.alignmentAttr()); 167 // Insert a cast so we have the same type as the old alloc. 168 auto resultCast = 169 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 170 171 rewriter.replaceOp(alloc, {resultCast}); 172 return success(); 173 } 174 }; 175 176 /// Fold alloc operations with no users or only store and dealloc uses. 177 template <typename T> 178 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 179 using OpRewritePattern<T>::OpRewritePattern; 180 181 LogicalResult matchAndRewrite(T alloc, 182 PatternRewriter &rewriter) const override { 183 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 184 if (auto storeOp = dyn_cast<StoreOp>(op)) 185 return storeOp.value() == alloc; 186 return !isa<DeallocOp>(op); 187 })) 188 return failure(); 189 190 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 191 rewriter.eraseOp(user); 192 193 rewriter.eraseOp(alloc); 194 return success(); 195 } 196 }; 197 } // end anonymous namespace. 198 199 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 200 MLIRContext *context) { 201 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 202 } 203 204 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 205 MLIRContext *context) { 206 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 207 context); 208 } 209 210 //===----------------------------------------------------------------------===// 211 // AllocaScopeOp 212 //===----------------------------------------------------------------------===// 213 214 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 215 bool printBlockTerminators = false; 216 217 p << " "; 218 if (!op.results().empty()) { 219 p << " -> (" << op.getResultTypes() << ")"; 220 printBlockTerminators = true; 221 } 222 p.printRegion(op.bodyRegion(), 223 /*printEntryBlockArgs=*/false, 224 /*printBlockTerminators=*/printBlockTerminators); 225 p.printOptionalAttrDict(op->getAttrs()); 226 } 227 228 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 229 OperationState &result) { 230 // Create a region for the body. 231 result.regions.reserve(1); 232 Region *bodyRegion = result.addRegion(); 233 234 // Parse optional results type list. 235 if (parser.parseOptionalArrowTypeList(result.types)) 236 return failure(); 237 238 // Parse the body region. 239 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 240 return failure(); 241 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 242 result.location); 243 244 // Parse the optional attribute list. 245 if (parser.parseOptionalAttrDict(result.attributes)) 246 return failure(); 247 248 return success(); 249 } 250 251 static LogicalResult verify(AllocaScopeOp op) { 252 if (failed(RegionBranchOpInterface::verifyTypes(op))) 253 return failure(); 254 255 return success(); 256 } 257 258 void AllocaScopeOp::getSuccessorRegions( 259 Optional<unsigned> index, ArrayRef<Attribute> operands, 260 SmallVectorImpl<RegionSuccessor> ®ions) { 261 if (index.hasValue()) { 262 regions.push_back(RegionSuccessor(getResults())); 263 return; 264 } 265 266 regions.push_back(RegionSuccessor(&bodyRegion())); 267 } 268 269 //===----------------------------------------------------------------------===// 270 // AssumeAlignmentOp 271 //===----------------------------------------------------------------------===// 272 273 static LogicalResult verify(AssumeAlignmentOp op) { 274 unsigned alignment = op.alignment(); 275 if (!llvm::isPowerOf2_32(alignment)) 276 return op.emitOpError("alignment must be power of 2"); 277 return success(); 278 } 279 280 //===----------------------------------------------------------------------===// 281 // BufferCastOp 282 //===----------------------------------------------------------------------===// 283 284 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 285 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 286 if (tensorLoad.memref().getType() == getType()) 287 return tensorLoad.memref(); 288 return {}; 289 } 290 291 namespace { 292 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 293 struct BufferCast : public OpRewritePattern<BufferCastOp> { 294 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 295 296 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 297 PatternRewriter &rewriter) const final { 298 auto tensorCastOperand = 299 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 300 if (!tensorCastOperand) 301 return failure(); 302 auto srcTensorType = 303 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 304 if (!srcTensorType) 305 return failure(); 306 auto memrefType = MemRefType::get(srcTensorType.getShape(), 307 srcTensorType.getElementType()); 308 Value memref = rewriter.create<BufferCastOp>( 309 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 310 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 311 memref); 312 return success(); 313 } 314 }; 315 316 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 317 /// type mismatches prevent `BufferCastOp::fold` to kick in. 318 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 319 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 320 321 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 322 PatternRewriter &rewriter) const final { 323 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 324 // Bail unless we have a tensor_load + memref.buffer_cast with different 325 // types. `BufferCastOp::fold` handles the same type case. 326 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 327 return failure(); 328 // If types are definitely not cast-compatible, bail. 329 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 330 bufferCast.getType())) 331 return failure(); 332 333 // We already know that the types are potentially cast-compatible. However 334 // in case the affine maps are different, we may need to use a copy if we go 335 // from dynamic to static offset or stride (the canonicalization cannot know 336 // at this point that it is really cast compatible). 337 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 338 int64_t sourceOffset, targetOffset; 339 SmallVector<int64_t, 4> sourceStrides, targetStrides; 340 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 341 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 342 return false; 343 auto dynamicToStatic = [](int64_t a, int64_t b) { 344 return a == MemRefType::getDynamicStrideOrOffset() && 345 b != MemRefType::getDynamicStrideOrOffset(); 346 }; 347 if (dynamicToStatic(sourceOffset, targetOffset)) 348 return false; 349 for (auto it : zip(sourceStrides, targetStrides)) 350 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 351 return false; 352 return true; 353 }; 354 355 auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>(); 356 auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>(); 357 if (tensorLoadType && bufferCastType && 358 !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) { 359 MemRefType resultType = bufferCastType; 360 auto loc = bufferCast.getLoc(); 361 SmallVector<Value, 4> dynamicOperands; 362 for (int i = 0; i < resultType.getRank(); ++i) { 363 if (resultType.getShape()[i] != ShapedType::kDynamicSize) 364 continue; 365 auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i); 366 Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index); 367 dynamicOperands.push_back(size); 368 } 369 auto copy = 370 rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands); 371 rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy); 372 rewriter.replaceOp(bufferCast, {copy}); 373 } else 374 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 375 tensorLoad.memref()); 376 return success(); 377 } 378 }; 379 380 } // namespace 381 382 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 383 MLIRContext *context) { 384 results.add<BufferCast, TensorLoadToMemRef>(context); 385 } 386 387 //===----------------------------------------------------------------------===// 388 // CastOp 389 //===----------------------------------------------------------------------===// 390 391 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 392 /// source memref. This is useful to to fold a memref.cast into a consuming op 393 /// and implement canonicalization patterns for ops in different dialects that 394 /// may consume the results of memref.cast operations. Such foldable memref.cast 395 /// operations are typically inserted as `view` and `subview` ops are 396 /// canonicalized, to preserve the type compatibility of their uses. 397 /// 398 /// Returns true when all conditions are met: 399 /// 1. source and result are ranked memrefs with strided semantics and same 400 /// element type and rank. 401 /// 2. each of the source's size, offset or stride has more static information 402 /// than the corresponding result's size, offset or stride. 403 /// 404 /// Example 1: 405 /// ```mlir 406 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 407 /// %2 = consumer %1 ... : memref<?x?xf32> ... 408 /// ``` 409 /// 410 /// may fold into: 411 /// 412 /// ```mlir 413 /// %2 = consumer %0 ... : memref<8x16xf32> ... 414 /// ``` 415 /// 416 /// Example 2: 417 /// ``` 418 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 419 /// to memref<?x?xf32> 420 /// consumer %1 : memref<?x?xf32> ... 421 /// ``` 422 /// 423 /// may fold into: 424 /// 425 /// ``` 426 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 427 /// ``` 428 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 429 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 430 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 431 432 // Requires ranked MemRefType. 433 if (!sourceType || !resultType) 434 return false; 435 436 // Requires same elemental type. 437 if (sourceType.getElementType() != resultType.getElementType()) 438 return false; 439 440 // Requires same rank. 441 if (sourceType.getRank() != resultType.getRank()) 442 return false; 443 444 // Only fold casts between strided memref forms. 445 int64_t sourceOffset, resultOffset; 446 SmallVector<int64_t, 4> sourceStrides, resultStrides; 447 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 448 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 449 return false; 450 451 // If cast is towards more static sizes along any dimension, don't fold. 452 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 453 auto ss = std::get<0>(it), st = std::get<1>(it); 454 if (ss != st) 455 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 456 return false; 457 } 458 459 // If cast is towards more static offset along any dimension, don't fold. 460 if (sourceOffset != resultOffset) 461 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 462 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 463 return false; 464 465 // If cast is towards more static strides along any dimension, don't fold. 466 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 467 auto ss = std::get<0>(it), st = std::get<1>(it); 468 if (ss != st) 469 if (MemRefType::isDynamicStrideOrOffset(ss) && 470 !MemRefType::isDynamicStrideOrOffset(st)) 471 return false; 472 } 473 474 return true; 475 } 476 477 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 478 if (inputs.size() != 1 || outputs.size() != 1) 479 return false; 480 Type a = inputs.front(), b = outputs.front(); 481 auto aT = a.dyn_cast<MemRefType>(); 482 auto bT = b.dyn_cast<MemRefType>(); 483 484 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 485 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 486 487 if (aT && bT) { 488 if (aT.getElementType() != bT.getElementType()) 489 return false; 490 if (aT.getLayout() != bT.getLayout()) { 491 int64_t aOffset, bOffset; 492 SmallVector<int64_t, 4> aStrides, bStrides; 493 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 494 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 495 aStrides.size() != bStrides.size()) 496 return false; 497 498 // Strides along a dimension/offset are compatible if the value in the 499 // source memref is static and the value in the target memref is the 500 // same. They are also compatible if either one is dynamic (see 501 // description of MemRefCastOp for details). 502 auto checkCompatible = [](int64_t a, int64_t b) { 503 return (a == MemRefType::getDynamicStrideOrOffset() || 504 b == MemRefType::getDynamicStrideOrOffset() || a == b); 505 }; 506 if (!checkCompatible(aOffset, bOffset)) 507 return false; 508 for (auto aStride : enumerate(aStrides)) 509 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 510 return false; 511 } 512 if (aT.getMemorySpace() != bT.getMemorySpace()) 513 return false; 514 515 // They must have the same rank, and any specified dimensions must match. 516 if (aT.getRank() != bT.getRank()) 517 return false; 518 519 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 520 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 521 if (aDim != -1 && bDim != -1 && aDim != bDim) 522 return false; 523 } 524 return true; 525 } else { 526 if (!aT && !uaT) 527 return false; 528 if (!bT && !ubT) 529 return false; 530 // Unranked to unranked casting is unsupported 531 if (uaT && ubT) 532 return false; 533 534 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 535 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 536 if (aEltType != bEltType) 537 return false; 538 539 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 540 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 541 if (aMemSpace != bMemSpace) 542 return false; 543 544 return true; 545 } 546 547 return false; 548 } 549 550 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 551 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 552 } 553 554 //===----------------------------------------------------------------------===// 555 // CloneOp 556 //===----------------------------------------------------------------------===// 557 558 void CloneOp::getEffects( 559 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 560 &effects) { 561 effects.emplace_back(MemoryEffects::Read::get(), input(), 562 SideEffects::DefaultResource::get()); 563 effects.emplace_back(MemoryEffects::Write::get(), output(), 564 SideEffects::DefaultResource::get()); 565 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 566 SideEffects::DefaultResource::get()); 567 } 568 569 namespace { 570 /// Merge the clone and its source (by converting the clone to a cast) when 571 /// possible. 572 struct SimplifyClones : public OpRewritePattern<CloneOp> { 573 using OpRewritePattern<CloneOp>::OpRewritePattern; 574 575 LogicalResult matchAndRewrite(CloneOp cloneOp, 576 PatternRewriter &rewriter) const override { 577 if (cloneOp.use_empty()) { 578 rewriter.eraseOp(cloneOp); 579 return success(); 580 } 581 582 Value source = cloneOp.input(); 583 584 // This only finds dealloc operations for the immediate value. It should 585 // also consider aliases. That would also make the safety check below 586 // redundant. 587 llvm::Optional<Operation *> maybeCloneDeallocOp = 588 findDealloc(cloneOp.output()); 589 // Skip if either of them has > 1 deallocate operations. 590 if (!maybeCloneDeallocOp.hasValue()) 591 return failure(); 592 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source); 593 if (!maybeSourceDeallocOp.hasValue()) 594 return failure(); 595 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 596 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 597 598 // If both are deallocated in the same block, their in-block lifetimes 599 // might not fully overlap, so we cannot decide which one to drop. 600 if (cloneDeallocOp && sourceDeallocOp && 601 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 602 return failure(); 603 604 Block *currentBlock = cloneOp->getBlock(); 605 Operation *redundantDealloc = nullptr; 606 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 607 redundantDealloc = cloneDeallocOp; 608 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 609 redundantDealloc = sourceDeallocOp; 610 } 611 612 if (!redundantDealloc) 613 return failure(); 614 615 // Safety check that there are no other deallocations inbetween 616 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 617 // of source before the uses of the clone. With alias information, we could 618 // restrict this to only fail of the dealloc's operand is an alias 619 // of the source. 620 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 621 pos = pos->getNextNode()) { 622 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 623 if (!effectInterface) 624 continue; 625 if (effectInterface.hasEffect<MemoryEffects::Free>()) 626 return failure(); 627 } 628 629 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 630 source); 631 rewriter.eraseOp(redundantDealloc); 632 return success(); 633 } 634 }; 635 636 } // end anonymous namespace. 637 638 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 639 MLIRContext *context) { 640 results.insert<SimplifyClones>(context); 641 } 642 643 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 644 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 645 } 646 647 //===----------------------------------------------------------------------===// 648 // DeallocOp 649 //===----------------------------------------------------------------------===// 650 651 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 652 SmallVectorImpl<OpFoldResult> &results) { 653 /// dealloc(memrefcast) -> dealloc 654 return foldMemRefCast(*this); 655 } 656 657 //===----------------------------------------------------------------------===// 658 // DimOp 659 //===----------------------------------------------------------------------===// 660 661 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 662 int64_t index) { 663 auto loc = result.location; 664 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 665 build(builder, result, source, indexValue); 666 } 667 668 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 669 Value index) { 670 auto indexTy = builder.getIndexType(); 671 build(builder, result, indexTy, source, index); 672 } 673 674 Optional<int64_t> DimOp::getConstantIndex() { 675 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 676 return constantOp.getValue().cast<IntegerAttr>().getInt(); 677 return {}; 678 } 679 680 static LogicalResult verify(DimOp op) { 681 // Assume unknown index to be in range. 682 Optional<int64_t> index = op.getConstantIndex(); 683 if (!index.hasValue()) 684 return success(); 685 686 // Check that constant index is not knowingly out of range. 687 auto type = op.source().getType(); 688 if (auto memrefType = type.dyn_cast<MemRefType>()) { 689 if (index.getValue() >= memrefType.getRank()) 690 return op.emitOpError("index is out of range"); 691 } else if (type.isa<UnrankedMemRefType>()) { 692 // Assume index to be in range. 693 } else { 694 llvm_unreachable("expected operand with memref type"); 695 } 696 return success(); 697 } 698 699 /// Return a map with key being elements in `vals` and data being number of 700 /// occurences of it. Use std::map, since the `vals` here are strides and the 701 /// dynamic stride value is the same as the tombstone value for 702 /// `DenseMap<int64_t>`. 703 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 704 std::map<int64_t, unsigned> numOccurences; 705 for (auto val : vals) 706 numOccurences[val]++; 707 return numOccurences; 708 } 709 710 /// Given the type of the un-rank reduced subview result type and the 711 /// rank-reduced result type, computes the dropped dimensions. This accounts for 712 /// cases where there are multiple unit-dims, but only a subset of those are 713 /// dropped. For MemRefTypes these can be disambiguated using the strides. If a 714 /// dimension is dropped the stride must be dropped too. 715 static llvm::Optional<llvm::SmallDenseSet<unsigned>> 716 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 717 ArrayAttr staticSizes) { 718 llvm::SmallDenseSet<unsigned> unusedDims; 719 if (originalType.getRank() == reducedType.getRank()) 720 return unusedDims; 721 722 for (auto dim : llvm::enumerate(staticSizes)) 723 if (dim.value().cast<IntegerAttr>().getInt() == 1) 724 unusedDims.insert(dim.index()); 725 SmallVector<int64_t> originalStrides, candidateStrides; 726 int64_t originalOffset, candidateOffset; 727 if (failed( 728 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 729 failed( 730 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 731 return llvm::None; 732 733 // For memrefs, a dimension is truly dropped if its corresponding stride is 734 // also dropped. This is particularly important when more than one of the dims 735 // is 1. Track the number of occurences of the strides in the original type 736 // and the candidate type. For each unused dim that stride should not be 737 // present in the candidate type. Note that there could be multiple dimensions 738 // that have the same size. We dont need to exactly figure out which dim 739 // corresponds to which stride, we just need to verify that the number of 740 // reptitions of a stride in the original + number of unused dims with that 741 // stride == number of repititions of a stride in the candidate. 742 std::map<int64_t, unsigned> currUnaccountedStrides = 743 getNumOccurences(originalStrides); 744 std::map<int64_t, unsigned> candidateStridesNumOccurences = 745 getNumOccurences(candidateStrides); 746 llvm::SmallDenseSet<unsigned> prunedUnusedDims; 747 for (unsigned dim : unusedDims) { 748 int64_t originalStride = originalStrides[dim]; 749 if (currUnaccountedStrides[originalStride] > 750 candidateStridesNumOccurences[originalStride]) { 751 // This dim can be treated as dropped. 752 currUnaccountedStrides[originalStride]--; 753 continue; 754 } 755 if (currUnaccountedStrides[originalStride] == 756 candidateStridesNumOccurences[originalStride]) { 757 // The stride for this is not dropped. Keep as is. 758 prunedUnusedDims.insert(dim); 759 continue; 760 } 761 if (currUnaccountedStrides[originalStride] < 762 candidateStridesNumOccurences[originalStride]) { 763 // This should never happen. Cant have a stride in the reduced rank type 764 // that wasnt in the original one. 765 return llvm::None; 766 } 767 } 768 769 for (auto prunedDim : prunedUnusedDims) 770 unusedDims.erase(prunedDim); 771 if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) 772 return llvm::None; 773 return unusedDims; 774 } 775 776 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() { 777 MemRefType sourceType = getSourceType(); 778 MemRefType resultType = getType(); 779 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 780 computeMemRefRankReductionMask(sourceType, resultType, static_sizes()); 781 assert(unusedDims && "unable to find unused dims of subview"); 782 return *unusedDims; 783 } 784 785 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 786 // All forms of folding require a known index. 787 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 788 if (!index) 789 return {}; 790 791 // Folding for unranked types (UnrankedMemRefType) is not supported. 792 auto memrefType = source().getType().dyn_cast<MemRefType>(); 793 if (!memrefType) 794 return {}; 795 796 // Fold if the shape extent along the given index is known. 797 if (!memrefType.isDynamicDim(index.getInt())) { 798 Builder builder(getContext()); 799 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 800 } 801 802 // The size at the given index is now known to be a dynamic size. 803 unsigned unsignedIndex = index.getValue().getZExtValue(); 804 805 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 806 Operation *definingOp = source().getDefiningOp(); 807 808 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 809 return *(alloc.getDynamicSizes().begin() + 810 memrefType.getDynamicDimIndex(unsignedIndex)); 811 812 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 813 return *(alloca.getDynamicSizes().begin() + 814 memrefType.getDynamicDimIndex(unsignedIndex)); 815 816 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 817 return *(view.getDynamicSizes().begin() + 818 memrefType.getDynamicDimIndex(unsignedIndex)); 819 820 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 821 llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims(); 822 unsigned resultIndex = 0; 823 unsigned sourceRank = subview.getSourceType().getRank(); 824 unsigned sourceIndex = 0; 825 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 826 if (unusedDims.count(i)) 827 continue; 828 if (resultIndex == unsignedIndex) { 829 sourceIndex = i; 830 break; 831 } 832 resultIndex++; 833 } 834 assert(subview.isDynamicSize(sourceIndex) && 835 "expected dynamic subview size"); 836 return subview.getDynamicSize(sourceIndex); 837 } 838 839 if (auto sizeInterface = 840 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 841 assert(sizeInterface.isDynamicSize(unsignedIndex) && 842 "Expected dynamic subview size"); 843 return sizeInterface.getDynamicSize(unsignedIndex); 844 } 845 846 // dim(memrefcast) -> dim 847 if (succeeded(foldMemRefCast(*this))) 848 return getResult(); 849 850 return {}; 851 } 852 853 namespace { 854 /// Fold dim of a memref reshape operation to a load into the reshape's shape 855 /// operand. 856 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 857 using OpRewritePattern<DimOp>::OpRewritePattern; 858 859 LogicalResult matchAndRewrite(DimOp dim, 860 PatternRewriter &rewriter) const override { 861 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 862 863 if (!reshape) 864 return failure(); 865 866 // Place the load directly after the reshape to ensure that the shape memref 867 // was not mutated. 868 rewriter.setInsertionPointAfter(reshape); 869 Location loc = dim.getLoc(); 870 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 871 if (load.getType() != dim.getType()) 872 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 873 rewriter.replaceOp(dim, load); 874 return success(); 875 } 876 }; 877 878 /// Fold dim of a cast into the dim of the source of the memref cast. 879 struct DimOfCastOp : public OpRewritePattern<DimOp> { 880 using OpRewritePattern<DimOp>::OpRewritePattern; 881 882 LogicalResult matchAndRewrite(DimOp dimOp, 883 PatternRewriter &rewriter) const override { 884 auto castOp = dimOp.source().getDefiningOp<BufferCastOp>(); 885 if (!castOp) 886 return failure(); 887 Value newSource = castOp.getOperand(); 888 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 889 return success(); 890 } 891 }; 892 } // end anonymous namespace. 893 894 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 895 MLIRContext *context) { 896 results.add<DimOfMemRefReshape, DimOfCastOp>(context); 897 } 898 899 // --------------------------------------------------------------------------- 900 // DmaStartOp 901 // --------------------------------------------------------------------------- 902 903 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 904 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 905 ValueRange destIndices, Value numElements, 906 Value tagMemRef, ValueRange tagIndices, Value stride, 907 Value elementsPerStride) { 908 result.addOperands(srcMemRef); 909 result.addOperands(srcIndices); 910 result.addOperands(destMemRef); 911 result.addOperands(destIndices); 912 result.addOperands({numElements, tagMemRef}); 913 result.addOperands(tagIndices); 914 if (stride) 915 result.addOperands({stride, elementsPerStride}); 916 } 917 918 static void print(OpAsmPrinter &p, DmaStartOp op) { 919 p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " 920 << op.getDstMemRef() << '[' << op.getDstIndices() << "], " 921 << op.getNumElements() << ", " << op.getTagMemRef() << '[' 922 << op.getTagIndices() << ']'; 923 if (op.isStrided()) 924 p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); 925 926 p.printOptionalAttrDict(op->getAttrs()); 927 p << " : " << op.getSrcMemRef().getType() << ", " 928 << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); 929 } 930 931 // Parse DmaStartOp. 932 // Ex: 933 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 934 // %tag[%index], %stride, %num_elt_per_stride : 935 // : memref<3076 x f32, 0>, 936 // memref<1024 x f32, 2>, 937 // memref<1 x i32> 938 // 939 static ParseResult parseDmaStartOp(OpAsmParser &parser, 940 OperationState &result) { 941 OpAsmParser::OperandType srcMemRefInfo; 942 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 943 OpAsmParser::OperandType dstMemRefInfo; 944 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 945 OpAsmParser::OperandType numElementsInfo; 946 OpAsmParser::OperandType tagMemrefInfo; 947 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 948 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 949 950 SmallVector<Type, 3> types; 951 auto indexType = parser.getBuilder().getIndexType(); 952 953 // Parse and resolve the following list of operands: 954 // *) source memref followed by its indices (in square brackets). 955 // *) destination memref followed by its indices (in square brackets). 956 // *) dma size in KiB. 957 if (parser.parseOperand(srcMemRefInfo) || 958 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 959 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 960 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 961 parser.parseComma() || parser.parseOperand(numElementsInfo) || 962 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 963 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 964 return failure(); 965 966 // Parse optional stride and elements per stride. 967 if (parser.parseTrailingOperandList(strideInfo)) 968 return failure(); 969 970 bool isStrided = strideInfo.size() == 2; 971 if (!strideInfo.empty() && !isStrided) { 972 return parser.emitError(parser.getNameLoc(), 973 "expected two stride related operands"); 974 } 975 976 if (parser.parseColonTypeList(types)) 977 return failure(); 978 if (types.size() != 3) 979 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 980 981 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 982 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 983 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 984 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 985 // size should be an index. 986 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 987 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 988 // tag indices should be index. 989 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 990 return failure(); 991 992 if (isStrided) { 993 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 994 return failure(); 995 } 996 997 return success(); 998 } 999 1000 static LogicalResult verify(DmaStartOp op) { 1001 unsigned numOperands = op.getNumOperands(); 1002 1003 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1004 // the number of elements. 1005 if (numOperands < 4) 1006 return op.emitOpError("expected at least 4 operands"); 1007 1008 // Check types of operands. The order of these calls is important: the later 1009 // calls rely on some type properties to compute the operand position. 1010 // 1. Source memref. 1011 if (!op.getSrcMemRef().getType().isa<MemRefType>()) 1012 return op.emitOpError("expected source to be of memref type"); 1013 if (numOperands < op.getSrcMemRefRank() + 4) 1014 return op.emitOpError() 1015 << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; 1016 if (!op.getSrcIndices().empty() && 1017 !llvm::all_of(op.getSrcIndices().getTypes(), 1018 [](Type t) { return t.isIndex(); })) 1019 return op.emitOpError("expected source indices to be of index type"); 1020 1021 // 2. Destination memref. 1022 if (!op.getDstMemRef().getType().isa<MemRefType>()) 1023 return op.emitOpError("expected destination to be of memref type"); 1024 unsigned numExpectedOperands = 1025 op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; 1026 if (numOperands < numExpectedOperands) 1027 return op.emitOpError() 1028 << "expected at least " << numExpectedOperands << " operands"; 1029 if (!op.getDstIndices().empty() && 1030 !llvm::all_of(op.getDstIndices().getTypes(), 1031 [](Type t) { return t.isIndex(); })) 1032 return op.emitOpError("expected destination indices to be of index type"); 1033 1034 // 3. Number of elements. 1035 if (!op.getNumElements().getType().isIndex()) 1036 return op.emitOpError("expected num elements to be of index type"); 1037 1038 // 4. Tag memref. 1039 if (!op.getTagMemRef().getType().isa<MemRefType>()) 1040 return op.emitOpError("expected tag to be of memref type"); 1041 numExpectedOperands += op.getTagMemRefRank(); 1042 if (numOperands < numExpectedOperands) 1043 return op.emitOpError() 1044 << "expected at least " << numExpectedOperands << " operands"; 1045 if (!op.getTagIndices().empty() && 1046 !llvm::all_of(op.getTagIndices().getTypes(), 1047 [](Type t) { return t.isIndex(); })) 1048 return op.emitOpError("expected tag indices to be of index type"); 1049 1050 // Optional stride-related operands must be either both present or both 1051 // absent. 1052 if (numOperands != numExpectedOperands && 1053 numOperands != numExpectedOperands + 2) 1054 return op.emitOpError("incorrect number of operands"); 1055 1056 // 5. Strides. 1057 if (op.isStrided()) { 1058 if (!op.getStride().getType().isIndex() || 1059 !op.getNumElementsPerStride().getType().isIndex()) 1060 return op.emitOpError( 1061 "expected stride and num elements per stride to be of type index"); 1062 } 1063 1064 return success(); 1065 } 1066 1067 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1068 SmallVectorImpl<OpFoldResult> &results) { 1069 /// dma_start(memrefcast) -> dma_start 1070 return foldMemRefCast(*this); 1071 } 1072 1073 // --------------------------------------------------------------------------- 1074 // DmaWaitOp 1075 // --------------------------------------------------------------------------- 1076 1077 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1078 SmallVectorImpl<OpFoldResult> &results) { 1079 /// dma_wait(memrefcast) -> dma_wait 1080 return foldMemRefCast(*this); 1081 } 1082 1083 static LogicalResult verify(DmaWaitOp op) { 1084 // Check that the number of tag indices matches the tagMemRef rank. 1085 unsigned numTagIndices = op.tagIndices().size(); 1086 unsigned tagMemRefRank = op.getTagMemRefRank(); 1087 if (numTagIndices != tagMemRefRank) 1088 return op.emitOpError() << "expected tagIndices to have the same number of " 1089 "elements as the tagMemRef rank, expected " 1090 << tagMemRefRank << ", but got " << numTagIndices; 1091 return success(); 1092 } 1093 1094 //===----------------------------------------------------------------------===// 1095 // GlobalOp 1096 //===----------------------------------------------------------------------===// 1097 1098 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1099 TypeAttr type, 1100 Attribute initialValue) { 1101 p << type; 1102 if (!op.isExternal()) { 1103 p << " = "; 1104 if (op.isUninitialized()) 1105 p << "uninitialized"; 1106 else 1107 p.printAttributeWithoutType(initialValue); 1108 } 1109 } 1110 1111 static ParseResult 1112 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1113 Attribute &initialValue) { 1114 Type type; 1115 if (parser.parseType(type)) 1116 return failure(); 1117 1118 auto memrefType = type.dyn_cast<MemRefType>(); 1119 if (!memrefType || !memrefType.hasStaticShape()) 1120 return parser.emitError(parser.getNameLoc()) 1121 << "type should be static shaped memref, but got " << type; 1122 typeAttr = TypeAttr::get(type); 1123 1124 if (parser.parseOptionalEqual()) 1125 return success(); 1126 1127 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1128 initialValue = UnitAttr::get(parser.getContext()); 1129 return success(); 1130 } 1131 1132 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1133 if (parser.parseAttribute(initialValue, tensorType)) 1134 return failure(); 1135 if (!initialValue.isa<ElementsAttr>()) 1136 return parser.emitError(parser.getNameLoc()) 1137 << "initial value should be a unit or elements attribute"; 1138 return success(); 1139 } 1140 1141 static LogicalResult verify(GlobalOp op) { 1142 auto memrefType = op.type().dyn_cast<MemRefType>(); 1143 if (!memrefType || !memrefType.hasStaticShape()) 1144 return op.emitOpError("type should be static shaped memref, but got ") 1145 << op.type(); 1146 1147 // Verify that the initial value, if present, is either a unit attribute or 1148 // an elements attribute. 1149 if (op.initial_value().hasValue()) { 1150 Attribute initValue = op.initial_value().getValue(); 1151 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1152 return op.emitOpError("initial value should be a unit or elements " 1153 "attribute, but got ") 1154 << initValue; 1155 1156 // Check that the type of the initial value is compatible with the type of 1157 // the global variable. 1158 if (initValue.isa<ElementsAttr>()) { 1159 Type initType = initValue.getType(); 1160 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1161 if (initType != tensorType) 1162 return op.emitOpError("initial value expected to be of type ") 1163 << tensorType << ", but was of type " << initType; 1164 } 1165 } 1166 1167 if (Optional<uint64_t> alignAttr = op.alignment()) { 1168 uint64_t alignment = alignAttr.getValue(); 1169 1170 if (!llvm::isPowerOf2_64(alignment)) 1171 return op->emitError() << "alignment attribute value " << alignment 1172 << " is not a power of 2"; 1173 } 1174 1175 // TODO: verify visibility for declarations. 1176 return success(); 1177 } 1178 1179 //===----------------------------------------------------------------------===// 1180 // GetGlobalOp 1181 //===----------------------------------------------------------------------===// 1182 1183 LogicalResult 1184 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1185 // Verify that the result type is same as the type of the referenced 1186 // memref.global op. 1187 auto global = 1188 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1189 if (!global) 1190 return emitOpError("'") 1191 << name() << "' does not reference a valid global memref"; 1192 1193 Type resultType = result().getType(); 1194 if (global.type() != resultType) 1195 return emitOpError("result type ") 1196 << resultType << " does not match type " << global.type() 1197 << " of the global memref @" << name(); 1198 return success(); 1199 } 1200 1201 //===----------------------------------------------------------------------===// 1202 // LoadOp 1203 //===----------------------------------------------------------------------===// 1204 1205 static LogicalResult verify(LoadOp op) { 1206 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1207 return op.emitOpError("incorrect number of indices for load"); 1208 return success(); 1209 } 1210 1211 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1212 /// load(memrefcast) -> load 1213 if (succeeded(foldMemRefCast(*this))) 1214 return getResult(); 1215 return OpFoldResult(); 1216 } 1217 1218 namespace { 1219 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1220 /// corresponding tensor. 1221 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1222 using OpRewritePattern<LoadOp>::OpRewritePattern; 1223 1224 LogicalResult matchAndRewrite(LoadOp load, 1225 PatternRewriter &rewriter) const override { 1226 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1227 if (!buffercast) 1228 return failure(); 1229 1230 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1231 load.indices()); 1232 return success(); 1233 } 1234 }; 1235 } // end anonymous namespace. 1236 1237 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1238 MLIRContext *context) { 1239 results.add<LoadOfBufferCast>(context); 1240 } 1241 1242 //===----------------------------------------------------------------------===// 1243 // PrefetchOp 1244 //===----------------------------------------------------------------------===// 1245 1246 static void print(OpAsmPrinter &p, PrefetchOp op) { 1247 p << " " << op.memref() << '['; 1248 p.printOperands(op.indices()); 1249 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1250 p << ", locality<" << op.localityHint(); 1251 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1252 p.printOptionalAttrDict( 1253 op->getAttrs(), 1254 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1255 p << " : " << op.getMemRefType(); 1256 } 1257 1258 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1259 OperationState &result) { 1260 OpAsmParser::OperandType memrefInfo; 1261 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1262 IntegerAttr localityHint; 1263 MemRefType type; 1264 StringRef readOrWrite, cacheType; 1265 1266 auto indexTy = parser.getBuilder().getIndexType(); 1267 auto i32Type = parser.getBuilder().getIntegerType(32); 1268 if (parser.parseOperand(memrefInfo) || 1269 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1270 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1271 parser.parseComma() || parser.parseKeyword("locality") || 1272 parser.parseLess() || 1273 parser.parseAttribute(localityHint, i32Type, "localityHint", 1274 result.attributes) || 1275 parser.parseGreater() || parser.parseComma() || 1276 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1277 parser.resolveOperand(memrefInfo, type, result.operands) || 1278 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1279 return failure(); 1280 1281 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1282 return parser.emitError(parser.getNameLoc(), 1283 "rw specifier has to be 'read' or 'write'"); 1284 result.addAttribute( 1285 PrefetchOp::getIsWriteAttrName(), 1286 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1287 1288 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1289 return parser.emitError(parser.getNameLoc(), 1290 "cache type has to be 'data' or 'instr'"); 1291 1292 result.addAttribute( 1293 PrefetchOp::getIsDataCacheAttrName(), 1294 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1295 1296 return success(); 1297 } 1298 1299 static LogicalResult verify(PrefetchOp op) { 1300 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1301 return op.emitOpError("too few indices"); 1302 1303 return success(); 1304 } 1305 1306 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1307 SmallVectorImpl<OpFoldResult> &results) { 1308 // prefetch(memrefcast) -> prefetch 1309 return foldMemRefCast(*this); 1310 } 1311 1312 //===----------------------------------------------------------------------===// 1313 // ReinterpretCastOp 1314 //===----------------------------------------------------------------------===// 1315 1316 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1317 /// `staticSizes` and `staticStrides` are automatically filled with 1318 /// source-memref-rank sentinel values that encode dynamic entries. 1319 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1320 MemRefType resultType, Value source, 1321 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1322 ArrayRef<OpFoldResult> strides, 1323 ArrayRef<NamedAttribute> attrs) { 1324 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1325 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1326 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1327 ShapedType::kDynamicStrideOrOffset); 1328 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1329 ShapedType::kDynamicSize); 1330 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1331 ShapedType::kDynamicStrideOrOffset); 1332 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1333 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1334 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1335 result.addAttributes(attrs); 1336 } 1337 1338 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1339 MemRefType resultType, Value source, 1340 int64_t offset, ArrayRef<int64_t> sizes, 1341 ArrayRef<int64_t> strides, 1342 ArrayRef<NamedAttribute> attrs) { 1343 SmallVector<OpFoldResult> sizeValues = 1344 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1345 return b.getI64IntegerAttr(v); 1346 })); 1347 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1348 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1349 return b.getI64IntegerAttr(v); 1350 })); 1351 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1352 strideValues, attrs); 1353 } 1354 1355 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1356 MemRefType resultType, Value source, Value offset, 1357 ValueRange sizes, ValueRange strides, 1358 ArrayRef<NamedAttribute> attrs) { 1359 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1360 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1361 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1362 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1363 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1364 } 1365 1366 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1367 // completed automatically, like we have for subview and extract_slice. 1368 static LogicalResult verify(ReinterpretCastOp op) { 1369 // The source and result memrefs should be in the same memory space. 1370 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1371 auto resultType = op.getType().cast<MemRefType>(); 1372 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1373 return op.emitError("different memory spaces specified for source type ") 1374 << srcType << " and result memref type " << resultType; 1375 if (srcType.getElementType() != resultType.getElementType()) 1376 return op.emitError("different element types specified for source type ") 1377 << srcType << " and result memref type " << resultType; 1378 1379 // Match sizes in result memref type and in static_sizes attribute. 1380 for (auto &en : 1381 llvm::enumerate(llvm::zip(resultType.getShape(), 1382 extractFromI64ArrayAttr(op.static_sizes())))) { 1383 int64_t resultSize = std::get<0>(en.value()); 1384 int64_t expectedSize = std::get<1>(en.value()); 1385 if (resultSize != expectedSize) 1386 return op.emitError("expected result type with size = ") 1387 << expectedSize << " instead of " << resultSize 1388 << " in dim = " << en.index(); 1389 } 1390 1391 // Match offset and strides in static_offset and static_strides attributes if 1392 // result memref type has an affine map specified. 1393 if (!resultType.getLayout().isIdentity()) { 1394 int64_t resultOffset; 1395 SmallVector<int64_t, 4> resultStrides; 1396 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1397 return failure(); 1398 1399 // Match offset in result memref type and in static_offsets attribute. 1400 int64_t expectedOffset = 1401 extractFromI64ArrayAttr(op.static_offsets()).front(); 1402 if (resultOffset != expectedOffset) 1403 return op.emitError("expected result type with offset = ") 1404 << resultOffset << " instead of " << expectedOffset; 1405 1406 // Match strides in result memref type and in static_strides attribute. 1407 for (auto &en : llvm::enumerate(llvm::zip( 1408 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1409 int64_t resultStride = std::get<0>(en.value()); 1410 int64_t expectedStride = std::get<1>(en.value()); 1411 if (resultStride != expectedStride) 1412 return op.emitError("expected result type with stride = ") 1413 << expectedStride << " instead of " << resultStride 1414 << " in dim = " << en.index(); 1415 } 1416 } 1417 return success(); 1418 } 1419 1420 //===----------------------------------------------------------------------===// 1421 // Reassociative reshape ops 1422 //===----------------------------------------------------------------------===// 1423 1424 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1425 return getSymbolLessAffineMaps(getReassociationExprs()); 1426 } 1427 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1428 return convertReassociationIndicesToExprs(getContext(), 1429 getReassociationIndices()); 1430 } 1431 1432 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1433 return getSymbolLessAffineMaps(getReassociationExprs()); 1434 } 1435 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1436 return convertReassociationIndicesToExprs(getContext(), 1437 getReassociationIndices()); 1438 } 1439 1440 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1441 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1442 } 1443 1444 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1445 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1446 } 1447 1448 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1449 /// copies. 1450 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1451 ArrayRef<int64_t> sizes, 1452 ArrayRef<AffineExpr> strides) { 1453 assert(sizes.size() == strides.size() && "mismatched ranks"); 1454 // off by 1 indexing to avoid out of bounds 1455 // V 1456 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1457 // Only bands of static shapes are reshapable. This is due to the fact that 1458 // there is no relation between dynamic sizes and dynamic strides: we do not 1459 // have enough information to know whether a "-1" size corresponds to the 1460 // proper symbol in the AffineExpr of a stride. 1461 if (ShapedType::isDynamic(sizes[dim + 1])) 1462 return false; 1463 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1464 // simplify on the fly and catch more reshapable cases. 1465 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1466 return false; 1467 } 1468 return true; 1469 } 1470 1471 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1472 /// expected to be valid) to `type`. 1473 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1474 /// MemRefType. 1475 static MemRefType 1476 computeReshapeCollapsedType(MemRefType type, 1477 ArrayRef<AffineMap> reassociation) { 1478 auto sizes = type.getShape(); 1479 AffineExpr offset; 1480 SmallVector<AffineExpr, 4> strides; 1481 auto status = getStridesAndOffset(type, strides, offset); 1482 (void)status; 1483 assert(succeeded(status) && "expected strided memref"); 1484 1485 SmallVector<int64_t, 4> newSizes; 1486 newSizes.reserve(reassociation.size()); 1487 SmallVector<AffineExpr, 4> newStrides; 1488 newStrides.reserve(reassociation.size()); 1489 1490 // Use the fact that reassociation is valid to simplify the logic: only use 1491 // each map's rank. 1492 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1493 unsigned currentDim = 0; 1494 for (AffineMap m : reassociation) { 1495 unsigned dim = m.getNumResults(); 1496 int64_t size = 1; 1497 AffineExpr stride = strides[currentDim + dim - 1]; 1498 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1499 size = ShapedType::kDynamicSize; 1500 stride = AffineExpr(); 1501 } else { 1502 for (unsigned d = 0; d < dim; ++d) 1503 size *= sizes[currentDim + d]; 1504 } 1505 newSizes.push_back(size); 1506 newStrides.push_back(stride); 1507 currentDim += dim; 1508 } 1509 1510 // Early-exit: if `type` is contiguous, the result must be contiguous. 1511 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1512 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1513 1514 // Convert back to int64_t because we don't have enough information to create 1515 // new strided layouts from AffineExpr only. This corresponds to a case where 1516 // copies may be necessary. 1517 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1518 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1519 intOffset = o.getValue(); 1520 SmallVector<int64_t, 4> intStrides; 1521 intStrides.reserve(strides.size()); 1522 for (auto stride : newStrides) { 1523 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1524 intStrides.push_back(cst.getValue()); 1525 else 1526 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1527 } 1528 auto layout = 1529 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1530 return canonicalizeStridedLayout( 1531 MemRefType::Builder(type).setShape(newSizes).setLayout( 1532 AffineMapAttr::get(layout))); 1533 } 1534 1535 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1536 ArrayRef<ReassociationIndices> reassociation, 1537 ArrayRef<NamedAttribute> attrs) { 1538 auto memRefType = src.getType().cast<MemRefType>(); 1539 auto resultType = computeReshapeCollapsedType( 1540 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1541 b.getContext(), reassociation))); 1542 build(b, result, resultType, src, attrs); 1543 result.addAttribute(getReassociationAttrName(), 1544 getReassociationIndicesAttribute(b, reassociation)); 1545 } 1546 1547 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1548 ArrayRef<ReassociationIndices> reassociation, 1549 ArrayRef<NamedAttribute> attrs) { 1550 auto memRefType = src.getType().cast<MemRefType>(); 1551 auto resultType = computeReshapeCollapsedType( 1552 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1553 b.getContext(), reassociation))); 1554 build(b, result, resultType, src, attrs); 1555 result.addAttribute(getReassociationAttrName(), 1556 getReassociationIndicesAttribute(b, reassociation)); 1557 } 1558 1559 template <typename ReshapeOp, 1560 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1561 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1562 MemRefType collapsedType) { 1563 if (failed( 1564 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1565 return failure(); 1566 auto maps = op.getReassociationMaps(); 1567 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1568 if (collapsedType != expectedType) 1569 return op.emitOpError("expected collapsed type to be ") 1570 << expectedType << ", but got " << collapsedType; 1571 return success(); 1572 } 1573 1574 static LogicalResult verify(ExpandShapeOp op) { 1575 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1576 } 1577 1578 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1579 MLIRContext *context) { 1580 results.add<CollapseReshapeOps<ExpandShapeOp>, 1581 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1582 } 1583 1584 static LogicalResult verify(CollapseShapeOp op) { 1585 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1586 } 1587 1588 struct CollapseShapeOpMemRefCastFolder 1589 : public OpRewritePattern<CollapseShapeOp> { 1590 public: 1591 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1592 1593 LogicalResult matchAndRewrite(CollapseShapeOp op, 1594 PatternRewriter &rewriter) const override { 1595 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1596 if (!cast) 1597 return failure(); 1598 1599 if (!CastOp::canFoldIntoConsumerOp(cast)) 1600 return failure(); 1601 1602 Type newResultType = computeReshapeCollapsedType( 1603 cast.getOperand().getType().cast<MemRefType>(), 1604 op.getReassociationMaps()); 1605 1606 if (newResultType == op.getResultType()) { 1607 rewriter.updateRootInPlace( 1608 op, [&]() { op.srcMutable().assign(cast.source()); }); 1609 } else { 1610 Value newOp = rewriter.create<CollapseShapeOp>( 1611 op->getLoc(), cast.source(), op.getReassociationIndices()); 1612 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1613 } 1614 return success(); 1615 } 1616 }; 1617 1618 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1619 MLIRContext *context) { 1620 results.add<CollapseReshapeOps<CollapseShapeOp>, 1621 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1622 CollapseShapeOpMemRefCastFolder>(context); 1623 } 1624 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1625 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1626 } 1627 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1628 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1629 } 1630 1631 //===----------------------------------------------------------------------===// 1632 // ReshapeOp 1633 //===----------------------------------------------------------------------===// 1634 1635 static LogicalResult verify(ReshapeOp op) { 1636 Type operandType = op.source().getType(); 1637 Type resultType = op.result().getType(); 1638 1639 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1640 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1641 if (operandElementType != resultElementType) 1642 return op.emitOpError("element types of source and destination memref " 1643 "types should be the same"); 1644 1645 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1646 if (!operandMemRefType.getLayout().isIdentity()) 1647 return op.emitOpError( 1648 "source memref type should have identity affine map"); 1649 1650 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1651 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1652 if (resultMemRefType) { 1653 if (!resultMemRefType.getLayout().isIdentity()) 1654 return op.emitOpError( 1655 "result memref type should have identity affine map"); 1656 if (shapeSize == ShapedType::kDynamicSize) 1657 return op.emitOpError("cannot use shape operand with dynamic length to " 1658 "reshape to statically-ranked memref type"); 1659 if (shapeSize != resultMemRefType.getRank()) 1660 return op.emitOpError( 1661 "length of shape operand differs from the result's memref rank"); 1662 } 1663 return success(); 1664 } 1665 1666 //===----------------------------------------------------------------------===// 1667 // StoreOp 1668 //===----------------------------------------------------------------------===// 1669 1670 static LogicalResult verify(StoreOp op) { 1671 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1672 return op.emitOpError("store index operand count not equal to memref rank"); 1673 1674 return success(); 1675 } 1676 1677 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1678 SmallVectorImpl<OpFoldResult> &results) { 1679 /// store(memrefcast) -> store 1680 return foldMemRefCast(*this, getValueToStore()); 1681 } 1682 1683 //===----------------------------------------------------------------------===// 1684 // SubViewOp 1685 //===----------------------------------------------------------------------===// 1686 1687 namespace { 1688 /// Helpers to write more idiomatic operations. 1689 namespace saturated_arith { 1690 struct Wrapper { 1691 explicit Wrapper(int64_t v) : v(v) {} 1692 operator int64_t() { return v; } 1693 int64_t v; 1694 }; 1695 Wrapper operator+(Wrapper a, int64_t b) { 1696 if (ShapedType::isDynamicStrideOrOffset(a) || 1697 ShapedType::isDynamicStrideOrOffset(b)) 1698 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1699 return Wrapper(a.v + b); 1700 } 1701 Wrapper operator*(Wrapper a, int64_t b) { 1702 if (ShapedType::isDynamicStrideOrOffset(a) || 1703 ShapedType::isDynamicStrideOrOffset(b)) 1704 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1705 return Wrapper(a.v * b); 1706 } 1707 } // end namespace saturated_arith 1708 } // end namespace 1709 1710 /// A subview result type can be fully inferred from the source type and the 1711 /// static representation of offsets, sizes and strides. Special sentinels 1712 /// encode the dynamic case. 1713 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1714 ArrayRef<int64_t> leadingStaticOffsets, 1715 ArrayRef<int64_t> leadingStaticSizes, 1716 ArrayRef<int64_t> leadingStaticStrides) { 1717 // A subview may specify only a leading subset of offset/sizes/strides in 1718 // which case we complete with offset=0, sizes from memref type and strides=1. 1719 unsigned rank = sourceMemRefType.getRank(); 1720 assert(leadingStaticOffsets.size() <= rank && 1721 "unexpected leadingStaticOffsets overflow"); 1722 assert(leadingStaticSizes.size() <= rank && 1723 "unexpected leadingStaticSizes overflow"); 1724 assert(leadingStaticStrides.size() <= rank && 1725 "unexpected leadingStaticStrides overflow"); 1726 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1727 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1728 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1729 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1730 unsigned numTrailingSizes = rank - staticSizes.size(); 1731 unsigned numTrailingStrides = rank - staticStrides.size(); 1732 staticOffsets.append(numTrailingOffsets, 0); 1733 llvm::append_range(staticSizes, 1734 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1735 staticStrides.append(numTrailingStrides, 1); 1736 1737 // Extract source offset and strides. 1738 int64_t sourceOffset; 1739 SmallVector<int64_t, 4> sourceStrides; 1740 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1741 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1742 (void)res; 1743 1744 // Compute target offset whose value is: 1745 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1746 int64_t targetOffset = sourceOffset; 1747 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1748 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1749 using namespace saturated_arith; 1750 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1751 } 1752 1753 // Compute target stride whose value is: 1754 // `sourceStrides_i * staticStrides_i`. 1755 SmallVector<int64_t, 4> targetStrides; 1756 targetStrides.reserve(staticOffsets.size()); 1757 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1758 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1759 using namespace saturated_arith; 1760 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1761 } 1762 1763 // The type is now known. 1764 return MemRefType::get( 1765 staticSizes, sourceMemRefType.getElementType(), 1766 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1767 sourceMemRefType.getContext()), 1768 sourceMemRefType.getMemorySpace()); 1769 } 1770 1771 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1772 ArrayRef<OpFoldResult> leadingStaticOffsets, 1773 ArrayRef<OpFoldResult> leadingStaticSizes, 1774 ArrayRef<OpFoldResult> leadingStaticStrides) { 1775 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1776 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1777 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1778 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1779 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1780 ShapedType::kDynamicSize); 1781 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1782 staticStrides, ShapedType::kDynamicStrideOrOffset); 1783 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1784 staticSizes, staticStrides) 1785 .cast<MemRefType>(); 1786 } 1787 1788 Type SubViewOp::inferRankReducedResultType( 1789 unsigned resultRank, MemRefType sourceRankedTensorType, 1790 ArrayRef<int64_t> leadingStaticOffsets, 1791 ArrayRef<int64_t> leadingStaticSizes, 1792 ArrayRef<int64_t> leadingStaticStrides) { 1793 auto inferredType = 1794 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1795 leadingStaticSizes, leadingStaticStrides) 1796 .cast<MemRefType>(); 1797 assert(inferredType.getRank() >= resultRank && "expected "); 1798 int rankDiff = inferredType.getRank() - resultRank; 1799 if (rankDiff > 0) { 1800 auto shape = inferredType.getShape(); 1801 llvm::SmallDenseSet<unsigned> dimsToProject; 1802 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1803 SmallVector<int64_t> projectedShape; 1804 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1805 if (!dimsToProject.contains(pos)) 1806 projectedShape.push_back(shape[pos]); 1807 1808 AffineMap map = inferredType.getLayout().getAffineMap(); 1809 if (!map.isIdentity()) 1810 map = getProjectedMap(map, dimsToProject); 1811 inferredType = 1812 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1813 inferredType.getMemorySpace()); 1814 } 1815 return inferredType; 1816 } 1817 1818 Type SubViewOp::inferRankReducedResultType( 1819 unsigned resultRank, MemRefType sourceRankedTensorType, 1820 ArrayRef<OpFoldResult> leadingStaticOffsets, 1821 ArrayRef<OpFoldResult> leadingStaticSizes, 1822 ArrayRef<OpFoldResult> leadingStaticStrides) { 1823 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1824 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1825 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1826 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1827 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1828 ShapedType::kDynamicSize); 1829 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1830 staticStrides, ShapedType::kDynamicStrideOrOffset); 1831 return SubViewOp::inferRankReducedResultType( 1832 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1833 staticStrides); 1834 } 1835 // Build a SubViewOp with mixed static and dynamic entries and custom result 1836 // type. If the type passed is nullptr, it is inferred. 1837 void SubViewOp::build(OpBuilder &b, OperationState &result, 1838 MemRefType resultType, Value source, 1839 ArrayRef<OpFoldResult> offsets, 1840 ArrayRef<OpFoldResult> sizes, 1841 ArrayRef<OpFoldResult> strides, 1842 ArrayRef<NamedAttribute> attrs) { 1843 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1844 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1845 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1846 ShapedType::kDynamicStrideOrOffset); 1847 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1848 ShapedType::kDynamicSize); 1849 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1850 ShapedType::kDynamicStrideOrOffset); 1851 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1852 // Structuring implementation this way avoids duplication between builders. 1853 if (!resultType) { 1854 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1855 staticSizes, staticStrides) 1856 .cast<MemRefType>(); 1857 } 1858 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1859 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1860 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1861 result.addAttributes(attrs); 1862 } 1863 1864 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1865 // type. 1866 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1867 ArrayRef<OpFoldResult> offsets, 1868 ArrayRef<OpFoldResult> sizes, 1869 ArrayRef<OpFoldResult> strides, 1870 ArrayRef<NamedAttribute> attrs) { 1871 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1872 } 1873 1874 // Build a SubViewOp with static entries and inferred result type. 1875 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1876 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1877 ArrayRef<int64_t> strides, 1878 ArrayRef<NamedAttribute> attrs) { 1879 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1880 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1881 return b.getI64IntegerAttr(v); 1882 })); 1883 SmallVector<OpFoldResult> sizeValues = 1884 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1885 return b.getI64IntegerAttr(v); 1886 })); 1887 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1888 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1889 return b.getI64IntegerAttr(v); 1890 })); 1891 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1892 } 1893 1894 // Build a SubViewOp with dynamic entries and custom result type. If the 1895 // type passed is nullptr, it is inferred. 1896 void SubViewOp::build(OpBuilder &b, OperationState &result, 1897 MemRefType resultType, Value source, 1898 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1899 ArrayRef<int64_t> strides, 1900 ArrayRef<NamedAttribute> attrs) { 1901 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1902 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1903 return b.getI64IntegerAttr(v); 1904 })); 1905 SmallVector<OpFoldResult> sizeValues = 1906 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1907 return b.getI64IntegerAttr(v); 1908 })); 1909 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1910 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1911 return b.getI64IntegerAttr(v); 1912 })); 1913 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1914 attrs); 1915 } 1916 1917 // Build a SubViewOp with dynamic entries and custom result type. If the type 1918 // passed is nullptr, it is inferred. 1919 void SubViewOp::build(OpBuilder &b, OperationState &result, 1920 MemRefType resultType, Value source, ValueRange offsets, 1921 ValueRange sizes, ValueRange strides, 1922 ArrayRef<NamedAttribute> attrs) { 1923 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1924 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1925 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1926 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1927 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1928 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1929 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1930 } 1931 1932 // Build a SubViewOp with dynamic entries and inferred result type. 1933 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1934 ValueRange offsets, ValueRange sizes, ValueRange strides, 1935 ArrayRef<NamedAttribute> attrs) { 1936 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1937 } 1938 1939 /// For ViewLikeOpInterface. 1940 Value SubViewOp::getViewSource() { return source(); } 1941 1942 enum SubViewVerificationResult { 1943 Success, 1944 RankTooLarge, 1945 SizeMismatch, 1946 ElemTypeMismatch, 1947 MemSpaceMismatch, 1948 AffineMapMismatch 1949 }; 1950 1951 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1952 /// This function is slight variant of `is subsequence` algorithm where 1953 /// not matching dimension must be 1. 1954 static SubViewVerificationResult 1955 isRankReducedType(Type originalType, Type candidateReducedType, 1956 ArrayAttr staticSizes, std::string *errMsg = nullptr) { 1957 if (originalType == candidateReducedType) 1958 return SubViewVerificationResult::Success; 1959 if (!originalType.isa<MemRefType>()) 1960 return SubViewVerificationResult::Success; 1961 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1962 return SubViewVerificationResult::Success; 1963 1964 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1965 ShapedType candidateReducedShapedType = 1966 candidateReducedType.cast<ShapedType>(); 1967 1968 // Rank and size logic is valid for all ShapedTypes. 1969 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1970 ArrayRef<int64_t> candidateReducedShape = 1971 candidateReducedShapedType.getShape(); 1972 unsigned originalRank = originalShape.size(), 1973 candidateReducedRank = candidateReducedShape.size(); 1974 if (candidateReducedRank > originalRank) 1975 return SubViewVerificationResult::RankTooLarge; 1976 1977 MemRefType original = originalType.cast<MemRefType>(); 1978 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1979 1980 auto optionalUnusedDimsMask = 1981 computeMemRefRankReductionMask(original, candidateReduced, staticSizes); 1982 1983 // Sizes cannot be matched in case empty vector is returned. 1984 if (!optionalUnusedDimsMask.hasValue()) 1985 return SubViewVerificationResult::SizeMismatch; 1986 1987 if (originalShapedType.getElementType() != 1988 candidateReducedShapedType.getElementType()) 1989 return SubViewVerificationResult::ElemTypeMismatch; 1990 1991 // Strided layout logic is relevant for MemRefType only. 1992 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1993 return SubViewVerificationResult::MemSpaceMismatch; 1994 return SubViewVerificationResult::Success; 1995 } 1996 1997 template <typename OpTy> 1998 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1999 OpTy op, Type expectedType, 2000 StringRef errMsg = "") { 2001 auto memrefType = expectedType.cast<ShapedType>(); 2002 switch (result) { 2003 case SubViewVerificationResult::Success: 2004 return success(); 2005 case SubViewVerificationResult::RankTooLarge: 2006 return op.emitError("expected result rank to be smaller or equal to ") 2007 << "the source rank. " << errMsg; 2008 case SubViewVerificationResult::SizeMismatch: 2009 return op.emitError("expected result type to be ") 2010 << expectedType 2011 << " or a rank-reduced version. (mismatch of result sizes) " 2012 << errMsg; 2013 case SubViewVerificationResult::ElemTypeMismatch: 2014 return op.emitError("expected result element type to be ") 2015 << memrefType.getElementType() << errMsg; 2016 case SubViewVerificationResult::MemSpaceMismatch: 2017 return op.emitError("expected result and source memory spaces to match.") 2018 << errMsg; 2019 case SubViewVerificationResult::AffineMapMismatch: 2020 return op.emitError("expected result type to be ") 2021 << expectedType 2022 << " or a rank-reduced version. (mismatch of result affine map) " 2023 << errMsg; 2024 } 2025 llvm_unreachable("unexpected subview verification result"); 2026 } 2027 2028 /// Verifier for SubViewOp. 2029 static LogicalResult verify(SubViewOp op) { 2030 MemRefType baseType = op.getSourceType(); 2031 MemRefType subViewType = op.getType(); 2032 2033 // The base memref and the view memref should be in the same memory space. 2034 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2035 return op.emitError("different memory spaces specified for base memref " 2036 "type ") 2037 << baseType << " and subview memref type " << subViewType; 2038 2039 // Verify that the base memref type has a strided layout map. 2040 if (!isStrided(baseType)) 2041 return op.emitError("base type ") << baseType << " is not strided"; 2042 2043 // Verify result type against inferred type. 2044 auto expectedType = SubViewOp::inferResultType( 2045 baseType, extractFromI64ArrayAttr(op.static_offsets()), 2046 extractFromI64ArrayAttr(op.static_sizes()), 2047 extractFromI64ArrayAttr(op.static_strides())); 2048 2049 std::string errMsg; 2050 auto result = 2051 isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg); 2052 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 2053 } 2054 2055 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 2056 return os << "range " << range.offset << ":" << range.size << ":" 2057 << range.stride; 2058 } 2059 2060 /// Return the list of Range (i.e. offset, size, stride). Each Range 2061 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2062 /// with `b` at location `loc`. 2063 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2064 OpBuilder &b, Location loc) { 2065 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2066 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2067 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2068 SmallVector<Range, 8> res; 2069 unsigned rank = ranks[0]; 2070 res.reserve(rank); 2071 for (unsigned idx = 0; idx < rank; ++idx) { 2072 Value offset = 2073 op.isDynamicOffset(idx) 2074 ? op.getDynamicOffset(idx) 2075 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2076 Value size = 2077 op.isDynamicSize(idx) 2078 ? op.getDynamicSize(idx) 2079 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 2080 Value stride = 2081 op.isDynamicStride(idx) 2082 ? op.getDynamicStride(idx) 2083 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 2084 res.emplace_back(Range{offset, size, stride}); 2085 } 2086 return res; 2087 } 2088 2089 /// Infer the canonical type of the result of a subview operation. Returns a 2090 /// type with rank `resultRank` that is either the rank of the rank-reduced 2091 /// type, or the non-rank-reduced type. 2092 static MemRefType 2093 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 2094 ArrayRef<OpFoldResult> mixedOffsets, 2095 ArrayRef<OpFoldResult> mixedSizes, 2096 ArrayRef<OpFoldResult> mixedStrides) { 2097 auto resultType = 2098 SubViewOp::inferRankReducedResultType( 2099 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 2100 .cast<MemRefType>(); 2101 if (resultType.getRank() != resultRank) { 2102 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2103 mixedSizes, mixedStrides) 2104 .cast<MemRefType>(); 2105 } 2106 return resultType; 2107 } 2108 2109 namespace { 2110 /// Pattern to rewrite a subview op with MemRefCast arguments. 2111 /// This essentially pushes memref.cast past its consuming subview when 2112 /// `canFoldIntoConsumerOp` is true. 2113 /// 2114 /// Example: 2115 /// ``` 2116 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2117 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2118 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2119 /// ``` 2120 /// is rewritten into: 2121 /// ``` 2122 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2123 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2124 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2125 /// ``` 2126 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2127 public: 2128 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2129 2130 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2131 PatternRewriter &rewriter) const override { 2132 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2133 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2134 return matchPattern(operand, matchConstantIndex()); 2135 })) 2136 return failure(); 2137 2138 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2139 if (!castOp) 2140 return failure(); 2141 2142 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2143 return failure(); 2144 2145 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 2146 /// the cast source operand type and the SubViewOp static information. This 2147 /// is the resulting type if the MemRefCastOp were folded. 2148 auto resultType = getCanonicalSubViewResultType( 2149 subViewOp.getType().getRank(), 2150 castOp.source().getType().cast<MemRefType>(), 2151 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2152 subViewOp.getMixedStrides()); 2153 Value newSubView = rewriter.create<SubViewOp>( 2154 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2155 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2156 subViewOp.static_sizes(), subViewOp.static_strides()); 2157 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2158 newSubView); 2159 return success(); 2160 } 2161 }; 2162 } // namespace 2163 2164 /// Return the canonical type of the result of a subview. 2165 struct SubViewReturnTypeCanonicalizer { 2166 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2167 ArrayRef<OpFoldResult> mixedSizes, 2168 ArrayRef<OpFoldResult> mixedStrides) { 2169 return getCanonicalSubViewResultType(op.getType().getRank(), 2170 op.getSourceType(), mixedOffsets, 2171 mixedSizes, mixedStrides); 2172 } 2173 }; 2174 2175 /// A canonicalizer wrapper to replace SubViewOps. 2176 struct SubViewCanonicalizer { 2177 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2178 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2179 } 2180 }; 2181 2182 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2183 MLIRContext *context) { 2184 results 2185 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2186 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2187 SubViewOpMemRefCastFolder>(context); 2188 } 2189 2190 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2191 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2192 auto sourceShapedType = source().getType().cast<ShapedType>(); 2193 2194 if (resultShapedType.hasStaticShape() && 2195 resultShapedType == sourceShapedType) { 2196 return getViewSource(); 2197 } 2198 2199 return {}; 2200 } 2201 2202 //===----------------------------------------------------------------------===// 2203 // TensorLoadOp 2204 //===----------------------------------------------------------------------===// 2205 2206 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 2207 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 2208 // Approximate alias analysis by conservatively folding only when no there 2209 // is no interleaved operation. 2210 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 2211 bufferCast->getNextNode() == this->getOperation()) 2212 return bufferCast.tensor(); 2213 return {}; 2214 } 2215 2216 namespace { 2217 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> { 2218 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 2219 2220 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 2221 PatternRewriter &rewriter) const override { 2222 auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>(); 2223 if (!tensorLoadOp) 2224 return failure(); 2225 2226 rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(), 2227 dimOp.index()); 2228 return success(); 2229 } 2230 }; 2231 } // namespace 2232 2233 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 2234 MLIRContext *context) { 2235 results.add<DimOfTensorLoadFolder>(context); 2236 } 2237 2238 //===----------------------------------------------------------------------===// 2239 // TransposeOp 2240 //===----------------------------------------------------------------------===// 2241 2242 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2243 static MemRefType inferTransposeResultType(MemRefType memRefType, 2244 AffineMap permutationMap) { 2245 auto rank = memRefType.getRank(); 2246 auto originalSizes = memRefType.getShape(); 2247 // Compute permuted sizes. 2248 SmallVector<int64_t, 4> sizes(rank, 0); 2249 for (auto en : llvm::enumerate(permutationMap.getResults())) 2250 sizes[en.index()] = 2251 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2252 2253 // Compute permuted strides. 2254 int64_t offset; 2255 SmallVector<int64_t, 4> strides; 2256 auto res = getStridesAndOffset(memRefType, strides, offset); 2257 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2258 (void)res; 2259 auto map = 2260 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2261 map = permutationMap ? map.compose(permutationMap) : map; 2262 return MemRefType::Builder(memRefType) 2263 .setShape(sizes) 2264 .setLayout(AffineMapAttr::get(map)); 2265 } 2266 2267 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2268 AffineMapAttr permutation, 2269 ArrayRef<NamedAttribute> attrs) { 2270 auto permutationMap = permutation.getValue(); 2271 assert(permutationMap); 2272 2273 auto memRefType = in.getType().cast<MemRefType>(); 2274 // Compute result type. 2275 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2276 2277 build(b, result, resultType, in, attrs); 2278 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2279 } 2280 2281 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2282 static void print(OpAsmPrinter &p, TransposeOp op) { 2283 p << " " << op.in() << " " << op.permutation(); 2284 p.printOptionalAttrDict(op->getAttrs(), 2285 {TransposeOp::getPermutationAttrName()}); 2286 p << " : " << op.in().getType() << " to " << op.getType(); 2287 } 2288 2289 static ParseResult parseTransposeOp(OpAsmParser &parser, 2290 OperationState &result) { 2291 OpAsmParser::OperandType in; 2292 AffineMap permutation; 2293 MemRefType srcType, dstType; 2294 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2295 parser.parseOptionalAttrDict(result.attributes) || 2296 parser.parseColonType(srcType) || 2297 parser.resolveOperand(in, srcType, result.operands) || 2298 parser.parseKeywordType("to", dstType) || 2299 parser.addTypeToList(dstType, result.types)) 2300 return failure(); 2301 2302 result.addAttribute(TransposeOp::getPermutationAttrName(), 2303 AffineMapAttr::get(permutation)); 2304 return success(); 2305 } 2306 2307 static LogicalResult verify(TransposeOp op) { 2308 if (!op.permutation().isPermutation()) 2309 return op.emitOpError("expected a permutation map"); 2310 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2311 return op.emitOpError( 2312 "expected a permutation map of same rank as the input"); 2313 2314 auto srcType = op.in().getType().cast<MemRefType>(); 2315 auto dstType = op.getType().cast<MemRefType>(); 2316 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2317 if (dstType != transposedType) 2318 return op.emitOpError("output type ") 2319 << dstType << " does not match transposed input type " << srcType 2320 << ", " << transposedType; 2321 return success(); 2322 } 2323 2324 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2325 if (succeeded(foldMemRefCast(*this))) 2326 return getResult(); 2327 return {}; 2328 } 2329 2330 //===----------------------------------------------------------------------===// 2331 // ViewOp 2332 //===----------------------------------------------------------------------===// 2333 2334 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2335 OpAsmParser::OperandType srcInfo; 2336 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2337 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2338 auto indexType = parser.getBuilder().getIndexType(); 2339 Type srcType, dstType; 2340 llvm::SMLoc offsetLoc; 2341 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2342 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2343 return failure(); 2344 2345 if (offsetInfo.size() != 1) 2346 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2347 2348 return failure( 2349 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2350 parser.parseOptionalAttrDict(result.attributes) || 2351 parser.parseColonType(srcType) || 2352 parser.resolveOperand(srcInfo, srcType, result.operands) || 2353 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2354 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2355 parser.parseKeywordType("to", dstType) || 2356 parser.addTypeToList(dstType, result.types)); 2357 } 2358 2359 static void print(OpAsmPrinter &p, ViewOp op) { 2360 p << ' ' << op.getOperand(0) << '['; 2361 p.printOperand(op.byte_shift()); 2362 p << "][" << op.sizes() << ']'; 2363 p.printOptionalAttrDict(op->getAttrs()); 2364 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2365 } 2366 2367 static LogicalResult verify(ViewOp op) { 2368 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2369 auto viewType = op.getType(); 2370 2371 // The base memref should have identity layout map (or none). 2372 if (!baseType.getLayout().isIdentity()) 2373 return op.emitError("unsupported map for base memref type ") << baseType; 2374 2375 // The result memref should have identity layout map (or none). 2376 if (!viewType.getLayout().isIdentity()) 2377 return op.emitError("unsupported map for result memref type ") << viewType; 2378 2379 // The base memref and the view memref should be in the same memory space. 2380 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2381 return op.emitError("different memory spaces specified for base memref " 2382 "type ") 2383 << baseType << " and view memref type " << viewType; 2384 2385 // Verify that we have the correct number of sizes for the result type. 2386 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2387 if (op.sizes().size() != numDynamicDims) 2388 return op.emitError("incorrect number of size operands for type ") 2389 << viewType; 2390 2391 return success(); 2392 } 2393 2394 Value ViewOp::getViewSource() { return source(); } 2395 2396 namespace { 2397 2398 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2399 using OpRewritePattern<ViewOp>::OpRewritePattern; 2400 2401 LogicalResult matchAndRewrite(ViewOp viewOp, 2402 PatternRewriter &rewriter) const override { 2403 // Return if none of the operands are constants. 2404 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2405 return matchPattern(operand, matchConstantIndex()); 2406 })) 2407 return failure(); 2408 2409 // Get result memref type. 2410 auto memrefType = viewOp.getType(); 2411 2412 // Get offset from old memref view type 'memRefType'. 2413 int64_t oldOffset; 2414 SmallVector<int64_t, 4> oldStrides; 2415 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2416 return failure(); 2417 assert(oldOffset == 0 && "Expected 0 offset"); 2418 2419 SmallVector<Value, 4> newOperands; 2420 2421 // Offset cannot be folded into result type. 2422 2423 // Fold any dynamic dim operands which are produced by a constant. 2424 SmallVector<int64_t, 4> newShapeConstants; 2425 newShapeConstants.reserve(memrefType.getRank()); 2426 2427 unsigned dynamicDimPos = 0; 2428 unsigned rank = memrefType.getRank(); 2429 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2430 int64_t dimSize = memrefType.getDimSize(dim); 2431 // If this is already static dimension, keep it. 2432 if (!ShapedType::isDynamic(dimSize)) { 2433 newShapeConstants.push_back(dimSize); 2434 continue; 2435 } 2436 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2437 if (auto constantIndexOp = 2438 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2439 // Dynamic shape dimension will be folded. 2440 newShapeConstants.push_back(constantIndexOp.value()); 2441 } else { 2442 // Dynamic shape dimension not folded; copy operand from old memref. 2443 newShapeConstants.push_back(dimSize); 2444 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2445 } 2446 dynamicDimPos++; 2447 } 2448 2449 // Create new memref type with constant folded dims. 2450 MemRefType newMemRefType = 2451 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2452 // Nothing new, don't fold. 2453 if (newMemRefType == memrefType) 2454 return failure(); 2455 2456 // Create new ViewOp. 2457 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2458 viewOp.getOperand(0), 2459 viewOp.byte_shift(), newOperands); 2460 // Insert a cast so we have the same type as the old memref type. 2461 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2462 return success(); 2463 } 2464 }; 2465 2466 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2467 using OpRewritePattern<ViewOp>::OpRewritePattern; 2468 2469 LogicalResult matchAndRewrite(ViewOp viewOp, 2470 PatternRewriter &rewriter) const override { 2471 Value memrefOperand = viewOp.getOperand(0); 2472 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2473 if (!memrefCastOp) 2474 return failure(); 2475 Value allocOperand = memrefCastOp.getOperand(); 2476 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2477 if (!allocOp) 2478 return failure(); 2479 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2480 viewOp.byte_shift(), viewOp.sizes()); 2481 return success(); 2482 } 2483 }; 2484 2485 } // end anonymous namespace 2486 2487 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2488 MLIRContext *context) { 2489 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2490 } 2491 2492 //===----------------------------------------------------------------------===// 2493 // TableGen'd op method definitions 2494 //===----------------------------------------------------------------------===// 2495 2496 #define GET_OP_CLASSES 2497 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2498