1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "llvm/ADT/STLExtras.h" 22 23 using namespace mlir; 24 using namespace mlir::memref; 25 26 /// Materialize a single constant operation from a given attribute value with 27 /// the desired resultant type. 28 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 29 Attribute value, Type type, 30 Location loc) { 31 return builder.create<mlir::ConstantOp>(loc, type, value); 32 } 33 34 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 35 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 36 return llvm::to_vector<4>( 37 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 38 return a.cast<IntegerAttr>().getInt(); 39 })); 40 } 41 42 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 43 /// it is a Value or into `staticVec` if it is an IntegerAttr. 44 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 45 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 46 /// come from an AttrSizedOperandSegments trait. 47 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 48 SmallVectorImpl<Value> &dynamicVec, 49 SmallVectorImpl<int64_t> &staticVec, 50 int64_t sentinel) { 51 if (auto v = ofr.dyn_cast<Value>()) { 52 dynamicVec.push_back(v); 53 staticVec.push_back(sentinel); 54 return; 55 } 56 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 57 staticVec.push_back(apInt.getSExtValue()); 58 } 59 60 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 61 SmallVectorImpl<Value> &dynamicVec, 62 SmallVectorImpl<int64_t> &staticVec, 63 int64_t sentinel) { 64 for (auto ofr : ofrs) 65 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // Common canonicalization pattern support logic 70 //===----------------------------------------------------------------------===// 71 72 /// This is a common class used for patterns of the form 73 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 74 /// into the root operation directly. 75 static LogicalResult foldMemRefCast(Operation *op) { 76 bool folded = false; 77 for (OpOperand &operand : op->getOpOperands()) { 78 auto cast = operand.get().getDefiningOp<CastOp>(); 79 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 80 operand.set(cast.getOperand()); 81 folded = true; 82 } 83 } 84 return success(folded); 85 } 86 87 //===----------------------------------------------------------------------===// 88 // Helpers for GlobalOp 89 //===----------------------------------------------------------------------===// 90 91 static Type getTensorTypeFromMemRefType(Type type) { 92 if (auto memref = type.dyn_cast<MemRefType>()) 93 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 94 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 95 return UnrankedTensorType::get(memref.getElementType()); 96 return NoneType::get(type.getContext()); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // AllocOp / AllocaOp 101 //===----------------------------------------------------------------------===// 102 103 template <typename AllocLikeOp> 104 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 105 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 106 "applies to only alloc or alloca"); 107 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 108 if (!memRefType) 109 return op.emitOpError("result must be a memref"); 110 111 if (static_cast<int64_t>(op.dynamicSizes().size()) != 112 memRefType.getNumDynamicDims()) 113 return op.emitOpError("dimension operand count does not equal memref " 114 "dynamic dimension count"); 115 116 unsigned numSymbols = 0; 117 if (!memRefType.getAffineMaps().empty()) 118 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 119 if (op.symbolOperands().size() != numSymbols) 120 return op.emitOpError( 121 "symbol operand count does not equal memref symbol count"); 122 123 return success(); 124 } 125 126 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 127 128 static LogicalResult verify(AllocaOp op) { 129 // An alloca op needs to have an ancestor with an allocation scope trait. 130 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 131 return op.emitOpError( 132 "requires an ancestor op with AutomaticAllocationScope trait"); 133 134 return verifyAllocLikeOp(op); 135 } 136 137 namespace { 138 /// Fold constant dimensions into an alloc like operation. 139 template <typename AllocLikeOp> 140 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 141 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 142 143 LogicalResult matchAndRewrite(AllocLikeOp alloc, 144 PatternRewriter &rewriter) const override { 145 // Check to see if any dimensions operands are constants. If so, we can 146 // substitute and drop them. 147 if (llvm::none_of(alloc.getOperands(), [](Value operand) { 148 return matchPattern(operand, matchConstantIndex()); 149 })) 150 return failure(); 151 152 auto memrefType = alloc.getType(); 153 154 // Ok, we have one or more constant operands. Collect the non-constant ones 155 // and keep track of the resultant memref type to build. 156 SmallVector<int64_t, 4> newShapeConstants; 157 newShapeConstants.reserve(memrefType.getRank()); 158 SmallVector<Value, 4> newOperands; 159 160 unsigned dynamicDimPos = 0; 161 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 162 int64_t dimSize = memrefType.getDimSize(dim); 163 // If this is already static dimension, keep it. 164 if (dimSize != -1) { 165 newShapeConstants.push_back(dimSize); 166 continue; 167 } 168 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); 169 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 170 // Dynamic shape dimension will be folded. 171 newShapeConstants.push_back(constantIndexOp.getValue()); 172 } else { 173 // Dynamic shape dimension not folded; copy operand from old memref. 174 newShapeConstants.push_back(-1); 175 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 176 } 177 dynamicDimPos++; 178 } 179 180 // Create new memref type (which will have fewer dynamic dimensions). 181 MemRefType newMemRefType = 182 MemRefType::Builder(memrefType).setShape(newShapeConstants); 183 assert(static_cast<int64_t>(newOperands.size()) == 184 newMemRefType.getNumDynamicDims()); 185 186 // Create and insert the alloc op for the new memref. 187 auto newAlloc = rewriter.create<AllocLikeOp>( 188 alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); 189 // Insert a cast so we have the same type as the old alloc. 190 auto resultCast = 191 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 192 193 rewriter.replaceOp(alloc, {resultCast}); 194 return success(); 195 } 196 }; 197 198 /// Fold alloc operations with no users or only store and dealloc uses. 199 template <typename T> 200 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 201 using OpRewritePattern<T>::OpRewritePattern; 202 203 LogicalResult matchAndRewrite(T alloc, 204 PatternRewriter &rewriter) const override { 205 if (llvm::any_of(alloc->getUsers(), [](Operation *op) { 206 return !isa<StoreOp, DeallocOp>(op); 207 })) 208 return failure(); 209 210 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 211 rewriter.eraseOp(user); 212 213 rewriter.eraseOp(alloc); 214 return success(); 215 } 216 }; 217 } // end anonymous namespace. 218 219 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 220 MLIRContext *context) { 221 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 222 } 223 224 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 225 MLIRContext *context) { 226 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 227 context); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // AssumeAlignmentOp 232 //===----------------------------------------------------------------------===// 233 234 static LogicalResult verify(AssumeAlignmentOp op) { 235 unsigned alignment = op.alignment(); 236 if (!llvm::isPowerOf2_32(alignment)) 237 return op.emitOpError("alignment must be power of 2"); 238 return success(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // BufferCastOp 243 //===----------------------------------------------------------------------===// 244 245 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 246 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 247 if (tensorLoad.memref().getType() == getType()) 248 return tensorLoad.memref(); 249 return {}; 250 } 251 252 namespace { 253 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 254 struct BufferCast : public OpRewritePattern<BufferCastOp> { 255 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 256 257 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 258 PatternRewriter &rewriter) const final { 259 auto tensorCastOperand = 260 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 261 if (!tensorCastOperand) 262 return failure(); 263 auto srcTensorType = 264 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 265 if (!srcTensorType) 266 return failure(); 267 auto memrefType = MemRefType::get(srcTensorType.getShape(), 268 srcTensorType.getElementType()); 269 Value memref = rewriter.create<BufferCastOp>( 270 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 271 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 272 memref); 273 return success(); 274 } 275 }; 276 277 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 278 /// type mismatches prevent `BufferCastOp::fold` to kick in. 279 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 280 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 281 282 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 283 PatternRewriter &rewriter) const final { 284 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 285 // Bail unless we have a tensor_load + memref.buffer_cast with different 286 // types. `BufferCastOp::fold` handles the same type case. 287 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 288 return failure(); 289 // If types are not cast-compatible, bail. 290 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 291 bufferCast.getType())) 292 return failure(); 293 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 294 tensorLoad.memref()); 295 return success(); 296 } 297 }; 298 299 } // namespace 300 301 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 302 MLIRContext *context) { 303 results.add<BufferCast, TensorLoadToMemRef>(context); 304 } 305 306 //===----------------------------------------------------------------------===// 307 // CastOp 308 //===----------------------------------------------------------------------===// 309 310 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 311 /// source memref. This is useful to to fold a memref.cast into a consuming op 312 /// and implement canonicalization patterns for ops in different dialects that 313 /// may consume the results of memref.cast operations. Such foldable memref.cast 314 /// operations are typically inserted as `view` and `subview` ops are 315 /// canonicalized, to preserve the type compatibility of their uses. 316 /// 317 /// Returns true when all conditions are met: 318 /// 1. source and result are ranked memrefs with strided semantics and same 319 /// element type and rank. 320 /// 2. each of the source's size, offset or stride has more static information 321 /// than the corresponding result's size, offset or stride. 322 /// 323 /// Example 1: 324 /// ```mlir 325 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 326 /// %2 = consumer %1 ... : memref<?x?xf32> ... 327 /// ``` 328 /// 329 /// may fold into: 330 /// 331 /// ```mlir 332 /// %2 = consumer %0 ... : memref<8x16xf32> ... 333 /// ``` 334 /// 335 /// Example 2: 336 /// ``` 337 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 338 /// to memref<?x?xf32> 339 /// consumer %1 : memref<?x?xf32> ... 340 /// ``` 341 /// 342 /// may fold into: 343 /// 344 /// ``` 345 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 346 /// ``` 347 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 348 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 349 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 350 351 // Requires ranked MemRefType. 352 if (!sourceType || !resultType) 353 return false; 354 355 // Requires same elemental type. 356 if (sourceType.getElementType() != resultType.getElementType()) 357 return false; 358 359 // Requires same rank. 360 if (sourceType.getRank() != resultType.getRank()) 361 return false; 362 363 // Only fold casts between strided memref forms. 364 int64_t sourceOffset, resultOffset; 365 SmallVector<int64_t, 4> sourceStrides, resultStrides; 366 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 367 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 368 return false; 369 370 // If cast is towards more static sizes along any dimension, don't fold. 371 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 372 auto ss = std::get<0>(it), st = std::get<1>(it); 373 if (ss != st) 374 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 375 return false; 376 } 377 378 // If cast is towards more static offset along any dimension, don't fold. 379 if (sourceOffset != resultOffset) 380 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 381 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 382 return false; 383 384 // If cast is towards more static strides along any dimension, don't fold. 385 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 386 auto ss = std::get<0>(it), st = std::get<1>(it); 387 if (ss != st) 388 if (MemRefType::isDynamicStrideOrOffset(ss) && 389 !MemRefType::isDynamicStrideOrOffset(st)) 390 return false; 391 } 392 393 return true; 394 } 395 396 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 397 if (inputs.size() != 1 || outputs.size() != 1) 398 return false; 399 Type a = inputs.front(), b = outputs.front(); 400 auto aT = a.dyn_cast<MemRefType>(); 401 auto bT = b.dyn_cast<MemRefType>(); 402 403 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 404 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 405 406 if (aT && bT) { 407 if (aT.getElementType() != bT.getElementType()) 408 return false; 409 if (aT.getAffineMaps() != bT.getAffineMaps()) { 410 int64_t aOffset, bOffset; 411 SmallVector<int64_t, 4> aStrides, bStrides; 412 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 413 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 414 aStrides.size() != bStrides.size()) 415 return false; 416 417 // Strides along a dimension/offset are compatible if the value in the 418 // source memref is static and the value in the target memref is the 419 // same. They are also compatible if either one is dynamic (see 420 // description of MemRefCastOp for details). 421 auto checkCompatible = [](int64_t a, int64_t b) { 422 return (a == MemRefType::getDynamicStrideOrOffset() || 423 b == MemRefType::getDynamicStrideOrOffset() || a == b); 424 }; 425 if (!checkCompatible(aOffset, bOffset)) 426 return false; 427 for (auto aStride : enumerate(aStrides)) 428 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 429 return false; 430 } 431 if (aT.getMemorySpace() != bT.getMemorySpace()) 432 return false; 433 434 // They must have the same rank, and any specified dimensions must match. 435 if (aT.getRank() != bT.getRank()) 436 return false; 437 438 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 439 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 440 if (aDim != -1 && bDim != -1 && aDim != bDim) 441 return false; 442 } 443 return true; 444 } else { 445 if (!aT && !uaT) 446 return false; 447 if (!bT && !ubT) 448 return false; 449 // Unranked to unranked casting is unsupported 450 if (uaT && ubT) 451 return false; 452 453 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 454 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 455 if (aEltType != bEltType) 456 return false; 457 458 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 459 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 460 if (aMemSpace != bMemSpace) 461 return false; 462 463 return true; 464 } 465 466 return false; 467 } 468 469 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 470 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 471 } 472 473 //===----------------------------------------------------------------------===// 474 // CloneOp 475 //===----------------------------------------------------------------------===// 476 477 void CloneOp::getEffects( 478 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 479 &effects) { 480 effects.emplace_back(MemoryEffects::Read::get(), input(), 481 SideEffects::DefaultResource::get()); 482 effects.emplace_back(MemoryEffects::Write::get(), output(), 483 SideEffects::DefaultResource::get()); 484 } 485 486 namespace { 487 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 488 /// by other Dealloc operations. 489 struct SimplifyClones : public OpRewritePattern<CloneOp> { 490 using OpRewritePattern<CloneOp>::OpRewritePattern; 491 492 LogicalResult matchAndRewrite(CloneOp cloneOp, 493 PatternRewriter &rewriter) const override { 494 if (cloneOp.use_empty()) { 495 rewriter.eraseOp(cloneOp); 496 return success(); 497 } 498 499 Value source = cloneOp.input(); 500 501 // Removes the clone operation and the corresponding dealloc and alloc 502 // operation (if any). 503 auto tryRemoveClone = [&](Operation *sourceOp, Operation *dealloc, 504 Operation *alloc) { 505 if (!sourceOp || !dealloc || !alloc || 506 alloc->getBlock() != dealloc->getBlock()) 507 return false; 508 rewriter.replaceOp(cloneOp, source); 509 rewriter.eraseOp(dealloc); 510 return true; 511 }; 512 513 // Removes unnecessary clones that are derived from the result of the clone 514 // op. 515 Operation *deallocOp = findDealloc(cloneOp.output()); 516 Operation *sourceOp = source.getDefiningOp(); 517 if (tryRemoveClone(sourceOp, deallocOp, sourceOp)) 518 return success(); 519 520 // Removes unnecessary clones that are derived from the source of the clone 521 // op. 522 deallocOp = findDealloc(source); 523 if (tryRemoveClone(sourceOp, deallocOp, cloneOp)) 524 return success(); 525 526 return failure(); 527 } 528 }; 529 530 } // end anonymous namespace. 531 532 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 533 MLIRContext *context) { 534 results.insert<SimplifyClones>(context); 535 } 536 537 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 538 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 539 } 540 541 //===----------------------------------------------------------------------===// 542 // DeallocOp 543 //===----------------------------------------------------------------------===// 544 545 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 546 SmallVectorImpl<OpFoldResult> &results) { 547 /// dealloc(memrefcast) -> dealloc 548 return foldMemRefCast(*this); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // DimOp 553 //===----------------------------------------------------------------------===// 554 555 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 556 int64_t index) { 557 auto loc = result.location; 558 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 559 build(builder, result, memref, indexValue); 560 } 561 562 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 563 Value index) { 564 auto indexTy = builder.getIndexType(); 565 build(builder, result, indexTy, memref, index); 566 } 567 568 Optional<int64_t> DimOp::getConstantIndex() { 569 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 570 return constantOp.getValue().cast<IntegerAttr>().getInt(); 571 return {}; 572 } 573 574 static LogicalResult verify(DimOp op) { 575 // Assume unknown index to be in range. 576 Optional<int64_t> index = op.getConstantIndex(); 577 if (!index.hasValue()) 578 return success(); 579 580 // Check that constant index is not knowingly out of range. 581 auto type = op.memrefOrTensor().getType(); 582 if (auto memrefType = type.dyn_cast<MemRefType>()) { 583 if (index.getValue() >= memrefType.getRank()) 584 return op.emitOpError("index is out of range"); 585 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 586 if (index.getValue() >= tensorType.getRank()) 587 return op.emitOpError("index is out of range"); 588 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 589 // Assume index to be in range. 590 } else { 591 llvm_unreachable("expected operand with memref type"); 592 } 593 return success(); 594 } 595 596 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 597 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 598 599 // All forms of folding require a known index. 600 if (!index) 601 return {}; 602 603 auto argTy = memrefOrTensor().getType(); 604 // Fold if the shape extent along the given index is known. 605 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 606 // Folding for unranked types (UnrankedMemRefType) is not supported. 607 if (!shapedTy.hasRank()) 608 return {}; 609 if (!shapedTy.isDynamicDim(index.getInt())) { 610 Builder builder(getContext()); 611 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 612 } 613 } 614 615 Operation *definingOp = memrefOrTensor().getDefiningOp(); 616 617 // dim(memref.tensor_load(memref)) -> dim(memref) 618 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 619 setOperand(0, tensorLoadOp.memref()); 620 return getResult(); 621 } 622 623 // Fold dim to the operand of tensor.generate. 624 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 625 auto resultType = 626 fromElements.getResult().getType().cast<RankedTensorType>(); 627 // The case where the type encodes the size of the dimension is handled 628 // above. 629 assert(resultType.getShape()[index.getInt()] == 630 RankedTensorType::kDynamicSize); 631 632 // Find the operand of the fromElements that corresponds to this index. 633 auto dynExtents = fromElements.dynamicExtents().begin(); 634 for (auto dim : resultType.getShape().take_front(index.getInt())) 635 if (dim == RankedTensorType::kDynamicSize) 636 dynExtents++; 637 638 return Value{*dynExtents}; 639 } 640 641 // The size at the given index is now known to be a dynamic size. 642 unsigned unsignedIndex = index.getValue().getZExtValue(); 643 644 if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) { 645 assert(subtensor.isDynamicSize(unsignedIndex) && 646 "Expected dynamic subtensor size"); 647 return subtensor.getDynamicSize(unsignedIndex); 648 } 649 650 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 651 auto memrefType = argTy.dyn_cast<MemRefType>(); 652 if (!memrefType) 653 return {}; 654 655 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 656 return *(alloc.getDynamicSizes().begin() + 657 memrefType.getDynamicDimIndex(unsignedIndex)); 658 659 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 660 return *(alloca.getDynamicSizes().begin() + 661 memrefType.getDynamicDimIndex(unsignedIndex)); 662 663 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 664 return *(view.getDynamicSizes().begin() + 665 memrefType.getDynamicDimIndex(unsignedIndex)); 666 667 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 668 assert(subview.isDynamicSize(unsignedIndex) && 669 "Expected dynamic subview size"); 670 return subview.getDynamicSize(unsignedIndex); 671 } 672 673 // dim(memrefcast) -> dim 674 if (succeeded(foldMemRefCast(*this))) 675 return getResult(); 676 677 return {}; 678 } 679 680 namespace { 681 /// Fold dim of a memref reshape operation to a load into the reshape's shape 682 /// operand. 683 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 684 using OpRewritePattern<DimOp>::OpRewritePattern; 685 686 LogicalResult matchAndRewrite(DimOp dim, 687 PatternRewriter &rewriter) const override { 688 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 689 690 if (!reshape) 691 return failure(); 692 693 // Place the load directly after the reshape to ensure that the shape memref 694 // was not mutated. 695 rewriter.setInsertionPointAfter(reshape); 696 rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(), 697 llvm::makeArrayRef({dim.index()})); 698 return success(); 699 } 700 }; 701 702 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 703 template <typename CastOpTy> 704 struct DimOfCastOp : public OpRewritePattern<DimOp> { 705 using OpRewritePattern<DimOp>::OpRewritePattern; 706 707 LogicalResult matchAndRewrite(DimOp dimOp, 708 PatternRewriter &rewriter) const override { 709 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 710 if (!castOp) 711 return failure(); 712 Value newSource = castOp.getOperand(); 713 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 714 return success(); 715 } 716 }; 717 718 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th 719 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. 720 /// TODO(ravishankarm): This is better put as a interface utility method 721 /// somewhere, but that would imply the interface will depend on the `tensor` 722 /// dialect. Ideally maybe a utility method in the `tensor` dialect. 723 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, 724 int64_t dimIndex) { 725 unsigned resultNumber = result.getResultNumber(); 726 auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner()); 727 Location loc = result.getOwner()->getLoc(); 728 if (!shapedTypeOp) 729 return nullptr; 730 731 // The interface exposes two methods, one that returns the shape of all the 732 // results as `Value` and other that returns the shape as a list of 733 // `SmallVector<Value>`. The former takes precedence over the latter. So first 734 // check if the op implements the first interface method or the second, and 735 // get the value to use appropriately. 736 SmallVector<Value> reifiedResultShapes; 737 if (succeeded( 738 shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) { 739 if (reifiedResultShapes.size() <= resultNumber) 740 return nullptr; 741 Value resultShape = reifiedResultShapes[resultNumber]; 742 auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); 743 if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) 744 return nullptr; 745 return builder.create<tensor::ExtractOp>( 746 loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex)); 747 } 748 749 SmallVector<SmallVector<Value>> reifiedResultShapesPerDim; 750 if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( 751 builder, reifiedResultShapesPerDim))) 752 return nullptr; 753 if (reifiedResultShapesPerDim.size() <= resultNumber || 754 reifiedResultShapesPerDim[resultNumber].size() != 755 static_cast<size_t>(result.getType().cast<ShapedType>().getRank())) 756 return nullptr; 757 OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; 758 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) 759 return builder.createOrFold<ConstantIndexOp>( 760 loc, attr.cast<IntegerAttr>().getInt()); 761 return valueOrAttr.get<Value>(); 762 } 763 764 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 765 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> { 766 using OpRewritePattern<DimOp>::OpRewritePattern; 767 768 LogicalResult matchAndRewrite(DimOp dimOp, 769 PatternRewriter &rewriter) const override { 770 OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>(); 771 if (!dimValue) 772 return failure(); 773 auto shapedTypeOp = 774 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); 775 if (!shapedTypeOp) 776 return failure(); 777 778 Optional<int64_t> dimIndex = dimOp.getConstantIndex(); 779 if (!dimIndex) 780 return failure(); 781 Value replacement = 782 getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); 783 if (!replacement) 784 return failure(); 785 rewriter.replaceOp(dimOp, replacement); 786 return success(); 787 } 788 }; 789 } // end anonymous namespace. 790 791 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 792 MLIRContext *context) { 793 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 794 DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context); 795 } 796 797 // --------------------------------------------------------------------------- 798 // DmaStartOp 799 // --------------------------------------------------------------------------- 800 801 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 802 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 803 ValueRange destIndices, Value numElements, 804 Value tagMemRef, ValueRange tagIndices, Value stride, 805 Value elementsPerStride) { 806 result.addOperands(srcMemRef); 807 result.addOperands(srcIndices); 808 result.addOperands(destMemRef); 809 result.addOperands(destIndices); 810 result.addOperands({numElements, tagMemRef}); 811 result.addOperands(tagIndices); 812 if (stride) 813 result.addOperands({stride, elementsPerStride}); 814 } 815 816 void DmaStartOp::print(OpAsmPrinter &p) { 817 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 818 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 819 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 820 << ']'; 821 if (isStrided()) 822 p << ", " << getStride() << ", " << getNumElementsPerStride(); 823 824 p.printOptionalAttrDict((*this)->getAttrs()); 825 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 826 << ", " << getTagMemRef().getType(); 827 } 828 829 // Parse DmaStartOp. 830 // Ex: 831 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 832 // %tag[%index], %stride, %num_elt_per_stride : 833 // : memref<3076 x f32, 0>, 834 // memref<1024 x f32, 2>, 835 // memref<1 x i32> 836 // 837 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 838 OpAsmParser::OperandType srcMemRefInfo; 839 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 840 OpAsmParser::OperandType dstMemRefInfo; 841 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 842 OpAsmParser::OperandType numElementsInfo; 843 OpAsmParser::OperandType tagMemrefInfo; 844 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 845 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 846 847 SmallVector<Type, 3> types; 848 auto indexType = parser.getBuilder().getIndexType(); 849 850 // Parse and resolve the following list of operands: 851 // *) source memref followed by its indices (in square brackets). 852 // *) destination memref followed by its indices (in square brackets). 853 // *) dma size in KiB. 854 if (parser.parseOperand(srcMemRefInfo) || 855 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 856 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 857 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 858 parser.parseComma() || parser.parseOperand(numElementsInfo) || 859 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 860 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 861 return failure(); 862 863 // Parse optional stride and elements per stride. 864 if (parser.parseTrailingOperandList(strideInfo)) 865 return failure(); 866 867 bool isStrided = strideInfo.size() == 2; 868 if (!strideInfo.empty() && !isStrided) { 869 return parser.emitError(parser.getNameLoc(), 870 "expected two stride related operands"); 871 } 872 873 if (parser.parseColonTypeList(types)) 874 return failure(); 875 if (types.size() != 3) 876 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 877 878 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 879 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 880 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 881 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 882 // size should be an index. 883 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 884 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 885 // tag indices should be index. 886 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 887 return failure(); 888 889 if (isStrided) { 890 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 891 return failure(); 892 } 893 894 return success(); 895 } 896 897 LogicalResult DmaStartOp::verify() { 898 unsigned numOperands = getNumOperands(); 899 900 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 901 // the number of elements. 902 if (numOperands < 4) 903 return emitOpError("expected at least 4 operands"); 904 905 // Check types of operands. The order of these calls is important: the later 906 // calls rely on some type properties to compute the operand position. 907 // 1. Source memref. 908 if (!getSrcMemRef().getType().isa<MemRefType>()) 909 return emitOpError("expected source to be of memref type"); 910 if (numOperands < getSrcMemRefRank() + 4) 911 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 912 << " operands"; 913 if (!getSrcIndices().empty() && 914 !llvm::all_of(getSrcIndices().getTypes(), 915 [](Type t) { return t.isIndex(); })) 916 return emitOpError("expected source indices to be of index type"); 917 918 // 2. Destination memref. 919 if (!getDstMemRef().getType().isa<MemRefType>()) 920 return emitOpError("expected destination to be of memref type"); 921 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 922 if (numOperands < numExpectedOperands) 923 return emitOpError() << "expected at least " << numExpectedOperands 924 << " operands"; 925 if (!getDstIndices().empty() && 926 !llvm::all_of(getDstIndices().getTypes(), 927 [](Type t) { return t.isIndex(); })) 928 return emitOpError("expected destination indices to be of index type"); 929 930 // 3. Number of elements. 931 if (!getNumElements().getType().isIndex()) 932 return emitOpError("expected num elements to be of index type"); 933 934 // 4. Tag memref. 935 if (!getTagMemRef().getType().isa<MemRefType>()) 936 return emitOpError("expected tag to be of memref type"); 937 numExpectedOperands += getTagMemRefRank(); 938 if (numOperands < numExpectedOperands) 939 return emitOpError() << "expected at least " << numExpectedOperands 940 << " operands"; 941 if (!getTagIndices().empty() && 942 !llvm::all_of(getTagIndices().getTypes(), 943 [](Type t) { return t.isIndex(); })) 944 return emitOpError("expected tag indices to be of index type"); 945 946 // DMAs from different memory spaces supported. 947 if (getSrcMemorySpace() == getDstMemorySpace()) 948 return emitOpError("DMA should be between different memory spaces"); 949 950 // Optional stride-related operands must be either both present or both 951 // absent. 952 if (numOperands != numExpectedOperands && 953 numOperands != numExpectedOperands + 2) 954 return emitOpError("incorrect number of operands"); 955 956 // 5. Strides. 957 if (isStrided()) { 958 if (!getStride().getType().isIndex() || 959 !getNumElementsPerStride().getType().isIndex()) 960 return emitOpError( 961 "expected stride and num elements per stride to be of type index"); 962 } 963 964 return success(); 965 } 966 967 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 968 SmallVectorImpl<OpFoldResult> &results) { 969 /// dma_start(memrefcast) -> dma_start 970 return foldMemRefCast(*this); 971 } 972 973 // --------------------------------------------------------------------------- 974 // DmaWaitOp 975 // --------------------------------------------------------------------------- 976 977 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 978 Value tagMemRef, ValueRange tagIndices, 979 Value numElements) { 980 result.addOperands(tagMemRef); 981 result.addOperands(tagIndices); 982 result.addOperands(numElements); 983 } 984 985 void DmaWaitOp::print(OpAsmPrinter &p) { 986 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 987 << "], " << getNumElements(); 988 p.printOptionalAttrDict((*this)->getAttrs()); 989 p << " : " << getTagMemRef().getType(); 990 } 991 992 // Parse DmaWaitOp. 993 // Eg: 994 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 995 // 996 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 997 OpAsmParser::OperandType tagMemrefInfo; 998 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 999 Type type; 1000 auto indexType = parser.getBuilder().getIndexType(); 1001 OpAsmParser::OperandType numElementsInfo; 1002 1003 // Parse tag memref, its indices, and dma size. 1004 if (parser.parseOperand(tagMemrefInfo) || 1005 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 1006 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1007 parser.parseColonType(type) || 1008 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 1009 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 1010 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1011 return failure(); 1012 1013 return success(); 1014 } 1015 1016 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1017 SmallVectorImpl<OpFoldResult> &results) { 1018 /// dma_wait(memrefcast) -> dma_wait 1019 return foldMemRefCast(*this); 1020 } 1021 1022 LogicalResult DmaWaitOp::verify() { 1023 // Mandatory non-variadic operands are tag and the number of elements. 1024 if (getNumOperands() < 2) 1025 return emitOpError() << "expected at least 2 operands"; 1026 1027 // Check types of operands. The order of these calls is important: the later 1028 // calls rely on some type properties to compute the operand position. 1029 if (!getTagMemRef().getType().isa<MemRefType>()) 1030 return emitOpError() << "expected tag to be of memref type"; 1031 1032 if (getNumOperands() != 2 + getTagMemRefRank()) 1033 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1034 << " operands"; 1035 1036 if (!getTagIndices().empty() && 1037 !llvm::all_of(getTagIndices().getTypes(), 1038 [](Type t) { return t.isIndex(); })) 1039 return emitOpError() << "expected tag indices to be of index type"; 1040 1041 if (!getNumElements().getType().isIndex()) 1042 return emitOpError() 1043 << "expected the number of elements to be of index type"; 1044 1045 return success(); 1046 } 1047 1048 //===----------------------------------------------------------------------===// 1049 // GlobalOp 1050 //===----------------------------------------------------------------------===// 1051 1052 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1053 TypeAttr type, 1054 Attribute initialValue) { 1055 p << type; 1056 if (!op.isExternal()) { 1057 p << " = "; 1058 if (op.isUninitialized()) 1059 p << "uninitialized"; 1060 else 1061 p.printAttributeWithoutType(initialValue); 1062 } 1063 } 1064 1065 static ParseResult 1066 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1067 Attribute &initialValue) { 1068 Type type; 1069 if (parser.parseType(type)) 1070 return failure(); 1071 1072 auto memrefType = type.dyn_cast<MemRefType>(); 1073 if (!memrefType || !memrefType.hasStaticShape()) 1074 return parser.emitError(parser.getNameLoc()) 1075 << "type should be static shaped memref, but got " << type; 1076 typeAttr = TypeAttr::get(type); 1077 1078 if (parser.parseOptionalEqual()) 1079 return success(); 1080 1081 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1082 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1083 return success(); 1084 } 1085 1086 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1087 if (parser.parseAttribute(initialValue, tensorType)) 1088 return failure(); 1089 if (!initialValue.isa<ElementsAttr>()) 1090 return parser.emitError(parser.getNameLoc()) 1091 << "initial value should be a unit or elements attribute"; 1092 return success(); 1093 } 1094 1095 static LogicalResult verify(GlobalOp op) { 1096 auto memrefType = op.type().dyn_cast<MemRefType>(); 1097 if (!memrefType || !memrefType.hasStaticShape()) 1098 return op.emitOpError("type should be static shaped memref, but got ") 1099 << op.type(); 1100 1101 // Verify that the initial value, if present, is either a unit attribute or 1102 // an elements attribute. 1103 if (op.initial_value().hasValue()) { 1104 Attribute initValue = op.initial_value().getValue(); 1105 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1106 return op.emitOpError("initial value should be a unit or elements " 1107 "attribute, but got ") 1108 << initValue; 1109 1110 // Check that the type of the initial value is compatible with the type of 1111 // the global variable. 1112 if (initValue.isa<ElementsAttr>()) { 1113 Type initType = initValue.getType(); 1114 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1115 if (initType != tensorType) 1116 return op.emitOpError("initial value expected to be of type ") 1117 << tensorType << ", but was of type " << initType; 1118 } 1119 } 1120 1121 // TODO: verify visibility for declarations. 1122 return success(); 1123 } 1124 1125 //===----------------------------------------------------------------------===// 1126 // GetGlobalOp 1127 //===----------------------------------------------------------------------===// 1128 1129 LogicalResult 1130 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1131 // Verify that the result type is same as the type of the referenced 1132 // memref.global op. 1133 auto global = 1134 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1135 if (!global) 1136 return emitOpError("'") 1137 << name() << "' does not reference a valid global memref"; 1138 1139 Type resultType = result().getType(); 1140 if (global.type() != resultType) 1141 return emitOpError("result type ") 1142 << resultType << " does not match type " << global.type() 1143 << " of the global memref @" << name(); 1144 return success(); 1145 } 1146 1147 //===----------------------------------------------------------------------===// 1148 // LoadOp 1149 //===----------------------------------------------------------------------===// 1150 1151 static LogicalResult verify(LoadOp op) { 1152 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1153 return op.emitOpError("incorrect number of indices for load"); 1154 return success(); 1155 } 1156 1157 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1158 /// load(memrefcast) -> load 1159 if (succeeded(foldMemRefCast(*this))) 1160 return getResult(); 1161 return OpFoldResult(); 1162 } 1163 1164 namespace { 1165 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1166 /// corresponding tensor. 1167 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1168 using OpRewritePattern<LoadOp>::OpRewritePattern; 1169 1170 LogicalResult matchAndRewrite(LoadOp load, 1171 PatternRewriter &rewriter) const override { 1172 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1173 if (!buffercast) 1174 return failure(); 1175 1176 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1177 load.indices()); 1178 return success(); 1179 } 1180 }; 1181 } // end anonymous namespace. 1182 1183 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1184 MLIRContext *context) { 1185 results.add<LoadOfBufferCast>(context); 1186 } 1187 1188 //===----------------------------------------------------------------------===// 1189 // PrefetchOp 1190 //===----------------------------------------------------------------------===// 1191 1192 static void print(OpAsmPrinter &p, PrefetchOp op) { 1193 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1194 p.printOperands(op.indices()); 1195 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1196 p << ", locality<" << op.localityHint(); 1197 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1198 p.printOptionalAttrDict( 1199 op->getAttrs(), 1200 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1201 p << " : " << op.getMemRefType(); 1202 } 1203 1204 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1205 OperationState &result) { 1206 OpAsmParser::OperandType memrefInfo; 1207 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1208 IntegerAttr localityHint; 1209 MemRefType type; 1210 StringRef readOrWrite, cacheType; 1211 1212 auto indexTy = parser.getBuilder().getIndexType(); 1213 auto i32Type = parser.getBuilder().getIntegerType(32); 1214 if (parser.parseOperand(memrefInfo) || 1215 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1216 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1217 parser.parseComma() || parser.parseKeyword("locality") || 1218 parser.parseLess() || 1219 parser.parseAttribute(localityHint, i32Type, "localityHint", 1220 result.attributes) || 1221 parser.parseGreater() || parser.parseComma() || 1222 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1223 parser.resolveOperand(memrefInfo, type, result.operands) || 1224 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1225 return failure(); 1226 1227 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1228 return parser.emitError(parser.getNameLoc(), 1229 "rw specifier has to be 'read' or 'write'"); 1230 result.addAttribute( 1231 PrefetchOp::getIsWriteAttrName(), 1232 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1233 1234 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1235 return parser.emitError(parser.getNameLoc(), 1236 "cache type has to be 'data' or 'instr'"); 1237 1238 result.addAttribute( 1239 PrefetchOp::getIsDataCacheAttrName(), 1240 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1241 1242 return success(); 1243 } 1244 1245 static LogicalResult verify(PrefetchOp op) { 1246 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1247 return op.emitOpError("too few indices"); 1248 1249 return success(); 1250 } 1251 1252 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1253 SmallVectorImpl<OpFoldResult> &results) { 1254 // prefetch(memrefcast) -> prefetch 1255 return foldMemRefCast(*this); 1256 } 1257 1258 //===----------------------------------------------------------------------===// 1259 // ReinterpretCastOp 1260 //===----------------------------------------------------------------------===// 1261 1262 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1263 /// `staticSizes` and `staticStrides` are automatically filled with 1264 /// source-memref-rank sentinel values that encode dynamic entries. 1265 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1266 MemRefType resultType, Value source, 1267 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1268 ArrayRef<OpFoldResult> strides, 1269 ArrayRef<NamedAttribute> attrs) { 1270 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1271 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1272 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1273 ShapedType::kDynamicStrideOrOffset); 1274 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1275 ShapedType::kDynamicSize); 1276 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1277 ShapedType::kDynamicStrideOrOffset); 1278 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1279 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1280 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1281 result.addAttributes(attrs); 1282 } 1283 1284 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1285 MemRefType resultType, Value source, 1286 int64_t offset, ArrayRef<int64_t> sizes, 1287 ArrayRef<int64_t> strides, 1288 ArrayRef<NamedAttribute> attrs) { 1289 SmallVector<OpFoldResult> sizeValues = 1290 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1291 return b.getI64IntegerAttr(v); 1292 })); 1293 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1294 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1295 return b.getI64IntegerAttr(v); 1296 })); 1297 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1298 strideValues, attrs); 1299 } 1300 1301 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1302 MemRefType resultType, Value source, Value offset, 1303 ValueRange sizes, ValueRange strides, 1304 ArrayRef<NamedAttribute> attrs) { 1305 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1306 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1307 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1308 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1309 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1310 } 1311 1312 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1313 // completed automatically, like we have for subview and subtensor. 1314 static LogicalResult verify(ReinterpretCastOp op) { 1315 // The source and result memrefs should be in the same memory space. 1316 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1317 auto resultType = op.getType().cast<MemRefType>(); 1318 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1319 return op.emitError("different memory spaces specified for source type ") 1320 << srcType << " and result memref type " << resultType; 1321 if (srcType.getElementType() != resultType.getElementType()) 1322 return op.emitError("different element types specified for source type ") 1323 << srcType << " and result memref type " << resultType; 1324 1325 // Match sizes in result memref type and in static_sizes attribute. 1326 for (auto &en : 1327 llvm::enumerate(llvm::zip(resultType.getShape(), 1328 extractFromI64ArrayAttr(op.static_sizes())))) { 1329 int64_t resultSize = std::get<0>(en.value()); 1330 int64_t expectedSize = std::get<1>(en.value()); 1331 if (resultSize != expectedSize) 1332 return op.emitError("expected result type with size = ") 1333 << expectedSize << " instead of " << resultSize 1334 << " in dim = " << en.index(); 1335 } 1336 1337 // Match offset and strides in static_offset and static_strides attributes if 1338 // result memref type has an affine map specified. 1339 if (!resultType.getAffineMaps().empty()) { 1340 int64_t resultOffset; 1341 SmallVector<int64_t, 4> resultStrides; 1342 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1343 return failure(); 1344 1345 // Match offset in result memref type and in static_offsets attribute. 1346 int64_t expectedOffset = 1347 extractFromI64ArrayAttr(op.static_offsets()).front(); 1348 if (resultOffset != expectedOffset) 1349 return op.emitError("expected result type with offset = ") 1350 << resultOffset << " instead of " << expectedOffset; 1351 1352 // Match strides in result memref type and in static_strides attribute. 1353 for (auto &en : llvm::enumerate(llvm::zip( 1354 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1355 int64_t resultStride = std::get<0>(en.value()); 1356 int64_t expectedStride = std::get<1>(en.value()); 1357 if (resultStride != expectedStride) 1358 return op.emitError("expected result type with stride = ") 1359 << expectedStride << " instead of " << resultStride 1360 << " in dim = " << en.index(); 1361 } 1362 } 1363 return success(); 1364 } 1365 1366 //===----------------------------------------------------------------------===// 1367 // ReshapeOp 1368 //===----------------------------------------------------------------------===// 1369 1370 static LogicalResult verify(ReshapeOp op) { 1371 Type operandType = op.source().getType(); 1372 Type resultType = op.result().getType(); 1373 1374 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1375 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1376 if (operandElementType != resultElementType) 1377 return op.emitOpError("element types of source and destination memref " 1378 "types should be the same"); 1379 1380 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1381 if (!operandMemRefType.getAffineMaps().empty()) 1382 return op.emitOpError( 1383 "source memref type should have identity affine map"); 1384 1385 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1386 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1387 if (resultMemRefType) { 1388 if (!resultMemRefType.getAffineMaps().empty()) 1389 return op.emitOpError( 1390 "result memref type should have identity affine map"); 1391 if (shapeSize == ShapedType::kDynamicSize) 1392 return op.emitOpError("cannot use shape operand with dynamic length to " 1393 "reshape to statically-ranked memref type"); 1394 if (shapeSize != resultMemRefType.getRank()) 1395 return op.emitOpError( 1396 "length of shape operand differs from the result's memref rank"); 1397 } 1398 return success(); 1399 } 1400 1401 //===----------------------------------------------------------------------===// 1402 // StoreOp 1403 //===----------------------------------------------------------------------===// 1404 1405 static LogicalResult verify(StoreOp op) { 1406 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1407 return op.emitOpError("store index operand count not equal to memref rank"); 1408 1409 return success(); 1410 } 1411 1412 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1413 SmallVectorImpl<OpFoldResult> &results) { 1414 /// store(memrefcast) -> store 1415 return foldMemRefCast(*this); 1416 } 1417 1418 //===----------------------------------------------------------------------===// 1419 // SubViewOp 1420 //===----------------------------------------------------------------------===// 1421 1422 namespace { 1423 /// Helpers to write more idiomatic operations. 1424 namespace saturated_arith { 1425 struct Wrapper { 1426 explicit Wrapper(int64_t v) : v(v) {} 1427 operator int64_t() { return v; } 1428 int64_t v; 1429 }; 1430 Wrapper operator+(Wrapper a, int64_t b) { 1431 if (ShapedType::isDynamicStrideOrOffset(a) || 1432 ShapedType::isDynamicStrideOrOffset(b)) 1433 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1434 return Wrapper(a.v + b); 1435 } 1436 Wrapper operator*(Wrapper a, int64_t b) { 1437 if (ShapedType::isDynamicStrideOrOffset(a) || 1438 ShapedType::isDynamicStrideOrOffset(b)) 1439 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1440 return Wrapper(a.v * b); 1441 } 1442 } // end namespace saturated_arith 1443 } // end namespace 1444 1445 /// A subview result type can be fully inferred from the source type and the 1446 /// static representation of offsets, sizes and strides. Special sentinels 1447 /// encode the dynamic case. 1448 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1449 ArrayRef<int64_t> leadingStaticOffsets, 1450 ArrayRef<int64_t> leadingStaticSizes, 1451 ArrayRef<int64_t> leadingStaticStrides) { 1452 // A subview may specify only a leading subset of offset/sizes/strides in 1453 // which case we complete with offset=0, sizes from memref type and strides=1. 1454 unsigned rank = sourceMemRefType.getRank(); 1455 assert(leadingStaticOffsets.size() <= rank && 1456 "unexpected leadingStaticOffsets overflow"); 1457 assert(leadingStaticSizes.size() <= rank && 1458 "unexpected leadingStaticSizes overflow"); 1459 assert(leadingStaticStrides.size() <= rank && 1460 "unexpected leadingStaticStrides overflow"); 1461 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1462 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1463 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1464 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1465 unsigned numTrailingSizes = rank - staticSizes.size(); 1466 unsigned numTrailingStrides = rank - staticStrides.size(); 1467 staticOffsets.append(numTrailingOffsets, 0); 1468 llvm::append_range(staticSizes, 1469 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1470 staticStrides.append(numTrailingStrides, 1); 1471 1472 // Extract source offset and strides. 1473 int64_t sourceOffset; 1474 SmallVector<int64_t, 4> sourceStrides; 1475 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1476 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1477 (void)res; 1478 1479 // Compute target offset whose value is: 1480 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1481 int64_t targetOffset = sourceOffset; 1482 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1483 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1484 using namespace saturated_arith; 1485 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1486 } 1487 1488 // Compute target stride whose value is: 1489 // `sourceStrides_i * staticStrides_i`. 1490 SmallVector<int64_t, 4> targetStrides; 1491 targetStrides.reserve(staticOffsets.size()); 1492 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1493 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1494 using namespace saturated_arith; 1495 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1496 } 1497 1498 // The type is now known. 1499 return MemRefType::get( 1500 staticSizes, sourceMemRefType.getElementType(), 1501 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1502 sourceMemRefType.getContext()), 1503 sourceMemRefType.getMemorySpace()); 1504 } 1505 1506 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1507 ArrayRef<OpFoldResult> leadingStaticOffsets, 1508 ArrayRef<OpFoldResult> leadingStaticSizes, 1509 ArrayRef<OpFoldResult> leadingStaticStrides) { 1510 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1511 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1512 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1513 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1514 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1515 ShapedType::kDynamicSize); 1516 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1517 staticStrides, ShapedType::kDynamicStrideOrOffset); 1518 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1519 staticSizes, staticStrides) 1520 .cast<MemRefType>(); 1521 } 1522 1523 Type SubViewOp::inferRankReducedResultType( 1524 unsigned resultRank, MemRefType sourceRankedTensorType, 1525 ArrayRef<int64_t> leadingStaticOffsets, 1526 ArrayRef<int64_t> leadingStaticSizes, 1527 ArrayRef<int64_t> leadingStaticStrides) { 1528 auto inferredType = 1529 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1530 leadingStaticSizes, leadingStaticStrides) 1531 .cast<MemRefType>(); 1532 assert(inferredType.getRank() >= resultRank && "expected "); 1533 int rankDiff = inferredType.getRank() - resultRank; 1534 if (rankDiff > 0) { 1535 auto shape = inferredType.getShape(); 1536 llvm::SmallDenseSet<unsigned> dimsToProject; 1537 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1538 SmallVector<int64_t> projectedShape; 1539 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1540 if (!dimsToProject.contains(pos)) 1541 projectedShape.push_back(shape[pos]); 1542 1543 AffineMap map; 1544 auto maps = inferredType.getAffineMaps(); 1545 if (!maps.empty() && maps.front()) 1546 map = getProjectedMap(maps.front(), dimsToProject); 1547 inferredType = 1548 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1549 inferredType.getMemorySpace()); 1550 } 1551 return inferredType; 1552 } 1553 1554 Type SubViewOp::inferRankReducedResultType( 1555 unsigned resultRank, MemRefType sourceRankedTensorType, 1556 ArrayRef<OpFoldResult> leadingStaticOffsets, 1557 ArrayRef<OpFoldResult> leadingStaticSizes, 1558 ArrayRef<OpFoldResult> leadingStaticStrides) { 1559 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1560 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1561 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1562 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1563 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1564 ShapedType::kDynamicSize); 1565 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1566 staticStrides, ShapedType::kDynamicStrideOrOffset); 1567 return SubViewOp::inferRankReducedResultType( 1568 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1569 staticStrides); 1570 } 1571 // Build a SubViewOp with mixed static and dynamic entries and custom result 1572 // type. If the type passed is nullptr, it is inferred. 1573 void SubViewOp::build(OpBuilder &b, OperationState &result, 1574 MemRefType resultType, Value source, 1575 ArrayRef<OpFoldResult> offsets, 1576 ArrayRef<OpFoldResult> sizes, 1577 ArrayRef<OpFoldResult> strides, 1578 ArrayRef<NamedAttribute> attrs) { 1579 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1580 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1581 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1582 ShapedType::kDynamicStrideOrOffset); 1583 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1584 ShapedType::kDynamicSize); 1585 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1586 ShapedType::kDynamicStrideOrOffset); 1587 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1588 // Structuring implementation this way avoids duplication between builders. 1589 if (!resultType) { 1590 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1591 staticSizes, staticStrides) 1592 .cast<MemRefType>(); 1593 } 1594 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1595 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1596 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1597 result.addAttributes(attrs); 1598 } 1599 1600 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1601 // type. 1602 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1603 ArrayRef<OpFoldResult> offsets, 1604 ArrayRef<OpFoldResult> sizes, 1605 ArrayRef<OpFoldResult> strides, 1606 ArrayRef<NamedAttribute> attrs) { 1607 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1608 } 1609 1610 // Build a SubViewOp with static entries and inferred result type. 1611 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1612 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1613 ArrayRef<int64_t> strides, 1614 ArrayRef<NamedAttribute> attrs) { 1615 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1616 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1617 return b.getI64IntegerAttr(v); 1618 })); 1619 SmallVector<OpFoldResult> sizeValues = 1620 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1621 return b.getI64IntegerAttr(v); 1622 })); 1623 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1624 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1625 return b.getI64IntegerAttr(v); 1626 })); 1627 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1628 } 1629 1630 // Build a SubViewOp with dynamic entries and custom result type. If the 1631 // type passed is nullptr, it is inferred. 1632 void SubViewOp::build(OpBuilder &b, OperationState &result, 1633 MemRefType resultType, Value source, 1634 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1635 ArrayRef<int64_t> strides, 1636 ArrayRef<NamedAttribute> attrs) { 1637 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1638 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1639 return b.getI64IntegerAttr(v); 1640 })); 1641 SmallVector<OpFoldResult> sizeValues = 1642 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1643 return b.getI64IntegerAttr(v); 1644 })); 1645 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1646 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1647 return b.getI64IntegerAttr(v); 1648 })); 1649 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1650 attrs); 1651 } 1652 1653 // Build a SubViewOp with dynamic entries and custom result type. If the type 1654 // passed is nullptr, it is inferred. 1655 void SubViewOp::build(OpBuilder &b, OperationState &result, 1656 MemRefType resultType, Value source, ValueRange offsets, 1657 ValueRange sizes, ValueRange strides, 1658 ArrayRef<NamedAttribute> attrs) { 1659 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1660 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1661 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1662 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1663 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1664 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1665 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1666 } 1667 1668 // Build a SubViewOp with dynamic entries and inferred result type. 1669 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1670 ValueRange offsets, ValueRange sizes, ValueRange strides, 1671 ArrayRef<NamedAttribute> attrs) { 1672 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1673 } 1674 1675 /// For ViewLikeOpInterface. 1676 Value SubViewOp::getViewSource() { return source(); } 1677 1678 enum SubViewVerificationResult { 1679 Success, 1680 RankTooLarge, 1681 SizeMismatch, 1682 ElemTypeMismatch, 1683 MemSpaceMismatch, 1684 AffineMapMismatch 1685 }; 1686 1687 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1688 /// This function is slight variant of `is subsequence` algorithm where 1689 /// not matching dimension must be 1. 1690 static SubViewVerificationResult 1691 isRankReducedType(Type originalType, Type candidateReducedType, 1692 std::string *errMsg = nullptr) { 1693 if (originalType == candidateReducedType) 1694 return SubViewVerificationResult::Success; 1695 if (!originalType.isa<MemRefType>()) 1696 return SubViewVerificationResult::Success; 1697 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1698 return SubViewVerificationResult::Success; 1699 1700 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1701 ShapedType candidateReducedShapedType = 1702 candidateReducedType.cast<ShapedType>(); 1703 1704 // Rank and size logic is valid for all ShapedTypes. 1705 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1706 ArrayRef<int64_t> candidateReducedShape = 1707 candidateReducedShapedType.getShape(); 1708 unsigned originalRank = originalShape.size(), 1709 candidateReducedRank = candidateReducedShape.size(); 1710 if (candidateReducedRank > originalRank) 1711 return SubViewVerificationResult::RankTooLarge; 1712 1713 auto optionalUnusedDimsMask = 1714 computeRankReductionMask(originalShape, candidateReducedShape); 1715 1716 // Sizes cannot be matched in case empty vector is returned. 1717 if (!optionalUnusedDimsMask.hasValue()) 1718 return SubViewVerificationResult::SizeMismatch; 1719 1720 if (originalShapedType.getElementType() != 1721 candidateReducedShapedType.getElementType()) 1722 return SubViewVerificationResult::ElemTypeMismatch; 1723 1724 // Strided layout logic is relevant for MemRefType only. 1725 MemRefType original = originalType.cast<MemRefType>(); 1726 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1727 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1728 return SubViewVerificationResult::MemSpaceMismatch; 1729 1730 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1731 auto inferredType = 1732 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1733 AffineMap candidateLayout; 1734 if (candidateReduced.getAffineMaps().empty()) 1735 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1736 else 1737 candidateLayout = candidateReduced.getAffineMaps().front(); 1738 assert(inferredType.getNumResults() == 1 && 1739 candidateLayout.getNumResults() == 1); 1740 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1741 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1742 if (errMsg) { 1743 llvm::raw_string_ostream os(*errMsg); 1744 os << "inferred type: " << inferredType; 1745 } 1746 return SubViewVerificationResult::AffineMapMismatch; 1747 } 1748 // Check that the difference of the affine maps simplifies to 0. 1749 AffineExpr diffExpr = 1750 inferredType.getResult(0) - candidateLayout.getResult(0); 1751 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1752 inferredType.getNumSymbols()); 1753 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1754 if (!(cst && cst.getValue() == 0)) { 1755 if (errMsg) { 1756 llvm::raw_string_ostream os(*errMsg); 1757 os << "inferred type: " << inferredType; 1758 } 1759 return SubViewVerificationResult::AffineMapMismatch; 1760 } 1761 return SubViewVerificationResult::Success; 1762 } 1763 1764 template <typename OpTy> 1765 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1766 OpTy op, Type expectedType, 1767 StringRef errMsg = "") { 1768 auto memrefType = expectedType.cast<ShapedType>(); 1769 switch (result) { 1770 case SubViewVerificationResult::Success: 1771 return success(); 1772 case SubViewVerificationResult::RankTooLarge: 1773 return op.emitError("expected result rank to be smaller or equal to ") 1774 << "the source rank. " << errMsg; 1775 case SubViewVerificationResult::SizeMismatch: 1776 return op.emitError("expected result type to be ") 1777 << expectedType 1778 << " or a rank-reduced version. (mismatch of result sizes) " 1779 << errMsg; 1780 case SubViewVerificationResult::ElemTypeMismatch: 1781 return op.emitError("expected result element type to be ") 1782 << memrefType.getElementType() << errMsg; 1783 case SubViewVerificationResult::MemSpaceMismatch: 1784 return op.emitError("expected result and source memory spaces to match.") 1785 << errMsg; 1786 case SubViewVerificationResult::AffineMapMismatch: 1787 return op.emitError("expected result type to be ") 1788 << expectedType 1789 << " or a rank-reduced version. (mismatch of result affine map) " 1790 << errMsg; 1791 } 1792 llvm_unreachable("unexpected subview verification result"); 1793 } 1794 1795 /// Verifier for SubViewOp. 1796 static LogicalResult verify(SubViewOp op) { 1797 MemRefType baseType = op.getSourceType(); 1798 MemRefType subViewType = op.getType(); 1799 1800 // The base memref and the view memref should be in the same memory space. 1801 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1802 return op.emitError("different memory spaces specified for base memref " 1803 "type ") 1804 << baseType << " and subview memref type " << subViewType; 1805 1806 // Verify that the base memref type has a strided layout map. 1807 if (!isStrided(baseType)) 1808 return op.emitError("base type ") << baseType << " is not strided"; 1809 1810 // Verify result type against inferred type. 1811 auto expectedType = SubViewOp::inferResultType( 1812 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1813 extractFromI64ArrayAttr(op.static_sizes()), 1814 extractFromI64ArrayAttr(op.static_strides())); 1815 1816 std::string errMsg; 1817 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1818 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1819 } 1820 1821 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1822 return os << "range " << range.offset << ":" << range.size << ":" 1823 << range.stride; 1824 } 1825 1826 /// Return the list of Range (i.e. offset, size, stride). Each Range 1827 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1828 /// with `b` at location `loc`. 1829 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1830 OpBuilder &b, Location loc) { 1831 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1832 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1833 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1834 SmallVector<Range, 8> res; 1835 unsigned rank = ranks[0]; 1836 res.reserve(rank); 1837 for (unsigned idx = 0; idx < rank; ++idx) { 1838 Value offset = 1839 op.isDynamicOffset(idx) 1840 ? op.getDynamicOffset(idx) 1841 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1842 Value size = op.isDynamicSize(idx) 1843 ? op.getDynamicSize(idx) 1844 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1845 Value stride = 1846 op.isDynamicStride(idx) 1847 ? op.getDynamicStride(idx) 1848 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1849 res.emplace_back(Range{offset, size, stride}); 1850 } 1851 return res; 1852 } 1853 1854 /// Infer the canonical type of the result of a subview operation. Returns a 1855 /// type with rank `resultRank` that is either the rank of the rank-reduced 1856 /// type, or the non-rank-reduced type. 1857 static MemRefType 1858 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1859 ArrayRef<OpFoldResult> mixedOffsets, 1860 ArrayRef<OpFoldResult> mixedSizes, 1861 ArrayRef<OpFoldResult> mixedStrides) { 1862 auto resultType = 1863 SubViewOp::inferRankReducedResultType( 1864 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1865 .cast<MemRefType>(); 1866 if (resultType.getRank() != resultRank) { 1867 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1868 mixedSizes, mixedStrides) 1869 .cast<MemRefType>(); 1870 } 1871 return resultType; 1872 } 1873 1874 namespace { 1875 /// Pattern to rewrite a subview op with MemRefCast arguments. 1876 /// This essentially pushes memref.cast past its consuming subview when 1877 /// `canFoldIntoConsumerOp` is true. 1878 /// 1879 /// Example: 1880 /// ``` 1881 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1882 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1883 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1884 /// ``` 1885 /// is rewritten into: 1886 /// ``` 1887 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1888 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1889 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1890 /// ``` 1891 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1892 public: 1893 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1894 1895 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1896 PatternRewriter &rewriter) const override { 1897 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1898 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1899 return matchPattern(operand, matchConstantIndex()); 1900 })) 1901 return failure(); 1902 1903 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1904 if (!castOp) 1905 return failure(); 1906 1907 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1908 return failure(); 1909 1910 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1911 /// the cast source operand type and the SubViewOp static information. This 1912 /// is the resulting type if the MemRefCastOp were folded. 1913 auto resultType = getCanonicalSubViewResultType( 1914 subViewOp.getType().getRank(), 1915 castOp.source().getType().cast<MemRefType>(), 1916 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1917 subViewOp.getMixedStrides()); 1918 Value newSubView = rewriter.create<SubViewOp>( 1919 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1920 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1921 subViewOp.static_sizes(), subViewOp.static_strides()); 1922 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1923 newSubView); 1924 return success(); 1925 } 1926 }; 1927 } // namespace 1928 1929 /// Return the canonical type of the result of a subview. 1930 struct SubViewReturnTypeCanonicalizer { 1931 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 1932 ArrayRef<OpFoldResult> mixedSizes, 1933 ArrayRef<OpFoldResult> mixedStrides) { 1934 return getCanonicalSubViewResultType(op.getType().getRank(), 1935 op.getSourceType(), mixedOffsets, 1936 mixedSizes, mixedStrides); 1937 } 1938 }; 1939 1940 /// A canonicalizer wrapper to replace SubViewOps. 1941 struct SubViewCanonicalizer { 1942 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1943 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1944 } 1945 }; 1946 1947 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1948 MLIRContext *context) { 1949 results 1950 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1951 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 1952 SubViewOpMemRefCastFolder>(context); 1953 } 1954 1955 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1956 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1957 auto sourceShapedType = source().getType().cast<ShapedType>(); 1958 1959 if (resultShapedType.hasStaticShape() && 1960 resultShapedType == sourceShapedType) { 1961 return getViewSource(); 1962 } 1963 1964 return {}; 1965 } 1966 1967 //===----------------------------------------------------------------------===// 1968 // TensorLoadOp 1969 //===----------------------------------------------------------------------===// 1970 1971 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1972 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1973 // Approximate alias analysis by conservatively folding only when no there 1974 // is no interleaved operation. 1975 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1976 bufferCast->getNextNode() == this->getOperation()) 1977 return bufferCast.tensor(); 1978 return {}; 1979 } 1980 1981 //===----------------------------------------------------------------------===// 1982 // TransposeOp 1983 //===----------------------------------------------------------------------===// 1984 1985 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1986 static MemRefType inferTransposeResultType(MemRefType memRefType, 1987 AffineMap permutationMap) { 1988 auto rank = memRefType.getRank(); 1989 auto originalSizes = memRefType.getShape(); 1990 // Compute permuted sizes. 1991 SmallVector<int64_t, 4> sizes(rank, 0); 1992 for (auto en : llvm::enumerate(permutationMap.getResults())) 1993 sizes[en.index()] = 1994 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 1995 1996 // Compute permuted strides. 1997 int64_t offset; 1998 SmallVector<int64_t, 4> strides; 1999 auto res = getStridesAndOffset(memRefType, strides, offset); 2000 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2001 (void)res; 2002 auto map = 2003 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2004 map = permutationMap ? map.compose(permutationMap) : map; 2005 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2006 } 2007 2008 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2009 AffineMapAttr permutation, 2010 ArrayRef<NamedAttribute> attrs) { 2011 auto permutationMap = permutation.getValue(); 2012 assert(permutationMap); 2013 2014 auto memRefType = in.getType().cast<MemRefType>(); 2015 // Compute result type. 2016 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2017 2018 build(b, result, resultType, in, attrs); 2019 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2020 } 2021 2022 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2023 static void print(OpAsmPrinter &p, TransposeOp op) { 2024 p << "memref.transpose " << op.in() << " " << op.permutation(); 2025 p.printOptionalAttrDict(op->getAttrs(), 2026 {TransposeOp::getPermutationAttrName()}); 2027 p << " : " << op.in().getType() << " to " << op.getType(); 2028 } 2029 2030 static ParseResult parseTransposeOp(OpAsmParser &parser, 2031 OperationState &result) { 2032 OpAsmParser::OperandType in; 2033 AffineMap permutation; 2034 MemRefType srcType, dstType; 2035 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2036 parser.parseOptionalAttrDict(result.attributes) || 2037 parser.parseColonType(srcType) || 2038 parser.resolveOperand(in, srcType, result.operands) || 2039 parser.parseKeywordType("to", dstType) || 2040 parser.addTypeToList(dstType, result.types)) 2041 return failure(); 2042 2043 result.addAttribute(TransposeOp::getPermutationAttrName(), 2044 AffineMapAttr::get(permutation)); 2045 return success(); 2046 } 2047 2048 static LogicalResult verify(TransposeOp op) { 2049 if (!op.permutation().isPermutation()) 2050 return op.emitOpError("expected a permutation map"); 2051 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2052 return op.emitOpError( 2053 "expected a permutation map of same rank as the input"); 2054 2055 auto srcType = op.in().getType().cast<MemRefType>(); 2056 auto dstType = op.getType().cast<MemRefType>(); 2057 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2058 if (dstType != transposedType) 2059 return op.emitOpError("output type ") 2060 << dstType << " does not match transposed input type " << srcType 2061 << ", " << transposedType; 2062 return success(); 2063 } 2064 2065 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2066 if (succeeded(foldMemRefCast(*this))) 2067 return getResult(); 2068 return {}; 2069 } 2070 2071 //===----------------------------------------------------------------------===// 2072 // ViewOp 2073 //===----------------------------------------------------------------------===// 2074 2075 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2076 OpAsmParser::OperandType srcInfo; 2077 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2078 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2079 auto indexType = parser.getBuilder().getIndexType(); 2080 Type srcType, dstType; 2081 llvm::SMLoc offsetLoc; 2082 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2083 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2084 return failure(); 2085 2086 if (offsetInfo.size() != 1) 2087 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2088 2089 return failure( 2090 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2091 parser.parseOptionalAttrDict(result.attributes) || 2092 parser.parseColonType(srcType) || 2093 parser.resolveOperand(srcInfo, srcType, result.operands) || 2094 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2095 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2096 parser.parseKeywordType("to", dstType) || 2097 parser.addTypeToList(dstType, result.types)); 2098 } 2099 2100 static void print(OpAsmPrinter &p, ViewOp op) { 2101 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2102 p.printOperand(op.byte_shift()); 2103 p << "][" << op.sizes() << ']'; 2104 p.printOptionalAttrDict(op->getAttrs()); 2105 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2106 } 2107 2108 static LogicalResult verify(ViewOp op) { 2109 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2110 auto viewType = op.getType(); 2111 2112 // The base memref should have identity layout map (or none). 2113 if (baseType.getAffineMaps().size() > 1 || 2114 (baseType.getAffineMaps().size() == 1 && 2115 !baseType.getAffineMaps()[0].isIdentity())) 2116 return op.emitError("unsupported map for base memref type ") << baseType; 2117 2118 // The result memref should have identity layout map (or none). 2119 if (viewType.getAffineMaps().size() > 1 || 2120 (viewType.getAffineMaps().size() == 1 && 2121 !viewType.getAffineMaps()[0].isIdentity())) 2122 return op.emitError("unsupported map for result memref type ") << viewType; 2123 2124 // The base memref and the view memref should be in the same memory space. 2125 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2126 return op.emitError("different memory spaces specified for base memref " 2127 "type ") 2128 << baseType << " and view memref type " << viewType; 2129 2130 // Verify that we have the correct number of sizes for the result type. 2131 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2132 if (op.sizes().size() != numDynamicDims) 2133 return op.emitError("incorrect number of size operands for type ") 2134 << viewType; 2135 2136 return success(); 2137 } 2138 2139 Value ViewOp::getViewSource() { return source(); } 2140 2141 namespace { 2142 2143 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2144 using OpRewritePattern<ViewOp>::OpRewritePattern; 2145 2146 LogicalResult matchAndRewrite(ViewOp viewOp, 2147 PatternRewriter &rewriter) const override { 2148 // Return if none of the operands are constants. 2149 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2150 return matchPattern(operand, matchConstantIndex()); 2151 })) 2152 return failure(); 2153 2154 // Get result memref type. 2155 auto memrefType = viewOp.getType(); 2156 2157 // Get offset from old memref view type 'memRefType'. 2158 int64_t oldOffset; 2159 SmallVector<int64_t, 4> oldStrides; 2160 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2161 return failure(); 2162 assert(oldOffset == 0 && "Expected 0 offset"); 2163 2164 SmallVector<Value, 4> newOperands; 2165 2166 // Offset cannot be folded into result type. 2167 2168 // Fold any dynamic dim operands which are produced by a constant. 2169 SmallVector<int64_t, 4> newShapeConstants; 2170 newShapeConstants.reserve(memrefType.getRank()); 2171 2172 unsigned dynamicDimPos = 0; 2173 unsigned rank = memrefType.getRank(); 2174 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2175 int64_t dimSize = memrefType.getDimSize(dim); 2176 // If this is already static dimension, keep it. 2177 if (!ShapedType::isDynamic(dimSize)) { 2178 newShapeConstants.push_back(dimSize); 2179 continue; 2180 } 2181 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2182 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2183 // Dynamic shape dimension will be folded. 2184 newShapeConstants.push_back(constantIndexOp.getValue()); 2185 } else { 2186 // Dynamic shape dimension not folded; copy operand from old memref. 2187 newShapeConstants.push_back(dimSize); 2188 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2189 } 2190 dynamicDimPos++; 2191 } 2192 2193 // Create new memref type with constant folded dims. 2194 MemRefType newMemRefType = 2195 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2196 // Nothing new, don't fold. 2197 if (newMemRefType == memrefType) 2198 return failure(); 2199 2200 // Create new ViewOp. 2201 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2202 viewOp.getOperand(0), 2203 viewOp.byte_shift(), newOperands); 2204 // Insert a cast so we have the same type as the old memref type. 2205 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2206 return success(); 2207 } 2208 }; 2209 2210 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2211 using OpRewritePattern<ViewOp>::OpRewritePattern; 2212 2213 LogicalResult matchAndRewrite(ViewOp viewOp, 2214 PatternRewriter &rewriter) const override { 2215 Value memrefOperand = viewOp.getOperand(0); 2216 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2217 if (!memrefCastOp) 2218 return failure(); 2219 Value allocOperand = memrefCastOp.getOperand(); 2220 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2221 if (!allocOp) 2222 return failure(); 2223 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2224 viewOp.byte_shift(), viewOp.sizes()); 2225 return success(); 2226 } 2227 }; 2228 2229 } // end anonymous namespace 2230 2231 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2232 MLIRContext *context) { 2233 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2234 } 2235 2236 //===----------------------------------------------------------------------===// 2237 // TableGen'd op method definitions 2238 //===----------------------------------------------------------------------===// 2239 2240 #define GET_OP_CLASSES 2241 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2242