1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 return builder.create<mlir::ConstantOp>(loc, type, value); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Common canonicalization pattern support logic 38 //===----------------------------------------------------------------------===// 39 40 /// This is a common class used for patterns of the form 41 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 42 /// into the root operation directly. 43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 44 bool folded = false; 45 for (OpOperand &operand : op->getOpOperands()) { 46 auto cast = operand.get().getDefiningOp<CastOp>(); 47 if (cast && operand.get() != inner && 48 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 49 operand.set(cast.getOperand()); 50 folded = true; 51 } 52 } 53 return success(folded); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Helpers for GlobalOp 58 //===----------------------------------------------------------------------===// 59 60 static Type getTensorTypeFromMemRefType(Type type) { 61 if (auto memref = type.dyn_cast<MemRefType>()) 62 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 63 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 64 return UnrankedTensorType::get(memref.getElementType()); 65 return NoneType::get(type.getContext()); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // AllocOp / AllocaOp 70 //===----------------------------------------------------------------------===// 71 72 template <typename AllocLikeOp> 73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 74 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 75 "applies to only alloc or alloca"); 76 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 77 if (!memRefType) 78 return op.emitOpError("result must be a memref"); 79 80 if (static_cast<int64_t>(op.dynamicSizes().size()) != 81 memRefType.getNumDynamicDims()) 82 return op.emitOpError("dimension operand count does not equal memref " 83 "dynamic dimension count"); 84 85 unsigned numSymbols = 0; 86 if (!memRefType.getAffineMaps().empty()) 87 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 88 if (op.symbolOperands().size() != numSymbols) 89 return op.emitOpError("symbol operand count does not equal memref symbol " 90 "count: expected ") 91 << numSymbols << ", got " << op.symbolOperands().size(); 92 93 return success(); 94 } 95 96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 97 98 static LogicalResult verify(AllocaOp op) { 99 // An alloca op needs to have an ancestor with an allocation scope trait. 100 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 101 return op.emitOpError( 102 "requires an ancestor op with AutomaticAllocationScope trait"); 103 104 return verifyAllocLikeOp(op); 105 } 106 107 namespace { 108 /// Fold constant dimensions into an alloc like operation. 109 template <typename AllocLikeOp> 110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 111 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(AllocLikeOp alloc, 114 PatternRewriter &rewriter) const override { 115 // Check to see if any dimensions operands are constants. If so, we can 116 // substitute and drop them. 117 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 118 return matchPattern(operand, matchConstantIndex()); 119 })) 120 return failure(); 121 122 auto memrefType = alloc.getType(); 123 124 // Ok, we have one or more constant operands. Collect the non-constant ones 125 // and keep track of the resultant memref type to build. 126 SmallVector<int64_t, 4> newShapeConstants; 127 newShapeConstants.reserve(memrefType.getRank()); 128 SmallVector<Value, 4> dynamicSizes; 129 130 unsigned dynamicDimPos = 0; 131 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 132 int64_t dimSize = memrefType.getDimSize(dim); 133 // If this is already static dimension, keep it. 134 if (dimSize != -1) { 135 newShapeConstants.push_back(dimSize); 136 continue; 137 } 138 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 139 auto *defOp = dynamicSize.getDefiningOp(); 140 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 141 // Dynamic shape dimension will be folded. 142 newShapeConstants.push_back(constantIndexOp.getValue()); 143 } else { 144 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 145 newShapeConstants.push_back(-1); 146 dynamicSizes.push_back(dynamicSize); 147 } 148 dynamicDimPos++; 149 } 150 151 // Create new memref type (which will have fewer dynamic dimensions). 152 MemRefType newMemRefType = 153 MemRefType::Builder(memrefType).setShape(newShapeConstants); 154 assert(static_cast<int64_t>(dynamicSizes.size()) == 155 newMemRefType.getNumDynamicDims()); 156 157 // Create and insert the alloc op for the new memref. 158 auto newAlloc = rewriter.create<AllocLikeOp>( 159 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 160 alloc.alignmentAttr()); 161 // Insert a cast so we have the same type as the old alloc. 162 auto resultCast = 163 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 164 165 rewriter.replaceOp(alloc, {resultCast}); 166 return success(); 167 } 168 }; 169 170 /// Fold alloc operations with no users or only store and dealloc uses. 171 template <typename T> 172 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 173 using OpRewritePattern<T>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(T alloc, 176 PatternRewriter &rewriter) const override { 177 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 178 if (auto storeOp = dyn_cast<StoreOp>(op)) 179 return storeOp.value() == alloc; 180 return !isa<DeallocOp>(op); 181 })) 182 return failure(); 183 184 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 185 rewriter.eraseOp(user); 186 187 rewriter.eraseOp(alloc); 188 return success(); 189 } 190 }; 191 } // end anonymous namespace. 192 193 Optional<Operation *> AllocOp::buildDealloc(OpBuilder &builder, Value alloc) { 194 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 195 .getOperation(); 196 } 197 198 Optional<Value> AllocOp::buildClone(OpBuilder &builder, Value alloc) { 199 return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult(); 200 } 201 202 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 203 MLIRContext *context) { 204 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 205 } 206 207 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 208 MLIRContext *context) { 209 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 210 context); 211 } 212 213 //===----------------------------------------------------------------------===// 214 // AllocaScopeOp 215 //===----------------------------------------------------------------------===// 216 217 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 218 bool printBlockTerminators = false; 219 220 p << " "; 221 if (!op.results().empty()) { 222 p << " -> (" << op.getResultTypes() << ")"; 223 printBlockTerminators = true; 224 } 225 p.printRegion(op.bodyRegion(), 226 /*printEntryBlockArgs=*/false, 227 /*printBlockTerminators=*/printBlockTerminators); 228 p.printOptionalAttrDict(op->getAttrs()); 229 } 230 231 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 232 OperationState &result) { 233 // Create a region for the body. 234 result.regions.reserve(1); 235 Region *bodyRegion = result.addRegion(); 236 237 // Parse optional results type list. 238 if (parser.parseOptionalArrowTypeList(result.types)) 239 return failure(); 240 241 // Parse the body region. 242 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 243 return failure(); 244 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 245 result.location); 246 247 // Parse the optional attribute list. 248 if (parser.parseOptionalAttrDict(result.attributes)) 249 return failure(); 250 251 return success(); 252 } 253 254 static LogicalResult verify(AllocaScopeOp op) { 255 if (failed(RegionBranchOpInterface::verifyTypes(op))) 256 return failure(); 257 258 return success(); 259 } 260 261 void AllocaScopeOp::getSuccessorRegions( 262 Optional<unsigned> index, ArrayRef<Attribute> operands, 263 SmallVectorImpl<RegionSuccessor> ®ions) { 264 if (index.hasValue()) { 265 regions.push_back(RegionSuccessor(getResults())); 266 return; 267 } 268 269 regions.push_back(RegionSuccessor(&bodyRegion())); 270 } 271 272 //===----------------------------------------------------------------------===// 273 // AssumeAlignmentOp 274 //===----------------------------------------------------------------------===// 275 276 static LogicalResult verify(AssumeAlignmentOp op) { 277 unsigned alignment = op.alignment(); 278 if (!llvm::isPowerOf2_32(alignment)) 279 return op.emitOpError("alignment must be power of 2"); 280 return success(); 281 } 282 283 //===----------------------------------------------------------------------===// 284 // BufferCastOp 285 //===----------------------------------------------------------------------===// 286 287 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 288 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 289 if (tensorLoad.memref().getType() == getType()) 290 return tensorLoad.memref(); 291 return {}; 292 } 293 294 namespace { 295 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 296 struct BufferCast : public OpRewritePattern<BufferCastOp> { 297 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 298 299 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 300 PatternRewriter &rewriter) const final { 301 auto tensorCastOperand = 302 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 303 if (!tensorCastOperand) 304 return failure(); 305 auto srcTensorType = 306 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 307 if (!srcTensorType) 308 return failure(); 309 auto memrefType = MemRefType::get(srcTensorType.getShape(), 310 srcTensorType.getElementType()); 311 Value memref = rewriter.create<BufferCastOp>( 312 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 313 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 314 memref); 315 return success(); 316 } 317 }; 318 319 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 320 /// type mismatches prevent `BufferCastOp::fold` to kick in. 321 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 322 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 323 324 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 325 PatternRewriter &rewriter) const final { 326 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 327 // Bail unless we have a tensor_load + memref.buffer_cast with different 328 // types. `BufferCastOp::fold` handles the same type case. 329 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 330 return failure(); 331 // If types are definitely not cast-compatible, bail. 332 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 333 bufferCast.getType())) 334 return failure(); 335 336 // We already know that the types are potentially cast-compatible. However 337 // in case the affine maps are different, we may need to use a copy if we go 338 // from dynamic to static offset or stride (the canonicalization cannot know 339 // at this point that it is really cast compatible). 340 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 341 int64_t sourceOffset, targetOffset; 342 SmallVector<int64_t, 4> sourceStrides, targetStrides; 343 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 344 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 345 return false; 346 auto dynamicToStatic = [](int64_t a, int64_t b) { 347 return a == MemRefType::getDynamicStrideOrOffset() && 348 b != MemRefType::getDynamicStrideOrOffset(); 349 }; 350 if (dynamicToStatic(sourceOffset, targetOffset)) 351 return false; 352 for (auto it : zip(sourceStrides, targetStrides)) 353 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 354 return false; 355 return true; 356 }; 357 358 auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>(); 359 auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>(); 360 if (tensorLoadType && bufferCastType && 361 !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) { 362 MemRefType resultType = bufferCastType; 363 auto loc = bufferCast.getLoc(); 364 SmallVector<Value, 4> dynamicOperands; 365 for (int i = 0; i < resultType.getRank(); ++i) { 366 if (resultType.getShape()[i] != ShapedType::kDynamicSize) 367 continue; 368 auto index = rewriter.createOrFold<ConstantIndexOp>(loc, i); 369 Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index); 370 dynamicOperands.push_back(size); 371 } 372 auto copy = 373 rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands); 374 rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy); 375 rewriter.replaceOp(bufferCast, {copy}); 376 } else 377 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 378 tensorLoad.memref()); 379 return success(); 380 } 381 }; 382 383 } // namespace 384 385 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 386 MLIRContext *context) { 387 results.add<BufferCast, TensorLoadToMemRef>(context); 388 } 389 390 //===----------------------------------------------------------------------===// 391 // CastOp 392 //===----------------------------------------------------------------------===// 393 394 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 395 /// source memref. This is useful to to fold a memref.cast into a consuming op 396 /// and implement canonicalization patterns for ops in different dialects that 397 /// may consume the results of memref.cast operations. Such foldable memref.cast 398 /// operations are typically inserted as `view` and `subview` ops are 399 /// canonicalized, to preserve the type compatibility of their uses. 400 /// 401 /// Returns true when all conditions are met: 402 /// 1. source and result are ranked memrefs with strided semantics and same 403 /// element type and rank. 404 /// 2. each of the source's size, offset or stride has more static information 405 /// than the corresponding result's size, offset or stride. 406 /// 407 /// Example 1: 408 /// ```mlir 409 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 410 /// %2 = consumer %1 ... : memref<?x?xf32> ... 411 /// ``` 412 /// 413 /// may fold into: 414 /// 415 /// ```mlir 416 /// %2 = consumer %0 ... : memref<8x16xf32> ... 417 /// ``` 418 /// 419 /// Example 2: 420 /// ``` 421 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 422 /// to memref<?x?xf32> 423 /// consumer %1 : memref<?x?xf32> ... 424 /// ``` 425 /// 426 /// may fold into: 427 /// 428 /// ``` 429 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 430 /// ``` 431 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 432 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 433 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 434 435 // Requires ranked MemRefType. 436 if (!sourceType || !resultType) 437 return false; 438 439 // Requires same elemental type. 440 if (sourceType.getElementType() != resultType.getElementType()) 441 return false; 442 443 // Requires same rank. 444 if (sourceType.getRank() != resultType.getRank()) 445 return false; 446 447 // Only fold casts between strided memref forms. 448 int64_t sourceOffset, resultOffset; 449 SmallVector<int64_t, 4> sourceStrides, resultStrides; 450 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 451 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 452 return false; 453 454 // If cast is towards more static sizes along any dimension, don't fold. 455 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 456 auto ss = std::get<0>(it), st = std::get<1>(it); 457 if (ss != st) 458 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 459 return false; 460 } 461 462 // If cast is towards more static offset along any dimension, don't fold. 463 if (sourceOffset != resultOffset) 464 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 465 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 466 return false; 467 468 // If cast is towards more static strides along any dimension, don't fold. 469 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 470 auto ss = std::get<0>(it), st = std::get<1>(it); 471 if (ss != st) 472 if (MemRefType::isDynamicStrideOrOffset(ss) && 473 !MemRefType::isDynamicStrideOrOffset(st)) 474 return false; 475 } 476 477 return true; 478 } 479 480 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 481 if (inputs.size() != 1 || outputs.size() != 1) 482 return false; 483 Type a = inputs.front(), b = outputs.front(); 484 auto aT = a.dyn_cast<MemRefType>(); 485 auto bT = b.dyn_cast<MemRefType>(); 486 487 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 488 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 489 490 if (aT && bT) { 491 if (aT.getElementType() != bT.getElementType()) 492 return false; 493 if (aT.getAffineMaps() != bT.getAffineMaps()) { 494 int64_t aOffset, bOffset; 495 SmallVector<int64_t, 4> aStrides, bStrides; 496 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 497 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 498 aStrides.size() != bStrides.size()) 499 return false; 500 501 // Strides along a dimension/offset are compatible if the value in the 502 // source memref is static and the value in the target memref is the 503 // same. They are also compatible if either one is dynamic (see 504 // description of MemRefCastOp for details). 505 auto checkCompatible = [](int64_t a, int64_t b) { 506 return (a == MemRefType::getDynamicStrideOrOffset() || 507 b == MemRefType::getDynamicStrideOrOffset() || a == b); 508 }; 509 if (!checkCompatible(aOffset, bOffset)) 510 return false; 511 for (auto aStride : enumerate(aStrides)) 512 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 513 return false; 514 } 515 if (aT.getMemorySpace() != bT.getMemorySpace()) 516 return false; 517 518 // They must have the same rank, and any specified dimensions must match. 519 if (aT.getRank() != bT.getRank()) 520 return false; 521 522 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 523 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 524 if (aDim != -1 && bDim != -1 && aDim != bDim) 525 return false; 526 } 527 return true; 528 } else { 529 if (!aT && !uaT) 530 return false; 531 if (!bT && !ubT) 532 return false; 533 // Unranked to unranked casting is unsupported 534 if (uaT && ubT) 535 return false; 536 537 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 538 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 539 if (aEltType != bEltType) 540 return false; 541 542 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 543 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 544 if (aMemSpace != bMemSpace) 545 return false; 546 547 return true; 548 } 549 550 return false; 551 } 552 553 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 554 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 555 } 556 557 //===----------------------------------------------------------------------===// 558 // CloneOp 559 //===----------------------------------------------------------------------===// 560 561 void CloneOp::getEffects( 562 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 563 &effects) { 564 effects.emplace_back(MemoryEffects::Read::get(), input(), 565 SideEffects::DefaultResource::get()); 566 effects.emplace_back(MemoryEffects::Write::get(), output(), 567 SideEffects::DefaultResource::get()); 568 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 569 SideEffects::DefaultResource::get()); 570 } 571 572 namespace { 573 /// Merge the clone and its source (by converting the clone to a cast) when 574 /// possible. 575 struct SimplifyClones : public OpRewritePattern<CloneOp> { 576 using OpRewritePattern<CloneOp>::OpRewritePattern; 577 578 LogicalResult matchAndRewrite(CloneOp cloneOp, 579 PatternRewriter &rewriter) const override { 580 if (cloneOp.use_empty()) { 581 rewriter.eraseOp(cloneOp); 582 return success(); 583 } 584 585 Value source = cloneOp.input(); 586 587 // This only finds dealloc operations for the immediate value. It should 588 // also consider aliases. That would also make the safety check below 589 // redundant. 590 llvm::Optional<Operation *> maybeCloneDeallocOp = 591 findDealloc(cloneOp.output()); 592 // Skip if either of them has > 1 deallocate operations. 593 if (!maybeCloneDeallocOp.hasValue()) 594 return failure(); 595 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source); 596 if (!maybeSourceDeallocOp.hasValue()) 597 return failure(); 598 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 599 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 600 601 // If both are deallocated in the same block, their in-block lifetimes 602 // might not fully overlap, so we cannot decide which one to drop. 603 if (cloneDeallocOp && sourceDeallocOp && 604 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 605 return failure(); 606 607 Block *currentBlock = cloneOp->getBlock(); 608 Operation *redundantDealloc = nullptr; 609 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 610 redundantDealloc = cloneDeallocOp; 611 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 612 redundantDealloc = sourceDeallocOp; 613 } 614 615 if (!redundantDealloc) 616 return failure(); 617 618 // Safety check that there are no other deallocations inbetween 619 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 620 // of source before the uses of the clone. With alias information, we could 621 // restrict this to only fail of the dealloc's operand is an alias 622 // of the source. 623 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 624 pos = pos->getNextNode()) { 625 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 626 if (!effectInterface) 627 continue; 628 if (effectInterface.hasEffect<MemoryEffects::Free>()) 629 return failure(); 630 } 631 632 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 633 source); 634 rewriter.eraseOp(redundantDealloc); 635 return success(); 636 } 637 }; 638 639 } // end anonymous namespace. 640 641 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 642 MLIRContext *context) { 643 results.insert<SimplifyClones>(context); 644 } 645 646 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 647 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 648 } 649 650 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 651 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 652 .getOperation(); 653 } 654 655 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 656 return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult(); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // DeallocOp 661 //===----------------------------------------------------------------------===// 662 663 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 664 SmallVectorImpl<OpFoldResult> &results) { 665 /// dealloc(memrefcast) -> dealloc 666 return foldMemRefCast(*this); 667 } 668 669 //===----------------------------------------------------------------------===// 670 // DimOp 671 //===----------------------------------------------------------------------===// 672 673 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 674 int64_t index) { 675 auto loc = result.location; 676 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 677 build(builder, result, source, indexValue); 678 } 679 680 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 681 Value index) { 682 auto indexTy = builder.getIndexType(); 683 build(builder, result, indexTy, source, index); 684 } 685 686 Optional<int64_t> DimOp::getConstantIndex() { 687 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 688 return constantOp.getValue().cast<IntegerAttr>().getInt(); 689 return {}; 690 } 691 692 static LogicalResult verify(DimOp op) { 693 // Assume unknown index to be in range. 694 Optional<int64_t> index = op.getConstantIndex(); 695 if (!index.hasValue()) 696 return success(); 697 698 // Check that constant index is not knowingly out of range. 699 auto type = op.source().getType(); 700 if (auto memrefType = type.dyn_cast<MemRefType>()) { 701 if (index.getValue() >= memrefType.getRank()) 702 return op.emitOpError("index is out of range"); 703 } else if (type.isa<UnrankedMemRefType>()) { 704 // Assume index to be in range. 705 } else { 706 llvm_unreachable("expected operand with memref type"); 707 } 708 return success(); 709 } 710 711 /// Return a map with key being elements in `vals` and data being number of 712 /// occurences of it. Use std::map, since the `vals` here are strides and the 713 /// dynamic stride value is the same as the tombstone value for 714 /// `DenseMap<int64_t>`. 715 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 716 std::map<int64_t, unsigned> numOccurences; 717 for (auto val : vals) 718 numOccurences[val]++; 719 return numOccurences; 720 } 721 722 /// Given the type of the un-rank reduced subview result type and the 723 /// rank-reduced result type, computes the dropped dimensions. This accounts for 724 /// cases where there are multiple unit-dims, but only a subset of those are 725 /// dropped. For MemRefTypes these can be disambiguated using the strides. If a 726 /// dimension is dropped the stride must be dropped too. 727 static llvm::Optional<llvm::SmallDenseSet<unsigned>> 728 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 729 ArrayAttr staticSizes) { 730 llvm::SmallDenseSet<unsigned> unusedDims; 731 if (originalType.getRank() == reducedType.getRank()) 732 return unusedDims; 733 734 for (auto dim : llvm::enumerate(staticSizes)) 735 if (dim.value().cast<IntegerAttr>().getInt() == 1) 736 unusedDims.insert(dim.index()); 737 SmallVector<int64_t> originalStrides, candidateStrides; 738 int64_t originalOffset, candidateOffset; 739 if (failed( 740 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 741 failed( 742 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 743 return llvm::None; 744 745 // For memrefs, a dimension is truly dropped if its corresponding stride is 746 // also dropped. This is particularly important when more than one of the dims 747 // is 1. Track the number of occurences of the strides in the original type 748 // and the candidate type. For each unused dim that stride should not be 749 // present in the candidate type. Note that there could be multiple dimensions 750 // that have the same size. We dont need to exactly figure out which dim 751 // corresponds to which stride, we just need to verify that the number of 752 // reptitions of a stride in the original + number of unused dims with that 753 // stride == number of repititions of a stride in the candidate. 754 std::map<int64_t, unsigned> currUnaccountedStrides = 755 getNumOccurences(originalStrides); 756 std::map<int64_t, unsigned> candidateStridesNumOccurences = 757 getNumOccurences(candidateStrides); 758 llvm::SmallDenseSet<unsigned> prunedUnusedDims; 759 for (unsigned dim : unusedDims) { 760 int64_t originalStride = originalStrides[dim]; 761 if (currUnaccountedStrides[originalStride] > 762 candidateStridesNumOccurences[originalStride]) { 763 // This dim can be treated as dropped. 764 currUnaccountedStrides[originalStride]--; 765 continue; 766 } 767 if (currUnaccountedStrides[originalStride] == 768 candidateStridesNumOccurences[originalStride]) { 769 // The stride for this is not dropped. Keep as is. 770 prunedUnusedDims.insert(dim); 771 continue; 772 } 773 if (currUnaccountedStrides[originalStride] < 774 candidateStridesNumOccurences[originalStride]) { 775 // This should never happen. Cant have a stride in the reduced rank type 776 // that wasnt in the original one. 777 return llvm::None; 778 } 779 } 780 781 for (auto prunedDim : prunedUnusedDims) 782 unusedDims.erase(prunedDim); 783 if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) 784 return llvm::None; 785 return unusedDims; 786 } 787 788 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() { 789 MemRefType sourceType = getSourceType(); 790 MemRefType resultType = getType(); 791 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 792 computeMemRefRankReductionMask(sourceType, resultType, static_sizes()); 793 assert(unusedDims && "unable to find unused dims of subview"); 794 return *unusedDims; 795 } 796 797 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 798 // All forms of folding require a known index. 799 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 800 if (!index) 801 return {}; 802 803 // Folding for unranked types (UnrankedMemRefType) is not supported. 804 auto memrefType = source().getType().dyn_cast<MemRefType>(); 805 if (!memrefType) 806 return {}; 807 808 // Fold if the shape extent along the given index is known. 809 if (!memrefType.isDynamicDim(index.getInt())) { 810 Builder builder(getContext()); 811 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 812 } 813 814 // The size at the given index is now known to be a dynamic size. 815 unsigned unsignedIndex = index.getValue().getZExtValue(); 816 817 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 818 Operation *definingOp = source().getDefiningOp(); 819 820 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 821 return *(alloc.getDynamicSizes().begin() + 822 memrefType.getDynamicDimIndex(unsignedIndex)); 823 824 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 825 return *(alloca.getDynamicSizes().begin() + 826 memrefType.getDynamicDimIndex(unsignedIndex)); 827 828 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 829 return *(view.getDynamicSizes().begin() + 830 memrefType.getDynamicDimIndex(unsignedIndex)); 831 832 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 833 llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims(); 834 unsigned resultIndex = 0; 835 unsigned sourceRank = subview.getSourceType().getRank(); 836 unsigned sourceIndex = 0; 837 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 838 if (unusedDims.count(i)) 839 continue; 840 if (resultIndex == unsignedIndex) { 841 sourceIndex = i; 842 break; 843 } 844 resultIndex++; 845 } 846 assert(subview.isDynamicSize(sourceIndex) && 847 "expected dynamic subview size"); 848 return subview.getDynamicSize(sourceIndex); 849 } 850 851 if (auto sizeInterface = 852 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 853 assert(sizeInterface.isDynamicSize(unsignedIndex) && 854 "Expected dynamic subview size"); 855 return sizeInterface.getDynamicSize(unsignedIndex); 856 } 857 858 // dim(memrefcast) -> dim 859 if (succeeded(foldMemRefCast(*this))) 860 return getResult(); 861 862 return {}; 863 } 864 865 namespace { 866 /// Fold dim of a memref reshape operation to a load into the reshape's shape 867 /// operand. 868 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 869 using OpRewritePattern<DimOp>::OpRewritePattern; 870 871 LogicalResult matchAndRewrite(DimOp dim, 872 PatternRewriter &rewriter) const override { 873 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 874 875 if (!reshape) 876 return failure(); 877 878 // Place the load directly after the reshape to ensure that the shape memref 879 // was not mutated. 880 rewriter.setInsertionPointAfter(reshape); 881 Location loc = dim.getLoc(); 882 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 883 if (load.getType() != dim.getType()) 884 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load); 885 rewriter.replaceOp(dim, load); 886 return success(); 887 } 888 }; 889 890 /// Fold dim of a cast into the dim of the source of the memref cast. 891 struct DimOfCastOp : public OpRewritePattern<DimOp> { 892 using OpRewritePattern<DimOp>::OpRewritePattern; 893 894 LogicalResult matchAndRewrite(DimOp dimOp, 895 PatternRewriter &rewriter) const override { 896 auto castOp = dimOp.source().getDefiningOp<BufferCastOp>(); 897 if (!castOp) 898 return failure(); 899 Value newSource = castOp.getOperand(); 900 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 901 return success(); 902 } 903 }; 904 } // end anonymous namespace. 905 906 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 907 MLIRContext *context) { 908 results.add<DimOfMemRefReshape, DimOfCastOp>(context); 909 } 910 911 // --------------------------------------------------------------------------- 912 // DmaStartOp 913 // --------------------------------------------------------------------------- 914 915 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 916 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 917 ValueRange destIndices, Value numElements, 918 Value tagMemRef, ValueRange tagIndices, Value stride, 919 Value elementsPerStride) { 920 result.addOperands(srcMemRef); 921 result.addOperands(srcIndices); 922 result.addOperands(destMemRef); 923 result.addOperands(destIndices); 924 result.addOperands({numElements, tagMemRef}); 925 result.addOperands(tagIndices); 926 if (stride) 927 result.addOperands({stride, elementsPerStride}); 928 } 929 930 static void print(OpAsmPrinter &p, DmaStartOp op) { 931 p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " 932 << op.getDstMemRef() << '[' << op.getDstIndices() << "], " 933 << op.getNumElements() << ", " << op.getTagMemRef() << '[' 934 << op.getTagIndices() << ']'; 935 if (op.isStrided()) 936 p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); 937 938 p.printOptionalAttrDict(op->getAttrs()); 939 p << " : " << op.getSrcMemRef().getType() << ", " 940 << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); 941 } 942 943 // Parse DmaStartOp. 944 // Ex: 945 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 946 // %tag[%index], %stride, %num_elt_per_stride : 947 // : memref<3076 x f32, 0>, 948 // memref<1024 x f32, 2>, 949 // memref<1 x i32> 950 // 951 static ParseResult parseDmaStartOp(OpAsmParser &parser, 952 OperationState &result) { 953 OpAsmParser::OperandType srcMemRefInfo; 954 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 955 OpAsmParser::OperandType dstMemRefInfo; 956 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 957 OpAsmParser::OperandType numElementsInfo; 958 OpAsmParser::OperandType tagMemrefInfo; 959 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 960 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 961 962 SmallVector<Type, 3> types; 963 auto indexType = parser.getBuilder().getIndexType(); 964 965 // Parse and resolve the following list of operands: 966 // *) source memref followed by its indices (in square brackets). 967 // *) destination memref followed by its indices (in square brackets). 968 // *) dma size in KiB. 969 if (parser.parseOperand(srcMemRefInfo) || 970 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 971 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 972 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 973 parser.parseComma() || parser.parseOperand(numElementsInfo) || 974 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 975 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 976 return failure(); 977 978 // Parse optional stride and elements per stride. 979 if (parser.parseTrailingOperandList(strideInfo)) 980 return failure(); 981 982 bool isStrided = strideInfo.size() == 2; 983 if (!strideInfo.empty() && !isStrided) { 984 return parser.emitError(parser.getNameLoc(), 985 "expected two stride related operands"); 986 } 987 988 if (parser.parseColonTypeList(types)) 989 return failure(); 990 if (types.size() != 3) 991 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 992 993 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 994 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 995 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 996 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 997 // size should be an index. 998 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 999 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 1000 // tag indices should be index. 1001 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 1002 return failure(); 1003 1004 if (isStrided) { 1005 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1006 return failure(); 1007 } 1008 1009 return success(); 1010 } 1011 1012 static LogicalResult verify(DmaStartOp op) { 1013 unsigned numOperands = op.getNumOperands(); 1014 1015 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1016 // the number of elements. 1017 if (numOperands < 4) 1018 return op.emitOpError("expected at least 4 operands"); 1019 1020 // Check types of operands. The order of these calls is important: the later 1021 // calls rely on some type properties to compute the operand position. 1022 // 1. Source memref. 1023 if (!op.getSrcMemRef().getType().isa<MemRefType>()) 1024 return op.emitOpError("expected source to be of memref type"); 1025 if (numOperands < op.getSrcMemRefRank() + 4) 1026 return op.emitOpError() 1027 << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; 1028 if (!op.getSrcIndices().empty() && 1029 !llvm::all_of(op.getSrcIndices().getTypes(), 1030 [](Type t) { return t.isIndex(); })) 1031 return op.emitOpError("expected source indices to be of index type"); 1032 1033 // 2. Destination memref. 1034 if (!op.getDstMemRef().getType().isa<MemRefType>()) 1035 return op.emitOpError("expected destination to be of memref type"); 1036 unsigned numExpectedOperands = 1037 op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; 1038 if (numOperands < numExpectedOperands) 1039 return op.emitOpError() 1040 << "expected at least " << numExpectedOperands << " operands"; 1041 if (!op.getDstIndices().empty() && 1042 !llvm::all_of(op.getDstIndices().getTypes(), 1043 [](Type t) { return t.isIndex(); })) 1044 return op.emitOpError("expected destination indices to be of index type"); 1045 1046 // 3. Number of elements. 1047 if (!op.getNumElements().getType().isIndex()) 1048 return op.emitOpError("expected num elements to be of index type"); 1049 1050 // 4. Tag memref. 1051 if (!op.getTagMemRef().getType().isa<MemRefType>()) 1052 return op.emitOpError("expected tag to be of memref type"); 1053 numExpectedOperands += op.getTagMemRefRank(); 1054 if (numOperands < numExpectedOperands) 1055 return op.emitOpError() 1056 << "expected at least " << numExpectedOperands << " operands"; 1057 if (!op.getTagIndices().empty() && 1058 !llvm::all_of(op.getTagIndices().getTypes(), 1059 [](Type t) { return t.isIndex(); })) 1060 return op.emitOpError("expected tag indices to be of index type"); 1061 1062 // Optional stride-related operands must be either both present or both 1063 // absent. 1064 if (numOperands != numExpectedOperands && 1065 numOperands != numExpectedOperands + 2) 1066 return op.emitOpError("incorrect number of operands"); 1067 1068 // 5. Strides. 1069 if (op.isStrided()) { 1070 if (!op.getStride().getType().isIndex() || 1071 !op.getNumElementsPerStride().getType().isIndex()) 1072 return op.emitOpError( 1073 "expected stride and num elements per stride to be of type index"); 1074 } 1075 1076 return success(); 1077 } 1078 1079 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1080 SmallVectorImpl<OpFoldResult> &results) { 1081 /// dma_start(memrefcast) -> dma_start 1082 return foldMemRefCast(*this); 1083 } 1084 1085 // --------------------------------------------------------------------------- 1086 // DmaWaitOp 1087 // --------------------------------------------------------------------------- 1088 1089 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1090 SmallVectorImpl<OpFoldResult> &results) { 1091 /// dma_wait(memrefcast) -> dma_wait 1092 return foldMemRefCast(*this); 1093 } 1094 1095 static LogicalResult verify(DmaWaitOp op) { 1096 // Check that the number of tag indices matches the tagMemRef rank. 1097 unsigned numTagIndices = op.tagIndices().size(); 1098 unsigned tagMemRefRank = op.getTagMemRefRank(); 1099 if (numTagIndices != tagMemRefRank) 1100 return op.emitOpError() << "expected tagIndices to have the same number of " 1101 "elements as the tagMemRef rank, expected " 1102 << tagMemRefRank << ", but got " << numTagIndices; 1103 return success(); 1104 } 1105 1106 //===----------------------------------------------------------------------===// 1107 // GlobalOp 1108 //===----------------------------------------------------------------------===// 1109 1110 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1111 TypeAttr type, 1112 Attribute initialValue) { 1113 p << type; 1114 if (!op.isExternal()) { 1115 p << " = "; 1116 if (op.isUninitialized()) 1117 p << "uninitialized"; 1118 else 1119 p.printAttributeWithoutType(initialValue); 1120 } 1121 } 1122 1123 static ParseResult 1124 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1125 Attribute &initialValue) { 1126 Type type; 1127 if (parser.parseType(type)) 1128 return failure(); 1129 1130 auto memrefType = type.dyn_cast<MemRefType>(); 1131 if (!memrefType || !memrefType.hasStaticShape()) 1132 return parser.emitError(parser.getNameLoc()) 1133 << "type should be static shaped memref, but got " << type; 1134 typeAttr = TypeAttr::get(type); 1135 1136 if (parser.parseOptionalEqual()) 1137 return success(); 1138 1139 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1140 initialValue = UnitAttr::get(parser.getContext()); 1141 return success(); 1142 } 1143 1144 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1145 if (parser.parseAttribute(initialValue, tensorType)) 1146 return failure(); 1147 if (!initialValue.isa<ElementsAttr>()) 1148 return parser.emitError(parser.getNameLoc()) 1149 << "initial value should be a unit or elements attribute"; 1150 return success(); 1151 } 1152 1153 static LogicalResult verify(GlobalOp op) { 1154 auto memrefType = op.type().dyn_cast<MemRefType>(); 1155 if (!memrefType || !memrefType.hasStaticShape()) 1156 return op.emitOpError("type should be static shaped memref, but got ") 1157 << op.type(); 1158 1159 // Verify that the initial value, if present, is either a unit attribute or 1160 // an elements attribute. 1161 if (op.initial_value().hasValue()) { 1162 Attribute initValue = op.initial_value().getValue(); 1163 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1164 return op.emitOpError("initial value should be a unit or elements " 1165 "attribute, but got ") 1166 << initValue; 1167 1168 // Check that the type of the initial value is compatible with the type of 1169 // the global variable. 1170 if (initValue.isa<ElementsAttr>()) { 1171 Type initType = initValue.getType(); 1172 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1173 if (initType != tensorType) 1174 return op.emitOpError("initial value expected to be of type ") 1175 << tensorType << ", but was of type " << initType; 1176 } 1177 } 1178 1179 if (Optional<uint64_t> alignAttr = op.alignment()) { 1180 uint64_t alignment = alignAttr.getValue(); 1181 1182 if (!llvm::isPowerOf2_64(alignment)) 1183 return op->emitError() << "alignment attribute value " << alignment 1184 << " is not a power of 2"; 1185 } 1186 1187 // TODO: verify visibility for declarations. 1188 return success(); 1189 } 1190 1191 //===----------------------------------------------------------------------===// 1192 // GetGlobalOp 1193 //===----------------------------------------------------------------------===// 1194 1195 LogicalResult 1196 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1197 // Verify that the result type is same as the type of the referenced 1198 // memref.global op. 1199 auto global = 1200 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1201 if (!global) 1202 return emitOpError("'") 1203 << name() << "' does not reference a valid global memref"; 1204 1205 Type resultType = result().getType(); 1206 if (global.type() != resultType) 1207 return emitOpError("result type ") 1208 << resultType << " does not match type " << global.type() 1209 << " of the global memref @" << name(); 1210 return success(); 1211 } 1212 1213 //===----------------------------------------------------------------------===// 1214 // LoadOp 1215 //===----------------------------------------------------------------------===// 1216 1217 static LogicalResult verify(LoadOp op) { 1218 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1219 return op.emitOpError("incorrect number of indices for load"); 1220 return success(); 1221 } 1222 1223 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1224 /// load(memrefcast) -> load 1225 if (succeeded(foldMemRefCast(*this))) 1226 return getResult(); 1227 return OpFoldResult(); 1228 } 1229 1230 namespace { 1231 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1232 /// corresponding tensor. 1233 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1234 using OpRewritePattern<LoadOp>::OpRewritePattern; 1235 1236 LogicalResult matchAndRewrite(LoadOp load, 1237 PatternRewriter &rewriter) const override { 1238 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1239 if (!buffercast) 1240 return failure(); 1241 1242 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1243 load.indices()); 1244 return success(); 1245 } 1246 }; 1247 } // end anonymous namespace. 1248 1249 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1250 MLIRContext *context) { 1251 results.add<LoadOfBufferCast>(context); 1252 } 1253 1254 //===----------------------------------------------------------------------===// 1255 // PrefetchOp 1256 //===----------------------------------------------------------------------===// 1257 1258 static void print(OpAsmPrinter &p, PrefetchOp op) { 1259 p << " " << op.memref() << '['; 1260 p.printOperands(op.indices()); 1261 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1262 p << ", locality<" << op.localityHint(); 1263 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1264 p.printOptionalAttrDict( 1265 op->getAttrs(), 1266 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1267 p << " : " << op.getMemRefType(); 1268 } 1269 1270 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1271 OperationState &result) { 1272 OpAsmParser::OperandType memrefInfo; 1273 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1274 IntegerAttr localityHint; 1275 MemRefType type; 1276 StringRef readOrWrite, cacheType; 1277 1278 auto indexTy = parser.getBuilder().getIndexType(); 1279 auto i32Type = parser.getBuilder().getIntegerType(32); 1280 if (parser.parseOperand(memrefInfo) || 1281 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1282 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1283 parser.parseComma() || parser.parseKeyword("locality") || 1284 parser.parseLess() || 1285 parser.parseAttribute(localityHint, i32Type, "localityHint", 1286 result.attributes) || 1287 parser.parseGreater() || parser.parseComma() || 1288 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1289 parser.resolveOperand(memrefInfo, type, result.operands) || 1290 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1291 return failure(); 1292 1293 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1294 return parser.emitError(parser.getNameLoc(), 1295 "rw specifier has to be 'read' or 'write'"); 1296 result.addAttribute( 1297 PrefetchOp::getIsWriteAttrName(), 1298 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1299 1300 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1301 return parser.emitError(parser.getNameLoc(), 1302 "cache type has to be 'data' or 'instr'"); 1303 1304 result.addAttribute( 1305 PrefetchOp::getIsDataCacheAttrName(), 1306 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1307 1308 return success(); 1309 } 1310 1311 static LogicalResult verify(PrefetchOp op) { 1312 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1313 return op.emitOpError("too few indices"); 1314 1315 return success(); 1316 } 1317 1318 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1319 SmallVectorImpl<OpFoldResult> &results) { 1320 // prefetch(memrefcast) -> prefetch 1321 return foldMemRefCast(*this); 1322 } 1323 1324 //===----------------------------------------------------------------------===// 1325 // ReinterpretCastOp 1326 //===----------------------------------------------------------------------===// 1327 1328 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1329 /// `staticSizes` and `staticStrides` are automatically filled with 1330 /// source-memref-rank sentinel values that encode dynamic entries. 1331 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1332 MemRefType resultType, Value source, 1333 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1334 ArrayRef<OpFoldResult> strides, 1335 ArrayRef<NamedAttribute> attrs) { 1336 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1337 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1338 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1339 ShapedType::kDynamicStrideOrOffset); 1340 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1341 ShapedType::kDynamicSize); 1342 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1343 ShapedType::kDynamicStrideOrOffset); 1344 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1345 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1346 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1347 result.addAttributes(attrs); 1348 } 1349 1350 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1351 MemRefType resultType, Value source, 1352 int64_t offset, ArrayRef<int64_t> sizes, 1353 ArrayRef<int64_t> strides, 1354 ArrayRef<NamedAttribute> attrs) { 1355 SmallVector<OpFoldResult> sizeValues = 1356 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1357 return b.getI64IntegerAttr(v); 1358 })); 1359 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1360 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1361 return b.getI64IntegerAttr(v); 1362 })); 1363 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1364 strideValues, attrs); 1365 } 1366 1367 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1368 MemRefType resultType, Value source, Value offset, 1369 ValueRange sizes, ValueRange strides, 1370 ArrayRef<NamedAttribute> attrs) { 1371 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1372 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1373 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1374 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1375 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1376 } 1377 1378 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1379 // completed automatically, like we have for subview and extract_slice. 1380 static LogicalResult verify(ReinterpretCastOp op) { 1381 // The source and result memrefs should be in the same memory space. 1382 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1383 auto resultType = op.getType().cast<MemRefType>(); 1384 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1385 return op.emitError("different memory spaces specified for source type ") 1386 << srcType << " and result memref type " << resultType; 1387 if (srcType.getElementType() != resultType.getElementType()) 1388 return op.emitError("different element types specified for source type ") 1389 << srcType << " and result memref type " << resultType; 1390 1391 // Match sizes in result memref type and in static_sizes attribute. 1392 for (auto &en : 1393 llvm::enumerate(llvm::zip(resultType.getShape(), 1394 extractFromI64ArrayAttr(op.static_sizes())))) { 1395 int64_t resultSize = std::get<0>(en.value()); 1396 int64_t expectedSize = std::get<1>(en.value()); 1397 if (resultSize != expectedSize) 1398 return op.emitError("expected result type with size = ") 1399 << expectedSize << " instead of " << resultSize 1400 << " in dim = " << en.index(); 1401 } 1402 1403 // Match offset and strides in static_offset and static_strides attributes if 1404 // result memref type has an affine map specified. 1405 if (!resultType.getAffineMaps().empty()) { 1406 int64_t resultOffset; 1407 SmallVector<int64_t, 4> resultStrides; 1408 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1409 return failure(); 1410 1411 // Match offset in result memref type and in static_offsets attribute. 1412 int64_t expectedOffset = 1413 extractFromI64ArrayAttr(op.static_offsets()).front(); 1414 if (resultOffset != expectedOffset) 1415 return op.emitError("expected result type with offset = ") 1416 << resultOffset << " instead of " << expectedOffset; 1417 1418 // Match strides in result memref type and in static_strides attribute. 1419 for (auto &en : llvm::enumerate(llvm::zip( 1420 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1421 int64_t resultStride = std::get<0>(en.value()); 1422 int64_t expectedStride = std::get<1>(en.value()); 1423 if (resultStride != expectedStride) 1424 return op.emitError("expected result type with stride = ") 1425 << expectedStride << " instead of " << resultStride 1426 << " in dim = " << en.index(); 1427 } 1428 } 1429 return success(); 1430 } 1431 1432 //===----------------------------------------------------------------------===// 1433 // Reassociative reshape ops 1434 //===----------------------------------------------------------------------===// 1435 1436 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1437 return getSymbolLessAffineMaps(getReassociationExprs()); 1438 } 1439 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1440 return convertReassociationIndicesToExprs(getContext(), 1441 getReassociationIndices()); 1442 } 1443 1444 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1445 return getSymbolLessAffineMaps(getReassociationExprs()); 1446 } 1447 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1448 return convertReassociationIndicesToExprs(getContext(), 1449 getReassociationIndices()); 1450 } 1451 1452 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1453 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1454 } 1455 1456 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1457 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1458 } 1459 1460 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1461 /// copies. 1462 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1463 ArrayRef<int64_t> sizes, 1464 ArrayRef<AffineExpr> strides) { 1465 assert(sizes.size() == strides.size() && "mismatched ranks"); 1466 // off by 1 indexing to avoid out of bounds 1467 // V 1468 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1469 // Only bands of static shapes are reshapable. This is due to the fact that 1470 // there is no relation between dynamic sizes and dynamic strides: we do not 1471 // have enough information to know whether a "-1" size corresponds to the 1472 // proper symbol in the AffineExpr of a stride. 1473 if (ShapedType::isDynamic(sizes[dim + 1])) 1474 return false; 1475 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1476 // simplify on the fly and catch more reshapable cases. 1477 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1478 return false; 1479 } 1480 return true; 1481 } 1482 1483 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1484 /// expected to be valid) to `type`. 1485 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1486 /// MemRefType. 1487 static MemRefType 1488 computeReshapeCollapsedType(MemRefType type, 1489 ArrayRef<AffineMap> reassociation) { 1490 auto sizes = type.getShape(); 1491 AffineExpr offset; 1492 SmallVector<AffineExpr, 4> strides; 1493 auto status = getStridesAndOffset(type, strides, offset); 1494 (void)status; 1495 assert(succeeded(status) && "expected strided memref"); 1496 1497 SmallVector<int64_t, 4> newSizes; 1498 newSizes.reserve(reassociation.size()); 1499 SmallVector<AffineExpr, 4> newStrides; 1500 newStrides.reserve(reassociation.size()); 1501 1502 // Use the fact that reassociation is valid to simplify the logic: only use 1503 // each map's rank. 1504 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1505 unsigned currentDim = 0; 1506 for (AffineMap m : reassociation) { 1507 unsigned dim = m.getNumResults(); 1508 int64_t size = 1; 1509 AffineExpr stride = strides[currentDim + dim - 1]; 1510 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1511 size = ShapedType::kDynamicSize; 1512 stride = AffineExpr(); 1513 } else { 1514 for (unsigned d = 0; d < dim; ++d) 1515 size *= sizes[currentDim + d]; 1516 } 1517 newSizes.push_back(size); 1518 newStrides.push_back(stride); 1519 currentDim += dim; 1520 } 1521 1522 // Early-exit: if `type` is contiguous, the result must be contiguous. 1523 if (canonicalizeStridedLayout(type).getAffineMaps().empty()) 1524 return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); 1525 1526 // Convert back to int64_t because we don't have enough information to create 1527 // new strided layouts from AffineExpr only. This corresponds to a case where 1528 // copies may be necessary. 1529 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1530 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1531 intOffset = o.getValue(); 1532 SmallVector<int64_t, 4> intStrides; 1533 intStrides.reserve(strides.size()); 1534 for (auto stride : newStrides) { 1535 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1536 intStrides.push_back(cst.getValue()); 1537 else 1538 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1539 } 1540 auto layout = 1541 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1542 return canonicalizeStridedLayout( 1543 MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); 1544 } 1545 1546 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1547 ArrayRef<ReassociationIndices> reassociation, 1548 ArrayRef<NamedAttribute> attrs) { 1549 auto memRefType = src.getType().cast<MemRefType>(); 1550 auto resultType = computeReshapeCollapsedType( 1551 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1552 b.getContext(), reassociation))); 1553 build(b, result, resultType, src, attrs); 1554 result.addAttribute(getReassociationAttrName(), 1555 getReassociationIndicesAttribute(b, reassociation)); 1556 } 1557 1558 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1559 ArrayRef<ReassociationIndices> reassociation, 1560 ArrayRef<NamedAttribute> attrs) { 1561 auto memRefType = src.getType().cast<MemRefType>(); 1562 auto resultType = computeReshapeCollapsedType( 1563 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1564 b.getContext(), reassociation))); 1565 build(b, result, resultType, src, attrs); 1566 result.addAttribute(getReassociationAttrName(), 1567 getReassociationIndicesAttribute(b, reassociation)); 1568 } 1569 1570 template <typename ReshapeOp, 1571 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1572 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1573 MemRefType collapsedType) { 1574 if (failed( 1575 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1576 return failure(); 1577 auto maps = op.getReassociationMaps(); 1578 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1579 if (collapsedType != expectedType) 1580 return op.emitOpError("expected collapsed type to be ") 1581 << expectedType << ", but got " << collapsedType; 1582 return success(); 1583 } 1584 1585 static LogicalResult verify(ExpandShapeOp op) { 1586 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1587 } 1588 1589 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1590 MLIRContext *context) { 1591 results.add<CollapseReshapeOps<ExpandShapeOp>, 1592 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1593 } 1594 1595 static LogicalResult verify(CollapseShapeOp op) { 1596 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1597 } 1598 1599 struct CollapseShapeOpMemRefCastFolder 1600 : public OpRewritePattern<CollapseShapeOp> { 1601 public: 1602 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1603 1604 LogicalResult matchAndRewrite(CollapseShapeOp op, 1605 PatternRewriter &rewriter) const override { 1606 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1607 if (!cast) 1608 return failure(); 1609 1610 if (!CastOp::canFoldIntoConsumerOp(cast)) 1611 return failure(); 1612 1613 Type newResultType = computeReshapeCollapsedType( 1614 cast.getOperand().getType().cast<MemRefType>(), 1615 op.getReassociationMaps()); 1616 1617 if (newResultType == op.getResultType()) { 1618 rewriter.updateRootInPlace( 1619 op, [&]() { op.srcMutable().assign(cast.source()); }); 1620 } else { 1621 Value newOp = rewriter.create<CollapseShapeOp>( 1622 op->getLoc(), cast.source(), op.getReassociationIndices()); 1623 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1624 } 1625 return success(); 1626 } 1627 }; 1628 1629 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1630 MLIRContext *context) { 1631 results.add<CollapseReshapeOps<CollapseShapeOp>, 1632 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1633 CollapseShapeOpMemRefCastFolder>(context); 1634 } 1635 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1636 if (succeeded(foldMemRefCast(*this))) 1637 return getResult(); 1638 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1639 } 1640 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1641 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1642 } 1643 1644 //===----------------------------------------------------------------------===// 1645 // ReshapeOp 1646 //===----------------------------------------------------------------------===// 1647 1648 static LogicalResult verify(ReshapeOp op) { 1649 Type operandType = op.source().getType(); 1650 Type resultType = op.result().getType(); 1651 1652 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1653 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1654 if (operandElementType != resultElementType) 1655 return op.emitOpError("element types of source and destination memref " 1656 "types should be the same"); 1657 1658 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1659 if (!operandMemRefType.getAffineMaps().empty()) 1660 return op.emitOpError( 1661 "source memref type should have identity affine map"); 1662 1663 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1664 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1665 if (resultMemRefType) { 1666 if (!resultMemRefType.getAffineMaps().empty()) 1667 return op.emitOpError( 1668 "result memref type should have identity affine map"); 1669 if (shapeSize == ShapedType::kDynamicSize) 1670 return op.emitOpError("cannot use shape operand with dynamic length to " 1671 "reshape to statically-ranked memref type"); 1672 if (shapeSize != resultMemRefType.getRank()) 1673 return op.emitOpError( 1674 "length of shape operand differs from the result's memref rank"); 1675 } 1676 return success(); 1677 } 1678 1679 //===----------------------------------------------------------------------===// 1680 // StoreOp 1681 //===----------------------------------------------------------------------===// 1682 1683 static LogicalResult verify(StoreOp op) { 1684 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1685 return op.emitOpError("store index operand count not equal to memref rank"); 1686 1687 return success(); 1688 } 1689 1690 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1691 SmallVectorImpl<OpFoldResult> &results) { 1692 /// store(memrefcast) -> store 1693 return foldMemRefCast(*this, getValueToStore()); 1694 } 1695 1696 //===----------------------------------------------------------------------===// 1697 // SubViewOp 1698 //===----------------------------------------------------------------------===// 1699 1700 namespace { 1701 /// Helpers to write more idiomatic operations. 1702 namespace saturated_arith { 1703 struct Wrapper { 1704 explicit Wrapper(int64_t v) : v(v) {} 1705 operator int64_t() { return v; } 1706 int64_t v; 1707 }; 1708 Wrapper operator+(Wrapper a, int64_t b) { 1709 if (ShapedType::isDynamicStrideOrOffset(a) || 1710 ShapedType::isDynamicStrideOrOffset(b)) 1711 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1712 return Wrapper(a.v + b); 1713 } 1714 Wrapper operator*(Wrapper a, int64_t b) { 1715 if (ShapedType::isDynamicStrideOrOffset(a) || 1716 ShapedType::isDynamicStrideOrOffset(b)) 1717 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1718 return Wrapper(a.v * b); 1719 } 1720 } // end namespace saturated_arith 1721 } // end namespace 1722 1723 /// A subview result type can be fully inferred from the source type and the 1724 /// static representation of offsets, sizes and strides. Special sentinels 1725 /// encode the dynamic case. 1726 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1727 ArrayRef<int64_t> leadingStaticOffsets, 1728 ArrayRef<int64_t> leadingStaticSizes, 1729 ArrayRef<int64_t> leadingStaticStrides) { 1730 // A subview may specify only a leading subset of offset/sizes/strides in 1731 // which case we complete with offset=0, sizes from memref type and strides=1. 1732 unsigned rank = sourceMemRefType.getRank(); 1733 assert(leadingStaticOffsets.size() <= rank && 1734 "unexpected leadingStaticOffsets overflow"); 1735 assert(leadingStaticSizes.size() <= rank && 1736 "unexpected leadingStaticSizes overflow"); 1737 assert(leadingStaticStrides.size() <= rank && 1738 "unexpected leadingStaticStrides overflow"); 1739 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1740 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1741 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1742 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1743 unsigned numTrailingSizes = rank - staticSizes.size(); 1744 unsigned numTrailingStrides = rank - staticStrides.size(); 1745 staticOffsets.append(numTrailingOffsets, 0); 1746 llvm::append_range(staticSizes, 1747 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1748 staticStrides.append(numTrailingStrides, 1); 1749 1750 // Extract source offset and strides. 1751 int64_t sourceOffset; 1752 SmallVector<int64_t, 4> sourceStrides; 1753 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1754 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1755 (void)res; 1756 1757 // Compute target offset whose value is: 1758 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1759 int64_t targetOffset = sourceOffset; 1760 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1761 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1762 using namespace saturated_arith; 1763 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1764 } 1765 1766 // Compute target stride whose value is: 1767 // `sourceStrides_i * staticStrides_i`. 1768 SmallVector<int64_t, 4> targetStrides; 1769 targetStrides.reserve(staticOffsets.size()); 1770 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1771 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1772 using namespace saturated_arith; 1773 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1774 } 1775 1776 // The type is now known. 1777 return MemRefType::get( 1778 staticSizes, sourceMemRefType.getElementType(), 1779 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1780 sourceMemRefType.getContext()), 1781 sourceMemRefType.getMemorySpace()); 1782 } 1783 1784 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1785 ArrayRef<OpFoldResult> leadingStaticOffsets, 1786 ArrayRef<OpFoldResult> leadingStaticSizes, 1787 ArrayRef<OpFoldResult> leadingStaticStrides) { 1788 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1789 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1790 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1791 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1792 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1793 ShapedType::kDynamicSize); 1794 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1795 staticStrides, ShapedType::kDynamicStrideOrOffset); 1796 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1797 staticSizes, staticStrides) 1798 .cast<MemRefType>(); 1799 } 1800 1801 Type SubViewOp::inferRankReducedResultType( 1802 unsigned resultRank, MemRefType sourceRankedTensorType, 1803 ArrayRef<int64_t> leadingStaticOffsets, 1804 ArrayRef<int64_t> leadingStaticSizes, 1805 ArrayRef<int64_t> leadingStaticStrides) { 1806 auto inferredType = 1807 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1808 leadingStaticSizes, leadingStaticStrides) 1809 .cast<MemRefType>(); 1810 assert(inferredType.getRank() >= resultRank && "expected "); 1811 int rankDiff = inferredType.getRank() - resultRank; 1812 if (rankDiff > 0) { 1813 auto shape = inferredType.getShape(); 1814 llvm::SmallDenseSet<unsigned> dimsToProject; 1815 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1816 SmallVector<int64_t> projectedShape; 1817 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1818 if (!dimsToProject.contains(pos)) 1819 projectedShape.push_back(shape[pos]); 1820 1821 AffineMap map; 1822 auto maps = inferredType.getAffineMaps(); 1823 if (!maps.empty() && maps.front()) 1824 map = getProjectedMap(maps.front(), dimsToProject); 1825 inferredType = 1826 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1827 inferredType.getMemorySpace()); 1828 } 1829 return inferredType; 1830 } 1831 1832 Type SubViewOp::inferRankReducedResultType( 1833 unsigned resultRank, MemRefType sourceRankedTensorType, 1834 ArrayRef<OpFoldResult> leadingStaticOffsets, 1835 ArrayRef<OpFoldResult> leadingStaticSizes, 1836 ArrayRef<OpFoldResult> leadingStaticStrides) { 1837 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1838 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1839 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1840 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1841 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1842 ShapedType::kDynamicSize); 1843 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1844 staticStrides, ShapedType::kDynamicStrideOrOffset); 1845 return SubViewOp::inferRankReducedResultType( 1846 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1847 staticStrides); 1848 } 1849 // Build a SubViewOp with mixed static and dynamic entries and custom result 1850 // type. If the type passed is nullptr, it is inferred. 1851 void SubViewOp::build(OpBuilder &b, OperationState &result, 1852 MemRefType resultType, Value source, 1853 ArrayRef<OpFoldResult> offsets, 1854 ArrayRef<OpFoldResult> sizes, 1855 ArrayRef<OpFoldResult> strides, 1856 ArrayRef<NamedAttribute> attrs) { 1857 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1858 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1859 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1860 ShapedType::kDynamicStrideOrOffset); 1861 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1862 ShapedType::kDynamicSize); 1863 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1864 ShapedType::kDynamicStrideOrOffset); 1865 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1866 // Structuring implementation this way avoids duplication between builders. 1867 if (!resultType) { 1868 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1869 staticSizes, staticStrides) 1870 .cast<MemRefType>(); 1871 } 1872 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1873 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1874 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1875 result.addAttributes(attrs); 1876 } 1877 1878 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1879 // type. 1880 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1881 ArrayRef<OpFoldResult> offsets, 1882 ArrayRef<OpFoldResult> sizes, 1883 ArrayRef<OpFoldResult> strides, 1884 ArrayRef<NamedAttribute> attrs) { 1885 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1886 } 1887 1888 // Build a SubViewOp with static entries and inferred result type. 1889 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1890 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1891 ArrayRef<int64_t> strides, 1892 ArrayRef<NamedAttribute> attrs) { 1893 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1894 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1895 return b.getI64IntegerAttr(v); 1896 })); 1897 SmallVector<OpFoldResult> sizeValues = 1898 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1899 return b.getI64IntegerAttr(v); 1900 })); 1901 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1902 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1903 return b.getI64IntegerAttr(v); 1904 })); 1905 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1906 } 1907 1908 // Build a SubViewOp with dynamic entries and custom result type. If the 1909 // type passed is nullptr, it is inferred. 1910 void SubViewOp::build(OpBuilder &b, OperationState &result, 1911 MemRefType resultType, Value source, 1912 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1913 ArrayRef<int64_t> strides, 1914 ArrayRef<NamedAttribute> attrs) { 1915 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1916 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1917 return b.getI64IntegerAttr(v); 1918 })); 1919 SmallVector<OpFoldResult> sizeValues = 1920 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1921 return b.getI64IntegerAttr(v); 1922 })); 1923 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1924 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1925 return b.getI64IntegerAttr(v); 1926 })); 1927 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1928 attrs); 1929 } 1930 1931 // Build a SubViewOp with dynamic entries and custom result type. If the type 1932 // passed is nullptr, it is inferred. 1933 void SubViewOp::build(OpBuilder &b, OperationState &result, 1934 MemRefType resultType, Value source, ValueRange offsets, 1935 ValueRange sizes, ValueRange strides, 1936 ArrayRef<NamedAttribute> attrs) { 1937 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1938 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1939 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1940 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1941 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1942 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1943 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1944 } 1945 1946 // Build a SubViewOp with dynamic entries and inferred result type. 1947 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1948 ValueRange offsets, ValueRange sizes, ValueRange strides, 1949 ArrayRef<NamedAttribute> attrs) { 1950 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1951 } 1952 1953 /// For ViewLikeOpInterface. 1954 Value SubViewOp::getViewSource() { return source(); } 1955 1956 enum SubViewVerificationResult { 1957 Success, 1958 RankTooLarge, 1959 SizeMismatch, 1960 ElemTypeMismatch, 1961 MemSpaceMismatch, 1962 AffineMapMismatch 1963 }; 1964 1965 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1966 /// This function is slight variant of `is subsequence` algorithm where 1967 /// not matching dimension must be 1. 1968 static SubViewVerificationResult 1969 isRankReducedType(Type originalType, Type candidateReducedType, 1970 ArrayAttr staticSizes, std::string *errMsg = nullptr) { 1971 if (originalType == candidateReducedType) 1972 return SubViewVerificationResult::Success; 1973 if (!originalType.isa<MemRefType>()) 1974 return SubViewVerificationResult::Success; 1975 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1976 return SubViewVerificationResult::Success; 1977 1978 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1979 ShapedType candidateReducedShapedType = 1980 candidateReducedType.cast<ShapedType>(); 1981 1982 // Rank and size logic is valid for all ShapedTypes. 1983 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1984 ArrayRef<int64_t> candidateReducedShape = 1985 candidateReducedShapedType.getShape(); 1986 unsigned originalRank = originalShape.size(), 1987 candidateReducedRank = candidateReducedShape.size(); 1988 if (candidateReducedRank > originalRank) 1989 return SubViewVerificationResult::RankTooLarge; 1990 1991 MemRefType original = originalType.cast<MemRefType>(); 1992 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1993 1994 auto optionalUnusedDimsMask = 1995 computeMemRefRankReductionMask(original, candidateReduced, staticSizes); 1996 1997 // Sizes cannot be matched in case empty vector is returned. 1998 if (!optionalUnusedDimsMask.hasValue()) 1999 return SubViewVerificationResult::SizeMismatch; 2000 2001 if (originalShapedType.getElementType() != 2002 candidateReducedShapedType.getElementType()) 2003 return SubViewVerificationResult::ElemTypeMismatch; 2004 2005 // Strided layout logic is relevant for MemRefType only. 2006 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 2007 return SubViewVerificationResult::MemSpaceMismatch; 2008 return SubViewVerificationResult::Success; 2009 } 2010 2011 template <typename OpTy> 2012 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 2013 OpTy op, Type expectedType, 2014 StringRef errMsg = "") { 2015 auto memrefType = expectedType.cast<ShapedType>(); 2016 switch (result) { 2017 case SubViewVerificationResult::Success: 2018 return success(); 2019 case SubViewVerificationResult::RankTooLarge: 2020 return op.emitError("expected result rank to be smaller or equal to ") 2021 << "the source rank. " << errMsg; 2022 case SubViewVerificationResult::SizeMismatch: 2023 return op.emitError("expected result type to be ") 2024 << expectedType 2025 << " or a rank-reduced version. (mismatch of result sizes) " 2026 << errMsg; 2027 case SubViewVerificationResult::ElemTypeMismatch: 2028 return op.emitError("expected result element type to be ") 2029 << memrefType.getElementType() << errMsg; 2030 case SubViewVerificationResult::MemSpaceMismatch: 2031 return op.emitError("expected result and source memory spaces to match.") 2032 << errMsg; 2033 case SubViewVerificationResult::AffineMapMismatch: 2034 return op.emitError("expected result type to be ") 2035 << expectedType 2036 << " or a rank-reduced version. (mismatch of result affine map) " 2037 << errMsg; 2038 } 2039 llvm_unreachable("unexpected subview verification result"); 2040 } 2041 2042 /// Verifier for SubViewOp. 2043 static LogicalResult verify(SubViewOp op) { 2044 MemRefType baseType = op.getSourceType(); 2045 MemRefType subViewType = op.getType(); 2046 2047 // The base memref and the view memref should be in the same memory space. 2048 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2049 return op.emitError("different memory spaces specified for base memref " 2050 "type ") 2051 << baseType << " and subview memref type " << subViewType; 2052 2053 // Verify that the base memref type has a strided layout map. 2054 if (!isStrided(baseType)) 2055 return op.emitError("base type ") << baseType << " is not strided"; 2056 2057 // Verify result type against inferred type. 2058 auto expectedType = SubViewOp::inferResultType( 2059 baseType, extractFromI64ArrayAttr(op.static_offsets()), 2060 extractFromI64ArrayAttr(op.static_sizes()), 2061 extractFromI64ArrayAttr(op.static_strides())); 2062 2063 std::string errMsg; 2064 auto result = 2065 isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg); 2066 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 2067 } 2068 2069 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 2070 return os << "range " << range.offset << ":" << range.size << ":" 2071 << range.stride; 2072 } 2073 2074 /// Return the list of Range (i.e. offset, size, stride). Each Range 2075 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2076 /// with `b` at location `loc`. 2077 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2078 OpBuilder &b, Location loc) { 2079 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2080 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2081 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2082 SmallVector<Range, 8> res; 2083 unsigned rank = ranks[0]; 2084 res.reserve(rank); 2085 for (unsigned idx = 0; idx < rank; ++idx) { 2086 Value offset = 2087 op.isDynamicOffset(idx) 2088 ? op.getDynamicOffset(idx) 2089 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2090 Value size = op.isDynamicSize(idx) 2091 ? op.getDynamicSize(idx) 2092 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 2093 Value stride = 2094 op.isDynamicStride(idx) 2095 ? op.getDynamicStride(idx) 2096 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 2097 res.emplace_back(Range{offset, size, stride}); 2098 } 2099 return res; 2100 } 2101 2102 /// Infer the canonical type of the result of a subview operation. Returns a 2103 /// type with rank `resultRank` that is either the rank of the rank-reduced 2104 /// type, or the non-rank-reduced type. 2105 static MemRefType 2106 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 2107 ArrayRef<OpFoldResult> mixedOffsets, 2108 ArrayRef<OpFoldResult> mixedSizes, 2109 ArrayRef<OpFoldResult> mixedStrides) { 2110 auto resultType = 2111 SubViewOp::inferRankReducedResultType( 2112 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 2113 .cast<MemRefType>(); 2114 if (resultType.getRank() != resultRank) { 2115 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2116 mixedSizes, mixedStrides) 2117 .cast<MemRefType>(); 2118 } 2119 return resultType; 2120 } 2121 2122 namespace { 2123 /// Pattern to rewrite a subview op with MemRefCast arguments. 2124 /// This essentially pushes memref.cast past its consuming subview when 2125 /// `canFoldIntoConsumerOp` is true. 2126 /// 2127 /// Example: 2128 /// ``` 2129 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2130 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2131 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2132 /// ``` 2133 /// is rewritten into: 2134 /// ``` 2135 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2136 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2137 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2138 /// ``` 2139 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2140 public: 2141 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2142 2143 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2144 PatternRewriter &rewriter) const override { 2145 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2146 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2147 return matchPattern(operand, matchConstantIndex()); 2148 })) 2149 return failure(); 2150 2151 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2152 if (!castOp) 2153 return failure(); 2154 2155 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2156 return failure(); 2157 2158 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 2159 /// the cast source operand type and the SubViewOp static information. This 2160 /// is the resulting type if the MemRefCastOp were folded. 2161 auto resultType = getCanonicalSubViewResultType( 2162 subViewOp.getType().getRank(), 2163 castOp.source().getType().cast<MemRefType>(), 2164 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2165 subViewOp.getMixedStrides()); 2166 Value newSubView = rewriter.create<SubViewOp>( 2167 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2168 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2169 subViewOp.static_sizes(), subViewOp.static_strides()); 2170 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2171 newSubView); 2172 return success(); 2173 } 2174 }; 2175 } // namespace 2176 2177 /// Return the canonical type of the result of a subview. 2178 struct SubViewReturnTypeCanonicalizer { 2179 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2180 ArrayRef<OpFoldResult> mixedSizes, 2181 ArrayRef<OpFoldResult> mixedStrides) { 2182 return getCanonicalSubViewResultType(op.getType().getRank(), 2183 op.getSourceType(), mixedOffsets, 2184 mixedSizes, mixedStrides); 2185 } 2186 }; 2187 2188 /// A canonicalizer wrapper to replace SubViewOps. 2189 struct SubViewCanonicalizer { 2190 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2191 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2192 } 2193 }; 2194 2195 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2196 MLIRContext *context) { 2197 results 2198 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2199 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2200 SubViewOpMemRefCastFolder>(context); 2201 } 2202 2203 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2204 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2205 auto sourceShapedType = source().getType().cast<ShapedType>(); 2206 2207 if (resultShapedType.hasStaticShape() && 2208 resultShapedType == sourceShapedType) { 2209 return getViewSource(); 2210 } 2211 2212 return {}; 2213 } 2214 2215 //===----------------------------------------------------------------------===// 2216 // TensorLoadOp 2217 //===----------------------------------------------------------------------===// 2218 2219 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 2220 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 2221 // Approximate alias analysis by conservatively folding only when no there 2222 // is no interleaved operation. 2223 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 2224 bufferCast->getNextNode() == this->getOperation()) 2225 return bufferCast.tensor(); 2226 return {}; 2227 } 2228 2229 namespace { 2230 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> { 2231 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 2232 2233 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 2234 PatternRewriter &rewriter) const override { 2235 auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>(); 2236 if (!tensorLoadOp) 2237 return failure(); 2238 2239 rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(), 2240 dimOp.index()); 2241 return success(); 2242 } 2243 }; 2244 } // namespace 2245 2246 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 2247 MLIRContext *context) { 2248 results.add<DimOfTensorLoadFolder>(context); 2249 } 2250 2251 //===----------------------------------------------------------------------===// 2252 // TransposeOp 2253 //===----------------------------------------------------------------------===// 2254 2255 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2256 static MemRefType inferTransposeResultType(MemRefType memRefType, 2257 AffineMap permutationMap) { 2258 auto rank = memRefType.getRank(); 2259 auto originalSizes = memRefType.getShape(); 2260 // Compute permuted sizes. 2261 SmallVector<int64_t, 4> sizes(rank, 0); 2262 for (auto en : llvm::enumerate(permutationMap.getResults())) 2263 sizes[en.index()] = 2264 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2265 2266 // Compute permuted strides. 2267 int64_t offset; 2268 SmallVector<int64_t, 4> strides; 2269 auto res = getStridesAndOffset(memRefType, strides, offset); 2270 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2271 (void)res; 2272 auto map = 2273 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2274 map = permutationMap ? map.compose(permutationMap) : map; 2275 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2276 } 2277 2278 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2279 AffineMapAttr permutation, 2280 ArrayRef<NamedAttribute> attrs) { 2281 auto permutationMap = permutation.getValue(); 2282 assert(permutationMap); 2283 2284 auto memRefType = in.getType().cast<MemRefType>(); 2285 // Compute result type. 2286 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2287 2288 build(b, result, resultType, in, attrs); 2289 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2290 } 2291 2292 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2293 static void print(OpAsmPrinter &p, TransposeOp op) { 2294 p << " " << op.in() << " " << op.permutation(); 2295 p.printOptionalAttrDict(op->getAttrs(), 2296 {TransposeOp::getPermutationAttrName()}); 2297 p << " : " << op.in().getType() << " to " << op.getType(); 2298 } 2299 2300 static ParseResult parseTransposeOp(OpAsmParser &parser, 2301 OperationState &result) { 2302 OpAsmParser::OperandType in; 2303 AffineMap permutation; 2304 MemRefType srcType, dstType; 2305 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2306 parser.parseOptionalAttrDict(result.attributes) || 2307 parser.parseColonType(srcType) || 2308 parser.resolveOperand(in, srcType, result.operands) || 2309 parser.parseKeywordType("to", dstType) || 2310 parser.addTypeToList(dstType, result.types)) 2311 return failure(); 2312 2313 result.addAttribute(TransposeOp::getPermutationAttrName(), 2314 AffineMapAttr::get(permutation)); 2315 return success(); 2316 } 2317 2318 static LogicalResult verify(TransposeOp op) { 2319 if (!op.permutation().isPermutation()) 2320 return op.emitOpError("expected a permutation map"); 2321 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2322 return op.emitOpError( 2323 "expected a permutation map of same rank as the input"); 2324 2325 auto srcType = op.in().getType().cast<MemRefType>(); 2326 auto dstType = op.getType().cast<MemRefType>(); 2327 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2328 if (dstType != transposedType) 2329 return op.emitOpError("output type ") 2330 << dstType << " does not match transposed input type " << srcType 2331 << ", " << transposedType; 2332 return success(); 2333 } 2334 2335 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2336 if (succeeded(foldMemRefCast(*this))) 2337 return getResult(); 2338 return {}; 2339 } 2340 2341 //===----------------------------------------------------------------------===// 2342 // ViewOp 2343 //===----------------------------------------------------------------------===// 2344 2345 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2346 OpAsmParser::OperandType srcInfo; 2347 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2348 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2349 auto indexType = parser.getBuilder().getIndexType(); 2350 Type srcType, dstType; 2351 llvm::SMLoc offsetLoc; 2352 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2353 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2354 return failure(); 2355 2356 if (offsetInfo.size() != 1) 2357 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2358 2359 return failure( 2360 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2361 parser.parseOptionalAttrDict(result.attributes) || 2362 parser.parseColonType(srcType) || 2363 parser.resolveOperand(srcInfo, srcType, result.operands) || 2364 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2365 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2366 parser.parseKeywordType("to", dstType) || 2367 parser.addTypeToList(dstType, result.types)); 2368 } 2369 2370 static void print(OpAsmPrinter &p, ViewOp op) { 2371 p << ' ' << op.getOperand(0) << '['; 2372 p.printOperand(op.byte_shift()); 2373 p << "][" << op.sizes() << ']'; 2374 p.printOptionalAttrDict(op->getAttrs()); 2375 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2376 } 2377 2378 static LogicalResult verify(ViewOp op) { 2379 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2380 auto viewType = op.getType(); 2381 2382 // The base memref should have identity layout map (or none). 2383 if (baseType.getAffineMaps().size() > 1 || 2384 (baseType.getAffineMaps().size() == 1 && 2385 !baseType.getAffineMaps()[0].isIdentity())) 2386 return op.emitError("unsupported map for base memref type ") << baseType; 2387 2388 // The result memref should have identity layout map (or none). 2389 if (viewType.getAffineMaps().size() > 1 || 2390 (viewType.getAffineMaps().size() == 1 && 2391 !viewType.getAffineMaps()[0].isIdentity())) 2392 return op.emitError("unsupported map for result memref type ") << viewType; 2393 2394 // The base memref and the view memref should be in the same memory space. 2395 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2396 return op.emitError("different memory spaces specified for base memref " 2397 "type ") 2398 << baseType << " and view memref type " << viewType; 2399 2400 // Verify that we have the correct number of sizes for the result type. 2401 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2402 if (op.sizes().size() != numDynamicDims) 2403 return op.emitError("incorrect number of size operands for type ") 2404 << viewType; 2405 2406 return success(); 2407 } 2408 2409 Value ViewOp::getViewSource() { return source(); } 2410 2411 namespace { 2412 2413 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2414 using OpRewritePattern<ViewOp>::OpRewritePattern; 2415 2416 LogicalResult matchAndRewrite(ViewOp viewOp, 2417 PatternRewriter &rewriter) const override { 2418 // Return if none of the operands are constants. 2419 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2420 return matchPattern(operand, matchConstantIndex()); 2421 })) 2422 return failure(); 2423 2424 // Get result memref type. 2425 auto memrefType = viewOp.getType(); 2426 2427 // Get offset from old memref view type 'memRefType'. 2428 int64_t oldOffset; 2429 SmallVector<int64_t, 4> oldStrides; 2430 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2431 return failure(); 2432 assert(oldOffset == 0 && "Expected 0 offset"); 2433 2434 SmallVector<Value, 4> newOperands; 2435 2436 // Offset cannot be folded into result type. 2437 2438 // Fold any dynamic dim operands which are produced by a constant. 2439 SmallVector<int64_t, 4> newShapeConstants; 2440 newShapeConstants.reserve(memrefType.getRank()); 2441 2442 unsigned dynamicDimPos = 0; 2443 unsigned rank = memrefType.getRank(); 2444 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2445 int64_t dimSize = memrefType.getDimSize(dim); 2446 // If this is already static dimension, keep it. 2447 if (!ShapedType::isDynamic(dimSize)) { 2448 newShapeConstants.push_back(dimSize); 2449 continue; 2450 } 2451 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2452 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2453 // Dynamic shape dimension will be folded. 2454 newShapeConstants.push_back(constantIndexOp.getValue()); 2455 } else { 2456 // Dynamic shape dimension not folded; copy operand from old memref. 2457 newShapeConstants.push_back(dimSize); 2458 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2459 } 2460 dynamicDimPos++; 2461 } 2462 2463 // Create new memref type with constant folded dims. 2464 MemRefType newMemRefType = 2465 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2466 // Nothing new, don't fold. 2467 if (newMemRefType == memrefType) 2468 return failure(); 2469 2470 // Create new ViewOp. 2471 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2472 viewOp.getOperand(0), 2473 viewOp.byte_shift(), newOperands); 2474 // Insert a cast so we have the same type as the old memref type. 2475 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2476 return success(); 2477 } 2478 }; 2479 2480 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2481 using OpRewritePattern<ViewOp>::OpRewritePattern; 2482 2483 LogicalResult matchAndRewrite(ViewOp viewOp, 2484 PatternRewriter &rewriter) const override { 2485 Value memrefOperand = viewOp.getOperand(0); 2486 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2487 if (!memrefCastOp) 2488 return failure(); 2489 Value allocOperand = memrefCastOp.getOperand(); 2490 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2491 if (!allocOp) 2492 return failure(); 2493 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2494 viewOp.byte_shift(), viewOp.sizes()); 2495 return success(); 2496 } 2497 }; 2498 2499 } // end anonymous namespace 2500 2501 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2502 MLIRContext *context) { 2503 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2504 } 2505 2506 //===----------------------------------------------------------------------===// 2507 // TableGen'd op method definitions 2508 //===----------------------------------------------------------------------===// 2509 2510 #define GET_OP_CLASSES 2511 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2512