1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "llvm/ADT/STLExtras.h" 22 23 using namespace mlir; 24 using namespace mlir::memref; 25 26 /// Materialize a single constant operation from a given attribute value with 27 /// the desired resultant type. 28 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 29 Attribute value, Type type, 30 Location loc) { 31 return builder.create<mlir::ConstantOp>(loc, type, value); 32 } 33 34 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 35 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 36 return llvm::to_vector<4>( 37 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 38 return a.cast<IntegerAttr>().getInt(); 39 })); 40 } 41 42 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 43 /// it is a Value or into `staticVec` if it is an IntegerAttr. 44 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 45 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 46 /// come from an AttrSizedOperandSegments trait. 47 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 48 SmallVectorImpl<Value> &dynamicVec, 49 SmallVectorImpl<int64_t> &staticVec, 50 int64_t sentinel) { 51 if (auto v = ofr.dyn_cast<Value>()) { 52 dynamicVec.push_back(v); 53 staticVec.push_back(sentinel); 54 return; 55 } 56 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 57 staticVec.push_back(apInt.getSExtValue()); 58 } 59 60 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 61 SmallVectorImpl<Value> &dynamicVec, 62 SmallVectorImpl<int64_t> &staticVec, 63 int64_t sentinel) { 64 for (auto ofr : ofrs) 65 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // Common canonicalization pattern support logic 70 //===----------------------------------------------------------------------===// 71 72 /// This is a common class used for patterns of the form 73 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 74 /// into the root operation directly. 75 static LogicalResult foldMemRefCast(Operation *op) { 76 bool folded = false; 77 for (OpOperand &operand : op->getOpOperands()) { 78 auto cast = operand.get().getDefiningOp<CastOp>(); 79 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 80 operand.set(cast.getOperand()); 81 folded = true; 82 } 83 } 84 return success(folded); 85 } 86 87 //===----------------------------------------------------------------------===// 88 // Helpers for GlobalOp 89 //===----------------------------------------------------------------------===// 90 91 static Type getTensorTypeFromMemRefType(Type type) { 92 if (auto memref = type.dyn_cast<MemRefType>()) 93 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 94 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 95 return UnrankedTensorType::get(memref.getElementType()); 96 return NoneType::get(type.getContext()); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // AllocOp / AllocaOp 101 //===----------------------------------------------------------------------===// 102 103 template <typename AllocLikeOp> 104 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 105 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 106 "applies to only alloc or alloca"); 107 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 108 if (!memRefType) 109 return op.emitOpError("result must be a memref"); 110 111 if (static_cast<int64_t>(op.dynamicSizes().size()) != 112 memRefType.getNumDynamicDims()) 113 return op.emitOpError("dimension operand count does not equal memref " 114 "dynamic dimension count"); 115 116 unsigned numSymbols = 0; 117 if (!memRefType.getAffineMaps().empty()) 118 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 119 if (op.symbolOperands().size() != numSymbols) 120 return op.emitOpError( 121 "symbol operand count does not equal memref symbol count"); 122 123 return success(); 124 } 125 126 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 127 128 static LogicalResult verify(AllocaOp op) { 129 // An alloca op needs to have an ancestor with an allocation scope trait. 130 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 131 return op.emitOpError( 132 "requires an ancestor op with AutomaticAllocationScope trait"); 133 134 return verifyAllocLikeOp(op); 135 } 136 137 namespace { 138 /// Fold constant dimensions into an alloc like operation. 139 template <typename AllocLikeOp> 140 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 141 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 142 143 LogicalResult matchAndRewrite(AllocLikeOp alloc, 144 PatternRewriter &rewriter) const override { 145 // Check to see if any dimensions operands are constants. If so, we can 146 // substitute and drop them. 147 if (llvm::none_of(alloc.getOperands(), [](Value operand) { 148 return matchPattern(operand, matchConstantIndex()); 149 })) 150 return failure(); 151 152 auto memrefType = alloc.getType(); 153 154 // Ok, we have one or more constant operands. Collect the non-constant ones 155 // and keep track of the resultant memref type to build. 156 SmallVector<int64_t, 4> newShapeConstants; 157 newShapeConstants.reserve(memrefType.getRank()); 158 SmallVector<Value, 4> newOperands; 159 160 unsigned dynamicDimPos = 0; 161 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 162 int64_t dimSize = memrefType.getDimSize(dim); 163 // If this is already static dimension, keep it. 164 if (dimSize != -1) { 165 newShapeConstants.push_back(dimSize); 166 continue; 167 } 168 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); 169 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 170 // Dynamic shape dimension will be folded. 171 newShapeConstants.push_back(constantIndexOp.getValue()); 172 } else { 173 // Dynamic shape dimension not folded; copy operand from old memref. 174 newShapeConstants.push_back(-1); 175 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 176 } 177 dynamicDimPos++; 178 } 179 180 // Create new memref type (which will have fewer dynamic dimensions). 181 MemRefType newMemRefType = 182 MemRefType::Builder(memrefType).setShape(newShapeConstants); 183 assert(static_cast<int64_t>(newOperands.size()) == 184 newMemRefType.getNumDynamicDims()); 185 186 // Create and insert the alloc op for the new memref. 187 auto newAlloc = rewriter.create<AllocLikeOp>( 188 alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); 189 // Insert a cast so we have the same type as the old alloc. 190 auto resultCast = 191 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 192 193 rewriter.replaceOp(alloc, {resultCast}); 194 return success(); 195 } 196 }; 197 198 /// Fold alloc operations with no uses. Alloc has side effects on the heap, 199 /// but can still be deleted if it has zero uses. 200 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { 201 using OpRewritePattern<AllocOp>::OpRewritePattern; 202 203 LogicalResult matchAndRewrite(AllocOp alloc, 204 PatternRewriter &rewriter) const override { 205 if (alloc.use_empty()) { 206 rewriter.eraseOp(alloc); 207 return success(); 208 } 209 return failure(); 210 } 211 }; 212 } // end anonymous namespace. 213 214 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 215 MLIRContext *context) { 216 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context); 217 } 218 219 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 220 MLIRContext *context) { 221 results.add<SimplifyAllocConst<AllocaOp>>(context); 222 } 223 224 //===----------------------------------------------------------------------===// 225 // AssumeAlignmentOp 226 //===----------------------------------------------------------------------===// 227 228 static LogicalResult verify(AssumeAlignmentOp op) { 229 unsigned alignment = op.alignment(); 230 if (!llvm::isPowerOf2_32(alignment)) 231 return op.emitOpError("alignment must be power of 2"); 232 return success(); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // BufferCastOp 237 //===----------------------------------------------------------------------===// 238 239 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 240 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 241 if (tensorLoad.memref().getType() == getType()) 242 return tensorLoad.memref(); 243 return {}; 244 } 245 246 namespace { 247 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 248 struct BufferCast : public OpRewritePattern<BufferCastOp> { 249 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 250 251 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 252 PatternRewriter &rewriter) const final { 253 auto tensorCastOperand = 254 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 255 if (!tensorCastOperand) 256 return failure(); 257 auto srcTensorType = 258 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 259 if (!srcTensorType) 260 return failure(); 261 auto memrefType = MemRefType::get(srcTensorType.getShape(), 262 srcTensorType.getElementType()); 263 Value memref = rewriter.create<BufferCastOp>( 264 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 265 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 266 memref); 267 return success(); 268 } 269 }; 270 271 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 272 /// type mismatches prevent `BufferCastOp::fold` to kick in. 273 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 274 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 275 276 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 277 PatternRewriter &rewriter) const final { 278 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 279 // Bail unless we have a tensor_load + memref.buffer_cast with different 280 // types. `BufferCastOp::fold` handles the same type case. 281 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 282 return failure(); 283 // If types are not cast-compatible, bail. 284 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 285 bufferCast.getType())) 286 return failure(); 287 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 288 tensorLoad.memref()); 289 return success(); 290 } 291 }; 292 293 } // namespace 294 295 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 296 MLIRContext *context) { 297 results.add<BufferCast, TensorLoadToMemRef>(context); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // CastOp 302 //===----------------------------------------------------------------------===// 303 304 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 305 /// source memref. This is useful to to fold a memref.cast into a consuming op 306 /// and implement canonicalization patterns for ops in different dialects that 307 /// may consume the results of memref.cast operations. Such foldable memref.cast 308 /// operations are typically inserted as `view` and `subview` ops are 309 /// canonicalized, to preserve the type compatibility of their uses. 310 /// 311 /// Returns true when all conditions are met: 312 /// 1. source and result are ranked memrefs with strided semantics and same 313 /// element type and rank. 314 /// 2. each of the source's size, offset or stride has more static information 315 /// than the corresponding result's size, offset or stride. 316 /// 317 /// Example 1: 318 /// ```mlir 319 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 320 /// %2 = consumer %1 ... : memref<?x?xf32> ... 321 /// ``` 322 /// 323 /// may fold into: 324 /// 325 /// ```mlir 326 /// %2 = consumer %0 ... : memref<8x16xf32> ... 327 /// ``` 328 /// 329 /// Example 2: 330 /// ``` 331 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 332 /// to memref<?x?xf32> 333 /// consumer %1 : memref<?x?xf32> ... 334 /// ``` 335 /// 336 /// may fold into: 337 /// 338 /// ``` 339 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 340 /// ``` 341 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 342 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 343 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 344 345 // Requires ranked MemRefType. 346 if (!sourceType || !resultType) 347 return false; 348 349 // Requires same elemental type. 350 if (sourceType.getElementType() != resultType.getElementType()) 351 return false; 352 353 // Requires same rank. 354 if (sourceType.getRank() != resultType.getRank()) 355 return false; 356 357 // Only fold casts between strided memref forms. 358 int64_t sourceOffset, resultOffset; 359 SmallVector<int64_t, 4> sourceStrides, resultStrides; 360 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 361 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 362 return false; 363 364 // If cast is towards more static sizes along any dimension, don't fold. 365 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 366 auto ss = std::get<0>(it), st = std::get<1>(it); 367 if (ss != st) 368 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 369 return false; 370 } 371 372 // If cast is towards more static offset along any dimension, don't fold. 373 if (sourceOffset != resultOffset) 374 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 375 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 376 return false; 377 378 // If cast is towards more static strides along any dimension, don't fold. 379 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 380 auto ss = std::get<0>(it), st = std::get<1>(it); 381 if (ss != st) 382 if (MemRefType::isDynamicStrideOrOffset(ss) && 383 !MemRefType::isDynamicStrideOrOffset(st)) 384 return false; 385 } 386 387 return true; 388 } 389 390 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 391 if (inputs.size() != 1 || outputs.size() != 1) 392 return false; 393 Type a = inputs.front(), b = outputs.front(); 394 auto aT = a.dyn_cast<MemRefType>(); 395 auto bT = b.dyn_cast<MemRefType>(); 396 397 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 398 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 399 400 if (aT && bT) { 401 if (aT.getElementType() != bT.getElementType()) 402 return false; 403 if (aT.getAffineMaps() != bT.getAffineMaps()) { 404 int64_t aOffset, bOffset; 405 SmallVector<int64_t, 4> aStrides, bStrides; 406 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 407 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 408 aStrides.size() != bStrides.size()) 409 return false; 410 411 // Strides along a dimension/offset are compatible if the value in the 412 // source memref is static and the value in the target memref is the 413 // same. They are also compatible if either one is dynamic (see 414 // description of MemRefCastOp for details). 415 auto checkCompatible = [](int64_t a, int64_t b) { 416 return (a == MemRefType::getDynamicStrideOrOffset() || 417 b == MemRefType::getDynamicStrideOrOffset() || a == b); 418 }; 419 if (!checkCompatible(aOffset, bOffset)) 420 return false; 421 for (auto aStride : enumerate(aStrides)) 422 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 423 return false; 424 } 425 if (aT.getMemorySpace() != bT.getMemorySpace()) 426 return false; 427 428 // They must have the same rank, and any specified dimensions must match. 429 if (aT.getRank() != bT.getRank()) 430 return false; 431 432 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 433 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 434 if (aDim != -1 && bDim != -1 && aDim != bDim) 435 return false; 436 } 437 return true; 438 } else { 439 if (!aT && !uaT) 440 return false; 441 if (!bT && !ubT) 442 return false; 443 // Unranked to unranked casting is unsupported 444 if (uaT && ubT) 445 return false; 446 447 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 448 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 449 if (aEltType != bEltType) 450 return false; 451 452 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 453 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 454 if (aMemSpace != bMemSpace) 455 return false; 456 457 return true; 458 } 459 460 return false; 461 } 462 463 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 464 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 465 } 466 467 //===----------------------------------------------------------------------===// 468 // CloneOp 469 //===----------------------------------------------------------------------===// 470 471 static LogicalResult verify(CloneOp op) { return success(); } 472 473 void CloneOp::getEffects( 474 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 475 &effects) { 476 effects.emplace_back(MemoryEffects::Read::get(), input(), 477 SideEffects::DefaultResource::get()); 478 effects.emplace_back(MemoryEffects::Write::get(), output(), 479 SideEffects::DefaultResource::get()); 480 } 481 482 namespace { 483 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 484 /// by other Dealloc operations. 485 struct SimplifyClones : public OpRewritePattern<CloneOp> { 486 using OpRewritePattern<CloneOp>::OpRewritePattern; 487 488 LogicalResult matchAndRewrite(CloneOp cloneOp, 489 PatternRewriter &rewriter) const override { 490 if (cloneOp.use_empty()) { 491 rewriter.eraseOp(cloneOp); 492 return success(); 493 } 494 495 Value source = cloneOp.input(); 496 497 // Removes the clone operation and the corresponding dealloc and alloc 498 // operation (if any). 499 auto tryRemoveClone = [&](Operation *sourceOp, Operation *dealloc, 500 Operation *alloc) { 501 if (!sourceOp || !dealloc || !alloc || 502 alloc->getBlock() != dealloc->getBlock()) 503 return false; 504 rewriter.replaceOp(cloneOp, source); 505 rewriter.eraseOp(dealloc); 506 return true; 507 }; 508 509 // Removes unnecessary clones that are derived from the result of the clone 510 // op. 511 Operation *deallocOp = findDealloc(cloneOp.output()); 512 Operation *sourceOp = source.getDefiningOp(); 513 if (tryRemoveClone(sourceOp, deallocOp, sourceOp)) 514 return success(); 515 516 // Removes unnecessary clones that are derived from the source of the clone 517 // op. 518 deallocOp = findDealloc(source); 519 if (tryRemoveClone(sourceOp, deallocOp, cloneOp)) 520 return success(); 521 522 return failure(); 523 } 524 }; 525 526 } // end anonymous namespace. 527 528 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 529 MLIRContext *context) { 530 results.insert<SimplifyClones>(context); 531 } 532 533 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 534 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 535 } 536 537 //===----------------------------------------------------------------------===// 538 // DeallocOp 539 //===----------------------------------------------------------------------===// 540 namespace { 541 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 542 /// by other Dealloc operations. 543 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { 544 using OpRewritePattern<DeallocOp>::OpRewritePattern; 545 546 LogicalResult matchAndRewrite(DeallocOp dealloc, 547 PatternRewriter &rewriter) const override { 548 // Check that the memref operand's defining operation is an AllocOp. 549 Value memref = dealloc.memref(); 550 if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp())) 551 return failure(); 552 553 // Check that all of the uses of the AllocOp are other DeallocOps. 554 for (auto *user : memref.getUsers()) 555 if (!isa<DeallocOp>(user)) 556 return failure(); 557 558 // Erase the dealloc operation. 559 rewriter.eraseOp(dealloc); 560 return success(); 561 } 562 }; 563 } // end anonymous namespace. 564 565 static LogicalResult verify(DeallocOp op) { 566 if (!op.memref().getType().isa<MemRefType>()) 567 return op.emitOpError("operand must be a memref"); 568 return success(); 569 } 570 571 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, 572 MLIRContext *context) { 573 results.add<SimplifyDeadDealloc>(context); 574 } 575 576 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 577 SmallVectorImpl<OpFoldResult> &results) { 578 /// dealloc(memrefcast) -> dealloc 579 return foldMemRefCast(*this); 580 } 581 582 //===----------------------------------------------------------------------===// 583 // DimOp 584 //===----------------------------------------------------------------------===// 585 586 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 587 int64_t index) { 588 auto loc = result.location; 589 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 590 build(builder, result, memref, indexValue); 591 } 592 593 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 594 Value index) { 595 auto indexTy = builder.getIndexType(); 596 build(builder, result, indexTy, memref, index); 597 } 598 599 Optional<int64_t> DimOp::getConstantIndex() { 600 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 601 return constantOp.getValue().cast<IntegerAttr>().getInt(); 602 return {}; 603 } 604 605 static LogicalResult verify(DimOp op) { 606 // Assume unknown index to be in range. 607 Optional<int64_t> index = op.getConstantIndex(); 608 if (!index.hasValue()) 609 return success(); 610 611 // Check that constant index is not knowingly out of range. 612 auto type = op.memrefOrTensor().getType(); 613 if (auto memrefType = type.dyn_cast<MemRefType>()) { 614 if (index.getValue() >= memrefType.getRank()) 615 return op.emitOpError("index is out of range"); 616 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 617 if (index.getValue() >= tensorType.getRank()) 618 return op.emitOpError("index is out of range"); 619 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 620 // Assume index to be in range. 621 } else { 622 llvm_unreachable("expected operand with memref type"); 623 } 624 return success(); 625 } 626 627 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 628 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 629 630 // All forms of folding require a known index. 631 if (!index) 632 return {}; 633 634 auto argTy = memrefOrTensor().getType(); 635 // Fold if the shape extent along the given index is known. 636 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 637 // Folding for unranked types (UnrankedMemRefType) is not supported. 638 if (!shapedTy.hasRank()) 639 return {}; 640 if (!shapedTy.isDynamicDim(index.getInt())) { 641 Builder builder(getContext()); 642 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 643 } 644 } 645 646 Operation *definingOp = memrefOrTensor().getDefiningOp(); 647 648 // dim(memref.tensor_load(memref)) -> dim(memref) 649 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 650 setOperand(0, tensorLoadOp.memref()); 651 return getResult(); 652 } 653 654 // Fold dim to the operand of tensor.generate. 655 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 656 auto resultType = 657 fromElements.getResult().getType().cast<RankedTensorType>(); 658 // The case where the type encodes the size of the dimension is handled 659 // above. 660 assert(resultType.getShape()[index.getInt()] == 661 RankedTensorType::kDynamicSize); 662 663 // Find the operand of the fromElements that corresponds to this index. 664 auto dynExtents = fromElements.dynamicExtents().begin(); 665 for (auto dim : resultType.getShape().take_front(index.getInt())) 666 if (dim == RankedTensorType::kDynamicSize) 667 dynExtents++; 668 669 return Value{*dynExtents}; 670 } 671 672 // The size at the given index is now known to be a dynamic size. 673 unsigned unsignedIndex = index.getValue().getZExtValue(); 674 675 if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) { 676 assert(subtensor.isDynamicSize(unsignedIndex) && 677 "Expected dynamic subtensor size"); 678 return subtensor.getDynamicSize(unsignedIndex); 679 } 680 681 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 682 auto memrefType = argTy.dyn_cast<MemRefType>(); 683 if (!memrefType) 684 return {}; 685 686 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 687 return *(alloc.getDynamicSizes().begin() + 688 memrefType.getDynamicDimIndex(unsignedIndex)); 689 690 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 691 return *(alloca.getDynamicSizes().begin() + 692 memrefType.getDynamicDimIndex(unsignedIndex)); 693 694 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 695 return *(view.getDynamicSizes().begin() + 696 memrefType.getDynamicDimIndex(unsignedIndex)); 697 698 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 699 assert(subview.isDynamicSize(unsignedIndex) && 700 "Expected dynamic subview size"); 701 return subview.getDynamicSize(unsignedIndex); 702 } 703 704 // dim(memrefcast) -> dim 705 if (succeeded(foldMemRefCast(*this))) 706 return getResult(); 707 708 return {}; 709 } 710 711 namespace { 712 /// Fold dim of a memref reshape operation to a load into the reshape's shape 713 /// operand. 714 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 715 using OpRewritePattern<DimOp>::OpRewritePattern; 716 717 LogicalResult matchAndRewrite(DimOp dim, 718 PatternRewriter &rewriter) const override { 719 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 720 721 if (!reshape) 722 return failure(); 723 724 // Place the load directly after the reshape to ensure that the shape memref 725 // was not mutated. 726 rewriter.setInsertionPointAfter(reshape); 727 rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(), 728 llvm::makeArrayRef({dim.index()})); 729 return success(); 730 } 731 }; 732 733 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 734 template <typename CastOpTy> 735 struct DimOfCastOp : public OpRewritePattern<DimOp> { 736 using OpRewritePattern<DimOp>::OpRewritePattern; 737 738 LogicalResult matchAndRewrite(DimOp dimOp, 739 PatternRewriter &rewriter) const override { 740 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 741 if (!castOp) 742 return failure(); 743 Value newSource = castOp.getOperand(); 744 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 745 return success(); 746 } 747 }; 748 749 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th 750 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. 751 /// TODO(ravishankarm): This is better put as a interface utility method 752 /// somewhere, but that would imply the interface will depend on the `tensor` 753 /// dialect. Ideally maybe a utility method in the `tensor` dialect. 754 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, 755 int64_t dimIndex) { 756 unsigned resultNumber = result.getResultNumber(); 757 auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner()); 758 Location loc = result.getOwner()->getLoc(); 759 if (!shapedTypeOp) 760 return nullptr; 761 762 // The interface exposes two methods, one that returns the shape of all the 763 // results as `Value` and other that returns the shape as a list of 764 // `SmallVector<Value>`. The former takes precedence over the latter. So first 765 // check if the op implements the first interface method or the second, and 766 // get the value to use appropriately. 767 SmallVector<Value> reifiedResultShapes; 768 if (succeeded( 769 shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) { 770 if (reifiedResultShapes.size() <= resultNumber) 771 return nullptr; 772 Value resultShape = reifiedResultShapes[resultNumber]; 773 auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); 774 if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) 775 return nullptr; 776 return builder.create<tensor::ExtractOp>( 777 loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex)); 778 } 779 780 SmallVector<SmallVector<Value>> reifiedResultShapesPerDim; 781 if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( 782 builder, reifiedResultShapesPerDim))) 783 return nullptr; 784 if (reifiedResultShapesPerDim.size() <= resultNumber || 785 reifiedResultShapesPerDim[resultNumber].size() != 786 static_cast<size_t>(result.getType().cast<ShapedType>().getRank())) 787 return nullptr; 788 OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; 789 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) 790 return builder.createOrFold<ConstantIndexOp>( 791 loc, attr.cast<IntegerAttr>().getInt()); 792 return valueOrAttr.get<Value>(); 793 } 794 795 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 796 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> { 797 using OpRewritePattern<DimOp>::OpRewritePattern; 798 799 LogicalResult matchAndRewrite(DimOp dimOp, 800 PatternRewriter &rewriter) const override { 801 OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>(); 802 if (!dimValue) 803 return failure(); 804 auto shapedTypeOp = 805 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); 806 if (!shapedTypeOp) 807 return failure(); 808 809 Optional<int64_t> dimIndex = dimOp.getConstantIndex(); 810 if (!dimIndex) 811 return failure(); 812 Value replacement = 813 getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); 814 if (!replacement) 815 return failure(); 816 rewriter.replaceOp(dimOp, replacement); 817 return success(); 818 } 819 }; 820 } // end anonymous namespace. 821 822 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 823 MLIRContext *context) { 824 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 825 DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context); 826 } 827 828 // --------------------------------------------------------------------------- 829 // DmaStartOp 830 // --------------------------------------------------------------------------- 831 832 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 833 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 834 ValueRange destIndices, Value numElements, 835 Value tagMemRef, ValueRange tagIndices, Value stride, 836 Value elementsPerStride) { 837 result.addOperands(srcMemRef); 838 result.addOperands(srcIndices); 839 result.addOperands(destMemRef); 840 result.addOperands(destIndices); 841 result.addOperands({numElements, tagMemRef}); 842 result.addOperands(tagIndices); 843 if (stride) 844 result.addOperands({stride, elementsPerStride}); 845 } 846 847 void DmaStartOp::print(OpAsmPrinter &p) { 848 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 849 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 850 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 851 << ']'; 852 if (isStrided()) 853 p << ", " << getStride() << ", " << getNumElementsPerStride(); 854 855 p.printOptionalAttrDict((*this)->getAttrs()); 856 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 857 << ", " << getTagMemRef().getType(); 858 } 859 860 // Parse DmaStartOp. 861 // Ex: 862 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 863 // %tag[%index], %stride, %num_elt_per_stride : 864 // : memref<3076 x f32, 0>, 865 // memref<1024 x f32, 2>, 866 // memref<1 x i32> 867 // 868 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 869 OpAsmParser::OperandType srcMemRefInfo; 870 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 871 OpAsmParser::OperandType dstMemRefInfo; 872 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 873 OpAsmParser::OperandType numElementsInfo; 874 OpAsmParser::OperandType tagMemrefInfo; 875 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 876 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 877 878 SmallVector<Type, 3> types; 879 auto indexType = parser.getBuilder().getIndexType(); 880 881 // Parse and resolve the following list of operands: 882 // *) source memref followed by its indices (in square brackets). 883 // *) destination memref followed by its indices (in square brackets). 884 // *) dma size in KiB. 885 if (parser.parseOperand(srcMemRefInfo) || 886 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 887 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 888 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 889 parser.parseComma() || parser.parseOperand(numElementsInfo) || 890 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 891 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 892 return failure(); 893 894 // Parse optional stride and elements per stride. 895 if (parser.parseTrailingOperandList(strideInfo)) 896 return failure(); 897 898 bool isStrided = strideInfo.size() == 2; 899 if (!strideInfo.empty() && !isStrided) { 900 return parser.emitError(parser.getNameLoc(), 901 "expected two stride related operands"); 902 } 903 904 if (parser.parseColonTypeList(types)) 905 return failure(); 906 if (types.size() != 3) 907 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 908 909 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 910 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 911 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 912 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 913 // size should be an index. 914 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 915 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 916 // tag indices should be index. 917 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 918 return failure(); 919 920 if (isStrided) { 921 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 922 return failure(); 923 } 924 925 return success(); 926 } 927 928 LogicalResult DmaStartOp::verify() { 929 unsigned numOperands = getNumOperands(); 930 931 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 932 // the number of elements. 933 if (numOperands < 4) 934 return emitOpError("expected at least 4 operands"); 935 936 // Check types of operands. The order of these calls is important: the later 937 // calls rely on some type properties to compute the operand position. 938 // 1. Source memref. 939 if (!getSrcMemRef().getType().isa<MemRefType>()) 940 return emitOpError("expected source to be of memref type"); 941 if (numOperands < getSrcMemRefRank() + 4) 942 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 943 << " operands"; 944 if (!getSrcIndices().empty() && 945 !llvm::all_of(getSrcIndices().getTypes(), 946 [](Type t) { return t.isIndex(); })) 947 return emitOpError("expected source indices to be of index type"); 948 949 // 2. Destination memref. 950 if (!getDstMemRef().getType().isa<MemRefType>()) 951 return emitOpError("expected destination to be of memref type"); 952 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 953 if (numOperands < numExpectedOperands) 954 return emitOpError() << "expected at least " << numExpectedOperands 955 << " operands"; 956 if (!getDstIndices().empty() && 957 !llvm::all_of(getDstIndices().getTypes(), 958 [](Type t) { return t.isIndex(); })) 959 return emitOpError("expected destination indices to be of index type"); 960 961 // 3. Number of elements. 962 if (!getNumElements().getType().isIndex()) 963 return emitOpError("expected num elements to be of index type"); 964 965 // 4. Tag memref. 966 if (!getTagMemRef().getType().isa<MemRefType>()) 967 return emitOpError("expected tag to be of memref type"); 968 numExpectedOperands += getTagMemRefRank(); 969 if (numOperands < numExpectedOperands) 970 return emitOpError() << "expected at least " << numExpectedOperands 971 << " operands"; 972 if (!getTagIndices().empty() && 973 !llvm::all_of(getTagIndices().getTypes(), 974 [](Type t) { return t.isIndex(); })) 975 return emitOpError("expected tag indices to be of index type"); 976 977 // DMAs from different memory spaces supported. 978 if (getSrcMemorySpace() == getDstMemorySpace()) 979 return emitOpError("DMA should be between different memory spaces"); 980 981 // Optional stride-related operands must be either both present or both 982 // absent. 983 if (numOperands != numExpectedOperands && 984 numOperands != numExpectedOperands + 2) 985 return emitOpError("incorrect number of operands"); 986 987 // 5. Strides. 988 if (isStrided()) { 989 if (!getStride().getType().isIndex() || 990 !getNumElementsPerStride().getType().isIndex()) 991 return emitOpError( 992 "expected stride and num elements per stride to be of type index"); 993 } 994 995 return success(); 996 } 997 998 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 999 SmallVectorImpl<OpFoldResult> &results) { 1000 /// dma_start(memrefcast) -> dma_start 1001 return foldMemRefCast(*this); 1002 } 1003 1004 // --------------------------------------------------------------------------- 1005 // DmaWaitOp 1006 // --------------------------------------------------------------------------- 1007 1008 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 1009 Value tagMemRef, ValueRange tagIndices, 1010 Value numElements) { 1011 result.addOperands(tagMemRef); 1012 result.addOperands(tagIndices); 1013 result.addOperands(numElements); 1014 } 1015 1016 void DmaWaitOp::print(OpAsmPrinter &p) { 1017 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 1018 << "], " << getNumElements(); 1019 p.printOptionalAttrDict((*this)->getAttrs()); 1020 p << " : " << getTagMemRef().getType(); 1021 } 1022 1023 // Parse DmaWaitOp. 1024 // Eg: 1025 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 1026 // 1027 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 1028 OpAsmParser::OperandType tagMemrefInfo; 1029 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 1030 Type type; 1031 auto indexType = parser.getBuilder().getIndexType(); 1032 OpAsmParser::OperandType numElementsInfo; 1033 1034 // Parse tag memref, its indices, and dma size. 1035 if (parser.parseOperand(tagMemrefInfo) || 1036 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 1037 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1038 parser.parseColonType(type) || 1039 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 1040 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 1041 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1042 return failure(); 1043 1044 return success(); 1045 } 1046 1047 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1048 SmallVectorImpl<OpFoldResult> &results) { 1049 /// dma_wait(memrefcast) -> dma_wait 1050 return foldMemRefCast(*this); 1051 } 1052 1053 LogicalResult DmaWaitOp::verify() { 1054 // Mandatory non-variadic operands are tag and the number of elements. 1055 if (getNumOperands() < 2) 1056 return emitOpError() << "expected at least 2 operands"; 1057 1058 // Check types of operands. The order of these calls is important: the later 1059 // calls rely on some type properties to compute the operand position. 1060 if (!getTagMemRef().getType().isa<MemRefType>()) 1061 return emitOpError() << "expected tag to be of memref type"; 1062 1063 if (getNumOperands() != 2 + getTagMemRefRank()) 1064 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1065 << " operands"; 1066 1067 if (!getTagIndices().empty() && 1068 !llvm::all_of(getTagIndices().getTypes(), 1069 [](Type t) { return t.isIndex(); })) 1070 return emitOpError() << "expected tag indices to be of index type"; 1071 1072 if (!getNumElements().getType().isIndex()) 1073 return emitOpError() 1074 << "expected the number of elements to be of index type"; 1075 1076 return success(); 1077 } 1078 1079 //===----------------------------------------------------------------------===// 1080 // GlobalOp 1081 //===----------------------------------------------------------------------===// 1082 1083 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1084 TypeAttr type, 1085 Attribute initialValue) { 1086 p << type; 1087 if (!op.isExternal()) { 1088 p << " = "; 1089 if (op.isUninitialized()) 1090 p << "uninitialized"; 1091 else 1092 p.printAttributeWithoutType(initialValue); 1093 } 1094 } 1095 1096 static ParseResult 1097 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1098 Attribute &initialValue) { 1099 Type type; 1100 if (parser.parseType(type)) 1101 return failure(); 1102 1103 auto memrefType = type.dyn_cast<MemRefType>(); 1104 if (!memrefType || !memrefType.hasStaticShape()) 1105 return parser.emitError(parser.getNameLoc()) 1106 << "type should be static shaped memref, but got " << type; 1107 typeAttr = TypeAttr::get(type); 1108 1109 if (parser.parseOptionalEqual()) 1110 return success(); 1111 1112 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1113 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1114 return success(); 1115 } 1116 1117 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1118 if (parser.parseAttribute(initialValue, tensorType)) 1119 return failure(); 1120 if (!initialValue.isa<ElementsAttr>()) 1121 return parser.emitError(parser.getNameLoc()) 1122 << "initial value should be a unit or elements attribute"; 1123 return success(); 1124 } 1125 1126 static LogicalResult verify(GlobalOp op) { 1127 auto memrefType = op.type().dyn_cast<MemRefType>(); 1128 if (!memrefType || !memrefType.hasStaticShape()) 1129 return op.emitOpError("type should be static shaped memref, but got ") 1130 << op.type(); 1131 1132 // Verify that the initial value, if present, is either a unit attribute or 1133 // an elements attribute. 1134 if (op.initial_value().hasValue()) { 1135 Attribute initValue = op.initial_value().getValue(); 1136 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1137 return op.emitOpError("initial value should be a unit or elements " 1138 "attribute, but got ") 1139 << initValue; 1140 1141 // Check that the type of the initial value is compatible with the type of 1142 // the global variable. 1143 if (initValue.isa<ElementsAttr>()) { 1144 Type initType = initValue.getType(); 1145 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1146 if (initType != tensorType) 1147 return op.emitOpError("initial value expected to be of type ") 1148 << tensorType << ", but was of type " << initType; 1149 } 1150 } 1151 1152 // TODO: verify visibility for declarations. 1153 return success(); 1154 } 1155 1156 //===----------------------------------------------------------------------===// 1157 // GetGlobalOp 1158 //===----------------------------------------------------------------------===// 1159 1160 LogicalResult 1161 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1162 // Verify that the result type is same as the type of the referenced 1163 // memref.global op. 1164 auto global = 1165 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1166 if (!global) 1167 return emitOpError("'") 1168 << name() << "' does not reference a valid global memref"; 1169 1170 Type resultType = result().getType(); 1171 if (global.type() != resultType) 1172 return emitOpError("result type ") 1173 << resultType << " does not match type " << global.type() 1174 << " of the global memref @" << name(); 1175 return success(); 1176 } 1177 1178 //===----------------------------------------------------------------------===// 1179 // LoadOp 1180 //===----------------------------------------------------------------------===// 1181 1182 static LogicalResult verify(LoadOp op) { 1183 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1184 return op.emitOpError("incorrect number of indices for load"); 1185 return success(); 1186 } 1187 1188 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1189 /// load(memrefcast) -> load 1190 if (succeeded(foldMemRefCast(*this))) 1191 return getResult(); 1192 return OpFoldResult(); 1193 } 1194 1195 namespace { 1196 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1197 /// corresponding tensor. 1198 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1199 using OpRewritePattern<LoadOp>::OpRewritePattern; 1200 1201 LogicalResult matchAndRewrite(LoadOp load, 1202 PatternRewriter &rewriter) const override { 1203 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1204 if (!buffercast) 1205 return failure(); 1206 1207 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1208 load.indices()); 1209 return success(); 1210 } 1211 }; 1212 } // end anonymous namespace. 1213 1214 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1215 MLIRContext *context) { 1216 results.add<LoadOfBufferCast>(context); 1217 } 1218 1219 //===----------------------------------------------------------------------===// 1220 // PrefetchOp 1221 //===----------------------------------------------------------------------===// 1222 1223 static void print(OpAsmPrinter &p, PrefetchOp op) { 1224 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1225 p.printOperands(op.indices()); 1226 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1227 p << ", locality<" << op.localityHint(); 1228 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1229 p.printOptionalAttrDict( 1230 op->getAttrs(), 1231 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1232 p << " : " << op.getMemRefType(); 1233 } 1234 1235 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1236 OperationState &result) { 1237 OpAsmParser::OperandType memrefInfo; 1238 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1239 IntegerAttr localityHint; 1240 MemRefType type; 1241 StringRef readOrWrite, cacheType; 1242 1243 auto indexTy = parser.getBuilder().getIndexType(); 1244 auto i32Type = parser.getBuilder().getIntegerType(32); 1245 if (parser.parseOperand(memrefInfo) || 1246 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1247 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1248 parser.parseComma() || parser.parseKeyword("locality") || 1249 parser.parseLess() || 1250 parser.parseAttribute(localityHint, i32Type, "localityHint", 1251 result.attributes) || 1252 parser.parseGreater() || parser.parseComma() || 1253 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1254 parser.resolveOperand(memrefInfo, type, result.operands) || 1255 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1256 return failure(); 1257 1258 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1259 return parser.emitError(parser.getNameLoc(), 1260 "rw specifier has to be 'read' or 'write'"); 1261 result.addAttribute( 1262 PrefetchOp::getIsWriteAttrName(), 1263 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1264 1265 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1266 return parser.emitError(parser.getNameLoc(), 1267 "cache type has to be 'data' or 'instr'"); 1268 1269 result.addAttribute( 1270 PrefetchOp::getIsDataCacheAttrName(), 1271 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1272 1273 return success(); 1274 } 1275 1276 static LogicalResult verify(PrefetchOp op) { 1277 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1278 return op.emitOpError("too few indices"); 1279 1280 return success(); 1281 } 1282 1283 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1284 SmallVectorImpl<OpFoldResult> &results) { 1285 // prefetch(memrefcast) -> prefetch 1286 return foldMemRefCast(*this); 1287 } 1288 1289 //===----------------------------------------------------------------------===// 1290 // ReinterpretCastOp 1291 //===----------------------------------------------------------------------===// 1292 1293 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1294 /// `staticSizes` and `staticStrides` are automatically filled with 1295 /// source-memref-rank sentinel values that encode dynamic entries. 1296 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1297 MemRefType resultType, Value source, 1298 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1299 ArrayRef<OpFoldResult> strides, 1300 ArrayRef<NamedAttribute> attrs) { 1301 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1302 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1303 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1304 ShapedType::kDynamicStrideOrOffset); 1305 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1306 ShapedType::kDynamicSize); 1307 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1308 ShapedType::kDynamicStrideOrOffset); 1309 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1310 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1311 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1312 result.addAttributes(attrs); 1313 } 1314 1315 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1316 MemRefType resultType, Value source, 1317 int64_t offset, ArrayRef<int64_t> sizes, 1318 ArrayRef<int64_t> strides, 1319 ArrayRef<NamedAttribute> attrs) { 1320 SmallVector<OpFoldResult> sizeValues = 1321 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1322 return b.getI64IntegerAttr(v); 1323 })); 1324 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1325 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1326 return b.getI64IntegerAttr(v); 1327 })); 1328 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1329 strideValues, attrs); 1330 } 1331 1332 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1333 MemRefType resultType, Value source, Value offset, 1334 ValueRange sizes, ValueRange strides, 1335 ArrayRef<NamedAttribute> attrs) { 1336 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1337 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1338 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1339 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1340 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1341 } 1342 1343 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1344 // completed automatically, like we have for subview and subtensor. 1345 static LogicalResult verify(ReinterpretCastOp op) { 1346 // The source and result memrefs should be in the same memory space. 1347 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1348 auto resultType = op.getType().cast<MemRefType>(); 1349 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1350 return op.emitError("different memory spaces specified for source type ") 1351 << srcType << " and result memref type " << resultType; 1352 if (srcType.getElementType() != resultType.getElementType()) 1353 return op.emitError("different element types specified for source type ") 1354 << srcType << " and result memref type " << resultType; 1355 1356 // Match sizes in result memref type and in static_sizes attribute. 1357 for (auto &en : 1358 llvm::enumerate(llvm::zip(resultType.getShape(), 1359 extractFromI64ArrayAttr(op.static_sizes())))) { 1360 int64_t resultSize = std::get<0>(en.value()); 1361 int64_t expectedSize = std::get<1>(en.value()); 1362 if (resultSize != expectedSize) 1363 return op.emitError("expected result type with size = ") 1364 << expectedSize << " instead of " << resultSize 1365 << " in dim = " << en.index(); 1366 } 1367 1368 // Match offset and strides in static_offset and static_strides attributes if 1369 // result memref type has an affine map specified. 1370 if (!resultType.getAffineMaps().empty()) { 1371 int64_t resultOffset; 1372 SmallVector<int64_t, 4> resultStrides; 1373 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1374 return failure(); 1375 1376 // Match offset in result memref type and in static_offsets attribute. 1377 int64_t expectedOffset = 1378 extractFromI64ArrayAttr(op.static_offsets()).front(); 1379 if (resultOffset != expectedOffset) 1380 return op.emitError("expected result type with offset = ") 1381 << resultOffset << " instead of " << expectedOffset; 1382 1383 // Match strides in result memref type and in static_strides attribute. 1384 for (auto &en : llvm::enumerate(llvm::zip( 1385 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1386 int64_t resultStride = std::get<0>(en.value()); 1387 int64_t expectedStride = std::get<1>(en.value()); 1388 if (resultStride != expectedStride) 1389 return op.emitError("expected result type with stride = ") 1390 << expectedStride << " instead of " << resultStride 1391 << " in dim = " << en.index(); 1392 } 1393 } 1394 return success(); 1395 } 1396 1397 //===----------------------------------------------------------------------===// 1398 // ReshapeOp 1399 //===----------------------------------------------------------------------===// 1400 1401 static LogicalResult verify(ReshapeOp op) { 1402 Type operandType = op.source().getType(); 1403 Type resultType = op.result().getType(); 1404 1405 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1406 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1407 if (operandElementType != resultElementType) 1408 return op.emitOpError("element types of source and destination memref " 1409 "types should be the same"); 1410 1411 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1412 if (!operandMemRefType.getAffineMaps().empty()) 1413 return op.emitOpError( 1414 "source memref type should have identity affine map"); 1415 1416 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1417 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1418 if (resultMemRefType) { 1419 if (!resultMemRefType.getAffineMaps().empty()) 1420 return op.emitOpError( 1421 "result memref type should have identity affine map"); 1422 if (shapeSize == ShapedType::kDynamicSize) 1423 return op.emitOpError("cannot use shape operand with dynamic length to " 1424 "reshape to statically-ranked memref type"); 1425 if (shapeSize != resultMemRefType.getRank()) 1426 return op.emitOpError( 1427 "length of shape operand differs from the result's memref rank"); 1428 } 1429 return success(); 1430 } 1431 1432 //===----------------------------------------------------------------------===// 1433 // StoreOp 1434 //===----------------------------------------------------------------------===// 1435 1436 static LogicalResult verify(StoreOp op) { 1437 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1438 return op.emitOpError("store index operand count not equal to memref rank"); 1439 1440 return success(); 1441 } 1442 1443 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1444 SmallVectorImpl<OpFoldResult> &results) { 1445 /// store(memrefcast) -> store 1446 return foldMemRefCast(*this); 1447 } 1448 1449 //===----------------------------------------------------------------------===// 1450 // SubViewOp 1451 //===----------------------------------------------------------------------===// 1452 1453 namespace { 1454 /// Helpers to write more idiomatic operations. 1455 namespace saturated_arith { 1456 struct Wrapper { 1457 explicit Wrapper(int64_t v) : v(v) {} 1458 operator int64_t() { return v; } 1459 int64_t v; 1460 }; 1461 Wrapper operator+(Wrapper a, int64_t b) { 1462 if (ShapedType::isDynamicStrideOrOffset(a) || 1463 ShapedType::isDynamicStrideOrOffset(b)) 1464 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1465 return Wrapper(a.v + b); 1466 } 1467 Wrapper operator*(Wrapper a, int64_t b) { 1468 if (ShapedType::isDynamicStrideOrOffset(a) || 1469 ShapedType::isDynamicStrideOrOffset(b)) 1470 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1471 return Wrapper(a.v * b); 1472 } 1473 } // end namespace saturated_arith 1474 } // end namespace 1475 1476 /// A subview result type can be fully inferred from the source type and the 1477 /// static representation of offsets, sizes and strides. Special sentinels 1478 /// encode the dynamic case. 1479 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1480 ArrayRef<int64_t> leadingStaticOffsets, 1481 ArrayRef<int64_t> leadingStaticSizes, 1482 ArrayRef<int64_t> leadingStaticStrides) { 1483 // A subview may specify only a leading subset of offset/sizes/strides in 1484 // which case we complete with offset=0, sizes from memref type and strides=1. 1485 unsigned rank = sourceMemRefType.getRank(); 1486 assert(leadingStaticOffsets.size() <= rank && 1487 "unexpected leadingStaticOffsets overflow"); 1488 assert(leadingStaticSizes.size() <= rank && 1489 "unexpected leadingStaticSizes overflow"); 1490 assert(leadingStaticStrides.size() <= rank && 1491 "unexpected leadingStaticStrides overflow"); 1492 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1493 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1494 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1495 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1496 unsigned numTrailingSizes = rank - staticSizes.size(); 1497 unsigned numTrailingStrides = rank - staticStrides.size(); 1498 staticOffsets.append(numTrailingOffsets, 0); 1499 llvm::append_range(staticSizes, 1500 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1501 staticStrides.append(numTrailingStrides, 1); 1502 1503 // Extract source offset and strides. 1504 int64_t sourceOffset; 1505 SmallVector<int64_t, 4> sourceStrides; 1506 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1507 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1508 (void)res; 1509 1510 // Compute target offset whose value is: 1511 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1512 int64_t targetOffset = sourceOffset; 1513 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1514 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1515 using namespace saturated_arith; 1516 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1517 } 1518 1519 // Compute target stride whose value is: 1520 // `sourceStrides_i * staticStrides_i`. 1521 SmallVector<int64_t, 4> targetStrides; 1522 targetStrides.reserve(staticOffsets.size()); 1523 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1524 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1525 using namespace saturated_arith; 1526 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1527 } 1528 1529 // The type is now known. 1530 return MemRefType::get( 1531 staticSizes, sourceMemRefType.getElementType(), 1532 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1533 sourceMemRefType.getContext()), 1534 sourceMemRefType.getMemorySpace()); 1535 } 1536 1537 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1538 ArrayRef<OpFoldResult> leadingStaticOffsets, 1539 ArrayRef<OpFoldResult> leadingStaticSizes, 1540 ArrayRef<OpFoldResult> leadingStaticStrides) { 1541 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1542 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1543 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1544 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1545 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1546 ShapedType::kDynamicSize); 1547 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1548 staticStrides, ShapedType::kDynamicStrideOrOffset); 1549 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1550 staticSizes, staticStrides) 1551 .cast<MemRefType>(); 1552 } 1553 1554 Type SubViewOp::inferRankReducedResultType( 1555 unsigned resultRank, MemRefType sourceRankedTensorType, 1556 ArrayRef<int64_t> leadingStaticOffsets, 1557 ArrayRef<int64_t> leadingStaticSizes, 1558 ArrayRef<int64_t> leadingStaticStrides) { 1559 auto inferredType = 1560 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1561 leadingStaticSizes, leadingStaticStrides) 1562 .cast<MemRefType>(); 1563 assert(inferredType.getRank() >= resultRank && "expected "); 1564 int rankDiff = inferredType.getRank() - resultRank; 1565 if (rankDiff > 0) { 1566 auto shape = inferredType.getShape(); 1567 llvm::SmallDenseSet<unsigned> dimsToProject; 1568 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1569 SmallVector<int64_t> projectedShape; 1570 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1571 if (!dimsToProject.contains(pos)) 1572 projectedShape.push_back(shape[pos]); 1573 1574 AffineMap map; 1575 auto maps = inferredType.getAffineMaps(); 1576 if (!maps.empty() && maps.front()) 1577 map = getProjectedMap(maps.front(), dimsToProject); 1578 inferredType = 1579 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1580 inferredType.getMemorySpace()); 1581 } 1582 return inferredType; 1583 } 1584 1585 Type SubViewOp::inferRankReducedResultType( 1586 unsigned resultRank, MemRefType sourceRankedTensorType, 1587 ArrayRef<OpFoldResult> leadingStaticOffsets, 1588 ArrayRef<OpFoldResult> leadingStaticSizes, 1589 ArrayRef<OpFoldResult> leadingStaticStrides) { 1590 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1591 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1592 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1593 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1594 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1595 ShapedType::kDynamicSize); 1596 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1597 staticStrides, ShapedType::kDynamicStrideOrOffset); 1598 return SubViewOp::inferRankReducedResultType( 1599 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1600 staticStrides); 1601 } 1602 // Build a SubViewOp with mixed static and dynamic entries and custom result 1603 // type. If the type passed is nullptr, it is inferred. 1604 void SubViewOp::build(OpBuilder &b, OperationState &result, 1605 MemRefType resultType, Value source, 1606 ArrayRef<OpFoldResult> offsets, 1607 ArrayRef<OpFoldResult> sizes, 1608 ArrayRef<OpFoldResult> strides, 1609 ArrayRef<NamedAttribute> attrs) { 1610 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1611 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1612 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1613 ShapedType::kDynamicStrideOrOffset); 1614 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1615 ShapedType::kDynamicSize); 1616 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1617 ShapedType::kDynamicStrideOrOffset); 1618 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1619 // Structuring implementation this way avoids duplication between builders. 1620 if (!resultType) { 1621 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1622 staticSizes, staticStrides) 1623 .cast<MemRefType>(); 1624 } 1625 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1626 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1627 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1628 result.addAttributes(attrs); 1629 } 1630 1631 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1632 // type. 1633 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1634 ArrayRef<OpFoldResult> offsets, 1635 ArrayRef<OpFoldResult> sizes, 1636 ArrayRef<OpFoldResult> strides, 1637 ArrayRef<NamedAttribute> attrs) { 1638 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1639 } 1640 1641 // Build a SubViewOp with static entries and inferred result type. 1642 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1643 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1644 ArrayRef<int64_t> strides, 1645 ArrayRef<NamedAttribute> attrs) { 1646 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1647 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1648 return b.getI64IntegerAttr(v); 1649 })); 1650 SmallVector<OpFoldResult> sizeValues = 1651 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1652 return b.getI64IntegerAttr(v); 1653 })); 1654 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1655 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1656 return b.getI64IntegerAttr(v); 1657 })); 1658 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1659 } 1660 1661 // Build a SubViewOp with dynamic entries and custom result type. If the 1662 // type passed is nullptr, it is inferred. 1663 void SubViewOp::build(OpBuilder &b, OperationState &result, 1664 MemRefType resultType, Value source, 1665 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1666 ArrayRef<int64_t> strides, 1667 ArrayRef<NamedAttribute> attrs) { 1668 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1669 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1670 return b.getI64IntegerAttr(v); 1671 })); 1672 SmallVector<OpFoldResult> sizeValues = 1673 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1674 return b.getI64IntegerAttr(v); 1675 })); 1676 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1677 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1678 return b.getI64IntegerAttr(v); 1679 })); 1680 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1681 attrs); 1682 } 1683 1684 // Build a SubViewOp with dynamic entries and custom result type. If the type 1685 // passed is nullptr, it is inferred. 1686 void SubViewOp::build(OpBuilder &b, OperationState &result, 1687 MemRefType resultType, Value source, ValueRange offsets, 1688 ValueRange sizes, ValueRange strides, 1689 ArrayRef<NamedAttribute> attrs) { 1690 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1691 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1692 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1693 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1694 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1695 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1696 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1697 } 1698 1699 // Build a SubViewOp with dynamic entries and inferred result type. 1700 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1701 ValueRange offsets, ValueRange sizes, ValueRange strides, 1702 ArrayRef<NamedAttribute> attrs) { 1703 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1704 } 1705 1706 /// For ViewLikeOpInterface. 1707 Value SubViewOp::getViewSource() { return source(); } 1708 1709 enum SubViewVerificationResult { 1710 Success, 1711 RankTooLarge, 1712 SizeMismatch, 1713 ElemTypeMismatch, 1714 MemSpaceMismatch, 1715 AffineMapMismatch 1716 }; 1717 1718 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1719 /// This function is slight variant of `is subsequence` algorithm where 1720 /// not matching dimension must be 1. 1721 static SubViewVerificationResult 1722 isRankReducedType(Type originalType, Type candidateReducedType, 1723 std::string *errMsg = nullptr) { 1724 if (originalType == candidateReducedType) 1725 return SubViewVerificationResult::Success; 1726 if (!originalType.isa<MemRefType>()) 1727 return SubViewVerificationResult::Success; 1728 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1729 return SubViewVerificationResult::Success; 1730 1731 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1732 ShapedType candidateReducedShapedType = 1733 candidateReducedType.cast<ShapedType>(); 1734 1735 // Rank and size logic is valid for all ShapedTypes. 1736 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1737 ArrayRef<int64_t> candidateReducedShape = 1738 candidateReducedShapedType.getShape(); 1739 unsigned originalRank = originalShape.size(), 1740 candidateReducedRank = candidateReducedShape.size(); 1741 if (candidateReducedRank > originalRank) 1742 return SubViewVerificationResult::RankTooLarge; 1743 1744 auto optionalUnusedDimsMask = 1745 computeRankReductionMask(originalShape, candidateReducedShape); 1746 1747 // Sizes cannot be matched in case empty vector is returned. 1748 if (!optionalUnusedDimsMask.hasValue()) 1749 return SubViewVerificationResult::SizeMismatch; 1750 1751 if (originalShapedType.getElementType() != 1752 candidateReducedShapedType.getElementType()) 1753 return SubViewVerificationResult::ElemTypeMismatch; 1754 1755 // Strided layout logic is relevant for MemRefType only. 1756 MemRefType original = originalType.cast<MemRefType>(); 1757 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1758 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1759 return SubViewVerificationResult::MemSpaceMismatch; 1760 1761 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1762 auto inferredType = 1763 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1764 AffineMap candidateLayout; 1765 if (candidateReduced.getAffineMaps().empty()) 1766 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1767 else 1768 candidateLayout = candidateReduced.getAffineMaps().front(); 1769 assert(inferredType.getNumResults() == 1 && 1770 candidateLayout.getNumResults() == 1); 1771 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1772 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1773 if (errMsg) { 1774 llvm::raw_string_ostream os(*errMsg); 1775 os << "inferred type: " << inferredType; 1776 } 1777 return SubViewVerificationResult::AffineMapMismatch; 1778 } 1779 // Check that the difference of the affine maps simplifies to 0. 1780 AffineExpr diffExpr = 1781 inferredType.getResult(0) - candidateLayout.getResult(0); 1782 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1783 inferredType.getNumSymbols()); 1784 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1785 if (!(cst && cst.getValue() == 0)) { 1786 if (errMsg) { 1787 llvm::raw_string_ostream os(*errMsg); 1788 os << "inferred type: " << inferredType; 1789 } 1790 return SubViewVerificationResult::AffineMapMismatch; 1791 } 1792 return SubViewVerificationResult::Success; 1793 } 1794 1795 template <typename OpTy> 1796 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1797 OpTy op, Type expectedType, 1798 StringRef errMsg = "") { 1799 auto memrefType = expectedType.cast<ShapedType>(); 1800 switch (result) { 1801 case SubViewVerificationResult::Success: 1802 return success(); 1803 case SubViewVerificationResult::RankTooLarge: 1804 return op.emitError("expected result rank to be smaller or equal to ") 1805 << "the source rank. " << errMsg; 1806 case SubViewVerificationResult::SizeMismatch: 1807 return op.emitError("expected result type to be ") 1808 << expectedType 1809 << " or a rank-reduced version. (mismatch of result sizes) " 1810 << errMsg; 1811 case SubViewVerificationResult::ElemTypeMismatch: 1812 return op.emitError("expected result element type to be ") 1813 << memrefType.getElementType() << errMsg; 1814 case SubViewVerificationResult::MemSpaceMismatch: 1815 return op.emitError("expected result and source memory spaces to match.") 1816 << errMsg; 1817 case SubViewVerificationResult::AffineMapMismatch: 1818 return op.emitError("expected result type to be ") 1819 << expectedType 1820 << " or a rank-reduced version. (mismatch of result affine map) " 1821 << errMsg; 1822 } 1823 llvm_unreachable("unexpected subview verification result"); 1824 } 1825 1826 /// Verifier for SubViewOp. 1827 static LogicalResult verify(SubViewOp op) { 1828 MemRefType baseType = op.getSourceType(); 1829 MemRefType subViewType = op.getType(); 1830 1831 // The base memref and the view memref should be in the same memory space. 1832 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1833 return op.emitError("different memory spaces specified for base memref " 1834 "type ") 1835 << baseType << " and subview memref type " << subViewType; 1836 1837 // Verify that the base memref type has a strided layout map. 1838 if (!isStrided(baseType)) 1839 return op.emitError("base type ") << baseType << " is not strided"; 1840 1841 // Verify result type against inferred type. 1842 auto expectedType = SubViewOp::inferResultType( 1843 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1844 extractFromI64ArrayAttr(op.static_sizes()), 1845 extractFromI64ArrayAttr(op.static_strides())); 1846 1847 std::string errMsg; 1848 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1849 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1850 } 1851 1852 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1853 return os << "range " << range.offset << ":" << range.size << ":" 1854 << range.stride; 1855 } 1856 1857 /// Return the list of Range (i.e. offset, size, stride). Each Range 1858 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1859 /// with `b` at location `loc`. 1860 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1861 OpBuilder &b, Location loc) { 1862 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1863 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1864 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1865 SmallVector<Range, 8> res; 1866 unsigned rank = ranks[0]; 1867 res.reserve(rank); 1868 for (unsigned idx = 0; idx < rank; ++idx) { 1869 Value offset = 1870 op.isDynamicOffset(idx) 1871 ? op.getDynamicOffset(idx) 1872 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1873 Value size = op.isDynamicSize(idx) 1874 ? op.getDynamicSize(idx) 1875 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1876 Value stride = 1877 op.isDynamicStride(idx) 1878 ? op.getDynamicStride(idx) 1879 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1880 res.emplace_back(Range{offset, size, stride}); 1881 } 1882 return res; 1883 } 1884 1885 namespace { 1886 /// Pattern to rewrite a subview op with MemRefCast arguments. 1887 /// This essentially pushes memref.cast past its consuming subview when 1888 /// `canFoldIntoConsumerOp` is true. 1889 /// 1890 /// Example: 1891 /// ``` 1892 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1893 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1894 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1895 /// ``` 1896 /// is rewritten into: 1897 /// ``` 1898 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1899 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1900 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1901 /// ``` 1902 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1903 public: 1904 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1905 1906 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1907 PatternRewriter &rewriter) const override { 1908 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1909 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1910 return matchPattern(operand, matchConstantIndex()); 1911 })) 1912 return failure(); 1913 1914 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1915 if (!castOp) 1916 return failure(); 1917 1918 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1919 return failure(); 1920 1921 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1922 /// the cast source operand type and the SubViewOp static information. This 1923 /// is the resulting type if the MemRefCastOp were folded. 1924 auto resultType = SubViewOp::inferRankReducedResultType( 1925 subViewOp.getType().getRank(), 1926 castOp.source().getType().cast<MemRefType>(), 1927 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1928 subViewOp.getMixedStrides()); 1929 Value newSubView = rewriter.create<SubViewOp>( 1930 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1931 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1932 subViewOp.static_sizes(), subViewOp.static_strides()); 1933 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1934 newSubView); 1935 return success(); 1936 } 1937 }; 1938 } // namespace 1939 1940 /// A canonicalizer wrapper to replace SubViewOps. 1941 struct SubViewCanonicalizer { 1942 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1943 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1944 } 1945 }; 1946 1947 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1948 MLIRContext *context) { 1949 results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1950 SubViewOp, SubViewCanonicalizer>, 1951 SubViewOpMemRefCastFolder>(context); 1952 } 1953 1954 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1955 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1956 auto sourceShapedType = source().getType().cast<ShapedType>(); 1957 1958 if (resultShapedType.hasStaticShape() && 1959 resultShapedType == sourceShapedType) { 1960 return getViewSource(); 1961 } 1962 1963 return {}; 1964 } 1965 1966 //===----------------------------------------------------------------------===// 1967 // TensorLoadOp 1968 //===----------------------------------------------------------------------===// 1969 1970 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1971 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1972 // Approximate alias analysis by conservatively folding only when no there 1973 // is no interleaved operation. 1974 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1975 bufferCast->getNextNode() == this->getOperation()) 1976 return bufferCast.tensor(); 1977 return {}; 1978 } 1979 1980 //===----------------------------------------------------------------------===// 1981 // TransposeOp 1982 //===----------------------------------------------------------------------===// 1983 1984 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1985 static MemRefType inferTransposeResultType(MemRefType memRefType, 1986 AffineMap permutationMap) { 1987 auto rank = memRefType.getRank(); 1988 auto originalSizes = memRefType.getShape(); 1989 // Compute permuted sizes. 1990 SmallVector<int64_t, 4> sizes(rank, 0); 1991 for (auto en : llvm::enumerate(permutationMap.getResults())) 1992 sizes[en.index()] = 1993 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 1994 1995 // Compute permuted strides. 1996 int64_t offset; 1997 SmallVector<int64_t, 4> strides; 1998 auto res = getStridesAndOffset(memRefType, strides, offset); 1999 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2000 (void)res; 2001 auto map = 2002 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2003 map = permutationMap ? map.compose(permutationMap) : map; 2004 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2005 } 2006 2007 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2008 AffineMapAttr permutation, 2009 ArrayRef<NamedAttribute> attrs) { 2010 auto permutationMap = permutation.getValue(); 2011 assert(permutationMap); 2012 2013 auto memRefType = in.getType().cast<MemRefType>(); 2014 // Compute result type. 2015 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2016 2017 build(b, result, resultType, in, attrs); 2018 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2019 } 2020 2021 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2022 static void print(OpAsmPrinter &p, TransposeOp op) { 2023 p << "memref.transpose " << op.in() << " " << op.permutation(); 2024 p.printOptionalAttrDict(op->getAttrs(), 2025 {TransposeOp::getPermutationAttrName()}); 2026 p << " : " << op.in().getType() << " to " << op.getType(); 2027 } 2028 2029 static ParseResult parseTransposeOp(OpAsmParser &parser, 2030 OperationState &result) { 2031 OpAsmParser::OperandType in; 2032 AffineMap permutation; 2033 MemRefType srcType, dstType; 2034 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2035 parser.parseOptionalAttrDict(result.attributes) || 2036 parser.parseColonType(srcType) || 2037 parser.resolveOperand(in, srcType, result.operands) || 2038 parser.parseKeywordType("to", dstType) || 2039 parser.addTypeToList(dstType, result.types)) 2040 return failure(); 2041 2042 result.addAttribute(TransposeOp::getPermutationAttrName(), 2043 AffineMapAttr::get(permutation)); 2044 return success(); 2045 } 2046 2047 static LogicalResult verify(TransposeOp op) { 2048 if (!op.permutation().isPermutation()) 2049 return op.emitOpError("expected a permutation map"); 2050 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2051 return op.emitOpError( 2052 "expected a permutation map of same rank as the input"); 2053 2054 auto srcType = op.in().getType().cast<MemRefType>(); 2055 auto dstType = op.getType().cast<MemRefType>(); 2056 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2057 if (dstType != transposedType) 2058 return op.emitOpError("output type ") 2059 << dstType << " does not match transposed input type " << srcType 2060 << ", " << transposedType; 2061 return success(); 2062 } 2063 2064 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2065 if (succeeded(foldMemRefCast(*this))) 2066 return getResult(); 2067 return {}; 2068 } 2069 2070 //===----------------------------------------------------------------------===// 2071 // ViewOp 2072 //===----------------------------------------------------------------------===// 2073 2074 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2075 OpAsmParser::OperandType srcInfo; 2076 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2077 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2078 auto indexType = parser.getBuilder().getIndexType(); 2079 Type srcType, dstType; 2080 llvm::SMLoc offsetLoc; 2081 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2082 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2083 return failure(); 2084 2085 if (offsetInfo.size() != 1) 2086 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2087 2088 return failure( 2089 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2090 parser.parseOptionalAttrDict(result.attributes) || 2091 parser.parseColonType(srcType) || 2092 parser.resolveOperand(srcInfo, srcType, result.operands) || 2093 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2094 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2095 parser.parseKeywordType("to", dstType) || 2096 parser.addTypeToList(dstType, result.types)); 2097 } 2098 2099 static void print(OpAsmPrinter &p, ViewOp op) { 2100 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2101 p.printOperand(op.byte_shift()); 2102 p << "][" << op.sizes() << ']'; 2103 p.printOptionalAttrDict(op->getAttrs()); 2104 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2105 } 2106 2107 static LogicalResult verify(ViewOp op) { 2108 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2109 auto viewType = op.getType(); 2110 2111 // The base memref should have identity layout map (or none). 2112 if (baseType.getAffineMaps().size() > 1 || 2113 (baseType.getAffineMaps().size() == 1 && 2114 !baseType.getAffineMaps()[0].isIdentity())) 2115 return op.emitError("unsupported map for base memref type ") << baseType; 2116 2117 // The result memref should have identity layout map (or none). 2118 if (viewType.getAffineMaps().size() > 1 || 2119 (viewType.getAffineMaps().size() == 1 && 2120 !viewType.getAffineMaps()[0].isIdentity())) 2121 return op.emitError("unsupported map for result memref type ") << viewType; 2122 2123 // The base memref and the view memref should be in the same memory space. 2124 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2125 return op.emitError("different memory spaces specified for base memref " 2126 "type ") 2127 << baseType << " and view memref type " << viewType; 2128 2129 // Verify that we have the correct number of sizes for the result type. 2130 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2131 if (op.sizes().size() != numDynamicDims) 2132 return op.emitError("incorrect number of size operands for type ") 2133 << viewType; 2134 2135 return success(); 2136 } 2137 2138 Value ViewOp::getViewSource() { return source(); } 2139 2140 namespace { 2141 2142 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2143 using OpRewritePattern<ViewOp>::OpRewritePattern; 2144 2145 LogicalResult matchAndRewrite(ViewOp viewOp, 2146 PatternRewriter &rewriter) const override { 2147 // Return if none of the operands are constants. 2148 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2149 return matchPattern(operand, matchConstantIndex()); 2150 })) 2151 return failure(); 2152 2153 // Get result memref type. 2154 auto memrefType = viewOp.getType(); 2155 2156 // Get offset from old memref view type 'memRefType'. 2157 int64_t oldOffset; 2158 SmallVector<int64_t, 4> oldStrides; 2159 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2160 return failure(); 2161 assert(oldOffset == 0 && "Expected 0 offset"); 2162 2163 SmallVector<Value, 4> newOperands; 2164 2165 // Offset cannot be folded into result type. 2166 2167 // Fold any dynamic dim operands which are produced by a constant. 2168 SmallVector<int64_t, 4> newShapeConstants; 2169 newShapeConstants.reserve(memrefType.getRank()); 2170 2171 unsigned dynamicDimPos = 0; 2172 unsigned rank = memrefType.getRank(); 2173 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2174 int64_t dimSize = memrefType.getDimSize(dim); 2175 // If this is already static dimension, keep it. 2176 if (!ShapedType::isDynamic(dimSize)) { 2177 newShapeConstants.push_back(dimSize); 2178 continue; 2179 } 2180 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2181 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2182 // Dynamic shape dimension will be folded. 2183 newShapeConstants.push_back(constantIndexOp.getValue()); 2184 } else { 2185 // Dynamic shape dimension not folded; copy operand from old memref. 2186 newShapeConstants.push_back(dimSize); 2187 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2188 } 2189 dynamicDimPos++; 2190 } 2191 2192 // Create new memref type with constant folded dims. 2193 MemRefType newMemRefType = 2194 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2195 // Nothing new, don't fold. 2196 if (newMemRefType == memrefType) 2197 return failure(); 2198 2199 // Create new ViewOp. 2200 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2201 viewOp.getOperand(0), 2202 viewOp.byte_shift(), newOperands); 2203 // Insert a cast so we have the same type as the old memref type. 2204 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2205 return success(); 2206 } 2207 }; 2208 2209 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2210 using OpRewritePattern<ViewOp>::OpRewritePattern; 2211 2212 LogicalResult matchAndRewrite(ViewOp viewOp, 2213 PatternRewriter &rewriter) const override { 2214 Value memrefOperand = viewOp.getOperand(0); 2215 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2216 if (!memrefCastOp) 2217 return failure(); 2218 Value allocOperand = memrefCastOp.getOperand(); 2219 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2220 if (!allocOp) 2221 return failure(); 2222 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2223 viewOp.byte_shift(), viewOp.sizes()); 2224 return success(); 2225 } 2226 }; 2227 2228 } // end anonymous namespace 2229 2230 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2231 MLIRContext *context) { 2232 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2233 } 2234 2235 //===----------------------------------------------------------------------===// 2236 // TableGen'd op method definitions 2237 //===----------------------------------------------------------------------===// 2238 2239 #define GET_OP_CLASSES 2240 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2241