1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Utils/StaticValueUtils.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Matchers.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Interfaces/InferTypeOpInterface.h" 23 #include "mlir/Interfaces/ViewLikeInterface.h" 24 #include "llvm/ADT/STLExtras.h" 25 26 using namespace mlir; 27 using namespace mlir::memref; 28 29 /// Materialize a single constant operation from a given attribute value with 30 /// the desired resultant type. 31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 32 Attribute value, Type type, 33 Location loc) { 34 if (arith::ConstantOp::isBuildableWith(value, type)) 35 return builder.create<arith::ConstantOp>(loc, value, type); 36 if (ConstantOp::isBuildableWith(value, type)) 37 return builder.create<ConstantOp>(loc, value, type); 38 return nullptr; 39 } 40 41 //===----------------------------------------------------------------------===// 42 // Common canonicalization pattern support logic 43 //===----------------------------------------------------------------------===// 44 45 /// This is a common class used for patterns of the form 46 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 47 /// into the root operation directly. 48 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 49 bool folded = false; 50 for (OpOperand &operand : op->getOpOperands()) { 51 auto cast = operand.get().getDefiningOp<CastOp>(); 52 if (cast && operand.get() != inner && 53 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 54 operand.set(cast.getOperand()); 55 folded = true; 56 } 57 } 58 return success(folded); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // Helpers for GlobalOp 63 //===----------------------------------------------------------------------===// 64 65 static Type getTensorTypeFromMemRefType(Type type) { 66 if (auto memref = type.dyn_cast<MemRefType>()) 67 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 68 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 69 return UnrankedTensorType::get(memref.getElementType()); 70 return NoneType::get(type.getContext()); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // AllocOp / AllocaOp 75 //===----------------------------------------------------------------------===// 76 77 template <typename AllocLikeOp> 78 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 79 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 80 "applies to only alloc or alloca"); 81 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 82 if (!memRefType) 83 return op.emitOpError("result must be a memref"); 84 85 if (static_cast<int64_t>(op.dynamicSizes().size()) != 86 memRefType.getNumDynamicDims()) 87 return op.emitOpError("dimension operand count does not equal memref " 88 "dynamic dimension count"); 89 90 unsigned numSymbols = 0; 91 if (!memRefType.getLayout().isIdentity()) 92 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 93 if (op.symbolOperands().size() != numSymbols) 94 return op.emitOpError("symbol operand count does not equal memref symbol " 95 "count: expected ") 96 << numSymbols << ", got " << op.symbolOperands().size(); 97 98 return success(); 99 } 100 101 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 102 103 static LogicalResult verify(AllocaOp op) { 104 // An alloca op needs to have an ancestor with an allocation scope trait. 105 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 106 return op.emitOpError( 107 "requires an ancestor op with AutomaticAllocationScope trait"); 108 109 return verifyAllocLikeOp(op); 110 } 111 112 namespace { 113 /// Fold constant dimensions into an alloc like operation. 114 template <typename AllocLikeOp> 115 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 116 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 117 118 LogicalResult matchAndRewrite(AllocLikeOp alloc, 119 PatternRewriter &rewriter) const override { 120 // Check to see if any dimensions operands are constants. If so, we can 121 // substitute and drop them. 122 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 123 return matchPattern(operand, matchConstantIndex()); 124 })) 125 return failure(); 126 127 auto memrefType = alloc.getType(); 128 129 // Ok, we have one or more constant operands. Collect the non-constant ones 130 // and keep track of the resultant memref type to build. 131 SmallVector<int64_t, 4> newShapeConstants; 132 newShapeConstants.reserve(memrefType.getRank()); 133 SmallVector<Value, 4> dynamicSizes; 134 135 unsigned dynamicDimPos = 0; 136 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 137 int64_t dimSize = memrefType.getDimSize(dim); 138 // If this is already static dimension, keep it. 139 if (dimSize != -1) { 140 newShapeConstants.push_back(dimSize); 141 continue; 142 } 143 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 144 auto *defOp = dynamicSize.getDefiningOp(); 145 if (auto constantIndexOp = 146 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 147 // Dynamic shape dimension will be folded. 148 newShapeConstants.push_back(constantIndexOp.value()); 149 } else { 150 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 151 newShapeConstants.push_back(-1); 152 dynamicSizes.push_back(dynamicSize); 153 } 154 dynamicDimPos++; 155 } 156 157 // Create new memref type (which will have fewer dynamic dimensions). 158 MemRefType newMemRefType = 159 MemRefType::Builder(memrefType).setShape(newShapeConstants); 160 assert(static_cast<int64_t>(dynamicSizes.size()) == 161 newMemRefType.getNumDynamicDims()); 162 163 // Create and insert the alloc op for the new memref. 164 auto newAlloc = rewriter.create<AllocLikeOp>( 165 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 166 alloc.alignmentAttr()); 167 // Insert a cast so we have the same type as the old alloc. 168 auto resultCast = 169 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 170 171 rewriter.replaceOp(alloc, {resultCast}); 172 return success(); 173 } 174 }; 175 176 /// Fold alloc operations with no users or only store and dealloc uses. 177 template <typename T> 178 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 179 using OpRewritePattern<T>::OpRewritePattern; 180 181 LogicalResult matchAndRewrite(T alloc, 182 PatternRewriter &rewriter) const override { 183 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 184 if (auto storeOp = dyn_cast<StoreOp>(op)) 185 return storeOp.value() == alloc; 186 return !isa<DeallocOp>(op); 187 })) 188 return failure(); 189 190 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 191 rewriter.eraseOp(user); 192 193 rewriter.eraseOp(alloc); 194 return success(); 195 } 196 }; 197 } // end anonymous namespace. 198 199 Optional<Operation *> AllocOp::buildDealloc(OpBuilder &builder, Value alloc) { 200 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 201 .getOperation(); 202 } 203 204 Optional<Value> AllocOp::buildClone(OpBuilder &builder, Value alloc) { 205 return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult(); 206 } 207 208 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 209 MLIRContext *context) { 210 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 211 } 212 213 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 214 MLIRContext *context) { 215 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 216 context); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // AllocaScopeOp 221 //===----------------------------------------------------------------------===// 222 223 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 224 bool printBlockTerminators = false; 225 226 p << " "; 227 if (!op.results().empty()) { 228 p << " -> (" << op.getResultTypes() << ")"; 229 printBlockTerminators = true; 230 } 231 p.printRegion(op.bodyRegion(), 232 /*printEntryBlockArgs=*/false, 233 /*printBlockTerminators=*/printBlockTerminators); 234 p.printOptionalAttrDict(op->getAttrs()); 235 } 236 237 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 238 OperationState &result) { 239 // Create a region for the body. 240 result.regions.reserve(1); 241 Region *bodyRegion = result.addRegion(); 242 243 // Parse optional results type list. 244 if (parser.parseOptionalArrowTypeList(result.types)) 245 return failure(); 246 247 // Parse the body region. 248 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 249 return failure(); 250 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 251 result.location); 252 253 // Parse the optional attribute list. 254 if (parser.parseOptionalAttrDict(result.attributes)) 255 return failure(); 256 257 return success(); 258 } 259 260 static LogicalResult verify(AllocaScopeOp op) { 261 if (failed(RegionBranchOpInterface::verifyTypes(op))) 262 return failure(); 263 264 return success(); 265 } 266 267 void AllocaScopeOp::getSuccessorRegions( 268 Optional<unsigned> index, ArrayRef<Attribute> operands, 269 SmallVectorImpl<RegionSuccessor> ®ions) { 270 if (index.hasValue()) { 271 regions.push_back(RegionSuccessor(getResults())); 272 return; 273 } 274 275 regions.push_back(RegionSuccessor(&bodyRegion())); 276 } 277 278 //===----------------------------------------------------------------------===// 279 // AssumeAlignmentOp 280 //===----------------------------------------------------------------------===// 281 282 static LogicalResult verify(AssumeAlignmentOp op) { 283 unsigned alignment = op.alignment(); 284 if (!llvm::isPowerOf2_32(alignment)) 285 return op.emitOpError("alignment must be power of 2"); 286 return success(); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // BufferCastOp 291 //===----------------------------------------------------------------------===// 292 293 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 294 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 295 if (tensorLoad.memref().getType() == getType()) 296 return tensorLoad.memref(); 297 return {}; 298 } 299 300 namespace { 301 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 302 struct BufferCast : public OpRewritePattern<BufferCastOp> { 303 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 304 305 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 306 PatternRewriter &rewriter) const final { 307 auto tensorCastOperand = 308 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 309 if (!tensorCastOperand) 310 return failure(); 311 auto srcTensorType = 312 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 313 if (!srcTensorType) 314 return failure(); 315 auto memrefType = MemRefType::get(srcTensorType.getShape(), 316 srcTensorType.getElementType()); 317 Value memref = rewriter.create<BufferCastOp>( 318 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 319 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 320 memref); 321 return success(); 322 } 323 }; 324 325 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 326 /// type mismatches prevent `BufferCastOp::fold` to kick in. 327 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 328 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 329 330 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 331 PatternRewriter &rewriter) const final { 332 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 333 // Bail unless we have a tensor_load + memref.buffer_cast with different 334 // types. `BufferCastOp::fold` handles the same type case. 335 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 336 return failure(); 337 // If types are definitely not cast-compatible, bail. 338 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 339 bufferCast.getType())) 340 return failure(); 341 342 // We already know that the types are potentially cast-compatible. However 343 // in case the affine maps are different, we may need to use a copy if we go 344 // from dynamic to static offset or stride (the canonicalization cannot know 345 // at this point that it is really cast compatible). 346 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { 347 int64_t sourceOffset, targetOffset; 348 SmallVector<int64_t, 4> sourceStrides, targetStrides; 349 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || 350 failed(getStridesAndOffset(target, targetStrides, targetOffset))) 351 return false; 352 auto dynamicToStatic = [](int64_t a, int64_t b) { 353 return a == MemRefType::getDynamicStrideOrOffset() && 354 b != MemRefType::getDynamicStrideOrOffset(); 355 }; 356 if (dynamicToStatic(sourceOffset, targetOffset)) 357 return false; 358 for (auto it : zip(sourceStrides, targetStrides)) 359 if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) 360 return false; 361 return true; 362 }; 363 364 auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>(); 365 auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>(); 366 if (tensorLoadType && bufferCastType && 367 !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) { 368 MemRefType resultType = bufferCastType; 369 auto loc = bufferCast.getLoc(); 370 SmallVector<Value, 4> dynamicOperands; 371 for (int i = 0; i < resultType.getRank(); ++i) { 372 if (resultType.getShape()[i] != ShapedType::kDynamicSize) 373 continue; 374 auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i); 375 Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index); 376 dynamicOperands.push_back(size); 377 } 378 auto copy = 379 rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands); 380 rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy); 381 rewriter.replaceOp(bufferCast, {copy}); 382 } else 383 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 384 tensorLoad.memref()); 385 return success(); 386 } 387 }; 388 389 } // namespace 390 391 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 392 MLIRContext *context) { 393 results.add<BufferCast, TensorLoadToMemRef>(context); 394 } 395 396 //===----------------------------------------------------------------------===// 397 // CastOp 398 //===----------------------------------------------------------------------===// 399 400 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 401 /// source memref. This is useful to to fold a memref.cast into a consuming op 402 /// and implement canonicalization patterns for ops in different dialects that 403 /// may consume the results of memref.cast operations. Such foldable memref.cast 404 /// operations are typically inserted as `view` and `subview` ops are 405 /// canonicalized, to preserve the type compatibility of their uses. 406 /// 407 /// Returns true when all conditions are met: 408 /// 1. source and result are ranked memrefs with strided semantics and same 409 /// element type and rank. 410 /// 2. each of the source's size, offset or stride has more static information 411 /// than the corresponding result's size, offset or stride. 412 /// 413 /// Example 1: 414 /// ```mlir 415 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 416 /// %2 = consumer %1 ... : memref<?x?xf32> ... 417 /// ``` 418 /// 419 /// may fold into: 420 /// 421 /// ```mlir 422 /// %2 = consumer %0 ... : memref<8x16xf32> ... 423 /// ``` 424 /// 425 /// Example 2: 426 /// ``` 427 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 428 /// to memref<?x?xf32> 429 /// consumer %1 : memref<?x?xf32> ... 430 /// ``` 431 /// 432 /// may fold into: 433 /// 434 /// ``` 435 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 436 /// ``` 437 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 438 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 439 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 440 441 // Requires ranked MemRefType. 442 if (!sourceType || !resultType) 443 return false; 444 445 // Requires same elemental type. 446 if (sourceType.getElementType() != resultType.getElementType()) 447 return false; 448 449 // Requires same rank. 450 if (sourceType.getRank() != resultType.getRank()) 451 return false; 452 453 // Only fold casts between strided memref forms. 454 int64_t sourceOffset, resultOffset; 455 SmallVector<int64_t, 4> sourceStrides, resultStrides; 456 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 457 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 458 return false; 459 460 // If cast is towards more static sizes along any dimension, don't fold. 461 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 462 auto ss = std::get<0>(it), st = std::get<1>(it); 463 if (ss != st) 464 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 465 return false; 466 } 467 468 // If cast is towards more static offset along any dimension, don't fold. 469 if (sourceOffset != resultOffset) 470 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 471 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 472 return false; 473 474 // If cast is towards more static strides along any dimension, don't fold. 475 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 476 auto ss = std::get<0>(it), st = std::get<1>(it); 477 if (ss != st) 478 if (MemRefType::isDynamicStrideOrOffset(ss) && 479 !MemRefType::isDynamicStrideOrOffset(st)) 480 return false; 481 } 482 483 return true; 484 } 485 486 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 487 if (inputs.size() != 1 || outputs.size() != 1) 488 return false; 489 Type a = inputs.front(), b = outputs.front(); 490 auto aT = a.dyn_cast<MemRefType>(); 491 auto bT = b.dyn_cast<MemRefType>(); 492 493 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 494 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 495 496 if (aT && bT) { 497 if (aT.getElementType() != bT.getElementType()) 498 return false; 499 if (aT.getLayout() != bT.getLayout()) { 500 int64_t aOffset, bOffset; 501 SmallVector<int64_t, 4> aStrides, bStrides; 502 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 503 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 504 aStrides.size() != bStrides.size()) 505 return false; 506 507 // Strides along a dimension/offset are compatible if the value in the 508 // source memref is static and the value in the target memref is the 509 // same. They are also compatible if either one is dynamic (see 510 // description of MemRefCastOp for details). 511 auto checkCompatible = [](int64_t a, int64_t b) { 512 return (a == MemRefType::getDynamicStrideOrOffset() || 513 b == MemRefType::getDynamicStrideOrOffset() || a == b); 514 }; 515 if (!checkCompatible(aOffset, bOffset)) 516 return false; 517 for (auto aStride : enumerate(aStrides)) 518 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 519 return false; 520 } 521 if (aT.getMemorySpace() != bT.getMemorySpace()) 522 return false; 523 524 // They must have the same rank, and any specified dimensions must match. 525 if (aT.getRank() != bT.getRank()) 526 return false; 527 528 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 529 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 530 if (aDim != -1 && bDim != -1 && aDim != bDim) 531 return false; 532 } 533 return true; 534 } else { 535 if (!aT && !uaT) 536 return false; 537 if (!bT && !ubT) 538 return false; 539 // Unranked to unranked casting is unsupported 540 if (uaT && ubT) 541 return false; 542 543 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 544 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 545 if (aEltType != bEltType) 546 return false; 547 548 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 549 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 550 if (aMemSpace != bMemSpace) 551 return false; 552 553 return true; 554 } 555 556 return false; 557 } 558 559 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 560 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 561 } 562 563 //===----------------------------------------------------------------------===// 564 // CloneOp 565 //===----------------------------------------------------------------------===// 566 567 void CloneOp::getEffects( 568 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 569 &effects) { 570 effects.emplace_back(MemoryEffects::Read::get(), input(), 571 SideEffects::DefaultResource::get()); 572 effects.emplace_back(MemoryEffects::Write::get(), output(), 573 SideEffects::DefaultResource::get()); 574 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 575 SideEffects::DefaultResource::get()); 576 } 577 578 namespace { 579 /// Merge the clone and its source (by converting the clone to a cast) when 580 /// possible. 581 struct SimplifyClones : public OpRewritePattern<CloneOp> { 582 using OpRewritePattern<CloneOp>::OpRewritePattern; 583 584 LogicalResult matchAndRewrite(CloneOp cloneOp, 585 PatternRewriter &rewriter) const override { 586 if (cloneOp.use_empty()) { 587 rewriter.eraseOp(cloneOp); 588 return success(); 589 } 590 591 Value source = cloneOp.input(); 592 593 // This only finds dealloc operations for the immediate value. It should 594 // also consider aliases. That would also make the safety check below 595 // redundant. 596 llvm::Optional<Operation *> maybeCloneDeallocOp = 597 findDealloc(cloneOp.output()); 598 // Skip if either of them has > 1 deallocate operations. 599 if (!maybeCloneDeallocOp.hasValue()) 600 return failure(); 601 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source); 602 if (!maybeSourceDeallocOp.hasValue()) 603 return failure(); 604 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 605 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 606 607 // If both are deallocated in the same block, their in-block lifetimes 608 // might not fully overlap, so we cannot decide which one to drop. 609 if (cloneDeallocOp && sourceDeallocOp && 610 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 611 return failure(); 612 613 Block *currentBlock = cloneOp->getBlock(); 614 Operation *redundantDealloc = nullptr; 615 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 616 redundantDealloc = cloneDeallocOp; 617 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 618 redundantDealloc = sourceDeallocOp; 619 } 620 621 if (!redundantDealloc) 622 return failure(); 623 624 // Safety check that there are no other deallocations inbetween 625 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 626 // of source before the uses of the clone. With alias information, we could 627 // restrict this to only fail of the dealloc's operand is an alias 628 // of the source. 629 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 630 pos = pos->getNextNode()) { 631 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 632 if (!effectInterface) 633 continue; 634 if (effectInterface.hasEffect<MemoryEffects::Free>()) 635 return failure(); 636 } 637 638 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 639 source); 640 rewriter.eraseOp(redundantDealloc); 641 return success(); 642 } 643 }; 644 645 } // end anonymous namespace. 646 647 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 648 MLIRContext *context) { 649 results.insert<SimplifyClones>(context); 650 } 651 652 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 653 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 654 } 655 656 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { 657 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) 658 .getOperation(); 659 } 660 661 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { 662 return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult(); 663 } 664 665 //===----------------------------------------------------------------------===// 666 // DeallocOp 667 //===----------------------------------------------------------------------===// 668 669 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 670 SmallVectorImpl<OpFoldResult> &results) { 671 /// dealloc(memrefcast) -> dealloc 672 return foldMemRefCast(*this); 673 } 674 675 //===----------------------------------------------------------------------===// 676 // DimOp 677 //===----------------------------------------------------------------------===// 678 679 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 680 int64_t index) { 681 auto loc = result.location; 682 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 683 build(builder, result, source, indexValue); 684 } 685 686 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 687 Value index) { 688 auto indexTy = builder.getIndexType(); 689 build(builder, result, indexTy, source, index); 690 } 691 692 Optional<int64_t> DimOp::getConstantIndex() { 693 if (auto constantOp = index().getDefiningOp<arith::ConstantOp>()) 694 return constantOp.getValue().cast<IntegerAttr>().getInt(); 695 return {}; 696 } 697 698 static LogicalResult verify(DimOp op) { 699 // Assume unknown index to be in range. 700 Optional<int64_t> index = op.getConstantIndex(); 701 if (!index.hasValue()) 702 return success(); 703 704 // Check that constant index is not knowingly out of range. 705 auto type = op.source().getType(); 706 if (auto memrefType = type.dyn_cast<MemRefType>()) { 707 if (index.getValue() >= memrefType.getRank()) 708 return op.emitOpError("index is out of range"); 709 } else if (type.isa<UnrankedMemRefType>()) { 710 // Assume index to be in range. 711 } else { 712 llvm_unreachable("expected operand with memref type"); 713 } 714 return success(); 715 } 716 717 /// Return a map with key being elements in `vals` and data being number of 718 /// occurences of it. Use std::map, since the `vals` here are strides and the 719 /// dynamic stride value is the same as the tombstone value for 720 /// `DenseMap<int64_t>`. 721 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) { 722 std::map<int64_t, unsigned> numOccurences; 723 for (auto val : vals) 724 numOccurences[val]++; 725 return numOccurences; 726 } 727 728 /// Given the type of the un-rank reduced subview result type and the 729 /// rank-reduced result type, computes the dropped dimensions. This accounts for 730 /// cases where there are multiple unit-dims, but only a subset of those are 731 /// dropped. For MemRefTypes these can be disambiguated using the strides. If a 732 /// dimension is dropped the stride must be dropped too. 733 static llvm::Optional<llvm::SmallDenseSet<unsigned>> 734 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, 735 ArrayAttr staticSizes) { 736 llvm::SmallDenseSet<unsigned> unusedDims; 737 if (originalType.getRank() == reducedType.getRank()) 738 return unusedDims; 739 740 for (auto dim : llvm::enumerate(staticSizes)) 741 if (dim.value().cast<IntegerAttr>().getInt() == 1) 742 unusedDims.insert(dim.index()); 743 SmallVector<int64_t> originalStrides, candidateStrides; 744 int64_t originalOffset, candidateOffset; 745 if (failed( 746 getStridesAndOffset(originalType, originalStrides, originalOffset)) || 747 failed( 748 getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) 749 return llvm::None; 750 751 // For memrefs, a dimension is truly dropped if its corresponding stride is 752 // also dropped. This is particularly important when more than one of the dims 753 // is 1. Track the number of occurences of the strides in the original type 754 // and the candidate type. For each unused dim that stride should not be 755 // present in the candidate type. Note that there could be multiple dimensions 756 // that have the same size. We dont need to exactly figure out which dim 757 // corresponds to which stride, we just need to verify that the number of 758 // reptitions of a stride in the original + number of unused dims with that 759 // stride == number of repititions of a stride in the candidate. 760 std::map<int64_t, unsigned> currUnaccountedStrides = 761 getNumOccurences(originalStrides); 762 std::map<int64_t, unsigned> candidateStridesNumOccurences = 763 getNumOccurences(candidateStrides); 764 llvm::SmallDenseSet<unsigned> prunedUnusedDims; 765 for (unsigned dim : unusedDims) { 766 int64_t originalStride = originalStrides[dim]; 767 if (currUnaccountedStrides[originalStride] > 768 candidateStridesNumOccurences[originalStride]) { 769 // This dim can be treated as dropped. 770 currUnaccountedStrides[originalStride]--; 771 continue; 772 } 773 if (currUnaccountedStrides[originalStride] == 774 candidateStridesNumOccurences[originalStride]) { 775 // The stride for this is not dropped. Keep as is. 776 prunedUnusedDims.insert(dim); 777 continue; 778 } 779 if (currUnaccountedStrides[originalStride] < 780 candidateStridesNumOccurences[originalStride]) { 781 // This should never happen. Cant have a stride in the reduced rank type 782 // that wasnt in the original one. 783 return llvm::None; 784 } 785 } 786 787 for (auto prunedDim : prunedUnusedDims) 788 unusedDims.erase(prunedDim); 789 if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) 790 return llvm::None; 791 return unusedDims; 792 } 793 794 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() { 795 MemRefType sourceType = getSourceType(); 796 MemRefType resultType = getType(); 797 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims = 798 computeMemRefRankReductionMask(sourceType, resultType, static_sizes()); 799 assert(unusedDims && "unable to find unused dims of subview"); 800 return *unusedDims; 801 } 802 803 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 804 // All forms of folding require a known index. 805 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 806 if (!index) 807 return {}; 808 809 // Folding for unranked types (UnrankedMemRefType) is not supported. 810 auto memrefType = source().getType().dyn_cast<MemRefType>(); 811 if (!memrefType) 812 return {}; 813 814 // Fold if the shape extent along the given index is known. 815 if (!memrefType.isDynamicDim(index.getInt())) { 816 Builder builder(getContext()); 817 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 818 } 819 820 // The size at the given index is now known to be a dynamic size. 821 unsigned unsignedIndex = index.getValue().getZExtValue(); 822 823 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 824 Operation *definingOp = source().getDefiningOp(); 825 826 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 827 return *(alloc.getDynamicSizes().begin() + 828 memrefType.getDynamicDimIndex(unsignedIndex)); 829 830 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 831 return *(alloca.getDynamicSizes().begin() + 832 memrefType.getDynamicDimIndex(unsignedIndex)); 833 834 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 835 return *(view.getDynamicSizes().begin() + 836 memrefType.getDynamicDimIndex(unsignedIndex)); 837 838 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 839 llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims(); 840 unsigned resultIndex = 0; 841 unsigned sourceRank = subview.getSourceType().getRank(); 842 unsigned sourceIndex = 0; 843 for (auto i : llvm::seq<unsigned>(0, sourceRank)) { 844 if (unusedDims.count(i)) 845 continue; 846 if (resultIndex == unsignedIndex) { 847 sourceIndex = i; 848 break; 849 } 850 resultIndex++; 851 } 852 assert(subview.isDynamicSize(sourceIndex) && 853 "expected dynamic subview size"); 854 return subview.getDynamicSize(sourceIndex); 855 } 856 857 if (auto sizeInterface = 858 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 859 assert(sizeInterface.isDynamicSize(unsignedIndex) && 860 "Expected dynamic subview size"); 861 return sizeInterface.getDynamicSize(unsignedIndex); 862 } 863 864 // dim(memrefcast) -> dim 865 if (succeeded(foldMemRefCast(*this))) 866 return getResult(); 867 868 return {}; 869 } 870 871 namespace { 872 /// Fold dim of a memref reshape operation to a load into the reshape's shape 873 /// operand. 874 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 875 using OpRewritePattern<DimOp>::OpRewritePattern; 876 877 LogicalResult matchAndRewrite(DimOp dim, 878 PatternRewriter &rewriter) const override { 879 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 880 881 if (!reshape) 882 return failure(); 883 884 // Place the load directly after the reshape to ensure that the shape memref 885 // was not mutated. 886 rewriter.setInsertionPointAfter(reshape); 887 Location loc = dim.getLoc(); 888 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 889 if (load.getType() != dim.getType()) 890 load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); 891 rewriter.replaceOp(dim, load); 892 return success(); 893 } 894 }; 895 896 /// Fold dim of a cast into the dim of the source of the memref cast. 897 struct DimOfCastOp : public OpRewritePattern<DimOp> { 898 using OpRewritePattern<DimOp>::OpRewritePattern; 899 900 LogicalResult matchAndRewrite(DimOp dimOp, 901 PatternRewriter &rewriter) const override { 902 auto castOp = dimOp.source().getDefiningOp<BufferCastOp>(); 903 if (!castOp) 904 return failure(); 905 Value newSource = castOp.getOperand(); 906 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 907 return success(); 908 } 909 }; 910 } // end anonymous namespace. 911 912 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 913 MLIRContext *context) { 914 results.add<DimOfMemRefReshape, DimOfCastOp>(context); 915 } 916 917 // --------------------------------------------------------------------------- 918 // DmaStartOp 919 // --------------------------------------------------------------------------- 920 921 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 922 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 923 ValueRange destIndices, Value numElements, 924 Value tagMemRef, ValueRange tagIndices, Value stride, 925 Value elementsPerStride) { 926 result.addOperands(srcMemRef); 927 result.addOperands(srcIndices); 928 result.addOperands(destMemRef); 929 result.addOperands(destIndices); 930 result.addOperands({numElements, tagMemRef}); 931 result.addOperands(tagIndices); 932 if (stride) 933 result.addOperands({stride, elementsPerStride}); 934 } 935 936 static void print(OpAsmPrinter &p, DmaStartOp op) { 937 p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " 938 << op.getDstMemRef() << '[' << op.getDstIndices() << "], " 939 << op.getNumElements() << ", " << op.getTagMemRef() << '[' 940 << op.getTagIndices() << ']'; 941 if (op.isStrided()) 942 p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); 943 944 p.printOptionalAttrDict(op->getAttrs()); 945 p << " : " << op.getSrcMemRef().getType() << ", " 946 << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); 947 } 948 949 // Parse DmaStartOp. 950 // Ex: 951 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 952 // %tag[%index], %stride, %num_elt_per_stride : 953 // : memref<3076 x f32, 0>, 954 // memref<1024 x f32, 2>, 955 // memref<1 x i32> 956 // 957 static ParseResult parseDmaStartOp(OpAsmParser &parser, 958 OperationState &result) { 959 OpAsmParser::OperandType srcMemRefInfo; 960 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 961 OpAsmParser::OperandType dstMemRefInfo; 962 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 963 OpAsmParser::OperandType numElementsInfo; 964 OpAsmParser::OperandType tagMemrefInfo; 965 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 966 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 967 968 SmallVector<Type, 3> types; 969 auto indexType = parser.getBuilder().getIndexType(); 970 971 // Parse and resolve the following list of operands: 972 // *) source memref followed by its indices (in square brackets). 973 // *) destination memref followed by its indices (in square brackets). 974 // *) dma size in KiB. 975 if (parser.parseOperand(srcMemRefInfo) || 976 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 977 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 978 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 979 parser.parseComma() || parser.parseOperand(numElementsInfo) || 980 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 981 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 982 return failure(); 983 984 // Parse optional stride and elements per stride. 985 if (parser.parseTrailingOperandList(strideInfo)) 986 return failure(); 987 988 bool isStrided = strideInfo.size() == 2; 989 if (!strideInfo.empty() && !isStrided) { 990 return parser.emitError(parser.getNameLoc(), 991 "expected two stride related operands"); 992 } 993 994 if (parser.parseColonTypeList(types)) 995 return failure(); 996 if (types.size() != 3) 997 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 998 999 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 1000 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 1001 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 1002 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 1003 // size should be an index. 1004 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 1005 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 1006 // tag indices should be index. 1007 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 1008 return failure(); 1009 1010 if (isStrided) { 1011 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 1012 return failure(); 1013 } 1014 1015 return success(); 1016 } 1017 1018 static LogicalResult verify(DmaStartOp op) { 1019 unsigned numOperands = op.getNumOperands(); 1020 1021 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 1022 // the number of elements. 1023 if (numOperands < 4) 1024 return op.emitOpError("expected at least 4 operands"); 1025 1026 // Check types of operands. The order of these calls is important: the later 1027 // calls rely on some type properties to compute the operand position. 1028 // 1. Source memref. 1029 if (!op.getSrcMemRef().getType().isa<MemRefType>()) 1030 return op.emitOpError("expected source to be of memref type"); 1031 if (numOperands < op.getSrcMemRefRank() + 4) 1032 return op.emitOpError() 1033 << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; 1034 if (!op.getSrcIndices().empty() && 1035 !llvm::all_of(op.getSrcIndices().getTypes(), 1036 [](Type t) { return t.isIndex(); })) 1037 return op.emitOpError("expected source indices to be of index type"); 1038 1039 // 2. Destination memref. 1040 if (!op.getDstMemRef().getType().isa<MemRefType>()) 1041 return op.emitOpError("expected destination to be of memref type"); 1042 unsigned numExpectedOperands = 1043 op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; 1044 if (numOperands < numExpectedOperands) 1045 return op.emitOpError() 1046 << "expected at least " << numExpectedOperands << " operands"; 1047 if (!op.getDstIndices().empty() && 1048 !llvm::all_of(op.getDstIndices().getTypes(), 1049 [](Type t) { return t.isIndex(); })) 1050 return op.emitOpError("expected destination indices to be of index type"); 1051 1052 // 3. Number of elements. 1053 if (!op.getNumElements().getType().isIndex()) 1054 return op.emitOpError("expected num elements to be of index type"); 1055 1056 // 4. Tag memref. 1057 if (!op.getTagMemRef().getType().isa<MemRefType>()) 1058 return op.emitOpError("expected tag to be of memref type"); 1059 numExpectedOperands += op.getTagMemRefRank(); 1060 if (numOperands < numExpectedOperands) 1061 return op.emitOpError() 1062 << "expected at least " << numExpectedOperands << " operands"; 1063 if (!op.getTagIndices().empty() && 1064 !llvm::all_of(op.getTagIndices().getTypes(), 1065 [](Type t) { return t.isIndex(); })) 1066 return op.emitOpError("expected tag indices to be of index type"); 1067 1068 // Optional stride-related operands must be either both present or both 1069 // absent. 1070 if (numOperands != numExpectedOperands && 1071 numOperands != numExpectedOperands + 2) 1072 return op.emitOpError("incorrect number of operands"); 1073 1074 // 5. Strides. 1075 if (op.isStrided()) { 1076 if (!op.getStride().getType().isIndex() || 1077 !op.getNumElementsPerStride().getType().isIndex()) 1078 return op.emitOpError( 1079 "expected stride and num elements per stride to be of type index"); 1080 } 1081 1082 return success(); 1083 } 1084 1085 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1086 SmallVectorImpl<OpFoldResult> &results) { 1087 /// dma_start(memrefcast) -> dma_start 1088 return foldMemRefCast(*this); 1089 } 1090 1091 // --------------------------------------------------------------------------- 1092 // DmaWaitOp 1093 // --------------------------------------------------------------------------- 1094 1095 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1096 SmallVectorImpl<OpFoldResult> &results) { 1097 /// dma_wait(memrefcast) -> dma_wait 1098 return foldMemRefCast(*this); 1099 } 1100 1101 static LogicalResult verify(DmaWaitOp op) { 1102 // Check that the number of tag indices matches the tagMemRef rank. 1103 unsigned numTagIndices = op.tagIndices().size(); 1104 unsigned tagMemRefRank = op.getTagMemRefRank(); 1105 if (numTagIndices != tagMemRefRank) 1106 return op.emitOpError() << "expected tagIndices to have the same number of " 1107 "elements as the tagMemRef rank, expected " 1108 << tagMemRefRank << ", but got " << numTagIndices; 1109 return success(); 1110 } 1111 1112 //===----------------------------------------------------------------------===// 1113 // GlobalOp 1114 //===----------------------------------------------------------------------===// 1115 1116 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1117 TypeAttr type, 1118 Attribute initialValue) { 1119 p << type; 1120 if (!op.isExternal()) { 1121 p << " = "; 1122 if (op.isUninitialized()) 1123 p << "uninitialized"; 1124 else 1125 p.printAttributeWithoutType(initialValue); 1126 } 1127 } 1128 1129 static ParseResult 1130 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1131 Attribute &initialValue) { 1132 Type type; 1133 if (parser.parseType(type)) 1134 return failure(); 1135 1136 auto memrefType = type.dyn_cast<MemRefType>(); 1137 if (!memrefType || !memrefType.hasStaticShape()) 1138 return parser.emitError(parser.getNameLoc()) 1139 << "type should be static shaped memref, but got " << type; 1140 typeAttr = TypeAttr::get(type); 1141 1142 if (parser.parseOptionalEqual()) 1143 return success(); 1144 1145 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1146 initialValue = UnitAttr::get(parser.getContext()); 1147 return success(); 1148 } 1149 1150 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1151 if (parser.parseAttribute(initialValue, tensorType)) 1152 return failure(); 1153 if (!initialValue.isa<ElementsAttr>()) 1154 return parser.emitError(parser.getNameLoc()) 1155 << "initial value should be a unit or elements attribute"; 1156 return success(); 1157 } 1158 1159 static LogicalResult verify(GlobalOp op) { 1160 auto memrefType = op.type().dyn_cast<MemRefType>(); 1161 if (!memrefType || !memrefType.hasStaticShape()) 1162 return op.emitOpError("type should be static shaped memref, but got ") 1163 << op.type(); 1164 1165 // Verify that the initial value, if present, is either a unit attribute or 1166 // an elements attribute. 1167 if (op.initial_value().hasValue()) { 1168 Attribute initValue = op.initial_value().getValue(); 1169 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1170 return op.emitOpError("initial value should be a unit or elements " 1171 "attribute, but got ") 1172 << initValue; 1173 1174 // Check that the type of the initial value is compatible with the type of 1175 // the global variable. 1176 if (initValue.isa<ElementsAttr>()) { 1177 Type initType = initValue.getType(); 1178 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1179 if (initType != tensorType) 1180 return op.emitOpError("initial value expected to be of type ") 1181 << tensorType << ", but was of type " << initType; 1182 } 1183 } 1184 1185 if (Optional<uint64_t> alignAttr = op.alignment()) { 1186 uint64_t alignment = alignAttr.getValue(); 1187 1188 if (!llvm::isPowerOf2_64(alignment)) 1189 return op->emitError() << "alignment attribute value " << alignment 1190 << " is not a power of 2"; 1191 } 1192 1193 // TODO: verify visibility for declarations. 1194 return success(); 1195 } 1196 1197 //===----------------------------------------------------------------------===// 1198 // GetGlobalOp 1199 //===----------------------------------------------------------------------===// 1200 1201 LogicalResult 1202 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1203 // Verify that the result type is same as the type of the referenced 1204 // memref.global op. 1205 auto global = 1206 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1207 if (!global) 1208 return emitOpError("'") 1209 << name() << "' does not reference a valid global memref"; 1210 1211 Type resultType = result().getType(); 1212 if (global.type() != resultType) 1213 return emitOpError("result type ") 1214 << resultType << " does not match type " << global.type() 1215 << " of the global memref @" << name(); 1216 return success(); 1217 } 1218 1219 //===----------------------------------------------------------------------===// 1220 // LoadOp 1221 //===----------------------------------------------------------------------===// 1222 1223 static LogicalResult verify(LoadOp op) { 1224 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1225 return op.emitOpError("incorrect number of indices for load"); 1226 return success(); 1227 } 1228 1229 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1230 /// load(memrefcast) -> load 1231 if (succeeded(foldMemRefCast(*this))) 1232 return getResult(); 1233 return OpFoldResult(); 1234 } 1235 1236 namespace { 1237 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1238 /// corresponding tensor. 1239 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1240 using OpRewritePattern<LoadOp>::OpRewritePattern; 1241 1242 LogicalResult matchAndRewrite(LoadOp load, 1243 PatternRewriter &rewriter) const override { 1244 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1245 if (!buffercast) 1246 return failure(); 1247 1248 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1249 load.indices()); 1250 return success(); 1251 } 1252 }; 1253 } // end anonymous namespace. 1254 1255 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1256 MLIRContext *context) { 1257 results.add<LoadOfBufferCast>(context); 1258 } 1259 1260 //===----------------------------------------------------------------------===// 1261 // PrefetchOp 1262 //===----------------------------------------------------------------------===// 1263 1264 static void print(OpAsmPrinter &p, PrefetchOp op) { 1265 p << " " << op.memref() << '['; 1266 p.printOperands(op.indices()); 1267 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1268 p << ", locality<" << op.localityHint(); 1269 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1270 p.printOptionalAttrDict( 1271 op->getAttrs(), 1272 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1273 p << " : " << op.getMemRefType(); 1274 } 1275 1276 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1277 OperationState &result) { 1278 OpAsmParser::OperandType memrefInfo; 1279 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1280 IntegerAttr localityHint; 1281 MemRefType type; 1282 StringRef readOrWrite, cacheType; 1283 1284 auto indexTy = parser.getBuilder().getIndexType(); 1285 auto i32Type = parser.getBuilder().getIntegerType(32); 1286 if (parser.parseOperand(memrefInfo) || 1287 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1288 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1289 parser.parseComma() || parser.parseKeyword("locality") || 1290 parser.parseLess() || 1291 parser.parseAttribute(localityHint, i32Type, "localityHint", 1292 result.attributes) || 1293 parser.parseGreater() || parser.parseComma() || 1294 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1295 parser.resolveOperand(memrefInfo, type, result.operands) || 1296 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1297 return failure(); 1298 1299 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1300 return parser.emitError(parser.getNameLoc(), 1301 "rw specifier has to be 'read' or 'write'"); 1302 result.addAttribute( 1303 PrefetchOp::getIsWriteAttrName(), 1304 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1305 1306 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1307 return parser.emitError(parser.getNameLoc(), 1308 "cache type has to be 'data' or 'instr'"); 1309 1310 result.addAttribute( 1311 PrefetchOp::getIsDataCacheAttrName(), 1312 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1313 1314 return success(); 1315 } 1316 1317 static LogicalResult verify(PrefetchOp op) { 1318 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1319 return op.emitOpError("too few indices"); 1320 1321 return success(); 1322 } 1323 1324 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1325 SmallVectorImpl<OpFoldResult> &results) { 1326 // prefetch(memrefcast) -> prefetch 1327 return foldMemRefCast(*this); 1328 } 1329 1330 //===----------------------------------------------------------------------===// 1331 // ReinterpretCastOp 1332 //===----------------------------------------------------------------------===// 1333 1334 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1335 /// `staticSizes` and `staticStrides` are automatically filled with 1336 /// source-memref-rank sentinel values that encode dynamic entries. 1337 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1338 MemRefType resultType, Value source, 1339 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1340 ArrayRef<OpFoldResult> strides, 1341 ArrayRef<NamedAttribute> attrs) { 1342 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1343 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1344 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1345 ShapedType::kDynamicStrideOrOffset); 1346 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1347 ShapedType::kDynamicSize); 1348 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1349 ShapedType::kDynamicStrideOrOffset); 1350 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1351 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1352 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1353 result.addAttributes(attrs); 1354 } 1355 1356 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1357 MemRefType resultType, Value source, 1358 int64_t offset, ArrayRef<int64_t> sizes, 1359 ArrayRef<int64_t> strides, 1360 ArrayRef<NamedAttribute> attrs) { 1361 SmallVector<OpFoldResult> sizeValues = 1362 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1363 return b.getI64IntegerAttr(v); 1364 })); 1365 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1366 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1367 return b.getI64IntegerAttr(v); 1368 })); 1369 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1370 strideValues, attrs); 1371 } 1372 1373 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1374 MemRefType resultType, Value source, Value offset, 1375 ValueRange sizes, ValueRange strides, 1376 ArrayRef<NamedAttribute> attrs) { 1377 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1378 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1379 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1380 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1381 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1382 } 1383 1384 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1385 // completed automatically, like we have for subview and extract_slice. 1386 static LogicalResult verify(ReinterpretCastOp op) { 1387 // The source and result memrefs should be in the same memory space. 1388 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1389 auto resultType = op.getType().cast<MemRefType>(); 1390 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1391 return op.emitError("different memory spaces specified for source type ") 1392 << srcType << " and result memref type " << resultType; 1393 if (srcType.getElementType() != resultType.getElementType()) 1394 return op.emitError("different element types specified for source type ") 1395 << srcType << " and result memref type " << resultType; 1396 1397 // Match sizes in result memref type and in static_sizes attribute. 1398 for (auto &en : 1399 llvm::enumerate(llvm::zip(resultType.getShape(), 1400 extractFromI64ArrayAttr(op.static_sizes())))) { 1401 int64_t resultSize = std::get<0>(en.value()); 1402 int64_t expectedSize = std::get<1>(en.value()); 1403 if (resultSize != expectedSize) 1404 return op.emitError("expected result type with size = ") 1405 << expectedSize << " instead of " << resultSize 1406 << " in dim = " << en.index(); 1407 } 1408 1409 // Match offset and strides in static_offset and static_strides attributes if 1410 // result memref type has an affine map specified. 1411 if (!resultType.getLayout().isIdentity()) { 1412 int64_t resultOffset; 1413 SmallVector<int64_t, 4> resultStrides; 1414 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1415 return failure(); 1416 1417 // Match offset in result memref type and in static_offsets attribute. 1418 int64_t expectedOffset = 1419 extractFromI64ArrayAttr(op.static_offsets()).front(); 1420 if (resultOffset != expectedOffset) 1421 return op.emitError("expected result type with offset = ") 1422 << resultOffset << " instead of " << expectedOffset; 1423 1424 // Match strides in result memref type and in static_strides attribute. 1425 for (auto &en : llvm::enumerate(llvm::zip( 1426 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1427 int64_t resultStride = std::get<0>(en.value()); 1428 int64_t expectedStride = std::get<1>(en.value()); 1429 if (resultStride != expectedStride) 1430 return op.emitError("expected result type with stride = ") 1431 << expectedStride << " instead of " << resultStride 1432 << " in dim = " << en.index(); 1433 } 1434 } 1435 return success(); 1436 } 1437 1438 //===----------------------------------------------------------------------===// 1439 // Reassociative reshape ops 1440 //===----------------------------------------------------------------------===// 1441 1442 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1443 return getSymbolLessAffineMaps(getReassociationExprs()); 1444 } 1445 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1446 return convertReassociationIndicesToExprs(getContext(), 1447 getReassociationIndices()); 1448 } 1449 1450 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1451 return getSymbolLessAffineMaps(getReassociationExprs()); 1452 } 1453 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1454 return convertReassociationIndicesToExprs(getContext(), 1455 getReassociationIndices()); 1456 } 1457 1458 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1459 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1460 } 1461 1462 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1463 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1464 } 1465 1466 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1467 /// copies. 1468 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1469 ArrayRef<int64_t> sizes, 1470 ArrayRef<AffineExpr> strides) { 1471 assert(sizes.size() == strides.size() && "mismatched ranks"); 1472 // off by 1 indexing to avoid out of bounds 1473 // V 1474 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1475 // Only bands of static shapes are reshapable. This is due to the fact that 1476 // there is no relation between dynamic sizes and dynamic strides: we do not 1477 // have enough information to know whether a "-1" size corresponds to the 1478 // proper symbol in the AffineExpr of a stride. 1479 if (ShapedType::isDynamic(sizes[dim + 1])) 1480 return false; 1481 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1482 // simplify on the fly and catch more reshapable cases. 1483 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1484 return false; 1485 } 1486 return true; 1487 } 1488 1489 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1490 /// expected to be valid) to `type`. 1491 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1492 /// MemRefType. 1493 static MemRefType 1494 computeReshapeCollapsedType(MemRefType type, 1495 ArrayRef<AffineMap> reassociation) { 1496 auto sizes = type.getShape(); 1497 AffineExpr offset; 1498 SmallVector<AffineExpr, 4> strides; 1499 auto status = getStridesAndOffset(type, strides, offset); 1500 (void)status; 1501 assert(succeeded(status) && "expected strided memref"); 1502 1503 SmallVector<int64_t, 4> newSizes; 1504 newSizes.reserve(reassociation.size()); 1505 SmallVector<AffineExpr, 4> newStrides; 1506 newStrides.reserve(reassociation.size()); 1507 1508 // Use the fact that reassociation is valid to simplify the logic: only use 1509 // each map's rank. 1510 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1511 unsigned currentDim = 0; 1512 for (AffineMap m : reassociation) { 1513 unsigned dim = m.getNumResults(); 1514 int64_t size = 1; 1515 AffineExpr stride = strides[currentDim + dim - 1]; 1516 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1517 size = ShapedType::kDynamicSize; 1518 stride = AffineExpr(); 1519 } else { 1520 for (unsigned d = 0; d < dim; ++d) 1521 size *= sizes[currentDim + d]; 1522 } 1523 newSizes.push_back(size); 1524 newStrides.push_back(stride); 1525 currentDim += dim; 1526 } 1527 1528 // Early-exit: if `type` is contiguous, the result must be contiguous. 1529 if (canonicalizeStridedLayout(type).getLayout().isIdentity()) 1530 return MemRefType::Builder(type).setShape(newSizes).setLayout({}); 1531 1532 // Convert back to int64_t because we don't have enough information to create 1533 // new strided layouts from AffineExpr only. This corresponds to a case where 1534 // copies may be necessary. 1535 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1536 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1537 intOffset = o.getValue(); 1538 SmallVector<int64_t, 4> intStrides; 1539 intStrides.reserve(strides.size()); 1540 for (auto stride : newStrides) { 1541 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1542 intStrides.push_back(cst.getValue()); 1543 else 1544 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1545 } 1546 auto layout = 1547 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1548 return canonicalizeStridedLayout( 1549 MemRefType::Builder(type).setShape(newSizes).setLayout( 1550 AffineMapAttr::get(layout))); 1551 } 1552 1553 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1554 ArrayRef<ReassociationIndices> reassociation, 1555 ArrayRef<NamedAttribute> attrs) { 1556 auto memRefType = src.getType().cast<MemRefType>(); 1557 auto resultType = computeReshapeCollapsedType( 1558 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1559 b.getContext(), reassociation))); 1560 build(b, result, resultType, src, attrs); 1561 result.addAttribute(getReassociationAttrName(), 1562 getReassociationIndicesAttribute(b, reassociation)); 1563 } 1564 1565 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1566 ArrayRef<ReassociationIndices> reassociation, 1567 ArrayRef<NamedAttribute> attrs) { 1568 auto memRefType = src.getType().cast<MemRefType>(); 1569 auto resultType = computeReshapeCollapsedType( 1570 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1571 b.getContext(), reassociation))); 1572 build(b, result, resultType, src, attrs); 1573 result.addAttribute(getReassociationAttrName(), 1574 getReassociationIndicesAttribute(b, reassociation)); 1575 } 1576 1577 template <typename ReshapeOp, 1578 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1579 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1580 MemRefType collapsedType) { 1581 if (failed( 1582 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1583 return failure(); 1584 auto maps = op.getReassociationMaps(); 1585 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1586 if (collapsedType != expectedType) 1587 return op.emitOpError("expected collapsed type to be ") 1588 << expectedType << ", but got " << collapsedType; 1589 return success(); 1590 } 1591 1592 static LogicalResult verify(ExpandShapeOp op) { 1593 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1594 } 1595 1596 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1597 MLIRContext *context) { 1598 results.add<CollapseReshapeOps<ExpandShapeOp>, 1599 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1600 } 1601 1602 static LogicalResult verify(CollapseShapeOp op) { 1603 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1604 } 1605 1606 struct CollapseShapeOpMemRefCastFolder 1607 : public OpRewritePattern<CollapseShapeOp> { 1608 public: 1609 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1610 1611 LogicalResult matchAndRewrite(CollapseShapeOp op, 1612 PatternRewriter &rewriter) const override { 1613 auto cast = op.getOperand().getDefiningOp<CastOp>(); 1614 if (!cast) 1615 return failure(); 1616 1617 if (!CastOp::canFoldIntoConsumerOp(cast)) 1618 return failure(); 1619 1620 Type newResultType = computeReshapeCollapsedType( 1621 cast.getOperand().getType().cast<MemRefType>(), 1622 op.getReassociationMaps()); 1623 1624 if (newResultType == op.getResultType()) { 1625 rewriter.updateRootInPlace( 1626 op, [&]() { op.srcMutable().assign(cast.source()); }); 1627 } else { 1628 Value newOp = rewriter.create<CollapseShapeOp>( 1629 op->getLoc(), cast.source(), op.getReassociationIndices()); 1630 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); 1631 } 1632 return success(); 1633 } 1634 }; 1635 1636 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1637 MLIRContext *context) { 1638 results.add<CollapseReshapeOps<CollapseShapeOp>, 1639 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>, 1640 CollapseShapeOpMemRefCastFolder>(context); 1641 } 1642 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1643 if (succeeded(foldMemRefCast(*this))) 1644 return getResult(); 1645 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1646 } 1647 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1648 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1649 } 1650 1651 //===----------------------------------------------------------------------===// 1652 // ReshapeOp 1653 //===----------------------------------------------------------------------===// 1654 1655 static LogicalResult verify(ReshapeOp op) { 1656 Type operandType = op.source().getType(); 1657 Type resultType = op.result().getType(); 1658 1659 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1660 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1661 if (operandElementType != resultElementType) 1662 return op.emitOpError("element types of source and destination memref " 1663 "types should be the same"); 1664 1665 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1666 if (!operandMemRefType.getLayout().isIdentity()) 1667 return op.emitOpError( 1668 "source memref type should have identity affine map"); 1669 1670 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1671 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1672 if (resultMemRefType) { 1673 if (!resultMemRefType.getLayout().isIdentity()) 1674 return op.emitOpError( 1675 "result memref type should have identity affine map"); 1676 if (shapeSize == ShapedType::kDynamicSize) 1677 return op.emitOpError("cannot use shape operand with dynamic length to " 1678 "reshape to statically-ranked memref type"); 1679 if (shapeSize != resultMemRefType.getRank()) 1680 return op.emitOpError( 1681 "length of shape operand differs from the result's memref rank"); 1682 } 1683 return success(); 1684 } 1685 1686 //===----------------------------------------------------------------------===// 1687 // StoreOp 1688 //===----------------------------------------------------------------------===// 1689 1690 static LogicalResult verify(StoreOp op) { 1691 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1692 return op.emitOpError("store index operand count not equal to memref rank"); 1693 1694 return success(); 1695 } 1696 1697 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1698 SmallVectorImpl<OpFoldResult> &results) { 1699 /// store(memrefcast) -> store 1700 return foldMemRefCast(*this, getValueToStore()); 1701 } 1702 1703 //===----------------------------------------------------------------------===// 1704 // SubViewOp 1705 //===----------------------------------------------------------------------===// 1706 1707 namespace { 1708 /// Helpers to write more idiomatic operations. 1709 namespace saturated_arith { 1710 struct Wrapper { 1711 explicit Wrapper(int64_t v) : v(v) {} 1712 operator int64_t() { return v; } 1713 int64_t v; 1714 }; 1715 Wrapper operator+(Wrapper a, int64_t b) { 1716 if (ShapedType::isDynamicStrideOrOffset(a) || 1717 ShapedType::isDynamicStrideOrOffset(b)) 1718 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1719 return Wrapper(a.v + b); 1720 } 1721 Wrapper operator*(Wrapper a, int64_t b) { 1722 if (ShapedType::isDynamicStrideOrOffset(a) || 1723 ShapedType::isDynamicStrideOrOffset(b)) 1724 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1725 return Wrapper(a.v * b); 1726 } 1727 } // end namespace saturated_arith 1728 } // end namespace 1729 1730 /// A subview result type can be fully inferred from the source type and the 1731 /// static representation of offsets, sizes and strides. Special sentinels 1732 /// encode the dynamic case. 1733 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1734 ArrayRef<int64_t> leadingStaticOffsets, 1735 ArrayRef<int64_t> leadingStaticSizes, 1736 ArrayRef<int64_t> leadingStaticStrides) { 1737 // A subview may specify only a leading subset of offset/sizes/strides in 1738 // which case we complete with offset=0, sizes from memref type and strides=1. 1739 unsigned rank = sourceMemRefType.getRank(); 1740 assert(leadingStaticOffsets.size() <= rank && 1741 "unexpected leadingStaticOffsets overflow"); 1742 assert(leadingStaticSizes.size() <= rank && 1743 "unexpected leadingStaticSizes overflow"); 1744 assert(leadingStaticStrides.size() <= rank && 1745 "unexpected leadingStaticStrides overflow"); 1746 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1747 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1748 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1749 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1750 unsigned numTrailingSizes = rank - staticSizes.size(); 1751 unsigned numTrailingStrides = rank - staticStrides.size(); 1752 staticOffsets.append(numTrailingOffsets, 0); 1753 llvm::append_range(staticSizes, 1754 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1755 staticStrides.append(numTrailingStrides, 1); 1756 1757 // Extract source offset and strides. 1758 int64_t sourceOffset; 1759 SmallVector<int64_t, 4> sourceStrides; 1760 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1761 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1762 (void)res; 1763 1764 // Compute target offset whose value is: 1765 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1766 int64_t targetOffset = sourceOffset; 1767 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1768 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1769 using namespace saturated_arith; 1770 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1771 } 1772 1773 // Compute target stride whose value is: 1774 // `sourceStrides_i * staticStrides_i`. 1775 SmallVector<int64_t, 4> targetStrides; 1776 targetStrides.reserve(staticOffsets.size()); 1777 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1778 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1779 using namespace saturated_arith; 1780 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1781 } 1782 1783 // The type is now known. 1784 return MemRefType::get( 1785 staticSizes, sourceMemRefType.getElementType(), 1786 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1787 sourceMemRefType.getContext()), 1788 sourceMemRefType.getMemorySpace()); 1789 } 1790 1791 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1792 ArrayRef<OpFoldResult> leadingStaticOffsets, 1793 ArrayRef<OpFoldResult> leadingStaticSizes, 1794 ArrayRef<OpFoldResult> leadingStaticStrides) { 1795 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1796 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1797 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1798 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1799 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1800 ShapedType::kDynamicSize); 1801 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1802 staticStrides, ShapedType::kDynamicStrideOrOffset); 1803 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1804 staticSizes, staticStrides) 1805 .cast<MemRefType>(); 1806 } 1807 1808 Type SubViewOp::inferRankReducedResultType( 1809 unsigned resultRank, MemRefType sourceRankedTensorType, 1810 ArrayRef<int64_t> leadingStaticOffsets, 1811 ArrayRef<int64_t> leadingStaticSizes, 1812 ArrayRef<int64_t> leadingStaticStrides) { 1813 auto inferredType = 1814 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1815 leadingStaticSizes, leadingStaticStrides) 1816 .cast<MemRefType>(); 1817 assert(inferredType.getRank() >= resultRank && "expected "); 1818 int rankDiff = inferredType.getRank() - resultRank; 1819 if (rankDiff > 0) { 1820 auto shape = inferredType.getShape(); 1821 llvm::SmallDenseSet<unsigned> dimsToProject; 1822 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1823 SmallVector<int64_t> projectedShape; 1824 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1825 if (!dimsToProject.contains(pos)) 1826 projectedShape.push_back(shape[pos]); 1827 1828 AffineMap map = inferredType.getLayout().getAffineMap(); 1829 if (!map.isIdentity()) 1830 map = getProjectedMap(map, dimsToProject); 1831 inferredType = 1832 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1833 inferredType.getMemorySpace()); 1834 } 1835 return inferredType; 1836 } 1837 1838 Type SubViewOp::inferRankReducedResultType( 1839 unsigned resultRank, MemRefType sourceRankedTensorType, 1840 ArrayRef<OpFoldResult> leadingStaticOffsets, 1841 ArrayRef<OpFoldResult> leadingStaticSizes, 1842 ArrayRef<OpFoldResult> leadingStaticStrides) { 1843 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1844 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1845 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1846 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1847 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1848 ShapedType::kDynamicSize); 1849 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1850 staticStrides, ShapedType::kDynamicStrideOrOffset); 1851 return SubViewOp::inferRankReducedResultType( 1852 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1853 staticStrides); 1854 } 1855 // Build a SubViewOp with mixed static and dynamic entries and custom result 1856 // type. If the type passed is nullptr, it is inferred. 1857 void SubViewOp::build(OpBuilder &b, OperationState &result, 1858 MemRefType resultType, Value source, 1859 ArrayRef<OpFoldResult> offsets, 1860 ArrayRef<OpFoldResult> sizes, 1861 ArrayRef<OpFoldResult> strides, 1862 ArrayRef<NamedAttribute> attrs) { 1863 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1864 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1865 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1866 ShapedType::kDynamicStrideOrOffset); 1867 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1868 ShapedType::kDynamicSize); 1869 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1870 ShapedType::kDynamicStrideOrOffset); 1871 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1872 // Structuring implementation this way avoids duplication between builders. 1873 if (!resultType) { 1874 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1875 staticSizes, staticStrides) 1876 .cast<MemRefType>(); 1877 } 1878 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1879 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1880 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1881 result.addAttributes(attrs); 1882 } 1883 1884 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1885 // type. 1886 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1887 ArrayRef<OpFoldResult> offsets, 1888 ArrayRef<OpFoldResult> sizes, 1889 ArrayRef<OpFoldResult> strides, 1890 ArrayRef<NamedAttribute> attrs) { 1891 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1892 } 1893 1894 // Build a SubViewOp with static entries and inferred result type. 1895 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1896 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1897 ArrayRef<int64_t> strides, 1898 ArrayRef<NamedAttribute> attrs) { 1899 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1900 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1901 return b.getI64IntegerAttr(v); 1902 })); 1903 SmallVector<OpFoldResult> sizeValues = 1904 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1905 return b.getI64IntegerAttr(v); 1906 })); 1907 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1908 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1909 return b.getI64IntegerAttr(v); 1910 })); 1911 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1912 } 1913 1914 // Build a SubViewOp with dynamic entries and custom result type. If the 1915 // type passed is nullptr, it is inferred. 1916 void SubViewOp::build(OpBuilder &b, OperationState &result, 1917 MemRefType resultType, Value source, 1918 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1919 ArrayRef<int64_t> strides, 1920 ArrayRef<NamedAttribute> attrs) { 1921 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1922 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1923 return b.getI64IntegerAttr(v); 1924 })); 1925 SmallVector<OpFoldResult> sizeValues = 1926 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1927 return b.getI64IntegerAttr(v); 1928 })); 1929 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1930 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1931 return b.getI64IntegerAttr(v); 1932 })); 1933 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1934 attrs); 1935 } 1936 1937 // Build a SubViewOp with dynamic entries and custom result type. If the type 1938 // passed is nullptr, it is inferred. 1939 void SubViewOp::build(OpBuilder &b, OperationState &result, 1940 MemRefType resultType, Value source, ValueRange offsets, 1941 ValueRange sizes, ValueRange strides, 1942 ArrayRef<NamedAttribute> attrs) { 1943 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1944 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1945 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1946 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1947 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1948 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1949 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1950 } 1951 1952 // Build a SubViewOp with dynamic entries and inferred result type. 1953 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1954 ValueRange offsets, ValueRange sizes, ValueRange strides, 1955 ArrayRef<NamedAttribute> attrs) { 1956 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1957 } 1958 1959 /// For ViewLikeOpInterface. 1960 Value SubViewOp::getViewSource() { return source(); } 1961 1962 enum SubViewVerificationResult { 1963 Success, 1964 RankTooLarge, 1965 SizeMismatch, 1966 ElemTypeMismatch, 1967 MemSpaceMismatch, 1968 AffineMapMismatch 1969 }; 1970 1971 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1972 /// This function is slight variant of `is subsequence` algorithm where 1973 /// not matching dimension must be 1. 1974 static SubViewVerificationResult 1975 isRankReducedType(Type originalType, Type candidateReducedType, 1976 ArrayAttr staticSizes, std::string *errMsg = nullptr) { 1977 if (originalType == candidateReducedType) 1978 return SubViewVerificationResult::Success; 1979 if (!originalType.isa<MemRefType>()) 1980 return SubViewVerificationResult::Success; 1981 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1982 return SubViewVerificationResult::Success; 1983 1984 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1985 ShapedType candidateReducedShapedType = 1986 candidateReducedType.cast<ShapedType>(); 1987 1988 // Rank and size logic is valid for all ShapedTypes. 1989 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1990 ArrayRef<int64_t> candidateReducedShape = 1991 candidateReducedShapedType.getShape(); 1992 unsigned originalRank = originalShape.size(), 1993 candidateReducedRank = candidateReducedShape.size(); 1994 if (candidateReducedRank > originalRank) 1995 return SubViewVerificationResult::RankTooLarge; 1996 1997 MemRefType original = originalType.cast<MemRefType>(); 1998 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1999 2000 auto optionalUnusedDimsMask = 2001 computeMemRefRankReductionMask(original, candidateReduced, staticSizes); 2002 2003 // Sizes cannot be matched in case empty vector is returned. 2004 if (!optionalUnusedDimsMask.hasValue()) 2005 return SubViewVerificationResult::SizeMismatch; 2006 2007 if (originalShapedType.getElementType() != 2008 candidateReducedShapedType.getElementType()) 2009 return SubViewVerificationResult::ElemTypeMismatch; 2010 2011 // Strided layout logic is relevant for MemRefType only. 2012 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 2013 return SubViewVerificationResult::MemSpaceMismatch; 2014 return SubViewVerificationResult::Success; 2015 } 2016 2017 template <typename OpTy> 2018 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 2019 OpTy op, Type expectedType, 2020 StringRef errMsg = "") { 2021 auto memrefType = expectedType.cast<ShapedType>(); 2022 switch (result) { 2023 case SubViewVerificationResult::Success: 2024 return success(); 2025 case SubViewVerificationResult::RankTooLarge: 2026 return op.emitError("expected result rank to be smaller or equal to ") 2027 << "the source rank. " << errMsg; 2028 case SubViewVerificationResult::SizeMismatch: 2029 return op.emitError("expected result type to be ") 2030 << expectedType 2031 << " or a rank-reduced version. (mismatch of result sizes) " 2032 << errMsg; 2033 case SubViewVerificationResult::ElemTypeMismatch: 2034 return op.emitError("expected result element type to be ") 2035 << memrefType.getElementType() << errMsg; 2036 case SubViewVerificationResult::MemSpaceMismatch: 2037 return op.emitError("expected result and source memory spaces to match.") 2038 << errMsg; 2039 case SubViewVerificationResult::AffineMapMismatch: 2040 return op.emitError("expected result type to be ") 2041 << expectedType 2042 << " or a rank-reduced version. (mismatch of result affine map) " 2043 << errMsg; 2044 } 2045 llvm_unreachable("unexpected subview verification result"); 2046 } 2047 2048 /// Verifier for SubViewOp. 2049 static LogicalResult verify(SubViewOp op) { 2050 MemRefType baseType = op.getSourceType(); 2051 MemRefType subViewType = op.getType(); 2052 2053 // The base memref and the view memref should be in the same memory space. 2054 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 2055 return op.emitError("different memory spaces specified for base memref " 2056 "type ") 2057 << baseType << " and subview memref type " << subViewType; 2058 2059 // Verify that the base memref type has a strided layout map. 2060 if (!isStrided(baseType)) 2061 return op.emitError("base type ") << baseType << " is not strided"; 2062 2063 // Verify result type against inferred type. 2064 auto expectedType = SubViewOp::inferResultType( 2065 baseType, extractFromI64ArrayAttr(op.static_offsets()), 2066 extractFromI64ArrayAttr(op.static_sizes()), 2067 extractFromI64ArrayAttr(op.static_strides())); 2068 2069 std::string errMsg; 2070 auto result = 2071 isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg); 2072 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 2073 } 2074 2075 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { 2076 return os << "range " << range.offset << ":" << range.size << ":" 2077 << range.stride; 2078 } 2079 2080 /// Return the list of Range (i.e. offset, size, stride). Each Range 2081 /// entry contains either the dynamic value or a ConstantIndexOp constructed 2082 /// with `b` at location `loc`. 2083 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 2084 OpBuilder &b, Location loc) { 2085 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 2086 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 2087 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 2088 SmallVector<Range, 8> res; 2089 unsigned rank = ranks[0]; 2090 res.reserve(rank); 2091 for (unsigned idx = 0; idx < rank; ++idx) { 2092 Value offset = 2093 op.isDynamicOffset(idx) 2094 ? op.getDynamicOffset(idx) 2095 : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); 2096 Value size = 2097 op.isDynamicSize(idx) 2098 ? op.getDynamicSize(idx) 2099 : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); 2100 Value stride = 2101 op.isDynamicStride(idx) 2102 ? op.getDynamicStride(idx) 2103 : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); 2104 res.emplace_back(Range{offset, size, stride}); 2105 } 2106 return res; 2107 } 2108 2109 /// Infer the canonical type of the result of a subview operation. Returns a 2110 /// type with rank `resultRank` that is either the rank of the rank-reduced 2111 /// type, or the non-rank-reduced type. 2112 static MemRefType 2113 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 2114 ArrayRef<OpFoldResult> mixedOffsets, 2115 ArrayRef<OpFoldResult> mixedSizes, 2116 ArrayRef<OpFoldResult> mixedStrides) { 2117 auto resultType = 2118 SubViewOp::inferRankReducedResultType( 2119 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 2120 .cast<MemRefType>(); 2121 if (resultType.getRank() != resultRank) { 2122 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 2123 mixedSizes, mixedStrides) 2124 .cast<MemRefType>(); 2125 } 2126 return resultType; 2127 } 2128 2129 namespace { 2130 /// Pattern to rewrite a subview op with MemRefCast arguments. 2131 /// This essentially pushes memref.cast past its consuming subview when 2132 /// `canFoldIntoConsumerOp` is true. 2133 /// 2134 /// Example: 2135 /// ``` 2136 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2137 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2138 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2139 /// ``` 2140 /// is rewritten into: 2141 /// ``` 2142 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2143 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2144 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2145 /// ``` 2146 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2147 public: 2148 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2149 2150 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2151 PatternRewriter &rewriter) const override { 2152 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2153 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2154 return matchPattern(operand, matchConstantIndex()); 2155 })) 2156 return failure(); 2157 2158 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2159 if (!castOp) 2160 return failure(); 2161 2162 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2163 return failure(); 2164 2165 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 2166 /// the cast source operand type and the SubViewOp static information. This 2167 /// is the resulting type if the MemRefCastOp were folded. 2168 auto resultType = getCanonicalSubViewResultType( 2169 subViewOp.getType().getRank(), 2170 castOp.source().getType().cast<MemRefType>(), 2171 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2172 subViewOp.getMixedStrides()); 2173 Value newSubView = rewriter.create<SubViewOp>( 2174 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2175 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2176 subViewOp.static_sizes(), subViewOp.static_strides()); 2177 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2178 newSubView); 2179 return success(); 2180 } 2181 }; 2182 } // namespace 2183 2184 /// Return the canonical type of the result of a subview. 2185 struct SubViewReturnTypeCanonicalizer { 2186 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2187 ArrayRef<OpFoldResult> mixedSizes, 2188 ArrayRef<OpFoldResult> mixedStrides) { 2189 return getCanonicalSubViewResultType(op.getType().getRank(), 2190 op.getSourceType(), mixedOffsets, 2191 mixedSizes, mixedStrides); 2192 } 2193 }; 2194 2195 /// A canonicalizer wrapper to replace SubViewOps. 2196 struct SubViewCanonicalizer { 2197 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2198 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2199 } 2200 }; 2201 2202 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2203 MLIRContext *context) { 2204 results 2205 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2206 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2207 SubViewOpMemRefCastFolder>(context); 2208 } 2209 2210 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2211 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2212 auto sourceShapedType = source().getType().cast<ShapedType>(); 2213 2214 if (resultShapedType.hasStaticShape() && 2215 resultShapedType == sourceShapedType) { 2216 return getViewSource(); 2217 } 2218 2219 return {}; 2220 } 2221 2222 //===----------------------------------------------------------------------===// 2223 // TensorLoadOp 2224 //===----------------------------------------------------------------------===// 2225 2226 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 2227 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 2228 // Approximate alias analysis by conservatively folding only when no there 2229 // is no interleaved operation. 2230 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 2231 bufferCast->getNextNode() == this->getOperation()) 2232 return bufferCast.tensor(); 2233 return {}; 2234 } 2235 2236 namespace { 2237 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> { 2238 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 2239 2240 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 2241 PatternRewriter &rewriter) const override { 2242 auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>(); 2243 if (!tensorLoadOp) 2244 return failure(); 2245 2246 rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(), 2247 dimOp.index()); 2248 return success(); 2249 } 2250 }; 2251 } // namespace 2252 2253 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 2254 MLIRContext *context) { 2255 results.add<DimOfTensorLoadFolder>(context); 2256 } 2257 2258 //===----------------------------------------------------------------------===// 2259 // TransposeOp 2260 //===----------------------------------------------------------------------===// 2261 2262 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2263 static MemRefType inferTransposeResultType(MemRefType memRefType, 2264 AffineMap permutationMap) { 2265 auto rank = memRefType.getRank(); 2266 auto originalSizes = memRefType.getShape(); 2267 // Compute permuted sizes. 2268 SmallVector<int64_t, 4> sizes(rank, 0); 2269 for (auto en : llvm::enumerate(permutationMap.getResults())) 2270 sizes[en.index()] = 2271 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2272 2273 // Compute permuted strides. 2274 int64_t offset; 2275 SmallVector<int64_t, 4> strides; 2276 auto res = getStridesAndOffset(memRefType, strides, offset); 2277 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2278 (void)res; 2279 auto map = 2280 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2281 map = permutationMap ? map.compose(permutationMap) : map; 2282 return MemRefType::Builder(memRefType) 2283 .setShape(sizes) 2284 .setLayout(AffineMapAttr::get(map)); 2285 } 2286 2287 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2288 AffineMapAttr permutation, 2289 ArrayRef<NamedAttribute> attrs) { 2290 auto permutationMap = permutation.getValue(); 2291 assert(permutationMap); 2292 2293 auto memRefType = in.getType().cast<MemRefType>(); 2294 // Compute result type. 2295 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2296 2297 build(b, result, resultType, in, attrs); 2298 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2299 } 2300 2301 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2302 static void print(OpAsmPrinter &p, TransposeOp op) { 2303 p << " " << op.in() << " " << op.permutation(); 2304 p.printOptionalAttrDict(op->getAttrs(), 2305 {TransposeOp::getPermutationAttrName()}); 2306 p << " : " << op.in().getType() << " to " << op.getType(); 2307 } 2308 2309 static ParseResult parseTransposeOp(OpAsmParser &parser, 2310 OperationState &result) { 2311 OpAsmParser::OperandType in; 2312 AffineMap permutation; 2313 MemRefType srcType, dstType; 2314 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2315 parser.parseOptionalAttrDict(result.attributes) || 2316 parser.parseColonType(srcType) || 2317 parser.resolveOperand(in, srcType, result.operands) || 2318 parser.parseKeywordType("to", dstType) || 2319 parser.addTypeToList(dstType, result.types)) 2320 return failure(); 2321 2322 result.addAttribute(TransposeOp::getPermutationAttrName(), 2323 AffineMapAttr::get(permutation)); 2324 return success(); 2325 } 2326 2327 static LogicalResult verify(TransposeOp op) { 2328 if (!op.permutation().isPermutation()) 2329 return op.emitOpError("expected a permutation map"); 2330 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2331 return op.emitOpError( 2332 "expected a permutation map of same rank as the input"); 2333 2334 auto srcType = op.in().getType().cast<MemRefType>(); 2335 auto dstType = op.getType().cast<MemRefType>(); 2336 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2337 if (dstType != transposedType) 2338 return op.emitOpError("output type ") 2339 << dstType << " does not match transposed input type " << srcType 2340 << ", " << transposedType; 2341 return success(); 2342 } 2343 2344 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2345 if (succeeded(foldMemRefCast(*this))) 2346 return getResult(); 2347 return {}; 2348 } 2349 2350 //===----------------------------------------------------------------------===// 2351 // ViewOp 2352 //===----------------------------------------------------------------------===// 2353 2354 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2355 OpAsmParser::OperandType srcInfo; 2356 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2357 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2358 auto indexType = parser.getBuilder().getIndexType(); 2359 Type srcType, dstType; 2360 llvm::SMLoc offsetLoc; 2361 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2362 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2363 return failure(); 2364 2365 if (offsetInfo.size() != 1) 2366 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2367 2368 return failure( 2369 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2370 parser.parseOptionalAttrDict(result.attributes) || 2371 parser.parseColonType(srcType) || 2372 parser.resolveOperand(srcInfo, srcType, result.operands) || 2373 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2374 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2375 parser.parseKeywordType("to", dstType) || 2376 parser.addTypeToList(dstType, result.types)); 2377 } 2378 2379 static void print(OpAsmPrinter &p, ViewOp op) { 2380 p << ' ' << op.getOperand(0) << '['; 2381 p.printOperand(op.byte_shift()); 2382 p << "][" << op.sizes() << ']'; 2383 p.printOptionalAttrDict(op->getAttrs()); 2384 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2385 } 2386 2387 static LogicalResult verify(ViewOp op) { 2388 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2389 auto viewType = op.getType(); 2390 2391 // The base memref should have identity layout map (or none). 2392 if (!baseType.getLayout().isIdentity()) 2393 return op.emitError("unsupported map for base memref type ") << baseType; 2394 2395 // The result memref should have identity layout map (or none). 2396 if (!viewType.getLayout().isIdentity()) 2397 return op.emitError("unsupported map for result memref type ") << viewType; 2398 2399 // The base memref and the view memref should be in the same memory space. 2400 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2401 return op.emitError("different memory spaces specified for base memref " 2402 "type ") 2403 << baseType << " and view memref type " << viewType; 2404 2405 // Verify that we have the correct number of sizes for the result type. 2406 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2407 if (op.sizes().size() != numDynamicDims) 2408 return op.emitError("incorrect number of size operands for type ") 2409 << viewType; 2410 2411 return success(); 2412 } 2413 2414 Value ViewOp::getViewSource() { return source(); } 2415 2416 namespace { 2417 2418 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2419 using OpRewritePattern<ViewOp>::OpRewritePattern; 2420 2421 LogicalResult matchAndRewrite(ViewOp viewOp, 2422 PatternRewriter &rewriter) const override { 2423 // Return if none of the operands are constants. 2424 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2425 return matchPattern(operand, matchConstantIndex()); 2426 })) 2427 return failure(); 2428 2429 // Get result memref type. 2430 auto memrefType = viewOp.getType(); 2431 2432 // Get offset from old memref view type 'memRefType'. 2433 int64_t oldOffset; 2434 SmallVector<int64_t, 4> oldStrides; 2435 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2436 return failure(); 2437 assert(oldOffset == 0 && "Expected 0 offset"); 2438 2439 SmallVector<Value, 4> newOperands; 2440 2441 // Offset cannot be folded into result type. 2442 2443 // Fold any dynamic dim operands which are produced by a constant. 2444 SmallVector<int64_t, 4> newShapeConstants; 2445 newShapeConstants.reserve(memrefType.getRank()); 2446 2447 unsigned dynamicDimPos = 0; 2448 unsigned rank = memrefType.getRank(); 2449 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2450 int64_t dimSize = memrefType.getDimSize(dim); 2451 // If this is already static dimension, keep it. 2452 if (!ShapedType::isDynamic(dimSize)) { 2453 newShapeConstants.push_back(dimSize); 2454 continue; 2455 } 2456 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2457 if (auto constantIndexOp = 2458 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) { 2459 // Dynamic shape dimension will be folded. 2460 newShapeConstants.push_back(constantIndexOp.value()); 2461 } else { 2462 // Dynamic shape dimension not folded; copy operand from old memref. 2463 newShapeConstants.push_back(dimSize); 2464 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2465 } 2466 dynamicDimPos++; 2467 } 2468 2469 // Create new memref type with constant folded dims. 2470 MemRefType newMemRefType = 2471 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2472 // Nothing new, don't fold. 2473 if (newMemRefType == memrefType) 2474 return failure(); 2475 2476 // Create new ViewOp. 2477 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2478 viewOp.getOperand(0), 2479 viewOp.byte_shift(), newOperands); 2480 // Insert a cast so we have the same type as the old memref type. 2481 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2482 return success(); 2483 } 2484 }; 2485 2486 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2487 using OpRewritePattern<ViewOp>::OpRewritePattern; 2488 2489 LogicalResult matchAndRewrite(ViewOp viewOp, 2490 PatternRewriter &rewriter) const override { 2491 Value memrefOperand = viewOp.getOperand(0); 2492 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2493 if (!memrefCastOp) 2494 return failure(); 2495 Value allocOperand = memrefCastOp.getOperand(); 2496 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2497 if (!allocOp) 2498 return failure(); 2499 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2500 viewOp.byte_shift(), viewOp.sizes()); 2501 return success(); 2502 } 2503 }; 2504 2505 } // end anonymous namespace 2506 2507 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2508 MLIRContext *context) { 2509 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2510 } 2511 2512 //===----------------------------------------------------------------------===// 2513 // TableGen'd op method definitions 2514 //===----------------------------------------------------------------------===// 2515 2516 #define GET_OP_CLASSES 2517 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2518