1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 return builder.create<mlir::ConstantOp>(loc, type, value); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Common canonicalization pattern support logic 38 //===----------------------------------------------------------------------===// 39 40 /// This is a common class used for patterns of the form 41 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 42 /// into the root operation directly. 43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 44 bool folded = false; 45 for (OpOperand &operand : op->getOpOperands()) { 46 auto cast = operand.get().getDefiningOp<CastOp>(); 47 if (cast && operand.get() != inner && 48 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 49 operand.set(cast.getOperand()); 50 folded = true; 51 } 52 } 53 return success(folded); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Helpers for GlobalOp 58 //===----------------------------------------------------------------------===// 59 60 static Type getTensorTypeFromMemRefType(Type type) { 61 if (auto memref = type.dyn_cast<MemRefType>()) 62 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 63 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 64 return UnrankedTensorType::get(memref.getElementType()); 65 return NoneType::get(type.getContext()); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // AllocOp / AllocaOp 70 //===----------------------------------------------------------------------===// 71 72 template <typename AllocLikeOp> 73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 74 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 75 "applies to only alloc or alloca"); 76 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 77 if (!memRefType) 78 return op.emitOpError("result must be a memref"); 79 80 if (static_cast<int64_t>(op.dynamicSizes().size()) != 81 memRefType.getNumDynamicDims()) 82 return op.emitOpError("dimension operand count does not equal memref " 83 "dynamic dimension count"); 84 85 unsigned numSymbols = 0; 86 if (!memRefType.getAffineMaps().empty()) 87 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 88 if (op.symbolOperands().size() != numSymbols) 89 return op.emitOpError("symbol operand count does not equal memref symbol " 90 "count: expected ") 91 << numSymbols << ", got " << op.symbolOperands().size(); 92 93 return success(); 94 } 95 96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 97 98 static LogicalResult verify(AllocaOp op) { 99 // An alloca op needs to have an ancestor with an allocation scope trait. 100 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 101 return op.emitOpError( 102 "requires an ancestor op with AutomaticAllocationScope trait"); 103 104 return verifyAllocLikeOp(op); 105 } 106 107 namespace { 108 /// Fold constant dimensions into an alloc like operation. 109 template <typename AllocLikeOp> 110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 111 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(AllocLikeOp alloc, 114 PatternRewriter &rewriter) const override { 115 // Check to see if any dimensions operands are constants. If so, we can 116 // substitute and drop them. 117 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 118 return matchPattern(operand, matchConstantIndex()); 119 })) 120 return failure(); 121 122 auto memrefType = alloc.getType(); 123 124 // Ok, we have one or more constant operands. Collect the non-constant ones 125 // and keep track of the resultant memref type to build. 126 SmallVector<int64_t, 4> newShapeConstants; 127 newShapeConstants.reserve(memrefType.getRank()); 128 SmallVector<Value, 4> dynamicSizes; 129 130 unsigned dynamicDimPos = 0; 131 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 132 int64_t dimSize = memrefType.getDimSize(dim); 133 // If this is already static dimension, keep it. 134 if (dimSize != -1) { 135 newShapeConstants.push_back(dimSize); 136 continue; 137 } 138 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 139 auto *defOp = dynamicSize.getDefiningOp(); 140 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 141 // Dynamic shape dimension will be folded. 142 newShapeConstants.push_back(constantIndexOp.getValue()); 143 } else { 144 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 145 newShapeConstants.push_back(-1); 146 dynamicSizes.push_back(dynamicSize); 147 } 148 dynamicDimPos++; 149 } 150 151 // Create new memref type (which will have fewer dynamic dimensions). 152 MemRefType newMemRefType = 153 MemRefType::Builder(memrefType).setShape(newShapeConstants); 154 assert(static_cast<int64_t>(dynamicSizes.size()) == 155 newMemRefType.getNumDynamicDims()); 156 157 // Create and insert the alloc op for the new memref. 158 auto newAlloc = rewriter.create<AllocLikeOp>( 159 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 160 alloc.alignmentAttr()); 161 // Insert a cast so we have the same type as the old alloc. 162 auto resultCast = 163 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 164 165 rewriter.replaceOp(alloc, {resultCast}); 166 return success(); 167 } 168 }; 169 170 /// Fold alloc operations with no users or only store and dealloc uses. 171 template <typename T> 172 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 173 using OpRewritePattern<T>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(T alloc, 176 PatternRewriter &rewriter) const override { 177 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 178 if (auto storeOp = dyn_cast<StoreOp>(op)) 179 return storeOp.value() == alloc; 180 return !isa<DeallocOp>(op); 181 })) 182 return failure(); 183 184 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 185 rewriter.eraseOp(user); 186 187 rewriter.eraseOp(alloc); 188 return success(); 189 } 190 }; 191 } // end anonymous namespace. 192 193 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 194 MLIRContext *context) { 195 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 196 } 197 198 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 199 MLIRContext *context) { 200 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 201 context); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // AllocaScopeOp 206 //===----------------------------------------------------------------------===// 207 208 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 209 bool printBlockTerminators = false; 210 211 p << AllocaScopeOp::getOperationName() << " "; 212 if (!op.results().empty()) { 213 p << " -> (" << op.getResultTypes() << ")"; 214 printBlockTerminators = true; 215 } 216 p.printRegion(op.bodyRegion(), 217 /*printEntryBlockArgs=*/false, 218 /*printBlockTerminators=*/printBlockTerminators); 219 p.printOptionalAttrDict(op->getAttrs()); 220 } 221 222 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 223 OperationState &result) { 224 // Create a region for the body. 225 result.regions.reserve(1); 226 Region *bodyRegion = result.addRegion(); 227 228 // Parse optional results type list. 229 if (parser.parseOptionalArrowTypeList(result.types)) 230 return failure(); 231 232 // Parse the body region. 233 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 234 return failure(); 235 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 236 result.location); 237 238 // Parse the optional attribute list. 239 if (parser.parseOptionalAttrDict(result.attributes)) 240 return failure(); 241 242 return success(); 243 } 244 245 static LogicalResult verify(AllocaScopeOp op) { 246 if (failed(RegionBranchOpInterface::verifyTypes(op))) 247 return failure(); 248 249 return success(); 250 } 251 252 void AllocaScopeOp::getSuccessorRegions( 253 Optional<unsigned> index, ArrayRef<Attribute> operands, 254 SmallVectorImpl<RegionSuccessor> ®ions) { 255 if (index.hasValue()) { 256 regions.push_back(RegionSuccessor(getResults())); 257 return; 258 } 259 260 regions.push_back(RegionSuccessor(&bodyRegion())); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // AssumeAlignmentOp 265 //===----------------------------------------------------------------------===// 266 267 static LogicalResult verify(AssumeAlignmentOp op) { 268 unsigned alignment = op.alignment(); 269 if (!llvm::isPowerOf2_32(alignment)) 270 return op.emitOpError("alignment must be power of 2"); 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // BufferCastOp 276 //===----------------------------------------------------------------------===// 277 278 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 279 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 280 if (tensorLoad.memref().getType() == getType()) 281 return tensorLoad.memref(); 282 return {}; 283 } 284 285 namespace { 286 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 287 struct BufferCast : public OpRewritePattern<BufferCastOp> { 288 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 289 290 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 291 PatternRewriter &rewriter) const final { 292 auto tensorCastOperand = 293 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 294 if (!tensorCastOperand) 295 return failure(); 296 auto srcTensorType = 297 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 298 if (!srcTensorType) 299 return failure(); 300 auto memrefType = MemRefType::get(srcTensorType.getShape(), 301 srcTensorType.getElementType()); 302 Value memref = rewriter.create<BufferCastOp>( 303 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 304 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 305 memref); 306 return success(); 307 } 308 }; 309 310 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 311 /// type mismatches prevent `BufferCastOp::fold` to kick in. 312 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 313 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 316 PatternRewriter &rewriter) const final { 317 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 318 // Bail unless we have a tensor_load + memref.buffer_cast with different 319 // types. `BufferCastOp::fold` handles the same type case. 320 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 321 return failure(); 322 // If types are not cast-compatible, bail. 323 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 324 bufferCast.getType())) 325 return failure(); 326 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 327 tensorLoad.memref()); 328 return success(); 329 } 330 }; 331 332 } // namespace 333 334 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 335 MLIRContext *context) { 336 results.add<BufferCast, TensorLoadToMemRef>(context); 337 } 338 339 //===----------------------------------------------------------------------===// 340 // CastOp 341 //===----------------------------------------------------------------------===// 342 343 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 344 /// source memref. This is useful to to fold a memref.cast into a consuming op 345 /// and implement canonicalization patterns for ops in different dialects that 346 /// may consume the results of memref.cast operations. Such foldable memref.cast 347 /// operations are typically inserted as `view` and `subview` ops are 348 /// canonicalized, to preserve the type compatibility of their uses. 349 /// 350 /// Returns true when all conditions are met: 351 /// 1. source and result are ranked memrefs with strided semantics and same 352 /// element type and rank. 353 /// 2. each of the source's size, offset or stride has more static information 354 /// than the corresponding result's size, offset or stride. 355 /// 356 /// Example 1: 357 /// ```mlir 358 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 359 /// %2 = consumer %1 ... : memref<?x?xf32> ... 360 /// ``` 361 /// 362 /// may fold into: 363 /// 364 /// ```mlir 365 /// %2 = consumer %0 ... : memref<8x16xf32> ... 366 /// ``` 367 /// 368 /// Example 2: 369 /// ``` 370 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 371 /// to memref<?x?xf32> 372 /// consumer %1 : memref<?x?xf32> ... 373 /// ``` 374 /// 375 /// may fold into: 376 /// 377 /// ``` 378 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 379 /// ``` 380 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 381 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 382 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 383 384 // Requires ranked MemRefType. 385 if (!sourceType || !resultType) 386 return false; 387 388 // Requires same elemental type. 389 if (sourceType.getElementType() != resultType.getElementType()) 390 return false; 391 392 // Requires same rank. 393 if (sourceType.getRank() != resultType.getRank()) 394 return false; 395 396 // Only fold casts between strided memref forms. 397 int64_t sourceOffset, resultOffset; 398 SmallVector<int64_t, 4> sourceStrides, resultStrides; 399 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 400 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 401 return false; 402 403 // If cast is towards more static sizes along any dimension, don't fold. 404 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 405 auto ss = std::get<0>(it), st = std::get<1>(it); 406 if (ss != st) 407 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 408 return false; 409 } 410 411 // If cast is towards more static offset along any dimension, don't fold. 412 if (sourceOffset != resultOffset) 413 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 414 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 415 return false; 416 417 // If cast is towards more static strides along any dimension, don't fold. 418 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 419 auto ss = std::get<0>(it), st = std::get<1>(it); 420 if (ss != st) 421 if (MemRefType::isDynamicStrideOrOffset(ss) && 422 !MemRefType::isDynamicStrideOrOffset(st)) 423 return false; 424 } 425 426 return true; 427 } 428 429 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 430 if (inputs.size() != 1 || outputs.size() != 1) 431 return false; 432 Type a = inputs.front(), b = outputs.front(); 433 auto aT = a.dyn_cast<MemRefType>(); 434 auto bT = b.dyn_cast<MemRefType>(); 435 436 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 437 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 438 439 if (aT && bT) { 440 if (aT.getElementType() != bT.getElementType()) 441 return false; 442 if (aT.getAffineMaps() != bT.getAffineMaps()) { 443 int64_t aOffset, bOffset; 444 SmallVector<int64_t, 4> aStrides, bStrides; 445 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 446 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 447 aStrides.size() != bStrides.size()) 448 return false; 449 450 // Strides along a dimension/offset are compatible if the value in the 451 // source memref is static and the value in the target memref is the 452 // same. They are also compatible if either one is dynamic (see 453 // description of MemRefCastOp for details). 454 auto checkCompatible = [](int64_t a, int64_t b) { 455 return (a == MemRefType::getDynamicStrideOrOffset() || 456 b == MemRefType::getDynamicStrideOrOffset() || a == b); 457 }; 458 if (!checkCompatible(aOffset, bOffset)) 459 return false; 460 for (auto aStride : enumerate(aStrides)) 461 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 462 return false; 463 } 464 if (aT.getMemorySpace() != bT.getMemorySpace()) 465 return false; 466 467 // They must have the same rank, and any specified dimensions must match. 468 if (aT.getRank() != bT.getRank()) 469 return false; 470 471 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 472 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 473 if (aDim != -1 && bDim != -1 && aDim != bDim) 474 return false; 475 } 476 return true; 477 } else { 478 if (!aT && !uaT) 479 return false; 480 if (!bT && !ubT) 481 return false; 482 // Unranked to unranked casting is unsupported 483 if (uaT && ubT) 484 return false; 485 486 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 487 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 488 if (aEltType != bEltType) 489 return false; 490 491 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 492 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 493 if (aMemSpace != bMemSpace) 494 return false; 495 496 return true; 497 } 498 499 return false; 500 } 501 502 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 503 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // CloneOp 508 //===----------------------------------------------------------------------===// 509 510 void CloneOp::getEffects( 511 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 512 &effects) { 513 effects.emplace_back(MemoryEffects::Read::get(), input(), 514 SideEffects::DefaultResource::get()); 515 effects.emplace_back(MemoryEffects::Write::get(), output(), 516 SideEffects::DefaultResource::get()); 517 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 518 SideEffects::DefaultResource::get()); 519 } 520 521 namespace { 522 /// Merge the clone and its source (by converting the clone to a cast) when 523 /// possible. 524 struct SimplifyClones : public OpRewritePattern<CloneOp> { 525 using OpRewritePattern<CloneOp>::OpRewritePattern; 526 527 LogicalResult matchAndRewrite(CloneOp cloneOp, 528 PatternRewriter &rewriter) const override { 529 if (cloneOp.use_empty()) { 530 rewriter.eraseOp(cloneOp); 531 return success(); 532 } 533 534 Value source = cloneOp.input(); 535 536 // This only finds dealloc operations for the immediate value. It should 537 // also consider aliases. That would also make the safety check below 538 // redundant. 539 llvm::Optional<Operation *> maybeCloneDeallocOp = 540 findDealloc(cloneOp.output()); 541 // Skip if either of them has > 1 deallocate operations. 542 if (!maybeCloneDeallocOp.hasValue()) 543 return failure(); 544 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source); 545 if (!maybeSourceDeallocOp.hasValue()) 546 return failure(); 547 Operation *cloneDeallocOp = *maybeCloneDeallocOp; 548 Operation *sourceDeallocOp = *maybeSourceDeallocOp; 549 550 // If both are deallocated in the same block, their in-block lifetimes 551 // might not fully overlap, so we cannot decide which one to drop. 552 if (cloneDeallocOp && sourceDeallocOp && 553 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 554 return failure(); 555 556 Block *currentBlock = cloneOp->getBlock(); 557 Operation *redundantDealloc = nullptr; 558 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 559 redundantDealloc = cloneDeallocOp; 560 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 561 redundantDealloc = sourceDeallocOp; 562 } 563 564 if (!redundantDealloc) 565 return failure(); 566 567 // Safety check that there are no other deallocations inbetween 568 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 569 // of source before the uses of the clone. With alias information, we could 570 // restrict this to only fail of the dealloc's operand is an alias 571 // of the source. 572 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 573 pos = pos->getNextNode()) { 574 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 575 if (!effectInterface) 576 continue; 577 if (effectInterface.hasEffect<MemoryEffects::Free>()) 578 return failure(); 579 } 580 581 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 582 source); 583 rewriter.eraseOp(redundantDealloc); 584 return success(); 585 } 586 }; 587 588 } // end anonymous namespace. 589 590 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 591 MLIRContext *context) { 592 results.insert<SimplifyClones>(context); 593 } 594 595 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 596 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 597 } 598 599 //===----------------------------------------------------------------------===// 600 // DeallocOp 601 //===----------------------------------------------------------------------===// 602 603 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 604 SmallVectorImpl<OpFoldResult> &results) { 605 /// dealloc(memrefcast) -> dealloc 606 return foldMemRefCast(*this); 607 } 608 609 //===----------------------------------------------------------------------===// 610 // DimOp 611 //===----------------------------------------------------------------------===// 612 613 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 614 int64_t index) { 615 auto loc = result.location; 616 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 617 build(builder, result, source, indexValue); 618 } 619 620 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 621 Value index) { 622 auto indexTy = builder.getIndexType(); 623 build(builder, result, indexTy, source, index); 624 } 625 626 Optional<int64_t> DimOp::getConstantIndex() { 627 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 628 return constantOp.getValue().cast<IntegerAttr>().getInt(); 629 return {}; 630 } 631 632 static LogicalResult verify(DimOp op) { 633 // Assume unknown index to be in range. 634 Optional<int64_t> index = op.getConstantIndex(); 635 if (!index.hasValue()) 636 return success(); 637 638 // Check that constant index is not knowingly out of range. 639 auto type = op.source().getType(); 640 if (auto memrefType = type.dyn_cast<MemRefType>()) { 641 if (index.getValue() >= memrefType.getRank()) 642 return op.emitOpError("index is out of range"); 643 } else if (type.isa<UnrankedMemRefType>()) { 644 // Assume index to be in range. 645 } else { 646 llvm_unreachable("expected operand with memref type"); 647 } 648 return success(); 649 } 650 651 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 652 // All forms of folding require a known index. 653 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 654 if (!index) 655 return {}; 656 657 // Folding for unranked types (UnrankedMemRefType) is not supported. 658 auto memrefType = source().getType().dyn_cast<MemRefType>(); 659 if (!memrefType) 660 return {}; 661 662 // Fold if the shape extent along the given index is known. 663 if (!memrefType.isDynamicDim(index.getInt())) { 664 Builder builder(getContext()); 665 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 666 } 667 668 // The size at the given index is now known to be a dynamic size. 669 unsigned unsignedIndex = index.getValue().getZExtValue(); 670 671 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 672 Operation *definingOp = source().getDefiningOp(); 673 674 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 675 return *(alloc.getDynamicSizes().begin() + 676 memrefType.getDynamicDimIndex(unsignedIndex)); 677 678 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 679 return *(alloca.getDynamicSizes().begin() + 680 memrefType.getDynamicDimIndex(unsignedIndex)); 681 682 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 683 return *(view.getDynamicSizes().begin() + 684 memrefType.getDynamicDimIndex(unsignedIndex)); 685 686 if (auto sizeInterface = 687 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 688 assert(sizeInterface.isDynamicSize(unsignedIndex) && 689 "Expected dynamic subview size"); 690 return sizeInterface.getDynamicSize(unsignedIndex); 691 } 692 693 // dim(memrefcast) -> dim 694 if (succeeded(foldMemRefCast(*this))) 695 return getResult(); 696 697 return {}; 698 } 699 700 namespace { 701 /// Fold dim of a memref reshape operation to a load into the reshape's shape 702 /// operand. 703 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 704 using OpRewritePattern<DimOp>::OpRewritePattern; 705 706 LogicalResult matchAndRewrite(DimOp dim, 707 PatternRewriter &rewriter) const override { 708 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 709 710 if (!reshape) 711 return failure(); 712 713 // Place the load directly after the reshape to ensure that the shape memref 714 // was not mutated. 715 rewriter.setInsertionPointAfter(reshape); 716 Location loc = dim.getLoc(); 717 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 718 if (load.getType() != dim.getType()) 719 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load); 720 rewriter.replaceOp(dim, load); 721 return success(); 722 } 723 }; 724 725 /// Fold dim of a cast into the dim of the source of the memref cast. 726 struct DimOfCastOp : public OpRewritePattern<DimOp> { 727 using OpRewritePattern<DimOp>::OpRewritePattern; 728 729 LogicalResult matchAndRewrite(DimOp dimOp, 730 PatternRewriter &rewriter) const override { 731 auto castOp = dimOp.source().getDefiningOp<BufferCastOp>(); 732 if (!castOp) 733 return failure(); 734 Value newSource = castOp.getOperand(); 735 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 736 return success(); 737 } 738 }; 739 } // end anonymous namespace. 740 741 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 742 MLIRContext *context) { 743 results.add<DimOfMemRefReshape, DimOfCastOp>(context); 744 } 745 746 // --------------------------------------------------------------------------- 747 // DmaStartOp 748 // --------------------------------------------------------------------------- 749 750 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 751 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 752 ValueRange destIndices, Value numElements, 753 Value tagMemRef, ValueRange tagIndices, Value stride, 754 Value elementsPerStride) { 755 result.addOperands(srcMemRef); 756 result.addOperands(srcIndices); 757 result.addOperands(destMemRef); 758 result.addOperands(destIndices); 759 result.addOperands({numElements, tagMemRef}); 760 result.addOperands(tagIndices); 761 if (stride) 762 result.addOperands({stride, elementsPerStride}); 763 } 764 765 void DmaStartOp::print(OpAsmPrinter &p) { 766 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 767 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 768 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 769 << ']'; 770 if (isStrided()) 771 p << ", " << getStride() << ", " << getNumElementsPerStride(); 772 773 p.printOptionalAttrDict((*this)->getAttrs()); 774 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 775 << ", " << getTagMemRef().getType(); 776 } 777 778 // Parse DmaStartOp. 779 // Ex: 780 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 781 // %tag[%index], %stride, %num_elt_per_stride : 782 // : memref<3076 x f32, 0>, 783 // memref<1024 x f32, 2>, 784 // memref<1 x i32> 785 // 786 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 787 OpAsmParser::OperandType srcMemRefInfo; 788 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 789 OpAsmParser::OperandType dstMemRefInfo; 790 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 791 OpAsmParser::OperandType numElementsInfo; 792 OpAsmParser::OperandType tagMemrefInfo; 793 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 794 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 795 796 SmallVector<Type, 3> types; 797 auto indexType = parser.getBuilder().getIndexType(); 798 799 // Parse and resolve the following list of operands: 800 // *) source memref followed by its indices (in square brackets). 801 // *) destination memref followed by its indices (in square brackets). 802 // *) dma size in KiB. 803 if (parser.parseOperand(srcMemRefInfo) || 804 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 805 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 806 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 807 parser.parseComma() || parser.parseOperand(numElementsInfo) || 808 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 809 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 810 return failure(); 811 812 // Parse optional stride and elements per stride. 813 if (parser.parseTrailingOperandList(strideInfo)) 814 return failure(); 815 816 bool isStrided = strideInfo.size() == 2; 817 if (!strideInfo.empty() && !isStrided) { 818 return parser.emitError(parser.getNameLoc(), 819 "expected two stride related operands"); 820 } 821 822 if (parser.parseColonTypeList(types)) 823 return failure(); 824 if (types.size() != 3) 825 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 826 827 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 828 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 829 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 830 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 831 // size should be an index. 832 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 833 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 834 // tag indices should be index. 835 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 836 return failure(); 837 838 if (isStrided) { 839 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 840 return failure(); 841 } 842 843 return success(); 844 } 845 846 LogicalResult DmaStartOp::verify() { 847 unsigned numOperands = getNumOperands(); 848 849 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 850 // the number of elements. 851 if (numOperands < 4) 852 return emitOpError("expected at least 4 operands"); 853 854 // Check types of operands. The order of these calls is important: the later 855 // calls rely on some type properties to compute the operand position. 856 // 1. Source memref. 857 if (!getSrcMemRef().getType().isa<MemRefType>()) 858 return emitOpError("expected source to be of memref type"); 859 if (numOperands < getSrcMemRefRank() + 4) 860 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 861 << " operands"; 862 if (!getSrcIndices().empty() && 863 !llvm::all_of(getSrcIndices().getTypes(), 864 [](Type t) { return t.isIndex(); })) 865 return emitOpError("expected source indices to be of index type"); 866 867 // 2. Destination memref. 868 if (!getDstMemRef().getType().isa<MemRefType>()) 869 return emitOpError("expected destination to be of memref type"); 870 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 871 if (numOperands < numExpectedOperands) 872 return emitOpError() << "expected at least " << numExpectedOperands 873 << " operands"; 874 if (!getDstIndices().empty() && 875 !llvm::all_of(getDstIndices().getTypes(), 876 [](Type t) { return t.isIndex(); })) 877 return emitOpError("expected destination indices to be of index type"); 878 879 // 3. Number of elements. 880 if (!getNumElements().getType().isIndex()) 881 return emitOpError("expected num elements to be of index type"); 882 883 // 4. Tag memref. 884 if (!getTagMemRef().getType().isa<MemRefType>()) 885 return emitOpError("expected tag to be of memref type"); 886 numExpectedOperands += getTagMemRefRank(); 887 if (numOperands < numExpectedOperands) 888 return emitOpError() << "expected at least " << numExpectedOperands 889 << " operands"; 890 if (!getTagIndices().empty() && 891 !llvm::all_of(getTagIndices().getTypes(), 892 [](Type t) { return t.isIndex(); })) 893 return emitOpError("expected tag indices to be of index type"); 894 895 // Optional stride-related operands must be either both present or both 896 // absent. 897 if (numOperands != numExpectedOperands && 898 numOperands != numExpectedOperands + 2) 899 return emitOpError("incorrect number of operands"); 900 901 // 5. Strides. 902 if (isStrided()) { 903 if (!getStride().getType().isIndex() || 904 !getNumElementsPerStride().getType().isIndex()) 905 return emitOpError( 906 "expected stride and num elements per stride to be of type index"); 907 } 908 909 return success(); 910 } 911 912 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 913 SmallVectorImpl<OpFoldResult> &results) { 914 /// dma_start(memrefcast) -> dma_start 915 return foldMemRefCast(*this); 916 } 917 918 // --------------------------------------------------------------------------- 919 // DmaWaitOp 920 // --------------------------------------------------------------------------- 921 922 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 923 Value tagMemRef, ValueRange tagIndices, 924 Value numElements) { 925 result.addOperands(tagMemRef); 926 result.addOperands(tagIndices); 927 result.addOperands(numElements); 928 } 929 930 void DmaWaitOp::print(OpAsmPrinter &p) { 931 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 932 << "], " << getNumElements(); 933 p.printOptionalAttrDict((*this)->getAttrs()); 934 p << " : " << getTagMemRef().getType(); 935 } 936 937 // Parse DmaWaitOp. 938 // Eg: 939 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 940 // 941 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 942 OpAsmParser::OperandType tagMemrefInfo; 943 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 944 Type type; 945 auto indexType = parser.getBuilder().getIndexType(); 946 OpAsmParser::OperandType numElementsInfo; 947 948 // Parse tag memref, its indices, and dma size. 949 if (parser.parseOperand(tagMemrefInfo) || 950 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 951 parser.parseComma() || parser.parseOperand(numElementsInfo) || 952 parser.parseColonType(type) || 953 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 954 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 955 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 956 return failure(); 957 958 return success(); 959 } 960 961 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 962 SmallVectorImpl<OpFoldResult> &results) { 963 /// dma_wait(memrefcast) -> dma_wait 964 return foldMemRefCast(*this); 965 } 966 967 LogicalResult DmaWaitOp::verify() { 968 // Mandatory non-variadic operands are tag and the number of elements. 969 if (getNumOperands() < 2) 970 return emitOpError() << "expected at least 2 operands"; 971 972 // Check types of operands. The order of these calls is important: the later 973 // calls rely on some type properties to compute the operand position. 974 if (!getTagMemRef().getType().isa<MemRefType>()) 975 return emitOpError() << "expected tag to be of memref type"; 976 977 if (getNumOperands() != 2 + getTagMemRefRank()) 978 return emitOpError() << "expected " << 2 + getTagMemRefRank() 979 << " operands"; 980 981 if (!getTagIndices().empty() && 982 !llvm::all_of(getTagIndices().getTypes(), 983 [](Type t) { return t.isIndex(); })) 984 return emitOpError() << "expected tag indices to be of index type"; 985 986 if (!getNumElements().getType().isIndex()) 987 return emitOpError() 988 << "expected the number of elements to be of index type"; 989 990 return success(); 991 } 992 993 //===----------------------------------------------------------------------===// 994 // GlobalOp 995 //===----------------------------------------------------------------------===// 996 997 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 998 TypeAttr type, 999 Attribute initialValue) { 1000 p << type; 1001 if (!op.isExternal()) { 1002 p << " = "; 1003 if (op.isUninitialized()) 1004 p << "uninitialized"; 1005 else 1006 p.printAttributeWithoutType(initialValue); 1007 } 1008 } 1009 1010 static ParseResult 1011 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1012 Attribute &initialValue) { 1013 Type type; 1014 if (parser.parseType(type)) 1015 return failure(); 1016 1017 auto memrefType = type.dyn_cast<MemRefType>(); 1018 if (!memrefType || !memrefType.hasStaticShape()) 1019 return parser.emitError(parser.getNameLoc()) 1020 << "type should be static shaped memref, but got " << type; 1021 typeAttr = TypeAttr::get(type); 1022 1023 if (parser.parseOptionalEqual()) 1024 return success(); 1025 1026 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1027 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1028 return success(); 1029 } 1030 1031 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1032 if (parser.parseAttribute(initialValue, tensorType)) 1033 return failure(); 1034 if (!initialValue.isa<ElementsAttr>()) 1035 return parser.emitError(parser.getNameLoc()) 1036 << "initial value should be a unit or elements attribute"; 1037 return success(); 1038 } 1039 1040 static LogicalResult verify(GlobalOp op) { 1041 auto memrefType = op.type().dyn_cast<MemRefType>(); 1042 if (!memrefType || !memrefType.hasStaticShape()) 1043 return op.emitOpError("type should be static shaped memref, but got ") 1044 << op.type(); 1045 1046 // Verify that the initial value, if present, is either a unit attribute or 1047 // an elements attribute. 1048 if (op.initial_value().hasValue()) { 1049 Attribute initValue = op.initial_value().getValue(); 1050 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1051 return op.emitOpError("initial value should be a unit or elements " 1052 "attribute, but got ") 1053 << initValue; 1054 1055 // Check that the type of the initial value is compatible with the type of 1056 // the global variable. 1057 if (initValue.isa<ElementsAttr>()) { 1058 Type initType = initValue.getType(); 1059 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1060 if (initType != tensorType) 1061 return op.emitOpError("initial value expected to be of type ") 1062 << tensorType << ", but was of type " << initType; 1063 } 1064 } 1065 1066 // TODO: verify visibility for declarations. 1067 return success(); 1068 } 1069 1070 //===----------------------------------------------------------------------===// 1071 // GetGlobalOp 1072 //===----------------------------------------------------------------------===// 1073 1074 LogicalResult 1075 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1076 // Verify that the result type is same as the type of the referenced 1077 // memref.global op. 1078 auto global = 1079 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1080 if (!global) 1081 return emitOpError("'") 1082 << name() << "' does not reference a valid global memref"; 1083 1084 Type resultType = result().getType(); 1085 if (global.type() != resultType) 1086 return emitOpError("result type ") 1087 << resultType << " does not match type " << global.type() 1088 << " of the global memref @" << name(); 1089 return success(); 1090 } 1091 1092 //===----------------------------------------------------------------------===// 1093 // LoadOp 1094 //===----------------------------------------------------------------------===// 1095 1096 static LogicalResult verify(LoadOp op) { 1097 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1098 return op.emitOpError("incorrect number of indices for load"); 1099 return success(); 1100 } 1101 1102 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1103 /// load(memrefcast) -> load 1104 if (succeeded(foldMemRefCast(*this))) 1105 return getResult(); 1106 return OpFoldResult(); 1107 } 1108 1109 namespace { 1110 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1111 /// corresponding tensor. 1112 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1113 using OpRewritePattern<LoadOp>::OpRewritePattern; 1114 1115 LogicalResult matchAndRewrite(LoadOp load, 1116 PatternRewriter &rewriter) const override { 1117 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1118 if (!buffercast) 1119 return failure(); 1120 1121 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1122 load.indices()); 1123 return success(); 1124 } 1125 }; 1126 } // end anonymous namespace. 1127 1128 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1129 MLIRContext *context) { 1130 results.add<LoadOfBufferCast>(context); 1131 } 1132 1133 //===----------------------------------------------------------------------===// 1134 // PrefetchOp 1135 //===----------------------------------------------------------------------===// 1136 1137 static void print(OpAsmPrinter &p, PrefetchOp op) { 1138 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1139 p.printOperands(op.indices()); 1140 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1141 p << ", locality<" << op.localityHint(); 1142 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1143 p.printOptionalAttrDict( 1144 op->getAttrs(), 1145 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1146 p << " : " << op.getMemRefType(); 1147 } 1148 1149 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1150 OperationState &result) { 1151 OpAsmParser::OperandType memrefInfo; 1152 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1153 IntegerAttr localityHint; 1154 MemRefType type; 1155 StringRef readOrWrite, cacheType; 1156 1157 auto indexTy = parser.getBuilder().getIndexType(); 1158 auto i32Type = parser.getBuilder().getIntegerType(32); 1159 if (parser.parseOperand(memrefInfo) || 1160 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1161 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1162 parser.parseComma() || parser.parseKeyword("locality") || 1163 parser.parseLess() || 1164 parser.parseAttribute(localityHint, i32Type, "localityHint", 1165 result.attributes) || 1166 parser.parseGreater() || parser.parseComma() || 1167 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1168 parser.resolveOperand(memrefInfo, type, result.operands) || 1169 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1170 return failure(); 1171 1172 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1173 return parser.emitError(parser.getNameLoc(), 1174 "rw specifier has to be 'read' or 'write'"); 1175 result.addAttribute( 1176 PrefetchOp::getIsWriteAttrName(), 1177 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1178 1179 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1180 return parser.emitError(parser.getNameLoc(), 1181 "cache type has to be 'data' or 'instr'"); 1182 1183 result.addAttribute( 1184 PrefetchOp::getIsDataCacheAttrName(), 1185 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1186 1187 return success(); 1188 } 1189 1190 static LogicalResult verify(PrefetchOp op) { 1191 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1192 return op.emitOpError("too few indices"); 1193 1194 return success(); 1195 } 1196 1197 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1198 SmallVectorImpl<OpFoldResult> &results) { 1199 // prefetch(memrefcast) -> prefetch 1200 return foldMemRefCast(*this); 1201 } 1202 1203 //===----------------------------------------------------------------------===// 1204 // ReinterpretCastOp 1205 //===----------------------------------------------------------------------===// 1206 1207 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1208 /// `staticSizes` and `staticStrides` are automatically filled with 1209 /// source-memref-rank sentinel values that encode dynamic entries. 1210 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1211 MemRefType resultType, Value source, 1212 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1213 ArrayRef<OpFoldResult> strides, 1214 ArrayRef<NamedAttribute> attrs) { 1215 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1216 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1217 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1218 ShapedType::kDynamicStrideOrOffset); 1219 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1220 ShapedType::kDynamicSize); 1221 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1222 ShapedType::kDynamicStrideOrOffset); 1223 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1224 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1225 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1226 result.addAttributes(attrs); 1227 } 1228 1229 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1230 MemRefType resultType, Value source, 1231 int64_t offset, ArrayRef<int64_t> sizes, 1232 ArrayRef<int64_t> strides, 1233 ArrayRef<NamedAttribute> attrs) { 1234 SmallVector<OpFoldResult> sizeValues = 1235 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1236 return b.getI64IntegerAttr(v); 1237 })); 1238 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1239 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1240 return b.getI64IntegerAttr(v); 1241 })); 1242 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1243 strideValues, attrs); 1244 } 1245 1246 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1247 MemRefType resultType, Value source, Value offset, 1248 ValueRange sizes, ValueRange strides, 1249 ArrayRef<NamedAttribute> attrs) { 1250 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1251 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1252 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1253 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1254 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1255 } 1256 1257 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1258 // completed automatically, like we have for subview and extract_slice. 1259 static LogicalResult verify(ReinterpretCastOp op) { 1260 // The source and result memrefs should be in the same memory space. 1261 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1262 auto resultType = op.getType().cast<MemRefType>(); 1263 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1264 return op.emitError("different memory spaces specified for source type ") 1265 << srcType << " and result memref type " << resultType; 1266 if (srcType.getElementType() != resultType.getElementType()) 1267 return op.emitError("different element types specified for source type ") 1268 << srcType << " and result memref type " << resultType; 1269 1270 // Match sizes in result memref type and in static_sizes attribute. 1271 for (auto &en : 1272 llvm::enumerate(llvm::zip(resultType.getShape(), 1273 extractFromI64ArrayAttr(op.static_sizes())))) { 1274 int64_t resultSize = std::get<0>(en.value()); 1275 int64_t expectedSize = std::get<1>(en.value()); 1276 if (resultSize != expectedSize) 1277 return op.emitError("expected result type with size = ") 1278 << expectedSize << " instead of " << resultSize 1279 << " in dim = " << en.index(); 1280 } 1281 1282 // Match offset and strides in static_offset and static_strides attributes if 1283 // result memref type has an affine map specified. 1284 if (!resultType.getAffineMaps().empty()) { 1285 int64_t resultOffset; 1286 SmallVector<int64_t, 4> resultStrides; 1287 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1288 return failure(); 1289 1290 // Match offset in result memref type and in static_offsets attribute. 1291 int64_t expectedOffset = 1292 extractFromI64ArrayAttr(op.static_offsets()).front(); 1293 if (resultOffset != expectedOffset) 1294 return op.emitError("expected result type with offset = ") 1295 << resultOffset << " instead of " << expectedOffset; 1296 1297 // Match strides in result memref type and in static_strides attribute. 1298 for (auto &en : llvm::enumerate(llvm::zip( 1299 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1300 int64_t resultStride = std::get<0>(en.value()); 1301 int64_t expectedStride = std::get<1>(en.value()); 1302 if (resultStride != expectedStride) 1303 return op.emitError("expected result type with stride = ") 1304 << expectedStride << " instead of " << resultStride 1305 << " in dim = " << en.index(); 1306 } 1307 } 1308 return success(); 1309 } 1310 1311 //===----------------------------------------------------------------------===// 1312 // Reassociative reshape ops 1313 //===----------------------------------------------------------------------===// 1314 1315 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1316 return getSymbolLessAffineMaps(getReassociationExprs()); 1317 } 1318 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1319 OpBuilder b(this->getContext()); 1320 return convertReassociationIndicesToExprs(b, getReassociationIndices()); 1321 } 1322 1323 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1324 return getSymbolLessAffineMaps(getReassociationExprs()); 1325 } 1326 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1327 OpBuilder b(this->getContext()); 1328 return convertReassociationIndicesToExprs(b, getReassociationIndices()); 1329 } 1330 1331 static void print(OpAsmPrinter &p, ExpandShapeOp op) { 1332 ::mlir::printReshapeOp<ExpandShapeOp>(p, op); 1333 } 1334 1335 static void print(OpAsmPrinter &p, CollapseShapeOp op) { 1336 ::mlir::printReshapeOp<CollapseShapeOp>(p, op); 1337 } 1338 1339 /// Detect whether memref dims [dim, dim + extent) can be reshaped without 1340 /// copies. 1341 static bool isReshapableDimBand(unsigned dim, unsigned extent, 1342 ArrayRef<int64_t> sizes, 1343 ArrayRef<AffineExpr> strides) { 1344 assert(sizes.size() == strides.size() && "mismatched ranks"); 1345 // off by 1 indexing to avoid out of bounds 1346 // V 1347 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { 1348 // Only bands of static shapes are reshapable. This is due to the fact that 1349 // there is no relation between dynamic sizes and dynamic strides: we do not 1350 // have enough information to know whether a "-1" size corresponds to the 1351 // proper symbol in the AffineExpr of a stride. 1352 if (ShapedType::isDynamic(sizes[dim + 1])) 1353 return false; 1354 // TODO: Refine this by passing the proper nDims and nSymbols so we can 1355 // simplify on the fly and catch more reshapable cases. 1356 if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) 1357 return false; 1358 } 1359 return true; 1360 } 1361 1362 /// Compute the MemRefType obtained by applying the `reassociation` (which is 1363 /// expected to be valid) to `type`. 1364 /// If `type` is Contiguous MemRefType, this always produce a contiguous 1365 /// MemRefType. 1366 static MemRefType 1367 computeReshapeCollapsedType(MemRefType type, 1368 ArrayRef<AffineMap> reassociation) { 1369 auto sizes = type.getShape(); 1370 AffineExpr offset; 1371 SmallVector<AffineExpr, 4> strides; 1372 auto status = getStridesAndOffset(type, strides, offset); 1373 (void)status; 1374 assert(succeeded(status) && "expected strided memref"); 1375 1376 SmallVector<int64_t, 4> newSizes; 1377 newSizes.reserve(reassociation.size()); 1378 SmallVector<AffineExpr, 4> newStrides; 1379 newStrides.reserve(reassociation.size()); 1380 1381 // Use the fact that reassociation is valid to simplify the logic: only use 1382 // each map's rank. 1383 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1384 unsigned currentDim = 0; 1385 for (AffineMap m : reassociation) { 1386 unsigned dim = m.getNumResults(); 1387 int64_t size = 1; 1388 AffineExpr stride = strides[currentDim + dim - 1]; 1389 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { 1390 size = ShapedType::kDynamicSize; 1391 stride = AffineExpr(); 1392 } else { 1393 for (unsigned d = 0; d < dim; ++d) 1394 size *= sizes[currentDim + d]; 1395 } 1396 newSizes.push_back(size); 1397 newStrides.push_back(stride); 1398 currentDim += dim; 1399 } 1400 1401 // Early-exit: if `type` is contiguous, the result must be contiguous. 1402 if (canonicalizeStridedLayout(type).getAffineMaps().empty()) 1403 return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); 1404 1405 // Convert back to int64_t because we don't have enough information to create 1406 // new strided layouts from AffineExpr only. This corresponds to a case where 1407 // copies may be necessary. 1408 int64_t intOffset = ShapedType::kDynamicStrideOrOffset; 1409 if (auto o = offset.dyn_cast<AffineConstantExpr>()) 1410 intOffset = o.getValue(); 1411 SmallVector<int64_t, 4> intStrides; 1412 intStrides.reserve(strides.size()); 1413 for (auto stride : newStrides) { 1414 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) 1415 intStrides.push_back(cst.getValue()); 1416 else 1417 intStrides.push_back(ShapedType::kDynamicStrideOrOffset); 1418 } 1419 auto layout = 1420 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); 1421 return canonicalizeStridedLayout( 1422 MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); 1423 } 1424 1425 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1426 ArrayRef<ReassociationIndices> reassociation, 1427 ArrayRef<NamedAttribute> attrs) { 1428 auto memRefType = src.getType().cast<MemRefType>(); 1429 auto resultType = computeReshapeCollapsedType( 1430 memRefType, getSymbolLessAffineMaps( 1431 convertReassociationIndicesToExprs(b, reassociation))); 1432 build(b, result, resultType, src, attrs); 1433 result.addAttribute(getReassociationAttrName(), 1434 getReassociationIndicesAttribute(b, reassociation)); 1435 } 1436 1437 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1438 ArrayRef<ReassociationIndices> reassociation, 1439 ArrayRef<NamedAttribute> attrs) { 1440 auto memRefType = src.getType().cast<MemRefType>(); 1441 auto resultType = computeReshapeCollapsedType( 1442 memRefType, getSymbolLessAffineMaps( 1443 convertReassociationIndicesToExprs(b, reassociation))); 1444 build(b, result, resultType, src, attrs); 1445 result.addAttribute(getReassociationAttrName(), 1446 getReassociationIndicesAttribute(b, reassociation)); 1447 } 1448 1449 template <typename ReshapeOp, 1450 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value> 1451 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, 1452 MemRefType collapsedType) { 1453 if (failed( 1454 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1455 return failure(); 1456 auto maps = op.getReassociationMaps(); 1457 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); 1458 if (collapsedType != expectedType) 1459 return op.emitOpError("expected collapsed type to be ") 1460 << expectedType << ", but got " << collapsedType; 1461 return success(); 1462 } 1463 1464 static LogicalResult verify(ExpandShapeOp op) { 1465 return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); 1466 } 1467 1468 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1469 MLIRContext *context) { 1470 results.add<CollapseReshapeOps<ExpandShapeOp>, 1471 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); 1472 } 1473 1474 static LogicalResult verify(CollapseShapeOp op) { 1475 return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); 1476 } 1477 1478 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 1479 MLIRContext *context) { 1480 results.add<CollapseReshapeOps<CollapseShapeOp>, 1481 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context); 1482 } 1483 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) { 1484 if (succeeded(foldMemRefCast(*this))) 1485 return getResult(); 1486 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands); 1487 } 1488 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) { 1489 if (succeeded(foldMemRefCast(*this))) 1490 return getResult(); 1491 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands); 1492 } 1493 1494 //===----------------------------------------------------------------------===// 1495 // ReshapeOp 1496 //===----------------------------------------------------------------------===// 1497 1498 static LogicalResult verify(ReshapeOp op) { 1499 Type operandType = op.source().getType(); 1500 Type resultType = op.result().getType(); 1501 1502 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1503 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1504 if (operandElementType != resultElementType) 1505 return op.emitOpError("element types of source and destination memref " 1506 "types should be the same"); 1507 1508 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1509 if (!operandMemRefType.getAffineMaps().empty()) 1510 return op.emitOpError( 1511 "source memref type should have identity affine map"); 1512 1513 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1514 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1515 if (resultMemRefType) { 1516 if (!resultMemRefType.getAffineMaps().empty()) 1517 return op.emitOpError( 1518 "result memref type should have identity affine map"); 1519 if (shapeSize == ShapedType::kDynamicSize) 1520 return op.emitOpError("cannot use shape operand with dynamic length to " 1521 "reshape to statically-ranked memref type"); 1522 if (shapeSize != resultMemRefType.getRank()) 1523 return op.emitOpError( 1524 "length of shape operand differs from the result's memref rank"); 1525 } 1526 return success(); 1527 } 1528 1529 //===----------------------------------------------------------------------===// 1530 // StoreOp 1531 //===----------------------------------------------------------------------===// 1532 1533 static LogicalResult verify(StoreOp op) { 1534 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1535 return op.emitOpError("store index operand count not equal to memref rank"); 1536 1537 return success(); 1538 } 1539 1540 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1541 SmallVectorImpl<OpFoldResult> &results) { 1542 /// store(memrefcast) -> store 1543 return foldMemRefCast(*this, getValueToStore()); 1544 } 1545 1546 //===----------------------------------------------------------------------===// 1547 // SubViewOp 1548 //===----------------------------------------------------------------------===// 1549 1550 namespace { 1551 /// Helpers to write more idiomatic operations. 1552 namespace saturated_arith { 1553 struct Wrapper { 1554 explicit Wrapper(int64_t v) : v(v) {} 1555 operator int64_t() { return v; } 1556 int64_t v; 1557 }; 1558 Wrapper operator+(Wrapper a, int64_t b) { 1559 if (ShapedType::isDynamicStrideOrOffset(a) || 1560 ShapedType::isDynamicStrideOrOffset(b)) 1561 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1562 return Wrapper(a.v + b); 1563 } 1564 Wrapper operator*(Wrapper a, int64_t b) { 1565 if (ShapedType::isDynamicStrideOrOffset(a) || 1566 ShapedType::isDynamicStrideOrOffset(b)) 1567 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1568 return Wrapper(a.v * b); 1569 } 1570 } // end namespace saturated_arith 1571 } // end namespace 1572 1573 /// A subview result type can be fully inferred from the source type and the 1574 /// static representation of offsets, sizes and strides. Special sentinels 1575 /// encode the dynamic case. 1576 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1577 ArrayRef<int64_t> leadingStaticOffsets, 1578 ArrayRef<int64_t> leadingStaticSizes, 1579 ArrayRef<int64_t> leadingStaticStrides) { 1580 // A subview may specify only a leading subset of offset/sizes/strides in 1581 // which case we complete with offset=0, sizes from memref type and strides=1. 1582 unsigned rank = sourceMemRefType.getRank(); 1583 assert(leadingStaticOffsets.size() <= rank && 1584 "unexpected leadingStaticOffsets overflow"); 1585 assert(leadingStaticSizes.size() <= rank && 1586 "unexpected leadingStaticSizes overflow"); 1587 assert(leadingStaticStrides.size() <= rank && 1588 "unexpected leadingStaticStrides overflow"); 1589 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1590 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1591 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1592 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1593 unsigned numTrailingSizes = rank - staticSizes.size(); 1594 unsigned numTrailingStrides = rank - staticStrides.size(); 1595 staticOffsets.append(numTrailingOffsets, 0); 1596 llvm::append_range(staticSizes, 1597 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1598 staticStrides.append(numTrailingStrides, 1); 1599 1600 // Extract source offset and strides. 1601 int64_t sourceOffset; 1602 SmallVector<int64_t, 4> sourceStrides; 1603 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1604 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1605 (void)res; 1606 1607 // Compute target offset whose value is: 1608 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1609 int64_t targetOffset = sourceOffset; 1610 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1611 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1612 using namespace saturated_arith; 1613 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1614 } 1615 1616 // Compute target stride whose value is: 1617 // `sourceStrides_i * staticStrides_i`. 1618 SmallVector<int64_t, 4> targetStrides; 1619 targetStrides.reserve(staticOffsets.size()); 1620 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1621 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1622 using namespace saturated_arith; 1623 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1624 } 1625 1626 // The type is now known. 1627 return MemRefType::get( 1628 staticSizes, sourceMemRefType.getElementType(), 1629 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1630 sourceMemRefType.getContext()), 1631 sourceMemRefType.getMemorySpace()); 1632 } 1633 1634 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1635 ArrayRef<OpFoldResult> leadingStaticOffsets, 1636 ArrayRef<OpFoldResult> leadingStaticSizes, 1637 ArrayRef<OpFoldResult> leadingStaticStrides) { 1638 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1639 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1640 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1641 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1642 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1643 ShapedType::kDynamicSize); 1644 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1645 staticStrides, ShapedType::kDynamicStrideOrOffset); 1646 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1647 staticSizes, staticStrides) 1648 .cast<MemRefType>(); 1649 } 1650 1651 Type SubViewOp::inferRankReducedResultType( 1652 unsigned resultRank, MemRefType sourceRankedTensorType, 1653 ArrayRef<int64_t> leadingStaticOffsets, 1654 ArrayRef<int64_t> leadingStaticSizes, 1655 ArrayRef<int64_t> leadingStaticStrides) { 1656 auto inferredType = 1657 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1658 leadingStaticSizes, leadingStaticStrides) 1659 .cast<MemRefType>(); 1660 assert(inferredType.getRank() >= resultRank && "expected "); 1661 int rankDiff = inferredType.getRank() - resultRank; 1662 if (rankDiff > 0) { 1663 auto shape = inferredType.getShape(); 1664 llvm::SmallDenseSet<unsigned> dimsToProject; 1665 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1666 SmallVector<int64_t> projectedShape; 1667 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1668 if (!dimsToProject.contains(pos)) 1669 projectedShape.push_back(shape[pos]); 1670 1671 AffineMap map; 1672 auto maps = inferredType.getAffineMaps(); 1673 if (!maps.empty() && maps.front()) 1674 map = getProjectedMap(maps.front(), dimsToProject); 1675 inferredType = 1676 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1677 inferredType.getMemorySpace()); 1678 } 1679 return inferredType; 1680 } 1681 1682 Type SubViewOp::inferRankReducedResultType( 1683 unsigned resultRank, MemRefType sourceRankedTensorType, 1684 ArrayRef<OpFoldResult> leadingStaticOffsets, 1685 ArrayRef<OpFoldResult> leadingStaticSizes, 1686 ArrayRef<OpFoldResult> leadingStaticStrides) { 1687 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1688 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1689 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1690 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1691 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1692 ShapedType::kDynamicSize); 1693 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1694 staticStrides, ShapedType::kDynamicStrideOrOffset); 1695 return SubViewOp::inferRankReducedResultType( 1696 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1697 staticStrides); 1698 } 1699 // Build a SubViewOp with mixed static and dynamic entries and custom result 1700 // type. If the type passed is nullptr, it is inferred. 1701 void SubViewOp::build(OpBuilder &b, OperationState &result, 1702 MemRefType resultType, Value source, 1703 ArrayRef<OpFoldResult> offsets, 1704 ArrayRef<OpFoldResult> sizes, 1705 ArrayRef<OpFoldResult> strides, 1706 ArrayRef<NamedAttribute> attrs) { 1707 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1708 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1709 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1710 ShapedType::kDynamicStrideOrOffset); 1711 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1712 ShapedType::kDynamicSize); 1713 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1714 ShapedType::kDynamicStrideOrOffset); 1715 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1716 // Structuring implementation this way avoids duplication between builders. 1717 if (!resultType) { 1718 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1719 staticSizes, staticStrides) 1720 .cast<MemRefType>(); 1721 } 1722 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1723 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1724 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1725 result.addAttributes(attrs); 1726 } 1727 1728 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1729 // type. 1730 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1731 ArrayRef<OpFoldResult> offsets, 1732 ArrayRef<OpFoldResult> sizes, 1733 ArrayRef<OpFoldResult> strides, 1734 ArrayRef<NamedAttribute> attrs) { 1735 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1736 } 1737 1738 // Build a SubViewOp with static entries and inferred result type. 1739 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1740 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1741 ArrayRef<int64_t> strides, 1742 ArrayRef<NamedAttribute> attrs) { 1743 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1744 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1745 return b.getI64IntegerAttr(v); 1746 })); 1747 SmallVector<OpFoldResult> sizeValues = 1748 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1749 return b.getI64IntegerAttr(v); 1750 })); 1751 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1752 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1753 return b.getI64IntegerAttr(v); 1754 })); 1755 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1756 } 1757 1758 // Build a SubViewOp with dynamic entries and custom result type. If the 1759 // type passed is nullptr, it is inferred. 1760 void SubViewOp::build(OpBuilder &b, OperationState &result, 1761 MemRefType resultType, Value source, 1762 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1763 ArrayRef<int64_t> strides, 1764 ArrayRef<NamedAttribute> attrs) { 1765 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1766 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1767 return b.getI64IntegerAttr(v); 1768 })); 1769 SmallVector<OpFoldResult> sizeValues = 1770 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1771 return b.getI64IntegerAttr(v); 1772 })); 1773 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1774 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1775 return b.getI64IntegerAttr(v); 1776 })); 1777 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1778 attrs); 1779 } 1780 1781 // Build a SubViewOp with dynamic entries and custom result type. If the type 1782 // passed is nullptr, it is inferred. 1783 void SubViewOp::build(OpBuilder &b, OperationState &result, 1784 MemRefType resultType, Value source, ValueRange offsets, 1785 ValueRange sizes, ValueRange strides, 1786 ArrayRef<NamedAttribute> attrs) { 1787 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1788 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1789 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1790 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1791 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1792 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1793 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1794 } 1795 1796 // Build a SubViewOp with dynamic entries and inferred result type. 1797 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1798 ValueRange offsets, ValueRange sizes, ValueRange strides, 1799 ArrayRef<NamedAttribute> attrs) { 1800 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1801 } 1802 1803 /// For ViewLikeOpInterface. 1804 Value SubViewOp::getViewSource() { return source(); } 1805 1806 enum SubViewVerificationResult { 1807 Success, 1808 RankTooLarge, 1809 SizeMismatch, 1810 ElemTypeMismatch, 1811 MemSpaceMismatch, 1812 AffineMapMismatch 1813 }; 1814 1815 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1816 /// This function is slight variant of `is subsequence` algorithm where 1817 /// not matching dimension must be 1. 1818 static SubViewVerificationResult 1819 isRankReducedType(Type originalType, Type candidateReducedType, 1820 std::string *errMsg = nullptr) { 1821 if (originalType == candidateReducedType) 1822 return SubViewVerificationResult::Success; 1823 if (!originalType.isa<MemRefType>()) 1824 return SubViewVerificationResult::Success; 1825 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1826 return SubViewVerificationResult::Success; 1827 1828 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1829 ShapedType candidateReducedShapedType = 1830 candidateReducedType.cast<ShapedType>(); 1831 1832 // Rank and size logic is valid for all ShapedTypes. 1833 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1834 ArrayRef<int64_t> candidateReducedShape = 1835 candidateReducedShapedType.getShape(); 1836 unsigned originalRank = originalShape.size(), 1837 candidateReducedRank = candidateReducedShape.size(); 1838 if (candidateReducedRank > originalRank) 1839 return SubViewVerificationResult::RankTooLarge; 1840 1841 auto optionalUnusedDimsMask = 1842 computeRankReductionMask(originalShape, candidateReducedShape); 1843 1844 // Sizes cannot be matched in case empty vector is returned. 1845 if (!optionalUnusedDimsMask.hasValue()) 1846 return SubViewVerificationResult::SizeMismatch; 1847 1848 if (originalShapedType.getElementType() != 1849 candidateReducedShapedType.getElementType()) 1850 return SubViewVerificationResult::ElemTypeMismatch; 1851 1852 // Strided layout logic is relevant for MemRefType only. 1853 MemRefType original = originalType.cast<MemRefType>(); 1854 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1855 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1856 return SubViewVerificationResult::MemSpaceMismatch; 1857 1858 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1859 auto inferredType = 1860 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1861 AffineMap candidateLayout; 1862 if (candidateReduced.getAffineMaps().empty()) 1863 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1864 else 1865 candidateLayout = candidateReduced.getAffineMaps().front(); 1866 assert(inferredType.getNumResults() == 1 && 1867 candidateLayout.getNumResults() == 1); 1868 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1869 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1870 if (errMsg) { 1871 llvm::raw_string_ostream os(*errMsg); 1872 os << "inferred type: " << inferredType; 1873 } 1874 return SubViewVerificationResult::AffineMapMismatch; 1875 } 1876 // Check that the difference of the affine maps simplifies to 0. 1877 AffineExpr diffExpr = 1878 inferredType.getResult(0) - candidateLayout.getResult(0); 1879 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1880 inferredType.getNumSymbols()); 1881 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1882 if (!(cst && cst.getValue() == 0)) { 1883 if (errMsg) { 1884 llvm::raw_string_ostream os(*errMsg); 1885 os << "inferred type: " << inferredType; 1886 } 1887 return SubViewVerificationResult::AffineMapMismatch; 1888 } 1889 return SubViewVerificationResult::Success; 1890 } 1891 1892 template <typename OpTy> 1893 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1894 OpTy op, Type expectedType, 1895 StringRef errMsg = "") { 1896 auto memrefType = expectedType.cast<ShapedType>(); 1897 switch (result) { 1898 case SubViewVerificationResult::Success: 1899 return success(); 1900 case SubViewVerificationResult::RankTooLarge: 1901 return op.emitError("expected result rank to be smaller or equal to ") 1902 << "the source rank. " << errMsg; 1903 case SubViewVerificationResult::SizeMismatch: 1904 return op.emitError("expected result type to be ") 1905 << expectedType 1906 << " or a rank-reduced version. (mismatch of result sizes) " 1907 << errMsg; 1908 case SubViewVerificationResult::ElemTypeMismatch: 1909 return op.emitError("expected result element type to be ") 1910 << memrefType.getElementType() << errMsg; 1911 case SubViewVerificationResult::MemSpaceMismatch: 1912 return op.emitError("expected result and source memory spaces to match.") 1913 << errMsg; 1914 case SubViewVerificationResult::AffineMapMismatch: 1915 return op.emitError("expected result type to be ") 1916 << expectedType 1917 << " or a rank-reduced version. (mismatch of result affine map) " 1918 << errMsg; 1919 } 1920 llvm_unreachable("unexpected subview verification result"); 1921 } 1922 1923 /// Verifier for SubViewOp. 1924 static LogicalResult verify(SubViewOp op) { 1925 MemRefType baseType = op.getSourceType(); 1926 MemRefType subViewType = op.getType(); 1927 1928 // The base memref and the view memref should be in the same memory space. 1929 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1930 return op.emitError("different memory spaces specified for base memref " 1931 "type ") 1932 << baseType << " and subview memref type " << subViewType; 1933 1934 // Verify that the base memref type has a strided layout map. 1935 if (!isStrided(baseType)) 1936 return op.emitError("base type ") << baseType << " is not strided"; 1937 1938 // Verify result type against inferred type. 1939 auto expectedType = SubViewOp::inferResultType( 1940 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1941 extractFromI64ArrayAttr(op.static_sizes()), 1942 extractFromI64ArrayAttr(op.static_strides())); 1943 1944 std::string errMsg; 1945 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1946 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1947 } 1948 1949 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1950 return os << "range " << range.offset << ":" << range.size << ":" 1951 << range.stride; 1952 } 1953 1954 /// Return the list of Range (i.e. offset, size, stride). Each Range 1955 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1956 /// with `b` at location `loc`. 1957 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1958 OpBuilder &b, Location loc) { 1959 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1960 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1961 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1962 SmallVector<Range, 8> res; 1963 unsigned rank = ranks[0]; 1964 res.reserve(rank); 1965 for (unsigned idx = 0; idx < rank; ++idx) { 1966 Value offset = 1967 op.isDynamicOffset(idx) 1968 ? op.getDynamicOffset(idx) 1969 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1970 Value size = op.isDynamicSize(idx) 1971 ? op.getDynamicSize(idx) 1972 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1973 Value stride = 1974 op.isDynamicStride(idx) 1975 ? op.getDynamicStride(idx) 1976 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1977 res.emplace_back(Range{offset, size, stride}); 1978 } 1979 return res; 1980 } 1981 1982 /// Infer the canonical type of the result of a subview operation. Returns a 1983 /// type with rank `resultRank` that is either the rank of the rank-reduced 1984 /// type, or the non-rank-reduced type. 1985 static MemRefType 1986 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1987 ArrayRef<OpFoldResult> mixedOffsets, 1988 ArrayRef<OpFoldResult> mixedSizes, 1989 ArrayRef<OpFoldResult> mixedStrides) { 1990 auto resultType = 1991 SubViewOp::inferRankReducedResultType( 1992 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1993 .cast<MemRefType>(); 1994 if (resultType.getRank() != resultRank) { 1995 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1996 mixedSizes, mixedStrides) 1997 .cast<MemRefType>(); 1998 } 1999 return resultType; 2000 } 2001 2002 namespace { 2003 /// Pattern to rewrite a subview op with MemRefCast arguments. 2004 /// This essentially pushes memref.cast past its consuming subview when 2005 /// `canFoldIntoConsumerOp` is true. 2006 /// 2007 /// Example: 2008 /// ``` 2009 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 2010 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 2011 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 2012 /// ``` 2013 /// is rewritten into: 2014 /// ``` 2015 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 2016 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 2017 /// memref<3x4xf32, offset:?, strides:[?, 1]> 2018 /// ``` 2019 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 2020 public: 2021 using OpRewritePattern<SubViewOp>::OpRewritePattern; 2022 2023 LogicalResult matchAndRewrite(SubViewOp subViewOp, 2024 PatternRewriter &rewriter) const override { 2025 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 2026 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 2027 return matchPattern(operand, matchConstantIndex()); 2028 })) 2029 return failure(); 2030 2031 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 2032 if (!castOp) 2033 return failure(); 2034 2035 if (!CastOp::canFoldIntoConsumerOp(castOp)) 2036 return failure(); 2037 2038 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 2039 /// the cast source operand type and the SubViewOp static information. This 2040 /// is the resulting type if the MemRefCastOp were folded. 2041 auto resultType = getCanonicalSubViewResultType( 2042 subViewOp.getType().getRank(), 2043 castOp.source().getType().cast<MemRefType>(), 2044 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 2045 subViewOp.getMixedStrides()); 2046 Value newSubView = rewriter.create<SubViewOp>( 2047 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 2048 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 2049 subViewOp.static_sizes(), subViewOp.static_strides()); 2050 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 2051 newSubView); 2052 return success(); 2053 } 2054 }; 2055 } // namespace 2056 2057 /// Return the canonical type of the result of a subview. 2058 struct SubViewReturnTypeCanonicalizer { 2059 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2060 ArrayRef<OpFoldResult> mixedSizes, 2061 ArrayRef<OpFoldResult> mixedStrides) { 2062 return getCanonicalSubViewResultType(op.getType().getRank(), 2063 op.getSourceType(), mixedOffsets, 2064 mixedSizes, mixedStrides); 2065 } 2066 }; 2067 2068 /// A canonicalizer wrapper to replace SubViewOps. 2069 struct SubViewCanonicalizer { 2070 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2071 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2072 } 2073 }; 2074 2075 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2076 MLIRContext *context) { 2077 results 2078 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2079 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2080 SubViewOpMemRefCastFolder>(context); 2081 } 2082 2083 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2084 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2085 auto sourceShapedType = source().getType().cast<ShapedType>(); 2086 2087 if (resultShapedType.hasStaticShape() && 2088 resultShapedType == sourceShapedType) { 2089 return getViewSource(); 2090 } 2091 2092 return {}; 2093 } 2094 2095 //===----------------------------------------------------------------------===// 2096 // TensorLoadOp 2097 //===----------------------------------------------------------------------===// 2098 2099 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 2100 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 2101 // Approximate alias analysis by conservatively folding only when no there 2102 // is no interleaved operation. 2103 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 2104 bufferCast->getNextNode() == this->getOperation()) 2105 return bufferCast.tensor(); 2106 return {}; 2107 } 2108 2109 namespace { 2110 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> { 2111 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 2112 2113 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 2114 PatternRewriter &rewriter) const override { 2115 auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>(); 2116 if (!tensorLoadOp) 2117 return failure(); 2118 2119 rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(), 2120 dimOp.index()); 2121 return success(); 2122 } 2123 }; 2124 } // namespace 2125 2126 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 2127 MLIRContext *context) { 2128 results.add<DimOfTensorLoadFolder>(context); 2129 } 2130 2131 //===----------------------------------------------------------------------===// 2132 // TransposeOp 2133 //===----------------------------------------------------------------------===// 2134 2135 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2136 static MemRefType inferTransposeResultType(MemRefType memRefType, 2137 AffineMap permutationMap) { 2138 auto rank = memRefType.getRank(); 2139 auto originalSizes = memRefType.getShape(); 2140 // Compute permuted sizes. 2141 SmallVector<int64_t, 4> sizes(rank, 0); 2142 for (auto en : llvm::enumerate(permutationMap.getResults())) 2143 sizes[en.index()] = 2144 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2145 2146 // Compute permuted strides. 2147 int64_t offset; 2148 SmallVector<int64_t, 4> strides; 2149 auto res = getStridesAndOffset(memRefType, strides, offset); 2150 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2151 (void)res; 2152 auto map = 2153 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2154 map = permutationMap ? map.compose(permutationMap) : map; 2155 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2156 } 2157 2158 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2159 AffineMapAttr permutation, 2160 ArrayRef<NamedAttribute> attrs) { 2161 auto permutationMap = permutation.getValue(); 2162 assert(permutationMap); 2163 2164 auto memRefType = in.getType().cast<MemRefType>(); 2165 // Compute result type. 2166 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2167 2168 build(b, result, resultType, in, attrs); 2169 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2170 } 2171 2172 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2173 static void print(OpAsmPrinter &p, TransposeOp op) { 2174 p << "memref.transpose " << op.in() << " " << op.permutation(); 2175 p.printOptionalAttrDict(op->getAttrs(), 2176 {TransposeOp::getPermutationAttrName()}); 2177 p << " : " << op.in().getType() << " to " << op.getType(); 2178 } 2179 2180 static ParseResult parseTransposeOp(OpAsmParser &parser, 2181 OperationState &result) { 2182 OpAsmParser::OperandType in; 2183 AffineMap permutation; 2184 MemRefType srcType, dstType; 2185 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2186 parser.parseOptionalAttrDict(result.attributes) || 2187 parser.parseColonType(srcType) || 2188 parser.resolveOperand(in, srcType, result.operands) || 2189 parser.parseKeywordType("to", dstType) || 2190 parser.addTypeToList(dstType, result.types)) 2191 return failure(); 2192 2193 result.addAttribute(TransposeOp::getPermutationAttrName(), 2194 AffineMapAttr::get(permutation)); 2195 return success(); 2196 } 2197 2198 static LogicalResult verify(TransposeOp op) { 2199 if (!op.permutation().isPermutation()) 2200 return op.emitOpError("expected a permutation map"); 2201 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2202 return op.emitOpError( 2203 "expected a permutation map of same rank as the input"); 2204 2205 auto srcType = op.in().getType().cast<MemRefType>(); 2206 auto dstType = op.getType().cast<MemRefType>(); 2207 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2208 if (dstType != transposedType) 2209 return op.emitOpError("output type ") 2210 << dstType << " does not match transposed input type " << srcType 2211 << ", " << transposedType; 2212 return success(); 2213 } 2214 2215 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2216 if (succeeded(foldMemRefCast(*this))) 2217 return getResult(); 2218 return {}; 2219 } 2220 2221 //===----------------------------------------------------------------------===// 2222 // ViewOp 2223 //===----------------------------------------------------------------------===// 2224 2225 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2226 OpAsmParser::OperandType srcInfo; 2227 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2228 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2229 auto indexType = parser.getBuilder().getIndexType(); 2230 Type srcType, dstType; 2231 llvm::SMLoc offsetLoc; 2232 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2233 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2234 return failure(); 2235 2236 if (offsetInfo.size() != 1) 2237 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2238 2239 return failure( 2240 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2241 parser.parseOptionalAttrDict(result.attributes) || 2242 parser.parseColonType(srcType) || 2243 parser.resolveOperand(srcInfo, srcType, result.operands) || 2244 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2245 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2246 parser.parseKeywordType("to", dstType) || 2247 parser.addTypeToList(dstType, result.types)); 2248 } 2249 2250 static void print(OpAsmPrinter &p, ViewOp op) { 2251 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2252 p.printOperand(op.byte_shift()); 2253 p << "][" << op.sizes() << ']'; 2254 p.printOptionalAttrDict(op->getAttrs()); 2255 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2256 } 2257 2258 static LogicalResult verify(ViewOp op) { 2259 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2260 auto viewType = op.getType(); 2261 2262 // The base memref should have identity layout map (or none). 2263 if (baseType.getAffineMaps().size() > 1 || 2264 (baseType.getAffineMaps().size() == 1 && 2265 !baseType.getAffineMaps()[0].isIdentity())) 2266 return op.emitError("unsupported map for base memref type ") << baseType; 2267 2268 // The result memref should have identity layout map (or none). 2269 if (viewType.getAffineMaps().size() > 1 || 2270 (viewType.getAffineMaps().size() == 1 && 2271 !viewType.getAffineMaps()[0].isIdentity())) 2272 return op.emitError("unsupported map for result memref type ") << viewType; 2273 2274 // The base memref and the view memref should be in the same memory space. 2275 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2276 return op.emitError("different memory spaces specified for base memref " 2277 "type ") 2278 << baseType << " and view memref type " << viewType; 2279 2280 // Verify that we have the correct number of sizes for the result type. 2281 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2282 if (op.sizes().size() != numDynamicDims) 2283 return op.emitError("incorrect number of size operands for type ") 2284 << viewType; 2285 2286 return success(); 2287 } 2288 2289 Value ViewOp::getViewSource() { return source(); } 2290 2291 namespace { 2292 2293 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2294 using OpRewritePattern<ViewOp>::OpRewritePattern; 2295 2296 LogicalResult matchAndRewrite(ViewOp viewOp, 2297 PatternRewriter &rewriter) const override { 2298 // Return if none of the operands are constants. 2299 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2300 return matchPattern(operand, matchConstantIndex()); 2301 })) 2302 return failure(); 2303 2304 // Get result memref type. 2305 auto memrefType = viewOp.getType(); 2306 2307 // Get offset from old memref view type 'memRefType'. 2308 int64_t oldOffset; 2309 SmallVector<int64_t, 4> oldStrides; 2310 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2311 return failure(); 2312 assert(oldOffset == 0 && "Expected 0 offset"); 2313 2314 SmallVector<Value, 4> newOperands; 2315 2316 // Offset cannot be folded into result type. 2317 2318 // Fold any dynamic dim operands which are produced by a constant. 2319 SmallVector<int64_t, 4> newShapeConstants; 2320 newShapeConstants.reserve(memrefType.getRank()); 2321 2322 unsigned dynamicDimPos = 0; 2323 unsigned rank = memrefType.getRank(); 2324 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2325 int64_t dimSize = memrefType.getDimSize(dim); 2326 // If this is already static dimension, keep it. 2327 if (!ShapedType::isDynamic(dimSize)) { 2328 newShapeConstants.push_back(dimSize); 2329 continue; 2330 } 2331 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2332 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2333 // Dynamic shape dimension will be folded. 2334 newShapeConstants.push_back(constantIndexOp.getValue()); 2335 } else { 2336 // Dynamic shape dimension not folded; copy operand from old memref. 2337 newShapeConstants.push_back(dimSize); 2338 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2339 } 2340 dynamicDimPos++; 2341 } 2342 2343 // Create new memref type with constant folded dims. 2344 MemRefType newMemRefType = 2345 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2346 // Nothing new, don't fold. 2347 if (newMemRefType == memrefType) 2348 return failure(); 2349 2350 // Create new ViewOp. 2351 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2352 viewOp.getOperand(0), 2353 viewOp.byte_shift(), newOperands); 2354 // Insert a cast so we have the same type as the old memref type. 2355 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2356 return success(); 2357 } 2358 }; 2359 2360 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2361 using OpRewritePattern<ViewOp>::OpRewritePattern; 2362 2363 LogicalResult matchAndRewrite(ViewOp viewOp, 2364 PatternRewriter &rewriter) const override { 2365 Value memrefOperand = viewOp.getOperand(0); 2366 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2367 if (!memrefCastOp) 2368 return failure(); 2369 Value allocOperand = memrefCastOp.getOperand(); 2370 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2371 if (!allocOp) 2372 return failure(); 2373 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2374 viewOp.byte_shift(), viewOp.sizes()); 2375 return success(); 2376 } 2377 }; 2378 2379 } // end anonymous namespace 2380 2381 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2382 MLIRContext *context) { 2383 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2384 } 2385 2386 //===----------------------------------------------------------------------===// 2387 // TableGen'd op method definitions 2388 //===----------------------------------------------------------------------===// 2389 2390 #define GET_OP_CLASSES 2391 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2392