1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/ViewLikeInterface.h" 22 #include "llvm/ADT/STLExtras.h" 23 24 using namespace mlir; 25 using namespace mlir::memref; 26 27 /// Materialize a single constant operation from a given attribute value with 28 /// the desired resultant type. 29 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 30 Attribute value, Type type, 31 Location loc) { 32 return builder.create<mlir::ConstantOp>(loc, type, value); 33 } 34 35 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 36 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 37 return llvm::to_vector<4>( 38 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 39 return a.cast<IntegerAttr>().getInt(); 40 })); 41 } 42 43 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 44 /// it is a Value or into `staticVec` if it is an IntegerAttr. 45 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 47 /// come from an AttrSizedOperandSegments trait. 48 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 49 SmallVectorImpl<Value> &dynamicVec, 50 SmallVectorImpl<int64_t> &staticVec, 51 int64_t sentinel) { 52 if (auto v = ofr.dyn_cast<Value>()) { 53 dynamicVec.push_back(v); 54 staticVec.push_back(sentinel); 55 return; 56 } 57 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 58 staticVec.push_back(apInt.getSExtValue()); 59 } 60 61 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 62 SmallVectorImpl<Value> &dynamicVec, 63 SmallVectorImpl<int64_t> &staticVec, 64 int64_t sentinel) { 65 for (auto ofr : ofrs) 66 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Common canonicalization pattern support logic 71 //===----------------------------------------------------------------------===// 72 73 /// This is a common class used for patterns of the form 74 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 75 /// into the root operation directly. 76 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 77 bool folded = false; 78 for (OpOperand &operand : op->getOpOperands()) { 79 auto cast = operand.get().getDefiningOp<CastOp>(); 80 if (cast && operand.get() != inner && 81 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 82 operand.set(cast.getOperand()); 83 folded = true; 84 } 85 } 86 return success(folded); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Helpers for GlobalOp 91 //===----------------------------------------------------------------------===// 92 93 static Type getTensorTypeFromMemRefType(Type type) { 94 if (auto memref = type.dyn_cast<MemRefType>()) 95 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 96 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 97 return UnrankedTensorType::get(memref.getElementType()); 98 return NoneType::get(type.getContext()); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // AllocOp / AllocaOp 103 //===----------------------------------------------------------------------===// 104 105 template <typename AllocLikeOp> 106 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 107 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 108 "applies to only alloc or alloca"); 109 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 110 if (!memRefType) 111 return op.emitOpError("result must be a memref"); 112 113 if (static_cast<int64_t>(op.dynamicSizes().size()) != 114 memRefType.getNumDynamicDims()) 115 return op.emitOpError("dimension operand count does not equal memref " 116 "dynamic dimension count"); 117 118 unsigned numSymbols = 0; 119 if (!memRefType.getAffineMaps().empty()) 120 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 121 if (op.symbolOperands().size() != numSymbols) 122 return op.emitOpError("symbol operand count does not equal memref symbol " 123 "count: expected ") 124 << numSymbols << ", got " << op.symbolOperands().size(); 125 126 return success(); 127 } 128 129 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 130 131 static LogicalResult verify(AllocaOp op) { 132 // An alloca op needs to have an ancestor with an allocation scope trait. 133 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 134 return op.emitOpError( 135 "requires an ancestor op with AutomaticAllocationScope trait"); 136 137 return verifyAllocLikeOp(op); 138 } 139 140 namespace { 141 /// Fold constant dimensions into an alloc like operation. 142 template <typename AllocLikeOp> 143 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 144 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 145 146 LogicalResult matchAndRewrite(AllocLikeOp alloc, 147 PatternRewriter &rewriter) const override { 148 // Check to see if any dimensions operands are constants. If so, we can 149 // substitute and drop them. 150 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { 151 return matchPattern(operand, matchConstantIndex()); 152 })) 153 return failure(); 154 155 auto memrefType = alloc.getType(); 156 157 // Ok, we have one or more constant operands. Collect the non-constant ones 158 // and keep track of the resultant memref type to build. 159 SmallVector<int64_t, 4> newShapeConstants; 160 newShapeConstants.reserve(memrefType.getRank()); 161 SmallVector<Value, 4> dynamicSizes; 162 163 unsigned dynamicDimPos = 0; 164 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 165 int64_t dimSize = memrefType.getDimSize(dim); 166 // If this is already static dimension, keep it. 167 if (dimSize != -1) { 168 newShapeConstants.push_back(dimSize); 169 continue; 170 } 171 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; 172 auto *defOp = dynamicSize.getDefiningOp(); 173 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 174 // Dynamic shape dimension will be folded. 175 newShapeConstants.push_back(constantIndexOp.getValue()); 176 } else { 177 // Dynamic shape dimension not folded; copy dynamicSize from old memref. 178 newShapeConstants.push_back(-1); 179 dynamicSizes.push_back(dynamicSize); 180 } 181 dynamicDimPos++; 182 } 183 184 // Create new memref type (which will have fewer dynamic dimensions). 185 MemRefType newMemRefType = 186 MemRefType::Builder(memrefType).setShape(newShapeConstants); 187 assert(static_cast<int64_t>(dynamicSizes.size()) == 188 newMemRefType.getNumDynamicDims()); 189 190 // Create and insert the alloc op for the new memref. 191 auto newAlloc = rewriter.create<AllocLikeOp>( 192 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), 193 alloc.alignmentAttr()); 194 // Insert a cast so we have the same type as the old alloc. 195 auto resultCast = 196 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 197 198 rewriter.replaceOp(alloc, {resultCast}); 199 return success(); 200 } 201 }; 202 203 /// Fold alloc operations with no users or only store and dealloc uses. 204 template <typename T> 205 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 206 using OpRewritePattern<T>::OpRewritePattern; 207 208 LogicalResult matchAndRewrite(T alloc, 209 PatternRewriter &rewriter) const override { 210 if (llvm::any_of(alloc->getUsers(), [](Operation *op) { 211 return !isa<StoreOp, DeallocOp>(op); 212 })) 213 return failure(); 214 215 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 216 rewriter.eraseOp(user); 217 218 rewriter.eraseOp(alloc); 219 return success(); 220 } 221 }; 222 } // end anonymous namespace. 223 224 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 225 MLIRContext *context) { 226 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 227 } 228 229 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 230 MLIRContext *context) { 231 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 232 context); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // AllocaScopeOp 237 //===----------------------------------------------------------------------===// 238 239 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 240 bool printBlockTerminators = false; 241 242 p << AllocaScopeOp::getOperationName() << " "; 243 if (!op.results().empty()) { 244 p << " -> (" << op.getResultTypes() << ")"; 245 printBlockTerminators = true; 246 } 247 p.printRegion(op.bodyRegion(), 248 /*printEntryBlockArgs=*/false, 249 /*printBlockTerminators=*/printBlockTerminators); 250 p.printOptionalAttrDict(op->getAttrs()); 251 } 252 253 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 254 OperationState &result) { 255 // Create a region for the body. 256 result.regions.reserve(1); 257 Region *bodyRegion = result.addRegion(); 258 259 // Parse optional results type list. 260 if (parser.parseOptionalArrowTypeList(result.types)) 261 return failure(); 262 263 // Parse the body region. 264 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 265 return failure(); 266 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 267 result.location); 268 269 // Parse the optional attribute list. 270 if (parser.parseOptionalAttrDict(result.attributes)) 271 return failure(); 272 273 return success(); 274 } 275 276 static LogicalResult verify(AllocaScopeOp op) { 277 if (failed(RegionBranchOpInterface::verifyTypes(op))) 278 return failure(); 279 280 return success(); 281 } 282 283 void AllocaScopeOp::getSuccessorRegions( 284 Optional<unsigned> index, ArrayRef<Attribute> operands, 285 SmallVectorImpl<RegionSuccessor> ®ions) { 286 if (index.hasValue()) { 287 regions.push_back(RegionSuccessor(getResults())); 288 return; 289 } 290 291 regions.push_back(RegionSuccessor(&bodyRegion())); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // AssumeAlignmentOp 296 //===----------------------------------------------------------------------===// 297 298 static LogicalResult verify(AssumeAlignmentOp op) { 299 unsigned alignment = op.alignment(); 300 if (!llvm::isPowerOf2_32(alignment)) 301 return op.emitOpError("alignment must be power of 2"); 302 return success(); 303 } 304 305 //===----------------------------------------------------------------------===// 306 // BufferCastOp 307 //===----------------------------------------------------------------------===// 308 309 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 310 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 311 if (tensorLoad.memref().getType() == getType()) 312 return tensorLoad.memref(); 313 return {}; 314 } 315 316 namespace { 317 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 318 struct BufferCast : public OpRewritePattern<BufferCastOp> { 319 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 320 321 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 322 PatternRewriter &rewriter) const final { 323 auto tensorCastOperand = 324 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 325 if (!tensorCastOperand) 326 return failure(); 327 auto srcTensorType = 328 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 329 if (!srcTensorType) 330 return failure(); 331 auto memrefType = MemRefType::get(srcTensorType.getShape(), 332 srcTensorType.getElementType()); 333 Value memref = rewriter.create<BufferCastOp>( 334 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 335 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 336 memref); 337 return success(); 338 } 339 }; 340 341 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 342 /// type mismatches prevent `BufferCastOp::fold` to kick in. 343 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 344 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 345 346 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 347 PatternRewriter &rewriter) const final { 348 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 349 // Bail unless we have a tensor_load + memref.buffer_cast with different 350 // types. `BufferCastOp::fold` handles the same type case. 351 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 352 return failure(); 353 // If types are not cast-compatible, bail. 354 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 355 bufferCast.getType())) 356 return failure(); 357 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 358 tensorLoad.memref()); 359 return success(); 360 } 361 }; 362 363 } // namespace 364 365 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 366 MLIRContext *context) { 367 results.add<BufferCast, TensorLoadToMemRef>(context); 368 } 369 370 //===----------------------------------------------------------------------===// 371 // CastOp 372 //===----------------------------------------------------------------------===// 373 374 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 375 /// source memref. This is useful to to fold a memref.cast into a consuming op 376 /// and implement canonicalization patterns for ops in different dialects that 377 /// may consume the results of memref.cast operations. Such foldable memref.cast 378 /// operations are typically inserted as `view` and `subview` ops are 379 /// canonicalized, to preserve the type compatibility of their uses. 380 /// 381 /// Returns true when all conditions are met: 382 /// 1. source and result are ranked memrefs with strided semantics and same 383 /// element type and rank. 384 /// 2. each of the source's size, offset or stride has more static information 385 /// than the corresponding result's size, offset or stride. 386 /// 387 /// Example 1: 388 /// ```mlir 389 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 390 /// %2 = consumer %1 ... : memref<?x?xf32> ... 391 /// ``` 392 /// 393 /// may fold into: 394 /// 395 /// ```mlir 396 /// %2 = consumer %0 ... : memref<8x16xf32> ... 397 /// ``` 398 /// 399 /// Example 2: 400 /// ``` 401 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 402 /// to memref<?x?xf32> 403 /// consumer %1 : memref<?x?xf32> ... 404 /// ``` 405 /// 406 /// may fold into: 407 /// 408 /// ``` 409 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 410 /// ``` 411 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 412 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 413 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 414 415 // Requires ranked MemRefType. 416 if (!sourceType || !resultType) 417 return false; 418 419 // Requires same elemental type. 420 if (sourceType.getElementType() != resultType.getElementType()) 421 return false; 422 423 // Requires same rank. 424 if (sourceType.getRank() != resultType.getRank()) 425 return false; 426 427 // Only fold casts between strided memref forms. 428 int64_t sourceOffset, resultOffset; 429 SmallVector<int64_t, 4> sourceStrides, resultStrides; 430 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 431 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 432 return false; 433 434 // If cast is towards more static sizes along any dimension, don't fold. 435 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 436 auto ss = std::get<0>(it), st = std::get<1>(it); 437 if (ss != st) 438 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 439 return false; 440 } 441 442 // If cast is towards more static offset along any dimension, don't fold. 443 if (sourceOffset != resultOffset) 444 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 445 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 446 return false; 447 448 // If cast is towards more static strides along any dimension, don't fold. 449 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 450 auto ss = std::get<0>(it), st = std::get<1>(it); 451 if (ss != st) 452 if (MemRefType::isDynamicStrideOrOffset(ss) && 453 !MemRefType::isDynamicStrideOrOffset(st)) 454 return false; 455 } 456 457 return true; 458 } 459 460 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 461 if (inputs.size() != 1 || outputs.size() != 1) 462 return false; 463 Type a = inputs.front(), b = outputs.front(); 464 auto aT = a.dyn_cast<MemRefType>(); 465 auto bT = b.dyn_cast<MemRefType>(); 466 467 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 468 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 469 470 if (aT && bT) { 471 if (aT.getElementType() != bT.getElementType()) 472 return false; 473 if (aT.getAffineMaps() != bT.getAffineMaps()) { 474 int64_t aOffset, bOffset; 475 SmallVector<int64_t, 4> aStrides, bStrides; 476 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 477 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 478 aStrides.size() != bStrides.size()) 479 return false; 480 481 // Strides along a dimension/offset are compatible if the value in the 482 // source memref is static and the value in the target memref is the 483 // same. They are also compatible if either one is dynamic (see 484 // description of MemRefCastOp for details). 485 auto checkCompatible = [](int64_t a, int64_t b) { 486 return (a == MemRefType::getDynamicStrideOrOffset() || 487 b == MemRefType::getDynamicStrideOrOffset() || a == b); 488 }; 489 if (!checkCompatible(aOffset, bOffset)) 490 return false; 491 for (auto aStride : enumerate(aStrides)) 492 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 493 return false; 494 } 495 if (aT.getMemorySpace() != bT.getMemorySpace()) 496 return false; 497 498 // They must have the same rank, and any specified dimensions must match. 499 if (aT.getRank() != bT.getRank()) 500 return false; 501 502 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 503 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 504 if (aDim != -1 && bDim != -1 && aDim != bDim) 505 return false; 506 } 507 return true; 508 } else { 509 if (!aT && !uaT) 510 return false; 511 if (!bT && !ubT) 512 return false; 513 // Unranked to unranked casting is unsupported 514 if (uaT && ubT) 515 return false; 516 517 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 518 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 519 if (aEltType != bEltType) 520 return false; 521 522 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 523 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 524 if (aMemSpace != bMemSpace) 525 return false; 526 527 return true; 528 } 529 530 return false; 531 } 532 533 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 534 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 535 } 536 537 //===----------------------------------------------------------------------===// 538 // CloneOp 539 //===----------------------------------------------------------------------===// 540 541 void CloneOp::getEffects( 542 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 543 &effects) { 544 effects.emplace_back(MemoryEffects::Read::get(), input(), 545 SideEffects::DefaultResource::get()); 546 effects.emplace_back(MemoryEffects::Write::get(), output(), 547 SideEffects::DefaultResource::get()); 548 } 549 550 namespace { 551 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 552 /// by other Dealloc operations. 553 struct SimplifyClones : public OpRewritePattern<CloneOp> { 554 using OpRewritePattern<CloneOp>::OpRewritePattern; 555 556 LogicalResult matchAndRewrite(CloneOp cloneOp, 557 PatternRewriter &rewriter) const override { 558 if (cloneOp.use_empty()) { 559 rewriter.eraseOp(cloneOp); 560 return success(); 561 } 562 563 Value source = cloneOp.input(); 564 565 // This only finds dealloc operations for the immediate value. It should 566 // also consider aliases. That would also make the safety check below 567 // redundant. 568 Operation *cloneDeallocOp = findDealloc(cloneOp.output()); 569 Operation *sourceDeallocOp = findDealloc(source); 570 571 // If both are deallocated in the same block, their in-block lifetimes 572 // might not fully overlap, so we cannot decide which one to drop. 573 if (cloneDeallocOp && sourceDeallocOp && 574 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 575 return failure(); 576 577 Block *currentBlock = cloneOp->getBlock(); 578 Operation *redundantDealloc = nullptr; 579 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 580 redundantDealloc = cloneDeallocOp; 581 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 582 redundantDealloc = sourceDeallocOp; 583 } 584 585 if (!redundantDealloc) 586 return failure(); 587 588 // Safety check that there are no other deallocations inbetween 589 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 590 // of source before the uses of the clone. With alias information, we could 591 // restrict this to only fail of the dealloc's operand is an alias 592 // of the source. 593 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 594 pos = pos->getNextNode()) { 595 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 596 if (!effectInterface) 597 continue; 598 if (effectInterface.hasEffect<MemoryEffects::Free>()) 599 return failure(); 600 } 601 602 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 603 source); 604 rewriter.eraseOp(redundantDealloc); 605 return success(); 606 } 607 }; 608 609 } // end anonymous namespace. 610 611 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 612 MLIRContext *context) { 613 results.insert<SimplifyClones>(context); 614 } 615 616 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 617 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // DeallocOp 622 //===----------------------------------------------------------------------===// 623 624 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 625 SmallVectorImpl<OpFoldResult> &results) { 626 /// dealloc(memrefcast) -> dealloc 627 return foldMemRefCast(*this); 628 } 629 630 //===----------------------------------------------------------------------===// 631 // DimOp 632 //===----------------------------------------------------------------------===// 633 634 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 635 int64_t index) { 636 auto loc = result.location; 637 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 638 build(builder, result, memref, indexValue); 639 } 640 641 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 642 Value index) { 643 auto indexTy = builder.getIndexType(); 644 build(builder, result, indexTy, memref, index); 645 } 646 647 Optional<int64_t> DimOp::getConstantIndex() { 648 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 649 return constantOp.getValue().cast<IntegerAttr>().getInt(); 650 return {}; 651 } 652 653 static LogicalResult verify(DimOp op) { 654 // Assume unknown index to be in range. 655 Optional<int64_t> index = op.getConstantIndex(); 656 if (!index.hasValue()) 657 return success(); 658 659 // Check that constant index is not knowingly out of range. 660 auto type = op.memrefOrTensor().getType(); 661 if (auto memrefType = type.dyn_cast<MemRefType>()) { 662 if (index.getValue() >= memrefType.getRank()) 663 return op.emitOpError("index is out of range"); 664 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 665 if (index.getValue() >= tensorType.getRank()) 666 return op.emitOpError("index is out of range"); 667 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 668 // Assume index to be in range. 669 } else { 670 llvm_unreachable("expected operand with memref type"); 671 } 672 return success(); 673 } 674 675 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 676 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 677 678 // All forms of folding require a known index. 679 if (!index) 680 return {}; 681 682 auto argTy = memrefOrTensor().getType(); 683 // Fold if the shape extent along the given index is known. 684 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 685 // Folding for unranked types (UnrankedMemRefType) is not supported. 686 if (!shapedTy.hasRank()) 687 return {}; 688 if (!shapedTy.isDynamicDim(index.getInt())) { 689 Builder builder(getContext()); 690 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 691 } 692 } 693 694 Operation *definingOp = memrefOrTensor().getDefiningOp(); 695 696 // dim(memref.tensor_load(memref)) -> dim(memref) 697 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 698 setOperand(0, tensorLoadOp.memref()); 699 return getResult(); 700 } 701 702 // Fold dim to the operand of tensor.generate. 703 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 704 auto resultType = 705 fromElements.getResult().getType().cast<RankedTensorType>(); 706 // The case where the type encodes the size of the dimension is handled 707 // above. 708 assert(resultType.getShape()[index.getInt()] == 709 RankedTensorType::kDynamicSize); 710 711 // Find the operand of the fromElements that corresponds to this index. 712 auto dynExtents = fromElements.dynamicExtents().begin(); 713 for (auto dim : resultType.getShape().take_front(index.getInt())) 714 if (dim == RankedTensorType::kDynamicSize) 715 dynExtents++; 716 717 return Value{*dynExtents}; 718 } 719 720 // The size at the given index is now known to be a dynamic size. 721 unsigned unsignedIndex = index.getValue().getZExtValue(); 722 723 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) { 724 assert(sliceOp.isDynamicSize(unsignedIndex) && 725 "Expected dynamic slice size"); 726 return sliceOp.getDynamicSize(unsignedIndex); 727 } 728 729 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 730 auto memrefType = argTy.dyn_cast<MemRefType>(); 731 if (!memrefType) 732 return {}; 733 734 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 735 return *(alloc.getDynamicSizes().begin() + 736 memrefType.getDynamicDimIndex(unsignedIndex)); 737 738 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 739 return *(alloca.getDynamicSizes().begin() + 740 memrefType.getDynamicDimIndex(unsignedIndex)); 741 742 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 743 return *(view.getDynamicSizes().begin() + 744 memrefType.getDynamicDimIndex(unsignedIndex)); 745 746 if (auto sizeInterface = 747 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 748 assert(sizeInterface.isDynamicSize(unsignedIndex) && 749 "Expected dynamic subview size"); 750 return sizeInterface.getDynamicSize(unsignedIndex); 751 } 752 753 // dim(memrefcast) -> dim 754 if (succeeded(foldMemRefCast(*this))) 755 return getResult(); 756 757 return {}; 758 } 759 760 namespace { 761 /// Fold dim of a memref reshape operation to a load into the reshape's shape 762 /// operand. 763 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 764 using OpRewritePattern<DimOp>::OpRewritePattern; 765 766 LogicalResult matchAndRewrite(DimOp dim, 767 PatternRewriter &rewriter) const override { 768 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 769 770 if (!reshape) 771 return failure(); 772 773 // Place the load directly after the reshape to ensure that the shape memref 774 // was not mutated. 775 rewriter.setInsertionPointAfter(reshape); 776 Location loc = dim.getLoc(); 777 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 778 if (load.getType() != dim.getType()) 779 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load); 780 rewriter.replaceOp(dim, load); 781 return success(); 782 } 783 }; 784 785 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 786 template <typename CastOpTy> 787 struct DimOfCastOp : public OpRewritePattern<DimOp> { 788 using OpRewritePattern<DimOp>::OpRewritePattern; 789 790 LogicalResult matchAndRewrite(DimOp dimOp, 791 PatternRewriter &rewriter) const override { 792 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 793 if (!castOp) 794 return failure(); 795 Value newSource = castOp.getOperand(); 796 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 797 return success(); 798 } 799 }; 800 } // end anonymous namespace. 801 802 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 803 MLIRContext *context) { 804 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 805 DimOfCastOp<tensor::CastOp>>(context); 806 } 807 808 // --------------------------------------------------------------------------- 809 // DmaStartOp 810 // --------------------------------------------------------------------------- 811 812 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 813 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 814 ValueRange destIndices, Value numElements, 815 Value tagMemRef, ValueRange tagIndices, Value stride, 816 Value elementsPerStride) { 817 result.addOperands(srcMemRef); 818 result.addOperands(srcIndices); 819 result.addOperands(destMemRef); 820 result.addOperands(destIndices); 821 result.addOperands({numElements, tagMemRef}); 822 result.addOperands(tagIndices); 823 if (stride) 824 result.addOperands({stride, elementsPerStride}); 825 } 826 827 void DmaStartOp::print(OpAsmPrinter &p) { 828 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 829 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 830 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 831 << ']'; 832 if (isStrided()) 833 p << ", " << getStride() << ", " << getNumElementsPerStride(); 834 835 p.printOptionalAttrDict((*this)->getAttrs()); 836 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 837 << ", " << getTagMemRef().getType(); 838 } 839 840 // Parse DmaStartOp. 841 // Ex: 842 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 843 // %tag[%index], %stride, %num_elt_per_stride : 844 // : memref<3076 x f32, 0>, 845 // memref<1024 x f32, 2>, 846 // memref<1 x i32> 847 // 848 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 849 OpAsmParser::OperandType srcMemRefInfo; 850 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 851 OpAsmParser::OperandType dstMemRefInfo; 852 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 853 OpAsmParser::OperandType numElementsInfo; 854 OpAsmParser::OperandType tagMemrefInfo; 855 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 856 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 857 858 SmallVector<Type, 3> types; 859 auto indexType = parser.getBuilder().getIndexType(); 860 861 // Parse and resolve the following list of operands: 862 // *) source memref followed by its indices (in square brackets). 863 // *) destination memref followed by its indices (in square brackets). 864 // *) dma size in KiB. 865 if (parser.parseOperand(srcMemRefInfo) || 866 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 867 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 868 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 869 parser.parseComma() || parser.parseOperand(numElementsInfo) || 870 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 871 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 872 return failure(); 873 874 // Parse optional stride and elements per stride. 875 if (parser.parseTrailingOperandList(strideInfo)) 876 return failure(); 877 878 bool isStrided = strideInfo.size() == 2; 879 if (!strideInfo.empty() && !isStrided) { 880 return parser.emitError(parser.getNameLoc(), 881 "expected two stride related operands"); 882 } 883 884 if (parser.parseColonTypeList(types)) 885 return failure(); 886 if (types.size() != 3) 887 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 888 889 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 890 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 891 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 892 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 893 // size should be an index. 894 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 895 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 896 // tag indices should be index. 897 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 898 return failure(); 899 900 if (isStrided) { 901 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 902 return failure(); 903 } 904 905 return success(); 906 } 907 908 LogicalResult DmaStartOp::verify() { 909 unsigned numOperands = getNumOperands(); 910 911 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 912 // the number of elements. 913 if (numOperands < 4) 914 return emitOpError("expected at least 4 operands"); 915 916 // Check types of operands. The order of these calls is important: the later 917 // calls rely on some type properties to compute the operand position. 918 // 1. Source memref. 919 if (!getSrcMemRef().getType().isa<MemRefType>()) 920 return emitOpError("expected source to be of memref type"); 921 if (numOperands < getSrcMemRefRank() + 4) 922 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 923 << " operands"; 924 if (!getSrcIndices().empty() && 925 !llvm::all_of(getSrcIndices().getTypes(), 926 [](Type t) { return t.isIndex(); })) 927 return emitOpError("expected source indices to be of index type"); 928 929 // 2. Destination memref. 930 if (!getDstMemRef().getType().isa<MemRefType>()) 931 return emitOpError("expected destination to be of memref type"); 932 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 933 if (numOperands < numExpectedOperands) 934 return emitOpError() << "expected at least " << numExpectedOperands 935 << " operands"; 936 if (!getDstIndices().empty() && 937 !llvm::all_of(getDstIndices().getTypes(), 938 [](Type t) { return t.isIndex(); })) 939 return emitOpError("expected destination indices to be of index type"); 940 941 // 3. Number of elements. 942 if (!getNumElements().getType().isIndex()) 943 return emitOpError("expected num elements to be of index type"); 944 945 // 4. Tag memref. 946 if (!getTagMemRef().getType().isa<MemRefType>()) 947 return emitOpError("expected tag to be of memref type"); 948 numExpectedOperands += getTagMemRefRank(); 949 if (numOperands < numExpectedOperands) 950 return emitOpError() << "expected at least " << numExpectedOperands 951 << " operands"; 952 if (!getTagIndices().empty() && 953 !llvm::all_of(getTagIndices().getTypes(), 954 [](Type t) { return t.isIndex(); })) 955 return emitOpError("expected tag indices to be of index type"); 956 957 // Optional stride-related operands must be either both present or both 958 // absent. 959 if (numOperands != numExpectedOperands && 960 numOperands != numExpectedOperands + 2) 961 return emitOpError("incorrect number of operands"); 962 963 // 5. Strides. 964 if (isStrided()) { 965 if (!getStride().getType().isIndex() || 966 !getNumElementsPerStride().getType().isIndex()) 967 return emitOpError( 968 "expected stride and num elements per stride to be of type index"); 969 } 970 971 return success(); 972 } 973 974 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 975 SmallVectorImpl<OpFoldResult> &results) { 976 /// dma_start(memrefcast) -> dma_start 977 return foldMemRefCast(*this); 978 } 979 980 // --------------------------------------------------------------------------- 981 // DmaWaitOp 982 // --------------------------------------------------------------------------- 983 984 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 985 Value tagMemRef, ValueRange tagIndices, 986 Value numElements) { 987 result.addOperands(tagMemRef); 988 result.addOperands(tagIndices); 989 result.addOperands(numElements); 990 } 991 992 void DmaWaitOp::print(OpAsmPrinter &p) { 993 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 994 << "], " << getNumElements(); 995 p.printOptionalAttrDict((*this)->getAttrs()); 996 p << " : " << getTagMemRef().getType(); 997 } 998 999 // Parse DmaWaitOp. 1000 // Eg: 1001 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 1002 // 1003 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 1004 OpAsmParser::OperandType tagMemrefInfo; 1005 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 1006 Type type; 1007 auto indexType = parser.getBuilder().getIndexType(); 1008 OpAsmParser::OperandType numElementsInfo; 1009 1010 // Parse tag memref, its indices, and dma size. 1011 if (parser.parseOperand(tagMemrefInfo) || 1012 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 1013 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1014 parser.parseColonType(type) || 1015 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 1016 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 1017 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1018 return failure(); 1019 1020 return success(); 1021 } 1022 1023 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1024 SmallVectorImpl<OpFoldResult> &results) { 1025 /// dma_wait(memrefcast) -> dma_wait 1026 return foldMemRefCast(*this); 1027 } 1028 1029 LogicalResult DmaWaitOp::verify() { 1030 // Mandatory non-variadic operands are tag and the number of elements. 1031 if (getNumOperands() < 2) 1032 return emitOpError() << "expected at least 2 operands"; 1033 1034 // Check types of operands. The order of these calls is important: the later 1035 // calls rely on some type properties to compute the operand position. 1036 if (!getTagMemRef().getType().isa<MemRefType>()) 1037 return emitOpError() << "expected tag to be of memref type"; 1038 1039 if (getNumOperands() != 2 + getTagMemRefRank()) 1040 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1041 << " operands"; 1042 1043 if (!getTagIndices().empty() && 1044 !llvm::all_of(getTagIndices().getTypes(), 1045 [](Type t) { return t.isIndex(); })) 1046 return emitOpError() << "expected tag indices to be of index type"; 1047 1048 if (!getNumElements().getType().isIndex()) 1049 return emitOpError() 1050 << "expected the number of elements to be of index type"; 1051 1052 return success(); 1053 } 1054 1055 //===----------------------------------------------------------------------===// 1056 // GlobalOp 1057 //===----------------------------------------------------------------------===// 1058 1059 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1060 TypeAttr type, 1061 Attribute initialValue) { 1062 p << type; 1063 if (!op.isExternal()) { 1064 p << " = "; 1065 if (op.isUninitialized()) 1066 p << "uninitialized"; 1067 else 1068 p.printAttributeWithoutType(initialValue); 1069 } 1070 } 1071 1072 static ParseResult 1073 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1074 Attribute &initialValue) { 1075 Type type; 1076 if (parser.parseType(type)) 1077 return failure(); 1078 1079 auto memrefType = type.dyn_cast<MemRefType>(); 1080 if (!memrefType || !memrefType.hasStaticShape()) 1081 return parser.emitError(parser.getNameLoc()) 1082 << "type should be static shaped memref, but got " << type; 1083 typeAttr = TypeAttr::get(type); 1084 1085 if (parser.parseOptionalEqual()) 1086 return success(); 1087 1088 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1089 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1090 return success(); 1091 } 1092 1093 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1094 if (parser.parseAttribute(initialValue, tensorType)) 1095 return failure(); 1096 if (!initialValue.isa<ElementsAttr>()) 1097 return parser.emitError(parser.getNameLoc()) 1098 << "initial value should be a unit or elements attribute"; 1099 return success(); 1100 } 1101 1102 static LogicalResult verify(GlobalOp op) { 1103 auto memrefType = op.type().dyn_cast<MemRefType>(); 1104 if (!memrefType || !memrefType.hasStaticShape()) 1105 return op.emitOpError("type should be static shaped memref, but got ") 1106 << op.type(); 1107 1108 // Verify that the initial value, if present, is either a unit attribute or 1109 // an elements attribute. 1110 if (op.initial_value().hasValue()) { 1111 Attribute initValue = op.initial_value().getValue(); 1112 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1113 return op.emitOpError("initial value should be a unit or elements " 1114 "attribute, but got ") 1115 << initValue; 1116 1117 // Check that the type of the initial value is compatible with the type of 1118 // the global variable. 1119 if (initValue.isa<ElementsAttr>()) { 1120 Type initType = initValue.getType(); 1121 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1122 if (initType != tensorType) 1123 return op.emitOpError("initial value expected to be of type ") 1124 << tensorType << ", but was of type " << initType; 1125 } 1126 } 1127 1128 // TODO: verify visibility for declarations. 1129 return success(); 1130 } 1131 1132 //===----------------------------------------------------------------------===// 1133 // GetGlobalOp 1134 //===----------------------------------------------------------------------===// 1135 1136 LogicalResult 1137 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1138 // Verify that the result type is same as the type of the referenced 1139 // memref.global op. 1140 auto global = 1141 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1142 if (!global) 1143 return emitOpError("'") 1144 << name() << "' does not reference a valid global memref"; 1145 1146 Type resultType = result().getType(); 1147 if (global.type() != resultType) 1148 return emitOpError("result type ") 1149 << resultType << " does not match type " << global.type() 1150 << " of the global memref @" << name(); 1151 return success(); 1152 } 1153 1154 //===----------------------------------------------------------------------===// 1155 // LoadOp 1156 //===----------------------------------------------------------------------===// 1157 1158 static LogicalResult verify(LoadOp op) { 1159 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1160 return op.emitOpError("incorrect number of indices for load"); 1161 return success(); 1162 } 1163 1164 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1165 /// load(memrefcast) -> load 1166 if (succeeded(foldMemRefCast(*this))) 1167 return getResult(); 1168 return OpFoldResult(); 1169 } 1170 1171 namespace { 1172 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1173 /// corresponding tensor. 1174 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1175 using OpRewritePattern<LoadOp>::OpRewritePattern; 1176 1177 LogicalResult matchAndRewrite(LoadOp load, 1178 PatternRewriter &rewriter) const override { 1179 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1180 if (!buffercast) 1181 return failure(); 1182 1183 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1184 load.indices()); 1185 return success(); 1186 } 1187 }; 1188 } // end anonymous namespace. 1189 1190 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1191 MLIRContext *context) { 1192 results.add<LoadOfBufferCast>(context); 1193 } 1194 1195 //===----------------------------------------------------------------------===// 1196 // PrefetchOp 1197 //===----------------------------------------------------------------------===// 1198 1199 static void print(OpAsmPrinter &p, PrefetchOp op) { 1200 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1201 p.printOperands(op.indices()); 1202 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1203 p << ", locality<" << op.localityHint(); 1204 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1205 p.printOptionalAttrDict( 1206 op->getAttrs(), 1207 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1208 p << " : " << op.getMemRefType(); 1209 } 1210 1211 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1212 OperationState &result) { 1213 OpAsmParser::OperandType memrefInfo; 1214 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1215 IntegerAttr localityHint; 1216 MemRefType type; 1217 StringRef readOrWrite, cacheType; 1218 1219 auto indexTy = parser.getBuilder().getIndexType(); 1220 auto i32Type = parser.getBuilder().getIntegerType(32); 1221 if (parser.parseOperand(memrefInfo) || 1222 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1223 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1224 parser.parseComma() || parser.parseKeyword("locality") || 1225 parser.parseLess() || 1226 parser.parseAttribute(localityHint, i32Type, "localityHint", 1227 result.attributes) || 1228 parser.parseGreater() || parser.parseComma() || 1229 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1230 parser.resolveOperand(memrefInfo, type, result.operands) || 1231 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1232 return failure(); 1233 1234 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1235 return parser.emitError(parser.getNameLoc(), 1236 "rw specifier has to be 'read' or 'write'"); 1237 result.addAttribute( 1238 PrefetchOp::getIsWriteAttrName(), 1239 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1240 1241 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1242 return parser.emitError(parser.getNameLoc(), 1243 "cache type has to be 'data' or 'instr'"); 1244 1245 result.addAttribute( 1246 PrefetchOp::getIsDataCacheAttrName(), 1247 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1248 1249 return success(); 1250 } 1251 1252 static LogicalResult verify(PrefetchOp op) { 1253 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1254 return op.emitOpError("too few indices"); 1255 1256 return success(); 1257 } 1258 1259 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1260 SmallVectorImpl<OpFoldResult> &results) { 1261 // prefetch(memrefcast) -> prefetch 1262 return foldMemRefCast(*this); 1263 } 1264 1265 //===----------------------------------------------------------------------===// 1266 // ReinterpretCastOp 1267 //===----------------------------------------------------------------------===// 1268 1269 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1270 /// `staticSizes` and `staticStrides` are automatically filled with 1271 /// source-memref-rank sentinel values that encode dynamic entries. 1272 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1273 MemRefType resultType, Value source, 1274 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1275 ArrayRef<OpFoldResult> strides, 1276 ArrayRef<NamedAttribute> attrs) { 1277 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1278 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1279 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1280 ShapedType::kDynamicStrideOrOffset); 1281 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1282 ShapedType::kDynamicSize); 1283 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1284 ShapedType::kDynamicStrideOrOffset); 1285 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1286 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1287 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1288 result.addAttributes(attrs); 1289 } 1290 1291 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1292 MemRefType resultType, Value source, 1293 int64_t offset, ArrayRef<int64_t> sizes, 1294 ArrayRef<int64_t> strides, 1295 ArrayRef<NamedAttribute> attrs) { 1296 SmallVector<OpFoldResult> sizeValues = 1297 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1298 return b.getI64IntegerAttr(v); 1299 })); 1300 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1301 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1302 return b.getI64IntegerAttr(v); 1303 })); 1304 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1305 strideValues, attrs); 1306 } 1307 1308 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1309 MemRefType resultType, Value source, Value offset, 1310 ValueRange sizes, ValueRange strides, 1311 ArrayRef<NamedAttribute> attrs) { 1312 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1313 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1314 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1315 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1316 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1317 } 1318 1319 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1320 // completed automatically, like we have for subview and extract_slice. 1321 static LogicalResult verify(ReinterpretCastOp op) { 1322 // The source and result memrefs should be in the same memory space. 1323 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1324 auto resultType = op.getType().cast<MemRefType>(); 1325 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1326 return op.emitError("different memory spaces specified for source type ") 1327 << srcType << " and result memref type " << resultType; 1328 if (srcType.getElementType() != resultType.getElementType()) 1329 return op.emitError("different element types specified for source type ") 1330 << srcType << " and result memref type " << resultType; 1331 1332 // Match sizes in result memref type and in static_sizes attribute. 1333 for (auto &en : 1334 llvm::enumerate(llvm::zip(resultType.getShape(), 1335 extractFromI64ArrayAttr(op.static_sizes())))) { 1336 int64_t resultSize = std::get<0>(en.value()); 1337 int64_t expectedSize = std::get<1>(en.value()); 1338 if (resultSize != expectedSize) 1339 return op.emitError("expected result type with size = ") 1340 << expectedSize << " instead of " << resultSize 1341 << " in dim = " << en.index(); 1342 } 1343 1344 // Match offset and strides in static_offset and static_strides attributes if 1345 // result memref type has an affine map specified. 1346 if (!resultType.getAffineMaps().empty()) { 1347 int64_t resultOffset; 1348 SmallVector<int64_t, 4> resultStrides; 1349 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1350 return failure(); 1351 1352 // Match offset in result memref type and in static_offsets attribute. 1353 int64_t expectedOffset = 1354 extractFromI64ArrayAttr(op.static_offsets()).front(); 1355 if (resultOffset != expectedOffset) 1356 return op.emitError("expected result type with offset = ") 1357 << resultOffset << " instead of " << expectedOffset; 1358 1359 // Match strides in result memref type and in static_strides attribute. 1360 for (auto &en : llvm::enumerate(llvm::zip( 1361 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1362 int64_t resultStride = std::get<0>(en.value()); 1363 int64_t expectedStride = std::get<1>(en.value()); 1364 if (resultStride != expectedStride) 1365 return op.emitError("expected result type with stride = ") 1366 << expectedStride << " instead of " << resultStride 1367 << " in dim = " << en.index(); 1368 } 1369 } 1370 return success(); 1371 } 1372 1373 //===----------------------------------------------------------------------===// 1374 // ReshapeOp 1375 //===----------------------------------------------------------------------===// 1376 1377 static LogicalResult verify(ReshapeOp op) { 1378 Type operandType = op.source().getType(); 1379 Type resultType = op.result().getType(); 1380 1381 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1382 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1383 if (operandElementType != resultElementType) 1384 return op.emitOpError("element types of source and destination memref " 1385 "types should be the same"); 1386 1387 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1388 if (!operandMemRefType.getAffineMaps().empty()) 1389 return op.emitOpError( 1390 "source memref type should have identity affine map"); 1391 1392 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1393 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1394 if (resultMemRefType) { 1395 if (!resultMemRefType.getAffineMaps().empty()) 1396 return op.emitOpError( 1397 "result memref type should have identity affine map"); 1398 if (shapeSize == ShapedType::kDynamicSize) 1399 return op.emitOpError("cannot use shape operand with dynamic length to " 1400 "reshape to statically-ranked memref type"); 1401 if (shapeSize != resultMemRefType.getRank()) 1402 return op.emitOpError( 1403 "length of shape operand differs from the result's memref rank"); 1404 } 1405 return success(); 1406 } 1407 1408 //===----------------------------------------------------------------------===// 1409 // StoreOp 1410 //===----------------------------------------------------------------------===// 1411 1412 static LogicalResult verify(StoreOp op) { 1413 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1414 return op.emitOpError("store index operand count not equal to memref rank"); 1415 1416 return success(); 1417 } 1418 1419 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1420 SmallVectorImpl<OpFoldResult> &results) { 1421 /// store(memrefcast) -> store 1422 return foldMemRefCast(*this, getValueToStore()); 1423 } 1424 1425 //===----------------------------------------------------------------------===// 1426 // SubViewOp 1427 //===----------------------------------------------------------------------===// 1428 1429 namespace { 1430 /// Helpers to write more idiomatic operations. 1431 namespace saturated_arith { 1432 struct Wrapper { 1433 explicit Wrapper(int64_t v) : v(v) {} 1434 operator int64_t() { return v; } 1435 int64_t v; 1436 }; 1437 Wrapper operator+(Wrapper a, int64_t b) { 1438 if (ShapedType::isDynamicStrideOrOffset(a) || 1439 ShapedType::isDynamicStrideOrOffset(b)) 1440 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1441 return Wrapper(a.v + b); 1442 } 1443 Wrapper operator*(Wrapper a, int64_t b) { 1444 if (ShapedType::isDynamicStrideOrOffset(a) || 1445 ShapedType::isDynamicStrideOrOffset(b)) 1446 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1447 return Wrapper(a.v * b); 1448 } 1449 } // end namespace saturated_arith 1450 } // end namespace 1451 1452 /// A subview result type can be fully inferred from the source type and the 1453 /// static representation of offsets, sizes and strides. Special sentinels 1454 /// encode the dynamic case. 1455 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1456 ArrayRef<int64_t> leadingStaticOffsets, 1457 ArrayRef<int64_t> leadingStaticSizes, 1458 ArrayRef<int64_t> leadingStaticStrides) { 1459 // A subview may specify only a leading subset of offset/sizes/strides in 1460 // which case we complete with offset=0, sizes from memref type and strides=1. 1461 unsigned rank = sourceMemRefType.getRank(); 1462 assert(leadingStaticOffsets.size() <= rank && 1463 "unexpected leadingStaticOffsets overflow"); 1464 assert(leadingStaticSizes.size() <= rank && 1465 "unexpected leadingStaticSizes overflow"); 1466 assert(leadingStaticStrides.size() <= rank && 1467 "unexpected leadingStaticStrides overflow"); 1468 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1469 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1470 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1471 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1472 unsigned numTrailingSizes = rank - staticSizes.size(); 1473 unsigned numTrailingStrides = rank - staticStrides.size(); 1474 staticOffsets.append(numTrailingOffsets, 0); 1475 llvm::append_range(staticSizes, 1476 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1477 staticStrides.append(numTrailingStrides, 1); 1478 1479 // Extract source offset and strides. 1480 int64_t sourceOffset; 1481 SmallVector<int64_t, 4> sourceStrides; 1482 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1483 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1484 (void)res; 1485 1486 // Compute target offset whose value is: 1487 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1488 int64_t targetOffset = sourceOffset; 1489 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1490 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1491 using namespace saturated_arith; 1492 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1493 } 1494 1495 // Compute target stride whose value is: 1496 // `sourceStrides_i * staticStrides_i`. 1497 SmallVector<int64_t, 4> targetStrides; 1498 targetStrides.reserve(staticOffsets.size()); 1499 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1500 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1501 using namespace saturated_arith; 1502 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1503 } 1504 1505 // The type is now known. 1506 return MemRefType::get( 1507 staticSizes, sourceMemRefType.getElementType(), 1508 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1509 sourceMemRefType.getContext()), 1510 sourceMemRefType.getMemorySpace()); 1511 } 1512 1513 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1514 ArrayRef<OpFoldResult> leadingStaticOffsets, 1515 ArrayRef<OpFoldResult> leadingStaticSizes, 1516 ArrayRef<OpFoldResult> leadingStaticStrides) { 1517 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1518 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1519 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1520 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1521 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1522 ShapedType::kDynamicSize); 1523 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1524 staticStrides, ShapedType::kDynamicStrideOrOffset); 1525 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1526 staticSizes, staticStrides) 1527 .cast<MemRefType>(); 1528 } 1529 1530 Type SubViewOp::inferRankReducedResultType( 1531 unsigned resultRank, MemRefType sourceRankedTensorType, 1532 ArrayRef<int64_t> leadingStaticOffsets, 1533 ArrayRef<int64_t> leadingStaticSizes, 1534 ArrayRef<int64_t> leadingStaticStrides) { 1535 auto inferredType = 1536 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1537 leadingStaticSizes, leadingStaticStrides) 1538 .cast<MemRefType>(); 1539 assert(inferredType.getRank() >= resultRank && "expected "); 1540 int rankDiff = inferredType.getRank() - resultRank; 1541 if (rankDiff > 0) { 1542 auto shape = inferredType.getShape(); 1543 llvm::SmallDenseSet<unsigned> dimsToProject; 1544 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1545 SmallVector<int64_t> projectedShape; 1546 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1547 if (!dimsToProject.contains(pos)) 1548 projectedShape.push_back(shape[pos]); 1549 1550 AffineMap map; 1551 auto maps = inferredType.getAffineMaps(); 1552 if (!maps.empty() && maps.front()) 1553 map = getProjectedMap(maps.front(), dimsToProject); 1554 inferredType = 1555 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1556 inferredType.getMemorySpace()); 1557 } 1558 return inferredType; 1559 } 1560 1561 Type SubViewOp::inferRankReducedResultType( 1562 unsigned resultRank, MemRefType sourceRankedTensorType, 1563 ArrayRef<OpFoldResult> leadingStaticOffsets, 1564 ArrayRef<OpFoldResult> leadingStaticSizes, 1565 ArrayRef<OpFoldResult> leadingStaticStrides) { 1566 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1567 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1568 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1569 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1570 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1571 ShapedType::kDynamicSize); 1572 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1573 staticStrides, ShapedType::kDynamicStrideOrOffset); 1574 return SubViewOp::inferRankReducedResultType( 1575 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1576 staticStrides); 1577 } 1578 // Build a SubViewOp with mixed static and dynamic entries and custom result 1579 // type. If the type passed is nullptr, it is inferred. 1580 void SubViewOp::build(OpBuilder &b, OperationState &result, 1581 MemRefType resultType, Value source, 1582 ArrayRef<OpFoldResult> offsets, 1583 ArrayRef<OpFoldResult> sizes, 1584 ArrayRef<OpFoldResult> strides, 1585 ArrayRef<NamedAttribute> attrs) { 1586 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1587 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1588 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1589 ShapedType::kDynamicStrideOrOffset); 1590 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1591 ShapedType::kDynamicSize); 1592 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1593 ShapedType::kDynamicStrideOrOffset); 1594 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1595 // Structuring implementation this way avoids duplication between builders. 1596 if (!resultType) { 1597 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1598 staticSizes, staticStrides) 1599 .cast<MemRefType>(); 1600 } 1601 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1602 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1603 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1604 result.addAttributes(attrs); 1605 } 1606 1607 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1608 // type. 1609 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1610 ArrayRef<OpFoldResult> offsets, 1611 ArrayRef<OpFoldResult> sizes, 1612 ArrayRef<OpFoldResult> strides, 1613 ArrayRef<NamedAttribute> attrs) { 1614 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1615 } 1616 1617 // Build a SubViewOp with static entries and inferred result type. 1618 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1619 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1620 ArrayRef<int64_t> strides, 1621 ArrayRef<NamedAttribute> attrs) { 1622 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1623 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1624 return b.getI64IntegerAttr(v); 1625 })); 1626 SmallVector<OpFoldResult> sizeValues = 1627 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1628 return b.getI64IntegerAttr(v); 1629 })); 1630 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1631 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1632 return b.getI64IntegerAttr(v); 1633 })); 1634 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1635 } 1636 1637 // Build a SubViewOp with dynamic entries and custom result type. If the 1638 // type passed is nullptr, it is inferred. 1639 void SubViewOp::build(OpBuilder &b, OperationState &result, 1640 MemRefType resultType, Value source, 1641 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1642 ArrayRef<int64_t> strides, 1643 ArrayRef<NamedAttribute> attrs) { 1644 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1645 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1646 return b.getI64IntegerAttr(v); 1647 })); 1648 SmallVector<OpFoldResult> sizeValues = 1649 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1650 return b.getI64IntegerAttr(v); 1651 })); 1652 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1653 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1654 return b.getI64IntegerAttr(v); 1655 })); 1656 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1657 attrs); 1658 } 1659 1660 // Build a SubViewOp with dynamic entries and custom result type. If the type 1661 // passed is nullptr, it is inferred. 1662 void SubViewOp::build(OpBuilder &b, OperationState &result, 1663 MemRefType resultType, Value source, ValueRange offsets, 1664 ValueRange sizes, ValueRange strides, 1665 ArrayRef<NamedAttribute> attrs) { 1666 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1667 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1668 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1669 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1670 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1671 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1672 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1673 } 1674 1675 // Build a SubViewOp with dynamic entries and inferred result type. 1676 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1677 ValueRange offsets, ValueRange sizes, ValueRange strides, 1678 ArrayRef<NamedAttribute> attrs) { 1679 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1680 } 1681 1682 /// For ViewLikeOpInterface. 1683 Value SubViewOp::getViewSource() { return source(); } 1684 1685 enum SubViewVerificationResult { 1686 Success, 1687 RankTooLarge, 1688 SizeMismatch, 1689 ElemTypeMismatch, 1690 MemSpaceMismatch, 1691 AffineMapMismatch 1692 }; 1693 1694 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1695 /// This function is slight variant of `is subsequence` algorithm where 1696 /// not matching dimension must be 1. 1697 static SubViewVerificationResult 1698 isRankReducedType(Type originalType, Type candidateReducedType, 1699 std::string *errMsg = nullptr) { 1700 if (originalType == candidateReducedType) 1701 return SubViewVerificationResult::Success; 1702 if (!originalType.isa<MemRefType>()) 1703 return SubViewVerificationResult::Success; 1704 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1705 return SubViewVerificationResult::Success; 1706 1707 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1708 ShapedType candidateReducedShapedType = 1709 candidateReducedType.cast<ShapedType>(); 1710 1711 // Rank and size logic is valid for all ShapedTypes. 1712 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1713 ArrayRef<int64_t> candidateReducedShape = 1714 candidateReducedShapedType.getShape(); 1715 unsigned originalRank = originalShape.size(), 1716 candidateReducedRank = candidateReducedShape.size(); 1717 if (candidateReducedRank > originalRank) 1718 return SubViewVerificationResult::RankTooLarge; 1719 1720 auto optionalUnusedDimsMask = 1721 computeRankReductionMask(originalShape, candidateReducedShape); 1722 1723 // Sizes cannot be matched in case empty vector is returned. 1724 if (!optionalUnusedDimsMask.hasValue()) 1725 return SubViewVerificationResult::SizeMismatch; 1726 1727 if (originalShapedType.getElementType() != 1728 candidateReducedShapedType.getElementType()) 1729 return SubViewVerificationResult::ElemTypeMismatch; 1730 1731 // Strided layout logic is relevant for MemRefType only. 1732 MemRefType original = originalType.cast<MemRefType>(); 1733 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1734 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1735 return SubViewVerificationResult::MemSpaceMismatch; 1736 1737 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1738 auto inferredType = 1739 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1740 AffineMap candidateLayout; 1741 if (candidateReduced.getAffineMaps().empty()) 1742 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1743 else 1744 candidateLayout = candidateReduced.getAffineMaps().front(); 1745 assert(inferredType.getNumResults() == 1 && 1746 candidateLayout.getNumResults() == 1); 1747 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1748 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1749 if (errMsg) { 1750 llvm::raw_string_ostream os(*errMsg); 1751 os << "inferred type: " << inferredType; 1752 } 1753 return SubViewVerificationResult::AffineMapMismatch; 1754 } 1755 // Check that the difference of the affine maps simplifies to 0. 1756 AffineExpr diffExpr = 1757 inferredType.getResult(0) - candidateLayout.getResult(0); 1758 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1759 inferredType.getNumSymbols()); 1760 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1761 if (!(cst && cst.getValue() == 0)) { 1762 if (errMsg) { 1763 llvm::raw_string_ostream os(*errMsg); 1764 os << "inferred type: " << inferredType; 1765 } 1766 return SubViewVerificationResult::AffineMapMismatch; 1767 } 1768 return SubViewVerificationResult::Success; 1769 } 1770 1771 template <typename OpTy> 1772 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1773 OpTy op, Type expectedType, 1774 StringRef errMsg = "") { 1775 auto memrefType = expectedType.cast<ShapedType>(); 1776 switch (result) { 1777 case SubViewVerificationResult::Success: 1778 return success(); 1779 case SubViewVerificationResult::RankTooLarge: 1780 return op.emitError("expected result rank to be smaller or equal to ") 1781 << "the source rank. " << errMsg; 1782 case SubViewVerificationResult::SizeMismatch: 1783 return op.emitError("expected result type to be ") 1784 << expectedType 1785 << " or a rank-reduced version. (mismatch of result sizes) " 1786 << errMsg; 1787 case SubViewVerificationResult::ElemTypeMismatch: 1788 return op.emitError("expected result element type to be ") 1789 << memrefType.getElementType() << errMsg; 1790 case SubViewVerificationResult::MemSpaceMismatch: 1791 return op.emitError("expected result and source memory spaces to match.") 1792 << errMsg; 1793 case SubViewVerificationResult::AffineMapMismatch: 1794 return op.emitError("expected result type to be ") 1795 << expectedType 1796 << " or a rank-reduced version. (mismatch of result affine map) " 1797 << errMsg; 1798 } 1799 llvm_unreachable("unexpected subview verification result"); 1800 } 1801 1802 /// Verifier for SubViewOp. 1803 static LogicalResult verify(SubViewOp op) { 1804 MemRefType baseType = op.getSourceType(); 1805 MemRefType subViewType = op.getType(); 1806 1807 // The base memref and the view memref should be in the same memory space. 1808 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1809 return op.emitError("different memory spaces specified for base memref " 1810 "type ") 1811 << baseType << " and subview memref type " << subViewType; 1812 1813 // Verify that the base memref type has a strided layout map. 1814 if (!isStrided(baseType)) 1815 return op.emitError("base type ") << baseType << " is not strided"; 1816 1817 // Verify result type against inferred type. 1818 auto expectedType = SubViewOp::inferResultType( 1819 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1820 extractFromI64ArrayAttr(op.static_sizes()), 1821 extractFromI64ArrayAttr(op.static_strides())); 1822 1823 std::string errMsg; 1824 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1825 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1826 } 1827 1828 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1829 return os << "range " << range.offset << ":" << range.size << ":" 1830 << range.stride; 1831 } 1832 1833 /// Return the list of Range (i.e. offset, size, stride). Each Range 1834 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1835 /// with `b` at location `loc`. 1836 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1837 OpBuilder &b, Location loc) { 1838 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1839 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1840 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1841 SmallVector<Range, 8> res; 1842 unsigned rank = ranks[0]; 1843 res.reserve(rank); 1844 for (unsigned idx = 0; idx < rank; ++idx) { 1845 Value offset = 1846 op.isDynamicOffset(idx) 1847 ? op.getDynamicOffset(idx) 1848 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1849 Value size = op.isDynamicSize(idx) 1850 ? op.getDynamicSize(idx) 1851 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1852 Value stride = 1853 op.isDynamicStride(idx) 1854 ? op.getDynamicStride(idx) 1855 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1856 res.emplace_back(Range{offset, size, stride}); 1857 } 1858 return res; 1859 } 1860 1861 /// Infer the canonical type of the result of a subview operation. Returns a 1862 /// type with rank `resultRank` that is either the rank of the rank-reduced 1863 /// type, or the non-rank-reduced type. 1864 static MemRefType 1865 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1866 ArrayRef<OpFoldResult> mixedOffsets, 1867 ArrayRef<OpFoldResult> mixedSizes, 1868 ArrayRef<OpFoldResult> mixedStrides) { 1869 auto resultType = 1870 SubViewOp::inferRankReducedResultType( 1871 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1872 .cast<MemRefType>(); 1873 if (resultType.getRank() != resultRank) { 1874 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1875 mixedSizes, mixedStrides) 1876 .cast<MemRefType>(); 1877 } 1878 return resultType; 1879 } 1880 1881 namespace { 1882 /// Pattern to rewrite a subview op with MemRefCast arguments. 1883 /// This essentially pushes memref.cast past its consuming subview when 1884 /// `canFoldIntoConsumerOp` is true. 1885 /// 1886 /// Example: 1887 /// ``` 1888 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1889 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1890 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1891 /// ``` 1892 /// is rewritten into: 1893 /// ``` 1894 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1895 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1896 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1897 /// ``` 1898 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1899 public: 1900 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1901 1902 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1903 PatternRewriter &rewriter) const override { 1904 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1905 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1906 return matchPattern(operand, matchConstantIndex()); 1907 })) 1908 return failure(); 1909 1910 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1911 if (!castOp) 1912 return failure(); 1913 1914 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1915 return failure(); 1916 1917 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1918 /// the cast source operand type and the SubViewOp static information. This 1919 /// is the resulting type if the MemRefCastOp were folded. 1920 auto resultType = getCanonicalSubViewResultType( 1921 subViewOp.getType().getRank(), 1922 castOp.source().getType().cast<MemRefType>(), 1923 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1924 subViewOp.getMixedStrides()); 1925 Value newSubView = rewriter.create<SubViewOp>( 1926 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1927 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1928 subViewOp.static_sizes(), subViewOp.static_strides()); 1929 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1930 newSubView); 1931 return success(); 1932 } 1933 }; 1934 } // namespace 1935 1936 /// Return the canonical type of the result of a subview. 1937 struct SubViewReturnTypeCanonicalizer { 1938 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 1939 ArrayRef<OpFoldResult> mixedSizes, 1940 ArrayRef<OpFoldResult> mixedStrides) { 1941 return getCanonicalSubViewResultType(op.getType().getRank(), 1942 op.getSourceType(), mixedOffsets, 1943 mixedSizes, mixedStrides); 1944 } 1945 }; 1946 1947 /// A canonicalizer wrapper to replace SubViewOps. 1948 struct SubViewCanonicalizer { 1949 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1950 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1951 } 1952 }; 1953 1954 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1955 MLIRContext *context) { 1956 results 1957 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1958 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 1959 SubViewOpMemRefCastFolder>(context); 1960 } 1961 1962 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1963 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1964 auto sourceShapedType = source().getType().cast<ShapedType>(); 1965 1966 if (resultShapedType.hasStaticShape() && 1967 resultShapedType == sourceShapedType) { 1968 return getViewSource(); 1969 } 1970 1971 return {}; 1972 } 1973 1974 //===----------------------------------------------------------------------===// 1975 // TensorLoadOp 1976 //===----------------------------------------------------------------------===// 1977 1978 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1979 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1980 // Approximate alias analysis by conservatively folding only when no there 1981 // is no interleaved operation. 1982 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1983 bufferCast->getNextNode() == this->getOperation()) 1984 return bufferCast.tensor(); 1985 return {}; 1986 } 1987 1988 //===----------------------------------------------------------------------===// 1989 // TransposeOp 1990 //===----------------------------------------------------------------------===// 1991 1992 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1993 static MemRefType inferTransposeResultType(MemRefType memRefType, 1994 AffineMap permutationMap) { 1995 auto rank = memRefType.getRank(); 1996 auto originalSizes = memRefType.getShape(); 1997 // Compute permuted sizes. 1998 SmallVector<int64_t, 4> sizes(rank, 0); 1999 for (auto en : llvm::enumerate(permutationMap.getResults())) 2000 sizes[en.index()] = 2001 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2002 2003 // Compute permuted strides. 2004 int64_t offset; 2005 SmallVector<int64_t, 4> strides; 2006 auto res = getStridesAndOffset(memRefType, strides, offset); 2007 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2008 (void)res; 2009 auto map = 2010 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2011 map = permutationMap ? map.compose(permutationMap) : map; 2012 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2013 } 2014 2015 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2016 AffineMapAttr permutation, 2017 ArrayRef<NamedAttribute> attrs) { 2018 auto permutationMap = permutation.getValue(); 2019 assert(permutationMap); 2020 2021 auto memRefType = in.getType().cast<MemRefType>(); 2022 // Compute result type. 2023 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2024 2025 build(b, result, resultType, in, attrs); 2026 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2027 } 2028 2029 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2030 static void print(OpAsmPrinter &p, TransposeOp op) { 2031 p << "memref.transpose " << op.in() << " " << op.permutation(); 2032 p.printOptionalAttrDict(op->getAttrs(), 2033 {TransposeOp::getPermutationAttrName()}); 2034 p << " : " << op.in().getType() << " to " << op.getType(); 2035 } 2036 2037 static ParseResult parseTransposeOp(OpAsmParser &parser, 2038 OperationState &result) { 2039 OpAsmParser::OperandType in; 2040 AffineMap permutation; 2041 MemRefType srcType, dstType; 2042 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2043 parser.parseOptionalAttrDict(result.attributes) || 2044 parser.parseColonType(srcType) || 2045 parser.resolveOperand(in, srcType, result.operands) || 2046 parser.parseKeywordType("to", dstType) || 2047 parser.addTypeToList(dstType, result.types)) 2048 return failure(); 2049 2050 result.addAttribute(TransposeOp::getPermutationAttrName(), 2051 AffineMapAttr::get(permutation)); 2052 return success(); 2053 } 2054 2055 static LogicalResult verify(TransposeOp op) { 2056 if (!op.permutation().isPermutation()) 2057 return op.emitOpError("expected a permutation map"); 2058 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2059 return op.emitOpError( 2060 "expected a permutation map of same rank as the input"); 2061 2062 auto srcType = op.in().getType().cast<MemRefType>(); 2063 auto dstType = op.getType().cast<MemRefType>(); 2064 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2065 if (dstType != transposedType) 2066 return op.emitOpError("output type ") 2067 << dstType << " does not match transposed input type " << srcType 2068 << ", " << transposedType; 2069 return success(); 2070 } 2071 2072 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2073 if (succeeded(foldMemRefCast(*this))) 2074 return getResult(); 2075 return {}; 2076 } 2077 2078 //===----------------------------------------------------------------------===// 2079 // ViewOp 2080 //===----------------------------------------------------------------------===// 2081 2082 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2083 OpAsmParser::OperandType srcInfo; 2084 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2085 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2086 auto indexType = parser.getBuilder().getIndexType(); 2087 Type srcType, dstType; 2088 llvm::SMLoc offsetLoc; 2089 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2090 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2091 return failure(); 2092 2093 if (offsetInfo.size() != 1) 2094 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2095 2096 return failure( 2097 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2098 parser.parseOptionalAttrDict(result.attributes) || 2099 parser.parseColonType(srcType) || 2100 parser.resolveOperand(srcInfo, srcType, result.operands) || 2101 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2102 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2103 parser.parseKeywordType("to", dstType) || 2104 parser.addTypeToList(dstType, result.types)); 2105 } 2106 2107 static void print(OpAsmPrinter &p, ViewOp op) { 2108 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2109 p.printOperand(op.byte_shift()); 2110 p << "][" << op.sizes() << ']'; 2111 p.printOptionalAttrDict(op->getAttrs()); 2112 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2113 } 2114 2115 static LogicalResult verify(ViewOp op) { 2116 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2117 auto viewType = op.getType(); 2118 2119 // The base memref should have identity layout map (or none). 2120 if (baseType.getAffineMaps().size() > 1 || 2121 (baseType.getAffineMaps().size() == 1 && 2122 !baseType.getAffineMaps()[0].isIdentity())) 2123 return op.emitError("unsupported map for base memref type ") << baseType; 2124 2125 // The result memref should have identity layout map (or none). 2126 if (viewType.getAffineMaps().size() > 1 || 2127 (viewType.getAffineMaps().size() == 1 && 2128 !viewType.getAffineMaps()[0].isIdentity())) 2129 return op.emitError("unsupported map for result memref type ") << viewType; 2130 2131 // The base memref and the view memref should be in the same memory space. 2132 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2133 return op.emitError("different memory spaces specified for base memref " 2134 "type ") 2135 << baseType << " and view memref type " << viewType; 2136 2137 // Verify that we have the correct number of sizes for the result type. 2138 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2139 if (op.sizes().size() != numDynamicDims) 2140 return op.emitError("incorrect number of size operands for type ") 2141 << viewType; 2142 2143 return success(); 2144 } 2145 2146 Value ViewOp::getViewSource() { return source(); } 2147 2148 namespace { 2149 2150 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2151 using OpRewritePattern<ViewOp>::OpRewritePattern; 2152 2153 LogicalResult matchAndRewrite(ViewOp viewOp, 2154 PatternRewriter &rewriter) const override { 2155 // Return if none of the operands are constants. 2156 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2157 return matchPattern(operand, matchConstantIndex()); 2158 })) 2159 return failure(); 2160 2161 // Get result memref type. 2162 auto memrefType = viewOp.getType(); 2163 2164 // Get offset from old memref view type 'memRefType'. 2165 int64_t oldOffset; 2166 SmallVector<int64_t, 4> oldStrides; 2167 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2168 return failure(); 2169 assert(oldOffset == 0 && "Expected 0 offset"); 2170 2171 SmallVector<Value, 4> newOperands; 2172 2173 // Offset cannot be folded into result type. 2174 2175 // Fold any dynamic dim operands which are produced by a constant. 2176 SmallVector<int64_t, 4> newShapeConstants; 2177 newShapeConstants.reserve(memrefType.getRank()); 2178 2179 unsigned dynamicDimPos = 0; 2180 unsigned rank = memrefType.getRank(); 2181 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2182 int64_t dimSize = memrefType.getDimSize(dim); 2183 // If this is already static dimension, keep it. 2184 if (!ShapedType::isDynamic(dimSize)) { 2185 newShapeConstants.push_back(dimSize); 2186 continue; 2187 } 2188 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2189 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2190 // Dynamic shape dimension will be folded. 2191 newShapeConstants.push_back(constantIndexOp.getValue()); 2192 } else { 2193 // Dynamic shape dimension not folded; copy operand from old memref. 2194 newShapeConstants.push_back(dimSize); 2195 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2196 } 2197 dynamicDimPos++; 2198 } 2199 2200 // Create new memref type with constant folded dims. 2201 MemRefType newMemRefType = 2202 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2203 // Nothing new, don't fold. 2204 if (newMemRefType == memrefType) 2205 return failure(); 2206 2207 // Create new ViewOp. 2208 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2209 viewOp.getOperand(0), 2210 viewOp.byte_shift(), newOperands); 2211 // Insert a cast so we have the same type as the old memref type. 2212 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2213 return success(); 2214 } 2215 }; 2216 2217 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2218 using OpRewritePattern<ViewOp>::OpRewritePattern; 2219 2220 LogicalResult matchAndRewrite(ViewOp viewOp, 2221 PatternRewriter &rewriter) const override { 2222 Value memrefOperand = viewOp.getOperand(0); 2223 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2224 if (!memrefCastOp) 2225 return failure(); 2226 Value allocOperand = memrefCastOp.getOperand(); 2227 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2228 if (!allocOp) 2229 return failure(); 2230 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2231 viewOp.byte_shift(), viewOp.sizes()); 2232 return success(); 2233 } 2234 }; 2235 2236 } // end anonymous namespace 2237 2238 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2239 MLIRContext *context) { 2240 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2241 } 2242 2243 //===----------------------------------------------------------------------===// 2244 // TableGen'd op method definitions 2245 //===----------------------------------------------------------------------===// 2246 2247 #define GET_OP_CLASSES 2248 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2249