1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 return builder.create<mlir::ConstantOp>(loc, type, value); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Common canonicalization pattern support logic 38 //===----------------------------------------------------------------------===// 39 40 /// This is a common class used for patterns of the form 41 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 42 /// into the root operation directly. 43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 44 bool folded = false; 45 for (OpOperand &operand : op->getOpOperands()) { 46 auto cast = operand.get().getDefiningOp<CastOp>(); 47 if (cast && operand.get() != inner && 48 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 49 operand.set(cast.getOperand()); 50 folded = true; 51 } 52 } 53 return success(folded); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Helpers for GlobalOp 58 //===----------------------------------------------------------------------===// 59 60 static Type getTensorTypeFromMemRefType(Type type) { 61 if (auto memref = type.dyn_cast<MemRefType>()) 62 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 63 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 64 return UnrankedTensorType::get(memref.getElementType()); 65 return NoneType::get(type.getContext()); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // AllocOp / AllocaOp 70 //===----------------------------------------------------------------------===// 71 72 template <typename AllocLikeOp> 73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 74 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 75 "applies to only alloc or alloca"); 76 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 77 if (!memRefType) 78 return op.emitOpError("result must be a memref"); 79 80 if (static_cast<int64_t>(op.dynamicSizes().size()) != 81 memRefType.getNumDynamicDims()) 82 return op.emitOpError("dimension operand count does not equal memref " 83 "dynamic dimension count"); 84 85 unsigned numSymbols = 0; 86 if (!memRefType.getAffineMaps().empty()) 87 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 88 if (op.symbolOperands().size() != numSymbols) 89 return op.emitOpError("symbol operand count does not equal memref symbol " 90 "count: expected ") 91 << numSymbols << ", got " << op.symbolOperands().size(); 92 93 return success(); 94 } 95 96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 97 98 static LogicalResult verify(AllocaOp op) { 99 // An alloca op needs to have an ancestor with an allocation scope trait. 100 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 101 return op.emitOpError( 102 "requires an ancestor op with AutomaticAllocationScope trait"); 103 104 return verifyAllocLikeOp(op); 105 } 106 107 namespace { 108 /// Fold constant dimensions into an alloc like operation. 109 template <typename AllocLikeOp> 110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 111 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(AllocLikeOp alloc, 114 PatternRewriter &rewriter) const override { 115 // Check to see if any dimensions operands are constants. If so, we can 116 // substitute and drop them. 117 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 118 return matchPattern(operand, matchConstantIndex()); 119 })) 120 return failure(); 121 122 auto memrefType = alloc.getType(); 123 124 // Ok, we have one or more constant operands. Collect the non-constant ones 125 // and keep track of the resultant memref type to build. 126 SmallVector<int64_t, 4> newShapeConstants; 127 newShapeConstants.reserve(memrefType.getRank()); 128 SmallVector<Value, 4> dynamicSizes; 129 130 unsigned dynamicDimPos = 0; 131 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 132 int64_t dimSize = memrefType.getDimSize(dim); 133 // If this is already static dimension, keep it. 134 if (dimSize != -1) { 135 newShapeConstants.push_back(dimSize); 136 continue; 137 } 138 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 139 auto *defOp = dynamicSize.getDefiningOp(); 140 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 141 // Dynamic shape dimension will be folded. 142 newShapeConstants.push_back(constantIndexOp.getValue()); 143 } else { 144 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 145 newShapeConstants.push_back(-1); 146 dynamicSizes.push_back(dynamicSize); 147 } 148 dynamicDimPos++; 149 } 150 151 // Create new memref type (which will have fewer dynamic dimensions). 152 MemRefType newMemRefType = 153 MemRefType::Builder(memrefType).setShape(newShapeConstants); 154 assert(static_cast<int64_t>(dynamicSizes.size()) == 155 newMemRefType.getNumDynamicDims()); 156 157 // Create and insert the alloc op for the new memref. 158 auto newAlloc = rewriter.create<AllocLikeOp>( 159 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 160 alloc.alignmentAttr()); 161 // Insert a cast so we have the same type as the old alloc. 162 auto resultCast = 163 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 164 165 rewriter.replaceOp(alloc, {resultCast}); 166 return success(); 167 } 168 }; 169 170 /// Fold alloc operations with no users or only store and dealloc uses. 171 template <typename T> 172 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 173 using OpRewritePattern<T>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(T alloc, 176 PatternRewriter &rewriter) const override { 177 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 178 if (auto storeOp = dyn_cast<StoreOp>(op)) 179 return storeOp.value() == alloc; 180 return !isa<DeallocOp>(op); 181 })) 182 return failure(); 183 184 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 185 rewriter.eraseOp(user); 186 187 rewriter.eraseOp(alloc); 188 return success(); 189 } 190 }; 191 } // end anonymous namespace. 192 193 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 194 MLIRContext *context) { 195 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 196 } 197 198 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 199 MLIRContext *context) { 200 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 201 context); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // AllocaScopeOp 206 //===----------------------------------------------------------------------===// 207 208 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 209 bool printBlockTerminators = false; 210 211 p << AllocaScopeOp::getOperationName() << " "; 212 if (!op.results().empty()) { 213 p << " -> (" << op.getResultTypes() << ")"; 214 printBlockTerminators = true; 215 } 216 p.printRegion(op.bodyRegion(), 217 /*printEntryBlockArgs=*/false, 218 /*printBlockTerminators=*/printBlockTerminators); 219 p.printOptionalAttrDict(op->getAttrs()); 220 } 221 222 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 223 OperationState &result) { 224 // Create a region for the body. 225 result.regions.reserve(1); 226 Region *bodyRegion = result.addRegion(); 227 228 // Parse optional results type list. 229 if (parser.parseOptionalArrowTypeList(result.types)) 230 return failure(); 231 232 // Parse the body region. 233 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 234 return failure(); 235 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 236 result.location); 237 238 // Parse the optional attribute list. 239 if (parser.parseOptionalAttrDict(result.attributes)) 240 return failure(); 241 242 return success(); 243 } 244 245 static LogicalResult verify(AllocaScopeOp op) { 246 if (failed(RegionBranchOpInterface::verifyTypes(op))) 247 return failure(); 248 249 return success(); 250 } 251 252 void AllocaScopeOp::getSuccessorRegions( 253 Optional<unsigned> index, ArrayRef<Attribute> operands, 254 SmallVectorImpl<RegionSuccessor> ®ions) { 255 if (index.hasValue()) { 256 regions.push_back(RegionSuccessor(getResults())); 257 return; 258 } 259 260 regions.push_back(RegionSuccessor(&bodyRegion())); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // AssumeAlignmentOp 265 //===----------------------------------------------------------------------===// 266 267 static LogicalResult verify(AssumeAlignmentOp op) { 268 unsigned alignment = op.alignment(); 269 if (!llvm::isPowerOf2_32(alignment)) 270 return op.emitOpError("alignment must be power of 2"); 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // BufferCastOp 276 //===----------------------------------------------------------------------===// 277 278 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 279 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 280 if (tensorLoad.memref().getType() == getType()) 281 return tensorLoad.memref(); 282 return {}; 283 } 284 285 namespace { 286 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 287 struct BufferCast : public OpRewritePattern<BufferCastOp> { 288 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 289 290 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 291 PatternRewriter &rewriter) const final { 292 auto tensorCastOperand = 293 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 294 if (!tensorCastOperand) 295 return failure(); 296 auto srcTensorType = 297 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 298 if (!srcTensorType) 299 return failure(); 300 auto memrefType = MemRefType::get(srcTensorType.getShape(), 301 srcTensorType.getElementType()); 302 Value memref = rewriter.create<BufferCastOp>( 303 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 304 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 305 memref); 306 return success(); 307 } 308 }; 309 310 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 311 /// type mismatches prevent `BufferCastOp::fold` to kick in. 312 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 313 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 316 PatternRewriter &rewriter) const final { 317 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 318 // Bail unless we have a tensor_load + memref.buffer_cast with different 319 // types. `BufferCastOp::fold` handles the same type case. 320 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 321 return failure(); 322 // If types are definitely not cast-compatible, bail. 323 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 324 bufferCast.getType())) 325 return failure(); 326 327 // We already know that the types are potentially cast-compatible. However 328 // in case the affine maps are different, we may need to use a copy if we go 329 // from dynamic to static offset or stride (the canonicalization cannot know 330 // at this point that it is really cast compatible). 331 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 332 int64_t sourceOffset, targetOffset; 333 SmallVector<int64_t, 4> sourceStrides, targetStrides; 334 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 335 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 336 return false; 337 auto dynamicToStatic = [](int64_t a, int64_t b) { 338 return a == MemRefType::getDynamicStrideOrOffset() && 339 b != MemRefType::getDynamicStrideOrOffset(); 340 }; 341 if (dynamicToStatic(sourceOffset, targetOffset)) 342 return false; 343 for (auto it : zip(sourceStrides, targetStrides)) 344 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 345 return false; 346 return true; 347 }; 348 349 auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>(); 350 auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>(); 351 if (tensorLoadType && bufferCastType && 352 !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) { 353 MemRefType resultType = bufferCastType; 354 auto loc = bufferCast.getLoc(); 355 SmallVector<Value, 4> dynamicOperands; 356 for (int i = 0; i < resultType.getRank(); ++i) { 357 if (resultType.getShape()[i] != ShapedType::kDynamicSize) 358 continue; 359 auto index = rewriter.createOrFold<ConstantIndexOp>(loc, i); 360 Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index); 361 dynamicOperands.push_back(size); 362 } 363 auto copy = 364 rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands); 365 rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy); 366 rewriter.replaceOp(bufferCast, {copy}); 367 } else 368 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 369 tensorLoad.memref()); 370 return success(); 371 } 372 }; 373 374 } // namespace 375 376 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 377 MLIRContext *context) { 378 results.add<BufferCast, TensorLoadToMemRef>(context); 379 } 380 381 //===----------------------------------------------------------------------===// 382 // CastOp 383 //===----------------------------------------------------------------------===// 384 385 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 386 /// source memref. This is useful to to fold a memref.cast into a consuming op 387 /// and implement canonicalization patterns for ops in different dialects that 388 /// may consume the results of memref.cast operations. Such foldable memref.cast 389 /// operations are typically inserted as `view` and `subview` ops are 390 /// canonicalized, to preserve the type compatibility of their uses. 391 /// 392 /// Returns true when all conditions are met: 393 /// 1. source and result are ranked memrefs with strided semantics and same 394 /// element type and rank. 395 /// 2. each of the source's size, offset or stride has more static information 396 /// than the corresponding result's size, offset or stride. 397 /// 398 /// Example 1: 399 /// ```mlir 400 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 401 /// %2 = consumer %1 ... : memref<?x?xf32> ... 402 /// ``` 403 /// 404 /// may fold into: 405 /// 406 /// ```mlir 407 /// %2 = consumer %0 ... : memref<8x16xf32> ... 408 /// ``` 409 /// 410 /// Example 2: 411 /// ``` 412 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 413 /// to memref<?x?xf32> 414 /// consumer %1 : memref<?x?xf32> ... 415 /// ``` 416 /// 417 /// may fold into: 418 /// 419 /// ``` 420 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 421 /// ``` 422 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 423 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 424 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 425 426 // Requires ranked MemRefType. 427 if (!sourceType || !resultType) 428 return false; 429 430 // Requires same elemental type. 431 if (sourceType.getElementType() != resultType.getElementType()) 432 return false; 433 434 // Requires same rank. 435 if (sourceType.getRank() != resultType.getRank()) 436 return false; 437 438 // Only fold casts between strided memref forms. 439 int64_t sourceOffset, resultOffset; 440 SmallVector<int64_t, 4> sourceStrides, resultStrides; 441 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 442 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 443 return false; 444 445 // If cast is towards more static sizes along any dimension, don't fold. 446 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 447 auto ss = std::get<0>(it), st = std::get<1>(it); 448 if (ss != st) 449 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 450 return false; 451 } 452 453 // If cast is towards more static offset along any dimension, don't fold. 454 if (sourceOffset != resultOffset) 455 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 456 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 457 return false; 458 459 // If cast is towards more static strides along any dimension, don't fold. 460 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 461 auto ss = std::get<0>(it), st = std::get<1>(it); 462 if (ss != st) 463 if (MemRefType::isDynamicStrideOrOffset(ss) && 464 !MemRefType::isDynamicStrideOrOffset(st)) 465 return false; 466 } 467 468 return true; 469 } 470 471 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 472 if (inputs.size() != 1 || outputs.size() != 1) 473 return false; 474 Type a = inputs.front(), b = outputs.front(); 475 auto aT = a.dyn_cast<MemRefType>(); 476 auto bT = b.dyn_cast<MemRefType>(); 477 478 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 479 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 480 481 if (aT && bT) { 482 if (aT.getElementType() != bT.getElementType()) 483 return false; 484 if (aT.getAffineMaps() != bT.getAffineMaps()) { 485 int64_t aOffset, bOffset; 486 SmallVector<int64_t, 4> aStrides, bStrides; 487 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 488 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 489 aStrides.size() != bStrides.size()) 490 return false; 491 492 // Strides along a dimension/offset are compatible if the value in the 493 // source memref is static and the value in the target memref is the 494 // same. They are also compatible if either one is dynamic (see 495 // description of MemRefCastOp for details). 496 auto checkCompatible = [](int64_t a, int64_t b) { 497 return (a == MemRefType::getDynamicStrideOrOffset() || 498 b == MemRefType::getDynamicStrideOrOffset() || a == b); 499 }; 500 if (!checkCompatible(aOffset, bOffset)) 501 return false; 502 for (auto aStride : enumerate(aStrides)) 503 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 504 return false; 505 } 506 if (aT.getMemorySpace() != bT.getMemorySpace()) 507 return false; 508 509 // They must have the same rank, and any specified dimensions must match. 510 if (aT.getRank() != bT.getRank()) 511 return false; 512 513 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 514 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 515 if (aDim != -1 && bDim != -1 && aDim != bDim) 516 return false; 517 } 518 return true; 519 } else { 520 if (!aT && !uaT) 521 return false; 522 if (!bT && !ubT) 523 return false; 524 // Unranked to unranked casting is unsupported 525 if (uaT && ubT) 526 return false; 527 528 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 529 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 530 if (aEltType != bEltType) 531 return false; 532 533 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 534 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 535 if (aMemSpace != bMemSpace) 536 return false; 537 538 return true; 539 } 540 541 return false; 542 } 543 544 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 545 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 546 } 547 548 //===----------------------------------------------------------------------===// 549 // CloneOp 550 //===----------------------------------------------------------------------===// 551 552 void CloneOp::getEffects( 553 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 554 &effects) { 555 effects.emplace_back(MemoryEffects::Read::get(), input(), 556 SideEffects::DefaultResource::get()); 557 effects.emplace_back(MemoryEffects::Write::get(), output(), 558 SideEffects::DefaultResource::get()); 559 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 560 SideEffects::DefaultResource::get()); 561 } 562 563 namespace { 564 /// Merge the clone and its source (by converting the clone to a cast) when 565 /// possible. 566 struct SimplifyClones : public OpRewritePattern<CloneOp> { 567 using OpRewritePattern<CloneOp>::OpRewritePattern; 568 569 LogicalResult matchAndRewrite(CloneOp cloneOp, 570 PatternRewriter &rewriter) const override { 571 if (cloneOp.use_empty()) { 572 rewriter.eraseOp(cloneOp); 573 return success(); 574 } 575 576 Value source = cloneOp.input(); 577 578 // This only finds dealloc operations for the immediate value. It should 579 // also consider aliases. That would also make the safety check below 580 // redundant. 581 llvm::Optional<Operation *> maybeCloneDeallocOp = 582 findDealloc(cloneOp.output()); 583 // Skip if either of them has > 1 deallocate operations. 584 if (!maybeCloneDeallocOp.hasValue()) 585 return failure(); 586 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source); 587 if (!maybeSourceDeallocOp.hasValue()) 588 return failure(); 589 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 590 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 591 592 // If both are deallocated in the same block, their in-block lifetimes 593 // might not fully overlap, so we cannot decide which one to drop. 594 if (cloneDeallocOp && sourceDeallocOp && 595 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 596 return failure(); 597 598 Block *currentBlock = cloneOp->getBlock(); 599 Operation *redundantDealloc = nullptr; 600 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 601 redundantDealloc = cloneDeallocOp; 602 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 603 redundantDealloc = sourceDeallocOp; 604 } 605 606 if (!redundantDealloc) 607 return failure(); 608 609 // Safety check that there are no other deallocations inbetween 610 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 611 // of source before the uses of the clone. With alias information, we could 612 // restrict this to only fail of the dealloc's operand is an alias 613 // of the source. 614 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 615 pos = pos->getNextNode()) { 616 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 617 if (!effectInterface) 618 continue; 619 if (effectInterface.hasEffect<MemoryEffects::Free>()) 620 return failure(); 621 } 622 623 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 624 source); 625 rewriter.eraseOp(redundantDealloc); 626 return success(); 627 } 628 }; 629 630 } // end anonymous namespace. 631 632 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 633 MLIRContext *context) { 634 results.insert<SimplifyClones>(context); 635 } 636 637 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 638 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 639 } 640 641 //===----------------------------------------------------------------------===// 642 // DeallocOp 643 //===----------------------------------------------------------------------===// 644 645 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 646 SmallVectorImpl<OpFoldResult> &results) { 647 /// dealloc(memrefcast) -> dealloc 648 return foldMemRefCast(*this); 649 } 650 651 //===----------------------------------------------------------------------===// 652 // DimOp 653 //===----------------------------------------------------------------------===// 654 655 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 656 int64_t index) { 657 auto loc = result.location; 658 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 659 build(builder, result, source, indexValue); 660 } 661 662 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 663 Value index) { 664 auto indexTy = builder.getIndexType(); 665 build(builder, result, indexTy, source, index); 666 } 667 668 Optional<int64_t> DimOp::getConstantIndex() { 669 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 670 return constantOp.getValue().cast<IntegerAttr>().getInt(); 671 return {}; 672 } 673 674 static LogicalResult verify(DimOp op) { 675 // Assume unknown index to be in range. 676 Optional<int64_t> index = op.getConstantIndex(); 677 if (!index.hasValue()) 678 return success(); 679 680 // Check that constant index is not knowingly out of range. 681 auto type = op.source().getType(); 682 if (auto memrefType = type.dyn_cast<MemRefType>()) { 683 if (index.getValue() >= memrefType.getRank()) 684 return op.emitOpError("index is out of range"); 685 } else if (type.isa<UnrankedMemRefType>()) { 686 // Assume index to be in range. 687 } else { 688 llvm_unreachable("expected operand with memref type"); 689 } 690 return success(); 691 } 692 693 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 694 // All forms of folding require a known index. 695 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 696 if (!index) 697 return {}; 698 699 // Folding for unranked types (UnrankedMemRefType) is not supported. 700 auto memrefType = source().getType().dyn_cast<MemRefType>(); 701 if (!memrefType) 702 return {}; 703 704 // Fold if the shape extent along the given index is known. 705 if (!memrefType.isDynamicDim(index.getInt())) { 706 Builder builder(getContext()); 707 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 708 } 709 710 // The size at the given index is now known to be a dynamic size. 711 unsigned unsignedIndex = index.getValue().getZExtValue(); 712 713 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 714 Operation *definingOp = source().getDefiningOp(); 715 716 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 717 return *(alloc.getDynamicSizes().begin() + 718 memrefType.getDynamicDimIndex(unsignedIndex)); 719 720 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 721 return *(alloca.getDynamicSizes().begin() + 722 memrefType.getDynamicDimIndex(unsignedIndex)); 723 724 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 725 return *(view.getDynamicSizes().begin() + 726 memrefType.getDynamicDimIndex(unsignedIndex)); 727 728 if (auto sizeInterface = 729 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 730 assert(sizeInterface.isDynamicSize(unsignedIndex) && 731 "Expected dynamic subview size"); 732 return sizeInterface.getDynamicSize(unsignedIndex); 733 } 734 735 // dim(memrefcast) -> dim 736 if (succeeded(foldMemRefCast(*this))) 737 return getResult(); 738 739 return {}; 740 } 741 742 namespace { 743 /// Fold dim of a memref reshape operation to a load into the reshape's shape 744 /// operand. 745 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 746 using OpRewritePattern<DimOp>::OpRewritePattern; 747 748 LogicalResult matchAndRewrite(DimOp dim, 749 PatternRewriter &rewriter) const override { 750 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 751 752 if (!reshape) 753 return failure(); 754 755 // Place the load directly after the reshape to ensure that the shape memref 756 // was not mutated. 757 rewriter.setInsertionPointAfter(reshape); 758 Location loc = dim.getLoc(); 759 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 760 if (load.getType() != dim.getType()) 761 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load); 762 rewriter.replaceOp(dim, load); 763 return success(); 764 } 765 }; 766 767 /// Fold dim of a cast into the dim of the source of the memref cast. 768 struct DimOfCastOp : public OpRewritePattern<DimOp> { 769 using OpRewritePattern<DimOp>::OpRewritePattern; 770 771 LogicalResult matchAndRewrite(DimOp dimOp, 772 PatternRewriter &rewriter) const override { 773 auto castOp = dimOp.source().getDefiningOp<BufferCastOp>(); 774 if (!castOp) 775 return failure(); 776 Value newSource = castOp.getOperand(); 777 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 778 return success(); 779 } 780 }; 781 } // end anonymous namespace. 782 783 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 784 MLIRContext *context) { 785 results.add<DimOfMemRefReshape, DimOfCastOp>(context); 786 } 787 788 // --------------------------------------------------------------------------- 789 // DmaStartOp 790 // --------------------------------------------------------------------------- 791 792 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 793 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 794 ValueRange destIndices, Value numElements, 795 Value tagMemRef, ValueRange tagIndices, Value stride, 796 Value elementsPerStride) { 797 result.addOperands(srcMemRef); 798 result.addOperands(srcIndices); 799 result.addOperands(destMemRef); 800 result.addOperands(destIndices); 801 result.addOperands({numElements, tagMemRef}); 802 result.addOperands(tagIndices); 803 if (stride) 804 result.addOperands({stride, elementsPerStride}); 805 } 806 807 void DmaStartOp::print(OpAsmPrinter &p) { 808 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 809 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 810 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 811 << ']'; 812 if (isStrided()) 813 p << ", " << getStride() << ", " << getNumElementsPerStride(); 814 815 p.printOptionalAttrDict((*this)->getAttrs()); 816 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 817 << ", " << getTagMemRef().getType(); 818 } 819 820 // Parse DmaStartOp. 821 // Ex: 822 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 823 // %tag[%index], %stride, %num_elt_per_stride : 824 // : memref<3076 x f32, 0>, 825 // memref<1024 x f32, 2>, 826 // memref<1 x i32> 827 // 828 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 829 OpAsmParser::OperandType srcMemRefInfo; 830 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 831 OpAsmParser::OperandType dstMemRefInfo; 832 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 833 OpAsmParser::OperandType numElementsInfo; 834 OpAsmParser::OperandType tagMemrefInfo; 835 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 836 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 837 838 SmallVector<Type, 3> types; 839 auto indexType = parser.getBuilder().getIndexType(); 840 841 // Parse and resolve the following list of operands: 842 // *) source memref followed by its indices (in square brackets). 843 // *) destination memref followed by its indices (in square brackets). 844 // *) dma size in KiB. 845 if (parser.parseOperand(srcMemRefInfo) || 846 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 847 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 848 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 849 parser.parseComma() || parser.parseOperand(numElementsInfo) || 850 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 851 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 852 return failure(); 853 854 // Parse optional stride and elements per stride. 855 if (parser.parseTrailingOperandList(strideInfo)) 856 return failure(); 857 858 bool isStrided = strideInfo.size() == 2; 859 if (!strideInfo.empty() && !isStrided) { 860 return parser.emitError(parser.getNameLoc(), 861 "expected two stride related operands"); 862 } 863 864 if (parser.parseColonTypeList(types)) 865 return failure(); 866 if (types.size() != 3) 867 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 868 869 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 870 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 871 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 872 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 873 // size should be an index. 874 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 875 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 876 // tag indices should be index. 877 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 878 return failure(); 879 880 if (isStrided) { 881 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 882 return failure(); 883 } 884 885 return success(); 886 } 887 888 LogicalResult DmaStartOp::verify() { 889 unsigned numOperands = getNumOperands(); 890 891 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 892 // the number of elements. 893 if (numOperands < 4) 894 return emitOpError("expected at least 4 operands"); 895 896 // Check types of operands. The order of these calls is important: the later 897 // calls rely on some type properties to compute the operand position. 898 // 1. Source memref. 899 if (!getSrcMemRef().getType().isa<MemRefType>()) 900 return emitOpError("expected source to be of memref type"); 901 if (numOperands < getSrcMemRefRank() + 4) 902 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 903 << " operands"; 904 if (!getSrcIndices().empty() && 905 !llvm::all_of(getSrcIndices().getTypes(), 906 [](Type t) { return t.isIndex(); })) 907 return emitOpError("expected source indices to be of index type"); 908 909 // 2. Destination memref. 910 if (!getDstMemRef().getType().isa<MemRefType>()) 911 return emitOpError("expected destination to be of memref type"); 912 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 913 if (numOperands < numExpectedOperands) 914 return emitOpError() << "expected at least " << numExpectedOperands 915 << " operands"; 916 if (!getDstIndices().empty() && 917 !llvm::all_of(getDstIndices().getTypes(), 918 [](Type t) { return t.isIndex(); })) 919 return emitOpError("expected destination indices to be of index type"); 920 921 // 3. Number of elements. 922 if (!getNumElements().getType().isIndex()) 923 return emitOpError("expected num elements to be of index type"); 924 925 // 4. Tag memref. 926 if (!getTagMemRef().getType().isa<MemRefType>()) 927 return emitOpError("expected tag to be of memref type"); 928 numExpectedOperands += getTagMemRefRank(); 929 if (numOperands < numExpectedOperands) 930 return emitOpError() << "expected at least " << numExpectedOperands 931 << " operands"; 932 if (!getTagIndices().empty() && 933 !llvm::all_of(getTagIndices().getTypes(), 934 [](Type t) { return t.isIndex(); })) 935 return emitOpError("expected tag indices to be of index type"); 936 937 // Optional stride-related operands must be either both present or both 938 // absent. 939 if (numOperands != numExpectedOperands && 940 numOperands != numExpectedOperands + 2) 941 return emitOpError("incorrect number of operands"); 942 943 // 5. Strides. 944 if (isStrided()) { 945 if (!getStride().getType().isIndex() || 946 !getNumElementsPerStride().getType().isIndex()) 947 return emitOpError( 948 "expected stride and num elements per stride to be of type index"); 949 } 950 951 return success(); 952 } 953 954 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 955 SmallVectorImpl<OpFoldResult> &results) { 956 /// dma_start(memrefcast) -> dma_start 957 return foldMemRefCast(*this); 958 } 959 960 // --------------------------------------------------------------------------- 961 // DmaWaitOp 962 // --------------------------------------------------------------------------- 963 964 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 965 Value tagMemRef, ValueRange tagIndices, 966 Value numElements) { 967 result.addOperands(tagMemRef); 968 result.addOperands(tagIndices); 969 result.addOperands(numElements); 970 } 971 972 void DmaWaitOp::print(OpAsmPrinter &p) { 973 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 974 << "], " << getNumElements(); 975 p.printOptionalAttrDict((*this)->getAttrs()); 976 p << " : " << getTagMemRef().getType(); 977 } 978 979 // Parse DmaWaitOp. 980 // Eg: 981 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 982 // 983 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 984 OpAsmParser::OperandType tagMemrefInfo; 985 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 986 Type type; 987 auto indexType = parser.getBuilder().getIndexType(); 988 OpAsmParser::OperandType numElementsInfo; 989 990 // Parse tag memref, its indices, and dma size. 991 if (parser.parseOperand(tagMemrefInfo) || 992 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 993 parser.parseComma() || parser.parseOperand(numElementsInfo) || 994 parser.parseColonType(type) || 995 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 996 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 997 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 998 return failure(); 999 1000 return success(); 1001 } 1002 1003 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1004 SmallVectorImpl<OpFoldResult> &results) { 1005 /// dma_wait(memrefcast) -> dma_wait 1006 return foldMemRefCast(*this); 1007 } 1008 1009 LogicalResult DmaWaitOp::verify() { 1010 // Mandatory non-variadic operands are tag and the number of elements. 1011 if (getNumOperands() < 2) 1012 return emitOpError() << "expected at least 2 operands"; 1013 1014 // Check types of operands. The order of these calls is important: the later 1015 // calls rely on some type properties to compute the operand position. 1016 if (!getTagMemRef().getType().isa<MemRefType>()) 1017 return emitOpError() << "expected tag to be of memref type"; 1018 1019 if (getNumOperands() != 2 + getTagMemRefRank()) 1020 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1021 << " operands"; 1022 1023 if (!getTagIndices().empty() && 1024 !llvm::all_of(getTagIndices().getTypes(), 1025 [](Type t) { return t.isIndex(); })) 1026 return emitOpError() << "expected tag indices to be of index type"; 1027 1028 if (!getNumElements().getType().isIndex()) 1029 return emitOpError() 1030 << "expected the number of elements to be of index type"; 1031 1032 return success(); 1033 } 1034 1035 //===----------------------------------------------------------------------===// 1036 // GlobalOp 1037 //===----------------------------------------------------------------------===// 1038 1039 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1040 TypeAttr type, 1041 Attribute initialValue) { 1042 p << type; 1043 if (!op.isExternal()) { 1044 p << " = "; 1045 if (op.isUninitialized()) 1046 p << "uninitialized"; 1047 else 1048 p.printAttributeWithoutType(initialValue); 1049 } 1050 } 1051 1052 static ParseResult 1053 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1054 Attribute &initialValue) { 1055 Type type; 1056 if (parser.parseType(type)) 1057 return failure(); 1058 1059 auto memrefType = type.dyn_cast<MemRefType>(); 1060 if (!memrefType || !memrefType.hasStaticShape()) 1061 return parser.emitError(parser.getNameLoc()) 1062 << "type should be static shaped memref, but got " << type; 1063 typeAttr = TypeAttr::get(type); 1064 1065 if (parser.parseOptionalEqual()) 1066 return success(); 1067 1068 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1069 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1070 return success(); 1071 } 1072 1073 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1074 if (parser.parseAttribute(initialValue, tensorType)) 1075 return failure(); 1076 if (!initialValue.isa<ElementsAttr>()) 1077 return parser.emitError(parser.getNameLoc()) 1078 << "initial value should be a unit or elements attribute"; 1079 return success(); 1080 } 1081 1082 static LogicalResult verify(GlobalOp op) { 1083 auto memrefType = op.type().dyn_cast<MemRefType>(); 1084 if (!memrefType || !memrefType.hasStaticShape()) 1085 return op.emitOpError("type should be static shaped memref, but got ") 1086 << op.type(); 1087 1088 // Verify that the initial value, if present, is either a unit attribute or 1089 // an elements attribute. 1090 if (op.initial_value().hasValue()) { 1091 Attribute initValue = op.initial_value().getValue(); 1092 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1093 return op.emitOpError("initial value should be a unit or elements " 1094 "attribute, but got ") 1095 << initValue; 1096 1097 // Check that the type of the initial value is compatible with the type of 1098 // the global variable. 1099 if (initValue.isa<ElementsAttr>()) { 1100 Type initType = initValue.getType(); 1101 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1102 if (initType != tensorType) 1103 return op.emitOpError("initial value expected to be of type ") 1104 << tensorType << ", but was of type " << initType; 1105 } 1106 } 1107 1108 // TODO: verify visibility for declarations. 1109 return success(); 1110 } 1111 1112 //===----------------------------------------------------------------------===// 1113 // GetGlobalOp 1114 //===----------------------------------------------------------------------===// 1115 1116 LogicalResult 1117 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1118 // Verify that the result type is same as the type of the referenced 1119 // memref.global op. 1120 auto global = 1121 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1122 if (!global) 1123 return emitOpError("'") 1124 << name() << "' does not reference a valid global memref"; 1125 1126 Type resultType = result().getType(); 1127 if (global.type() != resultType) 1128 return emitOpError("result type ") 1129 << resultType << " does not match type " << global.type() 1130 << " of the global memref @" << name(); 1131 return success(); 1132 } 1133 1134 //===----------------------------------------------------------------------===// 1135 // LoadOp 1136 //===----------------------------------------------------------------------===// 1137 1138 static LogicalResult verify(LoadOp op) { 1139 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1140 return op.emitOpError("incorrect number of indices for load"); 1141 return success(); 1142 } 1143 1144 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1145 /// load(memrefcast) -> load 1146 if (succeeded(foldMemRefCast(*this))) 1147 return getResult(); 1148 return OpFoldResult(); 1149 } 1150 1151 namespace { 1152 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1153 /// corresponding tensor. 1154 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1155 using OpRewritePattern<LoadOp>::OpRewritePattern; 1156 1157 LogicalResult matchAndRewrite(LoadOp load, 1158 PatternRewriter &rewriter) const override { 1159 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1160 if (!buffercast) 1161 return failure(); 1162 1163 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1164 load.indices()); 1165 return success(); 1166 } 1167 }; 1168 } // end anonymous namespace. 1169 1170 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1171 MLIRContext *context) { 1172 results.add<LoadOfBufferCast>(context); 1173 } 1174 1175 //===----------------------------------------------------------------------===// 1176 // PrefetchOp 1177 //===----------------------------------------------------------------------===// 1178 1179 static void print(OpAsmPrinter &p, PrefetchOp op) { 1180 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1181 p.printOperands(op.indices()); 1182 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1183 p << ", locality<" << op.localityHint(); 1184 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1185 p.printOptionalAttrDict( 1186 op->getAttrs(), 1187 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1188 p << " : " << op.getMemRefType(); 1189 } 1190 1191 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1192 OperationState &result) { 1193 OpAsmParser::OperandType memrefInfo; 1194 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1195 IntegerAttr localityHint; 1196 MemRefType type; 1197 StringRef readOrWrite, cacheType; 1198 1199 auto indexTy = parser.getBuilder().getIndexType(); 1200 auto i32Type = parser.getBuilder().getIntegerType(32); 1201 if (parser.parseOperand(memrefInfo) || 1202 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1203 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1204 parser.parseComma() || parser.parseKeyword("locality") || 1205 parser.parseLess() || 1206 parser.parseAttribute(localityHint, i32Type, "localityHint", 1207 result.attributes) || 1208 parser.parseGreater() || parser.parseComma() || 1209 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1210 parser.resolveOperand(memrefInfo, type, result.operands) || 1211 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1212 return failure(); 1213 1214 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1215 return parser.emitError(parser.getNameLoc(), 1216 "rw specifier has to be 'read' or 'write'"); 1217 result.addAttribute( 1218 PrefetchOp::getIsWriteAttrName(), 1219 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1220 1221 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1222 return parser.emitError(parser.getNameLoc(), 1223 "cache type has to be 'data' or 'instr'"); 1224 1225 result.addAttribute( 1226 PrefetchOp::getIsDataCacheAttrName(), 1227 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1228 1229 return success(); 1230 } 1231 1232 static LogicalResult verify(PrefetchOp op) { 1233 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1234 return op.emitOpError("too few indices"); 1235 1236 return success(); 1237 } 1238 1239 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1240 SmallVectorImpl<OpFoldResult> &results) { 1241 // prefetch(memrefcast) -> prefetch 1242 return foldMemRefCast(*this); 1243 } 1244 1245 //===----------------------------------------------------------------------===// 1246 // ReinterpretCastOp 1247 //===----------------------------------------------------------------------===// 1248 1249 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1250 /// `staticSizes` and `staticStrides` are automatically filled with 1251 /// source-memref-rank sentinel values that encode dynamic entries. 1252 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1253 MemRefType resultType, Value source, 1254 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1255 ArrayRef<OpFoldResult> strides, 1256 ArrayRef<NamedAttribute> attrs) { 1257 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1258 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1259 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1260 ShapedType::kDynamicStrideOrOffset); 1261 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1262 ShapedType::kDynamicSize); 1263 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1264 ShapedType::kDynamicStrideOrOffset); 1265 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1266 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1267 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1268 result.addAttributes(attrs); 1269 } 1270 1271 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1272 MemRefType resultType, Value source, 1273 int64_t offset, ArrayRef<int64_t> sizes, 1274 ArrayRef<int64_t> strides, 1275 ArrayRef<NamedAttribute> attrs) { 1276 SmallVector<OpFoldResult> sizeValues = 1277 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1278 return b.getI64IntegerAttr(v); 1279 })); 1280 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1281 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1282 return b.getI64IntegerAttr(v); 1283 })); 1284 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1285 strideValues, attrs); 1286 } 1287 1288 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1289 MemRefType resultType, Value source, Value offset, 1290 ValueRange sizes, ValueRange strides, 1291 ArrayRef<NamedAttribute> attrs) { 1292 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1293 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1294 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1295 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1296 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1297 } 1298 1299 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1300 // completed automatically, like we have for subview and extract_slice. 1301 static LogicalResult verify(ReinterpretCastOp op) { 1302 // The source and result memrefs should be in the same memory space. 1303 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1304 auto resultType = op.getType().cast<MemRefType>(); 1305 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1306 return op.emitError("different memory spaces specified for source type ") 1307 << srcType << " and result memref type " << resultType; 1308 if (srcType.getElementType() != resultType.getElementType()) 1309 return op.emitError("different element types specified for source type ") 1310 << srcType << " and result memref type " << resultType; 1311 1312 // Match sizes in result memref type and in static_sizes attribute. 1313 for (auto &en : 1314 llvm::enumerate(llvm::zip(resultType.getShape(), 1315 extractFromI64ArrayAttr(op.static_sizes())))) { 1316 int64_t resultSize = std::get<0>(en.value()); 1317 int64_t expectedSize = std::get<1>(en.value()); 1318 if (resultSize != expectedSize) 1319 return op.emitError("expected result type with size = ") 1320 << expectedSize << " instead of " << resultSize 1321 << " in dim = " << en.index(); 1322 } 1323 1324 // Match offset and strides in static_offset and static_strides attributes if 1325 // result memref type has an affine map specified. 1326 if (!resultType.getAffineMaps().empty()) { 1327 int64_t resultOffset; 1328 SmallVector<int64_t, 4> resultStrides; 1329 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1330 return failure(); 1331 1332 // Match offset in result memref type and in static_offsets attribute. 1333 int64_t expectedOffset = 1334 extractFromI64ArrayAttr(op.static_offsets()).front(); 1335 if (resultOffset != expectedOffset) 1336 return op.emitError("expected result type with offset = ") 1337 << resultOffset << " instead of " << expectedOffset; 1338 1339 // Match strides in result memref type and in static_strides attribute. 1340 for (auto &en : llvm::enumerate(llvm::zip( 1341 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1342 int64_t resultStride = std::get<0>(en.value()); 1343 int64_t expectedStride = std::get<1>(en.value()); 1344 if (resultStride != expectedStride) 1345 return op.emitError("expected result type with stride = ") 1346 << expectedStride << " instead of " << resultStride 1347 << " in dim = " << en.index(); 1348 } 1349 } 1350 return success(); 1351 } 1352 1353 //===----------------------------------------------------------------------===// 1354 // Reassociative reshape ops 1355 //===----------------------------------------------------------------------===// 1356 1357 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1358 return getSymbolLessAffineMaps(getReassociationExprs()); 1359 } 1360 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1361 return convertReassociationIndicesToExprs(getContext(), 1362 getReassociationIndices()); 1363 } 1364 1365 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1366 return getSymbolLessAffineMaps(getReassociationExprs()); 1367 } 1368 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1369 return convertReassociationIndicesToExprs(getContext(), 1370 getReassociationIndices()); 1371 } 1372 1373 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1374 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1375 } 1376 1377 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1378 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1379 } 1380 1381 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1382 /// copies. 1383 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1384 ArrayRef<int64_t> sizes, 1385 ArrayRef<AffineExpr> strides) { 1386 assert(sizes.size() == strides.size() && "mismatched ranks"); 1387 // off by 1 indexing to avoid out of bounds 1388 // V 1389 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1390 // Only bands of static shapes are reshapable. This is due to the fact that 1391 // there is no relation between dynamic sizes and dynamic strides: we do not 1392 // have enough information to know whether a "-1" size corresponds to the 1393 // proper symbol in the AffineExpr of a stride. 1394 if (ShapedType::isDynamic(sizes[dim + 1])) 1395 return false; 1396 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1397 // simplify on the fly and catch more reshapable cases. 1398 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1399 return false; 1400 } 1401 return true; 1402 } 1403 1404 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1405 /// expected to be valid) to `type`. 1406 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1407 /// MemRefType. 1408 static MemRefType 1409 computeReshapeCollapsedType(MemRefType type, 1410 ArrayRef<AffineMap> reassociation) { 1411 auto sizes = type.getShape(); 1412 AffineExpr offset; 1413 SmallVector<AffineExpr, 4> strides; 1414 auto status = getStridesAndOffset(type, strides, offset); 1415 (void)status; 1416 assert(succeeded(status) && "expected strided memref"); 1417 1418 SmallVector<int64_t, 4> newSizes; 1419 newSizes.reserve(reassociation.size()); 1420 SmallVector<AffineExpr, 4> newStrides; 1421 newStrides.reserve(reassociation.size()); 1422 1423 // Use the fact that reassociation is valid to simplify the logic: only use 1424 // each map's rank. 1425 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1426 unsigned currentDim = 0; 1427 for (AffineMap m : reassociation) { 1428 unsigned dim = m.getNumResults(); 1429 int64_t size = 1; 1430 AffineExpr stride = strides[currentDim + dim - 1]; 1431 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1432 size = ShapedType::kDynamicSize; 1433 stride = AffineExpr(); 1434 } else { 1435 for (unsigned d = 0; d < dim; ++d) 1436 size *= sizes[currentDim + d]; 1437 } 1438 newSizes.push_back(size); 1439 newStrides.push_back(stride); 1440 currentDim += dim; 1441 } 1442 1443 // Early-exit: if `type` is contiguous, the result must be contiguous. 1444 if (canonicalizeStridedLayout(type).getAffineMaps().empty()) 1445 return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); 1446 1447 // Convert back to int64_t because we don't have enough information to create 1448 // new strided layouts from AffineExpr only. This corresponds to a case where 1449 // copies may be necessary. 1450 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1451 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1452 intOffset = o.getValue(); 1453 SmallVector<int64_t, 4> intStrides; 1454 intStrides.reserve(strides.size()); 1455 for (auto stride : newStrides) { 1456 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1457 intStrides.push_back(cst.getValue()); 1458 else 1459 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1460 } 1461 auto layout = 1462 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1463 return canonicalizeStridedLayout( 1464 MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); 1465 } 1466 1467 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1468 ArrayRef<ReassociationIndices> reassociation, 1469 ArrayRef<NamedAttribute> attrs) { 1470 auto memRefType = src.getType().cast<MemRefType>(); 1471 auto resultType = computeReshapeCollapsedType( 1472 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1473 b.getContext(), reassociation))); 1474 build(b, result, resultType, src, attrs); 1475 result.addAttribute(getReassociationAttrName(), 1476 getReassociationIndicesAttribute(b, reassociation)); 1477 } 1478 1479 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1480 ArrayRef<ReassociationIndices> reassociation, 1481 ArrayRef<NamedAttribute> attrs) { 1482 auto memRefType = src.getType().cast<MemRefType>(); 1483 auto resultType = computeReshapeCollapsedType( 1484 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1485 b.getContext(), reassociation))); 1486 build(b, result, resultType, src, attrs); 1487 result.addAttribute(getReassociationAttrName(), 1488 getReassociationIndicesAttribute(b, reassociation)); 1489 } 1490 1491 template <typename ReshapeOp, 1492 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1493 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1494 MemRefType collapsedType) { 1495 if (failed( 1496 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1497 return failure(); 1498 auto maps = op.getReassociationMaps(); 1499 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1500 if (collapsedType != expectedType) 1501 return op.emitOpError("expected collapsed type to be ") 1502 << expectedType << ", but got " << collapsedType; 1503 return success(); 1504 } 1505 1506 static LogicalResult verify(ExpandShapeOp op) { 1507 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1508 } 1509 1510 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1511 MLIRContext *context) { 1512 results.add<CollapseReshapeOps<ExpandShapeOp>, 1513 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1514 } 1515 1516 static LogicalResult verify(CollapseShapeOp op) { 1517 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1518 } 1519 1520 struct CollapseShapeOpMemRefCastFolder 1521 : public OpRewritePattern<CollapseShapeOp> { 1522 public: 1523 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1524 1525 LogicalResult matchAndRewrite(CollapseShapeOp op, 1526 PatternRewriter &rewriter) const override { 1527 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1528 if (!cast) 1529 return failure(); 1530 1531 if (!CastOp::canFoldIntoConsumerOp(cast)) 1532 return failure(); 1533 1534 Type newResultType = computeReshapeCollapsedType( 1535 cast.getOperand().getType().cast<MemRefType>(), 1536 op.getReassociationMaps()); 1537 1538 if (newResultType == op.getResultType()) { 1539 rewriter.updateRootInPlace( 1540 op, [&]() { op.srcMutable().assign(cast.source()); }); 1541 } else { 1542 Value newOp = rewriter.create<CollapseShapeOp>( 1543 op->getLoc(), cast.source(), op.getReassociationIndices()); 1544 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1545 } 1546 return success(); 1547 } 1548 }; 1549 1550 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1551 MLIRContext *context) { 1552 results.add<CollapseReshapeOps<CollapseShapeOp>, 1553 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1554 CollapseShapeOpMemRefCastFolder>(context); 1555 } 1556 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1557 if (succeeded(foldMemRefCast(*this))) 1558 return getResult(); 1559 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1560 } 1561 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1562 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1563 } 1564 1565 //===----------------------------------------------------------------------===// 1566 // ReshapeOp 1567 //===----------------------------------------------------------------------===// 1568 1569 static LogicalResult verify(ReshapeOp op) { 1570 Type operandType = op.source().getType(); 1571 Type resultType = op.result().getType(); 1572 1573 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1574 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1575 if (operandElementType != resultElementType) 1576 return op.emitOpError("element types of source and destination memref " 1577 "types should be the same"); 1578 1579 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1580 if (!operandMemRefType.getAffineMaps().empty()) 1581 return op.emitOpError( 1582 "source memref type should have identity affine map"); 1583 1584 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1585 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1586 if (resultMemRefType) { 1587 if (!resultMemRefType.getAffineMaps().empty()) 1588 return op.emitOpError( 1589 "result memref type should have identity affine map"); 1590 if (shapeSize == ShapedType::kDynamicSize) 1591 return op.emitOpError("cannot use shape operand with dynamic length to " 1592 "reshape to statically-ranked memref type"); 1593 if (shapeSize != resultMemRefType.getRank()) 1594 return op.emitOpError( 1595 "length of shape operand differs from the result's memref rank"); 1596 } 1597 return success(); 1598 } 1599 1600 //===----------------------------------------------------------------------===// 1601 // StoreOp 1602 //===----------------------------------------------------------------------===// 1603 1604 static LogicalResult verify(StoreOp op) { 1605 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1606 return op.emitOpError("store index operand count not equal to memref rank"); 1607 1608 return success(); 1609 } 1610 1611 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1612 SmallVectorImpl<OpFoldResult> &results) { 1613 /// store(memrefcast) -> store 1614 return foldMemRefCast(*this, getValueToStore()); 1615 } 1616 1617 //===----------------------------------------------------------------------===// 1618 // SubViewOp 1619 //===----------------------------------------------------------------------===// 1620 1621 namespace { 1622 /// Helpers to write more idiomatic operations. 1623 namespace saturated_arith { 1624 struct Wrapper { 1625 explicit Wrapper(int64_t v) : v(v) {} 1626 operator int64_t() { return v; } 1627 int64_t v; 1628 }; 1629 Wrapper operator+(Wrapper a, int64_t b) { 1630 if (ShapedType::isDynamicStrideOrOffset(a) || 1631 ShapedType::isDynamicStrideOrOffset(b)) 1632 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1633 return Wrapper(a.v + b); 1634 } 1635 Wrapper operator*(Wrapper a, int64_t b) { 1636 if (ShapedType::isDynamicStrideOrOffset(a) || 1637 ShapedType::isDynamicStrideOrOffset(b)) 1638 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1639 return Wrapper(a.v * b); 1640 } 1641 } // end namespace saturated_arith 1642 } // end namespace 1643 1644 /// A subview result type can be fully inferred from the source type and the 1645 /// static representation of offsets, sizes and strides. Special sentinels 1646 /// encode the dynamic case. 1647 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1648 ArrayRef<int64_t> leadingStaticOffsets, 1649 ArrayRef<int64_t> leadingStaticSizes, 1650 ArrayRef<int64_t> leadingStaticStrides) { 1651 // A subview may specify only a leading subset of offset/sizes/strides in 1652 // which case we complete with offset=0, sizes from memref type and strides=1. 1653 unsigned rank = sourceMemRefType.getRank(); 1654 assert(leadingStaticOffsets.size() <= rank && 1655 "unexpected leadingStaticOffsets overflow"); 1656 assert(leadingStaticSizes.size() <= rank && 1657 "unexpected leadingStaticSizes overflow"); 1658 assert(leadingStaticStrides.size() <= rank && 1659 "unexpected leadingStaticStrides overflow"); 1660 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1661 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1662 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1663 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1664 unsigned numTrailingSizes = rank - staticSizes.size(); 1665 unsigned numTrailingStrides = rank - staticStrides.size(); 1666 staticOffsets.append(numTrailingOffsets, 0); 1667 llvm::append_range(staticSizes, 1668 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1669 staticStrides.append(numTrailingStrides, 1); 1670 1671 // Extract source offset and strides. 1672 int64_t sourceOffset; 1673 SmallVector<int64_t, 4> sourceStrides; 1674 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1675 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1676 (void)res; 1677 1678 // Compute target offset whose value is: 1679 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1680 int64_t targetOffset = sourceOffset; 1681 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1682 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1683 using namespace saturated_arith; 1684 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1685 } 1686 1687 // Compute target stride whose value is: 1688 // `sourceStrides_i * staticStrides_i`. 1689 SmallVector<int64_t, 4> targetStrides; 1690 targetStrides.reserve(staticOffsets.size()); 1691 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1692 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1693 using namespace saturated_arith; 1694 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1695 } 1696 1697 // The type is now known. 1698 return MemRefType::get( 1699 staticSizes, sourceMemRefType.getElementType(), 1700 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1701 sourceMemRefType.getContext()), 1702 sourceMemRefType.getMemorySpace()); 1703 } 1704 1705 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1706 ArrayRef<OpFoldResult> leadingStaticOffsets, 1707 ArrayRef<OpFoldResult> leadingStaticSizes, 1708 ArrayRef<OpFoldResult> leadingStaticStrides) { 1709 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1710 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1711 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1712 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1713 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1714 ShapedType::kDynamicSize); 1715 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1716 staticStrides, ShapedType::kDynamicStrideOrOffset); 1717 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1718 staticSizes, staticStrides) 1719 .cast<MemRefType>(); 1720 } 1721 1722 Type SubViewOp::inferRankReducedResultType( 1723 unsigned resultRank, MemRefType sourceRankedTensorType, 1724 ArrayRef<int64_t> leadingStaticOffsets, 1725 ArrayRef<int64_t> leadingStaticSizes, 1726 ArrayRef<int64_t> leadingStaticStrides) { 1727 auto inferredType = 1728 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1729 leadingStaticSizes, leadingStaticStrides) 1730 .cast<MemRefType>(); 1731 assert(inferredType.getRank() >= resultRank && "expected "); 1732 int rankDiff = inferredType.getRank() - resultRank; 1733 if (rankDiff > 0) { 1734 auto shape = inferredType.getShape(); 1735 llvm::SmallDenseSet<unsigned> dimsToProject; 1736 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1737 SmallVector<int64_t> projectedShape; 1738 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1739 if (!dimsToProject.contains(pos)) 1740 projectedShape.push_back(shape[pos]); 1741 1742 AffineMap map; 1743 auto maps = inferredType.getAffineMaps(); 1744 if (!maps.empty() && maps.front()) 1745 map = getProjectedMap(maps.front(), dimsToProject); 1746 inferredType = 1747 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1748 inferredType.getMemorySpace()); 1749 } 1750 return inferredType; 1751 } 1752 1753 Type SubViewOp::inferRankReducedResultType( 1754 unsigned resultRank, MemRefType sourceRankedTensorType, 1755 ArrayRef<OpFoldResult> leadingStaticOffsets, 1756 ArrayRef<OpFoldResult> leadingStaticSizes, 1757 ArrayRef<OpFoldResult> leadingStaticStrides) { 1758 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1759 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1760 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1761 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1762 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1763 ShapedType::kDynamicSize); 1764 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1765 staticStrides, ShapedType::kDynamicStrideOrOffset); 1766 return SubViewOp::inferRankReducedResultType( 1767 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1768 staticStrides); 1769 } 1770 // Build a SubViewOp with mixed static and dynamic entries and custom result 1771 // type. If the type passed is nullptr, it is inferred. 1772 void SubViewOp::build(OpBuilder &b, OperationState &result, 1773 MemRefType resultType, Value source, 1774 ArrayRef<OpFoldResult> offsets, 1775 ArrayRef<OpFoldResult> sizes, 1776 ArrayRef<OpFoldResult> strides, 1777 ArrayRef<NamedAttribute> attrs) { 1778 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1779 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1780 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1781 ShapedType::kDynamicStrideOrOffset); 1782 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1783 ShapedType::kDynamicSize); 1784 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1785 ShapedType::kDynamicStrideOrOffset); 1786 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1787 // Structuring implementation this way avoids duplication between builders. 1788 if (!resultType) { 1789 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1790 staticSizes, staticStrides) 1791 .cast<MemRefType>(); 1792 } 1793 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1794 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1795 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1796 result.addAttributes(attrs); 1797 } 1798 1799 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1800 // type. 1801 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1802 ArrayRef<OpFoldResult> offsets, 1803 ArrayRef<OpFoldResult> sizes, 1804 ArrayRef<OpFoldResult> strides, 1805 ArrayRef<NamedAttribute> attrs) { 1806 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1807 } 1808 1809 // Build a SubViewOp with static entries and inferred result type. 1810 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1811 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1812 ArrayRef<int64_t> strides, 1813 ArrayRef<NamedAttribute> attrs) { 1814 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1815 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1816 return b.getI64IntegerAttr(v); 1817 })); 1818 SmallVector<OpFoldResult> sizeValues = 1819 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1820 return b.getI64IntegerAttr(v); 1821 })); 1822 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1823 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1824 return b.getI64IntegerAttr(v); 1825 })); 1826 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1827 } 1828 1829 // Build a SubViewOp with dynamic entries and custom result type. If the 1830 // type passed is nullptr, it is inferred. 1831 void SubViewOp::build(OpBuilder &b, OperationState &result, 1832 MemRefType resultType, Value source, 1833 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1834 ArrayRef<int64_t> strides, 1835 ArrayRef<NamedAttribute> attrs) { 1836 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1837 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1838 return b.getI64IntegerAttr(v); 1839 })); 1840 SmallVector<OpFoldResult> sizeValues = 1841 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1842 return b.getI64IntegerAttr(v); 1843 })); 1844 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1845 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1846 return b.getI64IntegerAttr(v); 1847 })); 1848 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1849 attrs); 1850 } 1851 1852 // Build a SubViewOp with dynamic entries and custom result type. If the type 1853 // passed is nullptr, it is inferred. 1854 void SubViewOp::build(OpBuilder &b, OperationState &result, 1855 MemRefType resultType, Value source, ValueRange offsets, 1856 ValueRange sizes, ValueRange strides, 1857 ArrayRef<NamedAttribute> attrs) { 1858 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1859 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1860 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1861 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1862 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1863 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1864 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1865 } 1866 1867 // Build a SubViewOp with dynamic entries and inferred result type. 1868 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1869 ValueRange offsets, ValueRange sizes, ValueRange strides, 1870 ArrayRef<NamedAttribute> attrs) { 1871 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1872 } 1873 1874 /// For ViewLikeOpInterface. 1875 Value SubViewOp::getViewSource() { return source(); } 1876 1877 enum SubViewVerificationResult { 1878 Success, 1879 RankTooLarge, 1880 SizeMismatch, 1881 ElemTypeMismatch, 1882 MemSpaceMismatch, 1883 AffineMapMismatch 1884 }; 1885 1886 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1887 /// This function is slight variant of `is subsequence` algorithm where 1888 /// not matching dimension must be 1. 1889 static SubViewVerificationResult 1890 isRankReducedType(Type originalType, Type candidateReducedType, 1891 std::string *errMsg = nullptr) { 1892 if (originalType == candidateReducedType) 1893 return SubViewVerificationResult::Success; 1894 if (!originalType.isa<MemRefType>()) 1895 return SubViewVerificationResult::Success; 1896 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1897 return SubViewVerificationResult::Success; 1898 1899 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1900 ShapedType candidateReducedShapedType = 1901 candidateReducedType.cast<ShapedType>(); 1902 1903 // Rank and size logic is valid for all ShapedTypes. 1904 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1905 ArrayRef<int64_t> candidateReducedShape = 1906 candidateReducedShapedType.getShape(); 1907 unsigned originalRank = originalShape.size(), 1908 candidateReducedRank = candidateReducedShape.size(); 1909 if (candidateReducedRank > originalRank) 1910 return SubViewVerificationResult::RankTooLarge; 1911 1912 auto optionalUnusedDimsMask = 1913 computeRankReductionMask(originalShape, candidateReducedShape); 1914 1915 // Sizes cannot be matched in case empty vector is returned. 1916 if (!optionalUnusedDimsMask.hasValue()) 1917 return SubViewVerificationResult::SizeMismatch; 1918 1919 if (originalShapedType.getElementType() != 1920 candidateReducedShapedType.getElementType()) 1921 return SubViewVerificationResult::ElemTypeMismatch; 1922 1923 // Strided layout logic is relevant for MemRefType only. 1924 MemRefType original = originalType.cast<MemRefType>(); 1925 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1926 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1927 return SubViewVerificationResult::MemSpaceMismatch; 1928 1929 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1930 auto inferredType = 1931 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1932 AffineMap candidateLayout; 1933 if (candidateReduced.getAffineMaps().empty()) 1934 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1935 else 1936 candidateLayout = candidateReduced.getAffineMaps().front(); 1937 assert(inferredType.getNumResults() == 1 && 1938 candidateLayout.getNumResults() == 1); 1939 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1940 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1941 if (errMsg) { 1942 llvm::raw_string_ostream os(*errMsg); 1943 os << "inferred type: " << inferredType; 1944 } 1945 return SubViewVerificationResult::AffineMapMismatch; 1946 } 1947 // Check that the difference of the affine maps simplifies to 0. 1948 AffineExpr diffExpr = 1949 inferredType.getResult(0) - candidateLayout.getResult(0); 1950 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1951 inferredType.getNumSymbols()); 1952 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1953 if (!(cst && cst.getValue() == 0)) { 1954 if (errMsg) { 1955 llvm::raw_string_ostream os(*errMsg); 1956 os << "inferred type: " << inferredType; 1957 } 1958 return SubViewVerificationResult::AffineMapMismatch; 1959 } 1960 return SubViewVerificationResult::Success; 1961 } 1962 1963 template <typename OpTy> 1964 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1965 OpTy op, Type expectedType, 1966 StringRef errMsg = "") { 1967 auto memrefType = expectedType.cast<ShapedType>(); 1968 switch (result) { 1969 case SubViewVerificationResult::Success: 1970 return success(); 1971 case SubViewVerificationResult::RankTooLarge: 1972 return op.emitError("expected result rank to be smaller or equal to ") 1973 << "the source rank. " << errMsg; 1974 case SubViewVerificationResult::SizeMismatch: 1975 return op.emitError("expected result type to be ") 1976 << expectedType 1977 << " or a rank-reduced version. (mismatch of result sizes) " 1978 << errMsg; 1979 case SubViewVerificationResult::ElemTypeMismatch: 1980 return op.emitError("expected result element type to be ") 1981 << memrefType.getElementType() << errMsg; 1982 case SubViewVerificationResult::MemSpaceMismatch: 1983 return op.emitError("expected result and source memory spaces to match.") 1984 << errMsg; 1985 case SubViewVerificationResult::AffineMapMismatch: 1986 return op.emitError("expected result type to be ") 1987 << expectedType 1988 << " or a rank-reduced version. (mismatch of result affine map) " 1989 << errMsg; 1990 } 1991 llvm_unreachable("unexpected subview verification result"); 1992 } 1993 1994 /// Verifier for SubViewOp. 1995 static LogicalResult verify(SubViewOp op) { 1996 MemRefType baseType = op.getSourceType(); 1997 MemRefType subViewType = op.getType(); 1998 1999 // The base memref and the view memref should be in the same memory space. 2000 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2001 return op.emitError("different memory spaces specified for base memref " 2002 "type ") 2003 << baseType << " and subview memref type " << subViewType; 2004 2005 // Verify that the base memref type has a strided layout map. 2006 if (!isStrided(baseType)) 2007 return op.emitError("base type ") << baseType << " is not strided"; 2008 2009 // Verify result type against inferred type. 2010 auto expectedType = SubViewOp::inferResultType( 2011 baseType, extractFromI64ArrayAttr(op.static_offsets()), 2012 extractFromI64ArrayAttr(op.static_sizes()), 2013 extractFromI64ArrayAttr(op.static_strides())); 2014 2015 std::string errMsg; 2016 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 2017 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 2018 } 2019 2020 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 2021 return os << "range " << range.offset << ":" << range.size << ":" 2022 << range.stride; 2023 } 2024 2025 /// Return the list of Range (i.e. offset, size, stride). Each Range 2026 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2027 /// with `b` at location `loc`. 2028 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2029 OpBuilder &b, Location loc) { 2030 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2031 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2032 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2033 SmallVector<Range, 8> res; 2034 unsigned rank = ranks[0]; 2035 res.reserve(rank); 2036 for (unsigned idx = 0; idx < rank; ++idx) { 2037 Value offset = 2038 op.isDynamicOffset(idx) 2039 ? op.getDynamicOffset(idx) 2040 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2041 Value size = op.isDynamicSize(idx) 2042 ? op.getDynamicSize(idx) 2043 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 2044 Value stride = 2045 op.isDynamicStride(idx) 2046 ? op.getDynamicStride(idx) 2047 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 2048 res.emplace_back(Range{offset, size, stride}); 2049 } 2050 return res; 2051 } 2052 2053 /// Infer the canonical type of the result of a subview operation. Returns a 2054 /// type with rank `resultRank` that is either the rank of the rank-reduced 2055 /// type, or the non-rank-reduced type. 2056 static MemRefType 2057 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 2058 ArrayRef<OpFoldResult> mixedOffsets, 2059 ArrayRef<OpFoldResult> mixedSizes, 2060 ArrayRef<OpFoldResult> mixedStrides) { 2061 auto resultType = 2062 SubViewOp::inferRankReducedResultType( 2063 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 2064 .cast<MemRefType>(); 2065 if (resultType.getRank() != resultRank) { 2066 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2067 mixedSizes, mixedStrides) 2068 .cast<MemRefType>(); 2069 } 2070 return resultType; 2071 } 2072 2073 namespace { 2074 /// Pattern to rewrite a subview op with MemRefCast arguments. 2075 /// This essentially pushes memref.cast past its consuming subview when 2076 /// `canFoldIntoConsumerOp` is true. 2077 /// 2078 /// Example: 2079 /// ``` 2080 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2081 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2082 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2083 /// ``` 2084 /// is rewritten into: 2085 /// ``` 2086 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2087 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2088 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2089 /// ``` 2090 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2091 public: 2092 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2093 2094 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2095 PatternRewriter &rewriter) const override { 2096 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2097 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2098 return matchPattern(operand, matchConstantIndex()); 2099 })) 2100 return failure(); 2101 2102 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2103 if (!castOp) 2104 return failure(); 2105 2106 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2107 return failure(); 2108 2109 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 2110 /// the cast source operand type and the SubViewOp static information. This 2111 /// is the resulting type if the MemRefCastOp were folded. 2112 auto resultType = getCanonicalSubViewResultType( 2113 subViewOp.getType().getRank(), 2114 castOp.source().getType().cast<MemRefType>(), 2115 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2116 subViewOp.getMixedStrides()); 2117 Value newSubView = rewriter.create<SubViewOp>( 2118 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2119 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2120 subViewOp.static_sizes(), subViewOp.static_strides()); 2121 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2122 newSubView); 2123 return success(); 2124 } 2125 }; 2126 } // namespace 2127 2128 /// Return the canonical type of the result of a subview. 2129 struct SubViewReturnTypeCanonicalizer { 2130 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2131 ArrayRef<OpFoldResult> mixedSizes, 2132 ArrayRef<OpFoldResult> mixedStrides) { 2133 return getCanonicalSubViewResultType(op.getType().getRank(), 2134 op.getSourceType(), mixedOffsets, 2135 mixedSizes, mixedStrides); 2136 } 2137 }; 2138 2139 /// A canonicalizer wrapper to replace SubViewOps. 2140 struct SubViewCanonicalizer { 2141 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2142 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2143 } 2144 }; 2145 2146 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2147 MLIRContext *context) { 2148 results 2149 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2150 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2151 SubViewOpMemRefCastFolder>(context); 2152 } 2153 2154 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2155 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2156 auto sourceShapedType = source().getType().cast<ShapedType>(); 2157 2158 if (resultShapedType.hasStaticShape() && 2159 resultShapedType == sourceShapedType) { 2160 return getViewSource(); 2161 } 2162 2163 return {}; 2164 } 2165 2166 //===----------------------------------------------------------------------===// 2167 // TensorLoadOp 2168 //===----------------------------------------------------------------------===// 2169 2170 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 2171 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 2172 // Approximate alias analysis by conservatively folding only when no there 2173 // is no interleaved operation. 2174 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 2175 bufferCast->getNextNode() == this->getOperation()) 2176 return bufferCast.tensor(); 2177 return {}; 2178 } 2179 2180 namespace { 2181 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> { 2182 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 2183 2184 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 2185 PatternRewriter &rewriter) const override { 2186 auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>(); 2187 if (!tensorLoadOp) 2188 return failure(); 2189 2190 rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(), 2191 dimOp.index()); 2192 return success(); 2193 } 2194 }; 2195 } // namespace 2196 2197 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 2198 MLIRContext *context) { 2199 results.add<DimOfTensorLoadFolder>(context); 2200 } 2201 2202 //===----------------------------------------------------------------------===// 2203 // TransposeOp 2204 //===----------------------------------------------------------------------===// 2205 2206 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2207 static MemRefType inferTransposeResultType(MemRefType memRefType, 2208 AffineMap permutationMap) { 2209 auto rank = memRefType.getRank(); 2210 auto originalSizes = memRefType.getShape(); 2211 // Compute permuted sizes. 2212 SmallVector<int64_t, 4> sizes(rank, 0); 2213 for (auto en : llvm::enumerate(permutationMap.getResults())) 2214 sizes[en.index()] = 2215 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2216 2217 // Compute permuted strides. 2218 int64_t offset; 2219 SmallVector<int64_t, 4> strides; 2220 auto res = getStridesAndOffset(memRefType, strides, offset); 2221 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2222 (void)res; 2223 auto map = 2224 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2225 map = permutationMap ? map.compose(permutationMap) : map; 2226 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2227 } 2228 2229 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2230 AffineMapAttr permutation, 2231 ArrayRef<NamedAttribute> attrs) { 2232 auto permutationMap = permutation.getValue(); 2233 assert(permutationMap); 2234 2235 auto memRefType = in.getType().cast<MemRefType>(); 2236 // Compute result type. 2237 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2238 2239 build(b, result, resultType, in, attrs); 2240 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2241 } 2242 2243 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2244 static void print(OpAsmPrinter &p, TransposeOp op) { 2245 p << "memref.transpose " << op.in() << " " << op.permutation(); 2246 p.printOptionalAttrDict(op->getAttrs(), 2247 {TransposeOp::getPermutationAttrName()}); 2248 p << " : " << op.in().getType() << " to " << op.getType(); 2249 } 2250 2251 static ParseResult parseTransposeOp(OpAsmParser &parser, 2252 OperationState &result) { 2253 OpAsmParser::OperandType in; 2254 AffineMap permutation; 2255 MemRefType srcType, dstType; 2256 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2257 parser.parseOptionalAttrDict(result.attributes) || 2258 parser.parseColonType(srcType) || 2259 parser.resolveOperand(in, srcType, result.operands) || 2260 parser.parseKeywordType("to", dstType) || 2261 parser.addTypeToList(dstType, result.types)) 2262 return failure(); 2263 2264 result.addAttribute(TransposeOp::getPermutationAttrName(), 2265 AffineMapAttr::get(permutation)); 2266 return success(); 2267 } 2268 2269 static LogicalResult verify(TransposeOp op) { 2270 if (!op.permutation().isPermutation()) 2271 return op.emitOpError("expected a permutation map"); 2272 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2273 return op.emitOpError( 2274 "expected a permutation map of same rank as the input"); 2275 2276 auto srcType = op.in().getType().cast<MemRefType>(); 2277 auto dstType = op.getType().cast<MemRefType>(); 2278 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2279 if (dstType != transposedType) 2280 return op.emitOpError("output type ") 2281 << dstType << " does not match transposed input type " << srcType 2282 << ", " << transposedType; 2283 return success(); 2284 } 2285 2286 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2287 if (succeeded(foldMemRefCast(*this))) 2288 return getResult(); 2289 return {}; 2290 } 2291 2292 //===----------------------------------------------------------------------===// 2293 // ViewOp 2294 //===----------------------------------------------------------------------===// 2295 2296 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2297 OpAsmParser::OperandType srcInfo; 2298 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2299 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2300 auto indexType = parser.getBuilder().getIndexType(); 2301 Type srcType, dstType; 2302 llvm::SMLoc offsetLoc; 2303 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2304 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2305 return failure(); 2306 2307 if (offsetInfo.size() != 1) 2308 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2309 2310 return failure( 2311 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2312 parser.parseOptionalAttrDict(result.attributes) || 2313 parser.parseColonType(srcType) || 2314 parser.resolveOperand(srcInfo, srcType, result.operands) || 2315 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2316 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2317 parser.parseKeywordType("to", dstType) || 2318 parser.addTypeToList(dstType, result.types)); 2319 } 2320 2321 static void print(OpAsmPrinter &p, ViewOp op) { 2322 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2323 p.printOperand(op.byte_shift()); 2324 p << "][" << op.sizes() << ']'; 2325 p.printOptionalAttrDict(op->getAttrs()); 2326 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2327 } 2328 2329 static LogicalResult verify(ViewOp op) { 2330 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2331 auto viewType = op.getType(); 2332 2333 // The base memref should have identity layout map (or none). 2334 if (baseType.getAffineMaps().size() > 1 || 2335 (baseType.getAffineMaps().size() == 1 && 2336 !baseType.getAffineMaps()[0].isIdentity())) 2337 return op.emitError("unsupported map for base memref type ") << baseType; 2338 2339 // The result memref should have identity layout map (or none). 2340 if (viewType.getAffineMaps().size() > 1 || 2341 (viewType.getAffineMaps().size() == 1 && 2342 !viewType.getAffineMaps()[0].isIdentity())) 2343 return op.emitError("unsupported map for result memref type ") << viewType; 2344 2345 // The base memref and the view memref should be in the same memory space. 2346 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2347 return op.emitError("different memory spaces specified for base memref " 2348 "type ") 2349 << baseType << " and view memref type " << viewType; 2350 2351 // Verify that we have the correct number of sizes for the result type. 2352 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2353 if (op.sizes().size() != numDynamicDims) 2354 return op.emitError("incorrect number of size operands for type ") 2355 << viewType; 2356 2357 return success(); 2358 } 2359 2360 Value ViewOp::getViewSource() { return source(); } 2361 2362 namespace { 2363 2364 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2365 using OpRewritePattern<ViewOp>::OpRewritePattern; 2366 2367 LogicalResult matchAndRewrite(ViewOp viewOp, 2368 PatternRewriter &rewriter) const override { 2369 // Return if none of the operands are constants. 2370 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2371 return matchPattern(operand, matchConstantIndex()); 2372 })) 2373 return failure(); 2374 2375 // Get result memref type. 2376 auto memrefType = viewOp.getType(); 2377 2378 // Get offset from old memref view type 'memRefType'. 2379 int64_t oldOffset; 2380 SmallVector<int64_t, 4> oldStrides; 2381 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2382 return failure(); 2383 assert(oldOffset == 0 && "Expected 0 offset"); 2384 2385 SmallVector<Value, 4> newOperands; 2386 2387 // Offset cannot be folded into result type. 2388 2389 // Fold any dynamic dim operands which are produced by a constant. 2390 SmallVector<int64_t, 4> newShapeConstants; 2391 newShapeConstants.reserve(memrefType.getRank()); 2392 2393 unsigned dynamicDimPos = 0; 2394 unsigned rank = memrefType.getRank(); 2395 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2396 int64_t dimSize = memrefType.getDimSize(dim); 2397 // If this is already static dimension, keep it. 2398 if (!ShapedType::isDynamic(dimSize)) { 2399 newShapeConstants.push_back(dimSize); 2400 continue; 2401 } 2402 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2403 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2404 // Dynamic shape dimension will be folded. 2405 newShapeConstants.push_back(constantIndexOp.getValue()); 2406 } else { 2407 // Dynamic shape dimension not folded; copy operand from old memref. 2408 newShapeConstants.push_back(dimSize); 2409 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2410 } 2411 dynamicDimPos++; 2412 } 2413 2414 // Create new memref type with constant folded dims. 2415 MemRefType newMemRefType = 2416 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2417 // Nothing new, don't fold. 2418 if (newMemRefType == memrefType) 2419 return failure(); 2420 2421 // Create new ViewOp. 2422 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2423 viewOp.getOperand(0), 2424 viewOp.byte_shift(), newOperands); 2425 // Insert a cast so we have the same type as the old memref type. 2426 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2427 return success(); 2428 } 2429 }; 2430 2431 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2432 using OpRewritePattern<ViewOp>::OpRewritePattern; 2433 2434 LogicalResult matchAndRewrite(ViewOp viewOp, 2435 PatternRewriter &rewriter) const override { 2436 Value memrefOperand = viewOp.getOperand(0); 2437 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2438 if (!memrefCastOp) 2439 return failure(); 2440 Value allocOperand = memrefCastOp.getOperand(); 2441 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2442 if (!allocOp) 2443 return failure(); 2444 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2445 viewOp.byte_shift(), viewOp.sizes()); 2446 return success(); 2447 } 2448 }; 2449 2450 } // end anonymous namespace 2451 2452 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2453 MLIRContext *context) { 2454 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2455 } 2456 2457 //===----------------------------------------------------------------------===// 2458 // TableGen'd op method definitions 2459 //===----------------------------------------------------------------------===// 2460 2461 #define GET_OP_CLASSES 2462 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2463