1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Dialect/Utils/StaticValueUtils.h" 15 #include "mlir/IR/AffineMap.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Interfaces/InferTypeOpInterface.h" 22 #include "mlir/Interfaces/ViewLikeInterface.h" 23 #include "llvm/ADT/STLExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::memref; 27 28 /// Materialize a single constant operation from a given attribute value with 29 /// the desired resultant type. 30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 31 Attribute value, Type type, 32 Location loc) { 33 return builder.create<mlir::ConstantOp>(loc, type, value); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Common canonicalization pattern support logic 38 //===----------------------------------------------------------------------===// 39 40 /// This is a common class used for patterns of the form 41 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 42 /// into the root operation directly. 43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 44 bool folded = false; 45 for (OpOperand &operand : op->getOpOperands()) { 46 auto cast = operand.get().getDefiningOp<CastOp>(); 47 if (cast && operand.get() != inner && 48 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 49 operand.set(cast.getOperand()); 50 folded = true; 51 } 52 } 53 return success(folded); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Helpers for GlobalOp 58 //===----------------------------------------------------------------------===// 59 60 static Type getTensorTypeFromMemRefType(Type type) { 61 if (auto memref = type.dyn_cast<MemRefType>()) 62 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 63 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 64 return UnrankedTensorType::get(memref.getElementType()); 65 return NoneType::get(type.getContext()); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // AllocOp / AllocaOp 70 //===----------------------------------------------------------------------===// 71 72 template <typename AllocLikeOp> 73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 74 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 75 "applies to only alloc or alloca"); 76 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 77 if (!memRefType) 78 return op.emitOpError("result must be a memref"); 79 80 if (static_cast<int64_t>(op.dynamicSizes().size()) != 81 memRefType.getNumDynamicDims()) 82 return op.emitOpError("dimension operand count does not equal memref " 83 "dynamic dimension count"); 84 85 unsigned numSymbols = 0; 86 if (!memRefType.getAffineMaps().empty()) 87 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 88 if (op.symbolOperands().size() != numSymbols) 89 return op.emitOpError("symbol operand count does not equal memref symbol " 90 "count: expected ") 91 << numSymbols << ", got " << op.symbolOperands().size(); 92 93 return success(); 94 } 95 96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 97 98 static LogicalResult verify(AllocaOp op) { 99 // An alloca op needs to have an ancestor with an allocation scope trait. 100 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 101 return op.emitOpError( 102 "requires an ancestor op with AutomaticAllocationScope trait"); 103 104 return verifyAllocLikeOp(op); 105 } 106 107 namespace { 108 /// Fold constant dimensions into an alloc like operation. 109 template <typename AllocLikeOp> 110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 111 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 112 113 LogicalResult matchAndRewrite(AllocLikeOp alloc, 114 PatternRewriter &rewriter) const override { 115 // Check to see if any dimensions operands are constants. If so, we can 116 // substitute and drop them. 117 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 118 return matchPattern(operand, matchConstantIndex()); 119 })) 120 return failure(); 121 122 auto memrefType = alloc.getType(); 123 124 // Ok, we have one or more constant operands. Collect the non-constant ones 125 // and keep track of the resultant memref type to build. 126 SmallVector<int64_t, 4> newShapeConstants; 127 newShapeConstants.reserve(memrefType.getRank()); 128 SmallVector<Value, 4> dynamicSizes; 129 130 unsigned dynamicDimPos = 0; 131 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 132 int64_t dimSize = memrefType.getDimSize(dim); 133 // If this is already static dimension, keep it. 134 if (dimSize != -1) { 135 newShapeConstants.push_back(dimSize); 136 continue; 137 } 138 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 139 auto *defOp = dynamicSize.getDefiningOp(); 140 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 141 // Dynamic shape dimension will be folded. 142 newShapeConstants.push_back(constantIndexOp.getValue()); 143 } else { 144 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 145 newShapeConstants.push_back(-1); 146 dynamicSizes.push_back(dynamicSize); 147 } 148 dynamicDimPos++; 149 } 150 151 // Create new memref type (which will have fewer dynamic dimensions). 152 MemRefType newMemRefType = 153 MemRefType::Builder(memrefType).setShape(newShapeConstants); 154 assert(static_cast<int64_t>(dynamicSizes.size()) == 155 newMemRefType.getNumDynamicDims()); 156 157 // Create and insert the alloc op for the new memref. 158 auto newAlloc = rewriter.create<AllocLikeOp>( 159 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 160 alloc.alignmentAttr()); 161 // Insert a cast so we have the same type as the old alloc. 162 auto resultCast = 163 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 164 165 rewriter.replaceOp(alloc, {resultCast}); 166 return success(); 167 } 168 }; 169 170 /// Fold alloc operations with no users or only store and dealloc uses. 171 template <typename T> 172 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 173 using OpRewritePattern<T>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(T alloc, 176 PatternRewriter &rewriter) const override { 177 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { 178 if (auto storeOp = dyn_cast<StoreOp>(op)) 179 return storeOp.value() == alloc; 180 return !isa<DeallocOp>(op); 181 })) 182 return failure(); 183 184 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 185 rewriter.eraseOp(user); 186 187 rewriter.eraseOp(alloc); 188 return success(); 189 } 190 }; 191 } // end anonymous namespace. 192 193 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 194 MLIRContext *context) { 195 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 196 } 197 198 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 199 MLIRContext *context) { 200 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 201 context); 202 } 203 204 //===----------------------------------------------------------------------===// 205 // AllocaScopeOp 206 //===----------------------------------------------------------------------===// 207 208 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 209 bool printBlockTerminators = false; 210 211 p << AllocaScopeOp::getOperationName() << " "; 212 if (!op.results().empty()) { 213 p << " -> (" << op.getResultTypes() << ")"; 214 printBlockTerminators = true; 215 } 216 p.printRegion(op.bodyRegion(), 217 /*printEntryBlockArgs=*/false, 218 /*printBlockTerminators=*/printBlockTerminators); 219 p.printOptionalAttrDict(op->getAttrs()); 220 } 221 222 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 223 OperationState &result) { 224 // Create a region for the body. 225 result.regions.reserve(1); 226 Region *bodyRegion = result.addRegion(); 227 228 // Parse optional results type list. 229 if (parser.parseOptionalArrowTypeList(result.types)) 230 return failure(); 231 232 // Parse the body region. 233 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 234 return failure(); 235 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 236 result.location); 237 238 // Parse the optional attribute list. 239 if (parser.parseOptionalAttrDict(result.attributes)) 240 return failure(); 241 242 return success(); 243 } 244 245 static LogicalResult verify(AllocaScopeOp op) { 246 if (failed(RegionBranchOpInterface::verifyTypes(op))) 247 return failure(); 248 249 return success(); 250 } 251 252 void AllocaScopeOp::getSuccessorRegions( 253 Optional<unsigned> index, ArrayRef<Attribute> operands, 254 SmallVectorImpl<RegionSuccessor> ®ions) { 255 if (index.hasValue()) { 256 regions.push_back(RegionSuccessor(getResults())); 257 return; 258 } 259 260 regions.push_back(RegionSuccessor(&bodyRegion())); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // AssumeAlignmentOp 265 //===----------------------------------------------------------------------===// 266 267 static LogicalResult verify(AssumeAlignmentOp op) { 268 unsigned alignment = op.alignment(); 269 if (!llvm::isPowerOf2_32(alignment)) 270 return op.emitOpError("alignment must be power of 2"); 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // BufferCastOp 276 //===----------------------------------------------------------------------===// 277 278 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 279 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 280 if (tensorLoad.memref().getType() == getType()) 281 return tensorLoad.memref(); 282 return {}; 283 } 284 285 namespace { 286 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 287 struct BufferCast : public OpRewritePattern<BufferCastOp> { 288 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 289 290 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 291 PatternRewriter &rewriter) const final { 292 auto tensorCastOperand = 293 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 294 if (!tensorCastOperand) 295 return failure(); 296 auto srcTensorType = 297 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 298 if (!srcTensorType) 299 return failure(); 300 auto memrefType = MemRefType::get(srcTensorType.getShape(), 301 srcTensorType.getElementType()); 302 Value memref = rewriter.create<BufferCastOp>( 303 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 304 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 305 memref); 306 return success(); 307 } 308 }; 309 310 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 311 /// type mismatches prevent `BufferCastOp::fold` to kick in. 312 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 313 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 314 315 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 316 PatternRewriter &rewriter) const final { 317 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 318 // Bail unless we have a tensor_load + memref.buffer_cast with different 319 // types. `BufferCastOp::fold` handles the same type case. 320 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 321 return failure(); 322 // If types are not cast-compatible, bail. 323 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 324 bufferCast.getType())) 325 return failure(); 326 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 327 tensorLoad.memref()); 328 return success(); 329 } 330 }; 331 332 } // namespace 333 334 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 335 MLIRContext *context) { 336 results.add<BufferCast, TensorLoadToMemRef>(context); 337 } 338 339 //===----------------------------------------------------------------------===// 340 // CastOp 341 //===----------------------------------------------------------------------===// 342 343 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 344 /// source memref. This is useful to to fold a memref.cast into a consuming op 345 /// and implement canonicalization patterns for ops in different dialects that 346 /// may consume the results of memref.cast operations. Such foldable memref.cast 347 /// operations are typically inserted as `view` and `subview` ops are 348 /// canonicalized, to preserve the type compatibility of their uses. 349 /// 350 /// Returns true when all conditions are met: 351 /// 1. source and result are ranked memrefs with strided semantics and same 352 /// element type and rank. 353 /// 2. each of the source's size, offset or stride has more static information 354 /// than the corresponding result's size, offset or stride. 355 /// 356 /// Example 1: 357 /// ```mlir 358 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 359 /// %2 = consumer %1 ... : memref<?x?xf32> ... 360 /// ``` 361 /// 362 /// may fold into: 363 /// 364 /// ```mlir 365 /// %2 = consumer %0 ... : memref<8x16xf32> ... 366 /// ``` 367 /// 368 /// Example 2: 369 /// ``` 370 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 371 /// to memref<?x?xf32> 372 /// consumer %1 : memref<?x?xf32> ... 373 /// ``` 374 /// 375 /// may fold into: 376 /// 377 /// ``` 378 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 379 /// ``` 380 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 381 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 382 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 383 384 // Requires ranked MemRefType. 385 if (!sourceType || !resultType) 386 return false; 387 388 // Requires same elemental type. 389 if (sourceType.getElementType() != resultType.getElementType()) 390 return false; 391 392 // Requires same rank. 393 if (sourceType.getRank() != resultType.getRank()) 394 return false; 395 396 // Only fold casts between strided memref forms. 397 int64_t sourceOffset, resultOffset; 398 SmallVector<int64_t, 4> sourceStrides, resultStrides; 399 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 400 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 401 return false; 402 403 // If cast is towards more static sizes along any dimension, don't fold. 404 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 405 auto ss = std::get<0>(it), st = std::get<1>(it); 406 if (ss != st) 407 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 408 return false; 409 } 410 411 // If cast is towards more static offset along any dimension, don't fold. 412 if (sourceOffset != resultOffset) 413 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 414 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 415 return false; 416 417 // If cast is towards more static strides along any dimension, don't fold. 418 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 419 auto ss = std::get<0>(it), st = std::get<1>(it); 420 if (ss != st) 421 if (MemRefType::isDynamicStrideOrOffset(ss) && 422 !MemRefType::isDynamicStrideOrOffset(st)) 423 return false; 424 } 425 426 return true; 427 } 428 429 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 430 if (inputs.size() != 1 || outputs.size() != 1) 431 return false; 432 Type a = inputs.front(), b = outputs.front(); 433 auto aT = a.dyn_cast<MemRefType>(); 434 auto bT = b.dyn_cast<MemRefType>(); 435 436 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 437 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 438 439 if (aT && bT) { 440 if (aT.getElementType() != bT.getElementType()) 441 return false; 442 if (aT.getAffineMaps() != bT.getAffineMaps()) { 443 int64_t aOffset, bOffset; 444 SmallVector<int64_t, 4> aStrides, bStrides; 445 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 446 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 447 aStrides.size() != bStrides.size()) 448 return false; 449 450 // Strides along a dimension/offset are compatible if the value in the 451 // source memref is static and the value in the target memref is the 452 // same. They are also compatible if either one is dynamic (see 453 // description of MemRefCastOp for details). 454 auto checkCompatible = [](int64_t a, int64_t b) { 455 return (a == MemRefType::getDynamicStrideOrOffset() || 456 b == MemRefType::getDynamicStrideOrOffset() || a == b); 457 }; 458 if (!checkCompatible(aOffset, bOffset)) 459 return false; 460 for (auto aStride : enumerate(aStrides)) 461 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 462 return false; 463 } 464 if (aT.getMemorySpace() != bT.getMemorySpace()) 465 return false; 466 467 // They must have the same rank, and any specified dimensions must match. 468 if (aT.getRank() != bT.getRank()) 469 return false; 470 471 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 472 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 473 if (aDim != -1 && bDim != -1 && aDim != bDim) 474 return false; 475 } 476 return true; 477 } else { 478 if (!aT && !uaT) 479 return false; 480 if (!bT && !ubT) 481 return false; 482 // Unranked to unranked casting is unsupported 483 if (uaT && ubT) 484 return false; 485 486 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 487 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 488 if (aEltType != bEltType) 489 return false; 490 491 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 492 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 493 if (aMemSpace != bMemSpace) 494 return false; 495 496 return true; 497 } 498 499 return false; 500 } 501 502 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 503 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // CloneOp 508 //===----------------------------------------------------------------------===// 509 510 void CloneOp::getEffects( 511 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 512 &effects) { 513 effects.emplace_back(MemoryEffects::Read::get(), input(), 514 SideEffects::DefaultResource::get()); 515 effects.emplace_back(MemoryEffects::Write::get(), output(), 516 SideEffects::DefaultResource::get()); 517 effects.emplace_back(MemoryEffects::Allocate::get(), output(), 518 SideEffects::DefaultResource::get()); 519 } 520 521 namespace { 522 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 523 /// by other Dealloc operations. 524 struct SimplifyClones : public OpRewritePattern<CloneOp> { 525 using OpRewritePattern<CloneOp>::OpRewritePattern; 526 527 LogicalResult matchAndRewrite(CloneOp cloneOp, 528 PatternRewriter &rewriter) const override { 529 if (cloneOp.use_empty()) { 530 rewriter.eraseOp(cloneOp); 531 return success(); 532 } 533 534 Value source = cloneOp.input(); 535 536 // This only finds dealloc operations for the immediate value. It should 537 // also consider aliases. That would also make the safety check below 538 // redundant. 539 Operation *cloneDeallocOp = findDealloc(cloneOp.output()); 540 Operation *sourceDeallocOp = findDealloc(source); 541 542 // If both are deallocated in the same block, their in-block lifetimes 543 // might not fully overlap, so we cannot decide which one to drop. 544 if (cloneDeallocOp && sourceDeallocOp && 545 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 546 return failure(); 547 548 Block *currentBlock = cloneOp->getBlock(); 549 Operation *redundantDealloc = nullptr; 550 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 551 redundantDealloc = cloneDeallocOp; 552 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 553 redundantDealloc = sourceDeallocOp; 554 } 555 556 if (!redundantDealloc) 557 return failure(); 558 559 // Safety check that there are no other deallocations inbetween 560 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 561 // of source before the uses of the clone. With alias information, we could 562 // restrict this to only fail of the dealloc's operand is an alias 563 // of the source. 564 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 565 pos = pos->getNextNode()) { 566 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 567 if (!effectInterface) 568 continue; 569 if (effectInterface.hasEffect<MemoryEffects::Free>()) 570 return failure(); 571 } 572 573 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 574 source); 575 rewriter.eraseOp(redundantDealloc); 576 return success(); 577 } 578 }; 579 580 } // end anonymous namespace. 581 582 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 583 MLIRContext *context) { 584 results.insert<SimplifyClones>(context); 585 } 586 587 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 588 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 589 } 590 591 //===----------------------------------------------------------------------===// 592 // DeallocOp 593 //===----------------------------------------------------------------------===// 594 595 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 596 SmallVectorImpl<OpFoldResult> &results) { 597 /// dealloc(memrefcast) -> dealloc 598 return foldMemRefCast(*this); 599 } 600 601 //===----------------------------------------------------------------------===// 602 // DimOp 603 //===----------------------------------------------------------------------===// 604 605 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 606 int64_t index) { 607 auto loc = result.location; 608 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 609 build(builder, result, source, indexValue); 610 } 611 612 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 613 Value index) { 614 auto indexTy = builder.getIndexType(); 615 build(builder, result, indexTy, source, index); 616 } 617 618 Optional<int64_t> DimOp::getConstantIndex() { 619 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 620 return constantOp.getValue().cast<IntegerAttr>().getInt(); 621 return {}; 622 } 623 624 static LogicalResult verify(DimOp op) { 625 // Assume unknown index to be in range. 626 Optional<int64_t> index = op.getConstantIndex(); 627 if (!index.hasValue()) 628 return success(); 629 630 // Check that constant index is not knowingly out of range. 631 auto type = op.source().getType(); 632 if (auto memrefType = type.dyn_cast<MemRefType>()) { 633 if (index.getValue() >= memrefType.getRank()) 634 return op.emitOpError("index is out of range"); 635 } else if (type.isa<UnrankedMemRefType>()) { 636 // Assume index to be in range. 637 } else { 638 llvm_unreachable("expected operand with memref type"); 639 } 640 return success(); 641 } 642 643 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 644 // All forms of folding require a known index. 645 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 646 if (!index) 647 return {}; 648 649 // Folding for unranked types (UnrankedMemRefType) is not supported. 650 auto memrefType = source().getType().dyn_cast<MemRefType>(); 651 if (!memrefType) 652 return {}; 653 654 // Fold if the shape extent along the given index is known. 655 if (!memrefType.isDynamicDim(index.getInt())) { 656 Builder builder(getContext()); 657 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]); 658 } 659 660 // The size at the given index is now known to be a dynamic size. 661 unsigned unsignedIndex = index.getValue().getZExtValue(); 662 663 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 664 Operation *definingOp = source().getDefiningOp(); 665 666 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 667 return *(alloc.getDynamicSizes().begin() + 668 memrefType.getDynamicDimIndex(unsignedIndex)); 669 670 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 671 return *(alloca.getDynamicSizes().begin() + 672 memrefType.getDynamicDimIndex(unsignedIndex)); 673 674 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 675 return *(view.getDynamicSizes().begin() + 676 memrefType.getDynamicDimIndex(unsignedIndex)); 677 678 if (auto sizeInterface = 679 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 680 assert(sizeInterface.isDynamicSize(unsignedIndex) && 681 "Expected dynamic subview size"); 682 return sizeInterface.getDynamicSize(unsignedIndex); 683 } 684 685 // dim(memrefcast) -> dim 686 if (succeeded(foldMemRefCast(*this))) 687 return getResult(); 688 689 return {}; 690 } 691 692 namespace { 693 /// Fold dim of a memref reshape operation to a load into the reshape's shape 694 /// operand. 695 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 696 using OpRewritePattern<DimOp>::OpRewritePattern; 697 698 LogicalResult matchAndRewrite(DimOp dim, 699 PatternRewriter &rewriter) const override { 700 auto reshape = dim.source().getDefiningOp<ReshapeOp>(); 701 702 if (!reshape) 703 return failure(); 704 705 // Place the load directly after the reshape to ensure that the shape memref 706 // was not mutated. 707 rewriter.setInsertionPointAfter(reshape); 708 Location loc = dim.getLoc(); 709 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 710 if (load.getType() != dim.getType()) 711 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load); 712 rewriter.replaceOp(dim, load); 713 return success(); 714 } 715 }; 716 717 /// Fold dim of a cast into the dim of the source of the memref cast. 718 struct DimOfCastOp : public OpRewritePattern<DimOp> { 719 using OpRewritePattern<DimOp>::OpRewritePattern; 720 721 LogicalResult matchAndRewrite(DimOp dimOp, 722 PatternRewriter &rewriter) const override { 723 auto castOp = dimOp.source().getDefiningOp<BufferCastOp>(); 724 if (!castOp) 725 return failure(); 726 Value newSource = castOp.getOperand(); 727 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index()); 728 return success(); 729 } 730 }; 731 } // end anonymous namespace. 732 733 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 734 MLIRContext *context) { 735 results.add<DimOfMemRefReshape, DimOfCastOp>(context); 736 } 737 738 // --------------------------------------------------------------------------- 739 // DmaStartOp 740 // --------------------------------------------------------------------------- 741 742 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 743 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 744 ValueRange destIndices, Value numElements, 745 Value tagMemRef, ValueRange tagIndices, Value stride, 746 Value elementsPerStride) { 747 result.addOperands(srcMemRef); 748 result.addOperands(srcIndices); 749 result.addOperands(destMemRef); 750 result.addOperands(destIndices); 751 result.addOperands({numElements, tagMemRef}); 752 result.addOperands(tagIndices); 753 if (stride) 754 result.addOperands({stride, elementsPerStride}); 755 } 756 757 void DmaStartOp::print(OpAsmPrinter &p) { 758 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 759 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 760 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 761 << ']'; 762 if (isStrided()) 763 p << ", " << getStride() << ", " << getNumElementsPerStride(); 764 765 p.printOptionalAttrDict((*this)->getAttrs()); 766 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 767 << ", " << getTagMemRef().getType(); 768 } 769 770 // Parse DmaStartOp. 771 // Ex: 772 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 773 // %tag[%index], %stride, %num_elt_per_stride : 774 // : memref<3076 x f32, 0>, 775 // memref<1024 x f32, 2>, 776 // memref<1 x i32> 777 // 778 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 779 OpAsmParser::OperandType srcMemRefInfo; 780 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 781 OpAsmParser::OperandType dstMemRefInfo; 782 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 783 OpAsmParser::OperandType numElementsInfo; 784 OpAsmParser::OperandType tagMemrefInfo; 785 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 786 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 787 788 SmallVector<Type, 3> types; 789 auto indexType = parser.getBuilder().getIndexType(); 790 791 // Parse and resolve the following list of operands: 792 // *) source memref followed by its indices (in square brackets). 793 // *) destination memref followed by its indices (in square brackets). 794 // *) dma size in KiB. 795 if (parser.parseOperand(srcMemRefInfo) || 796 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 797 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 798 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 799 parser.parseComma() || parser.parseOperand(numElementsInfo) || 800 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 801 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 802 return failure(); 803 804 // Parse optional stride and elements per stride. 805 if (parser.parseTrailingOperandList(strideInfo)) 806 return failure(); 807 808 bool isStrided = strideInfo.size() == 2; 809 if (!strideInfo.empty() && !isStrided) { 810 return parser.emitError(parser.getNameLoc(), 811 "expected two stride related operands"); 812 } 813 814 if (parser.parseColonTypeList(types)) 815 return failure(); 816 if (types.size() != 3) 817 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 818 819 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 820 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 821 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 822 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 823 // size should be an index. 824 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 825 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 826 // tag indices should be index. 827 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 828 return failure(); 829 830 if (isStrided) { 831 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 832 return failure(); 833 } 834 835 return success(); 836 } 837 838 LogicalResult DmaStartOp::verify() { 839 unsigned numOperands = getNumOperands(); 840 841 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 842 // the number of elements. 843 if (numOperands < 4) 844 return emitOpError("expected at least 4 operands"); 845 846 // Check types of operands. The order of these calls is important: the later 847 // calls rely on some type properties to compute the operand position. 848 // 1. Source memref. 849 if (!getSrcMemRef().getType().isa<MemRefType>()) 850 return emitOpError("expected source to be of memref type"); 851 if (numOperands < getSrcMemRefRank() + 4) 852 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 853 << " operands"; 854 if (!getSrcIndices().empty() && 855 !llvm::all_of(getSrcIndices().getTypes(), 856 [](Type t) { return t.isIndex(); })) 857 return emitOpError("expected source indices to be of index type"); 858 859 // 2. Destination memref. 860 if (!getDstMemRef().getType().isa<MemRefType>()) 861 return emitOpError("expected destination to be of memref type"); 862 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 863 if (numOperands < numExpectedOperands) 864 return emitOpError() << "expected at least " << numExpectedOperands 865 << " operands"; 866 if (!getDstIndices().empty() && 867 !llvm::all_of(getDstIndices().getTypes(), 868 [](Type t) { return t.isIndex(); })) 869 return emitOpError("expected destination indices to be of index type"); 870 871 // 3. Number of elements. 872 if (!getNumElements().getType().isIndex()) 873 return emitOpError("expected num elements to be of index type"); 874 875 // 4. Tag memref. 876 if (!getTagMemRef().getType().isa<MemRefType>()) 877 return emitOpError("expected tag to be of memref type"); 878 numExpectedOperands += getTagMemRefRank(); 879 if (numOperands < numExpectedOperands) 880 return emitOpError() << "expected at least " << numExpectedOperands 881 << " operands"; 882 if (!getTagIndices().empty() && 883 !llvm::all_of(getTagIndices().getTypes(), 884 [](Type t) { return t.isIndex(); })) 885 return emitOpError("expected tag indices to be of index type"); 886 887 // Optional stride-related operands must be either both present or both 888 // absent. 889 if (numOperands != numExpectedOperands && 890 numOperands != numExpectedOperands + 2) 891 return emitOpError("incorrect number of operands"); 892 893 // 5. Strides. 894 if (isStrided()) { 895 if (!getStride().getType().isIndex() || 896 !getNumElementsPerStride().getType().isIndex()) 897 return emitOpError( 898 "expected stride and num elements per stride to be of type index"); 899 } 900 901 return success(); 902 } 903 904 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 905 SmallVectorImpl<OpFoldResult> &results) { 906 /// dma_start(memrefcast) -> dma_start 907 return foldMemRefCast(*this); 908 } 909 910 // --------------------------------------------------------------------------- 911 // DmaWaitOp 912 // --------------------------------------------------------------------------- 913 914 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 915 Value tagMemRef, ValueRange tagIndices, 916 Value numElements) { 917 result.addOperands(tagMemRef); 918 result.addOperands(tagIndices); 919 result.addOperands(numElements); 920 } 921 922 void DmaWaitOp::print(OpAsmPrinter &p) { 923 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 924 << "], " << getNumElements(); 925 p.printOptionalAttrDict((*this)->getAttrs()); 926 p << " : " << getTagMemRef().getType(); 927 } 928 929 // Parse DmaWaitOp. 930 // Eg: 931 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 932 // 933 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 934 OpAsmParser::OperandType tagMemrefInfo; 935 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 936 Type type; 937 auto indexType = parser.getBuilder().getIndexType(); 938 OpAsmParser::OperandType numElementsInfo; 939 940 // Parse tag memref, its indices, and dma size. 941 if (parser.parseOperand(tagMemrefInfo) || 942 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 943 parser.parseComma() || parser.parseOperand(numElementsInfo) || 944 parser.parseColonType(type) || 945 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 946 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 947 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 948 return failure(); 949 950 return success(); 951 } 952 953 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 954 SmallVectorImpl<OpFoldResult> &results) { 955 /// dma_wait(memrefcast) -> dma_wait 956 return foldMemRefCast(*this); 957 } 958 959 LogicalResult DmaWaitOp::verify() { 960 // Mandatory non-variadic operands are tag and the number of elements. 961 if (getNumOperands() < 2) 962 return emitOpError() << "expected at least 2 operands"; 963 964 // Check types of operands. The order of these calls is important: the later 965 // calls rely on some type properties to compute the operand position. 966 if (!getTagMemRef().getType().isa<MemRefType>()) 967 return emitOpError() << "expected tag to be of memref type"; 968 969 if (getNumOperands() != 2 + getTagMemRefRank()) 970 return emitOpError() << "expected " << 2 + getTagMemRefRank() 971 << " operands"; 972 973 if (!getTagIndices().empty() && 974 !llvm::all_of(getTagIndices().getTypes(), 975 [](Type t) { return t.isIndex(); })) 976 return emitOpError() << "expected tag indices to be of index type"; 977 978 if (!getNumElements().getType().isIndex()) 979 return emitOpError() 980 << "expected the number of elements to be of index type"; 981 982 return success(); 983 } 984 985 //===----------------------------------------------------------------------===// 986 // GlobalOp 987 //===----------------------------------------------------------------------===// 988 989 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 990 TypeAttr type, 991 Attribute initialValue) { 992 p << type; 993 if (!op.isExternal()) { 994 p << " = "; 995 if (op.isUninitialized()) 996 p << "uninitialized"; 997 else 998 p.printAttributeWithoutType(initialValue); 999 } 1000 } 1001 1002 static ParseResult 1003 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1004 Attribute &initialValue) { 1005 Type type; 1006 if (parser.parseType(type)) 1007 return failure(); 1008 1009 auto memrefType = type.dyn_cast<MemRefType>(); 1010 if (!memrefType || !memrefType.hasStaticShape()) 1011 return parser.emitError(parser.getNameLoc()) 1012 << "type should be static shaped memref, but got " << type; 1013 typeAttr = TypeAttr::get(type); 1014 1015 if (parser.parseOptionalEqual()) 1016 return success(); 1017 1018 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1019 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1020 return success(); 1021 } 1022 1023 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1024 if (parser.parseAttribute(initialValue, tensorType)) 1025 return failure(); 1026 if (!initialValue.isa<ElementsAttr>()) 1027 return parser.emitError(parser.getNameLoc()) 1028 << "initial value should be a unit or elements attribute"; 1029 return success(); 1030 } 1031 1032 static LogicalResult verify(GlobalOp op) { 1033 auto memrefType = op.type().dyn_cast<MemRefType>(); 1034 if (!memrefType || !memrefType.hasStaticShape()) 1035 return op.emitOpError("type should be static shaped memref, but got ") 1036 << op.type(); 1037 1038 // Verify that the initial value, if present, is either a unit attribute or 1039 // an elements attribute. 1040 if (op.initial_value().hasValue()) { 1041 Attribute initValue = op.initial_value().getValue(); 1042 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1043 return op.emitOpError("initial value should be a unit or elements " 1044 "attribute, but got ") 1045 << initValue; 1046 1047 // Check that the type of the initial value is compatible with the type of 1048 // the global variable. 1049 if (initValue.isa<ElementsAttr>()) { 1050 Type initType = initValue.getType(); 1051 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1052 if (initType != tensorType) 1053 return op.emitOpError("initial value expected to be of type ") 1054 << tensorType << ", but was of type " << initType; 1055 } 1056 } 1057 1058 // TODO: verify visibility for declarations. 1059 return success(); 1060 } 1061 1062 //===----------------------------------------------------------------------===// 1063 // GetGlobalOp 1064 //===----------------------------------------------------------------------===// 1065 1066 LogicalResult 1067 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1068 // Verify that the result type is same as the type of the referenced 1069 // memref.global op. 1070 auto global = 1071 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1072 if (!global) 1073 return emitOpError("'") 1074 << name() << "' does not reference a valid global memref"; 1075 1076 Type resultType = result().getType(); 1077 if (global.type() != resultType) 1078 return emitOpError("result type ") 1079 << resultType << " does not match type " << global.type() 1080 << " of the global memref @" << name(); 1081 return success(); 1082 } 1083 1084 //===----------------------------------------------------------------------===// 1085 // LoadOp 1086 //===----------------------------------------------------------------------===// 1087 1088 static LogicalResult verify(LoadOp op) { 1089 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1090 return op.emitOpError("incorrect number of indices for load"); 1091 return success(); 1092 } 1093 1094 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1095 /// load(memrefcast) -> load 1096 if (succeeded(foldMemRefCast(*this))) 1097 return getResult(); 1098 return OpFoldResult(); 1099 } 1100 1101 namespace { 1102 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1103 /// corresponding tensor. 1104 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1105 using OpRewritePattern<LoadOp>::OpRewritePattern; 1106 1107 LogicalResult matchAndRewrite(LoadOp load, 1108 PatternRewriter &rewriter) const override { 1109 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1110 if (!buffercast) 1111 return failure(); 1112 1113 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1114 load.indices()); 1115 return success(); 1116 } 1117 }; 1118 } // end anonymous namespace. 1119 1120 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1121 MLIRContext *context) { 1122 results.add<LoadOfBufferCast>(context); 1123 } 1124 1125 //===----------------------------------------------------------------------===// 1126 // PrefetchOp 1127 //===----------------------------------------------------------------------===// 1128 1129 static void print(OpAsmPrinter &p, PrefetchOp op) { 1130 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1131 p.printOperands(op.indices()); 1132 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1133 p << ", locality<" << op.localityHint(); 1134 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1135 p.printOptionalAttrDict( 1136 op->getAttrs(), 1137 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1138 p << " : " << op.getMemRefType(); 1139 } 1140 1141 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1142 OperationState &result) { 1143 OpAsmParser::OperandType memrefInfo; 1144 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1145 IntegerAttr localityHint; 1146 MemRefType type; 1147 StringRef readOrWrite, cacheType; 1148 1149 auto indexTy = parser.getBuilder().getIndexType(); 1150 auto i32Type = parser.getBuilder().getIntegerType(32); 1151 if (parser.parseOperand(memrefInfo) || 1152 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1153 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1154 parser.parseComma() || parser.parseKeyword("locality") || 1155 parser.parseLess() || 1156 parser.parseAttribute(localityHint, i32Type, "localityHint", 1157 result.attributes) || 1158 parser.parseGreater() || parser.parseComma() || 1159 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1160 parser.resolveOperand(memrefInfo, type, result.operands) || 1161 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1162 return failure(); 1163 1164 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1165 return parser.emitError(parser.getNameLoc(), 1166 "rw specifier has to be 'read' or 'write'"); 1167 result.addAttribute( 1168 PrefetchOp::getIsWriteAttrName(), 1169 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1170 1171 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1172 return parser.emitError(parser.getNameLoc(), 1173 "cache type has to be 'data' or 'instr'"); 1174 1175 result.addAttribute( 1176 PrefetchOp::getIsDataCacheAttrName(), 1177 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1178 1179 return success(); 1180 } 1181 1182 static LogicalResult verify(PrefetchOp op) { 1183 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1184 return op.emitOpError("too few indices"); 1185 1186 return success(); 1187 } 1188 1189 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1190 SmallVectorImpl<OpFoldResult> &results) { 1191 // prefetch(memrefcast) -> prefetch 1192 return foldMemRefCast(*this); 1193 } 1194 1195 //===----------------------------------------------------------------------===// 1196 // ReinterpretCastOp 1197 //===----------------------------------------------------------------------===// 1198 1199 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1200 /// `staticSizes` and `staticStrides` are automatically filled with 1201 /// source-memref-rank sentinel values that encode dynamic entries. 1202 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1203 MemRefType resultType, Value source, 1204 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1205 ArrayRef<OpFoldResult> strides, 1206 ArrayRef<NamedAttribute> attrs) { 1207 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1208 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1209 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1210 ShapedType::kDynamicStrideOrOffset); 1211 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1212 ShapedType::kDynamicSize); 1213 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1214 ShapedType::kDynamicStrideOrOffset); 1215 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1216 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1217 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1218 result.addAttributes(attrs); 1219 } 1220 1221 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1222 MemRefType resultType, Value source, 1223 int64_t offset, ArrayRef<int64_t> sizes, 1224 ArrayRef<int64_t> strides, 1225 ArrayRef<NamedAttribute> attrs) { 1226 SmallVector<OpFoldResult> sizeValues = 1227 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1228 return b.getI64IntegerAttr(v); 1229 })); 1230 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1231 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1232 return b.getI64IntegerAttr(v); 1233 })); 1234 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1235 strideValues, attrs); 1236 } 1237 1238 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1239 MemRefType resultType, Value source, Value offset, 1240 ValueRange sizes, ValueRange strides, 1241 ArrayRef<NamedAttribute> attrs) { 1242 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1243 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1244 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1245 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1246 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1247 } 1248 1249 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1250 // completed automatically, like we have for subview and extract_slice. 1251 static LogicalResult verify(ReinterpretCastOp op) { 1252 // The source and result memrefs should be in the same memory space. 1253 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1254 auto resultType = op.getType().cast<MemRefType>(); 1255 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1256 return op.emitError("different memory spaces specified for source type ") 1257 << srcType << " and result memref type " << resultType; 1258 if (srcType.getElementType() != resultType.getElementType()) 1259 return op.emitError("different element types specified for source type ") 1260 << srcType << " and result memref type " << resultType; 1261 1262 // Match sizes in result memref type and in static_sizes attribute. 1263 for (auto &en : 1264 llvm::enumerate(llvm::zip(resultType.getShape(), 1265 extractFromI64ArrayAttr(op.static_sizes())))) { 1266 int64_t resultSize = std::get<0>(en.value()); 1267 int64_t expectedSize = std::get<1>(en.value()); 1268 if (resultSize != expectedSize) 1269 return op.emitError("expected result type with size = ") 1270 << expectedSize << " instead of " << resultSize 1271 << " in dim = " << en.index(); 1272 } 1273 1274 // Match offset and strides in static_offset and static_strides attributes if 1275 // result memref type has an affine map specified. 1276 if (!resultType.getAffineMaps().empty()) { 1277 int64_t resultOffset; 1278 SmallVector<int64_t, 4> resultStrides; 1279 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1280 return failure(); 1281 1282 // Match offset in result memref type and in static_offsets attribute. 1283 int64_t expectedOffset = 1284 extractFromI64ArrayAttr(op.static_offsets()).front(); 1285 if (resultOffset != expectedOffset) 1286 return op.emitError("expected result type with offset = ") 1287 << resultOffset << " instead of " << expectedOffset; 1288 1289 // Match strides in result memref type and in static_strides attribute. 1290 for (auto &en : llvm::enumerate(llvm::zip( 1291 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1292 int64_t resultStride = std::get<0>(en.value()); 1293 int64_t expectedStride = std::get<1>(en.value()); 1294 if (resultStride != expectedStride) 1295 return op.emitError("expected result type with stride = ") 1296 << expectedStride << " instead of " << resultStride 1297 << " in dim = " << en.index(); 1298 } 1299 } 1300 return success(); 1301 } 1302 1303 //===----------------------------------------------------------------------===// 1304 // ReshapeOp 1305 //===----------------------------------------------------------------------===// 1306 1307 static LogicalResult verify(ReshapeOp op) { 1308 Type operandType = op.source().getType(); 1309 Type resultType = op.result().getType(); 1310 1311 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1312 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1313 if (operandElementType != resultElementType) 1314 return op.emitOpError("element types of source and destination memref " 1315 "types should be the same"); 1316 1317 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1318 if (!operandMemRefType.getAffineMaps().empty()) 1319 return op.emitOpError( 1320 "source memref type should have identity affine map"); 1321 1322 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1323 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1324 if (resultMemRefType) { 1325 if (!resultMemRefType.getAffineMaps().empty()) 1326 return op.emitOpError( 1327 "result memref type should have identity affine map"); 1328 if (shapeSize == ShapedType::kDynamicSize) 1329 return op.emitOpError("cannot use shape operand with dynamic length to " 1330 "reshape to statically-ranked memref type"); 1331 if (shapeSize != resultMemRefType.getRank()) 1332 return op.emitOpError( 1333 "length of shape operand differs from the result's memref rank"); 1334 } 1335 return success(); 1336 } 1337 1338 //===----------------------------------------------------------------------===// 1339 // StoreOp 1340 //===----------------------------------------------------------------------===// 1341 1342 static LogicalResult verify(StoreOp op) { 1343 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1344 return op.emitOpError("store index operand count not equal to memref rank"); 1345 1346 return success(); 1347 } 1348 1349 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1350 SmallVectorImpl<OpFoldResult> &results) { 1351 /// store(memrefcast) -> store 1352 return foldMemRefCast(*this, getValueToStore()); 1353 } 1354 1355 //===----------------------------------------------------------------------===// 1356 // SubViewOp 1357 //===----------------------------------------------------------------------===// 1358 1359 namespace { 1360 /// Helpers to write more idiomatic operations. 1361 namespace saturated_arith { 1362 struct Wrapper { 1363 explicit Wrapper(int64_t v) : v(v) {} 1364 operator int64_t() { return v; } 1365 int64_t v; 1366 }; 1367 Wrapper operator+(Wrapper a, int64_t b) { 1368 if (ShapedType::isDynamicStrideOrOffset(a) || 1369 ShapedType::isDynamicStrideOrOffset(b)) 1370 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1371 return Wrapper(a.v + b); 1372 } 1373 Wrapper operator*(Wrapper a, int64_t b) { 1374 if (ShapedType::isDynamicStrideOrOffset(a) || 1375 ShapedType::isDynamicStrideOrOffset(b)) 1376 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1377 return Wrapper(a.v * b); 1378 } 1379 } // end namespace saturated_arith 1380 } // end namespace 1381 1382 /// A subview result type can be fully inferred from the source type and the 1383 /// static representation of offsets, sizes and strides. Special sentinels 1384 /// encode the dynamic case. 1385 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1386 ArrayRef<int64_t> leadingStaticOffsets, 1387 ArrayRef<int64_t> leadingStaticSizes, 1388 ArrayRef<int64_t> leadingStaticStrides) { 1389 // A subview may specify only a leading subset of offset/sizes/strides in 1390 // which case we complete with offset=0, sizes from memref type and strides=1. 1391 unsigned rank = sourceMemRefType.getRank(); 1392 assert(leadingStaticOffsets.size() <= rank && 1393 "unexpected leadingStaticOffsets overflow"); 1394 assert(leadingStaticSizes.size() <= rank && 1395 "unexpected leadingStaticSizes overflow"); 1396 assert(leadingStaticStrides.size() <= rank && 1397 "unexpected leadingStaticStrides overflow"); 1398 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1399 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1400 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1401 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1402 unsigned numTrailingSizes = rank - staticSizes.size(); 1403 unsigned numTrailingStrides = rank - staticStrides.size(); 1404 staticOffsets.append(numTrailingOffsets, 0); 1405 llvm::append_range(staticSizes, 1406 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1407 staticStrides.append(numTrailingStrides, 1); 1408 1409 // Extract source offset and strides. 1410 int64_t sourceOffset; 1411 SmallVector<int64_t, 4> sourceStrides; 1412 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1413 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1414 (void)res; 1415 1416 // Compute target offset whose value is: 1417 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1418 int64_t targetOffset = sourceOffset; 1419 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1420 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1421 using namespace saturated_arith; 1422 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1423 } 1424 1425 // Compute target stride whose value is: 1426 // `sourceStrides_i * staticStrides_i`. 1427 SmallVector<int64_t, 4> targetStrides; 1428 targetStrides.reserve(staticOffsets.size()); 1429 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1430 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1431 using namespace saturated_arith; 1432 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1433 } 1434 1435 // The type is now known. 1436 return MemRefType::get( 1437 staticSizes, sourceMemRefType.getElementType(), 1438 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1439 sourceMemRefType.getContext()), 1440 sourceMemRefType.getMemorySpace()); 1441 } 1442 1443 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1444 ArrayRef<OpFoldResult> leadingStaticOffsets, 1445 ArrayRef<OpFoldResult> leadingStaticSizes, 1446 ArrayRef<OpFoldResult> leadingStaticStrides) { 1447 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1448 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1449 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1450 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1451 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1452 ShapedType::kDynamicSize); 1453 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1454 staticStrides, ShapedType::kDynamicStrideOrOffset); 1455 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1456 staticSizes, staticStrides) 1457 .cast<MemRefType>(); 1458 } 1459 1460 Type SubViewOp::inferRankReducedResultType( 1461 unsigned resultRank, MemRefType sourceRankedTensorType, 1462 ArrayRef<int64_t> leadingStaticOffsets, 1463 ArrayRef<int64_t> leadingStaticSizes, 1464 ArrayRef<int64_t> leadingStaticStrides) { 1465 auto inferredType = 1466 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1467 leadingStaticSizes, leadingStaticStrides) 1468 .cast<MemRefType>(); 1469 assert(inferredType.getRank() >= resultRank && "expected "); 1470 int rankDiff = inferredType.getRank() - resultRank; 1471 if (rankDiff > 0) { 1472 auto shape = inferredType.getShape(); 1473 llvm::SmallDenseSet<unsigned> dimsToProject; 1474 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1475 SmallVector<int64_t> projectedShape; 1476 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1477 if (!dimsToProject.contains(pos)) 1478 projectedShape.push_back(shape[pos]); 1479 1480 AffineMap map; 1481 auto maps = inferredType.getAffineMaps(); 1482 if (!maps.empty() && maps.front()) 1483 map = getProjectedMap(maps.front(), dimsToProject); 1484 inferredType = 1485 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1486 inferredType.getMemorySpace()); 1487 } 1488 return inferredType; 1489 } 1490 1491 Type SubViewOp::inferRankReducedResultType( 1492 unsigned resultRank, MemRefType sourceRankedTensorType, 1493 ArrayRef<OpFoldResult> leadingStaticOffsets, 1494 ArrayRef<OpFoldResult> leadingStaticSizes, 1495 ArrayRef<OpFoldResult> leadingStaticStrides) { 1496 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1497 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1498 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1499 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1500 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1501 ShapedType::kDynamicSize); 1502 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1503 staticStrides, ShapedType::kDynamicStrideOrOffset); 1504 return SubViewOp::inferRankReducedResultType( 1505 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1506 staticStrides); 1507 } 1508 // Build a SubViewOp with mixed static and dynamic entries and custom result 1509 // type. If the type passed is nullptr, it is inferred. 1510 void SubViewOp::build(OpBuilder &b, OperationState &result, 1511 MemRefType resultType, Value source, 1512 ArrayRef<OpFoldResult> offsets, 1513 ArrayRef<OpFoldResult> sizes, 1514 ArrayRef<OpFoldResult> strides, 1515 ArrayRef<NamedAttribute> attrs) { 1516 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1517 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1518 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1519 ShapedType::kDynamicStrideOrOffset); 1520 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1521 ShapedType::kDynamicSize); 1522 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1523 ShapedType::kDynamicStrideOrOffset); 1524 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1525 // Structuring implementation this way avoids duplication between builders. 1526 if (!resultType) { 1527 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1528 staticSizes, staticStrides) 1529 .cast<MemRefType>(); 1530 } 1531 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1532 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1533 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1534 result.addAttributes(attrs); 1535 } 1536 1537 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1538 // type. 1539 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1540 ArrayRef<OpFoldResult> offsets, 1541 ArrayRef<OpFoldResult> sizes, 1542 ArrayRef<OpFoldResult> strides, 1543 ArrayRef<NamedAttribute> attrs) { 1544 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1545 } 1546 1547 // Build a SubViewOp with static entries and inferred result type. 1548 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1549 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1550 ArrayRef<int64_t> strides, 1551 ArrayRef<NamedAttribute> attrs) { 1552 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1553 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1554 return b.getI64IntegerAttr(v); 1555 })); 1556 SmallVector<OpFoldResult> sizeValues = 1557 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1558 return b.getI64IntegerAttr(v); 1559 })); 1560 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1561 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1562 return b.getI64IntegerAttr(v); 1563 })); 1564 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1565 } 1566 1567 // Build a SubViewOp with dynamic entries and custom result type. If the 1568 // type passed is nullptr, it is inferred. 1569 void SubViewOp::build(OpBuilder &b, OperationState &result, 1570 MemRefType resultType, Value source, 1571 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1572 ArrayRef<int64_t> strides, 1573 ArrayRef<NamedAttribute> attrs) { 1574 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1575 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1576 return b.getI64IntegerAttr(v); 1577 })); 1578 SmallVector<OpFoldResult> sizeValues = 1579 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1580 return b.getI64IntegerAttr(v); 1581 })); 1582 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1583 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1584 return b.getI64IntegerAttr(v); 1585 })); 1586 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1587 attrs); 1588 } 1589 1590 // Build a SubViewOp with dynamic entries and custom result type. If the type 1591 // passed is nullptr, it is inferred. 1592 void SubViewOp::build(OpBuilder &b, OperationState &result, 1593 MemRefType resultType, Value source, ValueRange offsets, 1594 ValueRange sizes, ValueRange strides, 1595 ArrayRef<NamedAttribute> attrs) { 1596 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1597 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1598 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1599 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1600 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1601 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1602 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1603 } 1604 1605 // Build a SubViewOp with dynamic entries and inferred result type. 1606 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1607 ValueRange offsets, ValueRange sizes, ValueRange strides, 1608 ArrayRef<NamedAttribute> attrs) { 1609 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1610 } 1611 1612 /// For ViewLikeOpInterface. 1613 Value SubViewOp::getViewSource() { return source(); } 1614 1615 enum SubViewVerificationResult { 1616 Success, 1617 RankTooLarge, 1618 SizeMismatch, 1619 ElemTypeMismatch, 1620 MemSpaceMismatch, 1621 AffineMapMismatch 1622 }; 1623 1624 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1625 /// This function is slight variant of `is subsequence` algorithm where 1626 /// not matching dimension must be 1. 1627 static SubViewVerificationResult 1628 isRankReducedType(Type originalType, Type candidateReducedType, 1629 std::string *errMsg = nullptr) { 1630 if (originalType == candidateReducedType) 1631 return SubViewVerificationResult::Success; 1632 if (!originalType.isa<MemRefType>()) 1633 return SubViewVerificationResult::Success; 1634 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1635 return SubViewVerificationResult::Success; 1636 1637 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1638 ShapedType candidateReducedShapedType = 1639 candidateReducedType.cast<ShapedType>(); 1640 1641 // Rank and size logic is valid for all ShapedTypes. 1642 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1643 ArrayRef<int64_t> candidateReducedShape = 1644 candidateReducedShapedType.getShape(); 1645 unsigned originalRank = originalShape.size(), 1646 candidateReducedRank = candidateReducedShape.size(); 1647 if (candidateReducedRank > originalRank) 1648 return SubViewVerificationResult::RankTooLarge; 1649 1650 auto optionalUnusedDimsMask = 1651 computeRankReductionMask(originalShape, candidateReducedShape); 1652 1653 // Sizes cannot be matched in case empty vector is returned. 1654 if (!optionalUnusedDimsMask.hasValue()) 1655 return SubViewVerificationResult::SizeMismatch; 1656 1657 if (originalShapedType.getElementType() != 1658 candidateReducedShapedType.getElementType()) 1659 return SubViewVerificationResult::ElemTypeMismatch; 1660 1661 // Strided layout logic is relevant for MemRefType only. 1662 MemRefType original = originalType.cast<MemRefType>(); 1663 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1664 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1665 return SubViewVerificationResult::MemSpaceMismatch; 1666 1667 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1668 auto inferredType = 1669 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1670 AffineMap candidateLayout; 1671 if (candidateReduced.getAffineMaps().empty()) 1672 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1673 else 1674 candidateLayout = candidateReduced.getAffineMaps().front(); 1675 assert(inferredType.getNumResults() == 1 && 1676 candidateLayout.getNumResults() == 1); 1677 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1678 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1679 if (errMsg) { 1680 llvm::raw_string_ostream os(*errMsg); 1681 os << "inferred type: " << inferredType; 1682 } 1683 return SubViewVerificationResult::AffineMapMismatch; 1684 } 1685 // Check that the difference of the affine maps simplifies to 0. 1686 AffineExpr diffExpr = 1687 inferredType.getResult(0) - candidateLayout.getResult(0); 1688 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1689 inferredType.getNumSymbols()); 1690 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1691 if (!(cst && cst.getValue() == 0)) { 1692 if (errMsg) { 1693 llvm::raw_string_ostream os(*errMsg); 1694 os << "inferred type: " << inferredType; 1695 } 1696 return SubViewVerificationResult::AffineMapMismatch; 1697 } 1698 return SubViewVerificationResult::Success; 1699 } 1700 1701 template <typename OpTy> 1702 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1703 OpTy op, Type expectedType, 1704 StringRef errMsg = "") { 1705 auto memrefType = expectedType.cast<ShapedType>(); 1706 switch (result) { 1707 case SubViewVerificationResult::Success: 1708 return success(); 1709 case SubViewVerificationResult::RankTooLarge: 1710 return op.emitError("expected result rank to be smaller or equal to ") 1711 << "the source rank. " << errMsg; 1712 case SubViewVerificationResult::SizeMismatch: 1713 return op.emitError("expected result type to be ") 1714 << expectedType 1715 << " or a rank-reduced version. (mismatch of result sizes) " 1716 << errMsg; 1717 case SubViewVerificationResult::ElemTypeMismatch: 1718 return op.emitError("expected result element type to be ") 1719 << memrefType.getElementType() << errMsg; 1720 case SubViewVerificationResult::MemSpaceMismatch: 1721 return op.emitError("expected result and source memory spaces to match.") 1722 << errMsg; 1723 case SubViewVerificationResult::AffineMapMismatch: 1724 return op.emitError("expected result type to be ") 1725 << expectedType 1726 << " or a rank-reduced version. (mismatch of result affine map) " 1727 << errMsg; 1728 } 1729 llvm_unreachable("unexpected subview verification result"); 1730 } 1731 1732 /// Verifier for SubViewOp. 1733 static LogicalResult verify(SubViewOp op) { 1734 MemRefType baseType = op.getSourceType(); 1735 MemRefType subViewType = op.getType(); 1736 1737 // The base memref and the view memref should be in the same memory space. 1738 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1739 return op.emitError("different memory spaces specified for base memref " 1740 "type ") 1741 << baseType << " and subview memref type " << subViewType; 1742 1743 // Verify that the base memref type has a strided layout map. 1744 if (!isStrided(baseType)) 1745 return op.emitError("base type ") << baseType << " is not strided"; 1746 1747 // Verify result type against inferred type. 1748 auto expectedType = SubViewOp::inferResultType( 1749 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1750 extractFromI64ArrayAttr(op.static_sizes()), 1751 extractFromI64ArrayAttr(op.static_strides())); 1752 1753 std::string errMsg; 1754 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1755 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1756 } 1757 1758 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1759 return os << "range " << range.offset << ":" << range.size << ":" 1760 << range.stride; 1761 } 1762 1763 /// Return the list of Range (i.e. offset, size, stride). Each Range 1764 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1765 /// with `b` at location `loc`. 1766 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1767 OpBuilder &b, Location loc) { 1768 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1769 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1770 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1771 SmallVector<Range, 8> res; 1772 unsigned rank = ranks[0]; 1773 res.reserve(rank); 1774 for (unsigned idx = 0; idx < rank; ++idx) { 1775 Value offset = 1776 op.isDynamicOffset(idx) 1777 ? op.getDynamicOffset(idx) 1778 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1779 Value size = op.isDynamicSize(idx) 1780 ? op.getDynamicSize(idx) 1781 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1782 Value stride = 1783 op.isDynamicStride(idx) 1784 ? op.getDynamicStride(idx) 1785 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1786 res.emplace_back(Range{offset, size, stride}); 1787 } 1788 return res; 1789 } 1790 1791 /// Infer the canonical type of the result of a subview operation. Returns a 1792 /// type with rank `resultRank` that is either the rank of the rank-reduced 1793 /// type, or the non-rank-reduced type. 1794 static MemRefType 1795 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1796 ArrayRef<OpFoldResult> mixedOffsets, 1797 ArrayRef<OpFoldResult> mixedSizes, 1798 ArrayRef<OpFoldResult> mixedStrides) { 1799 auto resultType = 1800 SubViewOp::inferRankReducedResultType( 1801 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1802 .cast<MemRefType>(); 1803 if (resultType.getRank() != resultRank) { 1804 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1805 mixedSizes, mixedStrides) 1806 .cast<MemRefType>(); 1807 } 1808 return resultType; 1809 } 1810 1811 namespace { 1812 /// Pattern to rewrite a subview op with MemRefCast arguments. 1813 /// This essentially pushes memref.cast past its consuming subview when 1814 /// `canFoldIntoConsumerOp` is true. 1815 /// 1816 /// Example: 1817 /// ``` 1818 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1819 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1820 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1821 /// ``` 1822 /// is rewritten into: 1823 /// ``` 1824 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1825 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1826 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1827 /// ``` 1828 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1829 public: 1830 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1831 1832 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1833 PatternRewriter &rewriter) const override { 1834 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1835 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1836 return matchPattern(operand, matchConstantIndex()); 1837 })) 1838 return failure(); 1839 1840 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1841 if (!castOp) 1842 return failure(); 1843 1844 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1845 return failure(); 1846 1847 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1848 /// the cast source operand type and the SubViewOp static information. This 1849 /// is the resulting type if the MemRefCastOp were folded. 1850 auto resultType = getCanonicalSubViewResultType( 1851 subViewOp.getType().getRank(), 1852 castOp.source().getType().cast<MemRefType>(), 1853 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1854 subViewOp.getMixedStrides()); 1855 Value newSubView = rewriter.create<SubViewOp>( 1856 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1857 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1858 subViewOp.static_sizes(), subViewOp.static_strides()); 1859 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1860 newSubView); 1861 return success(); 1862 } 1863 }; 1864 } // namespace 1865 1866 /// Return the canonical type of the result of a subview. 1867 struct SubViewReturnTypeCanonicalizer { 1868 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 1869 ArrayRef<OpFoldResult> mixedSizes, 1870 ArrayRef<OpFoldResult> mixedStrides) { 1871 return getCanonicalSubViewResultType(op.getType().getRank(), 1872 op.getSourceType(), mixedOffsets, 1873 mixedSizes, mixedStrides); 1874 } 1875 }; 1876 1877 /// A canonicalizer wrapper to replace SubViewOps. 1878 struct SubViewCanonicalizer { 1879 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1880 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1881 } 1882 }; 1883 1884 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1885 MLIRContext *context) { 1886 results 1887 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1888 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 1889 SubViewOpMemRefCastFolder>(context); 1890 } 1891 1892 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1893 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1894 auto sourceShapedType = source().getType().cast<ShapedType>(); 1895 1896 if (resultShapedType.hasStaticShape() && 1897 resultShapedType == sourceShapedType) { 1898 return getViewSource(); 1899 } 1900 1901 return {}; 1902 } 1903 1904 //===----------------------------------------------------------------------===// 1905 // TensorLoadOp 1906 //===----------------------------------------------------------------------===// 1907 1908 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1909 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1910 // Approximate alias analysis by conservatively folding only when no there 1911 // is no interleaved operation. 1912 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1913 bufferCast->getNextNode() == this->getOperation()) 1914 return bufferCast.tensor(); 1915 return {}; 1916 } 1917 1918 namespace { 1919 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> { 1920 using OpRewritePattern<tensor::DimOp>::OpRewritePattern; 1921 1922 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 1923 PatternRewriter &rewriter) const override { 1924 auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>(); 1925 if (!tensorLoadOp) 1926 return failure(); 1927 1928 rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(), 1929 dimOp.index()); 1930 return success(); 1931 } 1932 }; 1933 } // namespace 1934 1935 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1936 MLIRContext *context) { 1937 results.add<DimOfTensorLoadFolder>(context); 1938 } 1939 1940 //===----------------------------------------------------------------------===// 1941 // TransposeOp 1942 //===----------------------------------------------------------------------===// 1943 1944 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1945 static MemRefType inferTransposeResultType(MemRefType memRefType, 1946 AffineMap permutationMap) { 1947 auto rank = memRefType.getRank(); 1948 auto originalSizes = memRefType.getShape(); 1949 // Compute permuted sizes. 1950 SmallVector<int64_t, 4> sizes(rank, 0); 1951 for (auto en : llvm::enumerate(permutationMap.getResults())) 1952 sizes[en.index()] = 1953 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 1954 1955 // Compute permuted strides. 1956 int64_t offset; 1957 SmallVector<int64_t, 4> strides; 1958 auto res = getStridesAndOffset(memRefType, strides, offset); 1959 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 1960 (void)res; 1961 auto map = 1962 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 1963 map = permutationMap ? map.compose(permutationMap) : map; 1964 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 1965 } 1966 1967 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 1968 AffineMapAttr permutation, 1969 ArrayRef<NamedAttribute> attrs) { 1970 auto permutationMap = permutation.getValue(); 1971 assert(permutationMap); 1972 1973 auto memRefType = in.getType().cast<MemRefType>(); 1974 // Compute result type. 1975 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 1976 1977 build(b, result, resultType, in, attrs); 1978 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 1979 } 1980 1981 // transpose $in $permutation attr-dict : type($in) `to` type(results) 1982 static void print(OpAsmPrinter &p, TransposeOp op) { 1983 p << "memref.transpose " << op.in() << " " << op.permutation(); 1984 p.printOptionalAttrDict(op->getAttrs(), 1985 {TransposeOp::getPermutationAttrName()}); 1986 p << " : " << op.in().getType() << " to " << op.getType(); 1987 } 1988 1989 static ParseResult parseTransposeOp(OpAsmParser &parser, 1990 OperationState &result) { 1991 OpAsmParser::OperandType in; 1992 AffineMap permutation; 1993 MemRefType srcType, dstType; 1994 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 1995 parser.parseOptionalAttrDict(result.attributes) || 1996 parser.parseColonType(srcType) || 1997 parser.resolveOperand(in, srcType, result.operands) || 1998 parser.parseKeywordType("to", dstType) || 1999 parser.addTypeToList(dstType, result.types)) 2000 return failure(); 2001 2002 result.addAttribute(TransposeOp::getPermutationAttrName(), 2003 AffineMapAttr::get(permutation)); 2004 return success(); 2005 } 2006 2007 static LogicalResult verify(TransposeOp op) { 2008 if (!op.permutation().isPermutation()) 2009 return op.emitOpError("expected a permutation map"); 2010 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2011 return op.emitOpError( 2012 "expected a permutation map of same rank as the input"); 2013 2014 auto srcType = op.in().getType().cast<MemRefType>(); 2015 auto dstType = op.getType().cast<MemRefType>(); 2016 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2017 if (dstType != transposedType) 2018 return op.emitOpError("output type ") 2019 << dstType << " does not match transposed input type " << srcType 2020 << ", " << transposedType; 2021 return success(); 2022 } 2023 2024 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2025 if (succeeded(foldMemRefCast(*this))) 2026 return getResult(); 2027 return {}; 2028 } 2029 2030 //===----------------------------------------------------------------------===// 2031 // ViewOp 2032 //===----------------------------------------------------------------------===// 2033 2034 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2035 OpAsmParser::OperandType srcInfo; 2036 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2037 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2038 auto indexType = parser.getBuilder().getIndexType(); 2039 Type srcType, dstType; 2040 llvm::SMLoc offsetLoc; 2041 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2042 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2043 return failure(); 2044 2045 if (offsetInfo.size() != 1) 2046 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2047 2048 return failure( 2049 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2050 parser.parseOptionalAttrDict(result.attributes) || 2051 parser.parseColonType(srcType) || 2052 parser.resolveOperand(srcInfo, srcType, result.operands) || 2053 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2054 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2055 parser.parseKeywordType("to", dstType) || 2056 parser.addTypeToList(dstType, result.types)); 2057 } 2058 2059 static void print(OpAsmPrinter &p, ViewOp op) { 2060 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2061 p.printOperand(op.byte_shift()); 2062 p << "][" << op.sizes() << ']'; 2063 p.printOptionalAttrDict(op->getAttrs()); 2064 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2065 } 2066 2067 static LogicalResult verify(ViewOp op) { 2068 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2069 auto viewType = op.getType(); 2070 2071 // The base memref should have identity layout map (or none). 2072 if (baseType.getAffineMaps().size() > 1 || 2073 (baseType.getAffineMaps().size() == 1 && 2074 !baseType.getAffineMaps()[0].isIdentity())) 2075 return op.emitError("unsupported map for base memref type ") << baseType; 2076 2077 // The result memref should have identity layout map (or none). 2078 if (viewType.getAffineMaps().size() > 1 || 2079 (viewType.getAffineMaps().size() == 1 && 2080 !viewType.getAffineMaps()[0].isIdentity())) 2081 return op.emitError("unsupported map for result memref type ") << viewType; 2082 2083 // The base memref and the view memref should be in the same memory space. 2084 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2085 return op.emitError("different memory spaces specified for base memref " 2086 "type ") 2087 << baseType << " and view memref type " << viewType; 2088 2089 // Verify that we have the correct number of sizes for the result type. 2090 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2091 if (op.sizes().size() != numDynamicDims) 2092 return op.emitError("incorrect number of size operands for type ") 2093 << viewType; 2094 2095 return success(); 2096 } 2097 2098 Value ViewOp::getViewSource() { return source(); } 2099 2100 namespace { 2101 2102 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2103 using OpRewritePattern<ViewOp>::OpRewritePattern; 2104 2105 LogicalResult matchAndRewrite(ViewOp viewOp, 2106 PatternRewriter &rewriter) const override { 2107 // Return if none of the operands are constants. 2108 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2109 return matchPattern(operand, matchConstantIndex()); 2110 })) 2111 return failure(); 2112 2113 // Get result memref type. 2114 auto memrefType = viewOp.getType(); 2115 2116 // Get offset from old memref view type 'memRefType'. 2117 int64_t oldOffset; 2118 SmallVector<int64_t, 4> oldStrides; 2119 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2120 return failure(); 2121 assert(oldOffset == 0 && "Expected 0 offset"); 2122 2123 SmallVector<Value, 4> newOperands; 2124 2125 // Offset cannot be folded into result type. 2126 2127 // Fold any dynamic dim operands which are produced by a constant. 2128 SmallVector<int64_t, 4> newShapeConstants; 2129 newShapeConstants.reserve(memrefType.getRank()); 2130 2131 unsigned dynamicDimPos = 0; 2132 unsigned rank = memrefType.getRank(); 2133 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2134 int64_t dimSize = memrefType.getDimSize(dim); 2135 // If this is already static dimension, keep it. 2136 if (!ShapedType::isDynamic(dimSize)) { 2137 newShapeConstants.push_back(dimSize); 2138 continue; 2139 } 2140 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2141 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2142 // Dynamic shape dimension will be folded. 2143 newShapeConstants.push_back(constantIndexOp.getValue()); 2144 } else { 2145 // Dynamic shape dimension not folded; copy operand from old memref. 2146 newShapeConstants.push_back(dimSize); 2147 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2148 } 2149 dynamicDimPos++; 2150 } 2151 2152 // Create new memref type with constant folded dims. 2153 MemRefType newMemRefType = 2154 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2155 // Nothing new, don't fold. 2156 if (newMemRefType == memrefType) 2157 return failure(); 2158 2159 // Create new ViewOp. 2160 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2161 viewOp.getOperand(0), 2162 viewOp.byte_shift(), newOperands); 2163 // Insert a cast so we have the same type as the old memref type. 2164 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2165 return success(); 2166 } 2167 }; 2168 2169 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2170 using OpRewritePattern<ViewOp>::OpRewritePattern; 2171 2172 LogicalResult matchAndRewrite(ViewOp viewOp, 2173 PatternRewriter &rewriter) const override { 2174 Value memrefOperand = viewOp.getOperand(0); 2175 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2176 if (!memrefCastOp) 2177 return failure(); 2178 Value allocOperand = memrefCastOp.getOperand(); 2179 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2180 if (!allocOp) 2181 return failure(); 2182 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2183 viewOp.byte_shift(), viewOp.sizes()); 2184 return success(); 2185 } 2186 }; 2187 2188 } // end anonymous namespace 2189 2190 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2191 MLIRContext *context) { 2192 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2193 } 2194 2195 //===----------------------------------------------------------------------===// 2196 // TableGen'd op method definitions 2197 //===----------------------------------------------------------------------===// 2198 2199 #define GET_OP_CLASSES 2200 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2201